summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/flash_attn_base.comp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders/flash_attn_base.comp')
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_base.comp15
1 files changed, 10 insertions, 5 deletions
diff --git a/ggml/src/vulkan-shaders/flash_attn_base.comp b/ggml/src/vulkan-shaders/flash_attn_base.comp
index 1d3e6387..7defe72b 100644
--- a/ggml/src/vulkan-shaders/flash_attn_base.comp
+++ b/ggml/src/vulkan-shaders/flash_attn_base.comp
@@ -24,6 +24,8 @@ layout (push_constant) uniform parameter {
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
+ uint32_t nem2;
+ uint32_t nem3;
uint32_t nb01;
uint32_t nb02;
@@ -34,14 +36,12 @@ layout (push_constant) uniform parameter {
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
- uint32_t nb31;
float scale;
float max_bias;
float logit_softcap;
- uint32_t mask;
- uint32_t n_head_log2;
+ uint32_t mask_n_head_log2;
float m0;
float m1;
@@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
uint32_t k_num;
} p;
+#define MASK_ENABLE_BIT (1<<16)
+#define N_LOG2_MASK 0xFFFF
+
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
@@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
- const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
- const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
+ uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
+
+ const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
+ const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}