summaryrefslogtreecommitdiff
path: root/ggml-metal.metal
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.metal')
-rw-r--r--ggml-metal.metal41
1 files changed, 21 insertions, 20 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 34e77728..4ec98e11 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -5061,7 +5061,7 @@ void kernel_mul_mv_iq1_bn_f32_impl(
const int ix = tiisg/4;
const int ir = tiisg%4;
- device const float * y4 = y + 32 * ix + 8 * ir;
+ device const float4 * y4 = (device const float4 *)y + 8 * ix + 2 * ir;
uint32_t aux32[2];
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
@@ -5070,8 +5070,7 @@ void kernel_mul_mv_iq1_bn_f32_impl(
for (int ib32 = ix; ib32 < nb32; ib32 += 8) {
- yl[0] = {y4[0], y4[4], y4[2], y4[6]};
- yl[1] = {y4[1], y4[5], y4[3], y4[7]};
+ yl[0] = y4[0]; yl[1] = y4[1];
const int ibl = ib32 / (QK_IQ1BN / 32);
const int ib = ib32 % (QK_IQ1BN / 32);
@@ -5085,9 +5084,9 @@ void kernel_mul_mv_iq1_bn_f32_impl(
uint8_t signs = extra[0] >> (4*ib + ir);
- uint32_t v = iq1bn_grid_u16[ql[0] | ((qh[0] << (8 - 4*(ir%2))) & 0x0f00)];
- uint32_t v32 = v | (v << 12);
- aux32[0] = v32 & 0x03030303; aux32[1] = (v32 >> 2) & 0x03030303;
+ uint32_t v = iq1bn_grid_zzz[ql[0] | ((qh[0] << (8 - 4*(ir%2))) & 0x0f00)];
+ uint32_t v32 = v | (v << 14);
+ aux32[0] = v32 & 0x03030303; aux32[1] = (v32 >> 4) & 0x03030303;
float4 acc4 = yl[0] * float4{values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]}
+ yl[1] * float4{values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]};
float acc = acc4[0] + acc4[1] + acc4[2] + acc4[3];
@@ -5099,7 +5098,7 @@ void kernel_mul_mv_iq1_bn_f32_impl(
qh += nb*sizeof(block_iq1_bn);
}
- y4 += 32 * 8;
+ y4 += 32 * 2;
}
for (int row = 0; row < N_DST; row += 2) {
@@ -5956,26 +5955,28 @@ template <typename type4x4>
void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
// il is in 0...3
uint8_t gs = xb->extra >> 2*il;
- const half d1 = gs & 1 ? -1.h : 1.h;
- const half d2 = gs & 2 ? -1.h : 1.h;
- uint32_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)];
- uint32_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)];
+ uint16_t idx1 = xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00);
+ uint16_t idx2 = xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00);
+ uint16_t val1 = gs & 1 ? 0xaaaa - iq1bn_grid_zzz[idx1] : iq1bn_grid_zzz[idx1];
+ uint16_t val2 = gs & 2 ? 0xaaaa - iq1bn_grid_zzz[idx2] : iq1bn_grid_zzz[idx2];
- uint32_t v = v1 | (v2 << 16);
+ uint32_t v = val1 | (val1 << 14);
uint32_t aux32;
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32;
const half values[3] = {-1.h, 0.h, 1.h};
-#pragma unroll(4)
- for (int i = 0; i < 4; ++i) {
- aux32 = (v >> 2*i) & 0x03030303;
- reg[0][i] = d1*values[aux8[0]];
- reg[1][i] = d1*values[aux8[1]];
- reg[2][i] = d2*values[aux8[2]];
- reg[3][i] = d2*values[aux8[3]];
- }
+ aux32 = v & 0x03030303;
+ reg[0][0] = values[aux8[0]]; reg[0][1] = values[aux8[1]]; reg[0][2] = values[aux8[2]]; reg[0][3] = values[aux8[3]];
+ aux32 = (v >> 4) & 0x03030303;
+ reg[1][0] = values[aux8[0]]; reg[1][1] = values[aux8[1]]; reg[1][2] = values[aux8[2]]; reg[1][3] = values[aux8[3]];
+
+ v = val2 | (val2 << 14);
+ aux32 = v & 0x03030303;
+ reg[2][0] = values[aux8[0]]; reg[2][1] = values[aux8[1]]; reg[2][2] = values[aux8[2]]; reg[2][3] = values[aux8[3]];
+ aux32 = (v >> 4) & 0x03030303;
+ reg[3][0] = values[aux8[0]]; reg[3][1] = values[aux8[1]]; reg[3][2] = values[aux8[2]]; reg[3][3] = values[aux8[3]];
}
template <typename type4x4>