diff options
Diffstat (limited to 'ggml/src/vulkan-shaders/flash_attn.comp')
-rw-r--r-- | ggml/src/vulkan-shaders/flash_attn.comp | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/ggml/src/vulkan-shaders/flash_attn.comp b/ggml/src/vulkan-shaders/flash_attn.comp index 454b3411..45c6e773 100644 --- a/ggml/src/vulkan-shaders/flash_attn.comp +++ b/ggml/src/vulkan-shaders/flash_attn.comp @@ -100,6 +100,10 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + } [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { @@ -145,13 +149,13 @@ void main() { } } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br) { - masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]); + masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); } } barrier(); |