diff options
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 66 |
1 files changed, 42 insertions, 24 deletions
@@ -1121,19 +1121,19 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA { LLM_ARCH_BITNET, { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, }, }, { @@ -2210,6 +2210,15 @@ struct llama_layer { // long rope factors struct ggml_tensor * rope_long = nullptr; struct ggml_tensor * rope_short = nullptr; + + // bitnet scale + struct ggml_tensor * wq_scale; + struct ggml_tensor * wk_scale; + struct ggml_tensor * wv_scale; + struct ggml_tensor * wo_scale; + struct ggml_tensor * ffn_gate_scale; + struct ggml_tensor * ffn_up_scale; + struct ggml_tensor * ffn_down_scale; }; struct llama_kv_cell { @@ -6715,16 +6724,23 @@ static bool llm_load_tensors( layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}); } } break; default: @@ -11810,6 +11826,7 @@ struct llm_build_context { { // compute Q and K and RoPE them struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); @@ -11818,6 +11835,7 @@ struct llm_build_context { // B1.K struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); @@ -11826,6 +11844,7 @@ struct llm_build_context { // B1.V struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -11856,14 +11875,9 @@ struct llm_build_context { const int64_t n_embd_head_v = hparams.n_embd_head_v; const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - struct ggml_tensor * q_cur = Qcur; - struct ggml_tensor * kq_mask = KQ_mask; float kq_scale = 1.0f/sqrtf(float(n_embd_head)); - struct ggml_tensor * attn_sub_norm = model.layers[il].attn_sub_norm; - struct ggml_cgraph * graph = gf; - struct ggml_tensor * wo = model.layers[il].wo; struct ggml_tensor * cur_attn; - struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); cb(q, "q", il); struct ggml_tensor * k = @@ -11885,14 +11899,14 @@ struct llm_build_context { 0); cb(v, "v", il); - cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); } else { struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); cb(kq, "kq", il); - kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); GGML_ASSERT(kv_self.size == n_ctx); @@ -11917,13 +11931,14 @@ struct llm_build_context { } cur_attn = llm_build_norm(ctx0, cur_attn, hparams, - attn_sub_norm, NULL, + model.layers[il].attn_sub_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur_attn, "attn_sub_norm", il); - ggml_build_forward_expand(graph, cur_attn); + ggml_build_forward_expand(gf, cur_attn); - cur = ggml_mul_mat(ctx0, wo, cur_attn); + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur_attn); + cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); cb(cur, "kqv_out", il); } @@ -11946,10 +11961,12 @@ struct llm_build_context { cb(cur, "ffn_norm", il); struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); + tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_up_scale); cb(tmp, "ffn_up", il); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_gate_scale); cb(cur, "ffn_gate", il); @@ -11966,6 +11983,7 @@ struct llm_build_context { cb(cur, "ffn_sub_norm", il); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); cb(cur, "ffn_down", il); } cur = ggml_add(ctx0, cur, ffn_inp); |