summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c241
1 files changed, 175 insertions, 66 deletions
diff --git a/ggml.c b/ggml.c
index 80d68225..2c7fe476 100644
--- a/ggml.c
+++ b/ggml.c
@@ -1,4 +1,5 @@
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
+#define _USE_MATH_DEFINES // For M_PI on MSVC
#include "ggml-impl.h"
#include "ggml-quants.h"
@@ -4845,8 +4846,13 @@ static struct ggml_tensor * ggml_rope_impl(
int n_dims,
int mode,
int n_ctx,
+ int n_orig_ctx,
float freq_base,
float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow,
float xpos_base,
bool xpos_down,
bool inplace) {
@@ -4862,11 +4868,15 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
- int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
- memcpy(params + 4, &freq_base, sizeof(float));
- memcpy(params + 5, &freq_scale, sizeof(float));
- memcpy(params + 6, &xpos_base, sizeof(float));
- memcpy(params + 7, &xpos_down, sizeof(bool));
+ int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
+ memcpy(params + 5, &freq_base, sizeof(float));
+ memcpy(params + 6, &freq_scale, sizeof(float));
+ memcpy(params + 7, &ext_factor, sizeof(float));
+ memcpy(params + 8, &attn_factor, sizeof(float));
+ memcpy(params + 9, &beta_fast, sizeof(float));
+ memcpy(params + 10, &beta_slow, sizeof(float));
+ memcpy(params + 11, &xpos_base, sizeof(float));
+ memcpy(params + 12, &xpos_down, sizeof(bool));
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE;
@@ -4884,7 +4894,9 @@ struct ggml_tensor * ggml_rope(
int n_dims,
int mode,
int n_ctx) {
- return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
+ return ggml_rope_impl(
+ ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
+ );
}
struct ggml_tensor * ggml_rope_inplace(
@@ -4894,7 +4906,9 @@ struct ggml_tensor * ggml_rope_inplace(
int n_dims,
int mode,
int n_ctx) {
- return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
+ return ggml_rope_impl(
+ ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
+ );
}
struct ggml_tensor * ggml_rope_custom(
@@ -4904,9 +4918,17 @@ struct ggml_tensor * ggml_rope_custom(
int n_dims,
int mode,
int n_ctx,
+ int n_orig_ctx,
float freq_base,
- float freq_scale) {
- return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow) {
+ return ggml_rope_impl(
+ ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
+ );
}
struct ggml_tensor * ggml_rope_custom_inplace(
@@ -4916,9 +4938,17 @@ struct ggml_tensor * ggml_rope_custom_inplace(
int n_dims,
int mode,
int n_ctx,
+ int n_orig_ctx,
float freq_base,
- float freq_scale) {
- return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow) {
+ return ggml_rope_impl(
+ ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
+ );
}
struct ggml_tensor * ggml_rope_xpos_inplace(
@@ -4928,7 +4958,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
int n_dims,
float base,
bool down) {
- return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
+ return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
}
// ggml_rope_back
@@ -10901,6 +10931,45 @@ static void ggml_compute_forward_clamp(
// ggml_compute_forward_rope
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+ const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
+ return 1 - MIN(1, MAX(0, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+ float * cos_theta, float * sin_theta
+) {
+ // Get n-d rotational scaling corrected for extrapolation
+ float theta_interp = freq_scale * theta_extrap;
+ float theta = theta_interp;
+ if (ext_factor != 0.0f) {
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+ // Get n-d magnitude scaling corrected for interpolation
+ mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+ }
+ *cos_theta = cosf(theta) * mscale;
+ *sin_theta = sinf(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
+ return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
+}
+
+void ggml_rope_yarn_corr_dims(
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+ // start and end correction dims
+ dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
+ dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
+}
+
static void ggml_compute_forward_rope_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
@@ -10910,21 +10979,26 @@ static void ggml_compute_forward_rope_f32(
return;
}
- float freq_base;
- float freq_scale;
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
// these two only relevant for xPos RoPE:
float xpos_base;
bool xpos_down;
- //const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_dims = ((int32_t *) dst->op_params)[1];
- const int mode = ((int32_t *) dst->op_params)[2];
- const int n_ctx = ((int32_t *) dst->op_params)[3];
- memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
- memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
- memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
- memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
+ //const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ const int mode = ((int32_t *) dst->op_params)[2];
+ const int n_ctx = ((int32_t *) dst->op_params)[3];
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+ memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
+ memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
GGML_TENSOR_UNARY_OP_LOCALS
@@ -10952,6 +11026,9 @@ static void ggml_compute_forward_rope_f32(
int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims);
+ const float inv_ndims = -1.f/n_dims;
+ float corr_dims[2];
+ ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
@@ -10965,18 +11042,18 @@ static void ggml_compute_forward_rope_f32(
if (ir++ < ir0) continue;
if (ir > ir1) break;
- float theta = freq_scale * (float)p;
+ float theta_base = (float)p;
if (is_glm) {
- theta = MIN(p, n_ctx - 2);
+ theta_base = MIN(p, n_ctx - 2);
float block_theta = MAX(p - (n_ctx - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
- const float cos_theta = cosf(theta);
- const float sin_theta = sinf(theta);
+ const float cos_theta = cosf(theta_base);
+ const float sin_theta = sinf(theta_base);
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta);
- theta *= theta_scale;
+ theta_base *= theta_scale;
block_theta *= theta_scale;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -10994,13 +11071,16 @@ static void ggml_compute_forward_rope_f32(
}
} else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
- const float cos_theta = cosf(theta);
- const float sin_theta = sinf(theta);
+ float cos_theta, sin_theta;
+ rope_yarn(
+ theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
+ );
+
// zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
if (xpos_down) zeta = 1.0f / zeta;
- theta *= theta_scale;
+ theta_base *= theta_scale;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -11014,12 +11094,19 @@ static void ggml_compute_forward_rope_f32(
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+ theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
- const float cos_theta = cosf(theta);
- const float sin_theta = sinf(theta);
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
+ float cur_rot = inv_ndims * ic - ib;
+
+ float cos_theta, sin_theta;
+ rope_yarn(
+ theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+ &cos_theta, &sin_theta
+ );
- theta *= theta_scale;
+ theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2;
@@ -11048,15 +11135,19 @@ static void ggml_compute_forward_rope_f16(
return;
}
- float freq_base;
- float freq_scale;
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
- //const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_dims = ((int32_t *) dst->op_params)[1];
- const int mode = ((int32_t *) dst->op_params)[2];
- const int n_ctx = ((int32_t *) dst->op_params)[3];
- memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
- memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+ //const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ const int mode = ((int32_t *) dst->op_params)[2];
+ const int n_ctx = ((int32_t *) dst->op_params)[3];
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
GGML_TENSOR_UNARY_OP_LOCALS
@@ -11084,6 +11175,9 @@ static void ggml_compute_forward_rope_f16(
int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims);
+ const float inv_ndims = -1.f/n_dims;
+ float corr_dims[2];
+ ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
@@ -11097,18 +11191,18 @@ static void ggml_compute_forward_rope_f16(
if (ir++ < ir0) continue;
if (ir > ir1) break;
- float theta = freq_scale * (float)p;
+ float theta_base = (float)p;
if (is_glm) {
- theta = MIN(p, n_ctx - 2);
+ theta_base = MIN(p, n_ctx - 2);
float block_theta = MAX(p - (n_ctx - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
- const float cos_theta = cosf(theta);
- const float sin_theta = sinf(theta);
+ const float cos_theta = cosf(theta_base);
+ const float sin_theta = sinf(theta_base);
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta);
- theta *= theta_scale;
+ theta_base *= theta_scale;
block_theta *= theta_scale;
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -11126,10 +11220,12 @@ static void ggml_compute_forward_rope_f16(
}
} else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
- const float cos_theta = cosf(theta);
- const float sin_theta = sinf(theta);
+ float cos_theta, sin_theta;
+ rope_yarn(
+ theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
+ );
- theta *= theta_scale;
+ theta_base *= theta_scale;
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -11143,12 +11239,19 @@ static void ggml_compute_forward_rope_f16(
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+ theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
- const float cos_theta = cosf(theta);
- const float sin_theta = sinf(theta);
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
+ float cur_rot = inv_ndims * ic - ib;
- theta *= theta_scale;
+ float cos_theta, sin_theta;
+ rope_yarn(
+ theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+ &cos_theta, &sin_theta
+ );
+
+ theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2;
@@ -11256,17 +11359,18 @@ static void ggml_compute_forward_rope_back_f32(
if (ir++ < ir0) continue;
if (ir > ir1) break;
- float theta = freq_scale * (float)p;
+ float theta_base = freq_scale * (float)p;
if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
- const float cos_theta = cosf(theta);
- const float sin_theta = sinf(theta);
+ const float cos_theta = cosf(theta_base);
+ const float sin_theta = sinf(theta_base);
+
// zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
if (xpos_down) zeta = 1.0f / zeta;
- theta *= theta_scale;
+ theta_base *= theta_scale;
const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -11280,10 +11384,10 @@ static void ggml_compute_forward_rope_back_f32(
} else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
- const float cos_theta = cosf(theta);
- const float sin_theta = sinf(theta);
+ const float cos_theta = cosf(theta_base);
+ const float sin_theta = sinf(theta_base);
- theta *= theta_scale;
+ theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2;
@@ -11356,14 +11460,14 @@ static void ggml_compute_forward_rope_back_f16(
if (ir++ < ir0) continue;
if (ir > ir1) break;
- float theta = (float)p;
+ float theta_base = (float)p;
if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
- const float cos_theta = cosf(theta);
- const float sin_theta = sinf(theta);
+ const float cos_theta = cosf(theta_base);
+ const float sin_theta = sinf(theta_base);
- theta *= theta_scale;
+ theta_base *= theta_scale;
const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -11377,10 +11481,10 @@ static void ggml_compute_forward_rope_back_f16(
} else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
- const float cos_theta = cosf(theta);
- const float sin_theta = sinf(theta);
+ const float cos_theta = cosf(theta_base);
+ const float sin_theta = sinf(theta_base);
- theta *= theta_scale;
+ theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2;
@@ -15505,9 +15609,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src1,
n_dims,
mode,
+ 0,
n_ctx,
freq_base,
freq_scale,
+ 0.0f,
+ 1.0f,
+ 0.0f,
+ 0.0f,
xpos_base,
xpos_down,
false),