From f853f6c6a54c5b5b59ef2c029f90dc5e21a7beec Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Fri, 13 Sep 2024 07:19:47 +0300 Subject: Fix bug and D < 128 case for Q8_0 k-cache (#52) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 42 ++++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 12 deletions(-) (limited to 'ggml/src') diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ce868514..a70310d4 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -7492,7 +7492,11 @@ struct FlashQKfp32 { #ifdef __aarch64__ mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); #else - mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); + if constexpr (D >= 128) { + mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); + } else { + mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); + } #endif } else if constexpr (std::is_same_v>) { @@ -7523,7 +7527,7 @@ struct FlashQKfp32 { const block_q8 * q, const char * mask, FlashMS& fms) { GGML_ASSERT(nq < 8); if constexpr (std::is_same_v>) { - DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; switch (nq) { #ifdef __aarch64__ case 1: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; @@ -7545,9 +7549,9 @@ struct FlashQKfp32 { } } else if constexpr (std::is_same_v>) { - DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; - switch (nq) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; #ifdef __aarch64__ + switch (nq) { case 1: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 2: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 3: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; @@ -7555,16 +7559,30 @@ struct FlashQKfp32 { case 5: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 6: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 7: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + } #else - case 1: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; -#endif + if constexpr (D >= 128) { + switch (nq) { + case 1: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + } + } else { + switch (nq) { + case 1: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + } } +#endif } else if constexpr (std::is_same_v>) { DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; -- cgit v1.2.3