diff options
author | Jhen-Jie Hong <iainst0409@gmail.com> | 2023-10-18 07:21:48 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-18 15:21:48 +0300 |
commit | c67fe68e417f766970fb1feaf2e66458aa24116a (patch) | |
tree | 0146618767b8b9ae811a233814bf3a217784a55b /ggml-metal.m | |
parent | 1117d06607d2d885640ac501f05f0aae5494e2c5 (diff) |
metal : implement q5_0 and q5_1 kernels (#3648)
* metal : implement dequantize_q5_0
* metal : block_q_n_dot_y for block_q5_0 (broken)
* metal : revert unnecessary change
* metal : implement dequantize_q5_1
* metal : block_q_n_dot_y for q5_1 (broken)
* metal : fix block_q_n_dot_y
* minor : spaces / formatting
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 47 |
1 files changed, 44 insertions, 3 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 87fa1721..c908106b 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -73,6 +73,8 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); + GGML_METAL_DECL_KERNEL(get_rows_q5_0); + GGML_METAL_DECL_KERNEL(get_rows_q5_1); GGML_METAL_DECL_KERNEL(get_rows_q8_0); GGML_METAL_DECL_KERNEL(get_rows_q2_K); GGML_METAL_DECL_KERNEL(get_rows_q3_K); @@ -87,6 +89,8 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32); GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32); GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32); @@ -97,6 +101,8 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32); GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32); GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32); @@ -254,6 +260,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); + GGML_METAL_ADD_KERNEL(get_rows_q5_0); + GGML_METAL_ADD_KERNEL(get_rows_q5_1); GGML_METAL_ADD_KERNEL(get_rows_q8_0); GGML_METAL_ADD_KERNEL(get_rows_q2_K); GGML_METAL_ADD_KERNEL(get_rows_q3_K); @@ -268,6 +276,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32); GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32); @@ -278,8 +288,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); @@ -346,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(get_rows_f16); GGML_METAL_DEL_KERNEL(get_rows_q4_0); GGML_METAL_DEL_KERNEL(get_rows_q4_1); + GGML_METAL_DEL_KERNEL(get_rows_q5_0); + GGML_METAL_DEL_KERNEL(get_rows_q5_1); GGML_METAL_DEL_KERNEL(get_rows_q8_0); GGML_METAL_DEL_KERNEL(get_rows_q2_K); GGML_METAL_DEL_KERNEL(get_rows_q3_K); @@ -360,6 +374,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32); GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32); GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32); GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32); @@ -370,8 +386,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); @@ -1052,6 +1070,8 @@ void ggml_metal_graph_compute( case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; + case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break; + case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break; case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; @@ -1121,6 +1141,24 @@ void ggml_metal_graph_compute( nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; } break; + case GGML_TYPE_Q5_0: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32]; + } break; + case GGML_TYPE_Q5_1: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32]; + } break; case GGML_TYPE_Q8_0: { GGML_ASSERT(ne02 == 1); @@ -1201,7 +1239,8 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -1233,6 +1272,8 @@ void ggml_metal_graph_compute( case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; + case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break; + case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break; case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; |