summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/flash_attn.comp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders/flash_attn.comp')
-rw-r--r--ggml/src/vulkan-shaders/flash_attn.comp337
1 files changed, 337 insertions, 0 deletions
diff --git a/ggml/src/vulkan-shaders/flash_attn.comp b/ggml/src/vulkan-shaders/flash_attn.comp
new file mode 100644
index 00000000..ce230a8f
--- /dev/null
+++ b/ggml/src/vulkan-shaders/flash_attn.comp
@@ -0,0 +1,337 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+
+#extension GL_KHR_shader_subgroup_shuffle : enable
+
+#include "types.comp"
+#include "flash_attn_base.comp"
+
+const uint32_t D_per_thread = D / D_split;
+
+const uint32_t cols_per_iter = WorkGroupSize / D_split;
+const uint32_t cols_per_thread = Bc / cols_per_iter;
+
+
+layout (binding = 0) readonly buffer Q {float data_q[];};
+layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
+layout (binding = 1) readonly buffer K {float16_t data_k[];};
+layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
+layout (binding = 2) readonly buffer V {float16_t data_v[];};
+layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
+layout (binding = 3) readonly buffer M {float16_t data_m[];};
+
+// Store the output when doing grouped query attention.
+// Rows index by Q's dimension 2, and the first N rows are valid.
+D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
+{
+ uint32_t offset = (iq2 + r) * D + c;
+ data_o[o_offset + offset] = D_TYPE(elem);
+ return elem;
+}
+
+shared FLOAT_TYPE tmpsh[WorkGroupSize];
+shared vec4 tmpshv4[WorkGroupSize];
+
+shared float masksh[Bc][Br];
+shared vec4 Qf[Br][D / 4];
+
+void main() {
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+#endif
+
+ init_indices();
+
+ const uint32_t tid = gl_LocalInvocationIndex;
+ const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
+ const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
+
+ uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
+
+ [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
+ uint32_t d = (idx + tid) % (D / 4);
+ uint32_t r = (idx + tid) / (D / 4);
+ if (r < Br && d < D / 4 &&
+ i * Br + r < N) {
+ Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
+ }
+ }
+ barrier();
+
+ vec4 Of[Br][D_per_thread / 4];
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Of[r][d] = vec4(0.0);
+ }
+ }
+
+ float Lf[Br], Mf[Br];
+
+ // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
+ const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
+
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Lf[r] = 0;
+ Mf[r] = NEG_FLT_MAX_OVER_2;
+ }
+
+ float slope[Br];
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ slope[r] = 1.0;
+ }
+
+ // ALiBi
+ if (p.max_bias > 0.0f) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
+ }
+ }
+
+#if BLOCK_SIZE > 1
+ uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
+ uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
+#else
+ uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
+ uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
+#endif
+
+ [[dont_unroll]]
+ for (uint32_t j = start_j; j < end_j; ++j) {
+
+ float Sf[Br][cols_per_thread];
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ Sf[r][c] = 0.0;
+ }
+ }
+
+
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+#if BLOCK_SIZE > 1
+ uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+ uint ib = coord / BLOCK_SIZE;
+ uint iqs = (coord % BLOCK_SIZE);
+ vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+#else
+ vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
+#endif
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
+ }
+ }
+ }
+
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ // Compute sum across the D_split
+ [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
+ }
+ }
+ }
+
+ if (p.logit_softcap != 0.0f) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
+ }
+ }
+ }
+
+ if (p.mask != 0) {
+
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
+ uint32_t c = (idx + tid) % Bc;
+ uint32_t r = (idx + tid) / Bc;
+ if (idx + tid < Bc * Br) {
+ masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
+ }
+ }
+ barrier();
+
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ float mvf = masksh[c * cols_per_iter + col_tid][r];
+
+ Sf[r][c] += slope[r]*mvf;
+ }
+ }
+ barrier();
+ }
+
+ float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ rowmaxf[r] = Sf[r][0];
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
+ }
+ Moldf[r] = Mf[r];
+
+ // M = max(rowmax, Mold)
+ // P = e^(S - M)
+ // eM = e^(Mold - M)
+ Mf[r] = max(rowmaxf[r], Moldf[r]);
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ Pf[r][c] = exp(Sf[r][c] - Mf[r]);
+ }
+ eMf[r] = exp(Moldf[r] - Mf[r]);
+
+ // Compute sum across row of P
+ rowsumf[r] = 0.0;
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ rowsumf[r] += Pf[r][c];
+ }
+
+ Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
+ }
+
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Of[r][d] = eMf[r] * Of[r][d];
+ }
+ }
+
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+#if BLOCK_SIZE > 1
+ uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+ uint ib = coord / BLOCK_SIZE;
+ uint iqs = (coord % BLOCK_SIZE);
+ vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+#else
+ vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
+#endif
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Of[r][d] += Pf[r][c] * Vf;
+ }
+ }
+ }
+
+ barrier();
+ }
+
+ // reduce across threads
+
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ float rowmaxf, eMf;
+
+ tmpsh[tid] = Mf[r];
+ // Compute max across the row
+ barrier();
+ [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
+ if (tid < s) {
+ tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
+ }
+ barrier();
+ }
+ rowmaxf = tmpsh[d_tid];
+ barrier();
+
+ float Moldf = Mf[r];
+
+ // M = max(rowmax, Mold)
+ // eM = e^(Mold - M)
+ Mf[r] = max(rowmaxf, Moldf);
+ eMf = exp(Moldf - Mf[r]);
+
+ Lf[r] = eMf*Lf[r];
+
+ tmpsh[tid] = Lf[r];
+
+ // Compute sum across the row
+ barrier();
+ [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
+ if (tid < s) {
+ tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
+ }
+ barrier();
+ }
+ Lf[r] = tmpsh[d_tid];
+ barrier();
+
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+
+ Of[r][d] = eMf * Of[r][d];
+ tmpshv4[tid] = Of[r][d];
+
+ barrier();
+ [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
+ if (tid < s) {
+ Of[r][d] += tmpshv4[tid + s];
+ tmpshv4[tid] = Of[r][d];
+ }
+ barrier();
+ }
+ Of[r][d] = tmpshv4[d_tid];
+ barrier();
+ }
+ }
+
+
+ // If there is split_k, then the split_k resolve shader does the final
+ // division by L. Store the intermediate O value and per-row m and L values.
+ if (p.k_num > 1) {
+ uint32_t o_offset = D * p.ne1 * split_k_index;
+
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ if (r < N) {
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
+ perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
+ }
+ }
+ }
+ }
+
+ o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ if (r < N) {
+ perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
+ perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+ }
+ }
+
+ return;
+ }
+
+ float Lfrcp[Br];
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Lfrcp[r] = 1.0 / Lf[r];
+ }
+
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ Of[r][d] *= Lfrcp[r];
+ }
+ }
+
+ uint32_t o_offset = iq3*p.ne2*p.ne1;
+
+ if (p.gqa_ratio > 1) {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ if (r < N) {
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
+ perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
+ }
+ }
+ }
+ }
+ } else {
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+ if (i * Br + r < N) {
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
+ data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
+ }
+ }
+ }
+ }
+ }
+}