summaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m62
1 files changed, 59 insertions, 3 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 71fcca56..6b5a8fdf 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -163,6 +163,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_KERNEL_TYPE_PAD_F32,
+ GGML_METAL_KERNEL_TYPE_ARANGE_F32,
+ GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
@@ -569,6 +571,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
@@ -697,6 +701,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
return false;
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
+ case GGML_OP_ARANGE:
+ case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
case GGML_OP_LEAKY_RELU:
return true;
@@ -1091,7 +1097,8 @@ static bool ggml_metal_graph_compute(
{
GGML_ASSERT(ggml_is_contiguous(src0));
- const float scale = *(const float *) dst->op_params;
+ float scale;
+ memcpy(&scale, dst->op_params, sizeof(scale));
int64_t n = ggml_nelements(dst);
@@ -1250,11 +1257,15 @@ static bool ggml_metal_graph_compute(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
}
- const float scale = ((float *) dst->op_params)[0];
- const float max_bias = ((float *) dst->op_params)[1];
+ float scale;
+ float max_bias;
+
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];
+
const uint32_t n_head_kv = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
@@ -2086,6 +2097,7 @@ static bool ggml_metal_graph_compute(
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
+
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
@@ -2300,6 +2312,50 @@ static bool ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
+ case GGML_OP_ARANGE:
+ {
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ float start;
+ float step;
+
+ memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
+ memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
+ [encoder setBytes:&start length:sizeof(start) atIndex:2];
+ [encoder setBytes:&step length:sizeof(step) atIndex:3];
+
+ const int nth = MIN(1024, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ const int dim = dst->op_params[0];
+ const int max_period = dst->op_params[1];
+
+ const int half = dim / 2;
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:3];
+ [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
+
+ const int nth = MIN(1024, half);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
case GGML_OP_ARGSORT:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);