summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gguf-py/gguf/constants.py2
-rw-r--r--src/llama.cpp208
2 files changed, 210 insertions, 0 deletions
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 037837da..93c614d6 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -207,6 +207,7 @@ class MODEL_ARCH(IntEnum):
MINICPM = auto()
GEMMA = auto()
GEMMA2 = auto()
+ GEMMA3 = auto()
STARCODER2 = auto()
MAMBA = auto()
XVERSE = auto()
@@ -338,6 +339,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.MINICPM: "minicpm",
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
+ MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
diff --git a/src/llama.cpp b/src/llama.cpp
index 186cb5a5..54e0bb01 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -203,6 +203,7 @@ enum llm_arch {
LLM_ARCH_MINICPM,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
+ LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
@@ -250,6 +251,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM, "minicpm" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
+ { LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@@ -1057,6 +1059,26 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
},
},
{
+ LLM_ARCH_GEMMA3,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
+ },
+ },
+ {
LLM_ARCH_STARCODER2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
@@ -2313,6 +2335,7 @@ struct llama_hparams {
uint32_t n_layer;
uint32_t n_rot;
uint32_t n_swa = 0; // sliding window attention (SWA)
+ uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0;
@@ -2342,7 +2365,9 @@ struct llama_hparams {
float rope_attn_factor = 1.0f;
float rope_freq_base_train;
+ float rope_freq_base_train_swa;
float rope_freq_scale_train;
+ float rope_freq_scale_train_swa;
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul;
@@ -4963,6 +4988,10 @@ static void llm_load_hparams(
}
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
+ // by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers
+ hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
+ hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
+
ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
// non-transformer models do not have attention heads
@@ -5310,6 +5339,28 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_GEMMA3:
+ {
+ hparams.n_swa_pattern = 6;
+
+ hparams.rope_freq_base_train_swa = 10000.0f;
+ hparams.rope_freq_scale_train_swa = 1.0f;
+
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ switch (hparams.n_layer) {
+ case 26: model.type = e_model::MODEL_2B; break;
+ case 34: model.type = e_model::MODEL_4B; break;
+ case 48: model.type = e_model::MODEL_12B; break;
+ case 62: model.type = e_model::MODEL_27B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+
+ hparams.f_attention_scale = model.type == e_model::MODEL_27B
+ ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
+ : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
+ } break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -7411,6 +7462,38 @@ static bool llm_load_tensors(
layer.ffn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd});
}
} break;
+ case LLM_ARCH_GEMMA3:
+ {
+ model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab},
+ llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+ layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
+ layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
+ layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+ layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+ layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+ layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+
+ layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+ layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
+ layer.ffn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+ }
+ } break;
case LLM_ARCH_STARCODER2:
{
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -12664,6 +12747,126 @@ struct llm_build_context {
return gf;
}
+ struct ggml_cgraph * build_gemma3() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+ // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
+ if (batch.token) {
+ inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+ cb(inpL, "inp_scaled", -1);
+ }
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = build_inp_pos();
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ // gemma3 requires different mask for layers using sliding window (SWA)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
+ struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
+
+ // "5-to-1 interleaved attention"
+ // 5 layers of local attention followed by 1 layer of global attention
+ static const int sliding_window_pattern = 6;
+
+ for (int il = 0; il < n_layer; ++il) {
+ const bool is_sliding = (il + 1) % sliding_window_pattern;
+ const float freq_base_l = is_sliding ? 10000.0f : freq_base;
+ const float freq_scale_l = is_sliding ? 1.0f : freq_scale;
+ struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
+
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+
+ struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+
+ struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens);
+ Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
+ cb(Qcur, "Qcur_normed", il);
+
+ Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens);
+ Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
+ cb(Kcur, "Kcur_normed", il);
+
+ Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+ cb(Kcur, "Kcur", il);
+
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL,
+ Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il);
+ }
+
+ cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_post_norm", il);
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
+ struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+ cb(sa_out, "sa_out", il);
+
+ cur = llm_build_norm(ctx0, sa_out, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ // feed-forward network
+ cur = llm_build_ffn(ctx0, lctx, cur,
+ model.layers[il].ffn_up, NULL, NULL,
+ model.layers[il].ffn_gate, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL,
+ NULL,
+ LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
+ cb(cur, "ffn_out", il);
+
+ cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, cb, -1);
+ cb(cur, "ffn_post_norm", -1);
+
+ cur = ggml_add(ctx0, cur, sa_out);
+ cur = lctx.cvec.apply_to(ctx0, cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
struct ggml_cgraph * build_starcoder2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@@ -14996,6 +15199,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_gemma2();
} break;
+ case LLM_ARCH_GEMMA3:
+ {
+ result = llm.build_gemma3();
+ } break;
case LLM_ARCH_STARCODER2:
{
result = llm.build_starcoder2();
@@ -18826,6 +19033,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHI3:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
+ case LLM_ARCH_GEMMA3:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX: