summaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m43
1 files changed, 27 insertions, 16 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index d52a1c3c..6cfacf64 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1028,20 +1028,27 @@ void ggml_metal_graph_compute(
int nth = 32; // SIMD width
if (ne00%4 == 0) {
+ while (nth < ne00/4 && nth < 256) {
+ nth *= 2;
+ }
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
} else {
- do {
+ while (nth < ne00 && nth < 1024) {
nth *= 2;
- } while (nth <= ne00 && nth <= 1024);
- nth /= 2;
+ }
[encoder setComputePipelineState:ctx->pipeline_soft_max];
}
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
+
+ const float scale = ((float *) dst->op_params)[0];
+
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
@@ -1351,15 +1358,19 @@ void ggml_metal_graph_compute(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
- const int nth = MIN(512, ne00);
+ int nth = 32; // SIMD width
+
+ while (nth < ne00/4 && nth < 1024) {
+ nth *= 2;
+ }
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0);