diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-02-22 14:25:38 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-02-22 14:25:38 +0200 |
commit | 49261058442cfe382dab3270fcd86652296a75c0 (patch) | |
tree | 01447c36a1ddf2045d291484248194345d4368c6 | |
parent | 33646fc40949e0fdcf16d96f5b40d12bf93244a9 (diff) |
Fix #217 (#220)
* Fix #217
* Remove stuff commited by mistake
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 11 |
1 files changed, 3 insertions, 8 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index fffddada..2c4987a0 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -15841,23 +15841,18 @@ struct FlashQKfp32 { #endif constexpr int qrem = q_step - nrc_q*(q_step/nrc_q); constexpr int krem = k_step - nrc_k*(k_step/nrc_k); + static_assert(krem == 0); DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; for (int iq = 0; iq < q_step/nrc_q; ++iq) { for (int ik = 0; ik < k_step/nrc_k; ++ik) { mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, nrc_q>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info); } - if constexpr (krem > 0) { - mul_mat_Qx_Qy_MxN_fa<QFT<q_float, nrc_q>, QFT<ggml_half, krem>>(D, kh.block, kh.stride, k_step - krem, info); - } info.cur_y += nrc_q; } if constexpr (qrem > 0) { for (int ik = 0; ik < k_step/nrc_k; ++ik) { mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, qrem>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info); } - if constexpr (krem > 0) { - mul_mat_Qx_Qy_MxN_fa<QFT<q_float, qrem>, QFT<ggml_half, krem>>(D, kh.block, kh.stride, k_step - krem, info); - } } F16::Data vk[k_step/F16::block_size]; for (int j = 0; j < q_step; ++j) { @@ -15910,7 +15905,7 @@ struct FlashQKfp32 { constexpr int nrc_k = 8; #endif static_assert(k_step%nrc_k == 0); - int qrem = q_step - nrc_q*(q_step/nrc_q); + int qrem = nq - nrc_q*(nq/nrc_q); DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; for (int iq = 0; iq < nq/nrc_q; ++iq) { for (int ik = 0; ik < k_step/nrc_k; ++ik) { @@ -15960,7 +15955,7 @@ struct FlashQKfp32 { } } F16::Data vk[k_step/F16::block_size]; - for (int j = 0; j < q_step; ++j) { + for (int j = 0; j < nq; ++j) { fms.update_M_S(j, vk, mask + stride_m*j); } } |