summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/flash_attn_base.comp
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-07-05 15:14:12 +0200
committerGitHub <noreply@github.com>2025-07-05 15:14:12 +0200
commit4622fadc2a2665b731a5887f93e295f0331ed80e (patch)
tree31fef5de7e4282cef3fd9b6cd3505ddbfa104672 /ggml/src/vulkan-shaders/flash_attn_base.comp
parent0678427f82686e9bb37d02bf5842e451bb742808 (diff)
Vulkan: flash attention for DeepSeek models (#584)
* vulkan: support mixed/deepseekR1 FA head sizes (#14509) * vulkan: better parameterize FA by head sizes * vulkan: support mixed/deepseekR1 FA head sizes * Fix the FA cherry-pick --------- Co-authored-by: Jeff Bolz <jbolz@nvidia.com> Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/vulkan-shaders/flash_attn_base.comp')
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_base.comp8
1 files changed, 4 insertions, 4 deletions
diff --git a/ggml/src/vulkan-shaders/flash_attn_base.comp b/ggml/src/vulkan-shaders/flash_attn_base.comp
index 61d90e2d..1d3e6387 100644
--- a/ggml/src/vulkan-shaders/flash_attn_base.comp
+++ b/ggml/src/vulkan-shaders/flash_attn_base.comp
@@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
-layout (constant_id = 3) const uint32_t D = 32;
-layout (constant_id = 4) const uint32_t Clamp = 0;
-layout (constant_id = 5) const uint32_t D_split = 16;
-
+layout (constant_id = 3) const uint32_t HSK = 32;
+layout (constant_id = 4) const uint32_t HSV = 32;
+layout (constant_id = 5) const uint32_t Clamp = 0;
+layout (constant_id = 6) const uint32_t D_split = 16;
layout (push_constant) uniform parameter {
uint32_t N;