diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-07-15 08:03:13 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-07-15 08:03:13 +0200 |
commit | 2081b3fccb9923699bf4d5e926d8719fc1d12c39 (patch) | |
tree | 61b3665214941b4857466fdea8220159d81a609e /ggml/src/vulkan-shaders/flash_attn.comp | |
parent | 45fae1a14444622478774f9a417e1d417af1ca46 (diff) |
Vulkan: a fresh start (#608)
* It compiles
* Seems to be working with coopmat
* Vulkan needs f32 precision for flash attention
* Vulkan: fix u_batch > 4096/n_active_experts
for coopmat1. Without this fix we get an assert.
We get the same assert in mainline too.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/vulkan-shaders/flash_attn.comp')
-rw-r--r-- | ggml/src/vulkan-shaders/flash_attn.comp | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/ggml/src/vulkan-shaders/flash_attn.comp b/ggml/src/vulkan-shaders/flash_attn.comp index 454b3411..45c6e773 100644 --- a/ggml/src/vulkan-shaders/flash_attn.comp +++ b/ggml/src/vulkan-shaders/flash_attn.comp @@ -100,6 +100,10 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + } [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { @@ -145,13 +149,13 @@ void main() { } } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br) { - masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]); + masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); } } barrier(); |