summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml-metal.metal72
1 files changed, 53 insertions, 19 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 12ab9cca..8054cc40 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -1674,7 +1674,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
-static void rope_yarn(
+static inline void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
thread float * cos_theta, thread float * sin_theta) {
// Get n-d rotational scaling corrected for extrapolation
@@ -1828,35 +1828,69 @@ kernel void kernel_rope_neox(
const float theta_base = (float) pos[i2];
const float inv_ndims = -1.f/n_dims;
+ float theta = theta_base * pow(freq_base, 2*tiitg*inv_ndims);
+ const float theta_multiplier = pow(freq_base, 2*tptg.x*inv_ndims);
+
float cos_theta;
float sin_theta;
- for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
- if (i0 < n_dims) {
- const int64_t ic = i0/2;
+ int64_t i0 = 2*tiitg;
+ for ( ; i0 < n_dims; i0 += 2*tptg.x) {
+ const int64_t ic = i0/2;
- const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
- const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
- rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
+ const float x0 = src[0];
+ const float x1 = src[n_dims/2];
- const float x0 = src[0];
- const float x1 = src[n_dims/2];
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
- dst_data[0] = x0*cos_theta - x1*sin_theta;
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
- } else {
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ theta *= theta_multiplier;
+ }
+ for ( ; i0 < ne0; i0 += 2*tptg.x) {
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
- dst_data[0] = src[0];
- dst_data[1] = src[1];
- }
+ dst_data[0] = src[0];
+ dst_data[1] = src[1];
}
+
+ // Original version
+ //for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
+ // if (i0 < n_dims) {
+ // const int64_t ic = i0/2;
+
+ // // Who thought that having a pow() evaluation in a loop is a good idea?
+ // //const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+ // const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+ // rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+ // device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+ // device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
+
+ // const float x0 = src[0];
+ // const float x1 = src[n_dims/2];
+
+ // dst_data[0] = x0*cos_theta - x1*sin_theta;
+ // dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+
+ // theta *= theta_multiplier;
+ // } else {
+ // device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ // device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ // dst_data[0] = src[0];
+ // dst_data[1] = src[1];
+ // }
+ //}
}
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;