summaryrefslogtreecommitdiff
path: root/ggml_vk_generate_shaders.py
diff options
context:
space:
mode:
Diffstat (limited to 'ggml_vk_generate_shaders.py')
-rw-r--r--ggml_vk_generate_shaders.py4
1 files changed, 1 insertions, 3 deletions
diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py
index a8f7373d..7c85ca7b 100644
--- a/ggml_vk_generate_shaders.py
+++ b/ggml_vk_generate_shaders.py
@@ -2670,14 +2670,12 @@ void main() {
const uint i = row*p.ncols + ib*p.ndims + ic/2;
const uint i2 = row/p.p_delta_rows;
- const float cur_rot = p.inv_ndims * ic - ib;
-
const int pos = data_b[i2];
const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f;
const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor;
float cos_theta, sin_theta;
- rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta);
+ rope_yarn(theta_base, ic, cos_theta, sin_theta);
const float x0 = float(data_a[i + 0]);
const float x1 = float(data_a[i + p.ndims/2]);