summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml-metal.m2
-rw-r--r--ggml-metal.metal38
2 files changed, 29 insertions, 11 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index e929c4b0..8c3c64f5 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -840,7 +840,7 @@ void ggml_metal_graph_compute(
switch (src0t) {
case GGML_TYPE_F16:
{
- nth0 = 64;
+ nth0 = 32;
nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
} break;
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 82e1a0c7..02db5323 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -528,24 +528,42 @@ kernel void kernel_mul_mat_f16_f32(
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
- sum[tpitg.x] = 0.0f;
+ uint ith = tpitg.x;
+ uint nth = tptg.x;
- for (int i = tpitg.x; i < ne00; i += tptg.x) {
- sum[tpitg.x] += (float) x[i] * (float) y[i];
+ sum[ith] = 0.0f;
+
+ for (int i = ith; i < ne00; i += nth) {
+ sum[ith] += (float) x[i] * (float) y[i];
}
// accumulate the sum from all threads in the threadgroup
threadgroup_barrier(mem_flags::mem_threadgroup);
- for (uint i = tptg.x/2; i > 0; i /= 2) {
- if (tpitg.x < i) {
- sum[tpitg.x] += sum[tpitg.x + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (ith%4 == 0) {
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
}
-
- if (tpitg.x == 0) {
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (ith%16 == 0) {
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (ith == 0) {
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
}
+
+ // Original implementation. Left behind commented out for now
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
+ //for (uint i = tptg.x/2; i > 0; i /= 2) {
+ // if (tpitg.x < i) {
+ // sum[tpitg.x] += sum[tpitg.x + i];
+ // }
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
+ //}
+ //
+ //if (tpitg.x == 0) {
+ // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
+ //}
}
kernel void kernel_alibi_f32(