summaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m31
1 files changed, 20 insertions, 11 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index cd9d0045..7a369b55 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1657,6 +1657,10 @@ void ggml_metal_graph_compute(
}
};
+ if (ggml_is_quantized(src0t)) {
+ GGML_ASSERT(ne00 >= nth0*nth1);
+ }
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -1715,6 +1719,9 @@ void ggml_metal_graph_compute(
// TODO: make this more general
GGML_ASSERT(n_as <= 8);
+ // max size of the src1ids array in the kernel stack
+ GGML_ASSERT(ne11 <= 512);
+
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
const int64_t ne20 = src2 ? src2->ne[0] : 0;
@@ -1732,9 +1739,6 @@ void ggml_metal_graph_compute(
GGML_ASSERT(!ggml_is_transposed(src2));
GGML_ASSERT(!ggml_is_transposed(src1));
- GGML_ASSERT(ne20 % 32 == 0);
- // !!!!!!!!! TODO: this assert is probably required but not sure!
- //GGML_ASSERT(ne20 >= 64);
GGML_ASSERT(src1t == GGML_TYPE_F32);
const uint r2 = ne12/ne22;
@@ -1742,22 +1746,22 @@ void ggml_metal_graph_compute(
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
- int ne11_mm_min = 1;
+ int ne11_mm_min = n_as;
const int idx = ((int32_t *) dst->op_params)[0];
// batch size
GGML_ASSERT(ne01 == ne11);
- const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
-
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
// !!!
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
// indirect matrix multiplication
// !!!
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+ ne20 % 32 == 0 && ne20 >= 64 &&
+ ne11 > ne11_mm_min) {
switch (src2->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
@@ -1787,7 +1791,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
@@ -1805,8 +1809,7 @@ void ggml_metal_graph_compute(
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
- // TODO: processing one row at a time (ne11 -> 1) is not efficient
- [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
int nth0 = 32;
int nth1 = 1;
@@ -1889,11 +1892,17 @@ void ggml_metal_graph_compute(
} break;
default:
{
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
GGML_ASSERT(false && "not implemented");
}
};
+ if (ggml_is_quantized(src2t)) {
+ GGML_ASSERT(ne20 >= nth0*nth1);
+ }
+
+ const int64_t _ne1 = 1; // kernels needs a reference in constant memory
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];