diff options
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 31 |
1 files changed, 20 insertions, 11 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index cd9d0045..7a369b55 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1657,6 +1657,10 @@ void ggml_metal_graph_compute( } }; + if (ggml_is_quantized(src0t)) { + GGML_ASSERT(ne00 >= nth0*nth1); + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1715,6 +1719,9 @@ void ggml_metal_graph_compute( // TODO: make this more general GGML_ASSERT(n_as <= 8); + // max size of the src1ids array in the kernel stack + GGML_ASSERT(ne11 <= 512); + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; const int64_t ne20 = src2 ? src2->ne[0] : 0; @@ -1732,9 +1739,6 @@ void ggml_metal_graph_compute( GGML_ASSERT(!ggml_is_transposed(src2)); GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(ne20 % 32 == 0); - // !!!!!!!!! TODO: this assert is probably required but not sure! - //GGML_ASSERT(ne20 >= 64); GGML_ASSERT(src1t == GGML_TYPE_F32); const uint r2 = ne12/ne22; @@ -1742,22 +1746,22 @@ void ggml_metal_graph_compute( // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel - int ne11_mm_min = 1; + int ne11_mm_min = n_as; const int idx = ((int32_t *) dst->op_params)[0]; // batch size GGML_ASSERT(ne01 == ne11); - const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel // !!! // TODO: for now, always use mat-vec kernels until we figure out how to improve the // indirect matrix multiplication // !!! - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) { + if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && + ne20 % 32 == 0 && ne20 >= 64 && + ne11 > ne11_mm_min) { switch (src2->type) { case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break; @@ -1787,7 +1791,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; [encoder setBytes:&r2 length:sizeof(r2) atIndex:16]; [encoder setBytes:&r3 length:sizeof(r3) atIndex:17]; @@ -1805,8 +1809,7 @@ void ggml_metal_graph_compute( [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - // TODO: processing one row at a time (ne11 -> 1) is not efficient - [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { int nth0 = 32; int nth1 = 1; @@ -1889,11 +1892,17 @@ void ggml_metal_graph_compute( } break; default: { - GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); + GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); GGML_ASSERT(false && "not implemented"); } }; + if (ggml_is_quantized(src2t)) { + GGML_ASSERT(ne20 >= nth0*nth1); + } + + const int64_t _ne1 = 1; // kernels needs a reference in constant memory + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; |