summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-07-31 08:44:19 +0200
committerKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-08-01 09:38:06 +0200
commit062313dab41381c6170175ea0c2075b2328b6f33 (patch)
tree68a2227f9bafe45a7ea071ed3ad1dcaf0cf95dee
parent57df5ccdd7495e67c4d3707cd0a0318f6d04f190 (diff)
iq3_k: Metal dot product
Quite slow: 43 t/s for a 7B model
-rw-r--r--ggml/src/ggml-metal.metal63
1 files changed, 24 insertions, 39 deletions
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 988a820f..03d9153c 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -3069,6 +3069,8 @@ constexpr constant static float kvalues_iq5k_f[64] = {
constexpr constant static float kvalues_iq2k_f[8] = { -31.f, -13.f, 1.f, 17.f, -26.f, -8.f, 6.f, 22.f };
+constexpr constant static float kvalues_iq3k_f[16] = { -63.f, -40.f, -23.f, -10.f, 1.f, 13.f, 28.f, 47.f, -59.f, -36.f, -19.f, -6.f, 5.f, 17.f, 32.f, 51.f };
+
kernel void kernel_cpy_f32_iq4_nl(
device const float * src0,
device void * dst,
@@ -5314,7 +5316,6 @@ kernel void kernel_mul_mv_iq2_k_f32(
kernel_mul_mv_iq2_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
-// TODO
void kernel_mul_mv_iq3_k_f32_impl(
device const void * src0,
device const float * src1,
@@ -5346,14 +5347,12 @@ void kernel_mul_mv_iq3_k_f32_impl(
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
- device const block_iq2_k * x = (device const block_iq2_k *) src0 + ib_row + offset0;
+ device const block_iq3_k * x = (device const block_iq3_k *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
- const int step = (sizeof(block_q2_K) * nb) / 4;
-
const int ix = tiisg/8; // 0...3
const int it = tiisg%8; // 0...7
const int iq = it/4; // 0 or 1
@@ -5362,18 +5361,12 @@ void kernel_mul_mv_iq3_k_f32_impl(
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
- uint32_t aux32;
- thread const uint8_t * aux8 = (thread const uint8_t *)&aux32;
+ uint32_t vl[2], vh[2];
+ uint32_t aux32[2];
+ thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
for (int ib = ix; ib < nb; ib += 4) {
- //float4 sumy = {0.f, 0.f, 0.f, 0.f};
- //for (int i = 0; i < 8; ++i) {
- // yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
- // yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
- // yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
- // yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
- //}
for (int i = 0; i < 8; ++i) {
yl[i+ 0] = y4[i+ 0];
yl[i+ 8] = y4[i+32];
@@ -5383,28 +5376,34 @@ void kernel_mul_mv_iq3_k_f32_impl(
for (int row = 0; row < N_DST; row++) {
- device const block_iq2_k & xb = x[row*nb + ib];
- device const uint32_t * q32 = (device const uint32_t *)xb.qs + 8*iq + 2*ir;
- device const uint32_t * sc = (device const uint32_t *)xb.scales;
+ device const block_iq3_k & xb = x[row*nb + ib];
+ device const uint16_t * ql16 = (device const uint16_t *)xb.qs + 16*iq + 4*ir;
+ device const uint16_t * qh16 = (device const uint16_t *)xb.qh + 4*ir;
+ device const uint32_t * sc = (device const uint32_t *)xb.scales_l;
const uint32_t scales32 = ((sc[iq] >> 4*is) & 0x0f0f0f0f) << 1;
thread const int8_t * s8 = (thread const int8_t *)&scales32;
uint16_t extra = xb.extra >> (8*iq + is);
+ uint16_t signs = xb.scales_h >> (8*iq + is);
+
+ vl[0] = ql16[0] | ql16[1] << 16;
+ vl[1] = ql16[2] | ql16[3] << 16;
+ vh[0] = ((qh16[0] | (qh16[1] << 16)) << 4*(1-iq)) >> 2;
+ vh[1] = ((qh16[2] | (qh16[3] << 16)) << 4*(1-iq)) >> 2;
float4 acc = {0.f};
for (int l = 0; l < 4; ++l) {
- constant float * values = kvalues_iq2k_f + 4*(extra & 1);
+ constant float * values = kvalues_iq3k_f + 8*(extra & 1);
extra >>= 2;
- for (int i = 0; i < 2; ++i) {
- aux32 = (q32[i] >> 2*l) & 0x03030303;
- acc[l] += values[aux8[0]] * yl[8*l + 4*i + 0] +
- + values[aux8[1]] * yl[8*l + 4*i + 1] +
- + values[aux8[2]] * yl[8*l + 4*i + 2] +
- + values[aux8[3]] * yl[8*l + 4*i + 3];
- }
+ aux32[0] = (vl[0] & 0x03030303) | (vh[0] & 0x04040404);
+ aux32[1] = (vl[1] & 0x03030303) | (vh[1] & 0x04040404);
+ for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
+ vl[0] >>= 2; vl[1] >>= 2;
+ vh[0] >>= 1; vh[1] >>= 1;
}
- sumf[row] += (float)xb.d * (acc[0] * (s8[0] - 15) + acc[1] * (s8[1] - 15) * acc[2] * (s8[2] - 15) + acc[3] * (s8[3] - 15));
+ sumf[row] += (float)xb.d * (acc[0] * (signs & 0x01 ? -s8[0] : s8[0]) + acc[1] * (signs & 0x04 ? -s8[1] : s8[1]) +
+ acc[2] * (signs & 0x10 ? -s8[2] : s8[2]) + acc[3] * (signs & 0x40 ? -s8[3] : s8[3]));
}
@@ -6371,7 +6370,6 @@ void dequantize_iq2_k(device const block_iq2_k * xb, short il, thread type4x4 &
}
}
-// TODO
template <typename type4x4>
void dequantize_iq3_k(device const block_iq3_k * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256
@@ -6379,19 +6377,6 @@ void dequantize_iq3_k(device const block_iq3_k * xb, short il, thread type4x4 &
device const uint16_t * q16h = (device const uint16_t *)xb->qh + 8*(il&1);
half d = xb->d * (2*((xb->scales_l[il/2] >> 4*(il&1)) & 0xf) + 1) * (xb->scales_h & (1 << il) ? -1 : 1);
- //constant int8_t * int_values = iq3nl_values + 8*((xb->extra >> il) & 1);
- //half values[8] = { d * int_values[0], d * int_values[1], d * int_values[2], d * int_values[3],
- // d * int_values[4], d * int_values[5], d * int_values[6], d * int_values[7] };
- //const int shift = 2*((il%8)/2);
- //uint32_t aux32;
- //thread const uint8_t * aux8 = (thread const uint8_t *)&aux32;
- //for (int i = 0; i < 4; ++i) {
- // uint32_t vl = q16l[2*i+0] | (q16l[2*i+1] << 16);
- // uint32_t vh = q16h[2*i+0] | (q16h[2*i+1] << 16);
- // aux32 = ((vl >> shift) & 0x03030303) | (((vh >> ((il/2)%8)) << 2) & 0x04040404);
- // for (int j = 0; j < 4; ++j) reg[i][j] = values[aux8[j]];
- //}
-
constant int8_t * values = iq3nl_values + 8*((xb->extra >> il) & 1);
const int shift = 2*((il%8)/2);