summaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m52
1 files changed, 28 insertions, 24 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index fddc44f7..946f1181 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -172,8 +172,10 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
- GGML_METAL_KERNEL_TYPE_ROPE_F32,
- GGML_METAL_KERNEL_TYPE_ROPE_F16,
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -626,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -2285,7 +2289,7 @@ static enum ggml_status ggml_metal_graph_compute(
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
float freq_base;
float freq_scale;
@@ -2302,22 +2306,23 @@ static enum ggml_status ggml_metal_graph_compute(
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
const bool is_neox = mode & 2;
- const bool is_glm = mode & 4;
- GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
+ id<MTLComputePipelineState> pipeline = nil;
if (!is_neox) {
- GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
+ default: GGML_ASSERT(false);
+ };
+ } else {
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
+ default: GGML_ASSERT(false);
+ };
}
- id<MTLComputePipelineState> pipeline = nil;
-
- switch (src0->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
- default: GGML_ASSERT(false);
- };
-
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2345,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
- [encoder setBytes:&mode length:sizeof( int) atIndex:22];
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
+ [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;