diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-06-05 11:29:20 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-05 11:29:20 +0300 |
commit | 2b3389677a833cee0880226533a1768b1a9508d2 (patch) | |
tree | 3af4beed34ff6c1309d202a7028f5ab85ec43693 /ggml.c | |
parent | 9973e81c5ccf4f31b3980f5aa73f5cfea8699860 (diff) |
ggml : refactor rope norm/neox (#7634)
* ggml : unify rope norm/neox (CPU)
* ggml : fix compile warning
* ggml : remove GLM rope mode
ggml-ci
* metal : better rope implementation
ggml-ci
* cuda : better rope implementation
ggml-ci
* naming : n_orig_ctx -> n_ctx_orig
ggml-ci
* dev : add reminders to update backends
ggml-ci
* vulkan : fix ggml_rope_ext() usage
* cuda : fix array size + indents
ggml-ci
Diffstat (limited to 'ggml.c')
-rw-r--r-- | ggml.c | 330 |
1 files changed, 97 insertions, 233 deletions
@@ -6250,16 +6250,13 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow, - float xpos_base, - bool xpos_down, bool inplace) { GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); @@ -6280,15 +6277,13 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; + int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float)); - memcpy(params + 11, &xpos_base, sizeof(float)); - memcpy(params + 12, &xpos_down, sizeof(bool)); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE; @@ -6305,10 +6300,9 @@ struct ggml_tensor * ggml_rope( struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - int mode, - int n_ctx) { + int mode) { return ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false + ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false ); } @@ -6317,10 +6311,9 @@ struct ggml_tensor * ggml_rope_inplace( struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - int mode, - int n_ctx) { + int mode) { return ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true + ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true ); } @@ -6331,8 +6324,7 @@ struct ggml_tensor * ggml_rope_ext( struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -6340,8 +6332,8 @@ struct ggml_tensor * ggml_rope_ext( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, false ); } @@ -6352,8 +6344,7 @@ struct ggml_tensor * ggml_rope_ext_inplace( struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -6361,8 +6352,8 @@ struct ggml_tensor * ggml_rope_ext_inplace( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, true ); } @@ -6372,8 +6363,7 @@ struct ggml_tensor * ggml_rope_custom( struct ggml_tensor * b, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -6381,8 +6371,8 @@ struct ggml_tensor * ggml_rope_custom( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false + ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, false ); } @@ -6392,8 +6382,7 @@ struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_tensor * b, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -6401,21 +6390,11 @@ struct ggml_tensor * ggml_rope_custom_inplace( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true + ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, true ); } -struct ggml_tensor * ggml_rope_xpos_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - float base, - bool down) { - return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true); -} - // ggml_rope_back struct ggml_tensor * ggml_rope_back( @@ -6425,16 +6404,13 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, - float beta_slow, - float xpos_base, - bool xpos_down) { + float beta_slow) { GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[2] == b->ne[0]); @@ -6450,15 +6426,13 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; + int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float)); - memcpy(params + 11, &xpos_base, sizeof(float)); - memcpy(params + 12, &xpos_down, sizeof(bool)); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE_BACK; @@ -14227,8 +14201,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static void rope_yarn( float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta -) { + float * cos_theta, float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -14245,18 +14218,19 @@ static void rope_yarn( // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) { - return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); } static void ggml_rope_cache_init( - float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, - float * cache, float sin_sign, float theta_scale -) { + float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, + float * cache, float sin_sign, float theta_scale) { + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py float theta = theta_base; for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0/2] : 1.0f; rope_yarn( - theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] + theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] ); cache[i0 + 1] *= sin_sign; @@ -14265,11 +14239,11 @@ static void ggml_rope_cache_init( } GGML_CALL void ggml_rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] ) { // start and end correction dims - float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)); - float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)); + float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); + float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); dims[0] = MAX(0, start); dims[1] = MIN(n_dims - 1, end); } @@ -14289,15 +14263,11 @@ static void ggml_compute_forward_rope_f32( float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - // these two only relevant for xPos RoPE: - float xpos_base; - bool xpos_down; - //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); @@ -14305,8 +14275,6 @@ static void ggml_compute_forward_rope_f32( memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float)); - memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool)); GGML_TENSOR_UNARY_OP_LOCALS @@ -14336,20 +14304,15 @@ static void ggml_compute_forward_rope_f32( const float theta_scale = powf(freq_base, -2.0f/n_dims); float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; - const bool is_glm = mode & 4; const float * freq_factors = NULL; - if (is_neox) { - if (src2 != NULL) { - GGML_ASSERT(src2->type == GGML_TYPE_F32); - GGML_ASSERT(src2->ne[0] >= n_dims / 2); - freq_factors = (const float *) src2->data; - } - } else { - GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox"); + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; } // backward process uses inverse rotation by cos and sin. @@ -14364,94 +14327,50 @@ static void ggml_compute_forward_rope_f32( const int64_t p = pos[i2]; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox - ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } + ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta_base = (float)p; - - if (is_glm) { - theta_base = MIN(p, n_ctx - 2); - float block_theta = MAX(p - (n_ctx - 2), 0); - for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta_base); - const float sin_theta = sinf(theta_base) * sin_sign; - const float cos_block_theta = cosf(block_theta); - const float sin_block_theta = sinf(block_theta) * sin_sign; - - theta_base *= theta_scale; - block_theta *= theta_scale; - - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - const float x2 = src[n_dims]; - const float x3 = src[n_dims/2*3]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta; - dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta; - } - } else if (!is_neox) { - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + if (!is_neox) { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { const float cos_theta = cache[i0 + 0]; const float sin_theta = cache[i0 + 1]; - // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; - if (xpos_down) zeta = 1.0f / zeta; - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float x0 = src[0]; const float x1 = src[1]; - dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta; - dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; } } else { - // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py - for (int64_t ic = 0; ic < ne0; ic += 2) { - if (ic < n_dims) { - const int64_t i0 = ic/2; - - const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f; - - float cos_theta, sin_theta; - rope_yarn( - theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, - &cos_theta, &sin_theta - ); + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; - sin_theta *= sin_sign; - theta_base *= theta_scale; + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - const float x0 = src[0]; - const float x1 = src[n_dims/2]; + const float x0 = src[0]; + const float x1 = src[n_dims/2]; - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - const int64_t i0 = ic; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } + } - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } @@ -14477,8 +14396,8 @@ static void ggml_compute_forward_rope_f16( //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -14514,20 +14433,15 @@ static void ggml_compute_forward_rope_f16( const float theta_scale = powf(freq_base, -2.0f/n_dims); float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; - const bool is_glm = mode & 4; const float * freq_factors = NULL; - if (is_neox) { - if (src2 != NULL) { - GGML_ASSERT(src2->type == GGML_TYPE_F32); - GGML_ASSERT(src2->ne[0] >= n_dims / 2); - freq_factors = (const float *) src2->data; - } - } else { - GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox"); + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; } // backward process uses inverse rotation by cos and sin. @@ -14542,43 +14456,14 @@ static void ggml_compute_forward_rope_f16( const int64_t p = pos[i2]; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox - ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } + ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta_base = (float)p; - - if (is_glm) { - theta_base = MIN(p, n_ctx - 2); - float block_theta = MAX(p - (n_ctx - 2), 0); - for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta_base); - const float sin_theta = sinf(theta_base) * sin_sign; - const float cos_block_theta = cosf(block_theta); - const float sin_block_theta = sinf(block_theta) * sin_sign; - - theta_base *= theta_scale; - block_theta *= theta_scale; - - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = GGML_FP16_TO_FP32(src[0]); - const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); - const float x2 = GGML_FP16_TO_FP32(src[n_dims]); - const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]); - - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); - dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); - } - } else if (!is_neox) { - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + if (!is_neox) { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { const float cos_theta = cache[i0 + 0]; const float sin_theta = cache[i0 + 1]; @@ -14592,40 +14477,29 @@ static void ggml_compute_forward_rope_f16( dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } } else { - // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py - for (int64_t ic = 0; ic < ne0; ic += 2) { - if (ic < n_dims) { - const int64_t i0 = ic/2; - - const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f; + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; - float cos_theta, sin_theta; - rope_yarn( - theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, - &cos_theta, &sin_theta - ); - - sin_theta *= sin_sign; - theta_base *= theta_scale; + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - const float x0 = GGML_FP16_TO_FP32(src[0]); - const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - } else { - const int64_t i0 = ic; + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } @@ -18327,9 +18201,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; - const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_orig_ctx = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); @@ -18337,8 +18211,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float)); - memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool)); src0->grad = ggml_add_or_set(ctx, src0->grad, @@ -18348,16 +18220,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src2, n_dims, mode, - n_ctx, - n_orig_ctx, + n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, - beta_slow, - xpos_base, - xpos_down), + beta_slow), zero_table); } } break; @@ -18367,9 +18236,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; - const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_orig_ctx = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); @@ -18377,8 +18246,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float)); - memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool)); src0->grad = ggml_add_or_set(ctx, src0->grad, @@ -18388,16 +18255,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src2, n_dims, mode, - n_ctx, - n_orig_ctx, + n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, - xpos_base, - xpos_down, false), zero_table); } |