summaryrefslogtreecommitdiff
path: root/convert-hf-to-gguf.py
diff options
context:
space:
mode:
Diffstat (limited to 'convert-hf-to-gguf.py')
-rwxr-xr-xconvert-hf-to-gguf.py96
1 files changed, 96 insertions, 0 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index 63710676..e1ac09e0 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -1427,6 +1427,102 @@ class GrokModel(Model):
self.gguf_writer.add_tensor(new_name, data)
+@Model.register("DbrxForCausalLM")
+class DbrxModel(Model):
+ model_arch = gguf.MODEL_ARCH.DBRX
+
+ def set_gguf_parameters(self):
+ ffn_config = self.hparams["ffn_config"]
+ attn_config = self.hparams["attn_config"]
+ self.gguf_writer.add_name(self.hparams["model_type"])
+ self.gguf_writer.add_block_count(self.hparams["n_layers"])
+
+ self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
+ self.gguf_writer.add_embedding_length(self.hparams["d_model"])
+ self.gguf_writer.add_feed_forward_length(ffn_config["ffn_hidden_size"])
+
+ self.gguf_writer.add_head_count(self.hparams["n_heads"])
+ self.gguf_writer.add_head_count_kv(attn_config["kv_n_heads"])
+
+ self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])
+
+ self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
+ self.gguf_writer.add_file_type(self.ftype)
+
+ self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
+ self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"])
+
+ self.gguf_writer.add_layer_norm_eps(1e-5)
+
+ self.gguf_writer.add_file_type(self.ftype)
+ print(f"gguf: file type = {self.ftype}")
+
+ def write_tensors(self):
+ block_count = self.hparams.get("n_layers")
+ tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+ for name, data_torch in self.get_tensors():
+ n_expert = self.hparams["ffn_config"]["moe_num_experts"]
+ n_ff = self.hparams["ffn_config"]["ffn_hidden_size"]
+ n_embd = self.hparams["d_model"]
+
+ # Specific behavior for experts tensors: suffix .weight, view as 3D and transpose
+ # original implementation expects (n_expert, n_ff, n_embd) for all experts weights
+ # But llama.cpp moe graph works differently
+ # AND the dimensions in ggml are typically in the reverse order of the pytorch dimensions
+ # so (n_expert, n_ff, n_embd) in pytorch is {n_embd, n_ff, n_expert} in ggml_tensor
+ exp_tensor_names = {"ffn.experts.mlp.w1": None, # LLM_TENSOR_FFN_GATE_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert}
+ "ffn.experts.mlp.w2": (0, 2, 1), # LLM_TENSOR_FFN_DOWN_EXPS ggml_tensor->ne{n_ff, n_embd, n_expert}
+ "ffn.experts.mlp.v1": None} # LLM_TENSOR_FFN_UP_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert}
+ experts = False
+ for exp_tensor_name in exp_tensor_names.keys():
+ if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
+ experts = True
+ data_torch = data_torch.view(n_expert, n_ff, n_embd)
+ if (permute_tensor := exp_tensor_names[exp_tensor_name]) is not None:
+ data_torch = data_torch.permute(*permute_tensor)
+ break
+
+ old_dtype = data_torch.dtype
+
+ # convert any unsupported data types to float32
+ if data_torch.dtype not in (torch.float16, torch.float32):
+ data_torch = data_torch.to(torch.float32)
+
+ data = data_torch.squeeze().numpy()
+
+ # map tensor names
+ # In MoE models the ffn tensors are typically most of the model weights,
+ # and need to be quantizable. Quantize expects tensor names to be suffixed by .weight.
+ # Every other model has the weight names ending in .weight,
+ # let's assume that is the convention which is not the case for dbrx:
+ # https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15
+ new_name = tensor_map.get_name(name if not experts else name + ".weight", try_suffixes=(".weight",))
+ if new_name is None:
+ print(f"Can not map tensor {name!r}")
+ sys.exit()
+
+ n_dims = len(data.shape)
+ data_dtype = data.dtype
+
+ # Most of the codebase that takes in 1D tensors only handles F32 tensors
+ # and most of the outputs tensors are F32.
+ if data_dtype != np.float32 and n_dims == 1:
+ print(f"Can not map tensor {name!r}: all 1D tensors must be F32")
+ sys.exit()
+
+ # if f32 desired, convert any float16 to float32
+ if self.ftype == 0 and data_dtype == np.float16:
+ data = data.astype(np.float32)
+
+ # if f16 desired, convert any float32 2-dim weight tensors to float16
+ if self.ftype == 1 and data_dtype == np.float32 and n_dims > 1:
+ data = data.astype(np.float16)
+
+ print(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
+
+ self.gguf_writer.add_tensor(new_name, data)
+
+
@Model.register("MiniCPMForCausalLM")
class MiniCPMModel(Model):
model_arch = gguf.MODEL_ARCH.MINICPM