From 61d1c88e155515dd03940913a5707ea84a8b119b Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Tue, 5 Mar 2024 13:33:42 +0100 Subject: Vulkan Improvements (#5835) * Improve dequant shaders, add fast q4_0 dequant * Optimize dmmv non-kquants for GCN Remove unnecessary SPIR-V shader duplication * Fix q4_0 dequant dispatch sizes Fix backend free bug * Optimize dequant shaders for q4_1, q5_0, q5_1 and q8_0 * Add unary and binary op shader templates * Fix Vulkan check results * Enable non-contiguous support for simple ops * Add argsort Basic q4_0 mmq shader and unit test * Speed up q4_0 dequant code, enable mmq for q4_0 * Rework matmul pipeline selection * Add soft_max alibi support * Add q4_1, q5_0, q5_1 and q8_0 dequant mat mat mul shaders * Add environment variable GGML_VK_FORCE_MAX_ALLOCATION_SIZE to limit max buffer size Rename GGML_VULKAN_DISABLE_F16 to GGML_VK_DISABLE_F16 for consistency --- ggml_vk_generate_shaders.py | 1171 ++++++++++++++++++++++++++----------------- 1 file changed, 712 insertions(+), 459 deletions(-) (limited to 'ggml_vk_generate_shaders.py') diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py index b2e86e18..4a6f5e32 100644 --- a/ggml_vk_generate_shaders.py +++ b/ggml_vk_generate_shaders.py @@ -64,6 +64,7 @@ struct block_q5_0 #define A_TYPE block_q5_0 """ shader_q5_1_defines = """ +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #define QUANT_K 32 #define QUANT_R 2 @@ -187,7 +188,8 @@ v = (v - 16.0f) * d; shader_q5_1_dequant_func = """ #define DEQUANT_FUNC const float d = float(data_a[ib].d); \ const float m = float(data_a[ib].m); \ -const ivec2 qh = ivec2(((data_a[ib].qh >> iqs) << 4) & 0x10, (data_a[ib].qh >> (iqs + 12)) & 0x10); \ +const uint uint_qh = data_a[ib].qh; \ +const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \ const uint vui = uint(data_a[ib].qs[iqs]); \ vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \ v = v*d + m; @@ -206,12 +208,15 @@ mulmat_head = """#version 450 #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require -#ifndef LOAD_VEC -#define LOAD_VEC 1 +#ifndef LOAD_VEC_A +#define LOAD_VEC_A 1 +#endif +#ifndef LOAD_VEC_B +#define LOAD_VEC_B 1 #endif """ -mulmat_body = """ +mulmat_body1 = """ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; @@ -240,7 +245,7 @@ layout (push_constant) uniform parameter layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; -layout (constant_id = 3) const uint BK = 16; +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant layout (constant_id = 4) const uint WM = 32; layout (constant_id = 5) const uint WN = 32; layout (constant_id = 6) const uint WMITER = 2; @@ -277,16 +282,19 @@ void main() { const uint tiwr = tiw % (WSUBM / TM); const uint tiwc = tiw / (WSUBM / TM); - const uint loadr = gl_LocalInvocationID.x % (BK / LOAD_VEC); - const uint loadc = gl_LocalInvocationID.x / (BK / LOAD_VEC); + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); - const uint loadstride = gl_WorkGroupSize.x * LOAD_VEC / BK; + const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK; + const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; const uint start_k = ik * p.k_split; const uint end_k = min(p.K, (ik + 1) * p.k_split); - uint pos_a = (batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC; - uint pos_b = (gl_GlobalInvocationID.z * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC; + uint pos_a = (batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A; + uint pos_b = (gl_GlobalInvocationID.z * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; float sums[WMITER * TM * WNITER * TN]; FLOAT_TYPE cache_a[WMITER * TM]; @@ -297,61 +305,145 @@ void main() { } [[unroll]] for (uint block = start_k; block < end_k; block += BK) { - [[unroll]] for (uint l = 0; l < BM; l += loadstride) { -#if LOAD_VEC == 8 - const uint idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr; - buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx][0].x); - buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx][0].y); - buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx][0].z); - buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx][0].w); - buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(data_a[idx][1].x); - buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(data_a[idx][1].y); - buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_a[idx][1].z); - buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_a[idx][1].w); -#elif LOAD_VEC == 4 - const uint idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr; - buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx].x); - buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx].y); - buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx].z); - buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx].w); + [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {""" + +mulmat_load_scalar = """ +#if LOAD_VEC_A == 8 + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); + buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); + buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); + buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w); + buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x); + buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y); + buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z); + buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); +#elif LOAD_VEC_A == 4 + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); + buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); + buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); + buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); #else - if (ir * BM + loadc + l < p.M && block + loadr < end_k) { - buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_a[pos_a + (loadc + l) * p.stride_a + loadr]); + if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { + buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); } else { - buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f); + buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f); } #endif +""" + +mulmat_load_q4_0 = """ + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; + + const uint ib = idx / 16; + const uint iqs = idx & 0xF; + + const float d = float(data_a[ib].d); + const uint vui = uint(data_a[ib].qs[iqs]); + const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);""" + +mulmat_load_q4_1 = """ + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; + + const uint ib = idx / 16; + const uint iqs = idx & 0xF; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + const uint vui = uint(data_a[ib].qs[iqs]); + const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);""" + +mulmat_load_q5_0 = """ + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; + + const uint ib = idx / 16; + const uint iqs = idx & 0xF; + + const float d = float(data_a[ib].d); + const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[ib].qs[iqs]); + const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);""" + +mulmat_load_q5_1 = """ + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; + + const uint ib = idx / 16; + const uint iqs = idx & 0xF; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + const uint uint_qh = data_a[ib].qh; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[ib].qs[iqs]); + const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);""" + +mulmat_load_q8_0 = """ + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 16; + const uint iqs = (idx & 0xF) * 2; + + const float d = float(data_a[ib].d); + const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);""" + +mulmat_body2 = """ } - [[unroll]] for (uint l = 0; l < BN; l += loadstride) { -#if LOAD_VEC == 8 - const uint idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr; - buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx][0].x); - buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx][0].y); - buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx][0].z); - buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx][0].w); - buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(data_b[idx][1].x); - buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(data_b[idx][1].y); - buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_b[idx][1].z); - buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_b[idx][1].w); -#elif LOAD_VEC == 4 - const uint idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr; - buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx].x); - buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx].y); - buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx].z); - buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx].w); + [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { +#if LOAD_VEC_B == 8 + const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; + const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; + buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); + buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); + buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); + buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w); + buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x); + buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y); + buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z); + buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); +#elif LOAD_VEC_B == 4 + const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; + const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; + buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); + buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); + buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); + buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); #else - if (ic * BN + loadc + l < p.N && block + loadr < end_k) { - buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_b[pos_b + (loadc + l) * p.stride_b + loadr]); + if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { + buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); } else { - buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f); + buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); } #endif } barrier(); - pos_a += BK / LOAD_VEC; - pos_b += BK / LOAD_VEC; + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; for (uint i = 0; i < BK; i++) { // Load from shared into cache @@ -438,45 +530,191 @@ dequant_head = """#version 450 #extension GL_EXT_control_flow_attributes : require #extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint M; + uint K; + uint stride_a; + uint stride_b; + uint nel; +} p; """ -dequant_body = """ +dequant_f32_body = """ layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 0) readonly buffer A {float data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -layout (push_constant) uniform parameter -{ - int M; - int K; - int stride_a; - int stride_b; -} p; +void main() { + const uint i = gl_GlobalInvocationID.x * 16; + + if (i >= p.nel) { + return; + } + + [[unroll]] for (uint l = 0; l < 16; l++) { + data_b[i + l] = D_TYPE(data_a[i + l]); + } +} +""" + +dequant_q4_0_body = """ +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q4_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { - const int i = int(gl_GlobalInvocationID.x); + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const float dm = -8.0f * d; - // Transposed - const int row = i % (p.K / QUANT_K); - const int col = i / (p.K / QUANT_K); + const uint q_idx = 8*il; - if (row * QUANT_K >= p.K || col >= p.M) { + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + dm); + data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + dm); + } +} +""" + +dequant_q4_1_body = """ +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q4_1 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { return; } - const int stride_a = p.stride_a / QUANT_K; + const uint b_idx = 1024*i + 32*ir + 8*il; - const int ib = col * stride_a + row; + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); - const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; - const int step = QUANT_R == 1 ? 2 : 1; + const uint q_idx = 8*il; - [[unroll]] for (int iqs = 0; iqs < QUANT_K/QUANT_R; iqs += step) { - DEQUANT_FUNC + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m); + data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m); + } +} +""" + +dequant_q5_0_body = """ +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q5_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + const uint iqs = q_idx + l; + const uint vui = uint(data_a[ib].qs[iqs]); + data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f)); + data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f)); + } +} +""" + +dequant_q5_1_body = """ +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q5_1 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + const uint qh = data_a[ib].qh; + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + const uint iqs = q_idx + l; + const uint vui = uint(data_a[ib].qs[iqs]); + data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m); + data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m); + } +} +""" + +dequant_q8_0_body = """ +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q8_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } - data_b[col * p.stride_b + row*QUANT_K + iqs + 0 ] = D_TYPE(v.x); - data_b[col * p.stride_b + row*QUANT_K + iqs + y_offset] = D_TYPE(v.y); + const uint b_idx = 1024*i + 32*ir + 16*il; + + const float d = float(data_a[ib].d); + + const uint q_idx = 16*il; + + [[unroll]] for (uint l = 0; l < 16; l += 2) { + data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]); + data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]); } } """ @@ -488,29 +726,21 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -layout (push_constant) uniform parameter -{ - int M; - int K; - int stride_a; - int stride_b; -} p; - void main() { - [[unroll]] for (int wgy = 0; wgy < 256; wgy++) { - const int i = int(gl_WorkGroupID.x * 256 + wgy); + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = gl_WorkGroupID.x * 256 + wgy; if (i >= p.M * p.K / QUANT_K) { return; } - const int tid = int(gl_LocalInvocationID.x); - const int ip = tid / 32; - const int il = tid - 32 * ip; - const int is = 8 * ip + il / 16; + const uint tid = gl_LocalInvocationID.x; + const uint ip = tid / 32; + const uint il = tid - 32 * ip; + const uint is = 8 * ip + il / 16; - const int y_idx = i * QUANT_K + 128 * ip + il; + const uint y_idx = i * QUANT_K + 128 * ip + il; - const int ql_idx = 32 * ip + il; + const uint ql_idx = 32 * ip + il; const uint8_t qs = data_a[i].qs[32 * ip + il]; FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); @@ -528,31 +758,23 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -layout (push_constant) uniform parameter -{ - int M; - int K; - int stride_a; - int stride_b; -} p; - void main() { - [[unroll]] for (int wgy = 0; wgy < 256; wgy++) { - const int i = int(gl_WorkGroupID.x * 256 + wgy); + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = uint(gl_WorkGroupID.x * 256 + wgy); if (i >= p.M * p.K / QUANT_K) { return; } - const int r = int(gl_LocalInvocationID.x) / 4; - const int tid = r / 2; - const int is0 = r % 2; - const int l0 = 16 * is0 + 4 * (int(gl_LocalInvocationID.x) % 4); - const int n = tid / 4; - const int j = tid - 4*n; + const uint r = gl_LocalInvocationID.x / 4; + const uint tid = r / 2; + const uint is0 = r % 2; + const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4); + const uint n = tid / 4; + const uint j = tid - 4*n; const uint8_t m = uint8_t(1 << (4*n + j)); - const int is = 8*n + 2*j + is0; - const int shift = 2*j; + const uint is = 8*n + 2*j + is0; + const uint shift = 2*j; const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) : is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) : @@ -561,10 +783,10 @@ void main() { const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d); const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32); - const int y_idx = i * QUANT_K + 128 * n + 32 * j; - const int qs_idx = 32*n; + const uint y_idx = i * QUANT_K + 128 * n + 32 * j; + const uint qs_idx = 32*n; - for (int l = l0; l < l0 + 4; ++l) { + for (uint l = l0; l < l0 + 4; ++l) { data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4))); } } @@ -576,32 +798,24 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -layout (push_constant) uniform parameter -{ - int M; - int K; - int stride_a; - int stride_b; -} p; - void main() { - [[unroll]] for (int wgy = 0; wgy < 256; wgy++) { - const int i = int(gl_WorkGroupID.x * 256 + wgy); + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = gl_WorkGroupID.x * 256 + wgy; if (i >= p.M * p.K / QUANT_K) { return; } - const int tid = int(gl_LocalInvocationID.x); - const int il = tid / 8; - const int ir = tid % 8; - const int is = 2 * il; - const int n = 4; + const uint tid = gl_LocalInvocationID.x; + const uint il = tid / 8; + const uint ir = tid % 8; + const uint is = 2 * il; + const uint n = 4; const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); - const int y_idx = i * QUANT_K + 64 * il + n * ir; - const int qs_idx = 32*il + n * ir; + const uint y_idx = i * QUANT_K + 64 * il + n * ir; + const uint qs_idx = 32*il + n * ir; uint8_t sc; uint8_t m; @@ -625,7 +839,7 @@ void main() { const FLOAT_TYPE d2 = dall * sc; const FLOAT_TYPE m2 = dmin * m; - [[unroll]] for (int l = 0; l < n; ++l) { + [[unroll]] for (uint l = 0; l < n; ++l) { data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1); data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >> 4) - m2); } @@ -638,32 +852,24 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -layout (push_constant) uniform parameter -{ - int M; - int K; - int stride_a; - int stride_b; -} p; - void main() { - [[unroll]] for (int wgy = 0; wgy < 256; wgy++) { - const int i = int(gl_WorkGroupID.x * 256 + wgy); + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = gl_WorkGroupID.x * 256 + wgy; if (i >= p.M * p.K / QUANT_K) { return; } - const int tid = int(gl_LocalInvocationID.x); - const int il = tid / 16; - const int ir = tid % 16; - const int is = 2 * il; + const uint tid = gl_LocalInvocationID.x; + const uint il = tid / 16; + const uint ir = tid % 16; + const uint is = 2 * il; const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); - const int y_idx = i * QUANT_K + 64 * il + 2 * ir; - const int qs_idx = 32*il + 2 * ir; - const int qh_idx = 2 * ir; + const uint y_idx = i * QUANT_K + 64 * il + 2 * ir; + const uint qs_idx = 32*il + 2 * ir; + const uint qh_idx = 2 * ir; uint8_t sc; uint8_t m; @@ -702,28 +908,20 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -layout (push_constant) uniform parameter -{ - int M; - int K; - int stride_a; - int stride_b; -} p; - void main() { - [[unroll]] for (int wgy = 0; wgy < 256; wgy++) { - const int i = int(gl_WorkGroupID.x * 256 + wgy); + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = gl_WorkGroupID.x * 256 + wgy; if (i >= p.M * p.K / QUANT_K) { return; } - const int tid = int(gl_LocalInvocationID.x); - const int ip = tid / 32; - const int il = tid - 32 * ip; - const int is = 8 * ip + il / 16; + const uint tid = gl_LocalInvocationID.x; + const uint ip = tid / 32; + const uint il = tid - 32 * ip; + const uint is = 8 * ip + il / 16; - const int y_idx = i * QUANT_K + 128 * ip + il; + const uint y_idx = i * QUANT_K + 128 * ip + il; - const int ql_idx = 64 * ip + il; + const uint ql_idx = 64 * ip + il; const uint8_t qh = data_a[i].qh[32 * ip + il]; const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d); @@ -742,49 +940,50 @@ mul_mat_vec_head = """#version 450 #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_8bit_storage : require + +layout (push_constant) uniform parameter +{ + uint ncols; + uint b_offset; + uint d_offset; +} p; """ mul_mat_vec_body = """ -layout(local_size_x = QUANT_K, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; -layout (push_constant) uniform parameter -{ - int ncols; - int b_offset; - int d_offset; -} p; +layout (constant_id = 0) const uint BLOCK_SIZE = 32; -shared FLOAT_TYPE tmp[QUANT_K]; +shared FLOAT_TYPE tmp[BLOCK_SIZE]; void main() { - const int block_size = int(gl_WorkGroupSize.x); - const int row = int(gl_WorkGroupID.x); - const int tid = int(gl_LocalInvocationID.x); + const uint row = gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; - const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; tmp[tid] = FLOAT_TYPE(0.0f); - [[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*p.ncols + col)/QUANT_K; // block index - const int iqs = (col%QUANT_K)/QUANT_R; // quant index - const int iybs = col - col%QUANT_K; // y block start index + [[unroll]] for (uint i = 0; i < p.ncols/BLOCK_SIZE; i += 2) { + const uint col = i*BLOCK_SIZE + 2*tid; + const uint ib = (row*p.ncols + col)/QUANT_K; // block index + const uint iqs = (col%QUANT_K)/QUANT_R; // quant index + const uint iybs = col - col%QUANT_K; // y block start index DEQUANT_FUNC // matrix multiplication - tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[p.b_offset + iybs + iqs + 0]); - tmp[tid] += FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[p.b_offset + iybs + iqs + y_offset]); + tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[p.b_offset + iybs + iqs + 0]) + + FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[p.b_offset + iybs + iqs + y_offset]); } // sum up partial sums and write back result barrier(); - [[unroll]] for (int s = block_size/2; s > 0; s >>= 1) { + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { if (tid < s) { tmp[tid] += tmp[tid + s]; } @@ -804,38 +1003,31 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; -layout (push_constant) uniform parameter -{ - int ncols; - int b_offset; - int d_offset; -} p; - shared FLOAT_TYPE tmp[32]; void main() { - const int row = int(gl_WorkGroupID.x); + const uint row = gl_WorkGroupID.x; - const int num_blocks_per_row = p.ncols / QUANT_K; - const int ib0 = row*num_blocks_per_row; + const uint num_blocks_per_row = p.ncols / QUANT_K; + const uint ib0 = row*num_blocks_per_row; - const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - const int v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int v_in = tid - step*v_im; // 0...15 or 0...7 + const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = tid - step*v_im; // 0...15 or 0...7 - const int l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 - const int q_offset = 32*v_im + l0; - const int s_offset = 8*v_im; - const int y_offset = 128*v_im + l0; + const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint s_offset = 8*v_im; + const uint y_offset = 128*v_im + l0; tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - const int y_idx = i * QUANT_K + y_offset; + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const uint y_idx = i * QUANT_K + y_offset; const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y); @@ -865,7 +1057,7 @@ void main() { // sum up partial sums and write back result barrier(); - [[unroll]] for (int s = 16; s > 0; s >>= 1) { + [[unroll]] for (uint s = 16; s > 0; s >>= 1) { if (tid < s) { tmp[tid] += tmp[tid + s]; } @@ -883,41 +1075,34 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; -layout (push_constant) uniform parameter -{ - int ncols; - int b_offset; - int d_offset; -} p; - shared FLOAT_TYPE tmp[32]; void main() { - const int row = int(gl_WorkGroupID.x); + const uint row = gl_WorkGroupID.x; - const int num_blocks_per_row = p.ncols / QUANT_K; - const int ib0 = row*num_blocks_per_row; + const uint num_blocks_per_row = p.ncols / QUANT_K; + const uint ib0 = row*num_blocks_per_row; - const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - const int v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int v_in = tid - step*v_im; // 0...15 or 0...7 + const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = tid - step*v_im; // 0...15 or 0...7 const uint8_t m = uint8_t(1 << (4 * v_im)); - const int l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 - const int q_offset = 32*v_im + l0; - const int y_offset = 128*v_im + l0; + const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 128*v_im + l0; tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp const uint s_shift = 4 * v_im; - [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - const int y_idx = i * QUANT_K + y_offset; + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const uint y_idx = i * QUANT_K + y_offset; const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); @@ -937,7 +1122,7 @@ void main() { // sum up partial sums and write back result barrier(); - [[unroll]] for (int s = 16; s > 0; s >>= 1) { + [[unroll]] for (uint s = 16; s > 0; s >>= 1) { if (tid < s) { tmp[tid] += tmp[tid + s]; } @@ -955,42 +1140,35 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; -layout (push_constant) uniform parameter -{ - int ncols; - int b_offset; - int d_offset; -} p; - shared FLOAT_TYPE tmp[32]; void main() { - const int row = int(gl_WorkGroupID.x); + const uint row = gl_WorkGroupID.x; - const int num_blocks_per_row = p.ncols / QUANT_K; - const int ib0 = row*num_blocks_per_row; + const uint num_blocks_per_row = p.ncols / QUANT_K; + const uint ib0 = row*num_blocks_per_row; - const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 - const int il = tid/step; // 0...3 - const int ir = tid - step*il; // 0...7 or 0...3 - const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + const uint il = tid/step; // 0...3 + const uint ir = tid - step*il; // 0...7 or 0...3 + const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 - const int v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int v_in = il % 2; + const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const uint v_in = il % 2; - const int l0 = n * (2 * ir + v_in); // 0...15 - const int q_offset = 32*v_im + l0; - const int y_offset = 64*v_im + l0; + const uint l0 = n * (2 * ir + v_in); // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 64*v_im + l0; tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - const int y1_idx = i * QUANT_K + y_offset; - const int y2_idx = y1_idx + 128; + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y); @@ -1058,7 +1236,7 @@ void main() { // sum up partial sums and write back result barrier(); - [[unroll]] for (int s = 16; s > 0; s >>= 1) { + [[unroll]] for (uint s = 16; s > 0; s >>= 1) { if (tid < s) { tmp[tid] += tmp[tid + s]; } @@ -1076,42 +1254,35 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; -layout (push_constant) uniform parameter -{ - int ncols; - int b_offset; - int d_offset; -} p; - shared FLOAT_TYPE tmp[32]; void main() { - const int row = int(gl_WorkGroupID.x); + const uint row = gl_WorkGroupID.x; - const int num_blocks_per_row = p.ncols / QUANT_K; - const int ib0 = row*num_blocks_per_row; + const uint num_blocks_per_row = p.ncols / QUANT_K; + const uint ib0 = row*num_blocks_per_row; - const int tid = int(gl_LocalInvocationID.x)/2; // 0...31 or 0...16 - const int ix = int(gl_LocalInvocationID.x)%2; // 0 or 0, 1 + const uint tid = gl_LocalInvocationID.x/2; // 0...31 or 0...16 + const uint ix = gl_LocalInvocationID.x%2; // 0 or 0, 1 - const int il = tid/4; // 0...3 - const int ir = tid - 4*il; // 0...7 or 0...3 + const uint il = tid/4; // 0...3 + const uint ir = tid - 4*il; // 0...7 or 0...3 - const int v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int v_in = il % 2; + const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const uint v_in = il % 2; - const int l0 = 4*ir + 2*v_in; // 0...15 - const int q_offset = 32*v_im + l0; - const int y_offset = 64*v_im + l0; + const uint l0 = 4*ir + 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 64*v_im + l0; const uint8_t hm1 = uint8_t(1 << (2*v_im)); const uint8_t hm2 = uint8_t(hm1 << 4); tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - [[unroll]] for (int i = ix; i < num_blocks_per_row; i += 2) { - const int y1_idx = i * QUANT_K + y_offset; - const int y2_idx = y1_idx + 128; + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y); @@ -1175,7 +1346,7 @@ void main() { // sum up partial sums and write back result barrier(); - [[unroll]] for (int s = 16; s > 0; s >>= 1) { + [[unroll]] for (uint s = 16; s > 0; s >>= 1) { if (tid < s) { tmp[tid] += tmp[tid + s]; } @@ -1193,46 +1364,40 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; -layout (push_constant) uniform parameter -{ - int ncols; - int b_offset; - int d_offset; -} p; - shared FLOAT_TYPE tmp[32]; void main() { - const int row = int(gl_WorkGroupID.x); + const uint block_size = gl_WorkGroupSize.x; + const uint row = gl_WorkGroupID.x; - const int num_blocks_per_row = p.ncols / QUANT_K; - const int ib0 = row*num_blocks_per_row; + const uint num_blocks_per_row = p.ncols / QUANT_K; + const uint ib0 = row*num_blocks_per_row; - const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - const int v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int v_in = tid - step*v_im; // 0...15 or 0...7 + const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = tid - step*v_im; // 0...15 or 0...7 #if K_QUANTS_PER_ITERATION == 1 - const int l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 - const int is = 0; + const uint l0 = v_in; // 0...15 + const uint is = 0; #else - const int l0 = 4 * v_in; // 0, 4, 8, ..., 28 - const int is = v_in / 4; + const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 + const uint is = v_in / 4; #endif - const int ql_offset = 64*v_im + l0; - const int qh_offset = 32*v_im + l0; - const int s_offset = 8*v_im + is; - const int y_offset = 128*v_im + l0; + const uint ql_offset = 64*v_im + l0; + const uint qh_offset = 32*v_im + l0; + const uint s_offset = 8*v_im + is; + const uint y_offset = 128*v_im + l0; tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - const int y_idx = i * QUANT_K + y_offset; + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const uint y_idx = i * QUANT_K + y_offset; const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); @@ -1260,7 +1425,7 @@ void main() { // sum up partial sums and write back result barrier(); - [[unroll]] for (int s = 16; s > 0; s >>= 1) { + [[unroll]] for (uint s = 16; s > 0; s >>= 1) { if (tid < s) { tmp[tid] += tmp[tid + s]; } @@ -1421,34 +1586,6 @@ void main() { } """ -# F16 to F32 -f32_to_f16_src = """#version 450 - -#extension GL_EXT_shader_16bit_storage : require - -layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {float data_a[];}; -layout (binding = 1) writeonly buffer D {float16_t data_b[];}; - -layout (push_constant) uniform parameter -{ - int M; - int K; - int stride_a; - int stride_b; -} p; - -void main() { - const int row = int(gl_GlobalInvocationID.x % p.K); - const int col = int(gl_GlobalInvocationID.x / p.K); - - if (row < p.K && col < p.M) { - data_b[col * p.stride_b + row] = float16_t(data_a[col * p.stride_a + row]); - } -} -""" - generic_head = """ #version 450 @@ -1463,136 +1600,147 @@ layout (push_constant) uniform parameter } p; """ -# MUL F32 -mul_body = """layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; - -void main() { - const uint idx = gl_GlobalInvocationID.x; +generic_unary_op_head = """#version 450 - if (idx >= p.KX) { - return; - } +#extension GL_EXT_shader_16bit_storage : require - data_d[idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(data_b[idx % p.KY])); -} -""" +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint d_offset; + float param1; float param2; +} p; -# ADD -add_body = """ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; - -void main() { - const uint idx = gl_GlobalInvocationID.x; - - if (idx >= p.KX) { - return; - } - - data_d[idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) + FLOAT_TYPE(data_b[idx % p.KY])); -} -""" - -# SCALE -scale_body = """layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -void main() { - const uint idx = gl_GlobalInvocationID.x; - - if (idx >= p.KX) { - return; - } - - data_d[idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(p.param1)); +uint src0_idx(uint idx) { + const uint i03 = idx / (p.ne02*p.ne01*p.ne00); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = (idx - i03_offset - i02_offset) / p.ne00; + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; } -""" - -# SQR -sqr_body = """layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -void main() { - const uint idx = gl_GlobalInvocationID.x; - if (idx >= p.KX) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[idx]); - data_d[idx] = D_TYPE(val * val); +uint dst_idx(uint idx) { + const uint i13 = idx / (p.ne12*p.ne11*p.ne10); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = (idx - i13_offset - i12_offset) / p.ne10; + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; } -""" - -# CLAMP -clamp_body = """layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; void main() { - const uint idx = gl_GlobalInvocationID.x; - - if (idx >= p.KX) { + if (gl_GlobalInvocationID.x >= p.ne) { return; } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[idx]); - data_d[idx] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); -} """ -# CPY -cpy_src = """#version 450 +generic_binary_op_head = """#version 450 #extension GL_EXT_shader_16bit_storage : require layout (push_constant) uniform parameter { uint ne; - uint ne00; uint ne01; uint nb00; uint nb01; uint nb02; - uint ne10; uint ne11; uint nb10; uint nb11; uint nb12; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; uint d_offset; + uint param1; uint param2; } p; layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; +layout (binding = 1) readonly buffer B {A_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +uint src0_idx(uint idx) { + const uint i03 = idx / (p.ne02*p.ne01*p.ne00); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = (idx - i03_offset - i02_offset) / p.ne00; + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; +} + +uint src1_idx(uint idx) { + const uint i03 = idx / (p.ne02*p.ne01*p.ne00); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = (idx - i03_offset - i02_offset) / p.ne00; + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + + return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10; +} + +uint dst_idx(uint idx) { + const uint i23 = idx / (p.ne22*p.ne21*p.ne20); + const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20; + const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20); + const uint i22_offset = i22*p.ne21*p.ne20; + const uint i21 = (idx - i23_offset - i22_offset) / p.ne20; + const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20; + return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20; +} void main() { if (gl_GlobalInvocationID.x >= p.ne) { return; } +""" + +# MUL F32 +mul_body = """ + data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)])); +} +""" + +# ADD +add_body = """ + data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)])); +} +""" - const uint i02 = gl_GlobalInvocationID.x / (p.ne00*p.ne01); - const uint i01 = (gl_GlobalInvocationID.x - i02*p.ne01*p.ne00) / p.ne00; - const uint i00 = gl_GlobalInvocationID.x - i02*p.ne01*p.ne00 - i01*p.ne00; - const uint a_idx = i00*p.nb00 + i01*p.nb01 + i02*p.nb02; +# SCALE +scale_body = """ + data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(p.param1)); +} +""" - const uint i12 = gl_GlobalInvocationID.x / (p.ne10*p.ne11); - const uint i11 = (gl_GlobalInvocationID.x - i12*p.ne11*p.ne10) / p.ne10; - const uint i10 = gl_GlobalInvocationID.x - i12*p.ne11*p.ne10 - i11*p.ne10; - const uint d_idx = i10*p.nb10 + i11*p.nb11 + i12*p.nb12; +# SQR +sqr_body = """ + const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]); + data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val * val); +} """ + +# CLAMP +clamp_body = """ + const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]); + data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); +} +""" + +# CPY cpy_end = """ - data_d[p.d_offset + d_idx] = D_TYPE(data_a[a_idx]); + data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]); } """ # Causes an optimization error otherwise cpy_f16_f16_end = """ - data_d[p.d_offset + d_idx] = data_a[a_idx]; + data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = data_a[src0_idx(gl_GlobalInvocationID.x)]; } """ @@ -1815,6 +1963,24 @@ void main() { """ # SOFT_MAX +soft_max_head = """ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + uint KZ; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; +} p; +""" + soft_max_body = """ #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 @@ -1823,7 +1989,8 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; -layout (binding = 2) buffer D {D_TYPE data_d[];}; +layout (binding = 2) readonly buffer Z {C_TYPE data_c[];}; +layout (binding = 3) buffer D {D_TYPE data_d[];}; shared FLOAT_TYPE vals[BLOCK_SIZE]; @@ -1832,11 +1999,23 @@ void main() { const uint rowx = gl_WorkGroupID.x; const uint rowy = rowx % p.KY; + float slope = 0.0f; + + // ALiBi + if (p.max_bias > 0.0f) { + const uint h = rowx/p.KY; // head index + + const float base = h < p.n_head_log2 ? p.m0 : p.m1; + const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; + + slope = pow(base, exp); + } + // Find max vals[tid] = uintBitsToFloat(0xFF800000); [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - vals[tid] = max(vals[tid], FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.param1 + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f))); + vals[tid] = max(vals[tid], FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) + (p.KZ > 0 ? slope * data_c[col] : 0.0f)); } barrier(); @@ -1855,7 +2034,7 @@ void main() { [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { const uint i = rowx * p.KX + col; - const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.param1 + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); + const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); vals[tid] += val; data_d[i] = D_TYPE(val); } @@ -2028,6 +2207,65 @@ void main() { } """ +argsort_src = """ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) buffer D {int data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + bool ascending; +} p; + +void swap(uint idx0, uint idx1) { + int tmp = data_d[idx0]; + data_d[idx0] = data_d[idx1]; + data_d[idx1] = tmp; +} + +void main() { + // bitonic sort + const int col = int(gl_LocalInvocationID.x); + const uint row = gl_WorkGroupID.y; + + if (col >= p.ncols) { + return; + } + + const uint a_idx = row * p.ncols; + const uint d_idx = row * p.ncols; + + // initialize indices + if (col < p.ncols) { + data_d[col] = col; + } + barrier(); + + for (uint k = 2; k <= p.ncols; k *= 2) { + for (uint j = k / 2; j > 0; j /= 2) { + const uint ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]]) { + swap(d_idx + col, d_idx + ixj); + } + } else { + if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]]) { + swap(d_idx + col, d_idx + ixj); + } + } + } + barrier(); + } + } +} +""" + GLSLC = "glslc" VK_NUM_TYPES = 16 @@ -2129,6 +2367,8 @@ async def main(): tasks = [] + stream = [] + for fp16 in (False, True): # mulmat if fp16: @@ -2142,28 +2382,41 @@ async def main(): vec_type_f16 = "f16vec4" vec_type = "vec4" - stream = [] - stream.extend((mulmat_head, shader_float_type, mulmat_body)) - tasks.append(string_to_spv("matmul_f32_l", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f32_m", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f32_s", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f32_aligned_l", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f32_aligned_m", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f32_aligned_s", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - tasks.append(string_to_spv("matmul_f16_l", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f16_m", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f16_s", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f16_aligned_l", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f16_aligned_m", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f16_aligned_s", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16)) - - tasks.append(string_to_spv("matmul_f16_f32_l", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f16_f32_m", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f16_f32_s", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f16_f32_aligned_l", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f16_f32_aligned_m", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f16_f32_aligned_s", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + stream.clear() + stream.extend((mulmat_head, shader_float_type, mulmat_body1, mulmat_load_scalar, mulmat_body2)) + tasks.append(string_to_spv("matmul_f32", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_f32_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + tasks.append(string_to_spv("matmul_f16", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_f16_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16)) + + tasks.append(string_to_spv("matmul_f16_f32", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_f16_f32_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2)) + tasks.append(string_to_spv("matmul_q4_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_q4_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_1_defines, mulmat_body1, mulmat_load_q4_1, mulmat_body2)) + tasks.append(string_to_spv("matmul_q4_1_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_q4_1_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_0_defines, mulmat_body1, mulmat_load_q5_0, mulmat_body2)) + tasks.append(string_to_spv("matmul_q5_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q5_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_q5_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_1_defines, mulmat_body1, mulmat_load_q5_1, mulmat_body2)) + tasks.append(string_to_spv("matmul_q5_1_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q5_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_q5_1_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q8_0_defines, mulmat_body1, mulmat_load_q8_0, mulmat_body2)) + tasks.append(string_to_spv("matmul_q8_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q8_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_q8_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q8_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) # Shaders where precision is needed, so no fp16 version @@ -2205,18 +2458,18 @@ async def main(): stream.extend((dequant_head, shader_int8_ext, shader_f32)) - if i == GGML_TYPE_F16: - stream.extend((shader_f16_defines, shader_f16_dequant_func, dequant_body)) + if i == GGML_TYPE_F32: + stream.append(dequant_f32_body) elif i == GGML_TYPE_Q4_0: - stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, dequant_body)) + stream.extend((shader_q4_0_defines, dequant_q4_0_body)) elif i == GGML_TYPE_Q4_1: - stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, dequant_body)) + stream.extend((shader_q4_1_defines, dequant_q4_1_body)) elif i == GGML_TYPE_Q5_0: - stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, dequant_body)) + stream.extend((shader_q5_0_defines, dequant_q5_0_body)) elif i == GGML_TYPE_Q5_1: - stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, dequant_body)) + stream.extend((shader_q5_1_defines, dequant_q5_1_body)) elif i == GGML_TYPE_Q8_0: - stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, dequant_body)) + stream.extend((shader_q8_0_defines, dequant_q8_0_body)) elif i == GGML_TYPE_Q2_K: stream.extend((shader_q2_K_defines, dequant_q2_K_body)) elif i == GGML_TYPE_Q3_K: @@ -2232,8 +2485,6 @@ async def main(): tasks.append(string_to_spv(f"dequant_{type_names[i]}", "".join(stream), {"D_TYPE": "float16_t"})) - tasks.append(string_to_spv("f32_to_f16", f32_to_f16_src, {})) - # get_rows for i in range(0, VK_NUM_TYPES): stream.clear() @@ -2264,20 +2515,20 @@ async def main(): tasks.append(string_to_spv("norm_f32", f"{generic_head}\n{shader_f32}\n{norm_body}", {"A_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("rms_norm_f32", f"{generic_head}\n{shader_f32}\n{rms_norm_body}", {"A_TYPE": "float", "D_TYPE": "float"})) - tasks.append(string_to_spv("cpy_f32_f32", f"{cpy_src}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float"})) - tasks.append(string_to_spv("cpy_f32_f16", f"{cpy_src}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float16_t"})) - tasks.append(string_to_spv("cpy_f16_f16", f"{cpy_src}\n{cpy_f16_f16_end}", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) + tasks.append(string_to_spv("cpy_f32_f32", f"{generic_unary_op_head}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float"})) + tasks.append(string_to_spv("cpy_f32_f16", f"{generic_unary_op_head}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float16_t"})) + tasks.append(string_to_spv("cpy_f16_f16", f"{generic_unary_op_head}\n{cpy_f16_f16_end}", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) - tasks.append(string_to_spv("add_f32", f"{generic_head}\n{shader_f32}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"})) + tasks.append(string_to_spv("add_f32", f"{generic_binary_op_head}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {})) - tasks.append(string_to_spv("mul_f32", f"{generic_head}\n{shader_f32}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"})) + tasks.append(string_to_spv("mul_f32", f"{generic_binary_op_head}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) - tasks.append(string_to_spv("scale_f32", f"{generic_head}\n{shader_f32}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float"})) + tasks.append(string_to_spv("scale_f32", f"{generic_unary_op_head}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) - tasks.append(string_to_spv("sqr_f32", f"{generic_head}\n{shader_f32}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float"})) + tasks.append(string_to_spv("sqr_f32", f"{generic_unary_op_head}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) - tasks.append(string_to_spv("clamp_f32", f"{generic_head}\n{shader_f32}\n{clamp_body}", {"A_TYPE": "float", "D_TYPE": "float"})) + tasks.append(string_to_spv("clamp_f32", f"{generic_unary_op_head}\n{clamp_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) tasks.append(string_to_spv("gelu_f32", f"{generic_head}\n{shader_f32}\n{gelu_body}", {"A_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("silu_f32", f"{generic_head}\n{shader_f32}\n{silu_body}", {"A_TYPE": "float", "D_TYPE": "float"})) @@ -2285,7 +2536,7 @@ async def main(): tasks.append(string_to_spv("diag_mask_inf_f32", f"{diag_mask_inf_head}\n{shader_f32}\n{diag_mask_inf_body}", {"A_TYPE": "float", "D_TYPE": "float"})) - tasks.append(string_to_spv("soft_max_f32", f"{generic_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"})) + tasks.append(string_to_spv("soft_max_f32", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "C_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) @@ -2293,6 +2544,8 @@ async def main(): tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) + tasks.append(string_to_spv("argsort_f32", argsort_src, {"A_TYPE": "float"})) + # Helper to decorate tasks with semaphore acquisition. async def withSemaphore(sem, task): async with sem: -- cgit v1.2.3