diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-05-28 11:04:19 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-28 11:04:19 +0300 |
commit | 0548a4187f2e53b8fc6d9ff0f4c71988f708ff42 (patch) | |
tree | 35ae0e19ecc36169939620b2702fd853c8e8c116 /ggml-metal.m | |
parent | 9335b969e86a222e247adacedf814d8abfff8847 (diff) |
ggml : generalize GGML_OP_CONCAT (#7563)
* ggml : generalize GGML_OP_CONCAT (WIP)
ggml-ci
* tests : add dim != 2 tests
* metal : generalize concat kernel
* tests : naming
* cuda : generalize concat kernel
ggml-ci
* sycl : add warning and assert
* ggml : fix op params handling
* metal : bugfix kernel
ggml-ci
* ggml : reimplement CPU and Metal
* cuda : add asserts
ggml-ci
* ggml : fix ptrs
ggml-ci
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index ff9ae55a..4ba498e8 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -990,6 +990,8 @@ static enum ggml_status ggml_metal_graph_compute( { id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; + const int32_t dim = ((int32_t *) dst->op_params)[0]; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1018,6 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; const int nth = MIN(1024, ne0); |