summaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-05-28 11:04:19 +0300
committerGitHub <noreply@github.com>2024-05-28 11:04:19 +0300
commit0548a4187f2e53b8fc6d9ff0f4c71988f708ff42 (patch)
tree35ae0e19ecc36169939620b2702fd853c8e8c116 /ggml-metal.m
parent9335b969e86a222e247adacedf814d8abfff8847 (diff)
ggml : generalize GGML_OP_CONCAT (#7563)
* ggml : generalize GGML_OP_CONCAT (WIP) ggml-ci * tests : add dim != 2 tests * metal : generalize concat kernel * tests : naming * cuda : generalize concat kernel ggml-ci * sycl : add warning and assert * ggml : fix op params handling * metal : bugfix kernel ggml-ci * ggml : reimplement CPU and Metal * cuda : add asserts ggml-ci * ggml : fix ptrs ggml-ci
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m3
1 files changed, 3 insertions, 0 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index ff9ae55a..4ba498e8 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -990,6 +990,8 @@ static enum ggml_status ggml_metal_graph_compute(
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1018,6 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
const int nth = MIN(1024, ne0);