diff options
Diffstat (limited to 'ggml.c')
-rw-r--r-- | ggml.c | 46 |
1 files changed, 40 insertions, 6 deletions
@@ -4098,6 +4098,14 @@ struct ggml_tensor * ggml_mul_mat( return result; } +void ggml_mul_mat_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 0, prec_i32); +} + // ggml_mul_mat_id struct ggml_tensor * ggml_mul_mat_id( @@ -9168,6 +9176,8 @@ static void ggml_compute_forward_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps > 0.0f); + // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { @@ -9237,6 +9247,8 @@ static void ggml_compute_forward_rms_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps > 0.0f); + // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { @@ -11562,10 +11574,13 @@ static void ggml_compute_forward_rope_f32( } } else { // TODO: this might be wrong for ne0 != n_dims - need double check - // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + // it seems we have to rope just the first n_dims elements and do nothing with the rest + // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 theta_base *= freq_scale; - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 0; ic < n_dims; ic += 2) { + for (int64_t ic = 0; ic < ne0; ic += 2) { + if (ic < n_dims) { + const int64_t ib = 0; + // simplified from `(ib * n_dims + ic) * inv_ndims` float cur_rot = inv_ndims * ic - ib; @@ -11588,6 +11603,14 @@ static void ggml_compute_forward_rope_f32( dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + const int64_t i0 = ic; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } @@ -11715,10 +11738,13 @@ static void ggml_compute_forward_rope_f16( } } else { // TODO: this might be wrong for ne0 != n_dims - need double check - // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + // it seems we have to rope just the first n_dims elements and do nothing with the rest + // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 theta_base *= freq_scale; - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 0; ic < n_dims; ic += 2) { + for (int64_t ic = 0; ic < ne0; ic += 2) { + if (ic < n_dims) { + const int64_t ib = 0; + // simplified from `(ib * n_dims + ic) * inv_ndims` float cur_rot = inv_ndims * ic - ib; @@ -11741,6 +11767,14 @@ static void ggml_compute_forward_rope_f16( dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } else { + const int64_t i0 = ic; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } |