diff options
Diffstat (limited to 'ggml/src/vulkan-shaders/flash_attn_cm2.comp')
-rw-r--r-- | ggml/src/vulkan-shaders/flash_attn_cm2.comp | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/ggml/src/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/vulkan-shaders/flash_attn_cm2.comp index 91caa184..274f48fc 100644 --- a/ggml/src/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/vulkan-shaders/flash_attn_cm2.comp @@ -130,6 +130,11 @@ void main() { coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + } + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { @@ -148,14 +153,14 @@ void main() { } } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv; - coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv); } |