diff options
author | 0cc4m <picard12@live.de> | 2024-05-09 20:39:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-09 20:39:54 +0200 |
commit | befddd0f15de6efb15d7e7f5b527dfb671f4196f (patch) | |
tree | 7c0aa80c4b4f8fef76aa284982502a3bf8dae1d5 /ggml_vk_generate_shaders.py | |
parent | d46dbc76f8770caec0175f1e57777173c70556a0 (diff) |
Vulkan Bugfixes and Improvements (#7084)
* Modify mat mat mul shader for mul_mat_id, modify mat vec mul shaders for single call batch operation
* Further work towards MoE, disabled for now
* Disable MoE code (not ready yet), fix a number of bugs in shaders and Vulkan code
* Add softmax with f16 mask and pos buffer support
* Disable mul_mat_id shaders for now
* Fix flake8
* Fix validation errors caused by empty buffers on larger batch sizes
Diffstat (limited to 'ggml_vk_generate_shaders.py')
-rw-r--r-- | ggml_vk_generate_shaders.py | 716 |
1 files changed, 541 insertions, 175 deletions
diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py index 1d9a0cc8..162cf5c6 100644 --- a/ggml_vk_generate_shaders.py +++ b/ggml_vk_generate_shaders.py @@ -167,47 +167,54 @@ struct block_q6_K # Dequant functions shader_float_dequant_func = """ -#define DEQUANT_FUNC vec2 v = vec2(ib, ib); // data_a[ib], data_a[ib + 1]); +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); +} """ shader_q4_0_dequant_func = """ -#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ -const uint vui = uint(data_a[ib].qs[iqs]); \ -vec2 v = vec2(vui & 0xF, vui >> 4); \ -v = (v - 8.0f)*d; +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const float d = float(data_a[a_offset + ib].d); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; +} """ shader_q4_1_dequant_func = """ -#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ -const float m = float(data_a[ib].m); \ -const uint vui = uint(data_a[ib].qs[iqs]); \ -vec2 v = vec2(vui & 0xF, vui >> 4); \ -v = v*d + m; +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const float d = float(data_a[a_offset + ib].d); + const float m = float(data_a[a_offset + ib].m); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(vui & 0xF, vui >> 4) * d + m; +} """ shader_q5_0_dequant_func = """ -#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ -const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; \ -const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \ -const uint vui = uint(data_a[ib].qs[iqs]); \ -vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \ -v = (v - 16.0f) * d; +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const float d = float(data_a[a_offset + ib].d); + const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0]; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; +} """ shader_q5_1_dequant_func = """ -#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ -const float m = float(data_a[ib].m); \ -const uint uint_qh = data_a[ib].qh; \ -const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \ -const uint vui = uint(data_a[ib].qs[iqs]); \ -vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \ -v = v*d + m; +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const float d = float(data_a[a_offset + ib].d); + const float m = float(data_a[a_offset + ib].m); + const uint uint_qh = data_a[a_offset + ib].qh; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; +} """ shader_q8_0_dequant_func = """ -#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ -vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])); \ -v = v * d; +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const float d = float(data_a[a_offset + ib].d); + return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d; +} """ # MULMAT @@ -217,6 +224,15 @@ mulmat_head = """#version 450 #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require +#ifdef MUL_MAT_ID +#extension GL_EXT_buffer_reference2 : require +#extension GL_EXT_nonuniform_qualifier : require +#extension GL_EXT_scalar_block_layout : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#define EXPERT_COUNT 8 +#endif + #ifndef LOAD_VEC_A #define LOAD_VEC_A 1 #endif @@ -232,6 +248,10 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + layout (push_constant) uniform parameter { uint M; @@ -250,6 +270,21 @@ layout (push_constant) uniform parameter uint batch_stride_a; uint batch_stride_b; uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint expert_stride_a; + uint expert_stride_b0; + uint expert_stride_b1; + uint expert_stride_d; + + uint ids_stride; + + uint n_as; + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#endif } p; layout (constant_id = 1) const uint BM = 64; @@ -265,9 +300,20 @@ layout (constant_id = 9) const uint WARP = 32; shared FLOAT_TYPE buf_a[BM * (BK+1)]; shared FLOAT_TYPE buf_b[BN * (BK+1)]; +#ifdef MUL_MAT_ID +shared u8vec2 rowids[2048]; +#endif + void main() { - const uint i13 = gl_GlobalInvocationID.z / p.ne12; - const uint i12 = gl_GlobalInvocationID.z % p.ne12; +#ifdef MUL_MAT_ID + const uint batch_idx = gl_GlobalInvocationID.z / p.n_as; + const uint expert_idx = gl_GlobalInvocationID.z % p.n_as; +#else + const uint batch_idx = gl_GlobalInvocationID.z; +#endif + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; const uint i03 = i13 / p.broadcast3; const uint i02 = i12 / p.broadcast2; @@ -299,11 +345,35 @@ void main() { const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK; const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; +#ifdef MUL_MAT_ID + uint _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0; ii0++) { + if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { + rowids[_ne1] = u8vec2(ii0, ii1); + _ne1++; + } + } + } + + const u8vec2 id = rowids[ir * BN + ic]; +#endif + const uint start_k = ik * p.k_split; const uint end_k = min(p.K, (ik + 1) * p.k_split); - uint pos_a = (batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A; - uint pos_b = (gl_GlobalInvocationID.z * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; + uint pos_a = ( +#ifdef MUL_MAT_ID + expert_idx * p.expert_stride_a + +#endif + batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A; + uint pos_b = ( +#ifdef MUL_MAT_ID + id.y * p.expert_stride_b1 + + (id.x % p.ne11) * p.expert_stride_b0 + +#endif + batch_idx * p.batch_stride_b + + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; float sums[WMITER * TM * WNITER * TN]; FLOAT_TYPE cache_a[WMITER * TM]; @@ -611,7 +681,11 @@ mulmat_body2 = """ const uint dr = ir * BM + warp_r * WM; const uint dc = ic * BN + warp_c * WN; - const uint offsets = gl_GlobalInvocationID.z * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + const uint offsets = +#ifdef MUL_MAT_ID + expert_idx * p.expert_stride_d + +#endif + batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { @@ -1077,21 +1151,54 @@ mul_mat_vec_head = """#version 450 #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_8bit_storage : require +#ifdef MUL_MAT_ID +#define EXPERT_COUNT 8 +#endif +""" + + +mul_mat_vec_layout = """ +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + layout (push_constant) uniform parameter { uint ncols; - uint b_offset; - uint d_offset; + uint stride_a; + uint stride_b; + uint stride_d; + + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint expert_stride_a; + uint expert_stride_b0; + uint expert_stride_b1; + uint expert_stride_d0; + uint expert_stride_d1; + + uint ne11; + uint nei0; + uint nbi1; + uint n_as; +#endif } p; """ mul_mat_vec_body = """ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; - layout (constant_id = 0) const uint BLOCK_SIZE = 32; shared FLOAT_TYPE tmp[BLOCK_SIZE]; @@ -1099,6 +1206,41 @@ shared FLOAT_TYPE tmp[BLOCK_SIZE]; void main() { const uint row = gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + const uint batch_idx = gl_GlobalInvocationID.y; +#ifdef MUL_MAT_ID + const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; + const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; +#endif + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; + +#ifdef MUL_MAT_ID + const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; +#endif + + const uint a_offset = +#ifdef MUL_MAT_ID + expert_id * p.expert_stride_a + +#endif + batch_idx_a * p.batch_stride_a; + const uint b_offset = +#ifdef MUL_MAT_ID + (expert_idx0 % p.ne11) * p.expert_stride_b0 + + expert_idx1 * p.expert_stride_b1 + +#endif + batch_idx * p.batch_stride_b; + const uint d_offset = +#ifdef MUL_MAT_ID + expert_idx0 * p.expert_stride_b0 + + expert_idx1 * p.expert_stride_b1 + +#endif + batch_idx * p.batch_stride_d; const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; @@ -1110,11 +1252,11 @@ void main() { const uint iqs = (col%QUANT_K)/QUANT_R; // quant index const uint iybs = col - col%QUANT_K; // y block start index - DEQUANT_FUNC + vec2 v = dequantize(ib, iqs, a_offset / QUANT_K); // matrix multiplication - tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[p.b_offset + iybs + iqs + 0]) + - FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[p.b_offset + iybs + iqs + y_offset]); + tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[b_offset + iybs + iqs]) + + FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]); } // sum up partial sums and write back result @@ -1126,7 +1268,7 @@ void main() { barrier(); } if (tid == 0) { - dst[p.d_offset + row] = D_TYPE(tmp[0]); + data_d[d_offset + row] = D_TYPE(tmp[0]); } } """ @@ -1135,17 +1277,48 @@ void main() { mul_mat_vec_q2_K_body = """ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; - shared FLOAT_TYPE tmp[32]; void main() { const uint row = gl_WorkGroupID.x; + const uint batch_idx = gl_GlobalInvocationID.y; +#ifdef MUL_MAT_ID + const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; + const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; +#endif + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; + +#ifdef MUL_MAT_ID + const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; +#endif + + const uint a_offset = +#ifdef MUL_MAT_ID + expert_id * p.expert_stride_a + +#endif + batch_idx_a * p.batch_stride_a; + const uint b_offset = +#ifdef MUL_MAT_ID + (expert_idx0 % p.ne11) * p.expert_stride_b0 + + expert_idx1 * p.expert_stride_b1 + +#endif + batch_idx * p.batch_stride_b; + const uint d_offset = +#ifdef MUL_MAT_ID + expert_idx0 * p.expert_stride_b0 + + expert_idx1 * p.expert_stride_b1 + +#endif + batch_idx * p.batch_stride_d; const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = row*num_blocks_per_row; + const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 @@ -1171,22 +1344,22 @@ void main() { FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum1 += FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3); - sum2 += FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 0]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 16]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 32]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 48]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 64]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 80]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 96]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l +112]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF); + sum1 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3) + + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3); + sum2 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF) + + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF); } tmp[16 * ix + tid] += dall * sum1 - dmin * sum2; } @@ -1200,24 +1373,55 @@ void main() { barrier(); } if (tid == 0) { - dst[p.d_offset + row] = D_TYPE(tmp[0]); + data_d[d_offset + row] = D_TYPE(tmp[0]); } } """ mul_mat_vec_q3_K_body = """ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; - shared FLOAT_TYPE tmp[32]; void main() { const uint row = gl_WorkGroupID.x; + const uint batch_idx = gl_GlobalInvocationID.y; +#ifdef MUL_MAT_ID + const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; + const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; +#endif + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; + +#ifdef MUL_MAT_ID + const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; +#endif + + const uint a_offset = +#ifdef MUL_MAT_ID + expert_id * p.expert_stride_a + +#endif + batch_idx_a * p.batch_stride_a; + const uint b_offset = +#ifdef MUL_MAT_ID + (expert_idx0 % p.ne11) * p.expert_stride_b0 + + expert_idx1 * p.expert_stride_b1 + +#endif + batch_idx * p.batch_stride_b; + const uint d_offset = +#ifdef MUL_MAT_ID + expert_idx0 * p.expert_stride_b0 + + expert_idx1 * p.expert_stride_b1 + +#endif + batch_idx * p.batch_stride_d; const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = row*num_blocks_per_row; + const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 @@ -1244,14 +1448,14 @@ void main() { FLOAT_TYPE sum = FLOAT_TYPE(0.0); for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum += FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)); + sum += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)) + + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)) + + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)); } tmp[16 * ix + tid] += d * sum; } @@ -1265,24 +1469,55 @@ void main() { barrier(); } if (tid == 0) { - dst[p.d_offset + row] = D_TYPE(tmp[0]); + data_d[d_offset + row] = D_TYPE(tmp[0]); } } """ mul_mat_vec_q4_K_body = """ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; - shared FLOAT_TYPE tmp[32]; void main() { const uint row = gl_WorkGroupID.x; + const uint batch_idx = gl_GlobalInvocationID.y; +#ifdef MUL_MAT_ID + const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; + const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; +#endif + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; + +#ifdef MUL_MAT_ID + const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; +#endif + + const uint a_offset = +#ifdef MUL_MAT_ID + expert_id * p.expert_stride_a + +#endif + batch_idx_a * p.batch_stride_a; + const uint b_offset = +#ifdef MUL_MAT_ID + (expert_idx0 % p.ne11) * p.expert_stride_b0 + + expert_idx1 * p.expert_stride_b1 + +#endif + batch_idx * p.batch_stride_b; + const uint d_offset = +#ifdef MUL_MAT_ID + expert_idx0 * p.expert_stride_b0 + + expert_idx1 * p.expert_stride_b1 + +#endif + batch_idx * p.batch_stride_d; const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = row*num_blocks_per_row; + const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 @@ -1336,15 +1571,15 @@ void main() { const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4); const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4); - const FLOAT_TYPE sx = FLOAT_TYPE(data_b[p.b_offset + y1_idx] * q4_0 + data_b[p.b_offset + y1_idx + 1] * q4_1 + data_b[p.b_offset + y1_idx + 2] * q4_2 + data_b[p.b_offset + y1_idx + 3] * q4_3); - const FLOAT_TYPE sy = FLOAT_TYPE(data_b[p.b_offset + y1_idx + 32] * q4_4 + data_b[p.b_offset + y1_idx + 33] * q4_5 + data_b[p.b_offset + y1_idx + 34] * q4_6 + data_b[p.b_offset + y1_idx + 35] * q4_7); - const FLOAT_TYPE sz = FLOAT_TYPE(data_b[p.b_offset + y2_idx] * q4_8 + data_b[p.b_offset + y2_idx + 1] * q4_9 + data_b[p.b_offset + y2_idx + 2] * q4_10 + data_b[p.b_offset + y2_idx + 3] * q4_11); - const FLOAT_TYPE sw = FLOAT_TYPE(data_b[p.b_offset + y2_idx + 32] * q4_12 + data_b[p.b_offset + y2_idx + 33] * q4_13 + data_b[p.b_offset + y2_idx + 34] * q4_14 + data_b[p.b_offset + y2_idx + 35] * q4_15); + const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1 + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3); + const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_5 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7); + const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx]) * q4_8 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_9 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * q4_10 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11); + const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_12 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_13 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * q4_14 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15); const FLOAT_TYPE smin = FLOAT_TYPE( - data_b[p.b_offset + y1_idx ] * sc2 + data_b[p.b_offset + y1_idx + 32] * sc3 + data_b[p.b_offset + y2_idx ] * sc6 + data_b[p.b_offset + y2_idx + 32] * sc7 - + data_b[p.b_offset + y1_idx + 1] * sc2 + data_b[p.b_offset + y1_idx + 33] * sc3 + data_b[p.b_offset + y2_idx + 1] * sc6 + data_b[p.b_offset + y2_idx + 33] * sc7 - + data_b[p.b_offset + y1_idx + 2] * sc2 + data_b[p.b_offset + y1_idx + 34] * sc3 + data_b[p.b_offset + y2_idx + 2] * sc6 + data_b[p.b_offset + y2_idx + 34] * sc7 - + data_b[p.b_offset + y1_idx + 3] * sc2 + data_b[p.b_offset + y1_idx + 35] * sc3 + data_b[p.b_offset + y2_idx + 3] * sc6 + data_b[p.b_offset + y2_idx + 35] * sc7 + FLOAT_TYPE(data_b[b_offset + y1_idx ]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx ]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7 + + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7 + + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * sc7 + + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7 ); tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin); #else @@ -1357,13 +1592,13 @@ void main() { const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); - const FLOAT_TYPE sx = FLOAT_TYPE(data_b[p.b_offset + y1_idx ] * q4_0 + data_b[p.b_offset + y1_idx + 1] * q4_1); - const FLOAT_TYPE sy = FLOAT_TYPE(data_b[p.b_offset + y1_idx + 32] * q4_2 + data_b[p.b_offset + y1_idx + 33] * q4_3); - const FLOAT_TYPE sz = FLOAT_TYPE(data_b[p.b_offset + y2_idx ] * q4_4 + data_b[p.b_offset + y2_idx + 1] * q4_5); - const FLOAT_TYPE sw = FLOAT_TYPE(data_b[p.b_offset + y2_idx + 32] * q4_6 + data_b[p.b_offset + y2_idx + 33] * q4_7); + const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx ]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1); + const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3); + const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx ]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5); + const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7); const FLOAT_TYPE smin = FLOAT_TYPE( - data_b[p.b_offset + y1_idx] * sc2 + data_b[p.b_offset + y1_idx + 32] * sc3 + data_b[p.b_offset + y2_idx] * sc6 + data_b[p.b_offset + y2_idx + 32] * sc7 - + data_b[p.b_offset + y1_idx + 1] * sc2 + data_b[p.b_offset + y1_idx + 33] * sc3 + data_b[p.b_offset + y2_idx + 1] * sc6 + data_b[p.b_offset + y2_idx + 33] * sc7 + FLOAT_TYPE(data_b[b_offset + y1_idx]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7 + + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7 ); tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin); @@ -1379,24 +1614,55 @@ void main() { barrier(); } if (tid == 0) { - dst[p.d_offset + row] = D_TYPE(tmp[0]); + data_d[d_offset + row] = D_TYPE(tmp[0]); } } """ mul_mat_vec_q5_K_body = """ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; - shared FLOAT_TYPE tmp[32]; void main() { const uint row = gl_WorkGroupID.x; + const uint batch_idx = gl_GlobalInvocationID.y; +#ifdef MUL_MAT_ID + const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; + const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; +#endif + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; + +#ifdef MUL_MAT_ID + const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; +#endif + + const uint a_offset = +#ifdef MUL_MAT_ID + expert_id * p.expert_stride_a + +#endif + batch_idx_a * p.batch_stride_a; + const uint b_offset = +#ifdef MUL_MAT_ID + (expert_idx0 % p.ne11) * p.expert_stride_b0 + + expert_idx1 * p.expert_stride_b1 + +#endif + batch_idx * p.batch_stride_b; + const uint d_offset = +#ifdef MUL_MAT_ID + expert_idx0 * p.expert_stride_b0 + + expert_idx1 * p.expert_stride_b1 + +#endif + batch_idx * p.batch_stride_d; const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = row*num_blocks_per_row; + const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint tid = gl_LocalInvocationID.x/2; // 0...31 or 0...16 const uint ix = gl_LocalInvocationID.x%2; // 0 or 0, 1 @@ -1450,32 +1716,32 @@ void main() { const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4); const FLOAT_TYPE sx = FLOAT_TYPE( - data_b[p.b_offset + y1_idx ] * (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)) - + data_b[p.b_offset + y1_idx + 1] * (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)) - + data_b[p.b_offset + y1_idx + 16] * (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)) - + data_b[p.b_offset + y1_idx + 17] * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)) + FLOAT_TYPE(data_b[b_offset + y1_idx ]) * (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)) + + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)) + + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) * (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)) + + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)) ); const FLOAT_TYPE sy = FLOAT_TYPE( - data_b[p.b_offset + y1_idx + 32] * (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)) - + data_b[p.b_offset + y1_idx + 33] * (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)) - + data_b[p.b_offset + y1_idx + 48] * (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)) - + data_b[p.b_offset + y1_idx + 49] * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)) + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)) + + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)) + + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) * (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)) + + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)) ); const FLOAT_TYPE sz = FLOAT_TYPE( - data_b[p.b_offset + y2_idx ] * (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)) - + data_b[p.b_offset + y2_idx + 1] * (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)) - + data_b[p.b_offset + y2_idx + 16] * (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)) - + data_b[p.b_offset + y2_idx + 17] * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)) + FLOAT_TYPE(data_b[b_offset + y2_idx ]) * (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)) + + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)) + + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) * (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)) + + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)) ); const FLOAT_TYPE sw = FLOAT_TYPE( - data_b[p.b_offset + y2_idx + 32] * (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)) - + data_b[p.b_offset + y2_idx + 33] * (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)) - + data_b[p.b_offset + y2_idx + 48] * (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)) - + data_b[p.b_offset + y2_idx + 49] * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)) + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)) + + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)) + + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) * (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)) + + FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)) ); const FLOAT_TYPE smin = FLOAT_TYPE( - (data_b[p.b_offset + y1_idx] + data_b[p.b_offset + y1_idx + 1] + data_b[p.b_offset + y1_idx + 16] + data_b[p.b_offset + y1_idx + 17]) * sc2 + (data_b[p.b_offset + y1_idx + 32] + data_b[p.b_offset + y1_idx + 33] + data_b[p.b_offset + y1_idx + 48] + data_b[p.b_offset + y1_idx + 49]) * sc3 - + (data_b[p.b_offset + y2_idx] + data_b[p.b_offset + y2_idx + 1] + data_b[p.b_offset + y2_idx + 16] + data_b[p.b_offset + y2_idx + 17]) * sc6 + (data_b[p.b_offset + y2_idx + 32] + data_b[p.b_offset + y2_idx + 33] + data_b[p.b_offset + y2_idx + 48] + data_b[p.b_offset + y2_idx + 49]) * sc7 + (FLOAT_TYPE(data_b[b_offset + y1_idx]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17])) * sc2 + (FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49])) * sc3 + + (FLOAT_TYPE(data_b[b_offset + y2_idx]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17])) * sc6 + (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7 ); tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin); } @@ -1489,25 +1755,55 @@ void main() { barrier(); } if (tid == 0) { - dst[p.d_offset + row] = D_TYPE(tmp[0]); + data_d[d_offset + row] = D_TYPE(tmp[0]); } } """ mul_mat_vec_q6_K_body = """ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; - shared FLOAT_TYPE tmp[32]; void main() { - const uint block_size = gl_WorkGroupSize.x; const uint row = gl_WorkGroupID.x; + const uint batch_idx = gl_GlobalInvocationID.y; +#ifdef MUL_MAT_ID + const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; + const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; +#endif + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; + +#ifdef MUL_MAT_ID + const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; +#endif + + const uint a_offset = +#ifdef MUL_MAT_ID + expert_id * p.expert_stride_a + +#endif + batch_idx_a * p.batch_stride_a; + const uint b_offset = +#ifdef MUL_MAT_ID + (expert_idx0 % p.ne11) * p.expert_stride_b0 + + expert_idx1 * p.expert_stride_b1 + +#endif + batch_idx * p.batch_stride_b; + const uint d_offset = +#ifdef MUL_MAT_ID + expert_idx0 * p.expert_stride_b0 + + expert_idx1 * p.expert_stride_b1 + +#endif + batch_idx * p.batch_stride_d; const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = row*num_blocks_per_row; + const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 @@ -1538,22 +1834,22 @@ void main() { const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); #if K_QUANTS_PER_ITERATION == 1 - FLOAT_TYPE sum = FLOAT_TYPE(data_b[p.b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32) - + FLOAT_TYPE(data_b[p.b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32); + FLOAT_TYPE sum = FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32) + + FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32) + + FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32) + + FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32) + + FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32) + + FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32) + + FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32) + + FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32); tmp[16 * ix + tid] += sum; #else FLOAT_TYPE sum = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 4; ++l) { - sum += FLOAT_TYPE(data_b[p.b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32) - + FLOAT_TYPE(data_b[p.b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32); + sum += FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32) + + FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32) + + FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32) + + FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32); } tmp[16 * ix + tid] += sum; #endif @@ -1568,7 +1864,7 @@ void main() { barrier(); } if (tid == 0) { - dst[p.d_offset + row] = D_TYPE(tmp[0]); + data_d[d_offset + row] = D_TYPE(tmp[0]); } } """ @@ -1749,12 +2045,13 @@ layout (push_constant) uniform parameter float param1; float param2; } p;""" -generic_unary_op_funcs = """ +generic_unary_op_layout = """ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};""" +generic_unary_op_funcs = """ uint src0_idx(uint idx) { const uint i03 = idx / (p.ne02*p.ne01*p.ne00); const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; @@ -1782,7 +2079,7 @@ void main() { } """ -generic_unary_op_combined = f"{generic_unary_op_head}\n{generic_unary_op_funcs}\n{generic_unary_op_main}" +generic_unary_op_combined = f"{generic_unary_op_head}\n{generic_unary_op_layout}\n{generic_unary_op_funcs}\n{generic_unary_op_main}" generic_binary_op_head = """#version 450 @@ -1798,13 +2095,14 @@ layout (push_constant) uniform parameter float param1; float param2; } p;""" -generic_binary_op_funcs = """ +generic_binary_op_layout = """ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};""" +generic_binary_op_funcs = """ uint src0_idx(uint idx) { const uint i03 = idx / (p.ne02*p.ne01*p.ne00); const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; @@ -1843,7 +2141,7 @@ void main() { } """ -generic_binary_op_combined = f"{generic_binary_op_head}\n{generic_binary_op_funcs}\n{generic_binary_op_main}" +generic_binary_op_combined = f"{generic_binary_op_head}\n{generic_binary_op_layout}\n{generic_binary_op_funcs}\n{generic_binary_op_main}" # MUL F32 mul_body = """ @@ -1859,7 +2157,7 @@ add_body = """ # SCALE scale_body = """ - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(p.param1)); + data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(p.param1)); } """ @@ -1934,7 +2232,7 @@ void main() { const uint iybs = i00 - i00%QUANT_K; // dst block start index const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; - DEQUANT_FUNC + vec2 v = dequantize(ib, iqs, 0); data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); @@ -2033,7 +2331,11 @@ void main() { } const uint i = row*p.ncols + col; - data_d[i] = D_TYPE(data_a[i] - float(uint(col > p.n_past + row % p.rows_per_channel) * 0xFFFFFFFF)); + if (col > p.n_past + row % p.rows_per_channel) { + data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000)); + } else { + data_d[i] = D_TYPE(data_a[i]); + } } """ @@ -2053,8 +2355,6 @@ void main() { const uint row = gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; - const float eps = 1e-5f; - sum[tid] = vec2(0.0f, 0.0f); [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { @@ -2074,7 +2374,7 @@ void main() { const float mean = sum[0].x / p.KX; const float var = sum[0].y / p.KX - mean * mean; - const float inv_std = inversesqrt(var + 1e-5f); + const float inv_std = inversesqrt(var + p.param1); [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std); @@ -2175,7 +2475,7 @@ void main() { vals[tid] = uintBitsToFloat(0xFF800000); [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - vals[tid] = max(vals[tid], FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) + (p.KZ > 0 ? slope * data_c[col] : 0.0f)); + vals[tid] = max(vals[tid], FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) + (p.KZ > 0 ? slope * FLOAT_TYPE(data_c[col]) : 0.0f)); } barrier(); @@ -2545,13 +2845,13 @@ async def main(): stream.clear() stream.extend((mulmat_head, shader_float_type, mulmat_body1, mulmat_load_scalar, mulmat_body2)) tasks.append(string_to_spv("matmul_f32", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f32_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_f32_aligned", "".join(stream), {"LOAD_VEC_A": 1, "LOAD_VEC_B": load_vec, "A_TYPE": "float", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) tasks.append(string_to_spv("matmul_f16", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f16_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_f16_aligned", "".join(stream), {"LOAD_VEC_A": 1, "LOAD_VEC_B": load_vec, "A_TYPE": "float16_t", "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16)) tasks.append(string_to_spv("matmul_f16_f32", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f16_f32_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("matmul_f16_f32_aligned", "".join(stream), {"LOAD_VEC_A": 1, "LOAD_VEC_B": load_vec, "A_TYPE": "float16_t", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) stream.clear() stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2)) @@ -2603,6 +2903,68 @@ async def main(): tasks.append(string_to_spv("matmul_q6_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q6_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) tasks.append(string_to_spv("matmul_q6_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q6_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + # MUL_MAT_ID + # stream.clear() + # stream.extend((mulmat_head, shader_float_type, mulmat_body1, mulmat_load_scalar, mulmat_body2)) + # tasks.append(string_to_spv("matmul_id_f32", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + # tasks.append(string_to_spv("matmul_id_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + # tasks.append(string_to_spv("matmul_id_f16", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16)) + # tasks.append(string_to_spv("matmul_id_f16_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16)) + + # tasks.append(string_to_spv("matmul_id_f16_f32", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + # tasks.append(string_to_spv("matmul_id_f16_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + # stream.clear() + # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2)) + # tasks.append(string_to_spv("matmul_id_q4_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + # tasks.append(string_to_spv("matmul_id_q4_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + # stream.clear() + # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_1_defines, mulmat_body1, mulmat_load_q4_1, mulmat_body2)) + # tasks.append(string_to_spv("matmul_id_q4_1_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + # tasks.append(string_to_spv("matmul_id_q4_1_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + # stream.clear() + # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_0_defines, mulmat_body1, mulmat_load_q5_0, mulmat_body2)) + # tasks.append(string_to_spv("matmul_id_q5_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + # tasks.append(string_to_spv("matmul_id_q5_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + # stream.clear() + # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_1_defines, mulmat_body1, mulmat_load_q5_1, mulmat_body2)) + # tasks.append(string_to_spv("matmul_id_q5_1_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + # tasks.append(string_to_spv("matmul_id_q5_1_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + # stream.clear() + # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q8_0_defines, mulmat_body1, mulmat_load_q8_0, mulmat_body2)) + # tasks.append(string_to_spv("matmul_id_q8_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q8_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + # tasks.append(string_to_spv("matmul_id_q8_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q8_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + # stream.clear() + # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q2_K_defines, mulmat_body1, mulmat_load_q2_K, mulmat_body2)) + # tasks.append(string_to_spv("matmul_id_q2_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q2_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + # tasks.append(string_to_spv("matmul_id_q2_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q2_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + # stream.clear() + # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q3_K_defines, mulmat_body1, mulmat_load_q3_K, mulmat_body2)) + # tasks.append(string_to_spv("matmul_id_q3_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q3_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + # tasks.append(string_to_spv("matmul_id_q3_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q3_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + # stream.clear() + # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_K_defines, mulmat_body1, mulmat_load_q4_K, mulmat_body2)) + # tasks.append(string_to_spv("matmul_id_q4_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + # tasks.append(string_to_spv("matmul_id_q4_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + # stream.clear() + # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_K_defines, mulmat_body1, mulmat_load_q5_K, mulmat_body2)) + # tasks.append(string_to_spv("matmul_id_q5_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + # tasks.append(string_to_spv("matmul_id_q5_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + + # stream.clear() + # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q6_K_defines, mulmat_body1, mulmat_load_q6_K, mulmat_body2)) + # tasks.append(string_to_spv("matmul_id_q6_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q6_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + # tasks.append(string_to_spv("matmul_id_q6_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q6_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) + # Shaders where precision is needed, so no fp16 version # mul mat vec @@ -2611,31 +2973,34 @@ async def main(): stream.extend((mul_mat_vec_head, shader_int8_ext, shader_f32)) if i == GGML_TYPE_F16: - stream.extend((shader_f16_defines, shader_float_dequant_func, mul_mat_vec_body)) + stream.extend((shader_f16_defines, mul_mat_vec_layout, shader_float_dequant_func, mul_mat_vec_body)) elif i == GGML_TYPE_Q4_0: - stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, mul_mat_vec_body)) + stream.extend((shader_q4_0_defines, mul_mat_vec_layout, shader_q4_0_dequant_func, mul_mat_vec_body)) elif i == GGML_TYPE_Q4_1: - stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, mul_mat_vec_body)) + stream.extend((shader_q4_1_defines, mul_mat_vec_layout, shader_q4_1_dequant_func, mul_mat_vec_body)) elif i == GGML_TYPE_Q5_0: - stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, mul_mat_vec_body)) + stream.extend((shader_q5_0_defines, mul_mat_vec_layout, shader_q5_0_dequant_func, mul_mat_vec_body)) elif i == GGML_TYPE_Q5_1: - stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, mul_mat_vec_body)) + stream.extend((shader_q5_1_defines, mul_mat_vec_layout, shader_q5_1_dequant_func, mul_mat_vec_body)) elif i == GGML_TYPE_Q8_0: - stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, mul_mat_vec_body)) + stream.extend((shader_q8_0_defines, mul_mat_vec_layout, shader_q8_0_dequant_func, mul_mat_vec_body)) elif i == GGML_TYPE_Q2_K: - stream.extend((shader_q2_K_defines, mul_mat_vec_q2_K_body)) + stream.extend((shader_q2_K_defines, mul_mat_vec_layout, mul_mat_vec_q2_K_body)) elif i == GGML_TYPE_Q3_K: - stream.extend((shader_q3_K_defines, mul_mat_vec_q3_K_body)) + stream.extend((shader_q3_K_defines, mul_mat_vec_layout, mul_mat_vec_q3_K_body)) elif i == GGML_TYPE_Q4_K: - stream.extend((shader_q4_K_defines, mul_mat_vec_q4_K_body)) + stream.extend((shader_q4_K_defines, mul_mat_vec_layout, mul_mat_vec_q4_K_body)) elif i == GGML_TYPE_Q5_K: - stream.extend((shader_q5_K_defines, mul_mat_vec_q5_K_body)) + stream.extend((shader_q5_K_defines, mul_mat_vec_layout, mul_mat_vec_q5_K_body)) elif i == GGML_TYPE_Q6_K: - stream.extend((shader_q6_K_defines, mul_mat_vec_q6_K_body)) + stream.extend((shader_q6_K_defines, mul_mat_vec_layout, mul_mat_vec_q6_K_body)) else: continue - tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION})) + tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION})) + tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f16_f32", "".join(stream), {"B_TYPE": "float16_t", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION})) + + # tasks.append(string_to_spv(f"mul_mat_vec_id_{type_names[i]}_f32", "".join(stream), {"MUL_MAT_ID": "1", "B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION})) # Dequant shaders for i in range(0, VK_NUM_TYPES): @@ -2677,20 +3042,20 @@ async def main(): optimization_workaround = False if i == GGML_TYPE_F32: - stream.extend((shader_f32_defines, generic_binary_op_funcs, get_rows_float_body)) + stream.extend((shader_f32_defines, generic_binary_op_layout, generic_binary_op_funcs, get_rows_float_body)) elif i == GGML_TYPE_F16: - stream.extend((shader_f16_defines, generic_binary_op_funcs, get_rows_float_body)) + stream.extend((shader_f16_defines, generic_binary_op_layout, generic_binary_op_funcs, get_rows_float_body)) optimization_workaround = True elif i == GGML_TYPE_Q4_0: - stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, generic_binary_op_funcs, get_rows_body)) + stream.extend((shader_q4_0_defines, generic_binary_op_layout, shader_q4_0_dequant_func, generic_binary_op_funcs, get_rows_body)) elif i == GGML_TYPE_Q4_1: - stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, generic_binary_op_funcs, get_rows_body)) + stream.extend((shader_q4_1_defines, generic_binary_op_layout, shader_q4_1_dequant_func, generic_binary_op_funcs, get_rows_body)) elif i == GGML_TYPE_Q5_0: - stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, generic_binary_op_funcs, get_rows_body)) + stream.extend((shader_q5_0_defines, generic_binary_op_layout, shader_q5_0_dequant_func, generic_binary_op_funcs, get_rows_body)) elif i == GGML_TYPE_Q5_1: - stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, generic_binary_op_funcs, get_rows_body)) + stream.extend((shader_q5_1_defines, generic_binary_op_layout, shader_q5_1_dequant_func, generic_binary_op_funcs, get_rows_body)) elif i == GGML_TYPE_Q8_0: - stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, generic_binary_op_funcs, get_rows_body)) + stream.extend((shader_q8_0_defines, generic_binary_op_layout, shader_q8_0_dequant_func, generic_binary_op_funcs, get_rows_body)) else: continue @@ -2729,6 +3094,7 @@ async def main(): tasks.append(string_to_spv("diag_mask_inf_f32", f"{diag_mask_inf_head}\n{shader_f32}\n{diag_mask_inf_body}", {"A_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("soft_max_f32", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "C_TYPE": "float", "D_TYPE": "float"})) + tasks.append(string_to_spv("soft_max_f32_f16", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float16_t", "C_TYPE": "float16_t", "D_TYPE": "float"})) tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) |