layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (constant_id = 0) const uint32_t WorkGroupSize = 128; layout (constant_id = 1) const uint32_t Br = 1; layout (constant_id = 2) const uint32_t Bc = 32; layout (constant_id = 3) const uint32_t HSK = 32; layout (constant_id = 4) const uint32_t HSV = 32; layout (constant_id = 5) const uint32_t Clamp = 0; layout (constant_id = 6) const uint32_t D_split = 16; layout (push_constant) uniform parameter { uint32_t N; uint32_t KV; uint32_t ne1; uint32_t ne2; uint32_t ne3; uint32_t neq2; uint32_t neq3; uint32_t nek2; uint32_t nek3; uint32_t nev2; uint32_t nev3; uint32_t nem1; uint32_t nem2; uint32_t nem3; uint32_t nb01; uint32_t nb02; uint32_t nb03; uint32_t nb11; uint32_t nb12; uint32_t nb13; uint32_t nb21; uint32_t nb22; uint32_t nb23; float scale; float max_bias; float logit_softcap; uint32_t mask_n_head_log2; float m0; float m1; uint32_t gqa_ratio; uint32_t split_kv; uint32_t k_num; } p; #define MASK_ENABLE_BIT (1<<16) #define N_LOG2_MASK 0xFFFF layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; #if defined(A_TYPE_PACKED16) #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; #endif #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); uint shift = (iqs & 0x10) >> 2; vui_lo >>= shift; vui_hi >>= shift; return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); } #endif #if defined(DATA_A_Q8_0) #define BLOCK_BYTE_SIZE 34 vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); } #endif #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) // Store column zero. This is used to save per-row m and L values for split_k. ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { if (r < N && c == 0) { uint32_t offset = iq2 + r; data_o[o_offset + offset] = D_TYPE(elem); } return elem; } // Load the slope matrix, indexed by Q's dimension 2. ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) { const uint32_t h = iq2 + (r % p.gqa_ratio); uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK; const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1); const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1); return ACC_TYPE(pow(base, ACC_TYPE(exph))); } uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, q_stride, k_stride, v_stride, m_stride; void init_indices() { N = p.N; KV = p.KV; i = gl_WorkGroupID.x; split_k_index = 0; if (p.k_num > 1) { i = 0; split_k_index = gl_WorkGroupID.x; } Tr = CEIL_DIV(N, Br); start_j = split_k_index * p.split_kv / Bc; end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. iq2 = gl_WorkGroupID.y * p.gqa_ratio; iq3 = gl_WorkGroupID.z; // broadcast factors rk2 = p.neq2/p.nek2; rk3 = p.neq3/p.nek3; rv2 = p.neq2/p.nev2; rv3 = p.neq3/p.nev3; // k indices ik3 = iq3 / rk3; ik2 = iq2 / rk2; // v indices iv3 = iq3 / rv3; iv2 = iq2 / rv2; // nb?1 are already divided by the type size and are in units of elements. // When using grouped query attention, Q is indexed by iq2, so the stride // should be nb02 (which is in bytes). q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; k_stride = p.nb11; v_stride = p.nb21; // When using grouped query attention, all rows use the same mask (stride 0). // "p.gqa_ratio >> 16" is just a roundabout way of writing zero // that prevents the compiler from folding the "&" through the select // and breaking the alignment detection. m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; }