summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/roll.comp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders/roll.comp')
-rw-r--r--ggml/src/vulkan-shaders/roll.comp46
1 files changed, 46 insertions, 0 deletions
diff --git a/ggml/src/vulkan-shaders/roll.comp b/ggml/src/vulkan-shaders/roll.comp
new file mode 100644
index 00000000..b9abe8de
--- /dev/null
+++ b/ggml/src/vulkan-shaders/roll.comp
@@ -0,0 +1,46 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+uint wrap_idx(int i, uint ne) {
+ if (i < 0) {
+ return i + ne;
+ } else if (i >= ne) {
+ return i - ne;
+ }
+ return i;
+}
+
+void main() {
+ const uint idx = get_idx();
+ if (idx >= p.ne) {
+ return;
+ }
+
+ const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
+ const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
+ const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L);
+ const uint i2_offset = i2*p.ne11*p.ne10;
+ const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L);
+ const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
+
+ const uint p1 = floatBitsToUint(p.param1);
+ const uint p2 = floatBitsToUint(p.param2);
+ const int s0 = int(p1 >> 16) - 0x8000;
+ const int s1 = int(p1 & 0xFFFF) - 0x8000;
+ const int s2 = int(p2 >> 16) - 0x8000;
+ const int s3 = int(p2 & 0xFFFF) - 0x8000;
+
+ const uint i00 = wrap_idx(int(i0) - s0, p.ne10);
+ const uint i01 = wrap_idx(int(i1) - s1, p.ne11);
+ const uint i02 = wrap_idx(int(i2) - s2, p.ne12);
+ const uint i03 = wrap_idx(int(i3) - s3, p.ne13);
+
+ const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
+ const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10;
+
+ data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]);
+}