summaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m47
1 files changed, 44 insertions, 3 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 87fa1721..c908106b 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -73,6 +73,8 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
+ GGML_METAL_DECL_KERNEL(get_rows_q5_0);
+ GGML_METAL_DECL_KERNEL(get_rows_q5_1);
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
@@ -87,6 +89,8 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
@@ -97,6 +101,8 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
@@ -254,6 +260,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
+ GGML_METAL_ADD_KERNEL(get_rows_q5_0);
+ GGML_METAL_ADD_KERNEL(get_rows_q5_1);
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
@@ -268,6 +276,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
@@ -278,8 +288,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
@@ -346,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(get_rows_f16);
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
+ GGML_METAL_DEL_KERNEL(get_rows_q5_0);
+ GGML_METAL_DEL_KERNEL(get_rows_q5_1);
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
@@ -360,6 +374,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
@@ -370,8 +386,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
@@ -1052,6 +1070,8 @@ void ggml_metal_graph_compute(
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
@@ -1121,6 +1141,24 @@ void ggml_metal_graph_compute(
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
} break;
+ case GGML_TYPE_Q5_0:
+ {
+ GGML_ASSERT(ne02 == 1);
+ GGML_ASSERT(ne12 == 1);
+
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
+ } break;
+ case GGML_TYPE_Q5_1:
+ {
+ GGML_ASSERT(ne02 == 1);
+ GGML_ASSERT(ne12 == 1);
+
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
+ } break;
case GGML_TYPE_Q8_0:
{
GGML_ASSERT(ne02 == 1);
@@ -1201,7 +1239,8 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
@@ -1233,6 +1272,8 @@ void ggml_metal_graph_compute(
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;