summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_flash_attn.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_flash_attn.cpp')
-rw-r--r--ggml/src/iqk/iqk_flash_attn.cpp2
1 files changed, 1 insertions, 1 deletions
diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp
index 639b79af..0de68b94 100644
--- a/ggml/src/iqk/iqk_flash_attn.cpp
+++ b/ggml/src/iqk/iqk_flash_attn.cpp
@@ -154,7 +154,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
}
if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
- int nk = 32 * (nek2*nek1/(32*nth));
+ int nk = std::max(1, 32 * (nek2*nek1/(32*nth)));
int nkk = (nek1 + nk - 1)/nk;
int nstep_k = nek2*nkk;
auto result_size = (Dv + 16)*rk2*sizeof(float);