summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c80
1 files changed, 68 insertions, 12 deletions
diff --git a/ggml.c b/ggml.c
index 4bd91152..37b16b7a 100644
--- a/ggml.c
+++ b/ggml.c
@@ -6231,6 +6231,7 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
@@ -6248,6 +6249,11 @@ static struct ggml_tensor * ggml_rope_impl(
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
+ if (c) {
+ GGML_ASSERT(c->type == GGML_TYPE_F32);
+ GGML_ASSERT(c->ne[0] >= n_dims / 2);
+ }
+
bool is_node = false;
if (a->grad) {
@@ -6271,6 +6277,7 @@ static struct ggml_tensor * ggml_rope_impl(
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
+ result->src[2] = c;
return result;
}
@@ -6283,7 +6290,7 @@ struct ggml_tensor * ggml_rope(
int mode,
int n_ctx) {
return ggml_rope_impl(
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
);
}
@@ -6295,14 +6302,15 @@ struct ggml_tensor * ggml_rope_inplace(
int mode,
int n_ctx) {
return ggml_rope_impl(
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
);
}
-struct ggml_tensor * ggml_rope_custom(
+struct ggml_tensor * ggml_rope_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
@@ -6314,15 +6322,16 @@ struct ggml_tensor * ggml_rope_custom(
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
);
}
-struct ggml_tensor * ggml_rope_custom_inplace(
+struct ggml_tensor * ggml_rope_ext_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
@@ -6334,19 +6343,49 @@ struct ggml_tensor * ggml_rope_custom_inplace(
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
);
}
-struct ggml_tensor * ggml_rope_xpos_inplace(
+struct ggml_tensor * ggml_rope_custom(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int n_dims,
- float base,
- bool down) {
- return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
+ int mode,
+ int n_ctx,
+ int n_orig_ctx,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow) {
+ return ggml_rope_impl(
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
+ );
+}
+
+struct ggml_tensor * ggml_rope_custom_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int n_dims,
+ int mode,
+ int n_ctx,
+ int n_orig_ctx,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow) {
+ return ggml_rope_impl(
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
+ );
}
// ggml_rope_back
@@ -6355,6 +6394,7 @@ struct ggml_tensor * ggml_rope_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
@@ -6370,6 +6410,7 @@ struct ggml_tensor * ggml_rope_back(
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
+ GGML_ASSERT(c == NULL && "freq factors not implemented yet");
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
@@ -14304,6 +14345,7 @@ static void ggml_compute_forward_rope_f32(
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
+ const struct ggml_tensor * src2 = dst->src[2];
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return;
@@ -14363,6 +14405,17 @@ static void ggml_compute_forward_rope_f32(
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
+ const float * freq_factors = NULL;
+ if (is_neox) {
+ if (src2 != NULL) {
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+ freq_factors = (const float *) src2->data;
+ }
+ } else {
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1");
+ }
+
// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
@@ -14439,10 +14492,11 @@ static void ggml_compute_forward_rope_f32(
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
sin_theta *= sin_sign;
@@ -18387,6 +18441,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
struct ggml_tensor * src0 = tensor->src[0];
struct ggml_tensor * src1 = tensor->src[1];
+ struct ggml_tensor * src2 = tensor->src[2];
switch (tensor->op) {
case GGML_OP_DUP:
@@ -18918,6 +18973,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_rope_back(ctx,
tensor->grad,
src1,
+ src2,
n_dims,
mode,
n_ctx,
@@ -18957,6 +19013,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_rope_impl(ctx,
tensor->grad,
src1,
+ src2,
n_dims,
mode,
n_ctx,
@@ -19038,7 +19095,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
masked);
}
- struct ggml_tensor * src2 = tensor->src[2];
const int64_t elem_q = ggml_nelements(src0);
const int64_t elem_k = ggml_nelements(src1);
const int64_t elem_v = ggml_nelements(src2);