summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-20 19:23:10 +0300
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:52 +0300
commit729ba46f774a4ba9af48ce6708da653ee80d2296 (patch)
tree1a82e24c139d77a8879d85bc8bbb9e3ec8f18ccd
parentf0325c5826c55bb9796485d49bc971a17735e96a (diff)
bitnet(scale in a separate tensor): CPU tweaks
I had ruined TG performance on AVX2 with the last commit. Was just testing at 8 threads and there we are totally memory bound. But at 4 threads we had regressed to 41 t/s on the Ryzen7950. Back to 51 t/s with this commit.
-rw-r--r--iqk_mul_mat.cpp85
1 files changed, 46 insertions, 39 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 1d382a41..7f57593b 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -1478,54 +1478,61 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx);
- {
- auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
- auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[1].qs);
- auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1);
- auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2);
- auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
- auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8);
- auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8);
- auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, 0, 0), v1);
- auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, 0, 1), v2);
- auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, 0, 2), v3);
- auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, 0, 3), v4);
+ if constexpr (nrc_y == 1) {
+ __m256i acc[2] = {};
+ for (int i = 0; i < nb/2; ++i) {
+ auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[2*i+0].qs);
+ auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[2*i+1].qs);
+ auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1);
+ auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2);
+ auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
+ auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8);
+ auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8);
+ auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8);
#if defined __AVX512VNNI__ && defined __AVX512VL__
- accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
- _mm256_setzero_si256(), m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4);
+ acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), v1)),
+ m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), v2));
+ acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), v3)),
+ m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), v4));
#else
- accd[iy] = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
- _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)),
- _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4))));
+ auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), v1)),
+ _mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), v2)));
+ auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), v3)),
+ _mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), v4)));
+ acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1));
+ acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2));
#endif
}
+ accd[0] = _mm256_add_epi32(acc[0], acc[1]);
}
+ else {
- for (int i = 1; i < nb/2; ++i) {
- auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[2*i+0].qs);
- auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[2*i+1].qs);
- auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1);
- auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2);
- auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
- auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8);
- auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8);
- auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1);
- auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2);
- auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), v3);
- auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), v4);
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
+
+ for (int i = 0; i < nb/2; ++i) {
+ auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[2*i+0].qs);
+ auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[2*i+1].qs);
+ auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1);
+ auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2);
+ auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
+ auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8);
+ auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8);
+ auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1);
+ auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2);
+ auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), v3);
+ auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), v4);
#if defined __AVX512VNNI__ && defined __AVX512VL__
- accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
- accd[iy], m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4);
+ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
+ accd[iy], m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4);
#else
- auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
- _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)),
- _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4))));
- accd[iy] = _mm256_add_epi32(dot, accd[iy]);
+ auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
+ _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)),
+ _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4))));
+ accd[iy] = _mm256_add_epi32(dot, accd[iy]);
#endif
+ }
}
}
int i = 2*(nb/2);