summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/flash_attn.comp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders/flash_attn.comp')
-rw-r--r--ggml/src/vulkan-shaders/flash_attn.comp8
1 files changed, 6 insertions, 2 deletions
diff --git a/ggml/src/vulkan-shaders/flash_attn.comp b/ggml/src/vulkan-shaders/flash_attn.comp
index 454b3411..45c6e773 100644
--- a/ggml/src/vulkan-shaders/flash_attn.comp
+++ b/ggml/src/vulkan-shaders/flash_attn.comp
@@ -100,6 +100,10 @@ void main() {
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
+ uint32_t m_offset = 0;
+ if (p.nem2 != 1 || p.nem3 != 1) {
+ m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+ }
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
@@ -145,13 +149,13 @@ void main() {
}
}
- if (p.mask != 0) {
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
- masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
+ masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
}
}
barrier();