diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-10-16 14:13:03 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-16 14:13:03 +0300 |
commit | 993ca95e9e3108f0352fa2a3384cab0775c7f7c1 (patch) | |
tree | 5fd1e52f04382acf4e3ed1226e4fe8084c06dd1e | |
parent | ff23008ed4f73c2c7091e7333495e36c268156bc (diff) |
iq4_ks: faster dot product on Metal (#90)
TG-128(LLaMA-3.1-8B) goes to 52.5 t/s up from 48.4 t/s.
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-metal.metal | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 72595c91..dff9326f 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6079,6 +6079,7 @@ void kernel_mul_mv_iq4_ks_f32_impl( float4 yl[4]; float2 sumf = 0.f; + float d[2]; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -6087,22 +6088,25 @@ void kernel_mul_mv_iq4_ks_f32_impl( float4 qf1, qf2; + device const float * dptr = (device const float *)cx; + d[0] = *dptr; + device const block_iq4_ks * x = (device const block_iq4_ks *)(dptr + 1) + ix; + dptr += row_size/4; + d[1] = *dptr; + for (int ibl = ix; ibl < nb; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - device const float * dptr = (device const float *)cx; + device const uint8_t * scales = x->scales; for (int row = 0; row < 2; ++row) { - //device const float * dptr = (device const float *)(cx + row*row_size); - const float d = *dptr; - device const block_iq4_ks * x = (device const block_iq4_ks *)(dptr + 1); - device const block_iq4_ks & xb = x[ibl]; - device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + threadgroup const float * block_values = shared_values + ((scales[ib] & 1) << 4); + const float ls = ((scales[ib] & 254) - 127); - threadgroup const float * block_values = shared_values + ((xb.scales[ib] & 1) << 4); + device const uint32_t * q4 = (device const uint32_t *)scales + QK_K/128 + 4*ib + 2*il; float4 acc1 = {0.f}, acc2 = {0.f}; @@ -6122,14 +6126,14 @@ void kernel_mul_mv_iq4_ks_f32_impl( acc1 += acc2; - const int ls = (xb.scales[ib] & 254) - 127; - sumf[row] += d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + sumf[row] += d[row] * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - dptr += row_size/4; + scales += row_size; } yb += 2 * QK_K; + x += 2; } sumf = simd_sum(sumf); |