diff options
author | saood06 <saood05@gmail.com> | 2025-05-02 00:07:24 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-02 07:07:24 +0200 |
commit | d37add8b39634a6bc31a2e3bda8bbe24297c66db (patch) | |
tree | a37f6c0ef91a2a83a54a4f059751da92e466a485 /src | |
parent | 98d1626469879d35faba9cb7e9d0b1ddaf853eee (diff) |
Fix model architecture name (#366)
Co-authored-by: junhuihe <junhui-he@outlook.com>
Diffstat (limited to 'src')
-rw-r--r-- | src/llama.cpp | 40 |
1 files changed, 37 insertions, 3 deletions
diff --git a/src/llama.cpp b/src/llama.cpp index 939e6e4b..dd9d82dc 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -227,6 +227,7 @@ enum llm_arch { LLM_ARCH_GLM4, LLM_ARCH_BITNET, LLM_ARCH_BITNET_25, + LLM_ARCH_BITNET_B158, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, @@ -281,6 +282,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_BITNET_25, "bitnet-25" }, + { LLM_ARCH_BITNET_B158, "bitnet-b1.58" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, @@ -1420,6 +1422,34 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA }, }, { + LLM_ARCH_BITNET_B158, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, + { LLM_ARCH_T5, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, @@ -5235,7 +5265,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25) { + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25 || model.arch == LLM_ARCH_BITNET_B158) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -5830,6 +5860,7 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BITNET_B158: case LLM_ARCH_BITNET_25: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -8350,6 +8381,7 @@ static bool llm_load_tensors( layer.ffn_up_scale = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_BITNET_B158: case LLM_ARCH_BITNET_25: { model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -15376,7 +15408,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_bitnet_25() { + struct ggml_cgraph * build_bitnet_158() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens @@ -16599,9 +16631,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_bitnet(); } break; + case LLM_ARCH_BITNET_B158: case LLM_ARCH_BITNET_25: { - result = llm.build_bitnet_25(); + result = llm.build_bitnet_158(); } break; case LLM_ARCH_COHERE2: { @@ -20293,6 +20326,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_STABLELM: case LLM_ARCH_BITNET: case LLM_ARCH_BITNET_25: + case LLM_ARCH_BITNET_B158: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: |