summaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m42
1 files changed, 36 insertions, 6 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 9698e5a7..6e559443 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -89,6 +89,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
GGML_METAL_DECL_KERNEL(get_rows_i32);
GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
+ GGML_METAL_DECL_KERNEL(get_rows_iq2_xs);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(group_norm);
GGML_METAL_DECL_KERNEL(norm);
@@ -108,6 +109,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_iq2_xs_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
@@ -124,6 +126,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xs_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -137,6 +140,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_iq2_xs_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
@@ -150,6 +154,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xs_f32);
GGML_METAL_DECL_KERNEL(rope_f32);
GGML_METAL_DECL_KERNEL(rope_f16);
GGML_METAL_DECL_KERNEL(alibi_f32);
@@ -385,6 +390,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
GGML_METAL_ADD_KERNEL(get_rows_i32);
GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
+ GGML_METAL_ADD_KERNEL(get_rows_iq2_xs);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(group_norm);
GGML_METAL_ADD_KERNEL(norm);
@@ -404,6 +410,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_iq2_xs_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
@@ -420,6 +427,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xs_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -434,6 +442,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_iq2_xs_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
@@ -447,6 +456,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xs_f32);
}
GGML_METAL_ADD_KERNEL(rope_f32);
GGML_METAL_ADD_KERNEL(rope_f16);
@@ -513,6 +523,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
GGML_METAL_DEL_KERNEL(get_rows_i32);
GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
+ GGML_METAL_DEL_KERNEL(get_rows_iq2_xs);
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(group_norm);
GGML_METAL_DEL_KERNEL(norm);
@@ -532,6 +543,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_iq2_xs_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
@@ -548,6 +560,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xs_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -562,6 +575,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_iq2_xs_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
@@ -575,6 +589,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xs_f32);
}
GGML_METAL_DEL_KERNEL(rope_f32);
GGML_METAL_DEL_KERNEL(rope_f16);
@@ -1561,6 +1576,7 @@ bool ggml_metal_graph_compute(
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
+ case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xs_f32]; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1679,6 +1695,12 @@ bool ggml_metal_graph_compute(
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
} break;
+ case GGML_TYPE_IQ2_XS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xs_f32];
+ } break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1712,12 +1734,12 @@ bool ggml_metal_graph_compute(
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
- //src0t == GGML_TYPE_IQ2_XXS ||
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
- else if (src0t == GGML_TYPE_IQ2_XXS) {
- [encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q4_K) {
@@ -1810,6 +1832,7 @@ bool ggml_metal_graph_compute(
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
+ case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xs_f32]; break;
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1931,6 +1954,12 @@ bool ggml_metal_graph_compute(
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
} break;
+ case GGML_TYPE_IQ2_XS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xs_f32];
+ } break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1980,12 +2009,12 @@ bool ggml_metal_graph_compute(
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
- //src2t == GGML_TYPE_IQ2_XXS ||
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
- else if (src2t == GGML_TYPE_IQ2_XXS) {
- [encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
+ else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
+ const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src2t == GGML_TYPE_Q4_K) {
@@ -2026,6 +2055,7 @@ bool ggml_metal_graph_compute(
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
+ case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xs]; break;
default: GGML_ASSERT(false && "not implemented");
}