summaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
authorcebtenzzre <cebtenzzre@gmail.com>2023-11-01 18:04:33 -0400
committerGitHub <noreply@github.com>2023-11-01 18:04:33 -0400
commit898aeca90a9bb992f506234cf3b8b7f7fa28a1df (patch)
tree125f8a9b466efd4534ecd3e64419ece001c86a7d /ggml-metal.m
parentc43c2da8afacaddfe51c09b21dbd9922cd0ea46b (diff)
llama : implement YaRN RoPE scaling (#2268)
Co-authored-by: cebtenzzre <cebtenzzre@gmail.com> Co-authored-by: Jeffrey Quesnelle <jquesnelle@gmail.com>
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m24
1 files changed, 16 insertions, 8 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 1f034150..611d5e17 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1400,14 +1400,18 @@ void ggml_metal_graph_compute(
const int nth = MIN(1024, ne00);
- const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_dims = ((int32_t *) dst->op_params)[1];
- const int mode = ((int32_t *) dst->op_params)[2];
-
- float freq_base;
- float freq_scale;
- memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
- memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+ const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ const int mode = ((int32_t *) dst->op_params)[2];
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
+
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
@@ -1439,6 +1443,10 @@ void ggml_metal_graph_compute(
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
+ [encoder setBytes:&ext_factor length:sizeof(float) atIndex:24];
+ [encoder setBytes:&attn_factor length:sizeof(float) atIndex:25];
+ [encoder setBytes:&beta_fast length:sizeof(float) atIndex:26];
+ [encoder setBytes:&beta_slow length:sizeof(float) atIndex:27];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;