diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-04-10 09:05:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-04-10 09:05:21 +0200 |
commit | 474435f58b6a26bc549589966482207fee94aa60 (patch) | |
tree | 6eb8ab42a565111cc46b1be6e79c450826b01b07 /src/llama.cpp | |
parent | 5f44f4b3d006a24267ea02fe65490bb760a01447 (diff) |
LlaMA-4 support (text only) (#321)
* llama4: WIP
* llama4: this seems to be working
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'src/llama.cpp')
-rw-r--r-- | src/llama.cpp | 288 |
1 files changed, 268 insertions, 20 deletions
diff --git a/src/llama.cpp b/src/llama.cpp index 5f4642fb..5565ffcd 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -184,6 +184,7 @@ static std::string format(const char * fmt, ...) { enum llm_arch { LLM_ARCH_LLAMA, + LLM_ARCH_LLAMA4, LLM_ARCH_FALCON, LLM_ARCH_BAICHUAN, LLM_ARCH_GROK, @@ -232,6 +233,7 @@ enum llm_arch { static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_LLAMA4, "llama4" }, { LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_GROK, "grok" }, { LLM_ARCH_GPT2, "gpt2" }, @@ -319,6 +321,8 @@ enum llm_kv { LLM_KV_TIME_DECAY_EXTRA_DIM, LLM_KV_RESIDUAL_SCALE, LLM_KV_EMBEDDING_SCALE, + LLM_KV_TOKEN_SHIFT_COUNT, + LLM_KV_INTERLEAVE_MOE_LAYER_STEP, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -422,6 +426,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, + { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, + { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -615,6 +621,35 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA }, }, { + LLM_ARCH_LLAMA4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, + { LLM_ARCH_BAICHUAN, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, @@ -1432,6 +1467,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_GRANITE, LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_MEGREZ, + LLM_CHAT_TEMPLATE_LLAMA4, LLM_CHAT_TEMPLATE_UNKNOWN, }; @@ -1467,6 +1503,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = { { "granite", LLM_CHAT_TEMPLATE_GRANITE }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, + { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, }; @@ -2357,6 +2394,8 @@ enum e_model { MODEL_10B_128x3_66B, MODEL_57B_A14B, MODEL_27B, + MODEL_17B_16E, + MODEL_17B_128E, }; static const size_t kiB = 1024; @@ -2444,6 +2483,14 @@ struct llama_hparams { bool use_alibi = false; bool attn_soft_cap = false; + uint32_t n_moe_layer_step = 0; + bool use_kq_norm = true; + uint32_t n_attn_chunk = 0; + // values below seems to be fixed on llama4 + uint32_t n_no_rope_layer_step = 4; + uint32_t n_attn_temp_floor_scale = 8192; + float f_attn_temp_scale = 0.1; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = -1; @@ -2939,6 +2986,8 @@ struct llama_context { struct llama_kv_cache kv_self; struct llama_control_vector cvec; + std::vector<float> scale_data; + std::unordered_map<struct llama_lora_adapter *, float> lora_adapters; std::vector<ggml_backend_t> backends; @@ -3015,6 +3064,7 @@ struct llama_context { struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] + struct ggml_tensor * inp_scale = nullptr; // F32 [n_tokens] }; struct llama_lora_weight { @@ -4945,6 +4995,8 @@ static const char * llama_model_type_name(e_model type) { case MODEL_10B_128x3_66B: return "10B+128x3.66B"; case MODEL_57B_A14B: return "57B.A14B"; case MODEL_27B: return "27B"; + case MODEL_17B_16E: return "17Bx16E (Scout)"; + case MODEL_17B_128E: return "17Bx128E (Maverick)"; default: return "?B"; } } @@ -5106,6 +5158,25 @@ static void llm_load_hparams( } } } break; + case LLM_ARCH_LLAMA4: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full + hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick + hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later + + switch (hparams.n_expert) { + case 16: model.type = MODEL_17B_16E; break; + case 128: model.type = MODEL_17B_128E; break; + default: model.type = MODEL_UNKNOWN; + } + + if (model.type == MODEL_17B_128E) { + hparams.use_kq_norm = false; + } + } break; case LLM_ARCH_MINICPM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5929,6 +6000,23 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "codeshell") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL; + } else if ( + tokenizer_pre == "gpt-4o" || + tokenizer_pre == "llama4") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT4O; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "superbpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "trillion") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TRILLION; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "bailingmoe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; + vocab.tokenizer_clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -6552,6 +6640,7 @@ static bool llm_load_tensors( const int64_t n_embd_gqa = n_embd_v_gqa; const int64_t n_vocab = hparams.n_vocab; const int64_t n_vocab_type = hparams.n_vocab_type; + const int64_t n_rot = hparams.n_rot; const int64_t n_expert = hparams.n_expert; const int64_t n_expert_used = hparams.n_expert_used; const int64_t n_ctx_train = hparams.n_ctx_train; @@ -6670,6 +6759,61 @@ static bool llm_load_tensors( } } } break; + case LLM_ARCH_LLAMA4: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, + llama_model_loader::TENSOR_NOT_REQUIRED); + + // output + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, + llama_model_loader::TENSOR_DUPLICATED); + } + + GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0"); + for (int i = 0; i < n_layer; ++i) { + bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0; + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, + llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + + if (is_moe_layer) { + int n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_inp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert + const int64_t n_ff_shexp = n_ff_exp; + layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } else { + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } + } break; case LLM_ARCH_GROK: { if (n_expert == 0) { @@ -8623,8 +8767,6 @@ static void llm_build_kv_store( // note: storing RoPE-ed version of K in the KV cache ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - struct ggml_tensor * v_cache_view = nullptr; if (cparams.flash_attn) { @@ -8884,6 +9026,7 @@ llm_expert_gating_func_type gating_op, int il) { int64_t n_embd = cur->ne[0]; int64_t n_tokens = cur->ne[1]; + bool weight_before_ffn = lctx.model.arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens] cb(logits, "ffn_moe_logits", il); @@ -8912,6 +9055,12 @@ llm_expert_gating_func_type gating_op, cb(selection_probs, "ffn_moe_probs_biased", il); } + // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k + // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198 + if (lctx.model.arch == LLM_ARCH_LLAMA4) { + selection_probs = logits; + } + // select experts ggml_tensor * selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used, lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens] @@ -8940,6 +9089,14 @@ llm_expert_gating_func_type gating_op, cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); + if (weight_before_ffn) { + // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d) + ggml_tensor * repeated = ggml_new_tensor_3d(ctx, cur->type, n_embd, n_expert_used, n_tokens); + repeated = ggml_repeat(ctx, cur, repeated); // [n_embd, n_expert_used, n_tokens] + cur = ggml_mul(ctx, repeated, weights); + cb(cur, "ffn_moe_weighted", il); + } + ggml_tensor * par; if (lctx.cparams.fused_moe_up_gate) { par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); @@ -8958,7 +9115,10 @@ llm_expert_gating_func_type gating_op, ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); - experts = ggml_mul(ctx, experts, weights); + if (!weight_before_ffn) { + experts = ggml_mul(ctx, experts, weights); + cb(cur, "ffn_moe_weighted", il); + } if (n_expert_used == 1) { return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0)); @@ -9487,6 +9647,14 @@ struct llm_build_context { return lctx.inp_pos; } + struct ggml_tensor * build_inpup_scale(int n_tokens) { + int n_pos_per_token = 1; + lctx.inp_scale = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token); + cb(lctx.inp_scale, "inp_scale", -1); + ggml_set_input(lctx.inp_scale); + return lctx.inp_scale; + } + struct ggml_tensor * build_rope_factors(int il) { // choose long/short freq factors based on the context size const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max; @@ -9667,14 +9835,19 @@ struct llm_build_context { GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); - struct ggml_tensor * cur; - struct ggml_tensor * inpL; + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_attn_scale = nullptr; inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); + if (model.arch == LLM_ARCH_LLAMA4) { + inp_attn_scale = build_inpup_scale(n_tokens); + } + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); @@ -9683,6 +9856,8 @@ struct llm_build_context { for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; + bool use_rope = model.arch == LLM_ARCH_LLAMA4 ? (il + 1) % hparams.n_no_rope_layer_step != 0 : true; + // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, @@ -9720,19 +9895,29 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_ext( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - cb(Qcur, "Qcur", il); + if (use_rope) { + Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } else if (inp_attn_scale) { + Qcur = ggml_mul(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_attn_scale); + } + + cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + if (model.arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) { + // Llama4TextL2Norm + Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6); + Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + } cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, @@ -9758,6 +9943,7 @@ struct llm_build_context { // feed-forward network if (model.layers[il].ffn_gate_inp == nullptr) { + // non-MoE cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); @@ -9770,6 +9956,37 @@ struct llm_build_context { NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); + } else if (model.arch == LLM_ARCH_LLAMA4) { + // llama4 MoE + ggml_tensor * ffn_inp_normed = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = llm_build_moe_ffn(ctx0, lctx, ffn_inp_normed, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + LLM_EXPERT_GATING_FUNC_SIGMOID, + cb, il); + + // Shared experts + ggml_tensor * shexp_out = llm_build_ffn(ctx0, lctx, ffn_inp_normed, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(shexp_out, "ffn_moe_shexp", il); + + cur = ggml_add(ctx0, moe_out, shexp_out); + cb(cur, "ffn_moe_out_merged", il); + } else { // MoE branch cur = llm_build_norm(ctx0, ffn_inp, hparams, @@ -9782,11 +9999,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, - nullptr, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, - LLM_EXPERT_GATING_FUNC_SOFTMAX, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); } @@ -15162,6 +15379,7 @@ static struct ggml_cgraph * llama_build_graph( switch (model.arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: { @@ -15423,6 +15641,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } + if (lctx.inp_pos && lctx.inp_scale) { + int n_tokens = batch.n_tokens; + GGML_ASSERT(ggml_nelements(lctx.inp_scale) >= n_tokens); + if (int(lctx.scale_data.size()) < n_tokens) lctx.scale_data.resize(n_tokens); + int n_pos_per_token = 1; + for (int i = 0; i < n_tokens; ++i) { + lctx.scale_data[i] = std::log(std::floor((batch.pos[i] + 1.0f) / hparams.n_attn_temp_floor_scale) + 1.0f) * hparams.f_attn_temp_scale + 1.0f; + } + ggml_backend_tensor_set(lctx.inp_scale, lctx.scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(lctx.inp_scale)); + } + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); const int64_t n_tokens = batch.n_tokens; @@ -15504,8 +15733,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // may need to cut off old tokens for sliding window if (data_swa) { - if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { - f = -INFINITY; + if (hparams.n_attn_chunk) { + llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; + if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { + f = -INFINITY; + } + } else { + if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } } data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; } @@ -18949,6 +19185,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA4: case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: case LLM_ARCH_PLAMO: @@ -20774,6 +21011,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_GIGACHAT; } else if (tmpl_contains("<|role_start|>")) { return LLM_CHAT_TEMPLATE_MEGREZ; + } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) { + return LLM_CHAT_TEMPLATE_LLAMA4; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -21155,6 +21394,15 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|role_start|>assistant<|role_end|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) { + // Llama 4 + for (auto message : chat) { + std::string role(message->role); + ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>"; + } + if (add_ass) { + ss << "<|header_start|>assistant<|header_end|>\n\n"; + } } else { // template not supported return -1; |