summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-19 16:46:23 +0300
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:52 +0300
commit58d9e8f1d2efba4b6717043f7a5167be670a6f2e (patch)
treebc70f7b1197e9572c3efdfa84d349b729c41cf9b
parent927e251a12fa287e13c6bd9667ee97d783486c09 (diff)
bitnet: put the scale in a separate tensor
and correspondingly add an extra ggml_mul_mat operation. As per @ggerganov, this is how things should be done. It seems to be working, but as far as I can tell this results in a ~15% performance penalty for prompt processing. Commiting so I can go and test on othe platforms.
-rwxr-xr-xconvert-hf-to-gguf.py38
-rw-r--r--ggml-common.h7
-rw-r--r--iqk-quantize.cpp41
-rw-r--r--iqk_mul_mat.cpp17
-rw-r--r--llama.cpp66
5 files changed, 97 insertions, 72 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index 0c08b800..ebd36a9b 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -1416,17 +1416,47 @@ class BitnetModel(Model):
dtype = weight.dtype
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
- result = (weight * s).round().clamp(-1, 1) / s
- return result.type(dtype)
+ weight = (weight * s).round().clamp(-1, 1) / s
+ scale = weight.abs().max().unsqueeze(0)
+ weight = torch.where(weight.abs().less(1e-6), 0, weight).type(dtype)
+ weight = torch.sign(weight).type(dtype)
+ return weight.type(dtype), scale.type(torch.float32)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# transform weight into 1/0/-1 (in fp32)
if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight",
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
"o_proj.weight")):
- data_torch = self.weight_quant(data_torch)
+ weight_torch, scale_torch = self.weight_quant(data_torch)
- return [(self.map_tensor_name(name), data_torch)]
+ tensors: list[tuple[str, Tensor]] = []
+
+ if name.endswith("q_proj.weight"):
+ tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), weight_torch))
+ tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid, suffix=".scale"), scale_torch))
+ elif name.endswith("k_proj.weight"):
+ tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), weight_torch))
+ tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid, suffix=".scale"), scale_torch))
+ elif name.endswith("v_proj.weight"):
+ tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), weight_torch))
+ tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid, suffix=".scale"), scale_torch))
+ elif name.endswith("o_proj.weight"):
+ tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), weight_torch))
+ tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid, suffix=".scale"), scale_torch))
+ elif name.endswith("up_proj.weight"):
+ tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), weight_torch))
+ tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid, suffix=".scale"), scale_torch))
+ elif name.endswith("down_proj.weight"):
+ tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), weight_torch))
+ tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid, suffix=".scale"), scale_torch))
+ elif name.endswith("gate_proj.weight"):
+ tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), weight_torch))
+ tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid, suffix=".scale"), scale_torch))
+
+ if len(tensors) == 0:
+ tensors.append((self.map_tensor_name(name), data_torch))
+
+ return tensors
@Model.register("GrokForCausalLM")
diff --git a/ggml-common.h b/ggml-common.h
index f5a35960..d3945975 100644
--- a/ggml-common.h
+++ b/ggml-common.h
@@ -375,20 +375,19 @@ static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m bl
//
#define QK_IQ1BN 64
typedef struct {
- uint16_t extra;
+ uint8_t extra;
uint8_t ql[QK_IQ1BN/8];
uint8_t qh[QK_IQ1BN/16];
} block_iq1_bn;
-static_assert(sizeof(block_iq1_bn) == sizeof(uint16_t) + QK_IQ1BN/8 + QK_IQ1BN/16, "wrong iq1_bn block size/padding");
+static_assert(sizeof(block_iq1_bn) == sizeof(uint8_t) + QK_IQ1BN/8 + QK_IQ1BN/16, "wrong iq1_bn block size/padding");
//
// Bitnet - implemented as 2.25 bpw
//
#define QK_IQ2BN 64
typedef struct {
- ggml_half d;
uint8_t qs[QK_IQ2BN/4];
} block_iq2_bn;
-static_assert(sizeof(block_iq2_bn) == sizeof(ggml_half) + QK_IQ2BN/4, "wrong iq2_bn block size/padding");
+static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/padding");
// Used by IQ1_M quants
typedef union {
diff --git a/iqk-quantize.cpp b/iqk-quantize.cpp
index 522ab2cd..6622d5ba 100644
--- a/iqk-quantize.cpp
+++ b/iqk-quantize.cpp
@@ -123,19 +123,10 @@ void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, i
const auto& iq1bn = get_iq1bn_data();
- auto max_in_row = row_max(n_per_row, src);
-
- max_in_row *= 1.015625f; // i.e., round to nearest in our fp8 representation
- if (max_in_row > iq1bn_max_value()) {
- fprintf(stderr, "%s: found scale %g, which is outside the range of out fp8 representation\n", __func__, max_in_row);
- }
- auto u = iq1bn_float_to_fp8(max_in_row);
-
for (int ib = 0; ib < nblock; ++ib) {
std::memset(&y[ib], 0, sizeof(block_iq1_bn));
auto xb = src + QK_IQ1BN*ib;
- auto extra = quantize_one_block_1bn(iq1bn, xb, L, y[ib].ql, y[ib].qh);
- y[ib].extra = u | (extra << 8);
+ y[ib].extra = quantize_one_block_1bn(iq1bn, xb, L, y[ib].ql, y[ib].qh);
}
}
@@ -145,13 +136,12 @@ void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, i
const int nblock = n_per_row/QK_IQ1BN;
- auto max_in_row = row_max(n_per_row, src);
- ggml_half dh = GGML_FP32_TO_FP16(max_in_row);
+ //auto max_in_row = row_max(n_per_row, src);
+ //printf("%s: max = %g\n", __func__, max_in_row);
constexpr int Nj = QK_IQ1BN/4;
for (int ib = 0; ib < nblock; ++ib) {
- y[ib].d = dh;
auto xb = src + QK_IQ1BN*ib;
for (int j = 0; j < QK_IQ1BN; ++j) {
L[j] = fabsf(xb[j]) < 1e-6f ? 1 : xb[j] < 0 ? 0 : 2;
@@ -229,9 +219,8 @@ void dequantize_row_iq2_bn(const block_iq2_bn * x, float * y, int64_t k) {
assert(k%QK_IQ1BN == 0);
int nblock = k / QK_IQ1BN;
- float d = GGML_FP16_TO_FP32(x[0].d);
- auto m = -d;
- auto d1 = d, d2 = d*0.25f, d3 = d2*0.25f, d4 = d3*0.25f;
+ auto d1 = 1.f, d2 = 0.25f, d3 = d2*0.25f, d4 = d3*0.25f;
+ auto m = -1.f;
constexpr int Nj = QK_IQ1BN/4;
for (int i = 0; i < nblock; ++i) {
for (int j = 0; j < Nj; ++j) {
@@ -260,8 +249,6 @@ void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, siz
float sumf = 0;
for (int i = 0; i < nblock; ++i) {
- float d = iq1bn_fp8_to_float(x[i].extra & 0xff);
- uint8_t extra = x[i].extra >> 8;
auto qh = x[i].qh;
auto ql = x[i].ql;
auto q8 = y[2*i+0].qs;
@@ -271,7 +258,7 @@ void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, siz
uint16_t val = iq1bn_grid_u16[idx];
int16_t sl = 0;
for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1);
- sumi1 += extra & (1 << k) ? -sl : sl;
+ sumi1 += x[i].extra & (1 << k) ? -sl : sl;
q8 += 8;
}
q8 = y[2*i+1].qs;
@@ -281,10 +268,10 @@ void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, siz
uint16_t val = iq1bn_grid_u16[idx];
int16_t sl = 0;
for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1);
- sumi2 += extra & (1 << k) ? -sl : sl;
+ sumi2 += x[i].extra & (1 << k) ? -sl : sl;
q8 += 8;
}
- sumf += d * (GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2);
+ sumf += GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2;
}
*s = sumf;
@@ -306,9 +293,7 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
float sumf = 0;
- float d = iq1bn_fp8_to_float(x[0].extra & 0xff);
for (int i = 0; i < nblock; ++i) {
- uint8_t extra = x[i].extra >> 8;
auto qh = x[i].qh;
auto ql = x[i].ql;
auto q8 = y[i].qs;
@@ -318,7 +303,7 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
uint16_t val = iq1bn_grid_u16[idx];
int16_t sl = 0;
for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1);
- sumi += extra & (1 << k) ? -sl : sl;
+ sumi += x[i].extra & (1 << k) ? -sl : sl;
q8 += 8;
}
for (int k = 4; k < 8; ++k) {
@@ -326,10 +311,10 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
uint16_t val = iq1bn_grid_u16[idx];
int16_t sl = 0;
for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1);
- sumi += extra & (1 << k) ? -sl : sl;
+ sumi += x[i].extra & (1 << k) ? -sl : sl;
q8 += 8;
}
- sumf += d * (y[i].d) * sumi;
+ sumf += y[i].d * sumi;
}
*s = sumf;
@@ -352,8 +337,6 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
float sumf = 0;
- float d = GGML_FP16_TO_FP32(x[0].d);
-
for (int i = 0; i < nblock; ++i) {
auto q8 = y[i].qs;
int s0 = 0, s1 = 0, s2 = 0, s3 = 0, s4 = 0;
@@ -367,7 +350,7 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
sumf += y[i].d * (s1 + 0.25f*s2 + 0.0625*s3 + 0.015625*s4 - s0);
}
- *s = sumf * d;
+ *s = sumf;
}
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 08f954e1..0cee0ff4 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -1343,19 +1343,17 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
const auto m1_16 = _mm256_set1_epi16(1);
#endif
- //auto step = bx / sizeof(block_iq1_bn);
const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
for (int ix = 0; ix < nrc_x; ++ix) {
x = (const block_iq1_bn *)((const char *)vx + ix*bx);
- float d1 = iq1bn_fp8_to_float(x[0].extra & 0xff);
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
- auto all_signs = _mm256_set1_epi8(x[i].extra >> 8);
+ auto all_signs = _mm256_set1_epi8(x[i].extra);
all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8);
signs[0] = _mm256_shuffle_epi8(all_signs, shuff3);
signs[1] = _mm256_shuffle_epi8(all_signs, shuff4);
@@ -1398,7 +1396,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
}
for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, d1 * hsum_float_8(accd[iy]));
+ info.store(ix, iy, hsum_float_8(accd[iy]));
}
}
@@ -1419,7 +1417,6 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
for (int ix = 0; ix < nrc_x; ++ix) {
const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx);
- float d = GGML_FP16_TO_FP32(x[0].d);
{
auto q2bits = _mm_loadu_si128((const __m128i *)x[0].qs);
@@ -1456,7 +1453,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
}
for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, d * hsum_float_8(accd[iy]));
+ info.store(ix, iy, hsum_float_8(accd[iy]));
}
}
@@ -4129,13 +4126,12 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
for (int ix = 0; ix < nrc_x; ++ix) {
x = (const block_iq1_bn *)((const char *)vx + ix*bx);
- float d1 = iq1bn_fp8_to_float(x[0].extra & 0xff);
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_f32(0.f);
for (int i = 0; i < nb; ++i) {
- auto all_signs = vdupq_n_u8(x[i].extra >> 8);
+ auto all_signs = vdupq_n_u8(x[i].extra);
all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1);
signs.val[0] = vqtbl1q_u8(all_signs, sign_shuffles.val[0]);
signs.val[1] = vqtbl1q_u8(all_signs, sign_shuffles.val[1]);
@@ -4179,7 +4175,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, d1 * vaddvq_f32(accd[iy]));
+ info.store(ix, iy, vaddvq_f32(accd[iy]));
}
}
@@ -4200,7 +4196,6 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
for (int ix = 0; ix < nrc_x; ++ix) {
const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx);
- const float d = GGML_FP16_TO_FP32(x[0].d);
{
auto q2bits = vld1q_u8(x[0].qs);
@@ -4236,7 +4231,7 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, d * vaddvq_f32(accd[iy]));
+ info.store(ix, iy, vaddvq_f32(accd[iy]));
}
}
diff --git a/llama.cpp b/llama.cpp
index 08cb6c35..46e94310 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1121,19 +1121,19 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{
LLM_ARCH_BITNET,
{
- { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
- { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
- { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
- { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
- { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
- { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
- { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
- { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" },
- { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
- { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
- { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
- { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
- { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
},
},
{
@@ -2210,6 +2210,15 @@ struct llama_layer {
// long rope factors
struct ggml_tensor * rope_long = nullptr;
struct ggml_tensor * rope_short = nullptr;
+
+ // bitnet scale
+ struct ggml_tensor * wq_scale;
+ struct ggml_tensor * wk_scale;
+ struct ggml_tensor * wv_scale;
+ struct ggml_tensor * wo_scale;
+ struct ggml_tensor * ffn_gate_scale;
+ struct ggml_tensor * ffn_up_scale;
+ struct ggml_tensor * ffn_down_scale;
};
struct llama_kv_cell {
@@ -6715,16 +6724,23 @@ static bool llm_load_tensors(
layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
+ layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
+ layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "scale", i), {1});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
+ layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "scale", i), {1});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+ layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1});
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
+ layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+ layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
+ layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1});
}
} break;
default:
@@ -11810,6 +11826,7 @@ struct llm_build_context {
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
@@ -11818,6 +11835,7 @@ struct llm_build_context {
// B1.K
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+ Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
@@ -11826,6 +11844,7 @@ struct llm_build_context {
// B1.V
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+ Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -11856,14 +11875,9 @@ struct llm_build_context {
const int64_t n_embd_head_v = hparams.n_embd_head_v;
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
- struct ggml_tensor * q_cur = Qcur;
- struct ggml_tensor * kq_mask = KQ_mask;
float kq_scale = 1.0f/sqrtf(float(n_embd_head));
- struct ggml_tensor * attn_sub_norm = model.layers[il].attn_sub_norm;
- struct ggml_cgraph * graph = gf;
- struct ggml_tensor * wo = model.layers[il].wo;
struct ggml_tensor * cur_attn;
- struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
+ struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
cb(q, "q", il);
struct ggml_tensor * k =
@@ -11885,14 +11899,14 @@ struct llm_build_context {
0);
cb(v, "v", il);
- cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+ cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
cb(kq, "kq", il);
- kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+ kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
GGML_ASSERT(kv_self.size == n_ctx);
@@ -11917,13 +11931,14 @@ struct llm_build_context {
}
cur_attn = llm_build_norm(ctx0, cur_attn, hparams,
- attn_sub_norm, NULL,
+ model.layers[il].attn_sub_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur_attn, "attn_sub_norm", il);
- ggml_build_forward_expand(graph, cur_attn);
+ ggml_build_forward_expand(gf, cur_attn);
- cur = ggml_mul_mat(ctx0, wo, cur_attn);
+ cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur_attn);
+ cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
cb(cur, "kqv_out", il);
}
@@ -11946,10 +11961,12 @@ struct llm_build_context {
cb(cur, "ffn_norm", il);
struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
+ tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_up_scale);
cb(tmp, "ffn_up", il);
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur);
+ cur = ggml_mul(ctx0, cur, model.layers[il].ffn_gate_scale);
cb(cur, "ffn_gate", il);
@@ -11966,6 +11983,7 @@ struct llm_build_context {
cb(cur, "ffn_sub_norm", il);
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur);
+ cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
cb(cur, "ffn_down", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);