summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-09-10 09:43:05 +0300
committerGitHub <noreply@github.com>2024-09-10 09:43:05 +0300
commita1f7a03f500451be80ec4aeae44665c58cde311f (patch)
tree0373ddcc1eaf00fa09d368fe5b83c739c5257b06
parent918ada20faf7747bbda6b78503b5d72a90157844 (diff)
IQ1_TN Metal implementation (#46)
* iq1_tn: Metal implementation Rquires to change the get_rows and matrix multiplication kernels to use a dequantizer type rather than a dequantization function. But once this is done, we can simply reuse the iq1_bn implementation. This change will also allow to add other quantization types that have meta data (such as a row scale) stored at the beginning of a row (or change existing quantization types to row-wise scales). * Some cleanup --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-metal.m29
-rw-r--r--ggml/src/ggml-metal.metal367
2 files changed, 314 insertions, 82 deletions
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index b3f6e60c..749396c9 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -93,6 +93,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_TN,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
@@ -130,6 +131,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_TN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_TN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
@@ -162,6 +164,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_TN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_TN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
@@ -191,6 +194,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_TN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
@@ -220,6 +224,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_TN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
@@ -603,6 +608,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, get_rows_iq1_bn, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_TN, get_rows_iq1_tn, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, get_rows_iq2_bn, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN, get_rows_iq2_tn, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
@@ -640,6 +646,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, mul_mv_iq1_bn_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_TN_F32, mul_mv_iq1_tn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, mul_mv_iq2_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_TN_F32, mul_mv_iq2_tn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
@@ -672,6 +679,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, mul_mv_id_iq1_bn_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_TN_F32, mul_mv_id_iq1_tn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, mul_mv_id_iq2_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_TN_F32, mul_mv_id_iq2_tn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
@@ -701,6 +709,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, mul_mm_iq1_bn_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_TN_F32, mul_mm_iq1_tn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, mul_mm_iq2_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32, mul_mm_iq2_tn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
@@ -730,6 +739,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, mul_mm_id_iq1_bn_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_TN_F32, mul_mm_id_iq1_tn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, mul_mm_id_iq2_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32, mul_mm_id_iq2_tn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
@@ -1869,6 +1879,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_TN_F32 ].pipeline; break;
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
@@ -2042,6 +2053,12 @@ static enum ggml_status ggml_metal_graph_compute(
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32].pipeline;
} break;
+ case GGML_TYPE_IQ1_TN:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_TN_F32].pipeline;
+ } break;
case GGML_TYPE_IQ2_BN:
{
nth0 = 4;
@@ -2128,7 +2145,7 @@ static enum ggml_status ggml_metal_graph_compute(
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K||
- src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN) {
+ src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -2230,6 +2247,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_TN_F32 ].pipeline; break;
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
@@ -2397,6 +2415,12 @@ static enum ggml_status ggml_metal_graph_compute(
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline;
} break;
+ case GGML_TYPE_IQ1_TN:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_TN_F32].pipeline;
+ } break;
case GGML_TYPE_IQ2_BN:
{
nth0 = 4;
@@ -2494,7 +2518,7 @@ static enum ggml_status ggml_metal_graph_compute(
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K||
- src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN) {
+ src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -2555,6 +2579,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break;
+ case GGML_TYPE_IQ1_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_TN ].pipeline; break;
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break;
case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index d7af1800..e050fdc3 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -5367,6 +5367,107 @@ void kernel_mul_mv_iq1_bn_f32_impl(
}
}
+// TODO: unify with kernel_mul_mv_iq1_bn_f32_impl
+void kernel_mul_mv_iq1_tn_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_value,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_IQ1BN;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ // Why are we not passing in src0->nb[0]?
+ // But because we are not, we need to use this hack
+ const uint row_size = sizeof(block_iq1_tn)*(ne00/QK_K);
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = ((i12/r2)*ne01 + (i13/r3)*(ne01*ne02))*row_size;
+ device const char * cx = (device const char *) src0 + first_row*row_size + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[16];
+ float sumf[N_DST]={0.f};
+
+ const int nb32 = nb * (QK_IQ1BN / 32);
+
+ const int ix = tiisg/2;
+ const int ir = tiisg%2;
+
+ device const float * y4 = (device const float *)y + 32 * ix + 16 * ir;
+
+ uint32_t aux32[2];
+
+ const float values[3] = {-1.f, 0.f, 1.f};
+
+ constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
+
+ for (int j = 0; j < 16; ++j) yl[j] = y4[j];
+
+ const int ibl = ib32 / (QK_IQ1BN / 32);
+ const int ib = ib32 % (QK_IQ1BN / 32);
+ const int i16 = 2*ib + ir;
+
+ device const half * dh = (device const half *)cx;
+ device const block_iq1_bn * xr = (device const block_iq1_bn *)(dh + 1) + ibl;
+ device const uint8_t * ql = xr->ql + 3*i16;
+ device const uint8_t * extra = (device const uint8_t *)&xr->extra;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ float acc = 0;
+ int i = 0;
+ for (int k = 0; k < 3; ++k) {
+ uint8_t q = ql[k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ v = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ acc += yl[i++] * values[v];
+ }
+ }
+ uint8_t v = k_mult[i16]*extra[0];
+ v = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ acc += yl[15] * values[v];
+
+ sumf[row] += acc * (float)dh[0];
+
+ extra += row_size;
+ ql += row_size;
+ dh += row_size/2;
+ }
+
+ y4 += 32 * 16;
+ }
+
+ for (int row = 0; row < N_DST; row += 2) {
+ half2 r = {(half)sumf[row], (half)sumf[row+1]};
+ r = simd_sum(r);
+ if (tiisg < 2) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg];
+ }
+ }
+}
+
void kernel_mul_mv_iq2_bn_f32_impl(
device const void * src0,
device const float * src1,
@@ -6290,6 +6391,34 @@ kernel void kernel_mul_mv_iq1_bn_f32(
kernel_mul_mv_iq1_bn_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
+[[host_name("kernel_mul_mv_iq1_tn_f32")]]
+kernel void kernel_mul_mv_iq1_tn_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq1_tn_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
[[host_name("kernel_mul_mv_iq2_bn_f32")]]
kernel void kernel_mul_mv_iq2_bn_f32(
device const void * src0,
@@ -7080,6 +7209,38 @@ kernel void kernel_get_rows_q(
}
}
+template<typename Dequantizer>
+kernel void kernel_get_rows_q2(
+ device const void * src0,
+ device const void * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ Dequantizer deq((const device char *) src0 + r*nb01 + i02*nb02);
+
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
+ float4x4 temp;
+ deq.convert(ind, temp);
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+ }
+}
+
template<typename T>
kernel void kernel_get_rows_f(
device const void * src0,
@@ -7149,8 +7310,48 @@ kernel void kernel_get_rows_i32(
#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
#define SG_MAT_ROW 8
+template <typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
+struct DefaultDequantizer {
+ using type4x4 = T4x4;
+ using Block = block_q;
+ DefaultDequantizer(device const char * cx, short il) : x((device const block_q *)cx + il/nl), il(il) {}
+ inline void convert(thread T4x4& t) const { dequantize_func(x, il, t); }
+ inline void next() {
+ il = (il + 2 < nl) ? il + 2 : il % 2;
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
+ }
+ device const block_q * x;
+ short il;
+};
+
+template <typename T4x4>
+struct DequantizerIQ1TN {
+ using type4x4 = T4x4;
+ using Block = block_iq1_bn;;
+ DequantizerIQ1TN(device const char * cx, short il = 0) : il(il) {
+ d = *(device const half *)cx;
+ x = (device const Block *)(cx + sizeof(half));
+ }
+ inline void convert(thread T4x4& t) const {
+ dequantize_iq1_bn(x, il, t);
+ t *= d;
+ }
+ inline void convert(int64_t ind, thread T4x4& t) {
+ dequantize_iq1_bn(x + ind/4, ind%4, t);
+ t *= d;
+ }
+ inline void next() {
+ constexpr int short nl = 4;
+ il = (il + 2 < nl) ? il + 2 : il % 2;
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
+ }
+ device const Block * x;
+ short il;
+ half d;
+};
+
// each block_q contains 16*nl weights
-template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
+template<typename T, typename simdgroup_T8x8, typename Dequantizer>
kernel void kernel_mul_mm(device const uchar * src0,
device const uchar * src1,
device float * dst,
@@ -7199,18 +7400,19 @@ kernel void kernel_mul_mm(device const uchar * src0,
const uint i13 = im/ne12;
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
- ushort offset1 = il/nl;
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
- device const float * y = (device const float *)(src1
+ device const char * cx = (device const char *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0);
+ device const float * y = (device const float *)(src1
+ nb12 * im
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+ Dequantizer deq(cx, il);
+
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
- T4x4 temp_a;
- dequantize_func(x, il, temp_a);
+ typename Dequantizer::type4x4 temp_a;
+ deq.convert(temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll(16)
@@ -7222,8 +7424,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
- il = (il + 2 < nl) ? il + 2 : il % 2;
- x = (il < 2) ? x + (2+nl-1)/nl : x;
+ deq.next();
y += BLOCK_SIZE_K;
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -7283,7 +7484,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
}
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+template<typename Dequantizer>
void kernel_mul_mm_id_impl(
device const uchar * src0,
device const uchar * src1,
@@ -7330,20 +7531,20 @@ void kernel_mul_mm_id_impl(
}
short il = (tiitg % THREAD_PER_ROW);
- ushort offset1 = il/nl;
-
threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
- device const float * y = (device const float *)(src1
+ device const char * cx = (device const char *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01);
+ device const float * y = (device const float *)(src1
+ nb12 * id[1]
+ nb11 * (id[0] % ne11)
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+ Dequantizer deq(cx, il);
+
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
half4x4 temp_a;
- dequantize_func(x, il, temp_a);
+ deq.convert(temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < 16; i++) {
@@ -7354,8 +7555,7 @@ void kernel_mul_mm_id_impl(
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
- il = (il + 2 < nl) ? il + 2 : il % 2;
- x = (il < 2) ? x + (2+nl-1)/nl : x;
+ deq.next();
y += BLOCK_SIZE_K;
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -7405,7 +7605,7 @@ void kernel_mul_mm_id_impl(
}
}
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+template<typename Dequantizer>
kernel void kernel_mul_mm_id(
device const uchar * src0s,
device const uchar * src1,
@@ -7456,7 +7656,7 @@ kernel void kernel_mul_mm_id(
threadgroup_barrier(mem_flags::mem_threadgroup);
- kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
+ kernel_mul_mm_id_impl<Dequantizer>(
src0,
src1,
rowids,
@@ -7519,78 +7719,84 @@ template [[host_name("kernel_get_rows_iq5_k")]] kernel get_rows_q_t kernel_get
template [[host_name("kernel_get_rows_iq6_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq6_k, QK_NL, dequantize_iq6_k>;
template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_bn, 4, dequantize_iq1_bn>;
template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_bn, 4, dequantize_iq2_bn>;
+template [[host_name("kernel_get_rows_iq1_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerIQ1TN<float4x4>>;
//
// matrix-matrix multiplication
//
-typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
-
-template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
-template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
-template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
-template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
-template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
-template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
-template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
-template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_tn, QK_NL, dequantize_iq2_tn>;
-template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
-template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
-template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
-template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
-template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
-template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
-template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
-template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
-template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
-template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
-template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
-template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
-template [[host_name("kernel_mul_mm_iq2_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_k, QK_NL, dequantize_iq2_k>;
-template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_k, QK_NL, dequantize_iq3_k>;
-template [[host_name("kernel_mul_mm_iq4_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_k, QK_NL, dequantize_iq4_k>;
-template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq5_k, QK_NL, dequantize_iq5_k>;
-template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq6_k, QK_NL, dequantize_iq6_k>;
-template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_bn, 4, dequantize_iq1_bn>;
-template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_bn, 4, dequantize_iq2_bn>;
+template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+using DD = DefaultDequantizer<half4x4, block_q, nl, dequantize_func>;
+
+typedef decltype(kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequantize_f32>>) mat_mm_t;
+
+template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequantize_f32>>;
+template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_f16>>;
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_0, 2, dequantize_q4_0>>;
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_1, 2, dequantize_q4_1>>;
+template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_0, 2, dequantize_q5_0>>;
+template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_1, 2, dequantize_q5_1>>;
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q8_0, 2, dequantize_q8_0>>;
+template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q2_K, QK_NL, dequantize_q2_K>>;
+template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_tn, QK_NL, dequantize_iq2_tn>>;
+template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q3_K, QK_NL, dequantize_q3_K>>;
+template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_K, QK_NL, dequantize_q4_K>>;
+template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_K, QK_NL, dequantize_q5_K>>;
+template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q6_K, QK_NL, dequantize_q6_K>>;
+template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>>;
+template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_xs, QK_NL, dequantize_iq2_xs>>;
+template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>>;
+template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq3_s, QK_NL, dequantize_iq3_s>>;
+template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_s, QK_NL, dequantize_iq2_s>>;
+template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_s, QK_NL, dequantize_iq1_s>>;
+template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_m, QK_NL, dequantize_iq1_m>>;
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_nl, 2, dequantize_iq4_nl>>;
+template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_xs, QK_NL, dequantize_iq4_xs>>;
+template [[host_name("kernel_mul_mm_iq2_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_k, QK_NL, dequantize_iq2_k>>;
+template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq3_k, QK_NL, dequantize_iq3_k>>;
+template [[host_name("kernel_mul_mm_iq4_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_k, QK_NL, dequantize_iq4_k>>;
+template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq5_k, QK_NL, dequantize_iq5_k>>;
+template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq6_k, QK_NL, dequantize_iq6_k>>;
+template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_bn, 4, dequantize_iq1_bn>>;
+template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_bn, 4, dequantize_iq2_bn>>;
+template [[host_name("kernel_mul_mm_iq1_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerIQ1TN<half4x4>>;
//
// indirect matrix-matrix multiplication
//
-typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
-
-template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
-template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
-template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
-template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
-template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
-template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
-template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
-template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_tn, QK_NL, dequantize_iq2_tn>;
-template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
-template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
-template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
-template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
-template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
-template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
-template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
-template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
-template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
-template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
-template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_bn, 4, dequantize_iq1_bn>;
-template [[host_name("kernel_mul_mm_id_iq2_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_bn, 4, dequantize_iq2_bn>;
-template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
-template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
-template [[host_name("kernel_mul_mm_id_iq2_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_k, QK_NL, dequantize_iq2_k>;
-template [[host_name("kernel_mul_mm_id_iq3_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_k, QK_NL, dequantize_iq3_k>;
-template [[host_name("kernel_mul_mm_id_iq4_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_k, QK_NL, dequantize_iq4_k>;
-template [[host_name("kernel_mul_mm_id_iq5_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq5_k, QK_NL, dequantize_iq5_k>;
-template [[host_name("kernel_mul_mm_id_iq6_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq6_k, QK_NL, dequantize_iq6_k>;
+typedef decltype(kernel_mul_mm_id<DD<float4x4, 1, dequantize_f32>>) mat_mm_id_t;
+
+template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<float4x4, 1, dequantize_f32>>;
+template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<half4x4, 1, dequantize_f16>>;
+template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_0, 2, dequantize_q4_0>>;
+template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_1, 2, dequantize_q4_1>>;
+template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_0, 2, dequantize_q5_0>>;
+template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_1, 2, dequantize_q5_1>>;
+template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q8_0, 2, dequantize_q8_0>>;
+template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q2_K, QK_NL, dequantize_q2_K>>;
+template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq2_tn, QK_NL, dequantize_iq2_tn>>;
+template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q3_K, QK_NL, dequantize_q3_K>>;
+template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_K, QK_NL, dequantize_q4_K>>;
+template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_K, QK_NL, dequantize_q5_K>>;
+template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q6_K, QK_NL, dequantize_q6_K>>;
+template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>>;
+template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq2_xs, QK_NL, dequantize_iq2_xs>>;
+template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>>;
+template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq3_s, QK_NL, dequantize_iq3_s>>;
+template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq2_s, QK_NL, dequantize_iq2_s>>;
+template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq1_s, QK_NL, dequantize_iq1_s>>;
+template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq1_m, QK_NL, dequantize_iq1_m>>;
+template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq1_bn, 4, dequantize_iq1_bn>>;
+template [[host_name("kernel_mul_mm_id_iq2_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq2_bn, 4, dequantize_iq2_bn>>;
+template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq4_nl, 2, dequantize_iq4_nl>>;
+template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq4_xs, QK_NL, dequantize_iq4_xs>>;
+template [[host_name("kernel_mul_mm_id_iq2_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq2_k, QK_NL, dequantize_iq2_k>>;
+template [[host_name("kernel_mul_mm_id_iq3_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq3_k, QK_NL, dequantize_iq3_k>>;
+template [[host_name("kernel_mul_mm_id_iq4_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq4_k, QK_NL, dequantize_iq4_k>>;
+template [[host_name("kernel_mul_mm_id_iq5_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq5_k, QK_NL, dequantize_iq5_k>>;
+template [[host_name("kernel_mul_mm_id_iq6_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq6_k, QK_NL, dequantize_iq6_k>>;
+template [[host_name("kernel_mul_mm_id_iq1_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerIQ1TN<half4x4>>;
//
// matrix-vector multiplication
@@ -7795,6 +8001,7 @@ template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq1_bn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_bn_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq1_tn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_tn_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq2_bn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_bn_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;