diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-07-31 08:44:19 +0200 |
---|---|---|
committer | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-08-01 09:38:06 +0200 |
commit | 062313dab41381c6170175ea0c2075b2328b6f33 (patch) | |
tree | 68a2227f9bafe45a7ea071ed3ad1dcaf0cf95dee | |
parent | 57df5ccdd7495e67c4d3707cd0a0318f6d04f190 (diff) |
iq3_k: Metal dot product
Quite slow: 43 t/s for a 7B model
-rw-r--r-- | ggml/src/ggml-metal.metal | 63 |
1 files changed, 24 insertions, 39 deletions
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 988a820f..03d9153c 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3069,6 +3069,8 @@ constexpr constant static float kvalues_iq5k_f[64] = { constexpr constant static float kvalues_iq2k_f[8] = { -31.f, -13.f, 1.f, 17.f, -26.f, -8.f, 6.f, 22.f }; +constexpr constant static float kvalues_iq3k_f[16] = { -63.f, -40.f, -23.f, -10.f, 1.f, 13.f, 28.f, 47.f, -59.f, -36.f, -19.f, -6.f, 5.f, 17.f, 32.f, 51.f }; + kernel void kernel_cpy_f32_iq4_nl( device const float * src0, device void * dst, @@ -5314,7 +5316,6 @@ kernel void kernel_mul_mv_iq2_k_f32( kernel_mul_mv_iq2_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } -// TODO void kernel_mul_mv_iq3_k_f32_impl( device const void * src0, device const float * src1, @@ -5346,14 +5347,12 @@ void kernel_mul_mv_iq3_k_f32_impl( const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq2_k * x = (device const block_iq2_k *) src0 + ib_row + offset0; + device const block_iq3_k * x = (device const block_iq3_k *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; float yl[32]; float sumf[N_DST]={0.f}, all_sum; - const int step = (sizeof(block_q2_K) * nb) / 4; - const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 const int iq = it/4; // 0 or 1 @@ -5362,18 +5361,12 @@ void kernel_mul_mv_iq3_k_f32_impl( device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; - uint32_t aux32; - thread const uint8_t * aux8 = (thread const uint8_t *)&aux32; + uint32_t vl[2], vh[2]; + uint32_t aux32[2]; + thread const uint8_t * aux8 = (thread const uint8_t *)aux32; for (int ib = ix; ib < nb; ib += 4) { - //float4 sumy = {0.f, 0.f, 0.f, 0.f}; - //for (int i = 0; i < 8; ++i) { - // yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - // yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; - // yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; - // yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; - //} for (int i = 0; i < 8; ++i) { yl[i+ 0] = y4[i+ 0]; yl[i+ 8] = y4[i+32]; @@ -5383,28 +5376,34 @@ void kernel_mul_mv_iq3_k_f32_impl( for (int row = 0; row < N_DST; row++) { - device const block_iq2_k & xb = x[row*nb + ib]; - device const uint32_t * q32 = (device const uint32_t *)xb.qs + 8*iq + 2*ir; - device const uint32_t * sc = (device const uint32_t *)xb.scales; + device const block_iq3_k & xb = x[row*nb + ib]; + device const uint16_t * ql16 = (device const uint16_t *)xb.qs + 16*iq + 4*ir; + device const uint16_t * qh16 = (device const uint16_t *)xb.qh + 4*ir; + device const uint32_t * sc = (device const uint32_t *)xb.scales_l; const uint32_t scales32 = ((sc[iq] >> 4*is) & 0x0f0f0f0f) << 1; thread const int8_t * s8 = (thread const int8_t *)&scales32; uint16_t extra = xb.extra >> (8*iq + is); + uint16_t signs = xb.scales_h >> (8*iq + is); + + vl[0] = ql16[0] | ql16[1] << 16; + vl[1] = ql16[2] | ql16[3] << 16; + vh[0] = ((qh16[0] | (qh16[1] << 16)) << 4*(1-iq)) >> 2; + vh[1] = ((qh16[2] | (qh16[3] << 16)) << 4*(1-iq)) >> 2; float4 acc = {0.f}; for (int l = 0; l < 4; ++l) { - constant float * values = kvalues_iq2k_f + 4*(extra & 1); + constant float * values = kvalues_iq3k_f + 8*(extra & 1); extra >>= 2; - for (int i = 0; i < 2; ++i) { - aux32 = (q32[i] >> 2*l) & 0x03030303; - acc[l] += values[aux8[0]] * yl[8*l + 4*i + 0] + - + values[aux8[1]] * yl[8*l + 4*i + 1] + - + values[aux8[2]] * yl[8*l + 4*i + 2] + - + values[aux8[3]] * yl[8*l + 4*i + 3]; - } + aux32[0] = (vl[0] & 0x03030303) | (vh[0] & 0x04040404); + aux32[1] = (vl[1] & 0x03030303) | (vh[1] & 0x04040404); + for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]]; + vl[0] >>= 2; vl[1] >>= 2; + vh[0] >>= 1; vh[1] >>= 1; } - sumf[row] += (float)xb.d * (acc[0] * (s8[0] - 15) + acc[1] * (s8[1] - 15) * acc[2] * (s8[2] - 15) + acc[3] * (s8[3] - 15)); + sumf[row] += (float)xb.d * (acc[0] * (signs & 0x01 ? -s8[0] : s8[0]) + acc[1] * (signs & 0x04 ? -s8[1] : s8[1]) + + acc[2] * (signs & 0x10 ? -s8[2] : s8[2]) + acc[3] * (signs & 0x40 ? -s8[3] : s8[3])); } @@ -6371,7 +6370,6 @@ void dequantize_iq2_k(device const block_iq2_k * xb, short il, thread type4x4 & } } -// TODO template <typename type4x4> void dequantize_iq3_k(device const block_iq3_k * xb, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 @@ -6379,19 +6377,6 @@ void dequantize_iq3_k(device const block_iq3_k * xb, short il, thread type4x4 & device const uint16_t * q16h = (device const uint16_t *)xb->qh + 8*(il&1); half d = xb->d * (2*((xb->scales_l[il/2] >> 4*(il&1)) & 0xf) + 1) * (xb->scales_h & (1 << il) ? -1 : 1); - //constant int8_t * int_values = iq3nl_values + 8*((xb->extra >> il) & 1); - //half values[8] = { d * int_values[0], d * int_values[1], d * int_values[2], d * int_values[3], - // d * int_values[4], d * int_values[5], d * int_values[6], d * int_values[7] }; - //const int shift = 2*((il%8)/2); - //uint32_t aux32; - //thread const uint8_t * aux8 = (thread const uint8_t *)&aux32; - //for (int i = 0; i < 4; ++i) { - // uint32_t vl = q16l[2*i+0] | (q16l[2*i+1] << 16); - // uint32_t vh = q16h[2*i+0] | (q16h[2*i+1] << 16); - // aux32 = ((vl >> shift) & 0x03030303) | (((vh >> ((il/2)%8)) << 2) & 0x04040404); - // for (int j = 0; j < 4; ++j) reg[i][j] = values[aux8[j]]; - //} - constant int8_t * values = iq3nl_values + 8*((xb->extra >> il) & 1); const int shift = 2*((il%8)/2); |