summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c46
1 files changed, 40 insertions, 6 deletions
diff --git a/ggml.c b/ggml.c
index ad546a73..6da65bd9 100644
--- a/ggml.c
+++ b/ggml.c
@@ -4098,6 +4098,14 @@ struct ggml_tensor * ggml_mul_mat(
return result;
}
+void ggml_mul_mat_set_prec(
+ struct ggml_tensor * a,
+ enum ggml_prec prec) {
+ const int32_t prec_i32 = (int32_t) prec;
+
+ ggml_set_op_params_i32(a, 0, prec_i32);
+}
+
// ggml_mul_mat_id
struct ggml_tensor * ggml_mul_mat_id(
@@ -9168,6 +9176,8 @@ static void ggml_compute_forward_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
+ GGML_ASSERT(eps > 0.0f);
+
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -9237,6 +9247,8 @@ static void ggml_compute_forward_rms_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
+ GGML_ASSERT(eps > 0.0f);
+
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -11562,10 +11574,13 @@ static void ggml_compute_forward_rope_f32(
}
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
- // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+ // it seems we have to rope just the first n_dims elements and do nothing with the rest
+ // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
- for (int64_t ic = 0; ic < n_dims; ic += 2) {
+ for (int64_t ic = 0; ic < ne0; ic += 2) {
+ if (ic < n_dims) {
+ const int64_t ib = 0;
+
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
@@ -11588,6 +11603,14 @@ static void ggml_compute_forward_rope_f32(
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+ } else {
+ const int64_t i0 = ic;
+
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ dst_data[0] = src[0];
+ dst_data[1] = src[1];
}
}
}
@@ -11715,10 +11738,13 @@ static void ggml_compute_forward_rope_f16(
}
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
- // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+ // it seems we have to rope just the first n_dims elements and do nothing with the rest
+ // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
- for (int64_t ic = 0; ic < n_dims; ic += 2) {
+ for (int64_t ic = 0; ic < ne0; ic += 2) {
+ if (ic < n_dims) {
+ const int64_t ib = 0;
+
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
@@ -11741,6 +11767,14 @@ static void ggml_compute_forward_rope_f16(
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+ } else {
+ const int64_t i0 = ic;
+
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ dst_data[0] = src[0];
+ dst_data[1] = src[1];
}
}
}