diff options
Diffstat (limited to 'ggml/src/vulkan-shaders/flash_attn_cm1.comp')
-rw-r--r-- | ggml/src/vulkan-shaders/flash_attn_cm1.comp | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/ggml/src/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/vulkan-shaders/flash_attn_cm1.comp index ad7594fe..486735fe 100644 --- a/ggml/src/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/vulkan-shaders/flash_attn_cm1.comp @@ -125,6 +125,10 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + } [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { @@ -178,12 +182,12 @@ void main() { barrier(); } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)])); + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); } } barrier(); |