diff options
-rw-r--r-- | ggml-metal.metal | 72 |
1 files changed, 53 insertions, 19 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal index 12ab9cca..8054cc40 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1674,7 +1674,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static void rope_yarn( +static inline void rope_yarn( float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, thread float * cos_theta, thread float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation @@ -1828,35 +1828,69 @@ kernel void kernel_rope_neox( const float theta_base = (float) pos[i2]; const float inv_ndims = -1.f/n_dims; + float theta = theta_base * pow(freq_base, 2*tiitg*inv_ndims); + const float theta_multiplier = pow(freq_base, 2*tptg.x*inv_ndims); + float cos_theta; float sin_theta; - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - if (i0 < n_dims) { - const int64_t ic = i0/2; + int64_t i0 = 2*tiitg; + for ( ; i0 < n_dims; i0 += 2*tptg.x) { + const int64_t ic = i0/2; - const float theta = theta_base * pow(freq_base, inv_ndims*i0); + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + const float x0 = src[0]; + const float x1 = src[n_dims/2]; - const float x0 = src[0]; - const float x1 = src[n_dims/2]; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + theta *= theta_multiplier; + } + for ( ; i0 < ne0; i0 += 2*tptg.x) { + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } + dst_data[0] = src[0]; + dst_data[1] = src[1]; } + + // Original version + //for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + // if (i0 < n_dims) { + // const int64_t ic = i0/2; + + // // Who thought that having a pow() evaluation in a loop is a good idea? + // //const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + // const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + // rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + // device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + // device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + // const float x0 = src[0]; + // const float x1 = src[n_dims/2]; + + // dst_data[0] = x0*cos_theta - x1*sin_theta; + // dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + + // theta *= theta_multiplier; + // } else { + // device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + // device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + // dst_data[0] = src[0]; + // dst_data[1] = src[1]; + // } + //} } typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t; |