summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml-metal.metal18
1 files changed, 14 insertions, 4 deletions
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 88fe607b..53b2ddb8 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -5246,6 +5246,7 @@ void kernel_mul_mv_iq2_k_f32_impl(
uint32_t aux32[2];
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
+ uint16_t shift[4];
for (int ib = ix; ib < nb; ib += 4) {
@@ -5266,10 +5267,14 @@ void kernel_mul_mv_iq2_k_f32_impl(
thread const int8_t * s8 = (thread const int8_t *)&scales32;
uint16_t extra = xb.extra >> (8*iq + is);
+ shift[0] = (extra << 2) & 4;
+ shift[1] = (extra << 1) & 4;
+ shift[2] = (extra >> 0) & 4;
+ shift[3] = (extra >> 1) & 4;
+
float4 acc = {0.f};
for (int l = 0; l < 4; ++l) {
- constant float * values = kvalues_iq2k_f + 4*(extra & 1);
- extra >>= 2;
+ constant float * values = kvalues_iq2k_f + shift[l];
aux32[0] = (q32[0] >> 2*l) & 0x03030303;
aux32[1] = (q32[1] >> 2*l) & 0x03030303;
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
@@ -5365,6 +5370,7 @@ void kernel_mul_mv_iq3_k_f32_impl(
uint32_t vl[2], vh[2];
uint32_t aux32[2];
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
+ uint16_t shift[4];
for (int ib = ix; ib < nb; ib += 4) {
@@ -5387,6 +5393,11 @@ void kernel_mul_mv_iq3_k_f32_impl(
uint16_t extra = xb.extra >> (8*iq + is);
uint16_t signs = xb.scales_h >> (8*iq + is);
+ shift[0] = (extra << 3) & 8;
+ shift[1] = (extra << 2) & 8;
+ shift[2] = (extra << 1) & 8;
+ shift[3] = (extra << 0) & 8;
+
vl[0] = ql16[0] | ql16[1] << 16;
vl[1] = ql16[2] | ql16[3] << 16;
vh[0] = ((qh16[0] | (qh16[1] << 16)) << 4*(1-iq)) >> 2;
@@ -5394,8 +5405,7 @@ void kernel_mul_mv_iq3_k_f32_impl(
float4 acc = {0.f};
for (int l = 0; l < 4; ++l) {
- constant float * values = kvalues_iq3k_f + 8*(extra & 1);
- extra >>= 2;
+ constant float * values = kvalues_iq3k_f + shift[l];
aux32[0] = (vl[0] & 0x03030303) | (vh[0] & 0x04040404);
aux32[1] = (vl[1] & 0x03030303) | (vh[1] & 0x04040404);
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];