summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-07-05 15:14:12 +0200
committerGitHub <noreply@github.com>2025-07-05 15:14:12 +0200
commit4622fadc2a2665b731a5887f93e295f0331ed80e (patch)
tree31fef5de7e4282cef3fd9b6cd3505ddbfa104672
parent0678427f82686e9bb37d02bf5842e451bb742808 (diff)
Vulkan: flash attention for DeepSeek models (#584)
* vulkan: support mixed/deepseekR1 FA head sizes (#14509) * vulkan: better parameterize FA by head sizes * vulkan: support mixed/deepseekR1 FA head sizes * Fix the FA cherry-pick --------- Co-authored-by: Jeff Bolz <jbolz@nvidia.com> Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-vulkan.cpp234
-rw-r--r--ggml/src/vulkan-shaders/flash_attn.comp43
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_base.comp8
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_cm1.comp62
-rw-r--r--ggml/src/vulkan-shaders/flash_attn_cm2.comp50
5 files changed, 211 insertions, 186 deletions
diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp
index 758e6e6d..4091c89f 100644
--- a/ggml/src/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan.cpp
@@ -232,6 +232,21 @@ enum vk_device_architecture {
INTEL_XE2,
};
+// HSK x HSV
+enum FaHeadSizes {
+ FA_HEAD_SIZE_64,
+ FA_HEAD_SIZE_80,
+ FA_HEAD_SIZE_96,
+ FA_HEAD_SIZE_112,
+ FA_HEAD_SIZE_128,
+ FA_HEAD_SIZE_192,
+ FA_HEAD_SIZE_192_128,
+ FA_HEAD_SIZE_256,
+ FA_HEAD_SIZE_576_512,
+ FA_HEAD_SIZE_UNSUPPORTED,
+ FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
+};
+
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
vk::PhysicalDeviceProperties props = device.getProperties();
@@ -471,26 +486,11 @@ struct vk_device_struct {
vk_pipeline pipeline_opt_step_adamw_f32;
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
- vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
-
- vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
-
- vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
- vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
+ vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
+
+ vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
+
+ vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_split_k_reduce;
@@ -1685,6 +1685,35 @@ enum FaCodePath {
FA_COOPMAT2,
};
+static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
+ if (hsk != 192 && hsk != 576 && hsk != hsv) {
+ return FA_HEAD_SIZE_UNSUPPORTED;
+ }
+ switch (hsk) {
+ case 64: return FA_HEAD_SIZE_64;
+ case 80: return FA_HEAD_SIZE_80;
+ case 96: return FA_HEAD_SIZE_96;
+ case 112: return FA_HEAD_SIZE_112;
+ case 128: return FA_HEAD_SIZE_128;
+ case 192:
+ if (hsv == 192) {
+ return FA_HEAD_SIZE_192;
+ } else if (hsv == 128) {
+ return FA_HEAD_SIZE_192_128;
+ } else {
+ return FA_HEAD_SIZE_UNSUPPORTED;
+ }
+ case 256: return FA_HEAD_SIZE_256;
+ case 576:
+ if (hsv == 512) {
+ return FA_HEAD_SIZE_576_512;
+ } else {
+ return FA_HEAD_SIZE_UNSUPPORTED;
+ }
+ default: return FA_HEAD_SIZE_UNSUPPORTED;
+ }
+}
+
// number of rows/cols for flash attention shader
static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
@@ -1705,8 +1734,9 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
}
}
-static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
+static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
GGML_UNUSED(clamp);
+ GGML_UNUSED(hsv);
if (path == FA_SCALAR) {
if (small_rows) {
@@ -1730,7 +1760,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
}
// small cols to reduce register count
- if (ggml_is_quantized(type) || D == 256) {
+ if (ggml_is_quantized(type) || hsk >= 256) {
return {64, 32};
}
return {64, 64};
@@ -2023,19 +2053,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
};
- auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
- return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
+ auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
+ return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
};
- auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
+ auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
// For large number of rows, 128 invocations seems to work best.
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
// can't use 256 for D==80.
// For scalar, use 128 (arbitrary)
+ // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
+ const uint32_t D = (hsk|hsv);
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
? scalar_flash_attention_workgroup_size
: ((small_rows && (D % 32) == 0) ? 256 : 128);
- auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
+ auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@@ -2044,26 +2076,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
- return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
+ return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
};
-#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
@@ -3665,7 +3700,6 @@ GGML_CALL void ggml_vk_instance_init() {
}
- size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
@@ -5956,24 +5990,47 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
-static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
+static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
// Needs to be kept up to date on shader changes
+ GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = scalar_flash_attention_num_large_rows;
const uint32_t Bc = scalar_flash_attention_Bc;
+ const uint32_t tmpsh = wg_size * sizeof(float);
+ const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
+
+ const uint32_t masksh = Bc * Br * sizeof(float);
+
+ const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
+
+ const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
+ const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
+
+ VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
+
+ return supported;
+}
+
+static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
+ // Needs to be kept up to date on shader changes
+ GGML_UNUSED(hsv);
+ const uint32_t wg_size = scalar_flash_attention_workgroup_size;
+ const uint32_t Br = coopmat1_flash_attention_num_large_rows;
+ const uint32_t Bc = scalar_flash_attention_Bc;
+
const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8;
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * acctype;
- const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
+ const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
- const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
+ const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
const uint32_t sfsh = Bc * sfshstride * acctype;
- const uint32_t kshstride = D / 4 + 2;
+ const uint32_t kshstride = hsk / 4 + 2;
const uint32_t ksh = Bc * kshstride * f16vec4;
const uint32_t slope = Br * sizeof(float);
@@ -5981,7 +6038,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
- VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
+ VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
return supported;
}
@@ -6005,11 +6062,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const uint32_t nem1 = mask ? mask->ne[1] : 0;
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
- const uint32_t D = neq0;
+ const uint32_t HSK = nek0;
+ const uint32_t HSV = nev0;
uint32_t N = neq1;
const uint32_t KV = nek1;
- GGML_ASSERT(ne0 == D);
+ GGML_ASSERT(ne0 == HSV);
GGML_ASSERT(ne2 == N);
// input tensor rows must be contiguous
@@ -6017,12 +6075,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
- GGML_ASSERT(neq0 == D);
- GGML_ASSERT(nek0 == D);
- GGML_ASSERT(nev0 == D);
+ GGML_ASSERT(neq0 == HSK);
GGML_ASSERT(neq1 == N);
- GGML_ASSERT(nev0 == D);
GGML_ASSERT(nev1 == nek1);
@@ -6043,7 +6098,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
- const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
+ const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
path = FA_SCALAR;
@@ -6096,47 +6151,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
path = FA_SCALAR;
}
+ // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
+ if (path == FA_SCALAR &&
+ !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
+ small_rows = true;
+ }
+
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
+ FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
+
switch (path) {
case FA_SCALAR:
- switch (D) {
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
- default:
- GGML_ASSERT(!"unsupported D value");
- return;
- }
+ pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
break;
case FA_COOPMAT1:
- switch (D) {
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
- default:
- GGML_ASSERT(!"unsupported D value");
- return;
- }
+ pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
break;
case FA_COOPMAT2:
- switch (D) {
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
- default:
- GGML_ASSERT(!"unsupported D value");
- return;
- }
+ pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
break;
default:
GGML_ASSERT(0);
@@ -6166,7 +6199,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Try to use split_k when KV is large enough to be worth the overhead
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
// Try to run two workgroups per SM.
- split_k = ctx->device->shader_core_count * 2 / workgroups_y;
+ split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that.
@@ -6176,9 +6209,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
- // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
- // and the per-row m and L values (ne1 rows).
- const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
+ // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
+ // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
+ const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
if (split_k_size > ctx->device->max_memory_allocation_size) {
GGML_ABORT("Requested preallocation size is too large");
}
@@ -6297,7 +6330,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
ggml_vk_sync_buffers(subctx);
- const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
+ const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
@@ -9634,19 +9667,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
{
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
bool coopmat2 = ctx->device->coopmat2;
- switch (op->src[0]->ne[0]) {
- case 64:
- case 80:
- case 96:
- case 112:
- case 128:
- case 256:
- break;
- default:
- return false;
- }
- if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
- // different head sizes of K and V are not supported yet
+ FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
+ if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
return false;
}
if (op->src[0]->type != GGML_TYPE_F32) {
diff --git a/ggml/src/vulkan-shaders/flash_attn.comp b/ggml/src/vulkan-shaders/flash_attn.comp
index ce230a8f..454b3411 100644
--- a/ggml/src/vulkan-shaders/flash_attn.comp
+++ b/ggml/src/vulkan-shaders/flash_attn.comp
@@ -11,7 +11,8 @@
#include "types.comp"
#include "flash_attn_base.comp"
-const uint32_t D_per_thread = D / D_split;
+const uint32_t HSK_per_thread = HSK / D_split;
+const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
@@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
- uint32_t offset = (iq2 + r) * D + c;
+ uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
@@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
shared vec4 tmpshv4[WorkGroupSize];
shared float masksh[Bc][Br];
-shared vec4 Qf[Br][D / 4];
+shared vec4 Qf[Br][HSK / 4];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
@@ -53,18 +54,18 @@ void main() {
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
- [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
- uint32_t d = (idx + tid) % (D / 4);
- uint32_t r = (idx + tid) / (D / 4);
- if (r < Br && d < D / 4 &&
+ [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
+ uint32_t d = (idx + tid) % (HSK / 4);
+ uint32_t r = (idx + tid) / (HSK / 4);
+ if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
}
}
barrier();
- vec4 Of[Br][D_per_thread / 4];
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ vec4 Of[Br][HSV_per_thread / 4];
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = vec4(0.0);
}
@@ -112,7 +113,7 @@ void main() {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
@@ -191,14 +192,14 @@ void main() {
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
}
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = eMf[r] * Of[r][d];
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
@@ -255,7 +256,7 @@ void main() {
Lf[r] = tmpsh[d_tid];
barrier();
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = eMf * Of[r][d];
tmpshv4[tid] = Of[r][d];
@@ -277,11 +278,11 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
- uint32_t o_offset = D * p.ne1 * split_k_index;
+ uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
@@ -289,7 +290,7 @@ void main() {
}
}
- o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
+ o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -305,18 +306,18 @@ void main() {
Lfrcp[r] = 1.0 / Lf[r];
}
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r];
}
}
- uint32_t o_offset = iq3*p.ne2*p.ne1;
+ uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
@@ -326,9 +327,9 @@ void main() {
} else {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (i * Br + r < N) {
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
- data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
+ data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
}
}
diff --git a/ggml/src/vulkan-shaders/flash_attn_base.comp b/ggml/src/vulkan-shaders/flash_attn_base.comp
index 61d90e2d..1d3e6387 100644
--- a/ggml/src/vulkan-shaders/flash_attn_base.comp
+++ b/ggml/src/vulkan-shaders/flash_attn_base.comp
@@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
-layout (constant_id = 3) const uint32_t D = 32;
-layout (constant_id = 4) const uint32_t Clamp = 0;
-layout (constant_id = 5) const uint32_t D_split = 16;
-
+layout (constant_id = 3) const uint32_t HSK = 32;
+layout (constant_id = 4) const uint32_t HSV = 32;
+layout (constant_id = 5) const uint32_t Clamp = 0;
+layout (constant_id = 6) const uint32_t D_split = 16;
layout (push_constant) uniform parameter {
uint32_t N;
diff --git a/ggml/src/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/vulkan-shaders/flash_attn_cm1.comp
index da478be2..ad7594fe 100644
--- a/ggml/src/vulkan-shaders/flash_attn_cm1.comp
+++ b/ggml/src/vulkan-shaders/flash_attn_cm1.comp
@@ -13,7 +13,9 @@
#include "types.comp"
#include "flash_attn_base.comp"
-const uint32_t D_per_thread = D / D_split;
+const uint32_t HSK_per_thread = HSK / D_split;
+const uint32_t HSV_per_thread = HSV / D_split;
+
const uint32_t row_split = 4;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
@@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
- uint32_t offset = (iq2 + r) * D + c;
+ uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
@@ -44,14 +46,14 @@ const uint32_t MatBc = 16;
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
-const uint32_t qstride = D / 4 + 2; // in units of f16vec4
+const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
-// Avoid padding for D==256 to make it fit in 48KB shmem.
-const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
+// Avoid padding for hsk==256 to make it fit in 48KB shmem.
+const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
shared ACC_TYPE sfsh[Bc * sfshstride];
-const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
+const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
shared f16vec4 ksh[Bc * kshstride];
shared float slope[Br];
@@ -74,18 +76,18 @@ void main() {
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
- [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
- uint32_t d = (idx + tid) % (D / 4);
- uint32_t r = (idx + tid) / (D / 4);
- if (r < Br && d < D / 4 &&
+ [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
+ uint32_t d = (idx + tid) % (HSK / 4);
+ uint32_t r = (idx + tid) / (HSK / 4);
+ if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
}
barrier();
- ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = ACC_TYPEV4(0.0);
}
@@ -127,10 +129,10 @@ void main() {
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
- [[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
- uint32_t d = (idx + tid) % (D / 4);
- uint32_t c = (idx + tid) / (D / 4);
- if (c < Bc && d < D / 4) {
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
+ uint32_t d = (idx + tid) % (HSK / 4);
+ uint32_t c = (idx + tid) / (HSK / 4);
+ if (c < Bc && d < HSK / 4) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
@@ -145,14 +147,14 @@ void main() {
}
barrier();
- // K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
- // Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
+ // K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
+ // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
// This is written transposed in order to allow for N being 8 if implementations need it
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
- for (uint32_t d = 0; d < D / 16; ++d) {
+ for (uint32_t d = 0; d < HSK / 16; ++d) {
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
@@ -202,7 +204,7 @@ void main() {
eMf[r] = exp(Moldf - Mf[r]);
}
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
}
@@ -217,7 +219,7 @@ void main() {
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
Lf[r] += Pf[r];
}
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
@@ -280,7 +282,7 @@ void main() {
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
tmpshv4[tid] = Of[r][d];
@@ -300,11 +302,11 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
- uint32_t o_offset = D * p.ne1 * split_k_index;
+ uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
}
@@ -312,7 +314,7 @@ void main() {
}
}
- o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
+ o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -328,18 +330,18 @@ void main() {
Lfrcp[r] = 1.0 / Lf[r];
}
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= float16_t(Lfrcp[r]);
}
}
- uint32_t o_offset = iq3*p.ne2*p.ne1;
+ uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
}
@@ -349,9 +351,9 @@ void main() {
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (i * Br + tile_row(r) < N) {
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
- data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
+ data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
}
}
diff --git a/ggml/src/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/vulkan-shaders/flash_attn_cm2.comp
index 6acf67a0..91caa184 100644
--- a/ggml/src/vulkan-shaders/flash_attn_cm2.comp
+++ b/ggml/src/vulkan-shaders/flash_attn_cm2.comp
@@ -61,8 +61,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
- if (r < N && c < D) {
- uint32_t offset = (iq2 + r) * D + c;
+ if (r < N && c < HSV) {
+ uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
@@ -86,9 +86,9 @@ void main() {
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
#endif
- tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
- tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
- tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
+ tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
+ tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
+ tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
// hint to the compiler that strides are aligned for the aligned variant of the shader
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
@@ -104,16 +104,16 @@ void main() {
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
- coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
- coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
+ coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
- coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
+ coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
- Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
+ Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
Qf16 *= float16_t(p.scale);
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
@@ -135,10 +135,10 @@ void main() {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
- coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
+ coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
- coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
+ coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
S = coopMatMulAdd(Qf16, K_T, S);
if (p.logit_softcap != 0.0f) {
@@ -203,42 +203,42 @@ void main() {
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
rowsum = coopMatMulAdd(P_A, One, rowsum);
- coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
+ coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
- coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
+ coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
L = eM*L + rowsum;
// This is the "diagonal" matrix in the paper, but since we do componentwise
// multiply rather than matrix multiply it has the diagonal element smeared
// across the row
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
// resize eM by using smear/reduce
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
// multiply with fp16 accumulation, then add to O.
- coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
PV = coopMatMulAdd(P_A, V, PV);
- O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
+ O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
- coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
+ coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
- uint32_t o_offset = D * p.ne1 * split_k_index;
+ uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
- o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
+ o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
return;
}
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
// resize L by using smear/reduce
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
@@ -250,18 +250,18 @@ void main() {
O = Ldiag*O;
- uint32_t o_offset = iq3*p.ne2*p.ne1;
+ uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
- coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
+ coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
if (p.gqa_ratio > 1) {
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
} else {
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
- tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
+ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
// permute dimensions
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
- coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
+ coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
}
}