From 2b3389677a833cee0880226533a1768b1a9508d2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 5 Jun 2024 11:29:20 +0300 Subject: ggml : refactor rope norm/neox (#7634) * ggml : unify rope norm/neox (CPU) * ggml : fix compile warning * ggml : remove GLM rope mode ggml-ci * metal : better rope implementation ggml-ci * cuda : better rope implementation ggml-ci * naming : n_orig_ctx -> n_ctx_orig ggml-ci * dev : add reminders to update backends ggml-ci * vulkan : fix ggml_rope_ext() usage * cuda : fix array size + indents ggml-ci --- ggml-kompute.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) (limited to 'ggml-kompute.cpp') diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index eabd70d5..5592741b 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1192,7 +1192,7 @@ static void ggml_vk_rope( const std::shared_ptr& inB, const std::shared_ptr& out, uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx, + ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow, int32_t ne01, int32_t ne02, int32_t ne03, uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03, @@ -1221,14 +1221,14 @@ static void ggml_vk_rope( struct PushConstants { uint32_t inAOff, inBOff, outOff; - int32_t n_dims, mode, n_orig_ctx; + int32_t n_dims, mode, n_ctx_orig; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; uint32_t nb00, nb01, nb02, nb03; int32_t ne0; uint32_t nb0, nb1, nb2, nb3; } pushConsts { safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size), - n_dims, mode, n_orig_ctx, + n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, nb00, nb01, nb02, nb03, ne0, @@ -1692,13 +1692,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225") GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); +#pragma message("TODO: update rope NORM mode to match NEOX mode") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634") + GGML_ASSERT(ne10 == ne02); GGML_ASSERT(src0t == dstt); // const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); @@ -1708,7 +1711,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); ggml_vk_rope( - seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx, + seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3 ); -- cgit v1.2.3