diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-25 11:16:13 +0200 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-25 11:16:13 +0200 |
commit | 7de9559cf258a05f7730006780c4962be8ccccc4 (patch) | |
tree | 95603d2529fb68c716d49ad0e0085f199e7d9b6a | |
parent | aa14a06b44ff12be7e4461a6e169a657275a5b20 (diff) |
Bitnet: adapt NEON and Metal to the alternative grid
-rw-r--r-- | ggml-metal.metal | 41 | ||||
-rw-r--r-- | iqk_mul_mat.cpp | 20 |
2 files changed, 31 insertions, 30 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal index 34e77728..4ec98e11 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -5061,7 +5061,7 @@ void kernel_mul_mv_iq1_bn_f32_impl( const int ix = tiisg/4; const int ir = tiisg%4; - device const float * y4 = y + 32 * ix + 8 * ir; + device const float4 * y4 = (device const float4 *)y + 8 * ix + 2 * ir; uint32_t aux32[2]; thread const uint8_t * aux8 = (thread const uint8_t *)aux32; @@ -5070,8 +5070,7 @@ void kernel_mul_mv_iq1_bn_f32_impl( for (int ib32 = ix; ib32 < nb32; ib32 += 8) { - yl[0] = {y4[0], y4[4], y4[2], y4[6]}; - yl[1] = {y4[1], y4[5], y4[3], y4[7]}; + yl[0] = y4[0]; yl[1] = y4[1]; const int ibl = ib32 / (QK_IQ1BN / 32); const int ib = ib32 % (QK_IQ1BN / 32); @@ -5085,9 +5084,9 @@ void kernel_mul_mv_iq1_bn_f32_impl( uint8_t signs = extra[0] >> (4*ib + ir); - uint32_t v = iq1bn_grid_u16[ql[0] | ((qh[0] << (8 - 4*(ir%2))) & 0x0f00)]; - uint32_t v32 = v | (v << 12); - aux32[0] = v32 & 0x03030303; aux32[1] = (v32 >> 2) & 0x03030303; + uint32_t v = iq1bn_grid_zzz[ql[0] | ((qh[0] << (8 - 4*(ir%2))) & 0x0f00)]; + uint32_t v32 = v | (v << 14); + aux32[0] = v32 & 0x03030303; aux32[1] = (v32 >> 4) & 0x03030303; float4 acc4 = yl[0] * float4{values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]} + yl[1] * float4{values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]}; float acc = acc4[0] + acc4[1] + acc4[2] + acc4[3]; @@ -5099,7 +5098,7 @@ void kernel_mul_mv_iq1_bn_f32_impl( qh += nb*sizeof(block_iq1_bn); } - y4 += 32 * 8; + y4 += 32 * 2; } for (int row = 0; row < N_DST; row += 2) { @@ -5956,26 +5955,28 @@ template <typename type4x4> void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) { // il is in 0...3 uint8_t gs = xb->extra >> 2*il; - const half d1 = gs & 1 ? -1.h : 1.h; - const half d2 = gs & 2 ? -1.h : 1.h; - uint32_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)]; - uint32_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)]; + uint16_t idx1 = xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00); + uint16_t idx2 = xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00); + uint16_t val1 = gs & 1 ? 0xaaaa - iq1bn_grid_zzz[idx1] : iq1bn_grid_zzz[idx1]; + uint16_t val2 = gs & 2 ? 0xaaaa - iq1bn_grid_zzz[idx2] : iq1bn_grid_zzz[idx2]; - uint32_t v = v1 | (v2 << 16); + uint32_t v = val1 | (val1 << 14); uint32_t aux32; thread const uint8_t * aux8 = (thread const uint8_t *)&aux32; const half values[3] = {-1.h, 0.h, 1.h}; -#pragma unroll(4) - for (int i = 0; i < 4; ++i) { - aux32 = (v >> 2*i) & 0x03030303; - reg[0][i] = d1*values[aux8[0]]; - reg[1][i] = d1*values[aux8[1]]; - reg[2][i] = d2*values[aux8[2]]; - reg[3][i] = d2*values[aux8[3]]; - } + aux32 = v & 0x03030303; + reg[0][0] = values[aux8[0]]; reg[0][1] = values[aux8[1]]; reg[0][2] = values[aux8[2]]; reg[0][3] = values[aux8[3]]; + aux32 = (v >> 4) & 0x03030303; + reg[1][0] = values[aux8[0]]; reg[1][1] = values[aux8[1]]; reg[1][2] = values[aux8[2]]; reg[1][3] = values[aux8[3]]; + + v = val2 | (val2 << 14); + aux32 = v & 0x03030303; + reg[2][0] = values[aux8[0]]; reg[2][1] = values[aux8[1]]; reg[2][2] = values[aux8[2]]; reg[2][3] = values[aux8[3]]; + aux32 = (v >> 4) & 0x03030303; + reg[3][0] = values[aux8[0]]; reg[3][1] = values[aux8[1]]; reg[3][2] = values[aux8[2]]; reg[3][3] = values[aux8[3]]; } template <typename type4x4> diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 907b0d19..df4dfc5f 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -4285,9 +4285,9 @@ struct DequantizerIQ1BN { vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}), vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}), }; - const int8x16_t shift = vreinterpretq_s8_u32(vdupq_n_u32(0xfafcfe00)); + const int8x16_t shift = vreinterpretq_s16_u64(vdupq_n_u64(0xfffafffcfffe0000)); const uint8x16_t qmask = vdupq_n_u8(3); - const uint8x16_t shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0909090908080808}); + const uint8x16_t shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0100010001000100, 0x0908090809080908}); const uint8x16_t mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); int8x16x4_t signs; uint64x2x4_t a; @@ -4299,15 +4299,15 @@ struct DequantizerIQ1BN { signs.val[2] = vqtbl1q_u8(all_signs, sign_shuffles.val[2]); signs.val[3] = vqtbl1q_u8(all_signs, sign_shuffles.val[3]); - a.val[0] = uint64x2_t{iq1bn_grid_u16[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_u16[ql[1] | ((qh[0] << 4) & 0x0f00)]}; - a.val[1] = uint64x2_t{iq1bn_grid_u16[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_u16[ql[3] | ((qh[1] << 4) & 0x0f00)]}; - a.val[2] = uint64x2_t{iq1bn_grid_u16[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_u16[ql[5] | ((qh[2] << 4) & 0x0f00)]}; - a.val[3] = uint64x2_t{iq1bn_grid_u16[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_u16[ql[7] | ((qh[3] << 4) & 0x0f00)]}; + a.val[0] = uint64x2_t{iq1bn_grid_zzz[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_zzz[ql[1] | ((qh[0] << 4) & 0x0f00)]}; + a.val[1] = uint64x2_t{iq1bn_grid_zzz[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_zzz[ql[3] | ((qh[1] << 4) & 0x0f00)]}; + a.val[2] = uint64x2_t{iq1bn_grid_zzz[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_zzz[ql[5] | ((qh[2] << 4) & 0x0f00)]}; + a.val[3] = uint64x2_t{iq1bn_grid_zzz[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_zzz[ql[7] | ((qh[3] << 4) & 0x0f00)]}; - v.val[0] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), shift), qmask), m1); - v.val[1] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), shift), qmask), m1); - v.val[2] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), shift), qmask), m1); - v.val[3] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), shift), qmask), m1); + v.val[0] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), shift), qmask), m1); + v.val[1] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), shift), qmask), m1); + v.val[2] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), shift), qmask), m1); + v.val[3] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), shift), qmask), m1); v.val[0] = vmulq_s8(v.val[0], signs.val[0]); v.val[1] = vmulq_s8(v.val[1], signs.val[1]); |