summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/flash_attn_base.comp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders/flash_attn_base.comp')
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_base.comp162
1 files changed, 162 insertions, 0 deletions
diff --git a/ggml/src/vulkan-shaders/flash_attn_base.comp b/ggml/src/vulkan-shaders/flash_attn_base.comp
new file mode 100644
index 00000000..61d90e2d
--- /dev/null
+++ b/ggml/src/vulkan-shaders/flash_attn_base.comp
@@ -0,0 +1,162 @@
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
+layout (constant_id = 1) const uint32_t Br = 1;
+layout (constant_id = 2) const uint32_t Bc = 32;
+layout (constant_id = 3) const uint32_t D = 32;
+layout (constant_id = 4) const uint32_t Clamp = 0;
+layout (constant_id = 5) const uint32_t D_split = 16;
+
+
+layout (push_constant) uniform parameter {
+ uint32_t N;
+ uint32_t KV;
+
+ uint32_t ne1;
+ uint32_t ne2;
+ uint32_t ne3;
+
+ uint32_t neq2;
+ uint32_t neq3;
+ uint32_t nek2;
+ uint32_t nek3;
+ uint32_t nev2;
+ uint32_t nev3;
+ uint32_t nem1;
+
+ uint32_t nb01;
+ uint32_t nb02;
+ uint32_t nb03;
+ uint32_t nb11;
+ uint32_t nb12;
+ uint32_t nb13;
+ uint32_t nb21;
+ uint32_t nb22;
+ uint32_t nb23;
+ uint32_t nb31;
+
+ float scale;
+ float max_bias;
+ float logit_softcap;
+
+ uint32_t mask;
+ uint32_t n_head_log2;
+ float m0;
+ float m1;
+
+ uint32_t gqa_ratio;
+ uint32_t split_kv;
+ uint32_t k_num;
+} p;
+
+layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
+
+#if defined(A_TYPE_PACKED16)
+#define BINDING_IDX_K 0
+#define BINDING_IDX_V 1
+layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
+#endif
+
+#if defined(DATA_A_Q4_0)
+#define BLOCK_BYTE_SIZE 18
+
+vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+ uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
+ uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
+ uint shift = (iqs & 0x10) >> 2;
+ vui_lo >>= shift;
+ vui_hi >>= shift;
+
+ return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
+}
+#endif
+
+#if defined(DATA_A_Q8_0)
+#define BLOCK_BYTE_SIZE 34
+vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+ const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
+ const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
+
+ return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
+}
+#endif
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+
+// Store column zero. This is used to save per-row m and L values for split_k.
+ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
+{
+ if (r < N && c == 0) {
+ uint32_t offset = iq2 + r;
+ data_o[o_offset + offset] = D_TYPE(elem);
+ }
+ return elem;
+}
+
+// Load the slope matrix, indexed by Q's dimension 2.
+ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
+{
+ const uint32_t h = iq2 + (r % p.gqa_ratio);
+
+ const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
+ const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
+
+ return ACC_TYPE(pow(base, ACC_TYPE(exph)));
+}
+
+uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
+ iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
+ q_stride, k_stride, v_stride, m_stride;
+
+void init_indices()
+{
+ N = p.N;
+ KV = p.KV;
+
+ i = gl_WorkGroupID.x;
+ split_k_index = 0;
+
+ if (p.k_num > 1) {
+ i = 0;
+ split_k_index = gl_WorkGroupID.x;
+ }
+
+ Tr = CEIL_DIV(N, Br);
+
+ start_j = split_k_index * p.split_kv / Bc;
+ end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
+
+ // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
+ // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
+ iq2 = gl_WorkGroupID.y * p.gqa_ratio;
+ iq3 = gl_WorkGroupID.z;
+
+ // broadcast factors
+ rk2 = p.neq2/p.nek2;
+ rk3 = p.neq3/p.nek3;
+
+ rv2 = p.neq2/p.nev2;
+ rv3 = p.neq3/p.nev3;
+
+ // k indices
+ ik3 = iq3 / rk3;
+ ik2 = iq2 / rk2;
+
+ // v indices
+ iv3 = iq3 / rv3;
+ iv2 = iq2 / rv2;
+
+ // nb?1 are already divided by the type size and are in units of elements.
+ // When using grouped query attention, Q is indexed by iq2, so the stride
+ // should be nb02 (which is in bytes).
+ q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
+ k_stride = p.nb11;
+ v_stride = p.nb21;
+ // When using grouped query attention, all rows use the same mask (stride 0).
+ // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
+ // that prevents the compiler from folding the "&" through the select
+ // and breaking the alignment detection.
+ m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
+}