diff options
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 132 |
1 files changed, 68 insertions, 64 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 835c5f29..969cf7da 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -167,7 +167,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { #define GGML_METAL_ADD_KERNEL(name) \ ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \ - fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \ + fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \ + (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \ + (int) ctx->pipeline_##name.threadExecutionWidth); \ if (error) { \ fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ return NULL; \ @@ -218,12 +220,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { #undef GGML_METAL_ADD_KERNEL } - fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); - fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); + fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); if (ctx->device.maxTransferRate != 0) { - fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); + fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); } else { - fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__); + fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__); } return ctx; @@ -537,8 +539,8 @@ void ggml_metal_graph_compute( id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb; + const int node_start = (cb_idx + 0) * n_nodes_per_cb; + const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); for (int ind = node_start; ind < node_end; ++ind) { const int i = has_concur ? ctx->concur_list[ind] : ind; @@ -744,32 +746,31 @@ void ggml_metal_graph_compute( [ctx->device supportsFamily:MTLGPUFamilyApple7] && ne00%32 == 0 && ne11 > 1) { - switch (src0->type) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; - case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; - case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break; - default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); - } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; - [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + switch (src0->type) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; + case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; + case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; + case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; + case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break; + case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break; + default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); } - else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } else { int nth0 = 32; int nth1 = 1; @@ -868,24 +869,24 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q3_K) { #ifdef GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #else - [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #endif } else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -938,16 +939,17 @@ void ggml_metal_graph_compute( } break; case GGML_OP_NORM: { - const float eps = 1e-5f; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); const int nth = 256; [encoder setComputePipelineState:ctx->pipeline_norm]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); @@ -990,7 +992,9 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; + const int nth = 32; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_ROPE: @@ -1005,8 +1009,8 @@ void ggml_metal_graph_compute( memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); [encoder setComputePipelineState:ctx->pipeline_rope]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; @@ -1057,24 +1061,24 @@ void ggml_metal_graph_compute( default: GGML_ASSERT(false && "not implemented"); } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; |