diff options
-rw-r--r-- | ggml-metal.metal | 65 |
1 files changed, 30 insertions, 35 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal index 44055d2d..a709b86d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -5026,15 +5026,16 @@ void kernel_mul_mv_iq1_bn_f32_impl( device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[32]; + float yl[16]; float sumf[N_DST]={0.f}, all_sum; float d1bn[N_DST]; const int nb32 = nb * (QK_IQ1BN / 32); - const int ix = tiisg; + const int ix = tiisg/2; + const int ir = tiisg%2; - device const float * y4 = y + 32 * ix; + device const float * y4 = y + 32 * ix + 16 * ir; typedef union { float f; uint32_t i; } scale_t; scale_t scale; @@ -5048,14 +5049,12 @@ void kernel_mul_mv_iq1_bn_f32_impl( uint32_t aux32; thread const uint8_t * aux8 = (thread const uint8_t *)&aux32; - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + for (int ib32 = ix; ib32 < nb32; ib32 += 16) { - float4 sumy = {0.f}; + float2 sumy = {0.f}; for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+24]; sumy[3] += yl[i+24]; + yl[i+0] = y4[i+0]; sumy[0] += yl[i+ 0]; + yl[i+8] = y4[i+8]; sumy[1] += yl[i+ 8]; } const int ibl = ib32 / (QK_IQ1BN / 32); @@ -5063,36 +5062,32 @@ void kernel_mul_mv_iq1_bn_f32_impl( device const block_iq1_bn * xr = x + ibl; device const uint16_t * extra = (device const uint16_t *)&xr->extra; - device const uint8_t * ql = xr->ql + 4 * ib; - device const uint8_t * qh = xr->qh + 2 * ib; + device const uint8_t * ql = xr->ql + 4 * ib + 2*ir; + device const uint8_t * qh = xr->qh + 2 * ib + ir;; for (int row = 0; row < N_DST; row++) { - uint8_t signs = extra[0] >> (8 + 4*ib); - float4 acc = {0.f}; - for (int j = 0; j < 2; ++j) { - uint32_t v1 = iq1bn_grid_u16[ql[2*j+0] | ((qh[j] << 8) & 0x0f00)]; - uint32_t v2 = iq1bn_grid_u16[ql[2*j+1] | ((qh[j] << 4) & 0x0f00)]; - uint32_t v = v1 | (v2 << 16); - aux32 = v & 0x03030303; - acc[2*j+0] += yl[16*j + 0] * aux8[0] + yl[16*j + 4] * aux8[1]; - acc[2*j+1] += yl[16*j + 8] * aux8[2] + yl[16*j +12] * aux8[3]; - aux32 = (v >> 2) & 0x03030303; - acc[2*j+0] += yl[16*j + 1] * aux8[0] + yl[16*j + 5] * aux8[1]; - acc[2*j+1] += yl[16*j + 9] * aux8[2] + yl[16*j +13] * aux8[3]; - aux32 = (v >> 4) & 0x03030303; - acc[2*j+0] += yl[16*j + 2] * aux8[0] + yl[16*j + 6] * aux8[1]; - acc[2*j+1] += yl[16*j +10] * aux8[2] + yl[16*j +14] * aux8[3]; - aux32 = (v >> 6) & 0x03030303; - acc[2*j+0] += yl[16*j + 3] * aux8[0] + yl[16*j + 7] * aux8[1]; - acc[2*j+1] += yl[16*j +12] * aux8[2] + yl[16*j +15] * aux8[3]; - } + uint8_t signs = extra[0] >> (8 + 4*ib + 2*ir); + float2 acc = {0.f}; + + uint32_t v1 = iq1bn_grid_u16[ql[0] | ((qh[0] << 8) & 0x0f00)]; + uint32_t v2 = iq1bn_grid_u16[ql[1] | ((qh[0] << 4) & 0x0f00)]; + uint32_t v = v1 | (v2 << 16); + aux32 = v & 0x03030303; + acc[0] += yl[ 0] * aux8[0] + yl[ 4] * aux8[1]; + acc[1] += yl[ 8] * aux8[2] + yl[12] * aux8[3]; + aux32 = (v >> 2) & 0x03030303; + acc[0] += yl[ 1] * aux8[0] + yl[ 5] * aux8[1]; + acc[1] += yl[ 9] * aux8[2] + yl[13] * aux8[3]; + aux32 = (v >> 4) & 0x03030303; + acc[0] += yl[ 2] * aux8[0] + yl[ 6] * aux8[1]; + acc[1] += yl[10] * aux8[2] + yl[14] * aux8[3]; + aux32 = (v >> 6) & 0x03030303; + acc[0] += yl[ 3] * aux8[0] + yl[ 7] * aux8[1]; + acc[1] += yl[12] * aux8[2] + yl[15] * aux8[3]; acc -= sumy; - float sum = (signs & 1 ? -acc[0] : acc[0]) - + (signs & 2 ? -acc[1] : acc[1]) - + (signs & 4 ? -acc[2] : acc[2]) - + (signs & 8 ? -acc[3] : acc[3]); + float sum = (signs & 1 ? -acc[0] : acc[0]) + (signs & 2 ? -acc[1] : acc[1]); sumf[row] += sum; @@ -5101,7 +5096,7 @@ void kernel_mul_mv_iq1_bn_f32_impl( qh += nb*sizeof(block_iq1_bn); } - y4 += 32 * 32; + y4 += 32 * 16; } for (int row = 0; row < N_DST; ++row) { |