summaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m35
1 files changed, 35 insertions, 0 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 956e323a..0d4aa430 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -62,6 +62,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
GGML_METAL_KERNEL_TYPE_RMS_NORM,
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -85,6 +86,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -104,6 +106,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -120,6 +123,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -136,6 +140,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_ROPE_F32,
GGML_METAL_KERNEL_TYPE_ROPE_F16,
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -448,6 +453,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -471,6 +477,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -490,6 +497,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -506,6 +514,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -522,6 +531,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -1338,6 +1348,7 @@ static bool ggml_metal_graph_compute(
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
}
@@ -1478,6 +1489,12 @@ static bool ggml_metal_graph_compute(
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
} break;
+ case GGML_TYPE_IQ4_NL:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
+ } break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1525,6 +1542,11 @@ static bool ggml_metal_graph_compute(
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
+ else if (src0t == GGML_TYPE_IQ4_NL) {
+ const int mem_size = 32*sizeof(float);
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
else if (src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
@@ -1619,6 +1641,7 @@ static bool ggml_metal_graph_compute(
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
}
@@ -1762,6 +1785,12 @@ static bool ggml_metal_graph_compute(
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
} break;
+ case GGML_TYPE_IQ4_NL:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
+ } break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1825,6 +1854,11 @@ static bool ggml_metal_graph_compute(
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
+ else if (src2t == GGML_TYPE_IQ4_NL) {
+ const int mem_size = 32*sizeof(float);
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
else if (src2t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
@@ -1867,6 +1901,7 @@ static bool ggml_metal_graph_compute(
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
}