diff options
author | 0cc4m <picard12@live.de> | 2024-06-16 07:17:31 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-16 07:17:31 +0200 |
commit | 7c7836d9d4062d6858e3fb337b135c417ccee6ce (patch) | |
tree | c896967a106e2985763bf1c7bfd7bfb8cbe4f0fd /vulkan-shaders/argsort.comp | |
parent | 0c7b3595b9e5ad2355818e259f06b0dc3f0065b3 (diff) |
Vulkan Shader Refactor, Memory Debugging Option (#7947)
* Refactor shaders, extract GLSL code from ggml_vk_generate_shaders.py into vulkan-shaders directory
* Improve debug log code
* Add memory debug output option
* Fix flake8
* Fix unnecessary high llama-3 VRAM use
Diffstat (limited to 'vulkan-shaders/argsort.comp')
-rw-r--r-- | vulkan-shaders/argsort.comp | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/vulkan-shaders/argsort.comp b/vulkan-shaders/argsort.comp new file mode 100644 index 00000000..e55414b0 --- /dev/null +++ b/vulkan-shaders/argsort.comp @@ -0,0 +1,71 @@ +#version 450 + +#include "types.comp" + +#define BLOCK_SIZE 1024 +#define ASC 0 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) buffer D {int data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint ncols_pad; + uint order; +} p; + +shared int dst_row[BLOCK_SIZE]; + +void swap(uint idx0, uint idx1) { + int tmp = dst_row[idx0]; + dst_row[idx0] = dst_row[idx1]; + dst_row[idx1] = tmp; +} + +void main() { + // bitonic sort + const int col = int(gl_LocalInvocationID.x); + const uint row = gl_WorkGroupID.y; + + if (col >= p.ncols_pad) { + return; + } + + const uint row_offset = row * p.ncols; + + // initialize indices + dst_row[col] = col; + barrier(); + + for (uint k = 2; k <= p.ncols_pad; k *= 2) { + for (uint j = k / 2; j > 0; j /= 2) { + const uint ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= p.ncols || + (dst_row[ixj] < p.ncols && (p.order == ASC ? + data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] : + data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]])) + ) { + swap(col, ixj); + } + } else { + if (dst_row[ixj] >= p.ncols || + (dst_row[col] < p.ncols && (p.order == ASC ? + data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] : + data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]])) + ) { + swap(col, ixj); + } + } + } + barrier(); + } + } + + if (col < p.ncols) { + data_d[row_offset + col] = dst_row[col]; + } +} |