summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/llama.h7
-rw-r--r--src/llama-vocab.cpp21
-rw-r--r--src/llama.cpp288
3 files changed, 295 insertions, 21 deletions
diff --git a/include/llama.h b/include/llama.h
index 3275857a..d7376d7d 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -100,7 +100,12 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
- LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 23, //llama.cpp lists this as 28
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, //llama.cpp lists this as 28
+ LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
+ LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
+ LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
+ LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
+ LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
};
// note: these values should be synchronized with ggml_rope
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index 4bd5aa81..09399417 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -439,6 +439,27 @@ struct llm_tokenizer_bpe {
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
+ case LLAMA_VOCAB_PRE_TYPE_GPT4O:
+ regex_exprs = {
+ // original regex from tokenizer.json
+ // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+ "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+ };
+ break;
+ case LLAMA_VOCAB_PRE_TYPE_SUPERBPE:
+ regex_exprs = {
+ "\\p{N}+",
+ "(?=(\\d{3})+(?!\\d))",
+ };
+ break;
+ case LLAMA_VOCAB_PRE_TYPE_BAILINGMOE:
+ regex_exprs = {
+ // original regex from tokenizer.json
+ // "'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+"
+ // FIXME? Changed possessive quantifiers (?+ and ++) to greedy to avoid errors and imatrix hanging (tried atomic grouping but it's not supported?)
+ "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
+ };
+ break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
diff --git a/src/llama.cpp b/src/llama.cpp
index 5f4642fb..5565ffcd 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -184,6 +184,7 @@ static std::string format(const char * fmt, ...) {
enum llm_arch {
LLM_ARCH_LLAMA,
+ LLM_ARCH_LLAMA4,
LLM_ARCH_FALCON,
LLM_ARCH_BAICHUAN,
LLM_ARCH_GROK,
@@ -232,6 +233,7 @@ enum llm_arch {
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" },
+ { LLM_ARCH_LLAMA4, "llama4" },
{ LLM_ARCH_FALCON, "falcon" },
{ LLM_ARCH_GROK, "grok" },
{ LLM_ARCH_GPT2, "gpt2" },
@@ -319,6 +321,8 @@ enum llm_kv {
LLM_KV_TIME_DECAY_EXTRA_DIM,
LLM_KV_RESIDUAL_SCALE,
LLM_KV_EMBEDDING_SCALE,
+ LLM_KV_TOKEN_SHIFT_COUNT,
+ LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -422,6 +426,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
+ { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
+ { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -615,6 +621,35 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
},
},
{
+ LLM_ARCH_LLAMA4,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
+ { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
+ { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
+ },
+ },
+ {
LLM_ARCH_BAICHUAN,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
@@ -1432,6 +1467,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_GRANITE,
LLM_CHAT_TEMPLATE_GIGACHAT,
LLM_CHAT_TEMPLATE_MEGREZ,
+ LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_UNKNOWN,
};
@@ -1467,6 +1503,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
+ { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
};
@@ -2357,6 +2394,8 @@ enum e_model {
MODEL_10B_128x3_66B,
MODEL_57B_A14B,
MODEL_27B,
+ MODEL_17B_16E,
+ MODEL_17B_128E,
};
static const size_t kiB = 1024;
@@ -2444,6 +2483,14 @@ struct llama_hparams {
bool use_alibi = false;
bool attn_soft_cap = false;
+ uint32_t n_moe_layer_step = 0;
+ bool use_kq_norm = true;
+ uint32_t n_attn_chunk = 0;
+ // values below seems to be fixed on llama4
+ uint32_t n_no_rope_layer_step = 4;
+ uint32_t n_attn_temp_floor_scale = 8192;
+ float f_attn_temp_scale = 0.1;
+
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = -1;
@@ -2939,6 +2986,8 @@ struct llama_context {
struct llama_kv_cache kv_self;
struct llama_control_vector cvec;
+ std::vector<float> scale_data;
+
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
std::vector<ggml_backend_t> backends;
@@ -3015,6 +3064,7 @@ struct llama_context {
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
+ struct ggml_tensor * inp_scale = nullptr; // F32 [n_tokens]
};
struct llama_lora_weight {
@@ -4945,6 +4995,8 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_10B_128x3_66B: return "10B+128x3.66B";
case MODEL_57B_A14B: return "57B.A14B";
case MODEL_27B: return "27B";
+ case MODEL_17B_16E: return "17Bx16E (Scout)";
+ case MODEL_17B_128E: return "17Bx128E (Maverick)";
default: return "?B";
}
}
@@ -5106,6 +5158,25 @@ static void llm_load_hparams(
}
}
} break;
+ case LLM_ARCH_LLAMA4:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
+ ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
+ hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
+ hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
+ hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later
+
+ switch (hparams.n_expert) {
+ case 16: model.type = MODEL_17B_16E; break;
+ case 128: model.type = MODEL_17B_128E; break;
+ default: model.type = MODEL_UNKNOWN;
+ }
+
+ if (model.type == MODEL_17B_128E) {
+ hparams.use_kq_norm = false;
+ }
+ } break;
case LLM_ARCH_MINICPM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -5929,6 +6000,23 @@ static void llm_load_vocab(
} else if (
tokenizer_pre == "codeshell") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL;
+ } else if (
+ tokenizer_pre == "gpt-4o" ||
+ tokenizer_pre == "llama4") {
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT4O;
+ vocab.tokenizer_clean_spaces = false;
+ } else if (
+ tokenizer_pre == "superbpe") {
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SUPERBPE;
+ vocab.tokenizer_clean_spaces = false;
+ } else if (
+ tokenizer_pre == "trillion") {
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TRILLION;
+ vocab.tokenizer_clean_spaces = false;
+ } else if (
+ tokenizer_pre == "bailingmoe") {
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
+ vocab.tokenizer_clean_spaces = false;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}
@@ -6552,6 +6640,7 @@ static bool llm_load_tensors(
const int64_t n_embd_gqa = n_embd_v_gqa;
const int64_t n_vocab = hparams.n_vocab;
const int64_t n_vocab_type = hparams.n_vocab_type;
+ const int64_t n_rot = hparams.n_rot;
const int64_t n_expert = hparams.n_expert;
const int64_t n_expert_used = hparams.n_expert_used;
const int64_t n_ctx_train = hparams.n_ctx_train;
@@ -6670,6 +6759,61 @@ static bool llm_load_tensors(
}
}
} break;
+ case LLM_ARCH_LLAMA4:
+ {
+ model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+ model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab},
+ llama_model_loader::TENSOR_NOT_REQUIRED);
+
+ // output
+ model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+ // if output is NULL, init from the input tok embed
+ if (model.output == NULL) {
+ model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab},
+ llama_model_loader::TENSOR_DUPLICATED);
+ }
+
+ GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0");
+ for (int i = 0; i < n_layer; ++i) {
+ bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0;
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+ layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
+ layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
+ layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+ layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2},
+ llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+
+ if (is_moe_layer) {
+ int n_ff_exp = hparams.n_ff_exp;
+
+ layer.ffn_gate_inp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+ layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
+ layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
+ layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
+
+ // Shared expert
+ const int64_t n_ff_shexp = n_ff_exp;
+ layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
+ layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0);
+ layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
+ } else {
+ layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
+ layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+ }
+ }
+ } break;
case LLM_ARCH_GROK:
{
if (n_expert == 0) {
@@ -8623,8 +8767,6 @@ static void llm_build_kv_store(
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
- assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
-
struct ggml_tensor * v_cache_view = nullptr;
if (cparams.flash_attn) {
@@ -8884,6 +9026,7 @@ llm_expert_gating_func_type gating_op,
int il) {
int64_t n_embd = cur->ne[0];
int64_t n_tokens = cur->ne[1];
+ bool weight_before_ffn = lctx.model.arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
cb(logits, "ffn_moe_logits", il);
@@ -8912,6 +9055,12 @@ llm_expert_gating_func_type gating_op,
cb(selection_probs, "ffn_moe_probs_biased", il);
}
+ // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
+ // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
+ if (lctx.model.arch == LLM_ARCH_LLAMA4) {
+ selection_probs = logits;
+ }
+
// select experts
ggml_tensor * selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used,
lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens]
@@ -8940,6 +9089,14 @@ llm_expert_gating_func_type gating_op,
cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
+ if (weight_before_ffn) {
+ // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
+ ggml_tensor * repeated = ggml_new_tensor_3d(ctx, cur->type, n_embd, n_expert_used, n_tokens);
+ repeated = ggml_repeat(ctx, cur, repeated); // [n_embd, n_expert_used, n_tokens]
+ cur = ggml_mul(ctx, repeated, weights);
+ cb(cur, "ffn_moe_weighted", il);
+ }
+
ggml_tensor * par;
if (lctx.cparams.fused_moe_up_gate) {
par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
@@ -8958,7 +9115,10 @@ llm_expert_gating_func_type gating_op,
ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
- experts = ggml_mul(ctx, experts, weights);
+ if (!weight_before_ffn) {
+ experts = ggml_mul(ctx, experts, weights);
+ cb(cur, "ffn_moe_weighted", il);
+ }
if (n_expert_used == 1) {
return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
@@ -9487,6 +9647,14 @@ struct llm_build_context {
return lctx.inp_pos;
}
+ struct ggml_tensor * build_inpup_scale(int n_tokens) {
+ int n_pos_per_token = 1;
+ lctx.inp_scale = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token);
+ cb(lctx.inp_scale, "inp_scale", -1);
+ ggml_set_input(lctx.inp_scale);
+ return lctx.inp_scale;
+ }
+
struct ggml_tensor * build_rope_factors(int il) {
// choose long/short freq factors based on the context size
const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
@@ -9667,14 +9835,19 @@ struct llm_build_context {
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
- struct ggml_tensor * cur;
- struct ggml_tensor * inpL;
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+ ggml_tensor * inp_attn_scale = nullptr;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
+ if (model.arch == LLM_ARCH_LLAMA4) {
+ inp_attn_scale = build_inpup_scale(n_tokens);
+ }
+
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -9683,6 +9856,8 @@ struct llm_build_context {
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
+ bool use_rope = model.arch == LLM_ARCH_LLAMA4 ? (il + 1) % hparams.n_no_rope_layer_step != 0 : true;
+
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
@@ -9720,19 +9895,29 @@ struct llm_build_context {
cb(Vcur, "Vcur", il);
}
- Qcur = ggml_rope_ext(
- ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
- ext_factor, attn_factor, beta_fast, beta_slow
- );
- cb(Qcur, "Qcur", il);
+ if (use_rope) {
+ Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow);
- Kcur = ggml_rope_ext(
- ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
- ext_factor, attn_factor, beta_fast, beta_slow
- );
+ Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+ } else if (inp_attn_scale) {
+ Qcur = ggml_mul(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_attn_scale);
+ }
+
+ cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ if (model.arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) {
+ // Llama4TextL2Norm
+ Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6);
+ Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6);
+ cb(Qcur, "Qcur_normed", il);
+ cb(Kcur, "Kcur_normed", il);
+ }
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
@@ -9758,6 +9943,7 @@ struct llm_build_context {
// feed-forward network
if (model.layers[il].ffn_gate_inp == nullptr) {
+ // non-MoE
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
@@ -9770,6 +9956,37 @@ struct llm_build_context {
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
+ } else if (model.arch == LLM_ARCH_LLAMA4) {
+ // llama4 MoE
+ ggml_tensor * ffn_inp_normed = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ ggml_tensor * moe_out = llm_build_moe_ffn(ctx0, lctx, ffn_inp_normed,
+ model.layers[il].ffn_gate_inp,
+ model.layers[il].ffn_up_exps,
+ model.layers[il].ffn_gate_exps,
+ model.layers[il].ffn_down_exps,
+ nullptr,
+ n_expert, n_expert_used,
+ LLM_FFN_SILU, false,
+ false, 0.0,
+ LLM_EXPERT_GATING_FUNC_SIGMOID,
+ cb, il);
+
+ // Shared experts
+ ggml_tensor * shexp_out = llm_build_ffn(ctx0, lctx, ffn_inp_normed,
+ model.layers[il].ffn_up_shexp, NULL, NULL,
+ model.layers[il].ffn_gate_shexp, NULL, NULL,
+ model.layers[il].ffn_down_shexp, NULL, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+ cb(shexp_out, "ffn_moe_shexp", il);
+
+ cur = ggml_add(ctx0, moe_out, shexp_out);
+ cb(cur, "ffn_moe_out_merged", il);
+
} else {
// MoE branch
cur = llm_build_norm(ctx0, ffn_inp, hparams,
@@ -9782,11 +9999,11 @@ struct llm_build_context {
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
- nullptr,
+ nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
false, 0.0,
- LLM_EXPERT_GATING_FUNC_SOFTMAX,
+ LLM_EXPERT_GATING_FUNC_SOFTMAX,
cb, il);
cb(cur, "ffn_moe_out", il);
}
@@ -15162,6 +15379,7 @@ static struct ggml_cgraph * llama_build_graph(
switch (model.arch) {
case LLM_ARCH_LLAMA:
+ case LLM_ARCH_LLAMA4:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
{
@@ -15423,6 +15641,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
+ if (lctx.inp_pos && lctx.inp_scale) {
+ int n_tokens = batch.n_tokens;
+ GGML_ASSERT(ggml_nelements(lctx.inp_scale) >= n_tokens);
+ if (int(lctx.scale_data.size()) < n_tokens) lctx.scale_data.resize(n_tokens);
+ int n_pos_per_token = 1;
+ for (int i = 0; i < n_tokens; ++i) {
+ lctx.scale_data[i] = std::log(std::floor((batch.pos[i] + 1.0f) / hparams.n_attn_temp_floor_scale) + 1.0f) * hparams.f_attn_temp_scale + 1.0f;
+ }
+ ggml_backend_tensor_set(lctx.inp_scale, lctx.scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(lctx.inp_scale));
+ }
+
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
const int64_t n_tokens = batch.n_tokens;
@@ -15504,8 +15733,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
// may need to cut off old tokens for sliding window
if (data_swa) {
- if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
- f = -INFINITY;
+ if (hparams.n_attn_chunk) {
+ llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
+ if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
+ f = -INFINITY;
+ }
+ } else {
+ if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
+ f = -INFINITY;
+ }
}
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
@@ -18949,6 +19185,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
// use what we call a normal RoPE, operating on pairs of consecutive head values
case LLM_ARCH_LLAMA:
+ case LLM_ARCH_LLAMA4:
case LLM_ARCH_BAICHUAN:
case LLM_ARCH_STARCODER:
case LLM_ARCH_PLAMO:
@@ -20774,6 +21011,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_GIGACHAT;
} else if (tmpl_contains("<|role_start|>")) {
return LLM_CHAT_TEMPLATE_MEGREZ;
+ } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
+ return LLM_CHAT_TEMPLATE_LLAMA4;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@@ -21155,6 +21394,15 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|role_start|>assistant<|role_end|>";
}
+ } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) {
+ // Llama 4
+ for (auto message : chat) {
+ std::string role(message->role);
+ ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>";
+ }
+ if (add_ass) {
+ ss << "<|header_start|>assistant<|header_end|>\n\n";
+ }
} else {
// template not supported
return -1;