summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShouzheng Liu <lshzh.hi@gmail.com>2023-08-21 06:59:29 -0400
committerGitHub <noreply@github.com>2023-08-21 13:59:29 +0300
commitdadbed99e65252d79f81101a392d0d6497b86caa (patch)
treee2d6ec78c820b0afaff858c3840459c28055057c
parentcb1c0727bd59803b439b6a3af121c99e6393ff3d (diff)
metal : fix synchronization in new matrix multiplication kernel (#2686)
-rw-r--r--ggml-metal.metal3
1 files changed, 2 insertions, 1 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 3f312523..88d48f6c 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -1898,10 +1898,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
for (int i = 0; i < 8; i++) {
+ threadgroup_barrier(mem_flags::mem_device);
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
}
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_device);
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
if (sgitg==0) {
for (int i = 0; i < n_rows; i++) {