diff options
author | 0cc4m <picard12@live.de> | 2024-05-18 08:10:58 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-18 08:10:58 +0200 |
commit | c1b295eea5c49887a066559527a74e8b94fe9db0 (patch) | |
tree | 29b2ef767f8ac67c5488dbc1eb1c0b2f53ffd2fd /ggml_vk_generate_shaders.py | |
parent | de731963441ff128248259e1b99573d75264d210 (diff) |
Update and fix Vulkan soft_max and argsort implementations (#7237)
* Update and fix Vulkan softmax implementation
* Update and fix Vulkan argsort implementation
Diffstat (limited to 'ggml_vk_generate_shaders.py')
-rw-r--r-- | ggml_vk_generate_shaders.py | 86 |
1 files changed, 59 insertions, 27 deletions
diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py index 162cf5c6..3c22a986 100644 --- a/ggml_vk_generate_shaders.py +++ b/ggml_vk_generate_shaders.py @@ -2432,7 +2432,6 @@ layout (push_constant) uniform parameter { uint KX; uint KY; - uint KZ; float scale; float max_bias; float m0; @@ -2449,8 +2448,7 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; -layout (binding = 2) readonly buffer Z {C_TYPE data_c[];}; -layout (binding = 3) buffer D {D_TYPE data_d[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; shared FLOAT_TYPE vals[BLOCK_SIZE]; @@ -2459,7 +2457,7 @@ void main() { const uint rowx = gl_WorkGroupID.x; const uint rowy = rowx % p.KY; - float slope = 0.0f; + float slope = 1.0f; // ALiBi if (p.max_bias > 0.0f) { @@ -2472,11 +2470,18 @@ void main() { } // Find max - vals[tid] = uintBitsToFloat(0xFF800000); + FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - vals[tid] = max(vals[tid], FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) + (p.KZ > 0 ? slope * FLOAT_TYPE(data_c[col]) : 0.0f)); + [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + + max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f))); } + vals[tid] = max_val; barrier(); [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { @@ -2486,15 +2491,21 @@ void main() { barrier(); } - const FLOAT_TYPE max_val = vals[0]; + max_val = vals[0]; barrier(); // Sum up values vals[tid] = FLOAT_TYPE(0.0f); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + const uint i = rowx * p.KX + col; - const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); + const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); vals[tid] += val; data_d[i] = D_TYPE(val); } @@ -2509,7 +2520,13 @@ void main() { const D_TYPE divisor = D_TYPE(vals[0]); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + data_d[rowx*p.KX + col] /= divisor; } } @@ -2672,20 +2689,26 @@ argsort_src = """ #extension GL_EXT_shader_16bit_storage : require -layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in; +#define BLOCK_SIZE 1024 +#define ASC 0 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) buffer D {int data_d[];}; layout (push_constant) uniform parameter { uint ncols; - bool ascending; + uint ncols_pad; + uint order; } p; +shared int dst_row[BLOCK_SIZE]; + void swap(uint idx0, uint idx1) { - int tmp = data_d[idx0]; - data_d[idx0] = data_d[idx1]; - data_d[idx1] = tmp; + int tmp = dst_row[idx0]; + dst_row[idx0] = dst_row[idx1]; + dst_row[idx1] = tmp; } void main() { @@ -2693,36 +2716,45 @@ void main() { const int col = int(gl_LocalInvocationID.x); const uint row = gl_WorkGroupID.y; - if (col >= p.ncols) { + if (col >= p.ncols_pad) { return; } - const uint a_idx = row * p.ncols; - const uint d_idx = row * p.ncols; + const uint row_offset = row * p.ncols; // initialize indices - if (col < p.ncols) { - data_d[col] = col; - } + dst_row[col] = col; barrier(); - for (uint k = 2; k <= p.ncols; k *= 2) { + for (uint k = 2; k <= p.ncols_pad; k *= 2) { for (uint j = k / 2; j > 0; j /= 2) { const uint ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]]) { - swap(d_idx + col, d_idx + ixj); + if (dst_row[col] >= p.ncols || + (dst_row[ixj] < p.ncols && (p.order == ASC ? + data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] : + data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]])) + ) { + swap(col, ixj); } } else { - if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]]) { - swap(d_idx + col, d_idx + ixj); + if (dst_row[ixj] >= p.ncols || + (dst_row[col] < p.ncols && (p.order == ASC ? + data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] : + data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]])) + ) { + swap(col, ixj); } } } barrier(); } } + + if (col < p.ncols) { + data_d[row_offset + col] = dst_row[col]; + } } """ |