diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-20 19:23:10 +0300 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:52 +0300 |
commit | 729ba46f774a4ba9af48ce6708da653ee80d2296 (patch) | |
tree | 1a82e24c139d77a8879d85bc8bbb9e3ec8f18ccd | |
parent | f0325c5826c55bb9796485d49bc971a17735e96a (diff) |
bitnet(scale in a separate tensor): CPU tweaks
I had ruined TG performance on AVX2 with the last commit.
Was just testing at 8 threads and there we are totally memory
bound. But at 4 threads we had regressed to 41 t/s on the Ryzen7950.
Back to 51 t/s with this commit.
-rw-r--r-- | iqk_mul_mat.cpp | 85 |
1 files changed, 46 insertions, 39 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 1d382a41..7f57593b 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1478,54 +1478,61 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx); - { - auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[0].qs); - auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[1].qs); - auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1); - auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2); - auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8); - auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8); - auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8); - auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8); - for (int iy = 0; iy < nrc_y; ++iy) { - auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, 0, 0), v1); - auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, 0, 1), v2); - auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, 0, 2), v3); - auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, 0, 3), v4); + if constexpr (nrc_y == 1) { + __m256i acc[2] = {}; + for (int i = 0; i < nb/2; ++i) { + auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[2*i+0].qs); + auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[2*i+1].qs); + auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1); + auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2); + auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8); + auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8); + auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8); + auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8); #if defined __AVX512VNNI__ && defined __AVX512VL__ - accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( - _mm256_setzero_si256(), m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4); + acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), v1)), + m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), v2)); + acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), v3)), + m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), v4)); #else - accd[iy] = _mm256_madd_epi16(m1_16, _mm256_add_epi16( - _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)), - _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4)))); + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), v1)), + _mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), v2))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), v3)), + _mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), v4))); + acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1)); + acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2)); #endif } + accd[0] = _mm256_add_epi32(acc[0], acc[1]); } + else { - for (int i = 1; i < nb/2; ++i) { - auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[2*i+0].qs); - auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[2*i+1].qs); - auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1); - auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2); - auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8); - auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8); - auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8); - auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8); - for (int iy = 0; iy < nrc_y; ++iy) { - auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1); - auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2); - auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), v3); - auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), v4); + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256(); + + for (int i = 0; i < nb/2; ++i) { + auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[2*i+0].qs); + auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[2*i+1].qs); + auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1); + auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2); + auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8); + auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8); + auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8); + auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8); + for (int iy = 0; iy < nrc_y; ++iy) { + auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1); + auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2); + auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), v3); + auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), v4); #if defined __AVX512VNNI__ && defined __AVX512VL__ - accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( - accd[iy], m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4); + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( + accd[iy], m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4); #else - auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( - _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)), - _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4)))); - accd[iy] = _mm256_add_epi32(dot, accd[iy]); + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( + _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)), + _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4)))); + accd[iy] = _mm256_add_epi32(dot, accd[iy]); #endif + } } } int i = 2*(nb/2); |