summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/flash_attn.comp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders/flash_attn.comp')
-rw-r--r--ggml/src/vulkan-shaders/flash_attn.comp43
1 files changed, 22 insertions, 21 deletions
diff --git a/ggml/src/vulkan-shaders/flash_attn.comp b/ggml/src/vulkan-shaders/flash_attn.comp
index ce230a8f..454b3411 100644
--- a/ggml/src/vulkan-shaders/flash_attn.comp
+++ b/ggml/src/vulkan-shaders/flash_attn.comp
@@ -11,7 +11,8 @@
#include "types.comp"
#include "flash_attn_base.comp"
-const uint32_t D_per_thread = D / D_split;
+const uint32_t HSK_per_thread = HSK / D_split;
+const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
@@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
- uint32_t offset = (iq2 + r) * D + c;
+ uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
@@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
shared vec4 tmpshv4[WorkGroupSize];
shared float masksh[Bc][Br];
-shared vec4 Qf[Br][D / 4];
+shared vec4 Qf[Br][HSK / 4];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
@@ -53,18 +54,18 @@ void main() {
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
- [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
- uint32_t d = (idx + tid) % (D / 4);
- uint32_t r = (idx + tid) / (D / 4);
- if (r < Br && d < D / 4 &&
+ [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
+ uint32_t d = (idx + tid) % (HSK / 4);
+ uint32_t r = (idx + tid) / (HSK / 4);
+ if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
}
}
barrier();
- vec4 Of[Br][D_per_thread / 4];
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ vec4 Of[Br][HSV_per_thread / 4];
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = vec4(0.0);
}
@@ -112,7 +113,7 @@ void main() {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
@@ -191,14 +192,14 @@ void main() {
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
}
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = eMf[r] * Of[r][d];
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
@@ -255,7 +256,7 @@ void main() {
Lf[r] = tmpsh[d_tid];
barrier();
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = eMf * Of[r][d];
tmpshv4[tid] = Of[r][d];
@@ -277,11 +278,11 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
- uint32_t o_offset = D * p.ne1 * split_k_index;
+ uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
@@ -289,7 +290,7 @@ void main() {
}
}
- o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
+ o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -305,18 +306,18 @@ void main() {
Lfrcp[r] = 1.0 / Lf[r];
}
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r];
}
}
- uint32_t o_offset = iq3*p.ne2*p.ne1;
+ uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
@@ -326,9 +327,9 @@ void main() {
} else {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (i * Br + r < N) {
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
- data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
+ data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
}
}