diff options
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 202 |
1 files changed, 123 insertions, 79 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index f8fa05dd..57c238dd 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -81,18 +81,18 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_q6_K); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); - GGML_METAL_DECL_KERNEL(mul_mat_f32_f32); - GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); - GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row); - GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4); - GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_f32_f32); + GGML_METAL_DECL_KERNEL(mul_mv_f16_f32); + GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row); + GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); + GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_f32_f32); GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); @@ -262,28 +262,30 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(get_rows_q6_K); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); - GGML_METAL_ADD_KERNEL(mul_mat_f32_f32); - GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); - GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row); - GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4); - GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); - GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_f32_f32); + GGML_METAL_ADD_KERNEL(mul_mv_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row); + GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); + GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32); + if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { + GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); + GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); + } GGML_METAL_ADD_KERNEL(rope_f32); GGML_METAL_ADD_KERNEL(rope_f16); GGML_METAL_ADD_KERNEL(alibi_f32); @@ -296,8 +298,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { #undef GGML_METAL_ADD_KERNEL } - GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); #if TARGET_OS_OSX + // print MTL GPU family: + GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]); + GGML_METAL_LOG_INFO("%s: GPU arch: %s\n", __func__, [[ctx->device architecture].name UTF8String]); + + // determine max supported GPU family + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf + // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf + for (int i = MTLGPUFamilyApple9 + 10; i >= MTLGPUFamilyApple1; --i) { + if ([ctx->device supportsFamily:i]) { + GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i); + break; + } + } + + GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); if (ctx->device.maxTransferRate != 0) { GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); @@ -339,28 +355,30 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(get_rows_q6_K); GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(norm); - GGML_METAL_DEL_KERNEL(mul_mat_f32_f32); - GGML_METAL_DEL_KERNEL(mul_mat_f16_f32); - GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row); - GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4); - GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); - GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_f32_f32); + GGML_METAL_DEL_KERNEL(mul_mv_f16_f32); + GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row); + GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); + GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32); + if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { + GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); + GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); + } GGML_METAL_DEL_KERNEL(rope_f32); GGML_METAL_DEL_KERNEL(rope_f16); GGML_METAL_DEL_KERNEL(alibi_f32); @@ -986,21 +1004,46 @@ void ggml_metal_graph_compute( } break; case GGML_OP_MUL_MAT: { - // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 - GGML_ASSERT(ne00 == ne10); - // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere - uint gqa = ne12/ne02; GGML_ASSERT(ne03 == ne13); + const uint gqa = ne12/ne02; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + int ne11_mm_min = 1; + +#if 0 + // the numbers below are measured on M2 Ultra for 7B and 13B models + // these numbers do not translate to other devices or model sizes + // TODO: need to find a better approach + if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { + switch (src0t) { + case GGML_TYPE_F16: ne11_mm_min = 2; break; + case GGML_TYPE_Q8_0: ne11_mm_min = 7; break; + case GGML_TYPE_Q2_K: ne11_mm_min = 15; break; + case GGML_TYPE_Q3_K: ne11_mm_min = 7; break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: ne11_mm_min = 15; break; + case GGML_TYPE_Q4_K: ne11_mm_min = 11; break; + case GGML_TYPE_Q5_0: // not tested yet + case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet + case GGML_TYPE_Q5_K: ne11_mm_min = 7; break; + case GGML_TYPE_Q6_K: ne11_mm_min = 7; break; + default: ne11_mm_min = 1; break; + } + } +#endif + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if (!ggml_is_transposed(src0) && + if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && + !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1t == GGML_TYPE_F32 && - [ctx->device supportsFamily:MTLGPUFamilyApple7] && - ne00%32 == 0 && - ne11 > 2) { + ne00 % 32 == 0 && + ne11 > ne11_mm_min) { + //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); switch (src0->type) { case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; @@ -1029,17 +1072,18 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { int nth0 = 32; int nth1 = 1; int nrows = 1; + //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); // use custom matrix x vector kernel switch (src0t) { case GGML_TYPE_F32: { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32]; nrows = 4; } break; case GGML_TYPE_F16: @@ -1047,12 +1091,12 @@ void ggml_metal_graph_compute( nth0 = 32; nth1 = 1; if (ne11 * ne12 < 4) { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4]; nrows = ne11; } else { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; nrows = 4; } } break; @@ -1063,7 +1107,7 @@ void ggml_metal_graph_compute( nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32]; } break; case GGML_TYPE_Q4_1: { @@ -1072,7 +1116,7 @@ void ggml_metal_graph_compute( nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; } break; case GGML_TYPE_Q8_0: { @@ -1081,7 +1125,7 @@ void ggml_metal_graph_compute( nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32]; } break; case GGML_TYPE_Q2_K: { @@ -1090,7 +1134,7 @@ void ggml_metal_graph_compute( nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32]; } break; case GGML_TYPE_Q3_K: { @@ -1099,7 +1143,7 @@ void ggml_metal_graph_compute( nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32]; } break; case GGML_TYPE_Q4_K: { @@ -1108,7 +1152,7 @@ void ggml_metal_graph_compute( nth0 = 4; //1; nth1 = 8; //32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32]; } break; case GGML_TYPE_Q5_K: { @@ -1117,7 +1161,7 @@ void ggml_metal_graph_compute( nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32]; } break; case GGML_TYPE_Q6_K: { @@ -1126,7 +1170,7 @@ void ggml_metal_graph_compute( nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32]; } break; default: { @@ -1155,7 +1199,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || - src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) { + src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q4_K) { |