summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml-metal.metal58
1 files changed, 30 insertions, 28 deletions
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index bc0ea9f5..287d8563 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -5459,16 +5459,15 @@ void kernel_mul_mv_iq1_bn_f32_impl(
device const float * y4 = (device const float *)y + 32 * ix + 16 * ir;
- const float values[3] = {-1.f, 0.f, 1.f};
-
- constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+ constexpr uint16_t k_mult[5] = {81, 27, 9, 3, 1};
const int ib = ix % (QK_IQ1BN / 32);
const int i16 = 2*ib + ir;
+ float sumy = 0;
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
- for (int j = 0; j < 16; ++j) yl[j] = y4[j];
+ for (int j = 0; j < 16; ++j) { yl[j] = y4[j]; sumy += y4[j]; }
const int ibl = ib32 / (QK_IQ1BN / 32);
device const block_iq1_bn * xr = x + ibl;
@@ -5478,18 +5477,20 @@ void kernel_mul_mv_iq1_bn_f32_impl(
for (int row = 0; row < N_DST; row++) {
float acc = 0;
- int i = 0;
+ thread const float * yy = yl;
for (int k = 0; k < 3; ++k) {
- uint8_t q = ql[k];
- for (int j = 0; j < 5; ++j) {
- uint8_t v = k_mult[j]*q;
- v = 3*v >> 8; //(v + (v >> 1)) >> 7;
- acc += yl[i++] * values[v];
+ uint16_t q = ql[k];
+ for (int j = 4; j >= 0; --j) {
+ uint16_t v = q & 0xff;
+ v += v << 1;
+ acc += yy[j] * (v & 0xff00);
+ q += q << 1;
}
+ yy += 5;
}
- uint8_t v = k_mult[i16]*extra[0];
- v = 3*v >> 8; //(v + (v >> 1)) >> 7;
- acc += yl[15] * values[v];
+ uint16_t v = (k_mult[i16]*extra[0]) & 0xff;
+ v += v << 1;
+ acc += yl[15] * (v & 0xff00);
sumf[row] += acc;
@@ -5501,7 +5502,7 @@ void kernel_mul_mv_iq1_bn_f32_impl(
}
for (int row = 0; row < N_DST; row += 2) {
- float2 r = {sumf[row], sumf[row+1]};
+ float2 r = {0.00390625f * sumf[row] - sumy, 0.00390625 * sumf[row+1] - sumy};
r = simd_sum(r);
if (tiisg < 2) {
dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg] * scale[row + tiisg];
@@ -7475,30 +7476,31 @@ template <typename type4x4>
void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
// il is in 0...3
- constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+ constexpr uint16_t k_mult[5] = {81, 27, 9, 3, 1};
+ constexpr half k_values[3] = {-1.h, 0.h, 1.h};
- int i = 0;
for (int k = 0; k < 3; ++k) {
- uint8_t q = xb->ql[3*il + k];
- for (int j = 0; j < 5; ++j) {
- uint8_t v = k_mult[j]*q;
- int8_t vs = 3*v >> 8;
- //int8_t vs = (v + (v >> 1)) >> 7;
- reg[i/4][i%4] = vs - 1;
- ++i;
+ uint16_t q = xb->ql[3*il + k];
+ int i = 5*k + 4;
+ for (int j = 4; j >= 0; --j) {
+ uint16_t v = q & 0xff;
+ v += v << 1;
+ reg[i/4][i%4] = k_values[v >> 8];
+ q += q << 1;
+ --i;
}
}
- uint8_t v = k_mult[il]*xb->extra;
- int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
- reg[3][3] = vs - 1;
+ uint16_t v = (k_mult[il]*xb->extra) & 0xff;
+ v += v << 1;
+ reg[3][3] = k_values[v >> 8];
}
template <typename type4x4>
void dequantize_iq2_bn(device const block_iq2_bn * xb, short il, thread type4x4 & reg) {
// il is in 0...3
- constexpr float k_scale[4] = {1.f, 0.25f, 0.0625f, 0.015625f};
+ constexpr half k_scale[4] = {1.h, 0.25h, 0.0625h, 0.015625h};
constexpr uint8_t k_mask[4] = {0x03, 0x0c, 0x30, 0xc0};
- const float d = k_scale[il];
+ const half d = k_scale[il];
uint8_t mask = k_mask[il];
for (int j = 0; j < 16; ++j) {