summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-09-17 10:54:42 +0300
committerGitHub <noreply@github.com>2024-09-17 10:54:42 +0300
commit4ee889f15891b6e8785c88c932ec27f129cae285 (patch)
tree77afce9b1c5e2807bce1553a15514faa6843f747
parent2874b984006c6c8d0691ce000dcd9ca2cf9ff6fd (diff)
BF16 support on Metal (#56)
* BF16 support on Metal * Faster BF16 Metal dot product --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-metal.m37
-rw-r--r--ggml/src/ggml-metal.metal125
2 files changed, 161 insertions, 1 deletions
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 749396c9..02794e3c 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -113,6 +113,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
@@ -146,6 +148,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
@@ -176,6 +179,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
@@ -206,6 +210,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
@@ -628,6 +633,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F16, mul_mv_bf16_f16, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
@@ -661,6 +668,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
@@ -691,6 +699,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32, mul_mv_id_iq6_k_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
@@ -721,6 +730,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32, mul_mm_iq6_k_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
@@ -856,9 +866,12 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
}
static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx, const struct ggml_tensor * op) {
+
for (size_t i = 0, n = 3; i < n; ++i) {
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
- return false;
+ if (op->op != GGML_OP_MUL_MAT && op->op != GGML_OP_MUL_MAT_ID && op->op != GGML_OP_GET_ROWS) {
+ return false;
+ }
}
}
@@ -1861,6 +1874,7 @@ static enum ggml_status ggml_metal_graph_compute(
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
@@ -1945,6 +1959,19 @@ static enum ggml_status ggml_metal_graph_compute(
nrows = 4;
}
} break;
+ case GGML_TYPE_BF16:
+ {
+ if (src1t == GGML_TYPE_F32) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
+ }
+ else if (src1t == GGML_TYPE_F16) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F16].pipeline;
+ }
+ else {
+ GGML_ABORT("not implemented");
+ }
+ nrows = 4;
+ } break;
case GGML_TYPE_Q4_0:
{
nth0 = 8;
@@ -2229,6 +2256,7 @@ static enum ggml_status ggml_metal_graph_compute(
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
@@ -2307,6 +2335,13 @@ static enum ggml_status ggml_metal_graph_compute(
nth1 = 1;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
} break;
+ case GGML_TYPE_BF16:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ nth0 = 32;
+ nth1 = 1;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
+ } break;
case GGML_TYPE_Q4_0:
{
nth0 = 8;
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index e050fdc3..6553f465 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -1667,11 +1667,121 @@ kernel void kernel_mul_mv(
tiisg);
}
+template<typename T1>
+void kernel_mul_mv_bf16_impl(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ uint3 tgpig,
+ uint tiisg) {
+ const int64_t r0 = tgpig.x;
+ const int64_t rb = tgpig.y*N_MV_T_T;
+ const int64_t im = tgpig.z;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+ device const uint16_t * x = (device const uint16_t *) (src0 + offset0);
+
+ typedef union { uint32_t u[4]; float f[4]; } aux_t;
+ aux_t aux;
+
+ for (int row = 0; row < N_MV_T_T; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ aux.u[0] = x[4*i+0] << 16;
+ aux.u[1] = x[4*i+1] << 16;
+ aux.u[2] = x[4*i+2] << 16;
+ aux.u[3] = x[4*i+3] << 16;
+ sumf += aux.f[0] * (float)y[4*i+0] + aux.f[1] * (float)y[4*i+1] + aux.f[2] * (float)y[4*i+2] + aux.f[3] * (float)y[4*i+3];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+}
+
+template<typename T1>
+kernel void kernel_mul_mv_bf16(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_bf16_impl<T1>(
+ src0,
+ src1,
+ dst,
+ ne00,
+ ne01,
+ ne02,
+ nb00,
+ nb01,
+ nb02,
+ ne10,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg);
+}
+
typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
+template [[host_name("kernel_mul_mv_bf16_f16")]] kernel mul_mv_t kernel_mul_mv_bf16<half>;
+template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv_bf16<float>;
template<typename T, typename T4>
kernel void kernel_mul_mv_1row(
@@ -6612,6 +6722,17 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
}
template <typename type4x4>
+void dequantize_bf16(device const half4x4 * src, short il, thread type4x4 & reg) {
+ device const uint16_t * src_u16 = (device const uint16_t *)src;
+ typedef union { uint32_t u; float f; } aux_t;
+ aux_t aux;
+ for (int i = 0; i < 16; ++i) {
+ aux.u = (uint32_t)src_u16[i] << 16;
+ reg[i/4][i%4] = aux.f;
+ }
+}
+
+template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = il ? (xb->d / 16.h) : xb->d;
@@ -7692,6 +7813,7 @@ template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_ro
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
+template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_q_t kernel_get_rows_q<half4x4, 1, dequantize_bf16>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
@@ -7732,6 +7854,7 @@ typedef decltype(kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequanti
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequantize_f32>>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_f16>>;
+template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_bf16>>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_0, 2, dequantize_q4_0>>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_1, 2, dequantize_q4_1>>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_0, 2, dequantize_q5_0>>;
@@ -7769,6 +7892,7 @@ typedef decltype(kernel_mul_mm_id<DD<float4x4, 1, dequantize_f32>>) mat_mm_id_t;
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<float4x4, 1, dequantize_f32>>;
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<half4x4, 1, dequantize_f16>>;
+template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<half4x4, 1, dequantize_bf16>>;
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_0, 2, dequantize_q4_0>>;
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_1, 2, dequantize_q4_1>>;
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_0, 2, dequantize_q5_0>>;
@@ -7987,6 +8111,7 @@ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
+template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_bf16_impl<float>>>;
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;