diff options
author | Shijie <821898965@qq.com> | 2024-04-16 23:40:48 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-16 18:40:48 +0300 |
commit | f4dea7da1841a92d2788b0535063abf2f0e28461 (patch) | |
tree | c7a729d974e4315c71c78eea84fa08dda920b649 /ggml-metal.m | |
parent | 8a56075b07a8b571bf95a912ffdce4c928c2b414 (diff) |
llama : add qwen2moe (#6074)
* support qwen2moe
* fix-review
* metal : support unary ops for nelements % 4 != 0
* metal : require contiguousness for float4 unary kernels
* metal : require contiguousness for float4 unary kernels (cont)
* fix-review
* names : for brevity "SHARED_EXP" -> "SHEXP"
* llama : reuse build_moe_ffn()
* llama : add model type name
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 57 |
1 files changed, 42 insertions, 15 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 0207b787..ae6ddeac 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -41,8 +41,11 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_TANH, GGML_METAL_KERNEL_TYPE_RELU, GGML_METAL_KERNEL_TYPE_GELU, + GGML_METAL_KERNEL_TYPE_GELU_4, GGML_METAL_KERNEL_TYPE_GELU_QUICK, + GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, + GGML_METAL_KERNEL_TYPE_SILU_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX, GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, @@ -473,8 +476,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); @@ -1178,6 +1184,9 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(gf->nodes[i])) { + // we are not taking into account the strides, so for now require contiguous tensors + GGML_ASSERT(ggml_is_contiguous(src0)); + case GGML_UNARY_OP_TANH: { id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; @@ -1204,42 +1213,60 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_UNARY_OP_GELU: { - id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + int64_t n = ggml_nelements(dst); + + id<MTLComputePipelineState> pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_UNARY_OP_GELU_QUICK: { - id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + int64_t n = ggml_nelements(dst); + + id<MTLComputePipelineState> pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_UNARY_OP_SILU: { - id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + int64_t n = ggml_nelements(dst); + + id<MTLComputePipelineState> pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; default: { |