summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorȘtefan-Gabriel Muscalu <legraphista@users.noreply.github.com>2024-06-17 22:08:46 +0300
committerGitHub <noreply@github.com>2024-06-17 21:08:46 +0200
commita94e6ff8774b7c9f950d9545baf0ce35e8d1ed2f (patch)
treeabfa71d6bf6b3743185ead9f9c337c80c49acc04
parent5b6da187508f49a9fa9d95fa22ae804a0780d256 (diff)
update: support Qwen2-57B-A14B (#7835)
* update: convert-hf-to-gguf.py to support Qwen2-57B-A14B * fix: QWEN2MOE support for expert_feed_forward_length previously, expert ff was taken from n_ff (intermediate size) but it is now properly taken from LLM_KV_EXPERT_FEED_FORWARD_LENGTH n_ff_exp and n_ff_shared_exp are now properly calculated * update: convert-hf-to-gguf.py cleanup for Qwen2MoeForCausalLM * fix: QWEN2MOE support for expert_feed_forward_length previously, expert ff was taken from n_ff (intermediate size) but it is now properly taken from LLM_KV_EXPERT_FEED_FORWARD_LENGTH n_ff_exp and n_ff_shexp are now properly calculated
-rwxr-xr-xconvert-hf-to-gguf.py6
-rw-r--r--gguf-py/gguf/constants.py31
-rw-r--r--gguf-py/gguf/gguf_writer.py3
-rw-r--r--llama.cpp51
4 files changed, 57 insertions, 34 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index 55ce502d..a6751cc8 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -1632,6 +1632,12 @@ class Qwen2MoeModel(Model):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
+ if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
+ self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
+ logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
+ if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None:
+ self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size)
+ logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}")
_experts: list[dict[str, Tensor]] | None = None
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 8908585c..fb20cfab 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -33,21 +33,22 @@ class Keys:
FILE_TYPE = "general.file_type"
class LLM:
- VOCAB_SIZE = "{arch}.vocab_size"
- CONTEXT_LENGTH = "{arch}.context_length"
- EMBEDDING_LENGTH = "{arch}.embedding_length"
- BLOCK_COUNT = "{arch}.block_count"
- LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
- FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
- EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
- USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
- TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
- EXPERT_COUNT = "{arch}.expert_count"
- EXPERT_USED_COUNT = "{arch}.expert_used_count"
- EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
- EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
- POOLING_TYPE = "{arch}.pooling_type"
- LOGIT_SCALE = "{arch}.logit_scale"
+ VOCAB_SIZE = "{arch}.vocab_size"
+ CONTEXT_LENGTH = "{arch}.context_length"
+ EMBEDDING_LENGTH = "{arch}.embedding_length"
+ BLOCK_COUNT = "{arch}.block_count"
+ LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
+ FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
+ EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
+ EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length"
+ USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
+ TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
+ EXPERT_COUNT = "{arch}.expert_count"
+ EXPERT_USED_COUNT = "{arch}.expert_used_count"
+ EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
+ EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
+ POOLING_TYPE = "{arch}.pooling_type"
+ LOGIT_SCALE = "{arch}.logit_scale"
class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index ed56abfb..a697f657 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -394,6 +394,9 @@ class GGUFWriter:
def add_expert_feed_forward_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
+ def add_expert_shared_feed_forward_length(self, length: int) -> None:
+ self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
+
def add_parallel_residual(self, use: bool) -> None:
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
diff --git a/llama.cpp b/llama.cpp
index 61948751..e06c851a 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -286,6 +286,7 @@ enum llm_kv {
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
LLM_KV_FEED_FORWARD_LENGTH,
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
+ LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
LLM_KV_USE_PARALLEL_RESIDUAL,
LLM_KV_TENSOR_DATA_LAYOUT,
LLM_KV_EXPERT_COUNT,
@@ -364,21 +365,22 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_GENERAL_SOURCE_URL, "general.source.url" },
{ LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" },
- { LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
- { LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
- { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
- { LLM_KV_BLOCK_COUNT, "%s.block_count" },
- { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
- { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
- { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
- { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
- { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
- { LLM_KV_EXPERT_COUNT, "%s.expert_count" },
- { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
- { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
- { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
- { LLM_KV_POOLING_TYPE , "%s.pooling_type" },
- { LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
+ { LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
+ { LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
+ { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
+ { LLM_KV_BLOCK_COUNT, "%s.block_count" },
+ { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
+ { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
+ { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
+ { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" },
+ { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
+ { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
+ { LLM_KV_EXPERT_COUNT, "%s.expert_count" },
+ { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
+ { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
+ { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
+ { LLM_KV_POOLING_TYPE , "%s.pooling_type" },
+ { LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -1970,6 +1972,7 @@ struct llama_hparams {
uint32_t n_lora_q = 0;
uint32_t n_lora_kv = 0;
uint32_t n_ff_exp = 0;
+ uint32_t n_ff_shexp = 0;
uint32_t n_expert_shared = 0;
float expert_weights_scale = 0.0;
@@ -2018,6 +2021,7 @@ struct llama_hparams {
if (this->n_lora_q != other.n_lora_q) return true;
if (this->n_lora_kv != other.n_lora_kv) return true;
if (this->n_ff_exp != other.n_ff_exp) return true;
+ if (this->n_ff_shexp != other.n_ff_shexp) return true;
if (this->n_expert_shared != other.n_expert_shared) return true;
if (this->rope_finetuned != other.rope_finetuned) return true;
@@ -4455,6 +4459,9 @@ static void llm_load_hparams(
} break;
case LLM_ARCH_QWEN2MOE:
{
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
+ ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 24: model.type = e_model::MODEL_A2_7B; break;
@@ -5240,6 +5247,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
}
+
+ if (model.arch == LLM_ARCH_QWEN2MOE) {
+ LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
+ LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
+ }
}
// Returns false if cancelled by progress_callback
@@ -6026,16 +6038,17 @@ static bool llm_load_tensors(
GGML_ASSERT(hparams.n_expert_used > 0);
// MoE branch
- auto n_ff_exp = n_ff / hparams.n_expert_used;
+ auto n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / hparams.n_expert_used;
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert});
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
// Shared expert branch
+ auto n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd});
- layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff});
- layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff, n_embd});
- layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff});
+ layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp});
+ layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd});
+ layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp});
}
} break;
case LLM_ARCH_PHI2: