diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-17 10:48:26 +0200 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:51 +0300 |
commit | d9fb92b7104b929f0427323f7964ef7a4da33d2b (patch) | |
tree | 5fabacd57ad9edf624ad4fa447023e3be923b2cb | |
parent | 0c5a353ebdcc58e8b8051f2c38a92a8c23fa8092 (diff) |
iq1_bn: Metal now works
PP performance is decent (668 t/s v 724 t/s for q4_0),
but TG is kind of low (60 t/s vs 81 t/s for q4_0).
-rw-r--r-- | ggml-metal.m | 29 | ||||
-rw-r--r-- | ggml-metal.metal | 194 |
2 files changed, 221 insertions, 2 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 79902c9a..3a383d1e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -76,6 +76,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, @@ -104,6 +105,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, @@ -128,6 +130,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, @@ -149,6 +152,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, @@ -170,6 +174,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, @@ -532,6 +537,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, get_rows_iq1_bn, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); @@ -560,6 +566,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, mul_mv_iq1_bn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); @@ -584,6 +591,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, mul_mv_id_iq1_bn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); @@ -605,6 +613,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, mul_mm_iq1_bn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); @@ -626,6 +635,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, mul_mm_id_iq1_bn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); @@ -1607,6 +1617,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); @@ -1767,6 +1778,12 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; } break; + case GGML_TYPE_IQ1_BN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32].pipeline; + } break; case GGML_TYPE_IQ4_NL: { nth0 = 4; @@ -1813,7 +1830,7 @@ static enum ggml_status ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| src0t == GGML_TYPE_IQ1_BN) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { @@ -1913,6 +1930,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); @@ -2067,6 +2085,12 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; } break; + case GGML_TYPE_IQ1_BN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline; + } break; case GGML_TYPE_IQ4_NL: { nth0 = 4; @@ -2120,7 +2144,7 @@ static enum ggml_status ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| src0t == GGML_TYPE_IQ1_BN) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { @@ -2179,6 +2203,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; diff --git a/ggml-metal.metal b/ggml-metal.metal index e2796fd6..52a3133d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4992,6 +4992,126 @@ void kernel_mul_mv_iq1_m_f32_impl( } } +// Not working. Don't see the bug. +void kernel_mul_mv_iq1_bn_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_IQ1BN; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + float d1bn[N_DST]; + + const int nb32 = nb * (QK_IQ1BN / 32); + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + typedef union { float f; uint32_t i; } scale_t; + scale_t scale; + + for (int row = 0; row < N_DST; ++row) { + uint8_t u = x[nb*row].extra & 0xff; + scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + d1bn[row] = scale.f; + } + + uint32_t aux32; + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + float4 sumy = {0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+24]; sumy[3] += yl[i+24]; + } + + const int ibl = ib32 / (QK_IQ1BN / 32); + const int ib = ib32 % (QK_IQ1BN / 32); + + device const block_iq1_bn * xr = x + ibl; + device const uint16_t * extra = (device const uint16_t *)&xr->extra; + device const uint8_t * ql = xr->ql + 4 * ib; + device const uint8_t * qh = xr->qh + 2 * ib; + + for (int row = 0; row < N_DST; row++) { + + //uint8_t u = extra[0] & 0xff; + //scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + uint8_t signs = extra[0] >> (8 + 4*ib); + float4 acc = {0.f}; + for (int j = 0; j < 2; ++j) { + uint32_t v1 = iq1bn_grid_u16[ql[2*j+0] | ((qh[j] << 8) & 0x0f00)]; + uint32_t v2 = iq1bn_grid_u16[ql[2*j+1] | ((qh[j] << 4) & 0x0f00)]; + uint32_t v = v1 | (v2 << 16); + aux32 = v & 0x03030303; + acc[2*j+0] += yl[16*j + 0] * aux8[0] + yl[16*j + 4] * aux8[1]; + acc[2*j+1] += yl[16*j + 8] * aux8[2] + yl[16*j +12] * aux8[3]; + aux32 = (v >> 2) & 0x03030303; + acc[2*j+0] += yl[16*j + 1] * aux8[0] + yl[16*j + 5] * aux8[1]; + acc[2*j+1] += yl[16*j + 9] * aux8[2] + yl[16*j +13] * aux8[3]; + aux32 = (v >> 4) & 0x03030303; + acc[2*j+0] += yl[16*j + 2] * aux8[0] + yl[16*j + 6] * aux8[1]; + acc[2*j+1] += yl[16*j +10] * aux8[2] + yl[16*j +14] * aux8[3]; + aux32 = (v >> 6) & 0x03030303; + acc[2*j+0] += yl[16*j + 3] * aux8[0] + yl[16*j + 7] * aux8[1]; + acc[2*j+1] += yl[16*j +12] * aux8[2] + yl[16*j +15] * aux8[3]; + } + + float sum = (signs & 1 ? sumy[0] - acc[0] : acc[0] - sumy[0]) + + (signs & 2 ? sumy[1] - acc[1] : acc[1] - sumy[1]) + + (signs & 4 ? sumy[2] - acc[2] : acc[2] - sumy[2]) + + (signs & 8 ? sumy[3] - acc[3] : acc[3] - sumy[3]); + sumf[row] += sum; + + extra += nb*sizeof(block_iq1_bn)/2; + ql += nb*sizeof(block_iq1_bn); + qh += nb*sizeof(block_iq1_bn); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * d1bn[row]; + } + } +} + void kernel_mul_mv_iq4_nl_f32_impl( device const void * src0, device const float * src1, @@ -5237,6 +5357,34 @@ kernel void kernel_mul_mv_iq1_m_f32( kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } +[[host_name("kernel_mul_mv_iq1_bn_f32")]] +kernel void kernel_mul_mv_iq1_bn_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_bn_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + [[host_name("kernel_mul_mv_iq4_nl_f32")]] kernel void kernel_mul_mv_iq4_nl_f32( device const void * src0, @@ -5695,6 +5843,48 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & } template <typename type4x4> +void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) { + // il is in 0...3 + typedef union { float f; uint32_t i; } scale_t; + scale_t scale; + uint8_t u = xb->extra & 0xff; + scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + //uint32_t u = xb->extra & 0xff; + //scale.i = (u << 19) + 905969664; + uint8_t gs = xb->extra >> (8 + 2*il); + const float d1 = gs & 1 ? -scale.f : scale.f; + const float d2 = gs & 2 ? -scale.f : scale.f; + + uint32_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)]; + uint32_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)]; + + uint32_t v = v1 | (v2 << 16); + uint32_t aux32; + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32; + + for (int i = 0; i < 4; ++i) { + aux32 = (v >> 2*i) & 0x03030303; + reg[0][i] = d1*aux8[0] - d1; + reg[1][i] = d1*aux8[1] - d1; + reg[2][i] = d2*aux8[2] - d2; + reg[3][i] = d2*aux8[3] - d2; + } + + //Basically same performance as above. I guess, the compiler makes the transformation automatically + //uint16_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)]; + //uint16_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)]; + //for (int i = 0; i < 4; ++i) { + // reg[0][i] = d1*((v1 >> 2*i) & 3) - d1; + // reg[2][i] = d2*((v2 >> 2*i) & 3) - d2; + //} + //v1 >>= 8; v2 >>= 8; + //for (int i = 0; i < 4; ++i) { + // reg[1][i] = d1*((v1 >> 2*i) & 3) - d1; + // reg[3][i] = d2*((v2 >> 2*i) & 3) - d2; + //} +} + +template <typename type4x4> void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { device const uint16_t * q4 = (device const uint16_t *)xb->qs; const float d = xb->d; @@ -6270,6 +6460,7 @@ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_r template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>; template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>; template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>; +template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_t kernel_get_rows<block_iq1_bn, 4, dequantize_iq1_bn>; template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>; @@ -6298,6 +6489,7 @@ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>; template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>; template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>; +template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_bn, 4, dequantize_iq1_bn>; template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>; @@ -6326,6 +6518,7 @@ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>; template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>; template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>; +template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_bn, 4, dequantize_iq1_bn>; template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>; @@ -6530,6 +6723,7 @@ template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t ke template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>; +template [[host_name("kernel_mul_mv_id_iq1_bn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_bn_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>; |