summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/flash_attn_cm1.comp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders/flash_attn_cm1.comp')
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_cm1.comp8
1 files changed, 6 insertions, 2 deletions
diff --git a/ggml/src/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/vulkan-shaders/flash_attn_cm1.comp
index ad7594fe..486735fe 100644
--- a/ggml/src/vulkan-shaders/flash_attn_cm1.comp
+++ b/ggml/src/vulkan-shaders/flash_attn_cm1.comp
@@ -125,6 +125,10 @@ void main() {
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
+ uint32_t m_offset = 0;
+ if (p.nem2 != 1 || p.nem3 != 1) {
+ m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+ }
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
@@ -178,12 +182,12 @@ void main() {
barrier();
}
- if (p.mask != 0) {
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
- sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
+ sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
}
}
barrier();