diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-18 13:42:42 +0200 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:52 +0300 |
commit | fece7e1db7bf73497a32751af06c6dbf48c26b19 (patch) | |
tree | 98a89902822006586a4454e092b679d7e746b212 | |
parent | 4f51348d3d5c0f0bfee42d0a7efc81030f046d13 (diff) |
Bitnet(2.25 bpw): faster Metal dot product
With this we get TG-128 = 97 t/s.
-rw-r--r-- | ggml-metal.metal | 36 |
1 files changed, 19 insertions, 17 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal index b9be74b2..e5ef552c 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -5127,16 +5127,16 @@ void kernel_mul_mv_iq2_bn_f32_impl( device const block_iq2_bn * x = (device const block_iq2_bn *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[8]; + float yl[16]; float sumf[N_DST]={0.f}; float d1bn[N_DST]; const int nb32 = nb * (QK_IQ1BN / 32); - const int ix = tiisg/8; // 0...3 - const int ir = tiisg%8; // 0...7 + const int ix = tiisg/4; // 0...7 + const int ir = tiisg%4; // 0...3 - device const float * y4 = y + 64 * ix + 2 * ir; + device const float * y4 = y + 64 * ix + 4 * ir; typedef union { float f; uint32_t i; } scale_t; scale_t scale; @@ -5145,32 +5145,34 @@ void kernel_mul_mv_iq2_bn_f32_impl( d1bn[row] = x[nb*row].d; } - for (int ib = ix; ib < nb; ib += 4) { + for (int ib = ix; ib < nb; ib += 8) { float sumy = 0.f; - for (int i = 0; i < 2; ++i) { - yl[i+0] = y4[i+ 0]; sumy += yl[i]; - yl[i+2] = y4[i+16]; sumy += yl[i+2]; - yl[i+4] = y4[i+32]; sumy += yl[i+4]; - yl[i+6] = y4[i+48]; sumy += yl[i+6]; + for (int i = 0; i < 4; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy += yl[i+ 0]; + yl[i+ 4] = y4[i+16]; sumy += yl[i+ 4]; + yl[i+ 8] = y4[i+32]; sumy += yl[i+ 8]; + yl[i+12] = y4[i+48]; sumy += yl[i+12]; } - device const uint8_t * qs = x[ib].qs + 2*ir; + device const uint8_t * qs = x[ib].qs + 4*ir; for (int row = 0; row < N_DST; row++) { - float acc = 0; - for (int j = 0; j < 2; ++j) { - acc += yl[j+0] * ((qs[j] >> 0) & 0x03) + yl[j+2] * ((qs[j] >> 2) & 0x03) - + yl[j+4] * ((qs[j] >> 4) & 0x03) + yl[j+6] * ((qs[j] >> 6) & 0x03); + float4 acc = {0.f}; + for (int j = 0; j < 4; ++j) { + acc[0] += yl[j+ 0] * (qs[j] & 0x03); + acc[1] += yl[j+ 4] * (qs[j] & 0x0c); + acc[2] += yl[j+ 8] * (qs[j] & 0x30); + acc[3] += yl[j+12] * (qs[j] & 0xc0); } - sumf[row] += acc - sumy; + sumf[row] += acc[0] + 0.25f*acc[1] + 0.0625*acc[2] + 0.015625f*acc[3] - sumy; qs += nb*sizeof(block_iq2_bn); } - y4 += 64 * 4; + y4 += 64 * 8; } for (int row = 0; row < N_DST; row += 2) { |