diff options
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 54792c12..3cb7573b 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -17242,7 +17242,7 @@ struct FlashAttn { q_size = GGML_PAD(q_size, 64); if (q_size > kMaxOnStackSize) { auto qptr = get_q_storage(q_size); - if (nq1 >= 8) { + if (false && nq1 >= 8) { if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) { #if FA_TIMING auto t1 = Perf::cur_time(); @@ -17929,6 +17929,12 @@ inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, if (M && S) { M += n; S += n; } return false; }; + if (nq1 >= 16) { + int n_step = nq1/16; + FlashAttn<576, 512, 16, step_k> fa(scale, softcap); + fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(16*n_step)) return; + } if (nq1 >= 8) { int n_step = nq1/8; FlashAttn<576, 512, 8, step_k> fa(scale, softcap); |