diff options
author | cebtenzzre <cebtenzzre@gmail.com> | 2023-11-01 18:04:33 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-01 18:04:33 -0400 |
commit | 898aeca90a9bb992f506234cf3b8b7f7fa28a1df (patch) | |
tree | 125f8a9b466efd4534ecd3e64419ece001c86a7d /ggml-metal.m | |
parent | c43c2da8afacaddfe51c09b21dbd9922cd0ea46b (diff) |
llama : implement YaRN RoPE scaling (#2268)
Co-authored-by: cebtenzzre <cebtenzzre@gmail.com>
Co-authored-by: Jeffrey Quesnelle <jquesnelle@gmail.com>
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 1f034150..611d5e17 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1400,14 +1400,18 @@ void ggml_metal_graph_compute( const int nth = MIN(1024, ne00); - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - - float freq_base; - float freq_scale; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[3]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); switch (src0->type) { case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break; @@ -1439,6 +1443,10 @@ void ggml_metal_graph_compute( [encoder setBytes:&mode length:sizeof( int) atIndex:21]; [encoder setBytes:&freq_base length:sizeof(float) atIndex:22]; [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23]; + [encoder setBytes:&ext_factor length:sizeof(float) atIndex:24]; + [encoder setBytes:&attn_factor length:sizeof(float) atIndex:25]; + [encoder setBytes:&beta_fast length:sizeof(float) atIndex:26]; + [encoder setBytes:&beta_slow length:sizeof(float) atIndex:27]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; |