summaryrefslogtreecommitdiff
path: root/ggml_vk_generate_shaders.py
diff options
context:
space:
mode:
author0cc4m <picard12@live.de>2024-03-05 13:33:42 +0100
committerGitHub <noreply@github.com>2024-03-05 13:33:42 +0100
commit61d1c88e155515dd03940913a5707ea84a8b119b (patch)
treec2c7de9900b33a73a6fba4299523b54528676e1f /ggml_vk_generate_shaders.py
parent21b08674331e1ea1b599f17c5ca91f0ed173be31 (diff)
Vulkan Improvements (#5835)
* Improve dequant shaders, add fast q4_0 dequant * Optimize dmmv non-kquants for GCN Remove unnecessary SPIR-V shader duplication * Fix q4_0 dequant dispatch sizes Fix backend free bug * Optimize dequant shaders for q4_1, q5_0, q5_1 and q8_0 * Add unary and binary op shader templates * Fix Vulkan check results * Enable non-contiguous support for simple ops * Add argsort Basic q4_0 mmq shader and unit test * Speed up q4_0 dequant code, enable mmq for q4_0 * Rework matmul pipeline selection * Add soft_max alibi support * Add q4_1, q5_0, q5_1 and q8_0 dequant mat mat mul shaders * Add environment variable GGML_VK_FORCE_MAX_ALLOCATION_SIZE to limit max buffer size Rename GGML_VULKAN_DISABLE_F16 to GGML_VK_DISABLE_F16 for consistency
Diffstat (limited to 'ggml_vk_generate_shaders.py')
-rw-r--r--ggml_vk_generate_shaders.py1171
1 files changed, 712 insertions, 459 deletions
diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py
index b2e86e18..4a6f5e32 100644
--- a/ggml_vk_generate_shaders.py
+++ b/ggml_vk_generate_shaders.py
@@ -64,6 +64,7 @@ struct block_q5_0
#define A_TYPE block_q5_0
"""
shader_q5_1_defines = """
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#define QUANT_K 32
#define QUANT_R 2
@@ -187,7 +188,8 @@ v = (v - 16.0f) * d;
shader_q5_1_dequant_func = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \
const float m = float(data_a[ib].m); \
-const ivec2 qh = ivec2(((data_a[ib].qh >> iqs) << 4) & 0x10, (data_a[ib].qh >> (iqs + 12)) & 0x10); \
+const uint uint_qh = data_a[ib].qh; \
+const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \
const uint vui = uint(data_a[ib].qs[iqs]); \
vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
v = v*d + m;
@@ -206,12 +208,15 @@ mulmat_head = """#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
-#ifndef LOAD_VEC
-#define LOAD_VEC 1
+#ifndef LOAD_VEC_A
+#define LOAD_VEC_A 1
+#endif
+#ifndef LOAD_VEC_B
+#define LOAD_VEC_B 1
#endif
"""
-mulmat_body = """
+mulmat_body1 = """
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
@@ -240,7 +245,7 @@ layout (push_constant) uniform parameter
layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64;
-layout (constant_id = 3) const uint BK = 16;
+layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
layout (constant_id = 4) const uint WM = 32;
layout (constant_id = 5) const uint WN = 32;
layout (constant_id = 6) const uint WMITER = 2;
@@ -277,16 +282,19 @@ void main() {
const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM);
- const uint loadr = gl_LocalInvocationID.x % (BK / LOAD_VEC);
- const uint loadc = gl_LocalInvocationID.x / (BK / LOAD_VEC);
+ const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
+ const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
+ const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
+ const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
- const uint loadstride = gl_WorkGroupSize.x * LOAD_VEC / BK;
+ const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
+ const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
const uint start_k = ik * p.k_split;
const uint end_k = min(p.K, (ik + 1) * p.k_split);
- uint pos_a = (batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC;
- uint pos_b = (gl_GlobalInvocationID.z * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC;
+ uint pos_a = (batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
+ uint pos_b = (gl_GlobalInvocationID.z * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
float sums[WMITER * TM * WNITER * TN];
FLOAT_TYPE cache_a[WMITER * TM];
@@ -297,61 +305,145 @@ void main() {
}
[[unroll]] for (uint block = start_k; block < end_k; block += BK) {
- [[unroll]] for (uint l = 0; l < BM; l += loadstride) {
-#if LOAD_VEC == 8
- const uint idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
- buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx][0].x);
- buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx][0].y);
- buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx][0].z);
- buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx][0].w);
- buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(data_a[idx][1].x);
- buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(data_a[idx][1].y);
- buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_a[idx][1].z);
- buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_a[idx][1].w);
-#elif LOAD_VEC == 4
- const uint idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
- buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx].x);
- buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx].y);
- buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx].z);
- buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx].w);
+ [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {"""
+
+mulmat_load_scalar = """
+#if LOAD_VEC_A == 8
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+ buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
+ buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
+ buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
+ buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
+ buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
+ buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
+ buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
+#elif LOAD_VEC_A == 4
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+ buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
+ buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
+ buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
#else
- if (ir * BM + loadc + l < p.M && block + loadr < end_k) {
- buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_a[pos_a + (loadc + l) * p.stride_a + loadr]);
+ if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
+ buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
} else {
- buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
+ buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
}
#endif
+"""
+
+mulmat_load_q4_0 = """
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+
+ const uint ib = idx / 16;
+ const uint iqs = idx & 0xF;
+
+ const float d = float(data_a[ib].d);
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);"""
+
+mulmat_load_q4_1 = """
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+
+ const uint ib = idx / 16;
+ const uint iqs = idx & 0xF;
+
+ const float d = float(data_a[ib].d);
+ const float m = float(data_a[ib].m);
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);"""
+
+mulmat_load_q5_0 = """
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+
+ const uint ib = idx / 16;
+ const uint iqs = idx & 0xF;
+
+ const float d = float(data_a[ib].d);
+ const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
+ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);"""
+
+mulmat_load_q5_1 = """
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+
+ const uint ib = idx / 16;
+ const uint iqs = idx & 0xF;
+
+ const float d = float(data_a[ib].d);
+ const float m = float(data_a[ib].m);
+ const uint uint_qh = data_a[ib].qh;
+ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);"""
+
+mulmat_load_q8_0 = """
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+
+ const uint ib = idx / 16;
+ const uint iqs = (idx & 0xF) * 2;
+
+ const float d = float(data_a[ib].d);
+ const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);"""
+
+mulmat_body2 = """
}
- [[unroll]] for (uint l = 0; l < BN; l += loadstride) {
-#if LOAD_VEC == 8
- const uint idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
- buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx][0].x);
- buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx][0].y);
- buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx][0].z);
- buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx][0].w);
- buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(data_b[idx][1].x);
- buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(data_b[idx][1].y);
- buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_b[idx][1].z);
- buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_b[idx][1].w);
-#elif LOAD_VEC == 4
- const uint idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
- buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx].x);
- buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx].y);
- buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx].z);
- buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx].w);
+ [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
+#if LOAD_VEC_B == 8
+ const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
+ const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
+ buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
+ buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
+ buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
+ buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w);
+ buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x);
+ buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y);
+ buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z);
+ buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
+#elif LOAD_VEC_B == 4
+ const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
+ const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
+ buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
+ buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
+ buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
+ buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
#else
- if (ic * BN + loadc + l < p.N && block + loadr < end_k) {
- buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_b[pos_b + (loadc + l) * p.stride_b + loadr]);
+ if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
+ buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
} else {
- buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
+ buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
}
#endif
}
barrier();
- pos_a += BK / LOAD_VEC;
- pos_b += BK / LOAD_VEC;
+ pos_a += BK / LOAD_VEC_A;
+ pos_b += BK / LOAD_VEC_B;
for (uint i = 0; i < BK; i++) {
// Load from shared into cache
@@ -438,45 +530,191 @@ dequant_head = """#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+ uint M;
+ uint K;
+ uint stride_a;
+ uint stride_b;
+ uint nel;
+} p;
"""
-dequant_body = """
+dequant_f32_body = """
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-layout (push_constant) uniform parameter
-{
- int M;
- int K;
- int stride_a;
- int stride_b;
-} p;
+void main() {
+ const uint i = gl_GlobalInvocationID.x * 16;
+
+ if (i >= p.nel) {
+ return;
+ }
+
+ [[unroll]] for (uint l = 0; l < 16; l++) {
+ data_b[i + l] = D_TYPE(data_a[i + l]);
+ }
+}
+"""
+
+dequant_q4_0_body = """
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_q4_0 data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
- const int i = int(gl_GlobalInvocationID.x);
+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+
+ const uint tid = gl_LocalInvocationID.x % 64;
+ const uint il = tid/32;
+ const uint ir = tid%32;
+ const uint ib = 32*i + ir;
+ if (ib >= p.nel / 32) {
+ return;
+ }
+
+ const uint b_idx = 1024*i + 32*ir + 8*il;
+
+ const float d = float(data_a[ib].d);
+ const float dm = -8.0f * d;
- // Transposed
- const int row = i % (p.K / QUANT_K);
- const int col = i / (p.K / QUANT_K);
+ const uint q_idx = 8*il;
- if (row * QUANT_K >= p.K || col >= p.M) {
+ [[unroll]] for (uint l = 0; l < 8; ++l) {
+ data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + dm);
+ data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + dm);
+ }
+}
+"""
+
+dequant_q4_1_body = """
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_q4_1 data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+
+ const uint tid = gl_LocalInvocationID.x % 64;
+ const uint il = tid/32;
+ const uint ir = tid%32;
+ const uint ib = 32*i + ir;
+ if (ib >= p.nel / 32) {
return;
}
- const int stride_a = p.stride_a / QUANT_K;
+ const uint b_idx = 1024*i + 32*ir + 8*il;
- const int ib = col * stride_a + row;
+ const float d = float(data_a[ib].d);
+ const float m = float(data_a[ib].m);
- const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
- const int step = QUANT_R == 1 ? 2 : 1;
+ const uint q_idx = 8*il;
- [[unroll]] for (int iqs = 0; iqs < QUANT_K/QUANT_R; iqs += step) {
- DEQUANT_FUNC
+ [[unroll]] for (uint l = 0; l < 8; ++l) {
+ data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m);
+ data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m);
+ }
+}
+"""
+
+dequant_q5_0_body = """
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_q5_0 data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+
+ const uint tid = gl_LocalInvocationID.x % 64;
+ const uint il = tid/32;
+ const uint ir = tid%32;
+ const uint ib = 32*i + ir;
+ if (ib >= p.nel / 32) {
+ return;
+ }
+
+ const uint b_idx = 1024*i + 32*ir + 8*il;
+
+ const float d = float(data_a[ib].d);
+ const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
+
+ const uint q_idx = 8*il;
+
+ [[unroll]] for (uint l = 0; l < 8; ++l) {
+ const uint iqs = q_idx + l;
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f));
+ data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f));
+ }
+}
+"""
+
+dequant_q5_1_body = """
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_q5_1 data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+
+ const uint tid = gl_LocalInvocationID.x % 64;
+ const uint il = tid/32;
+ const uint ir = tid%32;
+ const uint ib = 32*i + ir;
+ if (ib >= p.nel / 32) {
+ return;
+ }
+
+ const uint b_idx = 1024*i + 32*ir + 8*il;
+
+ const float d = float(data_a[ib].d);
+ const float m = float(data_a[ib].m);
+ const uint qh = data_a[ib].qh;
+
+ const uint q_idx = 8*il;
+
+ [[unroll]] for (uint l = 0; l < 8; ++l) {
+ const uint iqs = q_idx + l;
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m);
+ data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m);
+ }
+}
+"""
+
+dequant_q8_0_body = """
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_q8_0 data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+
+ const uint tid = gl_LocalInvocationID.x % 64;
+ const uint il = tid/32;
+ const uint ir = tid%32;
+ const uint ib = 32*i + ir;
+ if (ib >= p.nel / 32) {
+ return;
+ }
- data_b[col * p.stride_b + row*QUANT_K + iqs + 0 ] = D_TYPE(v.x);
- data_b[col * p.stride_b + row*QUANT_K + iqs + y_offset] = D_TYPE(v.y);
+ const uint b_idx = 1024*i + 32*ir + 16*il;
+
+ const float d = float(data_a[ib].d);
+
+ const uint q_idx = 16*il;
+
+ [[unroll]] for (uint l = 0; l < 16; l += 2) {
+ data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]);
+ data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]);
}
}
"""
@@ -488,29 +726,21 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-layout (push_constant) uniform parameter
-{
- int M;
- int K;
- int stride_a;
- int stride_b;
-} p;
-
void main() {
- [[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
- const int i = int(gl_WorkGroupID.x * 256 + wgy);
+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
+ const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.M * p.K / QUANT_K) {
return;
}
- const int tid = int(gl_LocalInvocationID.x);
- const int ip = tid / 32;
- const int il = tid - 32 * ip;
- const int is = 8 * ip + il / 16;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint ip = tid / 32;
+ const uint il = tid - 32 * ip;
+ const uint is = 8 * ip + il / 16;
- const int y_idx = i * QUANT_K + 128 * ip + il;
+ const uint y_idx = i * QUANT_K + 128 * ip + il;
- const int ql_idx = 32 * ip + il;
+ const uint ql_idx = 32 * ip + il;
const uint8_t qs = data_a[i].qs[32 * ip + il];
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
@@ -528,31 +758,23 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-layout (push_constant) uniform parameter
-{
- int M;
- int K;
- int stride_a;
- int stride_b;
-} p;
-
void main() {
- [[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
- const int i = int(gl_WorkGroupID.x * 256 + wgy);
+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
+ const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
if (i >= p.M * p.K / QUANT_K) {
return;
}
- const int r = int(gl_LocalInvocationID.x) / 4;
- const int tid = r / 2;
- const int is0 = r % 2;
- const int l0 = 16 * is0 + 4 * (int(gl_LocalInvocationID.x) % 4);
- const int n = tid / 4;
- const int j = tid - 4*n;
+ const uint r = gl_LocalInvocationID.x / 4;
+ const uint tid = r / 2;
+ const uint is0 = r % 2;
+ const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4);
+ const uint n = tid / 4;
+ const uint j = tid - 4*n;
const uint8_t m = uint8_t(1 << (4*n + j));
- const int is = 8*n + 2*j + is0;
- const int shift = 2*j;
+ const uint is = 8*n + 2*j + is0;
+ const uint shift = 2*j;
const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) :
is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) :
@@ -561,10 +783,10 @@ void main() {
const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d);
const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32);
- const int y_idx = i * QUANT_K + 128 * n + 32 * j;
- const int qs_idx = 32*n;
+ const uint y_idx = i * QUANT_K + 128 * n + 32 * j;
+ const uint qs_idx = 32*n;
- for (int l = l0; l < l0 + 4; ++l) {
+ for (uint l = l0; l < l0 + 4; ++l) {
data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)));
}
}
@@ -576,32 +798,24 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-layout (push_constant) uniform parameter
-{
- int M;
- int K;
- int stride_a;
- int stride_b;
-} p;
-
void main() {
- [[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
- const int i = int(gl_WorkGroupID.x * 256 + wgy);
+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
+ const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.M * p.K / QUANT_K) {
return;
}
- const int tid = int(gl_LocalInvocationID.x);
- const int il = tid / 8;
- const int ir = tid % 8;
- const int is = 2 * il;
- const int n = 4;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint il = tid / 8;
+ const uint ir = tid % 8;
+ const uint is = 2 * il;
+ const uint n = 4;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
- const int y_idx = i * QUANT_K + 64 * il + n * ir;
- const int qs_idx = 32*il + n * ir;
+ const uint y_idx = i * QUANT_K + 64 * il + n * ir;
+ const uint qs_idx = 32*il + n * ir;
uint8_t sc;
uint8_t m;
@@ -625,7 +839,7 @@ void main() {
const FLOAT_TYPE d2 = dall * sc;
const FLOAT_TYPE m2 = dmin * m;
- [[unroll]] for (int l = 0; l < n; ++l) {
+ [[unroll]] for (uint l = 0; l < n; ++l) {
data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1);
data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >> 4) - m2);
}
@@ -638,32 +852,24 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-layout (push_constant) uniform parameter
-{
- int M;
- int K;
- int stride_a;
- int stride_b;
-} p;
-
void main() {
- [[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
- const int i = int(gl_WorkGroupID.x * 256 + wgy);
+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
+ const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.M * p.K / QUANT_K) {
return;
}
- const int tid = int(gl_LocalInvocationID.x);
- const int il = tid / 16;
- const int ir = tid % 16;
- const int is = 2 * il;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint il = tid / 16;
+ const uint ir = tid % 16;
+ const uint is = 2 * il;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
- const int y_idx = i * QUANT_K + 64 * il + 2 * ir;
- const int qs_idx = 32*il + 2 * ir;
- const int qh_idx = 2 * ir;
+ const uint y_idx = i * QUANT_K + 64 * il + 2 * ir;
+ const uint qs_idx = 32*il + 2 * ir;
+ const uint qh_idx = 2 * ir;
uint8_t sc;
uint8_t m;
@@ -702,28 +908,20 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-layout (push_constant) uniform parameter
-{
- int M;
- int K;
- int stride_a;
- int stride_b;
-} p;
-
void main() {
- [[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
- const int i = int(gl_WorkGroupID.x * 256 + wgy);
+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
+ const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.M * p.K / QUANT_K) {
return;
}
- const int tid = int(gl_LocalInvocationID.x);
- const int ip = tid / 32;
- const int il = tid - 32 * ip;
- const int is = 8 * ip + il / 16;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint ip = tid / 32;
+ const uint il = tid - 32 * ip;
+ const uint is = 8 * ip + il / 16;
- const int y_idx = i * QUANT_K + 128 * ip + il;
+ const uint y_idx = i * QUANT_K + 128 * ip + il;
- const int ql_idx = 64 * ip + il;
+ const uint ql_idx = 64 * ip + il;
const uint8_t qh = data_a[i].qh[32 * ip + il];
const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d);
@@ -742,49 +940,50 @@ mul_mat_vec_head = """#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+ uint ncols;
+ uint b_offset;
+ uint d_offset;
+} p;
"""
mul_mat_vec_body = """
-layout(local_size_x = QUANT_K, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-layout (push_constant) uniform parameter
-{
- int ncols;
- int b_offset;
- int d_offset;
-} p;
+layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-shared FLOAT_TYPE tmp[QUANT_K];
+shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() {
- const int block_size = int(gl_WorkGroupSize.x);
- const int row = int(gl_WorkGroupID.x);
- const int tid = int(gl_LocalInvocationID.x);
+ const uint row = gl_WorkGroupID.x;
+ const uint tid = gl_LocalInvocationID.x;
- const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
+ const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
tmp[tid] = FLOAT_TYPE(0.0f);
- [[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) {
- const int col = i*block_size + 2*tid;
- const int ib = (row*p.ncols + col)/QUANT_K; // block index
- const int iqs = (col%QUANT_K)/QUANT_R; // quant index
- const int iybs = col - col%QUANT_K; // y block start index
+ [[unroll]] for (uint i = 0; i < p.ncols/BLOCK_SIZE; i += 2) {
+ const uint col = i*BLOCK_SIZE + 2*tid;
+ const uint ib = (row*p.ncols + col)/QUANT_K; // block index
+ const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
+ const uint iybs = col - col%QUANT_K; // y block start index
DEQUANT_FUNC
// matrix multiplication
- tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[p.b_offset + iybs + iqs + 0]);
- tmp[tid] += FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[p.b_offset + iybs + iqs + y_offset]);
+ tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[p.b_offset + iybs + iqs + 0]) +
+ FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[p.b_offset + iybs + iqs + y_offset]);
}
// sum up partial sums and write back result
barrier();
- [[unroll]] for (int s = block_size/2; s > 0; s >>= 1) {
+ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
@@ -804,38 +1003,31 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-layout (push_constant) uniform parameter
-{
- int ncols;
- int b_offset;
- int d_offset;
-} p;
-
shared FLOAT_TYPE tmp[32];
void main() {
- const int row = int(gl_WorkGroupID.x);
+ const uint row = gl_WorkGroupID.x;
- const int num_blocks_per_row = p.ncols / QUANT_K;
- const int ib0 = row*num_blocks_per_row;
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+ const uint ib0 = row*num_blocks_per_row;
- const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
+ const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
- const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
+ const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
- const int v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
- const int v_in = tid - step*v_im; // 0...15 or 0...7
+ const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const uint v_in = tid - step*v_im; // 0...15 or 0...7
- const int l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
- const int q_offset = 32*v_im + l0;
- const int s_offset = 8*v_im;
- const int y_offset = 128*v_im + l0;
+ const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
+ const uint q_offset = 32*v_im + l0;
+ const uint s_offset = 8*v_im;
+ const uint y_offset = 128*v_im + l0;
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
- [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const int y_idx = i * QUANT_K + y_offset;
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+ const uint y_idx = i * QUANT_K + y_offset;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
@@ -865,7 +1057,7 @@ void main() {
// sum up partial sums and write back result
barrier();
- [[unroll]] for (int s = 16; s > 0; s >>= 1) {
+ [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
@@ -883,41 +1075,34 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-layout (push_constant) uniform parameter
-{
- int ncols;
- int b_offset;
- int d_offset;
-} p;
-
shared FLOAT_TYPE tmp[32];
void main() {
- const int row = int(gl_WorkGroupID.x);
+ const uint row = gl_WorkGroupID.x;
- const int num_blocks_per_row = p.ncols / QUANT_K;
- const int ib0 = row*num_blocks_per_row;
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+ const uint ib0 = row*num_blocks_per_row;
- const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
+ const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
- const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
+ const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
- const int v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
- const int v_in = tid - step*v_im; // 0...15 or 0...7
+ const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const uint v_in = tid - step*v_im; // 0...15 or 0...7
const uint8_t m = uint8_t(1 << (4 * v_im));
- const int l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
- const int q_offset = 32*v_im + l0;
- const int y_offset = 128*v_im + l0;
+ const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
+ const uint q_offset = 32*v_im + l0;
+ const uint y_offset = 128*v_im + l0;
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
const uint s_shift = 4 * v_im;
- [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const int y_idx = i * QUANT_K + y_offset;
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+ const uint y_idx = i * QUANT_K + y_offset;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
@@ -937,7 +1122,7 @@ void main() {
// sum up partial sums and write back result
barrier();
- [[unroll]] for (int s = 16; s > 0; s >>= 1) {
+ [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
@@ -955,42 +1140,35 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-layout (push_constant) uniform parameter
-{
- int ncols;
- int b_offset;
- int d_offset;
-} p;
-
shared FLOAT_TYPE tmp[32];
void main() {
- const int row = int(gl_WorkGroupID.x);
+ const uint row = gl_WorkGroupID.x;
- const int num_blocks_per_row = p.ncols / QUANT_K;
- const int ib0 = row*num_blocks_per_row;
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+ const uint ib0 = row*num_blocks_per_row;
- const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
+ const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
- const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
+ const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
- const int il = tid/step; // 0...3
- const int ir = tid - step*il; // 0...7 or 0...3
- const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
+ const uint il = tid/step; // 0...3
+ const uint ir = tid - step*il; // 0...7 or 0...3
+ const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
- const int v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
- const int v_in = il % 2;
+ const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+ const uint v_in = il % 2;
- const int l0 = n * (2 * ir + v_in); // 0...15
- const int q_offset = 32*v_im + l0;
- const int y_offset = 64*v_im + l0;
+ const uint l0 = n * (2 * ir + v_in); // 0...15
+ const uint q_offset = 32*v_im + l0;
+ const uint y_offset = 64*v_im + l0;
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
- [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const int y1_idx = i * QUANT_K + y_offset;
- const int y2_idx = y1_idx + 128;
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+ const uint y1_idx = i * QUANT_K + y_offset;
+ const uint y2_idx = y1_idx + 128;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
@@ -1058,7 +1236,7 @@ void main() {
// sum up partial sums and write back result
barrier();
- [[unroll]] for (int s = 16; s > 0; s >>= 1) {
+ [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
@@ -1076,42 +1254,35 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-layout (push_constant) uniform parameter
-{
- int ncols;
- int b_offset;
- int d_offset;
-} p;
-
shared FLOAT_TYPE tmp[32];
void main() {
- const int row = int(gl_WorkGroupID.x);
+ const uint row = gl_WorkGroupID.x;
- const int num_blocks_per_row = p.ncols / QUANT_K;
- const int ib0 = row*num_blocks_per_row;
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+ const uint ib0 = row*num_blocks_per_row;
- const int tid = int(gl_LocalInvocationID.x)/2; // 0...31 or 0...16
- const int ix = int(gl_LocalInvocationID.x)%2; // 0 or 0, 1
+ const uint tid = gl_LocalInvocationID.x/2; // 0...31 or 0...16
+ const uint ix = gl_LocalInvocationID.x%2; // 0 or 0, 1
- const int il = tid/4; // 0...3
- const int ir = tid - 4*il; // 0...7 or 0...3
+ const uint il = tid/4; // 0...3
+ const uint ir = tid - 4*il; // 0...7 or 0...3
- const int v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
- const int v_in = il % 2;
+ const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+ const uint v_in = il % 2;
- const int l0 = 4*ir + 2*v_in; // 0...15
- const int q_offset = 32*v_im + l0;
- const int y_offset = 64*v_im + l0;
+ const uint l0 = 4*ir + 2*v_in; // 0...15
+ const uint q_offset = 32*v_im + l0;
+ const uint y_offset = 64*v_im + l0;
const uint8_t hm1 = uint8_t(1 << (2*v_im));
const uint8_t hm2 = uint8_t(hm1 << 4);
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
- [[unroll]] for (int i = ix; i < num_blocks_per_row; i += 2) {
- const int y1_idx = i * QUANT_K + y_offset;
- const int y2_idx = y1_idx + 128;
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
+ const uint y1_idx = i * QUANT_K + y_offset;
+ const uint y2_idx = y1_idx + 128;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
@@ -1175,7 +1346,7 @@ void main() {
// sum up partial sums and write back result
barrier();
- [[unroll]] for (int s = 16; s > 0; s >>= 1) {
+ [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
@@ -1193,46 +1364,40 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-layout (push_constant) uniform parameter
-{
- int ncols;
- int b_offset;
- int d_offset;
-} p;
-
shared FLOAT_TYPE tmp[32];
void main() {
- const int row = int(gl_WorkGroupID.x);
+ const uint block_size = gl_WorkGroupSize.x;
+ const uint row = gl_WorkGroupID.x;
- const int num_blocks_per_row = p.ncols / QUANT_K;
- const int ib0 = row*num_blocks_per_row;
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+ const uint ib0 = row*num_blocks_per_row;
- const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
- const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
+ const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
- const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
+ const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
- const int v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
- const int v_in = tid - step*v_im; // 0...15 or 0...7
+ const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const uint v_in = tid - step*v_im; // 0...15 or 0...7
#if K_QUANTS_PER_ITERATION == 1
- const int l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
- const int is = 0;
+ const uint l0 = v_in; // 0...15
+ const uint is = 0;
#else
- const int l0 = 4 * v_in; // 0, 4, 8, ..., 28
- const int is = v_in / 4;
+ const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
+ const uint is = v_in / 4;
#endif
- const int ql_offset = 64*v_im + l0;
- const int qh_offset = 32*v_im + l0;
- const int s_offset = 8*v_im + is;
- const int y_offset = 128*v_im + l0;
+ const uint ql_offset = 64*v_im + l0;
+ const uint qh_offset = 32*v_im + l0;
+ const uint s_offset = 8*v_im + is;
+ const uint y_offset = 128*v_im + l0;
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
- [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const int y_idx = i * QUANT_K + y_offset;
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+ const uint y_idx = i * QUANT_K + y_offset;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
@@ -1260,7 +1425,7 @@ void main() {
// sum up partial sums and write back result
barrier();
- [[unroll]] for (int s = 16; s > 0; s >>= 1) {
+ [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
@@ -1421,34 +1586,6 @@ void main() {
}
"""
-# F16 to F32
-f32_to_f16_src = """#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {float data_a[];};
-layout (binding = 1) writeonly buffer D {float16_t data_b[];};
-
-layout (push_constant) uniform parameter
-{
- int M;
- int K;
- int stride_a;
- int stride_b;
-} p;
-
-void main() {
- const int row = int(gl_GlobalInvocationID.x % p.K);
- const int col = int(gl_GlobalInvocationID.x / p.K);
-
- if (row < p.K && col < p.M) {
- data_b[col * p.stride_b + row] = float16_t(data_a[col * p.stride_a + row]);
- }
-}
-"""
-
generic_head = """
#version 450
@@ -1463,136 +1600,147 @@ layout (push_constant) uniform parameter
} p;
"""
-# MUL F32
-mul_body = """layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint idx = gl_GlobalInvocationID.x;
+generic_unary_op_head = """#version 450
- if (idx >= p.KX) {
- return;
- }
+#extension GL_EXT_shader_16bit_storage : require
- data_d[idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(data_b[idx % p.KY]));
-}
-"""
+layout (push_constant) uniform parameter
+{
+ uint ne;
+ uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
+ uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
+ uint d_offset;
+ float param1; float param2;
+} p;
-# ADD
-add_body = """
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint idx = gl_GlobalInvocationID.x;
-
- if (idx >= p.KX) {
- return;
- }
-
- data_d[idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) + FLOAT_TYPE(data_b[idx % p.KY]));
-}
-"""
-
-# SCALE
-scale_body = """layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-void main() {
- const uint idx = gl_GlobalInvocationID.x;
-
- if (idx >= p.KX) {
- return;
- }
-
- data_d[idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(p.param1));
+uint src0_idx(uint idx) {
+ const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
+ const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
+ const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
+ const uint i02_offset = i02*p.ne01*p.ne00;
+ const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
+ const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+ return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
}
-"""
-
-# SQR
-sqr_body = """layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
- const uint idx = gl_GlobalInvocationID.x;
- if (idx >= p.KX) {
- return;
- }
-
- const FLOAT_TYPE val = FLOAT_TYPE(data_a[idx]);
- data_d[idx] = D_TYPE(val * val);
+uint dst_idx(uint idx) {
+ const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
+ const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
+ const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
+ const uint i12_offset = i12*p.ne11*p.ne10;
+ const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
+ const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
+ return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
}
-"""
-
-# CLAMP
-clamp_body = """layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
- const uint idx = gl_GlobalInvocationID.x;
-
- if (idx >= p.KX) {
+ if (gl_GlobalInvocationID.x >= p.ne) {
return;
}
-
- const FLOAT_TYPE val = FLOAT_TYPE(data_a[idx]);
- data_d[idx] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
-}
"""
-# CPY
-cpy_src = """#version 450
+generic_binary_op_head = """#version 450
#extension GL_EXT_shader_16bit_storage : require
layout (push_constant) uniform parameter
{
uint ne;
- uint ne00; uint ne01; uint nb00; uint nb01; uint nb02;
- uint ne10; uint ne11; uint nb10; uint nb11; uint nb12;
+ uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
+ uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
+ uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
uint d_offset;
+ uint param1; uint param2;
} p;
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+layout (binding = 1) readonly buffer B {A_TYPE data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+
+uint src0_idx(uint idx) {
+ const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
+ const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
+ const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
+ const uint i02_offset = i02*p.ne01*p.ne00;
+ const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
+ const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+ return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
+}
+
+uint src1_idx(uint idx) {
+ const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
+ const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
+ const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
+ const uint i02_offset = i02*p.ne01*p.ne00;
+ const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
+ const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+
+ return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10;
+}
+
+uint dst_idx(uint idx) {
+ const uint i23 = idx / (p.ne22*p.ne21*p.ne20);
+ const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
+ const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20);
+ const uint i22_offset = i22*p.ne21*p.ne20;
+ const uint i21 = (idx - i23_offset - i22_offset) / p.ne20;
+ const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20;
+ return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20;
+}
void main() {
if (gl_GlobalInvocationID.x >= p.ne) {
return;
}
+"""
+
+# MUL F32
+mul_body = """
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
+}
+"""
+
+# ADD
+add_body = """
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
+}
+"""
- const uint i02 = gl_GlobalInvocationID.x / (p.ne00*p.ne01);
- const uint i01 = (gl_GlobalInvocationID.x - i02*p.ne01*p.ne00) / p.ne00;
- const uint i00 = gl_GlobalInvocationID.x - i02*p.ne01*p.ne00 - i01*p.ne00;
- const uint a_idx = i00*p.nb00 + i01*p.nb01 + i02*p.nb02;
+# SCALE
+scale_body = """
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(p.param1));
+}
+"""
- const uint i12 = gl_GlobalInvocationID.x / (p.ne10*p.ne11);
- const uint i11 = (gl_GlobalInvocationID.x - i12*p.ne11*p.ne10) / p.ne10;
- const uint i10 = gl_GlobalInvocationID.x - i12*p.ne11*p.ne10 - i11*p.ne10;
- const uint d_idx = i10*p.nb10 + i11*p.nb11 + i12*p.nb12;
+# SQR
+sqr_body = """
+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val * val);
+}
"""
+
+# CLAMP
+clamp_body = """
+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
+}
+"""
+
+# CPY
cpy_end = """
- data_d[p.d_offset + d_idx] = D_TYPE(data_a[a_idx]);
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
}
"""
# Causes an optimization error otherwise
cpy_f16_f16_end = """
- data_d[p.d_offset + d_idx] = data_a[a_idx];
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = data_a[src0_idx(gl_GlobalInvocationID.x)];
}
"""
@@ -1815,6 +1963,24 @@ void main() {
"""
# SOFT_MAX
+soft_max_head = """
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+ uint KX;
+ uint KY;
+ uint KZ;
+ float scale;
+ float max_bias;
+ float m0;
+ float m1;
+ uint n_head_log2;
+} p;
+"""
+
soft_max_body = """
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
@@ -1823,7 +1989,8 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
-layout (binding = 2) buffer D {D_TYPE data_d[];};
+layout (binding = 2) readonly buffer Z {C_TYPE data_c[];};
+layout (binding = 3) buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE vals[BLOCK_SIZE];
@@ -1832,11 +1999,23 @@ void main() {
const uint rowx = gl_WorkGroupID.x;
const uint rowy = rowx % p.KY;
+ float slope = 0.0f;
+
+ // ALiBi
+ if (p.max_bias > 0.0f) {
+ const uint h = rowx/p.KY; // head index
+
+ const float base = h < p.n_head_log2 ? p.m0 : p.m1;
+ const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
+
+ slope = pow(base, exp);
+ }
+
// Find max
vals[tid] = uintBitsToFloat(0xFF800000);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
- vals[tid] = max(vals[tid], FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.param1 + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)));
+ vals[tid] = max(vals[tid], FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) + (p.KZ > 0 ? slope * data_c[col] : 0.0f));
}
barrier();
@@ -1855,7 +2034,7 @@ void main() {
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const uint i = rowx * p.KX + col;
- const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.param1 + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
+ const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
vals[tid] += val;
data_d[i] = D_TYPE(val);
}
@@ -2028,6 +2207,65 @@ void main() {
}
"""
+argsort_src = """
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+
+layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) buffer D {int data_d[];};
+
+layout (push_constant) uniform parameter {
+ uint ncols;
+ bool ascending;
+} p;
+
+void swap(uint idx0, uint idx1) {
+ int tmp = data_d[idx0];
+ data_d[idx0] = data_d[idx1];
+ data_d[idx1] = tmp;
+}
+
+void main() {
+ // bitonic sort
+ const int col = int(gl_LocalInvocationID.x);
+ const uint row = gl_WorkGroupID.y;
+
+ if (col >= p.ncols) {
+ return;
+ }
+
+ const uint a_idx = row * p.ncols;
+ const uint d_idx = row * p.ncols;
+
+ // initialize indices
+ if (col < p.ncols) {
+ data_d[col] = col;
+ }
+ barrier();
+
+ for (uint k = 2; k <= p.ncols; k *= 2) {
+ for (uint j = k / 2; j > 0; j /= 2) {
+ const uint ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]]) {
+ swap(d_idx + col, d_idx + ixj);
+ }
+ } else {
+ if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]]) {
+ swap(d_idx + col, d_idx + ixj);
+ }
+ }
+ }
+ barrier();
+ }
+ }
+}
+"""
+
GLSLC = "glslc"
VK_NUM_TYPES = 16
@@ -2129,6 +2367,8 @@ async def main():
tasks = []
+ stream = []
+
for fp16 in (False, True):
# mulmat
if fp16:
@@ -2142,28 +2382,41 @@ async def main():
vec_type_f16 = "f16vec4"
vec_type = "vec4"
- stream = []
- stream.extend((mulmat_head, shader_float_type, mulmat_body))
- tasks.append(string_to_spv("matmul_f32_l", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f32_m", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f32_s", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f32_aligned_l", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f32_aligned_m", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f32_aligned_s", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
-
- tasks.append(string_to_spv("matmul_f16_l", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f16_m", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f16_s", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f16_aligned_l", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f16_aligned_m", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f16_aligned_s", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
-
- tasks.append(string_to_spv("matmul_f16_f32_l", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f16_f32_m", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f16_f32_s", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f16_f32_aligned_l", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f16_f32_aligned_m", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
- tasks.append(string_to_spv("matmul_f16_f32_aligned_s", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
+ stream.clear()
+ stream.extend((mulmat_head, shader_float_type, mulmat_body1, mulmat_load_scalar, mulmat_body2))
+ tasks.append(string_to_spv("matmul_f32", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
+ tasks.append(string_to_spv("matmul_f32_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
+
+ tasks.append(string_to_spv("matmul_f16", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
+ tasks.append(string_to_spv("matmul_f16_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
+
+ tasks.append(string_to_spv("matmul_f16_f32", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
+ tasks.append(string_to_spv("matmul_f16_f32_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
+
+ stream.clear()
+ stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2))
+ tasks.append(string_to_spv("matmul_q4_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
+ tasks.append(string_to_spv("matmul_q4_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
+
+ stream.clear()
+ stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_1_defines, mulmat_body1, mulmat_load_q4_1, mulmat_body2))
+ tasks.append(string_to_spv("matmul_q4_1_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
+ tasks.append(string_to_spv("matmul_q4_1_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
+
+ stream.clear()
+ stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_0_defines, mulmat_body1, mulmat_load_q5_0, mulmat_body2))
+ tasks.append(string_to_spv("matmul_q5_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q5_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
+ tasks.append(string_to_spv("matmul_q5_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
+
+ stream.clear()
+ stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_1_defines, mulmat_body1, mulmat_load_q5_1, mulmat_body2))
+ tasks.append(string_to_spv("matmul_q5_1_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q5_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
+ tasks.append(string_to_spv("matmul_q5_1_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
+
+ stream.clear()
+ stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q8_0_defines, mulmat_body1, mulmat_load_q8_0, mulmat_body2))
+ tasks.append(string_to_spv("matmul_q8_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q8_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
+ tasks.append(string_to_spv("matmul_q8_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q8_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
# Shaders where precision is needed, so no fp16 version
@@ -2205,18 +2458,18 @@ async def main():
stream.extend((dequant_head, shader_int8_ext, shader_f32))
- if i == GGML_TYPE_F16:
- stream.extend((shader_f16_defines, shader_f16_dequant_func, dequant_body))
+ if i == GGML_TYPE_F32:
+ stream.append(dequant_f32_body)
elif i == GGML_TYPE_Q4_0:
- stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, dequant_body))
+ stream.extend((shader_q4_0_defines, dequant_q4_0_body))
elif i == GGML_TYPE_Q4_1:
- stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, dequant_body))
+ stream.extend((shader_q4_1_defines, dequant_q4_1_body))
elif i == GGML_TYPE_Q5_0:
- stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, dequant_body))
+ stream.extend((shader_q5_0_defines, dequant_q5_0_body))
elif i == GGML_TYPE_Q5_1:
- stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, dequant_body))
+ stream.extend((shader_q5_1_defines, dequant_q5_1_body))
elif i == GGML_TYPE_Q8_0:
- stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, dequant_body))
+ stream.extend((shader_q8_0_defines, dequant_q8_0_body))
elif i == GGML_TYPE_Q2_K:
stream.extend((shader_q2_K_defines, dequant_q2_K_body))
elif i == GGML_TYPE_Q3_K:
@@ -2232,8 +2485,6 @@ async def main():
tasks.append(string_to_spv(f"dequant_{type_names[i]}", "".join(stream), {"D_TYPE": "float16_t"}))
- tasks.append(string_to_spv("f32_to_f16", f32_to_f16_src, {}))
-
# get_rows
for i in range(0, VK_NUM_TYPES):
stream.clear()
@@ -2264,20 +2515,20 @@ async def main():
tasks.append(string_to_spv("norm_f32", f"{generic_head}\n{shader_f32}\n{norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rms_norm_f32", f"{generic_head}\n{shader_f32}\n{rms_norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("cpy_f32_f32", f"{cpy_src}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("cpy_f32_f16", f"{cpy_src}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float16_t"}))
- tasks.append(string_to_spv("cpy_f16_f16", f"{cpy_src}\n{cpy_f16_f16_end}", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
+ tasks.append(string_to_spv("cpy_f32_f32", f"{generic_unary_op_head}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("cpy_f32_f16", f"{generic_unary_op_head}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float16_t"}))
+ tasks.append(string_to_spv("cpy_f16_f16", f"{generic_unary_op_head}\n{cpy_f16_f16_end}", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
- tasks.append(string_to_spv("add_f32", f"{generic_head}\n{shader_f32}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("add_f32", f"{generic_binary_op_head}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {}))
- tasks.append(string_to_spv("mul_f32", f"{generic_head}\n{shader_f32}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("mul_f32", f"{generic_binary_op_head}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
- tasks.append(string_to_spv("scale_f32", f"{generic_head}\n{shader_f32}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("scale_f32", f"{generic_unary_op_head}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
- tasks.append(string_to_spv("sqr_f32", f"{generic_head}\n{shader_f32}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("sqr_f32", f"{generic_unary_op_head}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
- tasks.append(string_to_spv("clamp_f32", f"{generic_head}\n{shader_f32}\n{clamp_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("clamp_f32", f"{generic_unary_op_head}\n{clamp_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
tasks.append(string_to_spv("gelu_f32", f"{generic_head}\n{shader_f32}\n{gelu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("silu_f32", f"{generic_head}\n{shader_f32}\n{silu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
@@ -2285,7 +2536,7 @@ async def main():
tasks.append(string_to_spv("diag_mask_inf_f32", f"{diag_mask_inf_head}\n{shader_f32}\n{diag_mask_inf_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
- tasks.append(string_to_spv("soft_max_f32", f"{generic_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
+ tasks.append(string_to_spv("soft_max_f32", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "C_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
@@ -2293,6 +2544,8 @@ async def main():
tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
+ tasks.append(string_to_spv("argsort_f32", argsort_src, {"A_TYPE": "float"}))
+
# Helper to decorate tasks with semaphore acquisition.
async def withSemaphore(sem, task):
async with sem: