diff options
author | saood06 <saood05@gmail.com> | 2025-04-22 01:34:13 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-04-22 08:34:13 +0200 |
commit | cc398007238cbbb064609cbdc9bc4aab03c658d7 (patch) | |
tree | a414370696faa5ad8a89ac5aa7ec7cf2171a78a4 | |
parent | 93cd77b65501246603061c7ee2801a992e3c6312 (diff) |
Add support for bitnet2b_2501 model (#337)
* add support for bitnet2b_2501 model
* Fixes
* Support both model names
---------
Co-authored-by: potassiummmm <zhou.hansong@outlook.com>
-rwxr-xr-x | convert_hf_to_gguf.py | 1 | ||||
-rw-r--r-- | gguf-py/gguf/constants.py | 24 | ||||
-rw-r--r-- | gguf-py/gguf/tensor_mapping.py | 5 | ||||
-rw-r--r-- | src/llama.cpp | 301 |
4 files changed, 330 insertions, 1 deletions
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 1ee82724..a6ab09c0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1598,6 +1598,7 @@ class LlamaModel(Model): @Model.register("BitnetForCausalLM") +@Model.register("BitNetForCausalLM") class BitnetModel(Model): model_arch = gguf.MODEL_ARCH.BITNET diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 93c614d6..22cde145 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -219,6 +219,7 @@ class MODEL_ARCH(IntEnum): DEEPSEEK2 = auto() CHATGLM = auto() BITNET = auto() + BITNET_25 = auto() T5 = auto() T5ENCODER = auto() JAIS = auto() @@ -351,6 +352,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.BITNET: "bitnet", + MODEL_ARCH.BITNET_25: "bitnet-25", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", MODEL_ARCH.JAIS: "jais", @@ -1019,6 +1021,28 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_SUB_NORM, MODEL_TENSOR.FFN_SUB_NORM, ], + MODEL_ARCH.BITNET_25: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.ATTN_SUB_NORM, + MODEL_TENSOR.FFN_SUB_NORM, + ], MODEL_ARCH.T5: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index e8725426..1dea6a82 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -131,6 +131,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.qkv_proj", # phi3 "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm + "layers.{bid}.attention.wqkv", ), # Attention query @@ -464,10 +465,14 @@ class TensorNameMap: MODEL_TENSOR.ATTN_SUB_NORM: ( "model.layers.{bid}.self_attn.inner_attn_ln", # bitnet + "layers.{bid}.attention.attn_sub_norm", # bitnet + "model.layers.{bid}.self_attn.attn_sub_norm", ), MODEL_TENSOR.FFN_SUB_NORM: ( "model.layers.{bid}.mlp.ffn_layernorm", # bitnet + "layers.{bid}.feed_forward.ffn_sub_norm", # bitnet + "model.layers.{bid}.mlp.ffn_sub_norm", ), MODEL_TENSOR.DEC_ATTN_NORM: ( diff --git a/src/llama.cpp b/src/llama.cpp index 5340d847..dcac3833 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -223,6 +223,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, LLM_ARCH_BITNET, + LLM_ARCH_BITNET_25, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, @@ -272,6 +273,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_BITNET_25, "bitnet-25" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, @@ -1324,6 +1326,34 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA }, }, { + LLM_ARCH_BITNET_25, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, + { LLM_ARCH_T5, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, @@ -1468,6 +1498,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_LLAMA4, + LLM_CHAT_TEMPLATE_BITNET, LLM_CHAT_TEMPLATE_UNKNOWN, }; @@ -1504,6 +1535,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = { { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, + { "bitnet", LLM_CHAT_TEMPLATE_BITNET }, }; @@ -5120,7 +5152,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -5690,6 +5722,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BITNET_25: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 30: model.type = e_model::MODEL_2B; break; // bitnet2b_2501 + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_T5: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -8107,6 +8148,90 @@ static bool llm_load_tensors( layer.ffn_up_scale = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_BITNET_25: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); + layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + + // optional bias tensors + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + + if (n_expert == 0) { + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + + // optional MLP bias + layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (layer.ffn_gate_exps) { + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + } else { + // merge split expert into a single tensor for compatibility with older models + // requires disabling mmap + use_mmap_buffer = false; + + ggml_type type_gate = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, 0).c_str())->type; + ggml_type type_down = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, 0).c_str())->type; + ggml_type type_up = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, 0).c_str())->type; + + layer.ffn_gate_exps = ggml_new_tensor_3d(ctx_split, type_gate, n_embd, n_ff, n_expert); + layer.ffn_down_exps = ggml_new_tensor_3d(ctx_split, type_down, n_ff, n_embd, n_expert); + layer.ffn_up_exps = ggml_new_tensor_3d(ctx_split, type_up, n_embd, n_ff, n_expert); + + ggml_set_name(layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i).c_str()); + ggml_set_name(layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i).c_str()); + ggml_set_name(layer.ffn_up_exps, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i).c_str()); + + for (uint32_t x = 0; x < n_expert; ++x) { + // the individual experts are loaded into a view of the merged tensor + ml.create_tensor_as_view(ctx_split, layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_gate_exps->nb[2]*x); + ml.create_tensor_as_view(ctx_split, layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd }, layer.ffn_down_exps->nb[2]*x); + ml.create_tensor_as_view(ctx_split, layer.ffn_up_exps, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_up_exps->nb[2]*x); + } + } + } + } + } break; case LLM_ARCH_T5: { const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; @@ -14731,6 +14856,156 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_bitnet_25() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + struct ggml_tensor * rope_factors = build_rope_factors(il); + // printf("%f\n\n\n\n",((float*)rope_factors->data)[1]); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + NULL, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].attn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_sub_norm", il); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + if (model.layers[il].wo_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); + } + if (model.layers[il].bo) { + cur = ggml_add(ctx0, cur, model.layers[il].bo); + } + cb(cur, "attn_o_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + // n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale, + NULL, NULL, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_sub_norm", il); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur); + if (model.layers[il].ffn_down_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); + } + cb(cur, "ffn_down", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.tok_embd, cur); + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_t5_encoder() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -15527,6 +15802,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_bitnet(); } break; + case LLM_ARCH_BITNET_25: + { + result = llm.build_bitnet_25(); + } break; case LLM_ARCH_T5: { if (lctx.is_encoding) { @@ -19210,6 +19489,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_STABLELM: case LLM_ARCH_BITNET: + case LLM_ARCH_BITNET_25: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: @@ -21403,6 +21683,25 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|header_start|>assistant<|header_end|>\n\n"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_BITNET) { + // bitnet-25 + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "System: "; + ss << message->content; + } else if (role == "user") { + ss << "User: "; + if (!system_prompt.empty()) { + ss << system_prompt; + system_prompt = ""; + } + ss << message->content << "<|eot_id|>Assistant: "; + } else { + ss << message->content; + } + } } else { // template not supported return -1; |