diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-19 16:46:23 +0300 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:52 +0300 |
commit | 58d9e8f1d2efba4b6717043f7a5167be670a6f2e (patch) | |
tree | bc70f7b1197e9572c3efdfa84d349b729c41cf9b | |
parent | 927e251a12fa287e13c6bd9667ee97d783486c09 (diff) |
bitnet: put the scale in a separate tensor
and correspondingly add an extra ggml_mul_mat operation.
As per @ggerganov, this is how things should be done.
It seems to be working, but as far as I can tell this
results in a ~15% performance penalty for prompt processing.
Commiting so I can go and test on othe platforms.
-rwxr-xr-x | convert-hf-to-gguf.py | 38 | ||||
-rw-r--r-- | ggml-common.h | 7 | ||||
-rw-r--r-- | iqk-quantize.cpp | 41 | ||||
-rw-r--r-- | iqk_mul_mat.cpp | 17 | ||||
-rw-r--r-- | llama.cpp | 66 |
5 files changed, 97 insertions, 72 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 0c08b800..ebd36a9b 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1416,17 +1416,47 @@ class BitnetModel(Model): dtype = weight.dtype weight = weight.float() s = 1 / weight.abs().mean().clamp(min=1e-5) - result = (weight * s).round().clamp(-1, 1) / s - return result.type(dtype) + weight = (weight * s).round().clamp(-1, 1) / s + scale = weight.abs().max().unsqueeze(0) + weight = torch.where(weight.abs().less(1e-6), 0, weight).type(dtype) + weight = torch.sign(weight).type(dtype) + return weight.type(dtype), scale.type(torch.float32) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # transform weight into 1/0/-1 (in fp32) if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", "down_proj.weight", "up_proj.weight", "gate_proj.weight", "o_proj.weight")): - data_torch = self.weight_quant(data_torch) + weight_torch, scale_torch = self.weight_quant(data_torch) - return [(self.map_tensor_name(name), data_torch)] + tensors: list[tuple[str, Tensor]] = [] + + if name.endswith("q_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid, suffix=".scale"), scale_torch)) + elif name.endswith("k_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid, suffix=".scale"), scale_torch)) + elif name.endswith("v_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid, suffix=".scale"), scale_torch)) + elif name.endswith("o_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid, suffix=".scale"), scale_torch)) + elif name.endswith("up_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid, suffix=".scale"), scale_torch)) + elif name.endswith("down_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid, suffix=".scale"), scale_torch)) + elif name.endswith("gate_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid, suffix=".scale"), scale_torch)) + + if len(tensors) == 0: + tensors.append((self.map_tensor_name(name), data_torch)) + + return tensors @Model.register("GrokForCausalLM") diff --git a/ggml-common.h b/ggml-common.h index f5a35960..d3945975 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -375,20 +375,19 @@ static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m bl // #define QK_IQ1BN 64 typedef struct { - uint16_t extra; + uint8_t extra; uint8_t ql[QK_IQ1BN/8]; uint8_t qh[QK_IQ1BN/16]; } block_iq1_bn; -static_assert(sizeof(block_iq1_bn) == sizeof(uint16_t) + QK_IQ1BN/8 + QK_IQ1BN/16, "wrong iq1_bn block size/padding"); +static_assert(sizeof(block_iq1_bn) == sizeof(uint8_t) + QK_IQ1BN/8 + QK_IQ1BN/16, "wrong iq1_bn block size/padding"); // // Bitnet - implemented as 2.25 bpw // #define QK_IQ2BN 64 typedef struct { - ggml_half d; uint8_t qs[QK_IQ2BN/4]; } block_iq2_bn; -static_assert(sizeof(block_iq2_bn) == sizeof(ggml_half) + QK_IQ2BN/4, "wrong iq2_bn block size/padding"); +static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/padding"); // Used by IQ1_M quants typedef union { diff --git a/iqk-quantize.cpp b/iqk-quantize.cpp index 522ab2cd..6622d5ba 100644 --- a/iqk-quantize.cpp +++ b/iqk-quantize.cpp @@ -123,19 +123,10 @@ void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, i const auto& iq1bn = get_iq1bn_data(); - auto max_in_row = row_max(n_per_row, src); - - max_in_row *= 1.015625f; // i.e., round to nearest in our fp8 representation - if (max_in_row > iq1bn_max_value()) { - fprintf(stderr, "%s: found scale %g, which is outside the range of out fp8 representation\n", __func__, max_in_row); - } - auto u = iq1bn_float_to_fp8(max_in_row); - for (int ib = 0; ib < nblock; ++ib) { std::memset(&y[ib], 0, sizeof(block_iq1_bn)); auto xb = src + QK_IQ1BN*ib; - auto extra = quantize_one_block_1bn(iq1bn, xb, L, y[ib].ql, y[ib].qh); - y[ib].extra = u | (extra << 8); + y[ib].extra = quantize_one_block_1bn(iq1bn, xb, L, y[ib].ql, y[ib].qh); } } @@ -145,13 +136,12 @@ void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, i const int nblock = n_per_row/QK_IQ1BN; - auto max_in_row = row_max(n_per_row, src); - ggml_half dh = GGML_FP32_TO_FP16(max_in_row); + //auto max_in_row = row_max(n_per_row, src); + //printf("%s: max = %g\n", __func__, max_in_row); constexpr int Nj = QK_IQ1BN/4; for (int ib = 0; ib < nblock; ++ib) { - y[ib].d = dh; auto xb = src + QK_IQ1BN*ib; for (int j = 0; j < QK_IQ1BN; ++j) { L[j] = fabsf(xb[j]) < 1e-6f ? 1 : xb[j] < 0 ? 0 : 2; @@ -229,9 +219,8 @@ void dequantize_row_iq2_bn(const block_iq2_bn * x, float * y, int64_t k) { assert(k%QK_IQ1BN == 0); int nblock = k / QK_IQ1BN; - float d = GGML_FP16_TO_FP32(x[0].d); - auto m = -d; - auto d1 = d, d2 = d*0.25f, d3 = d2*0.25f, d4 = d3*0.25f; + auto d1 = 1.f, d2 = 0.25f, d3 = d2*0.25f, d4 = d3*0.25f; + auto m = -1.f; constexpr int Nj = QK_IQ1BN/4; for (int i = 0; i < nblock; ++i) { for (int j = 0; j < Nj; ++j) { @@ -260,8 +249,6 @@ void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, siz float sumf = 0; for (int i = 0; i < nblock; ++i) { - float d = iq1bn_fp8_to_float(x[i].extra & 0xff); - uint8_t extra = x[i].extra >> 8; auto qh = x[i].qh; auto ql = x[i].ql; auto q8 = y[2*i+0].qs; @@ -271,7 +258,7 @@ void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, siz uint16_t val = iq1bn_grid_u16[idx]; int16_t sl = 0; for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1); - sumi1 += extra & (1 << k) ? -sl : sl; + sumi1 += x[i].extra & (1 << k) ? -sl : sl; q8 += 8; } q8 = y[2*i+1].qs; @@ -281,10 +268,10 @@ void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, siz uint16_t val = iq1bn_grid_u16[idx]; int16_t sl = 0; for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1); - sumi2 += extra & (1 << k) ? -sl : sl; + sumi2 += x[i].extra & (1 << k) ? -sl : sl; q8 += 8; } - sumf += d * (GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2); + sumf += GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2; } *s = sumf; @@ -306,9 +293,7 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si float sumf = 0; - float d = iq1bn_fp8_to_float(x[0].extra & 0xff); for (int i = 0; i < nblock; ++i) { - uint8_t extra = x[i].extra >> 8; auto qh = x[i].qh; auto ql = x[i].ql; auto q8 = y[i].qs; @@ -318,7 +303,7 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si uint16_t val = iq1bn_grid_u16[idx]; int16_t sl = 0; for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1); - sumi += extra & (1 << k) ? -sl : sl; + sumi += x[i].extra & (1 << k) ? -sl : sl; q8 += 8; } for (int k = 4; k < 8; ++k) { @@ -326,10 +311,10 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si uint16_t val = iq1bn_grid_u16[idx]; int16_t sl = 0; for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1); - sumi += extra & (1 << k) ? -sl : sl; + sumi += x[i].extra & (1 << k) ? -sl : sl; q8 += 8; } - sumf += d * (y[i].d) * sumi; + sumf += y[i].d * sumi; } *s = sumf; @@ -352,8 +337,6 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si float sumf = 0; - float d = GGML_FP16_TO_FP32(x[0].d); - for (int i = 0; i < nblock; ++i) { auto q8 = y[i].qs; int s0 = 0, s1 = 0, s2 = 0, s3 = 0, s4 = 0; @@ -367,7 +350,7 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si sumf += y[i].d * (s1 + 0.25f*s2 + 0.0625*s3 + 0.015625*s4 - s0); } - *s = sumf * d; + *s = sumf; } diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 08f954e1..0cee0ff4 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1343,19 +1343,17 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const const auto m1_16 = _mm256_set1_epi16(1); #endif - //auto step = bx / sizeof(block_iq1_bn); const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx); for (int ix = 0; ix < nrc_x; ++ix) { x = (const block_iq1_bn *)((const char *)vx + ix*bx); - float d1 = iq1bn_fp8_to_float(x[0].extra & 0xff); for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { - auto all_signs = _mm256_set1_epi8(x[i].extra >> 8); + auto all_signs = _mm256_set1_epi8(x[i].extra); all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8); signs[0] = _mm256_shuffle_epi8(all_signs, shuff3); signs[1] = _mm256_shuffle_epi8(all_signs, shuff4); @@ -1398,7 +1396,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, d1 * hsum_float_8(accd[iy])); + info.store(ix, iy, hsum_float_8(accd[iy])); } } @@ -1419,7 +1417,6 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const for (int ix = 0; ix < nrc_x; ++ix) { const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx); - float d = GGML_FP16_TO_FP32(x[0].d); { auto q2bits = _mm_loadu_si128((const __m128i *)x[0].qs); @@ -1456,7 +1453,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, d * hsum_float_8(accd[iy])); + info.store(ix, iy, hsum_float_8(accd[iy])); } } @@ -4129,13 +4126,12 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn for (int ix = 0; ix < nrc_x; ++ix) { x = (const block_iq1_bn *)((const char *)vx + ix*bx); - float d1 = iq1bn_fp8_to_float(x[0].extra & 0xff); for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_f32(0.f); for (int i = 0; i < nb; ++i) { - auto all_signs = vdupq_n_u8(x[i].extra >> 8); + auto all_signs = vdupq_n_u8(x[i].extra); all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1); signs.val[0] = vqtbl1q_u8(all_signs, sign_shuffles.val[0]); signs.val[1] = vqtbl1q_u8(all_signs, sign_shuffles.val[1]); @@ -4179,7 +4175,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, d1 * vaddvq_f32(accd[iy])); + info.store(ix, iy, vaddvq_f32(accd[iy])); } } @@ -4200,7 +4196,6 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn for (int ix = 0; ix < nrc_x; ++ix) { const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx); - const float d = GGML_FP16_TO_FP32(x[0].d); { auto q2bits = vld1q_u8(x[0].qs); @@ -4236,7 +4231,7 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, d * vaddvq_f32(accd[iy])); + info.store(ix, iy, vaddvq_f32(accd[iy])); } } @@ -1121,19 +1121,19 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA { LLM_ARCH_BITNET, { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, }, }, { @@ -2210,6 +2210,15 @@ struct llama_layer { // long rope factors struct ggml_tensor * rope_long = nullptr; struct ggml_tensor * rope_short = nullptr; + + // bitnet scale + struct ggml_tensor * wq_scale; + struct ggml_tensor * wk_scale; + struct ggml_tensor * wv_scale; + struct ggml_tensor * wo_scale; + struct ggml_tensor * ffn_gate_scale; + struct ggml_tensor * ffn_up_scale; + struct ggml_tensor * ffn_down_scale; }; struct llama_kv_cell { @@ -6715,16 +6724,23 @@ static bool llm_load_tensors( layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}); } } break; default: @@ -11810,6 +11826,7 @@ struct llm_build_context { { // compute Q and K and RoPE them struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); @@ -11818,6 +11835,7 @@ struct llm_build_context { // B1.K struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); @@ -11826,6 +11844,7 @@ struct llm_build_context { // B1.V struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -11856,14 +11875,9 @@ struct llm_build_context { const int64_t n_embd_head_v = hparams.n_embd_head_v; const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - struct ggml_tensor * q_cur = Qcur; - struct ggml_tensor * kq_mask = KQ_mask; float kq_scale = 1.0f/sqrtf(float(n_embd_head)); - struct ggml_tensor * attn_sub_norm = model.layers[il].attn_sub_norm; - struct ggml_cgraph * graph = gf; - struct ggml_tensor * wo = model.layers[il].wo; struct ggml_tensor * cur_attn; - struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); cb(q, "q", il); struct ggml_tensor * k = @@ -11885,14 +11899,14 @@ struct llm_build_context { 0); cb(v, "v", il); - cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); } else { struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); cb(kq, "kq", il); - kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); GGML_ASSERT(kv_self.size == n_ctx); @@ -11917,13 +11931,14 @@ struct llm_build_context { } cur_attn = llm_build_norm(ctx0, cur_attn, hparams, - attn_sub_norm, NULL, + model.layers[il].attn_sub_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur_attn, "attn_sub_norm", il); - ggml_build_forward_expand(graph, cur_attn); + ggml_build_forward_expand(gf, cur_attn); - cur = ggml_mul_mat(ctx0, wo, cur_attn); + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur_attn); + cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); cb(cur, "kqv_out", il); } @@ -11946,10 +11961,12 @@ struct llm_build_context { cb(cur, "ffn_norm", il); struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); + tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_up_scale); cb(tmp, "ffn_up", il); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_gate_scale); cb(cur, "ffn_gate", il); @@ -11966,6 +11983,7 @@ struct llm_build_context { cb(cur, "ffn_sub_norm", il); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); cb(cur, "ffn_down", il); } cur = ggml_add(ctx0, cur, ffn_inp); |