diff options
author | compilade <git@compilade.net> | 2024-05-13 14:10:51 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-13 14:10:51 -0400 |
commit | ee52225067622babc277371511b8124884e1c797 (patch) | |
tree | 8150564487416fc038952c8b85f1462f3b1c98cf /convert-hf-to-gguf.py | |
parent | 614d3b914e1c3e02596f869649eb4f1d3b68614d (diff) |
convert-hf : support direct Q8_0 conversion (#7234)
* convert-hf : support q8_0 conversion
* convert-hf : add missing ftype
This was messing with the checksums otherwise.
* convert-hf : add missing ftype to Baichuan and Xverse
I didn't notice these on my first pass.
Diffstat (limited to 'convert-hf-to-gguf.py')
-rwxr-xr-x | convert-hf-to-gguf.py | 72 |
1 files changed, 28 insertions, 44 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index d6e5dece..cd875fa4 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -240,23 +240,6 @@ class Model: return False def write_tensors(self): - # same as ggml_compute_fp32_to_bf16 in ggml-impl.h - def np_fp32_to_bf16(n: np.ndarray): - # force nan to quiet - n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n) - # flush subnormals to zero - n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n) - # round to nearest even - n = (n + (0x7fff + ((n >> 16) & 1))) >> 16 - return n.astype(np.int16) - - # Doing this row-wise is much, much faster than element-wise, hence the signature - v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)") - if self.lazy: - # TODO: find a way to implicitly wrap np.vectorize functions - # NOTE: the type is changed to reflect otypes passed to np.vectorize above - v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16) - max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") for name, data_torch in self.get_tensors(): @@ -309,27 +292,31 @@ class Model: )) if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32: - if self.ftype == gguf.LlamaFileType.MOSTLY_F16: + if self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + data = gguf.quantize_bf16(data) + assert data.dtype == np.int16 + data_qtype = gguf.GGMLQuantizationType.BF16 + + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data): + data = gguf.quantize_q8_0(data) + assert data.dtype == np.uint8 + data_qtype = gguf.GGMLQuantizationType.Q8_0 + + else: # default to float16 for quantized tensors if data_dtype != np.float16: data = data.astype(np.float16) data_qtype = gguf.GGMLQuantizationType.F16 - elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: - if data_dtype != np.float32: - data = data.astype(np.float32) - data = v_fp32_to_bf16(data.view(np.int32)) - assert data.dtype == np.int16 - data_qtype = gguf.GGMLQuantizationType.BF16 - - else: # by default, convert to float32 + if data_qtype is None: # by default, convert to float32 if data_dtype != np.float32: data = data.astype(np.float32) data_qtype = gguf.GGMLQuantizationType.F32 - assert data_qtype is not None - + block_size, type_size = gguf.GGML_QUANT_SIZES[data_qtype] # reverse shape to make it similar to the internal ggml dimension order - shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}" + shape_str = f"""{{{', '.join(str(n) for n in reversed( + (*data.shape[:-1], data.shape[-1] * data.dtype.itemsize // type_size * block_size)) + )}}}""" # n_dims is implicit in the shape logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") @@ -859,6 +846,7 @@ class BaichuanModel(Model): self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count_kv(head_count_kv) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: if self.hparams["rope_scaling"].get("type") == "linear": @@ -981,6 +969,7 @@ class XverseModel(Model): self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count_kv(head_count_kv) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: if self.hparams["rope_scaling"].get("type") == "linear": @@ -1215,6 +1204,7 @@ class StableLMModel(Model): self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"]) self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True) self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"])) + self.gguf_writer.add_file_type(self.ftype) _q_norms: list[dict[str, Tensor]] | None = None _k_norms: list[dict[str, Tensor]] | None = None @@ -1591,6 +1581,7 @@ class QwenModel(Model): self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) @Model.register("Qwen2ForCausalLM") @@ -1828,6 +1819,7 @@ class PlamoModel(Model): self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) def shuffle_attn_q_weight(self, data_torch): assert data_torch.size() == (5120, 5120) @@ -2007,6 +1999,7 @@ in chat mode so that the conversation can end normally.") self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) + self.gguf_writer.add_file_type(self.ftype) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: num_heads = self.hparams["num_attention_heads"] @@ -2415,25 +2408,15 @@ class LazyTorchTensor(gguf.LazyBase): def numpy(self) -> gguf.LazyNumpyTensor: dtype = self._dtype_map[self.dtype] return gguf.LazyNumpyTensor( - meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)), + meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), lazy=self._lazy, args=(self,), func=(lambda s: s[0].numpy()) ) @classmethod - def eager_to_meta(cls, t: Tensor) -> Tensor: - if t.is_meta: - return t - return t.detach().to("meta") - - @classmethod - def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor: - m = m.detach() - if not m.is_meta: - m = m.to("meta") - m.dtype = dtype - return m + def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor: + return torch.empty(size=shape, dtype=dtype, device="meta") @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -2464,8 +2447,8 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16", - help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( "--bigendian", action="store_true", @@ -2523,6 +2506,7 @@ def main() -> None: "f32": gguf.LlamaFileType.ALL_F32, "f16": gguf.LlamaFileType.MOSTLY_F16, "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, "auto": gguf.LlamaFileType.GUESSED, } |