summaryrefslogtreecommitdiff
path: root/ggml_vk_generate_shaders.py
diff options
context:
space:
mode:
Diffstat (limited to 'ggml_vk_generate_shaders.py')
-rw-r--r--ggml_vk_generate_shaders.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py
index 8096c03b..a8f7373d 100644
--- a/ggml_vk_generate_shaders.py
+++ b/ggml_vk_generate_shaders.py
@@ -2609,7 +2609,8 @@ layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {int data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+layout (binding = 2) readonly buffer Z {float data_freq_factors[];};
+layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
@@ -2622,6 +2623,7 @@ layout (push_constant) uniform parameter {
float corr_dims[4];
float theta_scale;
float inv_ndims;
+ uint has_freq_facs;
} p;
float rope_yarn_ramp(const float low, const float high, const uint i0) {
@@ -2671,7 +2673,8 @@ void main() {
const float cur_rot = p.inv_ndims * ic - ib;
const int pos = data_b[i2];
- const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f);
+ const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f;
+ const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor;
float cos_theta, sin_theta;
rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta);