summaryrefslogtreecommitdiff
path: root/convert-hf-to-gguf.py
diff options
context:
space:
mode:
authorfairydreaming <166155368+fairydreaming@users.noreply.github.com>2024-05-28 17:07:05 +0200
committerGitHub <noreply@github.com>2024-05-28 17:07:05 +0200
commitee3dff6b8e39bb8c1cdea1782a7b95ef0118f970 (patch)
tree28eaea501c6c929f98442cc451eb14b380c02de6 /convert-hf-to-gguf.py
parentedc29433fa08b4e5aeb67649a29fc7713af13d04 (diff)
Add support for DeepseekV2ForCausalLM (#7519)
* common : increase max number of experts to 160 * common : add tensors ATTN_Q_A, ATTN_Q_A_NORM, ATTN_Q_B, ATTN_KV_A_MQA, ATTN_KV_A_NORM, ATTN_KV_B needed by DeepSeek-V2 MLA (multi-head latent attention) architecture * common : add model header parameters: leading_dense_block_count, expert_feed_forward_length, expert_shared_count, expert_weights_scale, attention.q_lora_rank, attention.kv_lora_rank, rope.scaling.yarn_log_multiplier * convert-hf : add model conversion support for DeepseekV2ForCausalLM * llama : add model types for DeepSeek-V2 and DeepSeek-V2-Lite models * llama : add two new llm_build_moe_ffn() arguments: scale_w (whether to scale weights of selected MoE experts) and w_scale (numerical value of the scaling factor) * llama : add inference support for LLM_ARCH_DEEPSEEK2 --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Diffstat (limited to 'convert-hf-to-gguf.py')
-rwxr-xr-xconvert-hf-to-gguf.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index a342f6b1..1b060e4e 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -2620,6 +2620,85 @@ class ArcticModel(Model):
raise ValueError(f"Unprocessed experts: {experts}")
+@Model.register("DeepseekV2ForCausalLM")
+class DeepseekV2Model(Model):
+ model_arch = gguf.MODEL_ARCH.DEEPSEEK2
+
+ def set_vocab(self):
+ self._set_vocab_gpt2()
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ hparams = self.hparams
+
+ self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
+ self.gguf_writer.add_vocab_size(hparams["vocab_size"])
+ if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
+ self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
+ self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
+ self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
+ self.gguf_writer.add_value_length(hparams["v_head_dim"])
+ self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
+ self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
+ self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
+ self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
+ self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
+
+ if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
+ if self.hparams["rope_scaling"].get("type") == "yarn":
+ self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
+ self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
+ self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
+ self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * hparams["rope_scaling"]["mscale_all_dim"])
+
+ _experts: list[dict[str, Tensor]] | None = None
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ # process the experts separately
+ if name.find("mlp.experts") != -1:
+ n_experts = self.hparams["n_routed_experts"]
+ assert bid is not None
+
+ if self._experts is None:
+ self._experts = [{} for _ in range(self.block_count)]
+
+ self._experts[bid][name] = data_torch
+
+ if len(self._experts[bid]) >= n_experts * 3:
+ tensors: list[tuple[str, Tensor]] = []
+
+ # merge the experts into a single 3d tensor
+ for w_name in ["down_proj", "gate_proj", "up_proj"]:
+ datas: list[Tensor] = []
+
+ for xid in range(n_experts):
+ ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
+ datas.append(self._experts[bid][ename])
+ del self._experts[bid][ename]
+
+ data_torch = torch.stack(datas, dim=0)
+
+ merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
+
+ new_name = self.map_tensor_name(merged_name)
+
+ tensors.append((new_name, data_torch))
+ return tensors
+ else:
+ return []
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+ def write_tensors(self):
+ super().write_tensors()
+
+ if self._experts is not None:
+ # flatten `list[dict[str, Tensor]]` into `list[str]`
+ experts = [k for d in self._experts for k in d.keys()]
+ if len(experts) > 0:
+ raise ValueError(f"Unprocessed experts: {experts}")
+
+
###### CONVERSION LOGIC ######