diff options
Diffstat (limited to 'ggml-metal.metal')
-rw-r--r-- | ggml-metal.metal | 61 |
1 files changed, 55 insertions, 6 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal index f3152778..471d7d39 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1061,6 +1061,45 @@ kernel void kernel_alibi_f32( } } +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + float * cos_theta, float * sin_theta +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + *cos_theta = cosf(theta) * mscale; + *sin_theta = sinf(theta) * mscale; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { + return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +static void rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); +} + typedef void (rope_t)( device const void * src0, device const int32_t * src1, @@ -1116,6 +1155,10 @@ kernel void kernel_rope( constant int & mode, constant float & freq_base, constant float & freq_scale, + constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, uint tiitg[[thread_index_in_threadgroup]], uint3 tptg[[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { @@ -1125,19 +1168,22 @@ kernel void kernel_rope( const bool is_neox = mode & 2; + float corr_dims[2]; + rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + device const int32_t * pos = src1; const int64_t p = pos[i2]; - const float theta_0 = freq_scale * (float)p; + const float theta_0 = (float)p; const float inv_ndims = -1.f/n_dims; if (!is_neox) { for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { const float theta = theta_0 * pow(freq_base, inv_ndims*i0); - const float cos_theta = cos(theta); - const float sin_theta = sin(theta); + float cos_theta, sin_theta; + rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -1152,9 +1198,12 @@ kernel void kernel_rope( for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { - const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib); - const float cos_theta = cos(theta); - const float sin_theta = sin(theta); + // simplified from `(ib * n_dims + ic) * inv_ndims` + const float cur_rot = inv_ndims*ic - ib; + + const float theta = theta_0 * pow(freq_base, cur_rot); + float cos_theta, sin_theta; + rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); const int64_t i0 = ib*n_dims + ic/2; |