summaryrefslogtreecommitdiff
path: root/convert_hf_to_gguf.py
diff options
context:
space:
mode:
Diffstat (limited to 'convert_hf_to_gguf.py')
-rwxr-xr-xconvert_hf_to_gguf.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 3910aa1d..1ee82724 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -3123,6 +3123,7 @@ class ArcticModel(Model):
@Model.register("DeepseekV2ForCausalLM")
+@Model.register("DeepseekV3ForCausalLM")
class DeepseekV2Model(Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
@@ -3144,6 +3145,15 @@ class DeepseekV2Model(Model):
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
+ self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
+
+ if hparams["scoring_func"] == "sigmoid":
+ self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
+ elif hparams["scoring_func"] == "softmax":
+ self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
+ else:
+ raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
+
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
@@ -3156,6 +3166,17 @@ class DeepseekV2Model(Model):
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ # rename e_score_correction_bias tensors
+ if name.endswith("e_score_correction_bias"):
+ name = name.replace("e_score_correction_bias", "e_score_correction.bias")
+
+ # skip Multi-Token Prediction (MTP) layers
+ block_count = self.hparams["num_hidden_layers"]
+ match = re.match(r"model.layers.(\d+)", name)
+ if match and int(match.group(1)) >= block_count:
+ return []
+
+
# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["n_routed_experts"]
@@ -3188,6 +3209,27 @@ class DeepseekV2Model(Model):
return tensors
else:
return []
+ if name.endswith("kv_b_proj.weight"):
+ name_kb = name.replace("kv_b_proj", "k_b_proj")
+ name_vb = name.replace("kv_b_proj", "v_b_proj")
+
+ n_head_kv = self.hparams["num_key_value_heads"]
+ v_head_dim = self.hparams["v_head_dim"]
+ qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
+
+ assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
+
+ kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
+ k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
+ k_b = k_b.transpose(1, 2)
+ k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
+ v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])
+
+ return [
+ (self.map_tensor_name(name), data_torch),
+ (self.map_tensor_name(name_kb), k_b),
+ (self.map_tensor_name(name_vb), v_b)
+ ]
return [(self.map_tensor_name(name), data_torch)]