summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c74
1 files changed, 36 insertions, 38 deletions
diff --git a/ggml.c b/ggml.c
index e6e2397b..b2b725f6 100644
--- a/ggml.c
+++ b/ggml.c
@@ -3221,7 +3221,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}
-static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
+GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
+ return ggml_is_contiguous(tensor);
+}
+
+GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return
@@ -3230,6 +3234,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}
+GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ tensor->nb[0] == ggml_type_size(tensor->type) &&
+ tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
+}
+
GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
@@ -11420,8 +11432,8 @@ static void ggml_compute_forward_gelu_f32(
const struct ggml_tensor * src0 = dst->src[0];
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11483,8 +11495,8 @@ static void ggml_compute_forward_gelu_quick_f32(
const struct ggml_tensor * src0 = dst->src[0];
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11546,8 +11558,8 @@ static void ggml_compute_forward_silu_f32(
const struct ggml_tensor * src0 = dst->src[0];
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11658,9 +11670,9 @@ static void ggml_compute_forward_silu_back_f32(
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * grad = dst->src[1];
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad));
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
+ GGML_ASSERT(ggml_is_contiguous_1(grad));
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src0, grad));
@@ -14358,7 +14370,7 @@ static void ggml_compute_forward_rope_f32(
int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims);
- const float inv_ndims = -1.f/n_dims;
+
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
@@ -14407,7 +14419,7 @@ static void ggml_compute_forward_rope_f32(
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta) * sin_sign;
- theta_base *= theta_scale;
+ theta_base *= theta_scale;
block_theta *= theta_scale;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14442,29 +14454,22 @@ static void ggml_compute_forward_rope_f32(
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
}
} else {
- // TODO: this might be wrong for ne0 != n_dims - need double check
- // it seems we have to rope just the first n_dims elements and do nothing with the rest
- // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
- theta_base *= freq_scale;
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
- const int64_t ib = 0;
+ const int64_t i0 = ic/2;
- // simplified from `(ib * n_dims + ic) * inv_ndims`
- float cur_rot = inv_ndims * ic - ib;
- float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
+ const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(
- theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+ theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
- sin_theta *= sin_sign;
+ sin_theta *= sin_sign;
theta_base *= theta_scale;
- const int64_t i0 = ib*n_dims + ic/2;
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -14543,7 +14548,7 @@ static void ggml_compute_forward_rope_f16(
int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims);
- const float inv_ndims = -1.f/n_dims;
+
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
@@ -14592,7 +14597,7 @@ static void ggml_compute_forward_rope_f16(
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta) * sin_sign;
- theta_base *= theta_scale;
+ theta_base *= theta_scale;
block_theta *= theta_scale;
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14623,29 +14628,22 @@ static void ggml_compute_forward_rope_f16(
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} else {
- // TODO: this might be wrong for ne0 != n_dims - need double check
- // it seems we have to rope just the first n_dims elements and do nothing with the rest
- // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
- theta_base *= freq_scale;
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
- const int64_t ib = 0;
+ const int64_t i0 = ic/2;
- // simplified from `(ib * n_dims + ic) * inv_ndims`
- float cur_rot = inv_ndims * ic - ib;
- float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
+ const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(
- theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+ theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
- sin_theta *= sin_sign;
+ sin_theta *= sin_sign;
theta_base *= theta_scale;
- const int64_t i0 = ib*n_dims + ic/2;
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);