diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-06-05 11:29:20 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-05 11:29:20 +0300 |
commit | 2b3389677a833cee0880226533a1768b1a9508d2 (patch) | |
tree | 3af4beed34ff6c1309d202a7028f5ab85ec43693 /ggml-sycl.cpp | |
parent | 9973e81c5ccf4f31b3980f5aa73f5cfea8699860 (diff) |
ggml : refactor rope norm/neox (#7634)
* ggml : unify rope norm/neox (CPU)
* ggml : fix compile warning
* ggml : remove GLM rope mode
ggml-ci
* metal : better rope implementation
ggml-ci
* cuda : better rope implementation
ggml-ci
* naming : n_orig_ctx -> n_ctx_orig
ggml-ci
* dev : add reminders to update backends
ggml-ci
* vulkan : fix ggml_rope_ext() usage
* cuda : fix array size + indents
ggml-ci
Diffstat (limited to 'ggml-sycl.cpp')
-rw-r--r-- | ggml-sycl.cpp | 74 |
1 files changed, 7 insertions, 67 deletions
diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 5cd97e4f..3ff76474 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -8928,49 +8928,6 @@ static void rope_neox( dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; } -static void rope_glm_f32( - const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, - int n_ctx -, const sycl::nd_item<3> &item_ct1) { - const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - const int half_n_dims = ncols/4; - - if (col >= half_n_dims) { - return; - } - - const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1); - const int i = row*ncols + col; - const int i2 = row/p_delta_rows; - - const float col_theta_scale = dpct::pow(freq_base, -2.0f * col / ncols); - // FIXME: this is likely wrong - const int p = pos != nullptr ? pos[i2] : 0; - - const float theta = sycl::min(p, n_ctx - 2) * freq_scale * col_theta_scale; - const float sin_theta = sycl::sin((float)theta); - const float cos_theta = sycl::cos((float)theta); - - const float x0 = x[i + 0]; - const float x1 = x[i + half_n_dims]; - - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; - - const float block_theta = - ((float)sycl::max(p - n_ctx - 2, 0)) * col_theta_scale; - const float sin_block_theta = sycl::sin((float)block_theta); - const float cos_block_theta = sycl::cos((float)block_theta); - - const float x2 = x[i + half_n_dims * 2]; - const float x3 = x[i + half_n_dims * 3]; - - dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta; - dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; -} - static void k_sum_rows_f32(const float * x, float * dst, const int ncols, const sycl::nd_item<3> &item_ct1) { const int row = item_ct1.get_group(1); @@ -12520,22 +12477,6 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows, } } -static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows, - const int32_t *pos, float freq_scale, - int p_delta_rows, float freq_base, int n_ctx, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % 4 == 0); - const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4); - const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE; - const sycl::range<3> block_nums(1, nrows, num_blocks_x); - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_glm_f32(x, dst, ncols, pos, freq_scale, - p_delta_rows, freq_base, n_ctx, - item_ct1); - }); -} - static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols, const int nrows, dpct::queue_ptr stream) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); @@ -14066,8 +14007,8 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; // RoPE alteration for extended context float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; @@ -14087,7 +14028,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, } const bool is_neox = mode & 2; - const bool is_glm = mode & 4; + +#pragma message("TODO: update rope NORM mode to match NEOX mode") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634") if (is_neox) { pos = (const int32_t *) src1_dd; @@ -14100,13 +14043,10 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, } rope_corr_dims corr_dims; - ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v); // compute - if (is_glm) { - GGML_ASSERT(false); - rope_glm_f32_sycl(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream); - } else if (is_neox) { + if (is_neox) { if (src0->type == GGML_TYPE_F32) { rope_neox_sycl( (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, |