summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml-metal.m22
-rw-r--r--ggml-metal.metal9
2 files changed, 31 insertions, 0 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 38da384b..0207b787 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -37,6 +37,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_DIV_ROW,
GGML_METAL_KERNEL_TYPE_SCALE,
GGML_METAL_KERNEL_TYPE_SCALE_4,
+ GGML_METAL_KERNEL_TYPE_CLAMP,
GGML_METAL_KERNEL_TYPE_TANH,
GGML_METAL_KERNEL_TYPE_RELU,
GGML_METAL_KERNEL_TYPE_GELU,
@@ -468,6 +469,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
@@ -713,6 +715,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SCALE:
+ case GGML_OP_CLAMP:
case GGML_OP_SQR:
case GGML_OP_SUM_ROWS:
return true;
@@ -1154,6 +1157,25 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
+ case GGML_OP_CLAMP:
+ {
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
+
+ float min;
+ float max;
+ memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
+ memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_TANH:
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 3a823e65..56748166 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -213,6 +213,15 @@ kernel void kernel_scale_4(
dst[tpig] = src0[tpig] * scale;
}
+kernel void kernel_clamp(
+ device const float * src0,
+ device float * dst,
+ constant float & min,
+ constant float & max,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
+}
+
kernel void kernel_relu(
device const float * src0,
device float * dst,