summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/glu_main.comp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders/glu_main.comp')
-rw-r--r--ggml/src/vulkan-shaders/glu_main.comp29
1 files changed, 29 insertions, 0 deletions
diff --git a/ggml/src/vulkan-shaders/glu_main.comp b/ggml/src/vulkan-shaders/glu_main.comp
new file mode 100644
index 00000000..85cf65a9
--- /dev/null
+++ b/ggml/src/vulkan-shaders/glu_main.comp
@@ -0,0 +1,29 @@
+void main() {
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.N) {
+ return;
+ }
+
+ const uint row = i / p.ne20;
+ const uint col = i - row * p.ne20;
+
+ if (p.mode == 0) {
+ // Default
+ const uint offset = p.ne00 / 2;
+ const uint idx = row * p.ne00 + col;
+
+ data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
+ } else if (p.mode == 1) {
+ // Swapped
+ const uint offset = p.ne00 / 2;
+ const uint idx = row * p.ne00 + col;
+
+ data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
+ } else {
+ // Split
+ const uint idx = row * p.ne00 + col;
+
+ data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
+ }
+}