summaryrefslogtreecommitdiff
path: root/ggml_vk_generate_shaders.py
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-05-29 20:17:31 +0300
committerGitHub <noreply@github.com>2024-05-29 20:17:31 +0300
commitfb76ec31a9914b7761c1727303ab30380fd4f05c (patch)
treea0bcc5041d8cf3373ad853bea4befd0b96e098d4 /ggml_vk_generate_shaders.py
parentcce3dcffc5695bd24835f04e6080070a2a119873 (diff)
ggml : fix YARN + add tests + add asserts (#7617)
* tests : add rope tests ggml-ci * ggml : fixes (hopefully) ggml-ci * tests : add non-cont tests ggml-ci * cuda : add asserts for rope/norm + fix DS2 ggml-ci * ggml : assert contiguousness * tests : reduce RoPE tests ggml-ci
Diffstat (limited to 'ggml_vk_generate_shaders.py')
-rw-r--r--ggml_vk_generate_shaders.py4
1 files changed, 1 insertions, 3 deletions
diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py
index a8f7373d..7c85ca7b 100644
--- a/ggml_vk_generate_shaders.py
+++ b/ggml_vk_generate_shaders.py
@@ -2670,14 +2670,12 @@ void main() {
const uint i = row*p.ncols + ib*p.ndims + ic/2;
const uint i2 = row/p.p_delta_rows;
- const float cur_rot = p.inv_ndims * ic - ib;
-
const int pos = data_b[i2];
const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f;
const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor;
float cos_theta, sin_theta;
- rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta);
+ rope_yarn(theta_base, ic, cos_theta, sin_theta);
const float x0 = float(data_a[i + 0]);
const float x1 = float(data_a[i + p.ndims/2]);