From f7b05a09ddb2b2579f6301a6223d894f5b97c494 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 26 Oct 2024 10:59:59 +0200 Subject: Faster IQ1_BN Metal implementation (#107) * iq1_bn: faster Metal dot product 82 t/s -> 87.9 t/s * iq1_bn(Metal): 87.9 -> 89.0 t/s for TG-128 * iq1_bn(Metal): 89.0 -> 94.7 t/s for TG-128 So, total improvement is ~15%. Not bad. * iq1_bn(Metal): 686 -> 702 t/s for PP-512 * iq2_bn(Metal): 710 -> 714 t/s for PP-512 --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-metal.metal | 58 ++++++++++++++++++++++++----------------------- 1 file changed, 30 insertions(+), 28 deletions(-) (limited to 'ggml/src') diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index bc0ea9f5..287d8563 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -5459,16 +5459,15 @@ void kernel_mul_mv_iq1_bn_f32_impl( device const float * y4 = (device const float *)y + 32 * ix + 16 * ir; - const float values[3] = {-1.f, 0.f, 1.f}; - - constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1}; + constexpr uint16_t k_mult[5] = {81, 27, 9, 3, 1}; const int ib = ix % (QK_IQ1BN / 32); const int i16 = 2*ib + ir; + float sumy = 0; for (int ib32 = ix; ib32 < nb32; ib32 += 16) { - for (int j = 0; j < 16; ++j) yl[j] = y4[j]; + for (int j = 0; j < 16; ++j) { yl[j] = y4[j]; sumy += y4[j]; } const int ibl = ib32 / (QK_IQ1BN / 32); device const block_iq1_bn * xr = x + ibl; @@ -5478,18 +5477,20 @@ void kernel_mul_mv_iq1_bn_f32_impl( for (int row = 0; row < N_DST; row++) { float acc = 0; - int i = 0; + thread const float * yy = yl; for (int k = 0; k < 3; ++k) { - uint8_t q = ql[k]; - for (int j = 0; j < 5; ++j) { - uint8_t v = k_mult[j]*q; - v = 3*v >> 8; //(v + (v >> 1)) >> 7; - acc += yl[i++] * values[v]; + uint16_t q = ql[k]; + for (int j = 4; j >= 0; --j) { + uint16_t v = q & 0xff; + v += v << 1; + acc += yy[j] * (v & 0xff00); + q += q << 1; } + yy += 5; } - uint8_t v = k_mult[i16]*extra[0]; - v = 3*v >> 8; //(v + (v >> 1)) >> 7; - acc += yl[15] * values[v]; + uint16_t v = (k_mult[i16]*extra[0]) & 0xff; + v += v << 1; + acc += yl[15] * (v & 0xff00); sumf[row] += acc; @@ -5501,7 +5502,7 @@ void kernel_mul_mv_iq1_bn_f32_impl( } for (int row = 0; row < N_DST; row += 2) { - float2 r = {sumf[row], sumf[row+1]}; + float2 r = {0.00390625f * sumf[row] - sumy, 0.00390625 * sumf[row+1] - sumy}; r = simd_sum(r); if (tiisg < 2) { dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg] * scale[row + tiisg]; @@ -7475,30 +7476,31 @@ template void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) { // il is in 0...3 - constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1}; + constexpr uint16_t k_mult[5] = {81, 27, 9, 3, 1}; + constexpr half k_values[3] = {-1.h, 0.h, 1.h}; - int i = 0; for (int k = 0; k < 3; ++k) { - uint8_t q = xb->ql[3*il + k]; - for (int j = 0; j < 5; ++j) { - uint8_t v = k_mult[j]*q; - int8_t vs = 3*v >> 8; - //int8_t vs = (v + (v >> 1)) >> 7; - reg[i/4][i%4] = vs - 1; - ++i; + uint16_t q = xb->ql[3*il + k]; + int i = 5*k + 4; + for (int j = 4; j >= 0; --j) { + uint16_t v = q & 0xff; + v += v << 1; + reg[i/4][i%4] = k_values[v >> 8]; + q += q << 1; + --i; } } - uint8_t v = k_mult[il]*xb->extra; - int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7; - reg[3][3] = vs - 1; + uint16_t v = (k_mult[il]*xb->extra) & 0xff; + v += v << 1; + reg[3][3] = k_values[v >> 8]; } template void dequantize_iq2_bn(device const block_iq2_bn * xb, short il, thread type4x4 & reg) { // il is in 0...3 - constexpr float k_scale[4] = {1.f, 0.25f, 0.0625f, 0.015625f}; + constexpr half k_scale[4] = {1.h, 0.25h, 0.0625h, 0.015625h}; constexpr uint8_t k_mask[4] = {0x03, 0x0c, 0x30, 0xc0}; - const float d = k_scale[il]; + const half d = k_scale[il]; uint8_t mask = k_mask[il]; for (int j = 0; j < 16; ++j) { -- cgit v1.2.3