summaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m202
1 files changed, 123 insertions, 79 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index f8fa05dd..57c238dd 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -81,18 +81,18 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm);
- GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
- GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -262,28 +262,30 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm);
- GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
- GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
+ GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
+ }
GGML_METAL_ADD_KERNEL(rope_f32);
GGML_METAL_ADD_KERNEL(rope_f16);
GGML_METAL_ADD_KERNEL(alibi_f32);
@@ -296,8 +298,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
#undef GGML_METAL_ADD_KERNEL
}
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
#if TARGET_OS_OSX
+ // print MTL GPU family:
+ GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
+ GGML_METAL_LOG_INFO("%s: GPU arch: %s\n", __func__, [[ctx->device architecture].name UTF8String]);
+
+ // determine max supported GPU family
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
+ for (int i = MTLGPUFamilyApple9 + 10; i >= MTLGPUFamilyApple1; --i) {
+ if ([ctx->device supportsFamily:i]) {
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
+ break;
+ }
+ }
+
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
if (ctx->device.maxTransferRate != 0) {
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
@@ -339,28 +355,30 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(norm);
- GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
- GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
- GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
+ GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
+ }
GGML_METAL_DEL_KERNEL(rope_f32);
GGML_METAL_DEL_KERNEL(rope_f16);
GGML_METAL_DEL_KERNEL(alibi_f32);
@@ -986,21 +1004,46 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_MUL_MAT:
{
- // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
-
GGML_ASSERT(ne00 == ne10);
- // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
- uint gqa = ne12/ne02;
GGML_ASSERT(ne03 == ne13);
+ const uint gqa = ne12/ne02;
+
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+ // to the matrix-vector kernel
+ int ne11_mm_min = 1;
+
+#if 0
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
+ // these numbers do not translate to other devices or model sizes
+ // TODO: need to find a better approach
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
+ switch (src0t) {
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
+ case GGML_TYPE_Q5_0: // not tested yet
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
+ default: ne11_mm_min = 1; break;
+ }
+ }
+#endif
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
- if (!ggml_is_transposed(src0) &&
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+ !ggml_is_transposed(src0) &&
!ggml_is_transposed(src1) &&
src1t == GGML_TYPE_F32 &&
- [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
- ne00%32 == 0 &&
- ne11 > 2) {
+ ne00 % 32 == 0 &&
+ ne11 > ne11_mm_min) {
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
@@ -1029,17 +1072,18 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
int nth0 = 32;
int nth1 = 1;
int nrows = 1;
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
// use custom matrix x vector kernel
switch (src0t) {
case GGML_TYPE_F32:
{
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
nrows = 4;
} break;
case GGML_TYPE_F16:
@@ -1047,12 +1091,12 @@ void ggml_metal_graph_compute(
nth0 = 32;
nth1 = 1;
if (ne11 * ne12 < 4) {
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
nrows = ne11;
} else {
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
nrows = 4;
}
} break;
@@ -1063,7 +1107,7 @@ void ggml_metal_graph_compute(
nth0 = 8;
nth1 = 8;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
} break;
case GGML_TYPE_Q4_1:
{
@@ -1072,7 +1116,7 @@ void ggml_metal_graph_compute(
nth0 = 8;
nth1 = 8;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
} break;
case GGML_TYPE_Q8_0:
{
@@ -1081,7 +1125,7 @@ void ggml_metal_graph_compute(
nth0 = 8;
nth1 = 8;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
} break;
case GGML_TYPE_Q2_K:
{
@@ -1090,7 +1134,7 @@ void ggml_metal_graph_compute(
nth0 = 2;
nth1 = 32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
} break;
case GGML_TYPE_Q3_K:
{
@@ -1099,7 +1143,7 @@ void ggml_metal_graph_compute(
nth0 = 2;
nth1 = 32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
} break;
case GGML_TYPE_Q4_K:
{
@@ -1108,7 +1152,7 @@ void ggml_metal_graph_compute(
nth0 = 4; //1;
nth1 = 8; //32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
} break;
case GGML_TYPE_Q5_K:
{
@@ -1117,7 +1161,7 @@ void ggml_metal_graph_compute(
nth0 = 2;
nth1 = 32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
} break;
case GGML_TYPE_Q6_K:
{
@@ -1126,7 +1170,7 @@ void ggml_metal_graph_compute(
nth0 = 2;
nth1 = 32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
} break;
default:
{
@@ -1155,7 +1199,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
- src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q4_K) {