summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp281
1 files changed, 281 insertions, 0 deletions
diff --git a/llama.cpp b/llama.cpp
index a05a52b4..619ffa4e 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -225,6 +225,7 @@ enum llm_arch {
LLM_ARCH_OLMO,
LLM_ARCH_ARCTIC,
LLM_ARCH_DEEPSEEK2,
+ LLM_ARCH_BITNET,
LLM_ARCH_UNKNOWN,
};
@@ -263,6 +264,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_OLMO, "olmo" },
{ LLM_ARCH_ARCTIC, "arctic" },
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
+ { LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -500,6 +502,8 @@ enum llm_tensor {
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
+ LLM_TENSOR_ATTN_SUB_NORM,
+ LLM_TENSOR_FFN_SUB_NORM,
};
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
@@ -1114,6 +1118,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
},
},
{
+ LLM_ARCH_BITNET,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
+ },
+ },
+ {
LLM_ARCH_UNKNOWN,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
@@ -2118,6 +2140,8 @@ struct llama_layer {
struct ggml_tensor * attn_out_norm_b;
struct ggml_tensor * attn_q_a_norm;
struct ggml_tensor * attn_kv_a_norm;
+ struct ggml_tensor * attn_sub_norm;
+ struct ggml_tensor * ffn_sub_norm;
// attention
struct ggml_tensor * wq;
@@ -4710,6 +4734,15 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_BITNET:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ switch (hparams.n_layer) {
+ case 26: model.type = e_model::MODEL_3B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
}
@@ -6655,6 +6688,40 @@ static bool llm_load_tensors(
}
}
} break;
+ case LLM_ARCH_BITNET:
+ {
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+ // output
+ {
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+ }
+
+ const uint32_t n_ff = hparams.n_ff;
+ model.layers.resize(n_layer);
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+ layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
+
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+ layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
+
+ layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
+ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -11709,6 +11776,215 @@ struct llm_build_context {
return gf;
}
+ struct ggml_cgraph * build_bitnet() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = build_inp_pos();
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ if (model.layers[il].bq) {
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+ }
+
+ // B1.K
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+ if (model.layers[il].bk) {
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+ }
+
+ // B1.V
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+ if (model.layers[il].bv) {
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+ }
+
+ Qcur = ggml_rope_ext(
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_rope_ext(
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Kcur, "Kcur", il);
+
+ llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+
+ const int64_t n_ctx = cparams.n_ctx;
+ const int64_t n_head = hparams.n_head;
+ const int64_t n_head_kv = hparams.n_head_kv;
+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
+ const int64_t n_embd_head_v = hparams.n_embd_head_v;
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+
+ struct ggml_tensor * q_cur = Qcur;
+ struct ggml_tensor * kq_mask = KQ_mask;
+ float kq_scale = 1.0f/sqrtf(float(n_embd_head));
+ struct ggml_tensor * attn_sub_norm = model.layers[il].attn_sub_norm;
+ struct ggml_cgraph * graph = gf;
+ struct ggml_tensor * wo = model.layers[il].wo;
+ struct ggml_tensor * cur_attn;
+ struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
+ cb(q, "q", il);
+
+ struct ggml_tensor * k =
+ ggml_view_3d(ctx0, kv_self.k_l[il],
+ n_embd_head_k, n_kv, n_head_kv,
+ ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
+ ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
+ 0);
+ cb(k, "k", il);
+
+ if (cparams.flash_attn) {
+
+ // split cached v into n_head heads (not transposed)
+ struct ggml_tensor * v =
+ ggml_view_3d(ctx0, kv_self.v_l[il],
+ n_embd_head_v, n_kv, n_head_kv,
+ ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
+ ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
+ 0);
+ cb(v, "v", il);
+
+ cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+
+ cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
+ } else {
+ struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+ cb(kq, "kq", il);
+
+ kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
+
+ GGML_ASSERT(kv_self.size == n_ctx);
+
+ // split cached v into n_head heads
+ struct ggml_tensor * v =
+ ggml_view_3d(ctx0, kv_self.v_l[il],
+ n_kv, n_embd_head_v, n_head_kv,
+ ggml_element_size(kv_self.v_l[il])*n_ctx,
+ ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
+ 0);
+ cb(v, "v", il);
+
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+ cb(kqv, "kqv", il);
+
+ struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+ cb(kqv_merged, "kqv_merged", il);
+
+ cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
+ cb(cur_attn, "kqv_merged_cont", il);
+ }
+
+ cur_attn = llm_build_norm(ctx0, cur_attn, hparams,
+ attn_sub_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur_attn, "attn_sub_norm", il);
+
+ ggml_build_forward_expand(graph, cur_attn);
+
+ cur = ggml_mul_mat(ctx0, wo, cur_attn);
+
+ cb(cur, "kqv_out", il);
+ }
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward forward
+ if (model.layers[il].ffn_gate_inp == nullptr) {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
+
+ cb(tmp, "ffn_up", il);
+
+ cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur);
+
+ cb(cur, "ffn_gate", il);
+
+
+ cur = ggml_silu(ctx0, cur);
+ cb(cur, "ffn_silu", il);
+
+ cur = ggml_mul(ctx0, cur, tmp);
+ cb(cur, "ffn_gate_par", il);
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.layers[il].ffn_sub_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_sub_norm", il);
+
+ cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur);
+ cb(cur, "ffn_down", il);
+ }
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.tok_embd, cur);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+ return gf;
+ }
+
};
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -11932,6 +12208,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_deepseek2();
} break;
+ case LLM_ARCH_BITNET:
+ {
+ result = llm.build_bitnet();
+ } break;
default:
GGML_ASSERT(false);
}
@@ -16760,6 +17040,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_BERT:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_STABLELM:
+ case LLM_ARCH_BITNET:
case LLM_ARCH_QWEN:
case LLM_ARCH_QWEN2:
case LLM_ARCH_QWEN2MOE: