From 1ea1df4b2d942ebd56efdcdfb922ec92d6dc1db7 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 2 May 2025 07:09:09 +0200 Subject: Fix FA bug on AVX2 (#364) * Fix FA bug on AVX2 * Also this was wrong --------- Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'ggml/src') diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5e49089a..6adc43cf 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -17120,11 +17120,12 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, vh.reset_block(); block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8]; HelperQ80R8 khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0)); - HelperQ80::convert(q_step, stride_q, q, q8); + auto q8r = (typename HelperQ80R8::block_q8 *)qptr; + HelperQ80::convert(q_step, stride_q, q, q8r); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { - HelperQ80R8::repack(k_step, kh.data, kh.stride, q8r8); - KQHelper::mul_mask_kq(khr8, stride_m, q8, mr, fms); + HelperQ80R8::repack(k_step, kh.block, kh.stride, q8r8); + KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms); fqkv.accumulate_qkv(vh, fms); kh.next_block(); vh.next_block(); @@ -17236,7 +17237,8 @@ struct FlashAttn { std::is_same_v> || std::is_same_v>) { constexpr size_t kMaxOnStackSize = 576; - auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8); + //auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8); + auto q_size = q_step*(Dk/QK8_2*sizeof(block_q8_2)); q_size = GGML_PAD(q_size, 64); if (q_size > kMaxOnStackSize) { auto qptr = get_q_storage(q_size); -- cgit v1.2.3