summaryrefslogtreecommitdiff
path: root/ggml_vk_generate_shaders.py
diff options
context:
space:
mode:
Diffstat (limited to 'ggml_vk_generate_shaders.py')
-rw-r--r--ggml_vk_generate_shaders.py86
1 files changed, 59 insertions, 27 deletions
diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py
index 162cf5c6..3c22a986 100644
--- a/ggml_vk_generate_shaders.py
+++ b/ggml_vk_generate_shaders.py
@@ -2432,7 +2432,6 @@ layout (push_constant) uniform parameter
{
uint KX;
uint KY;
- uint KZ;
float scale;
float max_bias;
float m0;
@@ -2449,8 +2448,7 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
-layout (binding = 2) readonly buffer Z {C_TYPE data_c[];};
-layout (binding = 3) buffer D {D_TYPE data_d[];};
+layout (binding = 2) buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE vals[BLOCK_SIZE];
@@ -2459,7 +2457,7 @@ void main() {
const uint rowx = gl_WorkGroupID.x;
const uint rowy = rowx % p.KY;
- float slope = 0.0f;
+ float slope = 1.0f;
// ALiBi
if (p.max_bias > 0.0f) {
@@ -2472,11 +2470,18 @@ void main() {
}
// Find max
- vals[tid] = uintBitsToFloat(0xFF800000);
+ FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- vals[tid] = max(vals[tid], FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) + (p.KZ > 0 ? slope * FLOAT_TYPE(data_c[col]) : 0.0f));
+ [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
+ const uint col = col0 + tid;
+
+ if (col >= p.KX) {
+ break;
+ }
+
+ max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)));
}
+ vals[tid] = max_val;
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
@@ -2486,15 +2491,21 @@ void main() {
barrier();
}
- const FLOAT_TYPE max_val = vals[0];
+ max_val = vals[0];
barrier();
// Sum up values
vals[tid] = FLOAT_TYPE(0.0f);
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+ [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
+ const uint col = col0 + tid;
+
+ if (col >= p.KX) {
+ break;
+ }
+
const uint i = rowx * p.KX + col;
- const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
+ const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
vals[tid] += val;
data_d[i] = D_TYPE(val);
}
@@ -2509,7 +2520,13 @@ void main() {
const D_TYPE divisor = D_TYPE(vals[0]);
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+ [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
+ const uint col = col0 + tid;
+
+ if (col >= p.KX) {
+ break;
+ }
+
data_d[rowx*p.KX + col] /= divisor;
}
}
@@ -2672,20 +2689,26 @@ argsort_src = """
#extension GL_EXT_shader_16bit_storage : require
-layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in;
+#define BLOCK_SIZE 1024
+#define ASC 0
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
- bool ascending;
+ uint ncols_pad;
+ uint order;
} p;
+shared int dst_row[BLOCK_SIZE];
+
void swap(uint idx0, uint idx1) {
- int tmp = data_d[idx0];
- data_d[idx0] = data_d[idx1];
- data_d[idx1] = tmp;
+ int tmp = dst_row[idx0];
+ dst_row[idx0] = dst_row[idx1];
+ dst_row[idx1] = tmp;
}
void main() {
@@ -2693,36 +2716,45 @@ void main() {
const int col = int(gl_LocalInvocationID.x);
const uint row = gl_WorkGroupID.y;
- if (col >= p.ncols) {
+ if (col >= p.ncols_pad) {
return;
}
- const uint a_idx = row * p.ncols;
- const uint d_idx = row * p.ncols;
+ const uint row_offset = row * p.ncols;
// initialize indices
- if (col < p.ncols) {
- data_d[col] = col;
- }
+ dst_row[col] = col;
barrier();
- for (uint k = 2; k <= p.ncols; k *= 2) {
+ for (uint k = 2; k <= p.ncols_pad; k *= 2) {
for (uint j = k / 2; j > 0; j /= 2) {
const uint ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
- if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]]) {
- swap(d_idx + col, d_idx + ixj);
+ if (dst_row[col] >= p.ncols ||
+ (dst_row[ixj] < p.ncols && (p.order == ASC ?
+ data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] :
+ data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]]))
+ ) {
+ swap(col, ixj);
}
} else {
- if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]]) {
- swap(d_idx + col, d_idx + ixj);
+ if (dst_row[ixj] >= p.ncols ||
+ (dst_row[col] < p.ncols && (p.order == ASC ?
+ data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] :
+ data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]]))
+ ) {
+ swap(col, ixj);
}
}
}
barrier();
}
}
+
+ if (col < p.ncols) {
+ data_d[row_offset + col] = dst_row[col];
+ }
}
"""