diff options
author | 0cc4m <picard12@live.de> | 2024-06-03 10:59:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-03 10:59:14 +0200 |
commit | 3d7ebf63123b8652fb7bbecef7ba731202309901 (patch) | |
tree | 8adfcc3dd20946ece9c0b8d15b131823b24455ae /ggml_vk_generate_shaders.py | |
parent | a10cda58d3199cd85305e0f03a8c6056714ae2e8 (diff) |
Vulkan Mixture of Experts (MoE) support (#7628)
* Finish Vulkan mul_mat_id implementation
* Add Vulkan sum_rows and div ops
* Fix MUL_MAT_ID matrix matrix shader
* Fix MUL_MAT_ID matrix vector shader dispatch size
* Fix MUL_MAT_ID matrix vector shader and dispatch code
* Update Vulkan CPU offload for MUL_MAT_ID
* Fix crash when using split mode none and setting a main GPU
Diffstat (limited to 'ggml_vk_generate_shaders.py')
-rw-r--r-- | ggml_vk_generate_shaders.py | 531 |
1 files changed, 220 insertions, 311 deletions
diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py index 7c85ca7b..a905f570 100644 --- a/ggml_vk_generate_shaders.py +++ b/ggml_vk_generate_shaders.py @@ -225,10 +225,7 @@ mulmat_head = """#version 450 #extension GL_EXT_shader_16bit_storage : require #ifdef MUL_MAT_ID -#extension GL_EXT_buffer_reference2 : require -#extension GL_EXT_nonuniform_qualifier : require -#extension GL_EXT_scalar_block_layout : require -#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #define EXPERT_COUNT 8 #endif @@ -260,30 +257,22 @@ layout (push_constant) uniform parameter uint stride_a; uint stride_b; uint stride_d; - uint k_split; - - uint ne02; - uint ne12; - uint broadcast2; - uint broadcast3; uint batch_stride_a; uint batch_stride_b; uint batch_stride_d; #ifdef MUL_MAT_ID - uint expert_stride_a; - uint expert_stride_b0; - uint expert_stride_b1; - uint expert_stride_d; - - uint ids_stride; - - uint n_as; uint nei0; uint nei1; uint nbi1; uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; #endif } p; @@ -301,16 +290,14 @@ shared FLOAT_TYPE buf_a[BM * (BK+1)]; shared FLOAT_TYPE buf_b[BN * (BK+1)]; #ifdef MUL_MAT_ID -shared u8vec2 rowids[2048]; +shared u16vec2 row_ids[2048]; #endif void main() { #ifdef MUL_MAT_ID - const uint batch_idx = gl_GlobalInvocationID.z / p.n_as; - const uint expert_idx = gl_GlobalInvocationID.z % p.n_as; + const uint expert_idx = gl_GlobalInvocationID.z; #else const uint batch_idx = gl_GlobalInvocationID.z; -#endif const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; @@ -319,6 +306,7 @@ void main() { const uint i02 = i12 / p.broadcast2; const uint batch_idx_a = i03 * p.ne02 + i02; +#endif const uint blocks_m = (p.M + BM - 1) / BM; const uint ir = gl_WorkGroupID.x % blocks_m; @@ -350,30 +338,38 @@ void main() { for (uint ii1 = 0; ii1 < p.nei1; ii1++) { for (uint ii0 = 0; ii0 < p.nei0; ii0++) { if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { - rowids[_ne1] = u8vec2(ii0, ii1); + row_ids[_ne1] = u16vec2(ii0, ii1); _ne1++; } } } - const u8vec2 id = rowids[ir * BN + ic]; + barrier(); + + // Workgroup has no work + if (ic * BN >= _ne1) return; #endif +#ifdef MUL_MAT_ID + const uint start_k = 0; + const uint end_k = p.K; +#else const uint start_k = ik * p.k_split; const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif uint pos_a = ( #ifdef MUL_MAT_ID - expert_idx * p.expert_stride_a + + expert_idx * p.batch_stride_a + +#else + batch_idx_a * p.batch_stride_a + #endif - batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A; - uint pos_b = ( + ir * BM * p.stride_a + start_k) / LOAD_VEC_A; #ifdef MUL_MAT_ID - id.y * p.expert_stride_b1 + - (id.x % p.ne11) * p.expert_stride_b0 + + uint pos_b = 0; +#else + uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; #endif - batch_idx * p.batch_stride_b + - ic * BN * p.stride_b + start_k) / LOAD_VEC_B; float sums[WMITER * TM * WNITER * TN]; FLOAT_TYPE cache_a[WMITER * TM]; @@ -620,7 +616,12 @@ mulmat_body2 = """ } [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { #if LOAD_VEC_B == 8 +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; +#else const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; +#endif const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); @@ -631,18 +632,31 @@ mulmat_body2 = """ buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z); buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); #elif LOAD_VEC_B == 4 +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; +#else const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; +#endif const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); -#else +#elif !MUL_MAT_ID if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); } else { buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); } +#else + const uint row_i = ic * BN + loadc_b + l; + if (row_i < _ne1) { + const u16vec2 row_idx = row_ids[row_i]; + buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); + } else { + buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); + } #endif } @@ -681,11 +695,9 @@ mulmat_body2 = """ const uint dr = ir * BM + warp_r * WM; const uint dc = ic * BN + warp_c * WN; - const uint offsets = -#ifdef MUL_MAT_ID - expert_idx * p.expert_stride_d + +#ifndef MUL_MAT_ID + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; #endif - batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { @@ -693,10 +705,20 @@ mulmat_body2 = """ const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; [[unroll]] for (uint cc = 0; cc < TN; cc++) { +#ifdef MUL_MAT_ID + const uint row_i = dc_warp + cc; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; +#endif [[unroll]] for (uint cr = 0; cr < TM; cr++) { +#ifdef MUL_MAT_ID + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); +#else if (dr_warp + cr < p.M && dc_warp + cc < p.N) { data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); } +#endif } } } @@ -1172,46 +1194,29 @@ layout (push_constant) uniform parameter uint stride_b; uint stride_d; - uint ne02; - uint ne12; - uint broadcast2; - uint broadcast3; - uint batch_stride_a; uint batch_stride_b; uint batch_stride_d; #ifdef MUL_MAT_ID - uint expert_stride_a; - uint expert_stride_b0; - uint expert_stride_b1; - uint expert_stride_d0; - uint expert_stride_d1; - - uint ne11; uint nei0; - uint nbi1; - uint n_as; + uint ne11; +#else + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; #endif } p; -""" -mul_mat_vec_body = """ -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (constant_id = 0) const uint BLOCK_SIZE = 32; - -shared FLOAT_TYPE tmp[BLOCK_SIZE]; - -void main() { - const uint row = gl_WorkGroupID.x; - const uint tid = gl_LocalInvocationID.x; - const uint batch_idx = gl_GlobalInvocationID.y; +void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #ifdef MUL_MAT_ID - const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; - const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; + const uint expert_idx = gl_GlobalInvocationID.y; +#else + const uint batch_idx = gl_GlobalInvocationID.y; #endif +#ifndef MUL_MAT_ID const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; @@ -1219,28 +1224,44 @@ void main() { const uint i02 = i12 / p.broadcast2; const uint batch_idx_a = i03 * p.ne02 + i02; - -#ifdef MUL_MAT_ID - const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; +#else + const uint expert_id = data_ids[expert_idx]; #endif - const uint a_offset = + a_offset = #ifdef MUL_MAT_ID - expert_id * p.expert_stride_a + -#endif + expert_id * p.batch_stride_a; +#else batch_idx_a * p.batch_stride_a; - const uint b_offset = -#ifdef MUL_MAT_ID - (expert_idx0 % p.ne11) * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + #endif - batch_idx * p.batch_stride_b; - const uint d_offset = + b_offset = #ifdef MUL_MAT_ID - expert_idx0 * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + + (expert_idx % p.ne11) * p.stride_b; +#else + batch_idx * p.batch_stride_b; #endif + d_offset = +#ifdef MUL_MAT_ID + expert_idx * p.stride_d; +#else batch_idx * p.batch_stride_d; +#endif +} +""" + +mul_mat_vec_body = """ +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; @@ -1281,41 +1302,9 @@ shared FLOAT_TYPE tmp[32]; void main() { const uint row = gl_WorkGroupID.x; - const uint batch_idx = gl_GlobalInvocationID.y; -#ifdef MUL_MAT_ID - const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; - const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; -#endif - - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; - - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; - - const uint batch_idx_a = i03 * p.ne02 + i02; -#ifdef MUL_MAT_ID - const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; -#endif - - const uint a_offset = -#ifdef MUL_MAT_ID - expert_id * p.expert_stride_a + -#endif - batch_idx_a * p.batch_stride_a; - const uint b_offset = -#ifdef MUL_MAT_ID - (expert_idx0 % p.ne11) * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_b; - const uint d_offset = -#ifdef MUL_MAT_ID - expert_idx0 * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_d; + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; @@ -1384,41 +1373,9 @@ shared FLOAT_TYPE tmp[32]; void main() { const uint row = gl_WorkGroupID.x; - const uint batch_idx = gl_GlobalInvocationID.y; -#ifdef MUL_MAT_ID - const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; - const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; -#endif - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; - - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; - - const uint batch_idx_a = i03 * p.ne02 + i02; - -#ifdef MUL_MAT_ID - const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; -#endif - - const uint a_offset = -#ifdef MUL_MAT_ID - expert_id * p.expert_stride_a + -#endif - batch_idx_a * p.batch_stride_a; - const uint b_offset = -#ifdef MUL_MAT_ID - (expert_idx0 % p.ne11) * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_b; - const uint d_offset = -#ifdef MUL_MAT_ID - expert_idx0 * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_d; + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; @@ -1480,41 +1437,9 @@ shared FLOAT_TYPE tmp[32]; void main() { const uint row = gl_WorkGroupID.x; - const uint batch_idx = gl_GlobalInvocationID.y; -#ifdef MUL_MAT_ID - const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; - const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; -#endif - - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; - - const uint batch_idx_a = i03 * p.ne02 + i02; - -#ifdef MUL_MAT_ID - const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; -#endif - - const uint a_offset = -#ifdef MUL_MAT_ID - expert_id * p.expert_stride_a + -#endif - batch_idx_a * p.batch_stride_a; - const uint b_offset = -#ifdef MUL_MAT_ID - (expert_idx0 % p.ne11) * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_b; - const uint d_offset = -#ifdef MUL_MAT_ID - expert_idx0 * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_d; + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; @@ -1625,41 +1550,9 @@ shared FLOAT_TYPE tmp[32]; void main() { const uint row = gl_WorkGroupID.x; - const uint batch_idx = gl_GlobalInvocationID.y; -#ifdef MUL_MAT_ID - const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; - const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; -#endif - - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; - - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; - const uint batch_idx_a = i03 * p.ne02 + i02; - -#ifdef MUL_MAT_ID - const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; -#endif - - const uint a_offset = -#ifdef MUL_MAT_ID - expert_id * p.expert_stride_a + -#endif - batch_idx_a * p.batch_stride_a; - const uint b_offset = -#ifdef MUL_MAT_ID - (expert_idx0 % p.ne11) * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_b; - const uint d_offset = -#ifdef MUL_MAT_ID - expert_idx0 * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_d; + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; @@ -1766,41 +1659,9 @@ shared FLOAT_TYPE tmp[32]; void main() { const uint row = gl_WorkGroupID.x; - const uint batch_idx = gl_GlobalInvocationID.y; -#ifdef MUL_MAT_ID - const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; - const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; -#endif - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; - - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; - - const uint batch_idx_a = i03 * p.ne02 + i02; - -#ifdef MUL_MAT_ID - const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; -#endif - - const uint a_offset = -#ifdef MUL_MAT_ID - expert_id * p.expert_stride_a + -#endif - batch_idx_a * p.batch_stride_a; - const uint b_offset = -#ifdef MUL_MAT_ID - (expert_idx0 % p.ne11) * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_b; - const uint d_offset = -#ifdef MUL_MAT_ID - expert_idx0 * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_d; + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; @@ -2143,12 +2004,18 @@ void main() { generic_binary_op_combined = f"{generic_binary_op_head}\n{generic_binary_op_layout}\n{generic_binary_op_funcs}\n{generic_binary_op_main}" -# MUL F32 +# MUL mul_body = """ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)])); } """ +# DIV +div_body = """ + data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) / FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)])); +} +""" + # ADD add_body = """ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)])); @@ -2759,6 +2626,41 @@ void main() { } """ +sum_rows_src = """ +#extension GL_EXT_control_flow_attributes : enable +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + tmp[col] = FLOAT_TYPE(0.0f); + + for (uint i = col; i < p.KX; i += BLOCK_SIZE) { + tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]); + } + + barrier(); + [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { + if (col < s) { + tmp[col] += tmp[col + s]; + } + barrier(); + } + + if (col == 0) { + data_d[row] = D_TYPE(tmp[0]); + } +} +""" + GLSLC = "glslc" VK_NUM_TYPES = 16 @@ -2940,66 +2842,66 @@ async def main(): tasks.append(string_to_spv("matmul_q6_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q6_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) # MUL_MAT_ID - # stream.clear() - # stream.extend((mulmat_head, shader_float_type, mulmat_body1, mulmat_load_scalar, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_f32", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # tasks.append(string_to_spv("matmul_id_f16", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_f16_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16)) - - # tasks.append(string_to_spv("matmul_id_f16_f32", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_f16_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q4_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q4_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_1_defines, mulmat_body1, mulmat_load_q4_1, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q4_1_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q4_1_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_0_defines, mulmat_body1, mulmat_load_q5_0, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q5_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q5_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_1_defines, mulmat_body1, mulmat_load_q5_1, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q5_1_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q5_1_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q8_0_defines, mulmat_body1, mulmat_load_q8_0, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q8_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q8_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q8_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q8_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q2_K_defines, mulmat_body1, mulmat_load_q2_K, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q2_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q2_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q2_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q2_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q3_K_defines, mulmat_body1, mulmat_load_q3_K, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q3_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q3_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q3_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q3_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_K_defines, mulmat_body1, mulmat_load_q4_K, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q4_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q4_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_K_defines, mulmat_body1, mulmat_load_q5_K, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q5_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q5_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q6_K_defines, mulmat_body1, mulmat_load_q6_K, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q6_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q6_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q6_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q6_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + stream.clear() + stream.extend((mulmat_head, shader_float_type, mulmat_body1, mulmat_load_scalar, mulmat_body2)) + tasks.append(string_to_spv("matmul_id_f32", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_id_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + tasks.append(string_to_spv("matmul_id_f16", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_id_f16_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16)) + + tasks.append(string_to_spv("matmul_id_f16_f32", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_id_f16_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2)) + tasks.append(string_to_spv("matmul_id_q4_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_id_q4_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_1_defines, mulmat_body1, mulmat_load_q4_1, mulmat_body2)) + tasks.append(string_to_spv("matmul_id_q4_1_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_id_q4_1_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_0_defines, mulmat_body1, mulmat_load_q5_0, mulmat_body2)) + tasks.append(string_to_spv("matmul_id_q5_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_id_q5_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_1_defines, mulmat_body1, mulmat_load_q5_1, mulmat_body2)) + tasks.append(string_to_spv("matmul_id_q5_1_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_id_q5_1_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q8_0_defines, mulmat_body1, mulmat_load_q8_0, mulmat_body2)) + tasks.append(string_to_spv("matmul_id_q8_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q8_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_id_q8_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q8_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q2_K_defines, mulmat_body1, mulmat_load_q2_K, mulmat_body2)) + tasks.append(string_to_spv("matmul_id_q2_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q2_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_id_q2_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q2_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q3_K_defines, mulmat_body1, mulmat_load_q3_K, mulmat_body2)) + tasks.append(string_to_spv("matmul_id_q3_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q3_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_id_q3_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q3_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_K_defines, mulmat_body1, mulmat_load_q4_K, mulmat_body2)) + tasks.append(string_to_spv("matmul_id_q4_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_id_q4_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_K_defines, mulmat_body1, mulmat_load_q5_K, mulmat_body2)) + tasks.append(string_to_spv("matmul_id_q5_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_id_q5_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + stream.clear() + stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q6_K_defines, mulmat_body1, mulmat_load_q6_K, mulmat_body2)) + tasks.append(string_to_spv("matmul_id_q6_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q6_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_id_q6_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q6_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) # Shaders where precision is needed, so no fp16 version @@ -3008,7 +2910,9 @@ async def main(): stream.clear() stream.extend((mul_mat_vec_head, shader_int8_ext, shader_f32)) - if i == GGML_TYPE_F16: + if i == GGML_TYPE_F32: + stream.extend((shader_f32_defines, mul_mat_vec_layout, shader_float_dequant_func, mul_mat_vec_body)) + elif i == GGML_TYPE_F16: stream.extend((shader_f16_defines, mul_mat_vec_layout, shader_float_dequant_func, mul_mat_vec_body)) elif i == GGML_TYPE_Q4_0: stream.extend((shader_q4_0_defines, mul_mat_vec_layout, shader_q4_0_dequant_func, mul_mat_vec_body)) @@ -3036,7 +2940,7 @@ async def main(): tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION})) tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f16_f32", "".join(stream), {"B_TYPE": "float16_t", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION})) - # tasks.append(string_to_spv(f"mul_mat_vec_id_{type_names[i]}_f32", "".join(stream), {"MUL_MAT_ID": "1", "B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION})) + tasks.append(string_to_spv(f"mul_mat_vec_id_{type_names[i]}_f32", "".join(stream), {"MUL_MAT_ID": "1", "B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION})) # Dequant shaders for i in range(0, VK_NUM_TYPES): @@ -3115,8 +3019,11 @@ async def main(): tasks.append(string_to_spv("add_f32", f"{generic_binary_op_combined}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {})) + tasks.append(string_to_spv("mul_f32", f"{generic_binary_op_combined}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) + tasks.append(string_to_spv("div_f32", f"{generic_binary_op_combined}\n{div_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) + tasks.append(string_to_spv("scale_f32", f"{generic_unary_op_combined}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) tasks.append(string_to_spv("sqr_f32", f"{generic_unary_op_combined}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) @@ -3140,6 +3047,8 @@ async def main(): tasks.append(string_to_spv("argsort_f32", argsort_src, {"A_TYPE": "float"})) + tasks.append(string_to_spv("sum_rows_f32", f"{generic_head}\n{shader_f32}\n{sum_rows_src}", {"A_TYPE": "float", "D_TYPE": "float"})) + # Helper to decorate tasks with semaphore acquisition. async def withSemaphore(sem, task): async with sem: |