summaryrefslogtreecommitdiff
path: root/vulkan-shaders/rope_neox.comp
diff options
context:
space:
mode:
Diffstat (limited to 'vulkan-shaders/rope_neox.comp')
-rw-r--r--vulkan-shaders/rope_neox.comp37
1 files changed, 37 insertions, 0 deletions
diff --git a/vulkan-shaders/rope_neox.comp b/vulkan-shaders/rope_neox.comp
new file mode 100644
index 00000000..83b46b69
--- /dev/null
+++ b/vulkan-shaders/rope_neox.comp
@@ -0,0 +1,37 @@
+#version 450
+
+#include "rope_head.comp"
+
+void main() {
+ const uint col = gl_GlobalInvocationID.y * 2;
+ const uint row = gl_GlobalInvocationID.x;
+
+ if (col >= p.ncols) {
+ return;
+ }
+
+ if (col >= p.n_dims) {
+ const uint i = row*p.ncols + col;
+
+ data_d[i + 0] = data_a[i + 0];
+ data_d[i + 1] = data_a[i + 1];
+
+ return;
+ }
+
+ const uint i = row*p.ncols + col/2;
+ const uint i2 = row/p.p_delta_rows;
+
+ const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
+
+ const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
+
+ float cos_theta, sin_theta;
+ rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
+
+ const float x0 = float(data_a[i + 0]);
+ const float x1 = float(data_a[i + p.n_dims/2]);
+
+ data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
+ data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
+}