diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-05-13 17:53:20 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-13 17:53:20 +0300 |
commit | 553c08b6b47008928653d5e377211cd38dfaeffc (patch) | |
tree | 6d0f351230cdd1dec5f212df7e457fed2c4b0787 | |
parent | 4ba6bbb44a39c874ed4a98d982a4a975287e23e7 (diff) |
Better CPU FA performance for DeepSeek-Lite (#410)
* Better CPU FA performance for DeepSeek-Lite
* It must be like this
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 54792c12..3cb7573b 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -17242,7 +17242,7 @@ struct FlashAttn { q_size = GGML_PAD(q_size, 64); if (q_size > kMaxOnStackSize) { auto qptr = get_q_storage(q_size); - if (nq1 >= 8) { + if (false && nq1 >= 8) { if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) { #if FA_TIMING auto t1 = Perf::cur_time(); @@ -17929,6 +17929,12 @@ inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, if (M && S) { M += n; S += n; } return false; }; + if (nq1 >= 16) { + int n_step = nq1/16; + FlashAttn<576, 512, 16, step_k> fa(scale, softcap); + fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(16*n_step)) return; + } if (nq1 >= 8) { int n_step = nq1/8; FlashAttn<576, 512, 8, step_k> fa(scale, softcap); |