diff options
Diffstat (limited to 'ggml/src/vulkan-shaders/glu_main.comp')
-rw-r--r-- | ggml/src/vulkan-shaders/glu_main.comp | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/ggml/src/vulkan-shaders/glu_main.comp b/ggml/src/vulkan-shaders/glu_main.comp new file mode 100644 index 00000000..85cf65a9 --- /dev/null +++ b/ggml/src/vulkan-shaders/glu_main.comp @@ -0,0 +1,29 @@ +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.N) { + return; + } + + const uint row = i / p.ne20; + const uint col = i - row * p.ne20; + + if (p.mode == 0) { + // Default + const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; + + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + } else if (p.mode == 1) { + // Swapped + const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; + + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + } else { + // Split + const uint idx = row * p.ne00 + col; + + data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + } +} |