#version 450 #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require #if USE_SUBGROUP_ADD #extension GL_KHR_shader_subgroup_arithmetic : enable #endif #define FLOAT_TYPE float layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; layout(constant_id = 0) const int BLOCK_SIZE = 32; // gqa_ratio is in the range [1,8] layout(constant_id = 1) const uint gqa_ratio = 1; layout (push_constant) uniform parameter { uint ncols_x; uint nrows_x; uint nchannels_x; uint nchannels_y; uint b_offset; uint d_offset; } p; #if !USE_SUBGROUP_ADD shared FLOAT_TYPE tmp[8][BLOCK_SIZE]; #endif void main() { const uint tid = gl_LocalInvocationID.x; const uint row_x = gl_GlobalInvocationID.y; uint channel, channel_x; // When gqa_ratio > 1, each invocation does multiple rows. // The row in the A matrix is starting from channel / gqa_ratio and the // rows in the B matrix are [channel, channel+gqa_ratio). // When gpa_ratio is 1, each invocation does one row. if (gqa_ratio > 1) { channel_x = gl_GlobalInvocationID.z; channel = channel_x * gqa_ratio; } else { channel = gl_GlobalInvocationID.z; channel_x = channel / (p.nchannels_y / p.nchannels_x);; } const uint nrows_y = p.ncols_x; const uint nrows_dst = p.nrows_x; const uint row_dst = row_x; FLOAT_TYPE temp[8]; [[unroll]] for (uint i = 0; i < 8; ++i) { temp[i] = FLOAT_TYPE(0.0f); } // Detect alignment for vector loads bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0; for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { // Use vec4 loads if aligned if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { uint col_x = col_x0 + 4*tid; const uint row_y = col_x; // x is transposed and permuted const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; const vec4 av4 = vec4(data_a_v4[ix / 4]); [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { // y is not transposed but permuted const uint iy = (channel + c)*nrows_y + row_y; vec4 bv4 = data_b_v4[iy / 4]; temp[c] += dot(av4, bv4); } col_x0 += 3*BLOCK_SIZE; } else { const uint col_x = col_x0 + tid; if (col_x >= p.ncols_x) { break; } // x is transposed and permuted const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); const uint row_y = col_x; [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { // y is not transposed but permuted const uint iy = (channel + c)*nrows_y + row_y; temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]); } } } #if USE_SUBGROUP_ADD // reduce vec4 at a time vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]); t = subgroupAdd(t); temp[0] = t[0]; temp[1] = t[1]; temp[2] = t[2]; temp[3] = t[3]; if (gqa_ratio > 4) { t = vec4(temp[4], temp[5], temp[6], temp[7]); t = subgroupAdd(t); temp[4] = t[0]; temp[5] = t[1]; temp[6] = t[2]; temp[7] = t[3]; } #else [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { tmp[c][tid] = temp[c]; } // sum up partial sums and write back result barrier(); [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { if (tid < s) { [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { temp[c] += tmp[c][tid + s]; tmp[c][tid] = temp[c]; } } barrier(); } [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { temp[c] = tmp[c][tid]; } #endif if (tid == 0) { [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { // dst is not transposed and not permuted const uint idst = (channel + c)*nrows_dst + row_dst; dst[idst] = temp[c]; } } }