summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp252
1 files changed, 243 insertions, 9 deletions
diff --git a/llama.cpp b/llama.cpp
index 38e59362..340e68fd 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -105,7 +105,7 @@
#endif
#define LLAMA_MAX_NODES 8192
-#define LLAMA_MAX_EXPERTS 16
+#define LLAMA_MAX_EXPERTS 60
//
@@ -209,6 +209,7 @@ enum llm_arch {
LLM_ARCH_STABLELM,
LLM_ARCH_QWEN,
LLM_ARCH_QWEN2,
+ LLM_ARCH_QWEN2MOE,
LLM_ARCH_PHI2,
LLM_ARCH_PLAMO,
LLM_ARCH_CODESHELL,
@@ -242,6 +243,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_STABLELM, "stablelm" },
{ LLM_ARCH_QWEN, "qwen" },
{ LLM_ARCH_QWEN2, "qwen2" },
+ { LLM_ARCH_QWEN2MOE, "qwen2moe" },
{ LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PLAMO, "plamo" },
{ LLM_ARCH_CODESHELL, "codeshell" },
@@ -437,6 +439,7 @@ enum llm_tensor {
LLM_TENSOR_ATTN_OUT_NORM,
LLM_TENSOR_ATTN_ROT_EMBD,
LLM_TENSOR_FFN_GATE_INP,
+ LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN,
@@ -448,6 +451,9 @@ enum llm_tensor {
LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
+ LLM_TENSOR_FFN_DOWN_SHEXP,
+ LLM_TENSOR_FFN_GATE_SHEXP,
+ LLM_TENSOR_FFN_UP_SHEXP,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_LAYER_OUT_NORM,
@@ -746,6 +752,28 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
},
},
{
+ LLM_ARCH_QWEN2MOE,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
+ { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
+ },
+ },
+ {
LLM_ARCH_PHI2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
@@ -1731,6 +1759,7 @@ enum e_model {
MODEL_MEDIUM,
MODEL_LARGE,
MODEL_XL,
+ MODEL_A2_7B,
MODEL_8x7B,
MODEL_8x22B,
MODEL_16x12B,
@@ -1917,6 +1946,12 @@ struct llama_layer {
struct ggml_tensor * ffn_down_exps;
struct ggml_tensor * ffn_up_exps ;
+ // ff shared expert (shexp)
+ struct ggml_tensor * ffn_gate_inp_shexp;
+ struct ggml_tensor * ffn_gate_shexp;
+ struct ggml_tensor * ffn_down_shexp;
+ struct ggml_tensor * ffn_up_shexp;
+
// ff bias
struct ggml_tensor * ffn_down_b; // b2
struct ggml_tensor * ffn_up_b; // b3
@@ -3587,6 +3622,7 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_MEDIUM: return "0.4B";
case MODEL_LARGE: return "0.8B";
case MODEL_XL: return "1.5B";
+ case MODEL_A2_7B: return "A2.7B";
case MODEL_8x7B: return "8x7B";
case MODEL_8x22B: return "8x22B";
case MODEL_16x12B: return "16x12B";
@@ -3886,6 +3922,14 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_QWEN2MOE:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ switch (hparams.n_layer) {
+ case 24: model.type = e_model::MODEL_A2_7B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_PHI2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -5156,6 +5200,54 @@ static bool llm_load_tensors(
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
}
} break;
+ case LLM_ARCH_QWEN2MOE:
+ {
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+ // output
+ {
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+ // optional bias tensors
+ layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
+ layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa});
+ layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa});
+
+ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+
+ layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
+
+ GGML_ASSERT(hparams.n_expert > 0);
+ GGML_ASSERT(hparams.n_expert_used > 0);
+
+ // MoE branch
+ auto n_ff_exp = n_ff / hparams.n_expert_used;
+ layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
+ layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert});
+ layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
+
+ // Shared expert branch
+ layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd});
+ layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff});
+ layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff, n_embd});
+ layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff});
+ }
+ } break;
case LLM_ARCH_PHI2:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -6532,7 +6624,7 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
- cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il);
+ cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, true, il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
@@ -6565,7 +6657,7 @@ struct llm_build_context {
}
// REVIEW: will be replaced by https://github.com/ggerganov/llama.cpp/pull/6505
- ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, int il) {
+ ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, bool norm_w, int il) {
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
cb(logits, "ffn_moe_logits", il);
@@ -6582,11 +6674,13 @@ struct llm_build_context {
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
- ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
- cb(weights_sum, "ffn_moe_weights_sum", il);
+ if (norm_w) {
+ ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
+ cb(weights_sum, "ffn_moe_weights_sum", il);
- weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
- cb(weights, "ffn_moe_weights_norm", il);
+ weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
+ cb(weights, "ffn_moe_weights_norm", il);
+ }
// compute expert outputs
ggml_tensor * moe_out = nullptr;
@@ -7083,7 +7177,7 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
- cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, il);
+ cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, true, il);
// Grok
// if layer_out_norm is present then apply it before adding the input
@@ -7219,7 +7313,7 @@ struct llm_build_context {
LLM_NORM, cb, il);
cb(cur, "attn_out_norm", il);
- cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il);
+ cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, true, il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
@@ -8434,6 +8528,141 @@ struct llm_build_context {
return gf;
}
+ struct ggml_cgraph * build_qwen2moe() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
+ int32_t n_tokens = this->n_tokens;
+
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = build_inp_pos();
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // self_attention
+ {
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_rope_custom(
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_rope_custom(
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Kcur, "Kcur", il);
+
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ model.layers[il].wo, model.layers[il].bo,
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ }
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // MoE branch
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ ggml_tensor * moe_out = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, false, il);
+
+ // FFN shared expert
+ {
+ ggml_tensor * cur_gate_inp = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
+ cb(cur_gate_inp, "ffn_shexp_gate_inp", il);
+
+ // sigmoid
+ ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp);
+ cb(cur_gate, "ffn_shexp_gate", il);
+
+ ggml_tensor * cur_ffn = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up_shexp, NULL,
+ model.layers[il].ffn_gate_shexp, NULL,
+ model.layers[il].ffn_down_shexp, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+ cb(cur_ffn, "ffn_shexp", il);
+
+ ggml_tensor * ffn_shexp_out = ggml_mul(ctx0, cur_ffn, cur_gate);
+ cb(ffn_shexp_out, "ffn_shexp_out", il);
+
+ moe_out = ggml_add(ctx0, moe_out, ffn_shexp_out);
+ cb(moe_out, "ffn_out", il);
+
+ cur = moe_out;
+ }
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
+
struct ggml_cgraph * build_phi2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
@@ -9917,6 +10146,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_qwen2();
} break;
+ case LLM_ARCH_QWEN2MOE:
+ {
+ result = llm.build_qwen2moe();
+ } break;
case LLM_ARCH_PHI2:
{
result = llm.build_phi2();
@@ -14834,6 +15067,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_STABLELM:
case LLM_ARCH_QWEN:
case LLM_ARCH_QWEN2:
+ case LLM_ARCH_QWEN2MOE:
case LLM_ARCH_PHI2:
case LLM_ARCH_GEMMA:
case LLM_ARCH_STARCODER2: