diff options
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 52 |
1 files changed, 28 insertions, 24 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index fddc44f7..946f1181 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -172,8 +172,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_ROPE_F32, - GGML_METAL_KERNEL_TYPE_ROPE_F16, + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, + GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32, @@ -626,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); @@ -2285,7 +2289,7 @@ static enum ggml_status ggml_metal_graph_compute( const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; float freq_base; float freq_scale; @@ -2302,22 +2306,23 @@ static enum ggml_status ggml_metal_graph_compute( memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); const bool is_neox = mode & 2; - const bool is_glm = mode & 4; - GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal"); + id<MTLComputePipelineState> pipeline = nil; if (!is_neox) { - GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox"); + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + default: GGML_ASSERT(false); + }; + } else { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ASSERT(false); + }; } - id<MTLComputePipelineState> pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break; - default: GGML_ASSERT(false); - }; - [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -2345,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&mode length:sizeof( int) atIndex:22]; - [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:24]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29]; + [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; |