diff options
author | Dave <dave-fl@users.noreply.github.com> | 2024-04-14 07:14:19 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-14 13:14:19 +0200 |
commit | 422c2aff1c9735853c9d8f5162104e41a364adc4 (patch) | |
tree | 4d6f263a1780612653ed26c830d2c7c0f4ecd516 | |
parent | 8800226d65d5c98cd34eede6a6c05c78405c52da (diff) |
Added support for GGML_OP_CLAMP in Metal (#6662)
* Added support for GGML_OP_CLAMP in Metal
* Corrected size
---------
Co-authored-by: dave-fl <dave@Davids-MacBook-Pro.local>
-rw-r--r-- | ggml-metal.m | 22 | ||||
-rw-r--r-- | ggml-metal.metal | 9 |
2 files changed, 31 insertions, 0 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 38da384b..0207b787 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -37,6 +37,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_DIV_ROW, GGML_METAL_KERNEL_TYPE_SCALE, GGML_METAL_KERNEL_TYPE_SCALE_4, + GGML_METAL_KERNEL_TYPE_CLAMP, GGML_METAL_KERNEL_TYPE_TANH, GGML_METAL_KERNEL_TYPE_RELU, GGML_METAL_KERNEL_TYPE_GELU, @@ -468,6 +469,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); @@ -713,6 +715,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_SCALE: + case GGML_OP_CLAMP: case GGML_OP_SQR: case GGML_OP_SUM_ROWS: return true; @@ -1154,6 +1157,25 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_CLAMP: + { + id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; + + float min; + float max; + memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(gf->nodes[i])) { case GGML_UNARY_OP_TANH: diff --git a/ggml-metal.metal b/ggml-metal.metal index 3a823e65..56748166 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -213,6 +213,15 @@ kernel void kernel_scale_4( dst[tpig] = src0[tpig] * scale; } +kernel void kernel_clamp( + device const float * src0, + device float * dst, + constant float & min, + constant float & max, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); +} + kernel void kernel_relu( device const float * src0, device float * dst, |