summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/flash_attn_cm2.comp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders/flash_attn_cm2.comp')
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_cm2.comp9
1 files changed, 7 insertions, 2 deletions
diff --git a/ggml/src/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/vulkan-shaders/flash_attn_cm2.comp
index 91caa184..274f48fc 100644
--- a/ggml/src/vulkan-shaders/flash_attn_cm2.comp
+++ b/ggml/src/vulkan-shaders/flash_attn_cm2.comp
@@ -130,6 +130,11 @@ void main() {
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
}
+ uint32_t m_offset = 0;
+ if (p.nem2 != 1 || p.nem3 != 1) {
+ m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
+ }
+
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
@@ -148,14 +153,14 @@ void main() {
}
}
- if (p.mask != 0) {
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
- coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+ coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}