summaryrefslogtreecommitdiff
path: root/ggml-metal.metal
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.metal')
-rw-r--r--ggml-metal.metal20
1 files changed, 19 insertions, 1 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal
index e8083734..744b2a8b 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -4497,7 +4497,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
device const float * y4 = y + 32 * ix;
+#if QK_K != 64
iq1m_scale_t scale;
+#endif
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
@@ -4519,7 +4521,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
for (int row = 0; row < N_DST; row++) {
+#if QK_K != 64
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+#endif
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
@@ -4535,8 +4539,14 @@ void kernel_mul_mv_iq1_m_f32_impl(
}
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+#if QK_K == 64
+ const float d = (float) *((device const half *)(sc - 1));
+ sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
+ (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
+#else
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
+#endif
sc += nb*sizeof(block_iq1_m)/2;
qs += nb*sizeof(block_iq1_m);
@@ -5277,13 +5287,21 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
- iq1m_scale_t scale;
device const uint16_t * sc = (device const uint16_t *)xb->scales;
+#if QK_K == 64
+ const float d = xb->d;
+#else
+ iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const float d = scale.f16;
+#endif
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint8_t * qh = xb->qh + 2*ib32 + il;
+#if QK_K == 64
+ const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
+#else
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
+#endif
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));