diff options
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 1f034150..611d5e17 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1400,14 +1400,18 @@ void ggml_metal_graph_compute( const int nth = MIN(1024, ne00); - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - - float freq_base; - float freq_scale; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[3]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); switch (src0->type) { case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break; @@ -1439,6 +1443,10 @@ void ggml_metal_graph_compute( [encoder setBytes:&mode length:sizeof( int) atIndex:21]; [encoder setBytes:&freq_base length:sizeof(float) atIndex:22]; [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23]; + [encoder setBytes:&ext_factor length:sizeof(float) atIndex:24]; + [encoder setBytes:&attn_factor length:sizeof(float) atIndex:25]; + [encoder setBytes:&beta_fast length:sizeof(float) atIndex:26]; + [encoder setBytes:&beta_slow length:sizeof(float) atIndex:27]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; |