summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c20
1 files changed, 18 insertions, 2 deletions
diff --git a/ggml.c b/ggml.c
index 37b16b7a..d316e3d3 100644
--- a/ggml.c
+++ b/ggml.c
@@ -6245,6 +6245,8 @@ static struct ggml_tensor * ggml_rope_impl(
float xpos_base,
bool xpos_down,
bool inplace) {
+ GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
+
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
@@ -14413,7 +14415,7 @@ static void ggml_compute_forward_rope_f32(
freq_factors = (const float *) src2->data;
}
} else {
- GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1");
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
}
// backward process uses inverse rotation by cos and sin.
@@ -14529,6 +14531,7 @@ static void ggml_compute_forward_rope_f32(
}
}
+// TODO: deduplicate f16/f32 code
static void ggml_compute_forward_rope_f16(
const struct ggml_compute_params * params,
struct ggml_tensor * dst,
@@ -14536,6 +14539,7 @@ static void ggml_compute_forward_rope_f16(
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
+ const struct ggml_tensor * src2 = dst->src[2];
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return;
@@ -14588,6 +14592,17 @@ static void ggml_compute_forward_rope_f16(
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
+ const float * freq_factors = NULL;
+ if (is_neox) {
+ if (src2 != NULL) {
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+ freq_factors = (const float *) src2->data;
+ }
+ } else {
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
+ }
+
// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
@@ -14660,10 +14675,11 @@ static void ggml_compute_forward_rope_f16(
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
sin_theta *= sin_sign;