diff options
Diffstat (limited to 'ggml-vulkan.cpp')
-rw-r--r-- | ggml-vulkan.cpp | 22 |
1 files changed, 8 insertions, 14 deletions
diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 5e12ea9d..e0c512c0 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -3898,11 +3898,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const { const int mode = ((const int32_t *) dst->op_params)[2]; const bool is_neox = mode & 2; - const bool is_glm = mode & 4; - - if (is_glm) { - return nullptr; - } if (is_neox) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { @@ -4401,7 +4396,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; // const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; const float freq_base = ((float *) dst->op_params)[5]; const float freq_scale = ((float *) dst->op_params)[6]; const float ext_factor = ((float *) dst->op_params)[7]; @@ -4410,12 +4405,12 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con const float beta_slow = ((float *) dst->op_params)[10]; const bool is_neox = mode & 2; - const bool is_glm = mode & 4; - GGML_ASSERT(!is_glm); +#pragma message("TODO: update rope NORM mode to match NEOX mode") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634") float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); if (is_neox) { const float theta_scale = powf(freq_base, -2.0f/n_dims); @@ -6485,9 +6480,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const case GGML_OP_ROPE: { const int mode = ((const int32_t *) op->op_params)[2]; - const bool is_glm = mode & 4; - return !is_glm; + return true; } break; case GGML_OP_NONE: case GGML_OP_RESHAPE: @@ -6992,15 +6986,15 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ } else if (tensor->op == GGML_OP_ROPE) { const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; - const int n_ggml_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_orig_ggml_ctx = ((int32_t *) tensor->op_params)[4]; + //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4]; float freq_base = ((float *) tensor->op_params)[5]; float freq_scale = ((float *) tensor->op_params)[6]; float ext_factor = ((float *) tensor->op_params)[7]; float attn_factor = ((float *) tensor->op_params)[8]; float beta_fast = ((float *) tensor->op_params)[9]; float beta_slow = ((float *) tensor->op_params)[10]; - tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ggml_ctx, n_orig_ggml_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); } else if (tensor->op == GGML_OP_UNARY) { switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_SILU: |