summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_flash_attn.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_flash_attn.cpp')
-rw-r--r--ggml/src/iqk/iqk_flash_attn.cpp2
1 files changed, 1 insertions, 1 deletions
diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp
index fd0d5dd0..610f18b7 100644
--- a/ggml/src/iqk/iqk_flash_attn.cpp
+++ b/ggml/src/iqk/iqk_flash_attn.cpp
@@ -81,7 +81,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int int_type_k = int_type_k_in;
auto work_buffer = work_buffer_in;
- if (neq1 >= 8 || rk2 >= 8) {
+ if (neq1 >= 8 || (rk2 >= 8 && nek2 > 1)) {
uint64_t row_size = 0;
work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size);
if (int_type_k != int_type_k_in) {