summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/llama.cpp40
1 files changed, 37 insertions, 3 deletions
diff --git a/src/llama.cpp b/src/llama.cpp
index 939e6e4b..dd9d82dc 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -227,6 +227,7 @@ enum llm_arch {
LLM_ARCH_GLM4,
LLM_ARCH_BITNET,
LLM_ARCH_BITNET_25,
+ LLM_ARCH_BITNET_B158,
LLM_ARCH_T5,
LLM_ARCH_T5ENCODER,
LLM_ARCH_JAIS,
@@ -281,6 +282,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_BITNET_25, "bitnet-25" },
+ { LLM_ARCH_BITNET_B158, "bitnet-b1.58" },
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" },
{ LLM_ARCH_JAIS, "jais" },
@@ -1420,6 +1422,34 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
},
},
{
+ LLM_ARCH_BITNET_B158,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
+ { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
+ { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
+ { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" },
+ { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
+ },
+ },
+ {
LLM_ARCH_T5,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
@@ -5235,7 +5265,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
- if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25) {
+ if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25 || model.arch == LLM_ARCH_BITNET_B158) {
if (hparams.n_rot != hparams.n_embd_head_k) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
}
@@ -5830,6 +5860,7 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_BITNET_B158:
case LLM_ARCH_BITNET_25:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -8350,6 +8381,7 @@ static bool llm_load_tensors(
layer.ffn_up_scale = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
}
} break;
+ case LLM_ARCH_BITNET_B158:
case LLM_ARCH_BITNET_25:
{
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -15376,7 +15408,7 @@ struct llm_build_context {
return gf;
}
- struct ggml_cgraph * build_bitnet_25() {
+ struct ggml_cgraph * build_bitnet_158() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
// mutable variable, needed during the last layer of the computation to skip unused tokens
@@ -16599,9 +16631,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_bitnet();
} break;
+ case LLM_ARCH_BITNET_B158:
case LLM_ARCH_BITNET_25:
{
- result = llm.build_bitnet_25();
+ result = llm.build_bitnet_158();
} break;
case LLM_ARCH_COHERE2:
{
@@ -20293,6 +20326,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_STABLELM:
case LLM_ARCH_BITNET:
case LLM_ARCH_BITNET_25:
+ case LLM_ARCH_BITNET_B158:
case LLM_ARCH_QWEN:
case LLM_ARCH_QWEN2:
case LLM_ARCH_QWEN2MOE: