From e4dc3babb59da45a60dd3fdf1a9a45e1e9390f37 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 15 Jul 2024 13:46:07 +0300 Subject: iq1bn(no lookup): somewhat better MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We now have for Bitnet-3B: | threads | test | t/s | | ------: | ------------: | ---------------: | | 16 | pp512 | 308.97 ± 1.89 | | 16 | tg128 | 58.80 ± 0.07 | | 8 | tg128 | 49.79 ± 1.23 | | 4 | tg128 | 28.85 ± 0.02 | | 2 | tg128 | 15.39 ± 0.01 | --- iqk_mul_mat.cpp | 61 +++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 48 insertions(+), 13 deletions(-) (limited to 'iqk_mul_mat.cpp') diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 44d0703a..0f53e02c 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1342,8 +1342,11 @@ template struct Q8_K64 { struct DequantizerIQ1BN { const __m256i m1_8 = _mm256_set1_epi8(1); - const __m128i mask1 = _mm_set1_epi8(0xf0); +#ifdef HAVE_FANCY_SIMD + const __m128i shifthh = _mm_set_epi16(5, 6, 7, 8, 9, 10, 11, 12); +#else const __m128i mulhh = _mm_set_epi16(32, 64, 128, 256, 512, 1024, 2048, 4096); +#endif const __m128i maskhh = _mm_set1_epi16(4096); const __m256i shuffles[4] = { _mm256_set_epi64x(0x0302030203020302, 0x0302030203020302, 0x0100010001000100, 0x0100010001000100), @@ -1353,22 +1356,54 @@ struct DequantizerIQ1BN { }; const __m256i mult = _mm256_set_epi16(8, 24, 72, 216, 648, 1944, 5832, 17496, 8, 24, 72, 216, 648, 1944, 5832, 17496); const __m256i m3 = _mm256_set1_epi16(3); + const __m128i shuff_l = _mm_set_epi8(-128, 8, -128, 7, -128, 6, -128, 5, -128, 4, -128, 3, -128, 2, -128, 1); + const __m128i shuff_h = _mm_set_epi8(12, -128, 11, -128, 10, -128, 9, -128, 12, -128, 11, -128, 10, -128, 9, -128); + const __m128i shift_h = _mm_set_epi32(4, 4, 0, 0); + const __m128i mask_h = _mm_set1_epi16(0x0f00); + const __m128i shuff_hh = _mm_set_epi8(-128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0); +#ifdef HAVE_FANCY_SIMD + const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); +#endif - IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) { - - auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql)); - uint32_t aux32; std::memcpy(&aux32, qh, 4); - auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1)); - auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh); - auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3)); + IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) { + auto data = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes! + auto aux1 = _mm_shuffle_epi8(data, shuff_l); + auto aux2 = _mm_and_si128(_mm_srlv_epi32(_mm_shuffle_epi8(data, shuff_h), shift_h), mask_h); +#ifdef HAVE_FANCY_SIMD + auto aux3 = _mm_and_si128(_mm_sllv_epi16(_mm_shuffle_epi8(data, shuff_hh), shifthh), maskhh); +#else + auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_shuffle_epi8(data, shuff_hh), mulhh), maskhh); +#endif + auto all128 = _mm_or_si128(_mm_or_si128(aux1, aux2), aux3); auto all = MM256_SET_M128I(all128, all128); auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3); auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3); auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3); auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3); +#ifdef HAVE_FANCY_SIMD + v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8); + v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8); +#else v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8); v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8); +#endif } + + //IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) { + + // auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql)); + // uint32_t aux32; std::memcpy(&aux32, qh, 4); + // auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1)); + // auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh); + // auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3)); + // auto all = MM256_SET_M128I(all128, all128); + // auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3); + // auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3); + // auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3); + // auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3); + // v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8); + // v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8); + //} }; template @@ -1392,8 +1427,8 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const if constexpr (nrc_y == 1) { __m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256(); for (int i = 0; i < nb/2; ++i) { - deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]); - deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]); + deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]); + deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); #if defined __AVX512VNNI__ && defined __AVX512VL__ auto dot1 = _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0]); auto dot2 = _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]); @@ -1418,8 +1453,8 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const for (int i = 0; i < nb/2; ++i) { - deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]); - deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]); + deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]); + deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); for (int iy = 0; iy < nrc_y; ++iy) { #if defined __AVX512VNNI__ && defined __AVX512VL__ @@ -1442,7 +1477,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const } int i = 2*(nb/2); if (i < nb) { - deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, val[0], val[1]); + deq.prepare_iq1bn_quants(x + i, val[0], val[1]); for (int iy = 0; iy < nrc_y; ++iy) { auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]); auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]); -- cgit v1.2.3