diff options
-rw-r--r-- | ggml-metal.m | 6 | ||||
-rw-r--r-- | ggml-metal.metal | 6 |
2 files changed, 6 insertions, 6 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 00df2283..3cf80de7 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1642,8 +1642,8 @@ static enum ggml_status ggml_metal_graph_compute( // TODO: make this more general GGML_ASSERT(n_as <= 8); - // max size of the src1ids array in the kernel stack - GGML_ASSERT(ne11 <= 512); + // max size of the src1ids array in the kernel shared buffer + GGML_ASSERT(ne11 <= 4096); const int64_t ne20 = src2 ? src2->ne[0] : 0; const int64_t ne21 = src2 ? src2->ne[1] : 0; @@ -1741,7 +1741,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j]; } - [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { diff --git a/ggml-metal.metal b/ggml-metal.metal index 6ebbbd19..50185ae4 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -5386,7 +5386,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_ void kernel_mul_mm_id_impl( device const uchar * src0, device const uchar * src1, - thread short * src1ids, + threadgroup short * src1ids, device float * dst, constant int64_t & ne00, constant int64_t & ne02, @@ -5589,9 +5589,9 @@ kernel void kernel_mul_mm_id( tgpig.z = tgpig.z%(ne12*ne13); // row indices of src1 for expert id - int64_t _ne1 = 0; - short src1ids[512]; + threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192); + int64_t _ne1 = 0; for (int64_t i1 = 0; i1 < ne1; i1++) { if (((device int32_t *) (ids + i1*nbi1))[idx] == id) { src1ids[_ne1++] = i1; |