diff options
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 41 |
1 files changed, 34 insertions, 7 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index b73f51f2..658c392e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -52,14 +52,18 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); GGML_METAL_DECL_KERNEL(get_rows_q2_k); + GGML_METAL_DECL_KERNEL(get_rows_q3_k); GGML_METAL_DECL_KERNEL(get_rows_q4_k); + GGML_METAL_DECL_KERNEL(get_rows_q5_k); GGML_METAL_DECL_KERNEL(get_rows_q6_k); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32); GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(cpy_f32_f16); @@ -153,14 +157,18 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); GGML_METAL_ADD_KERNEL(get_rows_q2_k); + GGML_METAL_ADD_KERNEL(get_rows_q3_k); GGML_METAL_ADD_KERNEL(get_rows_q4_k); + GGML_METAL_ADD_KERNEL(get_rows_q5_k); GGML_METAL_ADD_KERNEL(get_rows_q6_k); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32); GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(cpy_f32_f16); @@ -575,6 +583,15 @@ void ggml_metal_graph_compute( nth1 = 16; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32]; } break; + case GGML_TYPE_Q3_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32]; + } break; case GGML_TYPE_Q4_K: { GGML_ASSERT(ne02 == 1); @@ -584,6 +601,15 @@ void ggml_metal_graph_compute( nth1 = 16; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32]; } break; + case GGML_TYPE_Q5_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32]; + } break; case GGML_TYPE_Q6_K: { GGML_ASSERT(ne02 == 1); @@ -620,15 +646,14 @@ void ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) { [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else if (src0t == GGML_TYPE_Q2_K) { + } + else if (src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_Q3_K || + src0t == GGML_TYPE_Q4_K || + src0t == GGML_TYPE_Q5_K || + src0t == GGML_TYPE_Q6_K) { [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else if (src0t == GGML_TYPE_Q4_K) { - [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else if (src0t == GGML_TYPE_Q6_K) { - [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -646,7 +671,9 @@ void ggml_metal_graph_compute( case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break; + case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break; case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; + case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break; case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break; default: GGML_ASSERT(false && "not implemented"); } |