diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-17 19:07:38 +0300 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:52 +0300 |
commit | 661698513587fea89191a08b9c28a1f5619ebac8 (patch) | |
tree | 209dc4d7a09d2523c0e5dd7b832e4f2bc51c3448 | |
parent | f6863cfa1bbc5ac42b78837b355e45d82246a472 (diff) |
bitnet 2 bpw: AVX2 implementation
We get PP-512 = 322 t/s.
TG is already 51.6 t/s at 4 threads, then it saturates and
starts going down for more than 8 threads.
-rw-r--r-- | iqk_mul_mat.cpp | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 41c920de..a264ba94 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1407,6 +1407,84 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const } } +template <int nrc_y> +IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + const int nb = n / QK_IQ1BN; + Q8_K64<nrc_y> q8(info); + __m256 accd[nrc_y]; + __m256i signs[2]; + + const auto m1_8 = _mm256_set1_epi8(1); + const auto shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000); + const auto shuff2 = _mm256_add_epi8(shuff1, m1_8); + const auto shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); + const auto shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404); + const auto mask1 = _mm256_set1_epi64x(0x8040201008040201); + const auto mask2 = _mm256_set1_epi8(3); +#if !(defined __AVX512VNNI__ && defined __AVX512VL__) + const auto m1_16 = _mm256_set1_epi16(1); +#endif + + for (int ix = 0; ix < nrc_x; ++ix) { + + const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx); + float d = GGML_FP16_TO_FP32(*(const ggml_half *)x); + auto extra_ptr = (const uint16_t *)x; + + auto all_signs = _mm256_set1_epi8(extra_ptr[1]); + all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8); + signs[0] = _mm256_shuffle_epi8(all_signs, shuff3); + signs[1] = _mm256_shuffle_epi8(all_signs, shuff4); + + auto ql = (const uint8_t *)(extra_ptr + 2); + auto qh = ql + QK_IQ1BN/8; + auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)], + iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]); + auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)], + iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]); + + auto v1 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1), + _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1)); + auto v2 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1), + _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1)); + + v1 = _mm256_sign_epi8(v1, signs[0]); + v2 = _mm256_sign_epi8(v2, signs[1]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, 0, 0), v1); + auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, 0, 1), v2); +#if defined __AVX512VNNI__ && defined __AVX512VL__ + auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2); +#else + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2))); +#endif + accd[iy] = _mm256_mul_ps(_mm256_set1_ps(q8.scale(iy, 0)), _mm256_cvtepi32_ps(dot)); + } + + for (int i = 1; i < nb; ++i) { + auto q2bits = _mm_loadu_si128((const __m128i *)x[i].qs); + auto q2 = MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits); + auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2, mask2), m1_8); + auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2, 4), mask2), m1_8); + for (int iy = 0; iy < nrc_y; ++iy) { + auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1); + auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2); +#if defined __AVX512VNNI__ && defined __AVX512VL__ + auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2); +#else + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2))); +#endif + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, d * hsum_float_8(accd[iy])); + } + + } +} + template <typename Dequantizer, int nrc_y> static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); @@ -2610,6 +2688,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[7] = mul_mat_iq1bn_q8_K64<8>; expected_typeB = GGML_TYPE_Q8_K64; break; + case GGML_TYPE_IQ2_BN: + assert (ne00 % QK_IQ1BN == 0); + mm.funcs[0] = mul_mat_iq2bn_q8_K64<1>; + mm.funcs[1] = mul_mat_iq2bn_q8_K64<2>; + mm.funcs[2] = mul_mat_iq2bn_q8_K64<3>; + mm.funcs[3] = mul_mat_iq2bn_q8_K64<4>; + mm.funcs[4] = mul_mat_iq2bn_q8_K64<5>; + mm.funcs[5] = mul_mat_iq2bn_q8_K64<6>; + mm.funcs[6] = mul_mat_iq2bn_q8_K64<7>; + mm.funcs[7] = mul_mat_iq2bn_q8_K64<8>; + expected_typeB = GGML_TYPE_Q8_K64; + break; case GGML_TYPE_Q4_0: assert (ne00 % QK4_0 == 0); MulMat::set_functions<Q4_0_Unpacker>(mm); |