summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-17 10:48:26 +0200
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:51 +0300
commitd9fb92b7104b929f0427323f7964ef7a4da33d2b (patch)
tree5fabacd57ad9edf624ad4fa447023e3be923b2cb
parent0c5a353ebdcc58e8b8051f2c38a92a8c23fa8092 (diff)
iq1_bn: Metal now works
PP performance is decent (668 t/s v 724 t/s for q4_0), but TG is kind of low (60 t/s vs 81 t/s for q4_0).
-rw-r--r--ggml-metal.m29
-rw-r--r--ggml-metal.metal194
2 files changed, 221 insertions, 2 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 79902c9a..3a383d1e 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -76,6 +76,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
@@ -104,6 +105,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
@@ -128,6 +130,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
@@ -149,6 +152,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
@@ -170,6 +174,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
@@ -532,6 +537,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, get_rows_iq1_bn, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
@@ -560,6 +566,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, mul_mv_iq1_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
@@ -584,6 +591,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, mul_mv_id_iq1_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
@@ -605,6 +613,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, mul_mm_iq1_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
@@ -626,6 +635,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, mul_mm_id_iq1_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
@@ -1607,6 +1617,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
@@ -1767,6 +1778,12 @@ static enum ggml_status ggml_metal_graph_compute(
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
} break;
+ case GGML_TYPE_IQ1_BN:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32].pipeline;
+ } break;
case GGML_TYPE_IQ4_NL:
{
nth0 = 4;
@@ -1813,7 +1830,7 @@ static enum ggml_status ggml_metal_graph_compute(
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| src0t == GGML_TYPE_IQ1_BN) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1913,6 +1930,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
@@ -2067,6 +2085,12 @@ static enum ggml_status ggml_metal_graph_compute(
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
} break;
+ case GGML_TYPE_IQ1_BN:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline;
+ } break;
case GGML_TYPE_IQ4_NL:
{
nth0 = 4;
@@ -2120,7 +2144,7 @@ static enum ggml_status ggml_metal_graph_compute(
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| src0t == GGML_TYPE_IQ1_BN) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -2179,6 +2203,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
+ case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
diff --git a/ggml-metal.metal b/ggml-metal.metal
index e2796fd6..52a3133d 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -4992,6 +4992,126 @@ void kernel_mul_mv_iq1_m_f32_impl(
}
}
+// Not working. Don't see the bug.
+void kernel_mul_mv_iq1_bn_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_value,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_IQ1BN;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+ float d1bn[N_DST];
+
+ const int nb32 = nb * (QK_IQ1BN / 32);
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ typedef union { float f; uint32_t i; } scale_t;
+ scale_t scale;
+
+ for (int row = 0; row < N_DST; ++row) {
+ uint8_t u = x[nb*row].extra & 0xff;
+ scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
+ d1bn[row] = scale.f;
+ }
+
+ uint32_t aux32;
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ float4 sumy = {0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+ yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
+ yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
+ yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
+ }
+
+ const int ibl = ib32 / (QK_IQ1BN / 32);
+ const int ib = ib32 % (QK_IQ1BN / 32);
+
+ device const block_iq1_bn * xr = x + ibl;
+ device const uint16_t * extra = (device const uint16_t *)&xr->extra;
+ device const uint8_t * ql = xr->ql + 4 * ib;
+ device const uint8_t * qh = xr->qh + 2 * ib;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ //uint8_t u = extra[0] & 0xff;
+ //scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
+ uint8_t signs = extra[0] >> (8 + 4*ib);
+ float4 acc = {0.f};
+ for (int j = 0; j < 2; ++j) {
+ uint32_t v1 = iq1bn_grid_u16[ql[2*j+0] | ((qh[j] << 8) & 0x0f00)];
+ uint32_t v2 = iq1bn_grid_u16[ql[2*j+1] | ((qh[j] << 4) & 0x0f00)];
+ uint32_t v = v1 | (v2 << 16);
+ aux32 = v & 0x03030303;
+ acc[2*j+0] += yl[16*j + 0] * aux8[0] + yl[16*j + 4] * aux8[1];
+ acc[2*j+1] += yl[16*j + 8] * aux8[2] + yl[16*j +12] * aux8[3];
+ aux32 = (v >> 2) & 0x03030303;
+ acc[2*j+0] += yl[16*j + 1] * aux8[0] + yl[16*j + 5] * aux8[1];
+ acc[2*j+1] += yl[16*j + 9] * aux8[2] + yl[16*j +13] * aux8[3];
+ aux32 = (v >> 4) & 0x03030303;
+ acc[2*j+0] += yl[16*j + 2] * aux8[0] + yl[16*j + 6] * aux8[1];
+ acc[2*j+1] += yl[16*j +10] * aux8[2] + yl[16*j +14] * aux8[3];
+ aux32 = (v >> 6) & 0x03030303;
+ acc[2*j+0] += yl[16*j + 3] * aux8[0] + yl[16*j + 7] * aux8[1];
+ acc[2*j+1] += yl[16*j +12] * aux8[2] + yl[16*j +15] * aux8[3];
+ }
+
+ float sum = (signs & 1 ? sumy[0] - acc[0] : acc[0] - sumy[0])
+ + (signs & 2 ? sumy[1] - acc[1] : acc[1] - sumy[1])
+ + (signs & 4 ? sumy[2] - acc[2] : acc[2] - sumy[2])
+ + (signs & 8 ? sumy[3] - acc[3] : acc[3] - sumy[3]);
+ sumf[row] += sum;
+
+ extra += nb*sizeof(block_iq1_bn)/2;
+ ql += nb*sizeof(block_iq1_bn);
+ qh += nb*sizeof(block_iq1_bn);
+ }
+
+ y4 += 32 * 32;
+ }
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * d1bn[row];
+ }
+ }
+}
+
void kernel_mul_mv_iq4_nl_f32_impl(
device const void * src0,
device const float * src1,
@@ -5237,6 +5357,34 @@ kernel void kernel_mul_mv_iq1_m_f32(
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
+[[host_name("kernel_mul_mv_iq1_bn_f32")]]
+kernel void kernel_mul_mv_iq1_bn_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq1_bn_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
kernel void kernel_mul_mv_iq4_nl_f32(
device const void * src0,
@@ -5695,6 +5843,48 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
}
template <typename type4x4>
+void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
+ // il is in 0...3
+ typedef union { float f; uint32_t i; } scale_t;
+ scale_t scale;
+ uint8_t u = xb->extra & 0xff;
+ scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
+ //uint32_t u = xb->extra & 0xff;
+ //scale.i = (u << 19) + 905969664;
+ uint8_t gs = xb->extra >> (8 + 2*il);
+ const float d1 = gs & 1 ? -scale.f : scale.f;
+ const float d2 = gs & 2 ? -scale.f : scale.f;
+
+ uint32_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)];
+ uint32_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)];
+
+ uint32_t v = v1 | (v2 << 16);
+ uint32_t aux32;
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32;
+
+ for (int i = 0; i < 4; ++i) {
+ aux32 = (v >> 2*i) & 0x03030303;
+ reg[0][i] = d1*aux8[0] - d1;
+ reg[1][i] = d1*aux8[1] - d1;
+ reg[2][i] = d2*aux8[2] - d2;
+ reg[3][i] = d2*aux8[3] - d2;
+ }
+
+ //Basically same performance as above. I guess, the compiler makes the transformation automatically
+ //uint16_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)];
+ //uint16_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)];
+ //for (int i = 0; i < 4; ++i) {
+ // reg[0][i] = d1*((v1 >> 2*i) & 3) - d1;
+ // reg[2][i] = d2*((v2 >> 2*i) & 3) - d2;
+ //}
+ //v1 >>= 8; v2 >>= 8;
+ //for (int i = 0; i < 4; ++i) {
+ // reg[1][i] = d1*((v1 >> 2*i) & 3) - d1;
+ // reg[3][i] = d2*((v2 >> 2*i) & 3) - d2;
+ //}
+}
+
+template <typename type4x4>
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
const float d = xb->d;
@@ -6270,6 +6460,7 @@ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_r
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_t kernel_get_rows<block_iq1_bn, 4, dequantize_iq1_bn>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
@@ -6298,6 +6489,7 @@ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_bn, 4, dequantize_iq1_bn>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
@@ -6326,6 +6518,7 @@ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_bn, 4, dequantize_iq1_bn>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
@@ -6530,6 +6723,7 @@ template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t ke
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq1_bn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_bn_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;