diff options
Diffstat (limited to 'ggml/src/iqk/iqk_gemm_kquants.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_gemm_kquants.cpp | 468 |
1 files changed, 461 insertions, 7 deletions
diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index 589fbc26..43eff43c 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -6,6 +6,7 @@ #define GGML_COMMON_IMPL_C #include "ggml-common.h" +#include "ggml-quants.h" #ifdef __x86_64__ @@ -860,6 +861,175 @@ static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const Data } } +struct DequantizerQ6K_AVX2 final : public BaseDequantizer<block_q6_K> { + DequantizerQ6K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + inline void prepare(int i, int j) { + auto lbits1 = _mm256_loadu_si256((const __m256i *)x[i].ql + 2*j+0); + auto lbits2 = _mm256_loadu_si256((const __m256i *)x[i].ql + 2*j+1); + auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j); + bits.values[0] = _mm256_or_si256(_mm256_and_si256(lbits1, bits.ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh)); + bits.values[1] = _mm256_or_si256(_mm256_and_si256(lbits2, bits.ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); + bits.values[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), bits.ml), _mm256_and_si256(hbits, mh)); + bits.values[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), bits.ml), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh)); + } + inline void prepare_signed(int i, int j, __m256i * us) { + prepare(i, j); + for (int k = 0; k < 4; ++k) { + bits.values[k] = _mm256_add_epi8(bits.values[k], _mm256_set1_epi8(-32)); + us[k] = _mm256_sign_epi8(bits.values[k], bits.values[k]); + } + } + inline __m256i make_scales(int i) const { + return _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)x[i].scales)); + } + + const __m256i mh = _mm256_set1_epi8(0x30); + Q4Bits_AVX2 bits; +}; + +struct SimpleBits { + __m256i values[4]; +}; + +struct DequantizerQ3K_AVX2 final : public BaseDequantizer<block_q3_K> { + DequantizerQ3K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + inline void prepare(int i, int j) { + hbits = j == 0 ? _mm256_loadu_si256((const __m256i *)x[i].hmask) : _mm256_srli_epi16(hbits, 4); + auto q2bits = _mm256_loadu_si256((const __m256i *)x[i].qs + j); + bits.values[0] = _mm256_and_si256(q2bits, ml); + bits.values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml); + bits.values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml); + bits.values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml); + bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); + bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)); + bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh)); + bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh)); + //bits.values[0] = _mm256_sub_epi8(bits.values[0], _mm256_xor_si256(mh, _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh))); + //bits.values[1] = _mm256_sub_epi8(bits.values[1], _mm256_xor_si256(mh, _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh))); + //bits.values[2] = _mm256_sub_epi8(bits.values[2], _mm256_xor_si256(mh, _mm256_and_si256(hbits, mh))); + //bits.values[3] = _mm256_sub_epi8(bits.values[3], _mm256_xor_si256(mh, _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh))); + } + inline void prepare_signed(int i, int j, __m256i * us) { + prepare(i, j); + for (int k = 0; k < 4; ++k) { + bits.values[k] = _mm256_sub_epi8(bits.values[k], mh); + us[k] = _mm256_sign_epi8(bits.values[k], bits.values[k]); + } + //for (int k = 0; k < 4; ++k) { + // us[k] = _mm256_sign_epi8(bits.values[k], bits.values[k]); + //} + } + inline __m256i make_scales(int i) const { + return _mm256_cvtepi8_epi16(sc3.make_scales((const uint16_t *)x[i].scales)); + } + + ScaleQ3 sc3; + + __m256i hbits; + SimpleBits bits; + const __m256i ml = _mm256_set1_epi8(3); + const __m256i mh = _mm256_set1_epi8(4); +}; + +template <typename Dequantizer, int nrc_y> +static void mul_mat_qY_K_q8_2_X4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8<nrc_y, block_q8_2_x4> q8(info); + + Dequantizer deq(vx, bx); + + __m256 accd[nrc_y]; + __m256 scales[2]; + float d8[8*nrc_y]; + __m256i us[4]; + + uint8_t k_shuff[32] = {0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15}; + auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + deq.d = GGML_FP16_TO_FP32(deq.x[i].d); + auto vd = _mm256_set1_ps(deq.d); + auto sc16 = _mm256_shuffle_epi8(deq.make_scales(i), shuff); + scales[0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(sc16)))); + scales[1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(sc16, 1)))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d))); + auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d))); + auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16)); + if constexpr (nrc_y == 1) { + auto dyh = _mm256_extractf128_ps(dy, 1); + scales[0] = _mm256_mul_ps(scales[0], _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy))); + scales[1] = _mm256_mul_ps(scales[1], _mm256_set_m128(dyh, dyh)); + } else { + _mm256_storeu_ps(d8 + 8*iy, dy); + } + } + + for (int j = 0; j < QK_K/128; ++j) { + + deq.prepare_signed(i, j, us); + + for (int iy = 0; iy < nrc_y; ++iy) { + auto qs = q8.y[iy][2*i+j].qs; +#ifdef HAVE_FANCY_SIMD + // 0...31 + auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+0), deq.bits.values[0])); + // 32...63 + auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[1], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+1), deq.bits.values[1])); + // 64...95 + auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[2], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+2), deq.bits.values[2])); + // 96...128 + auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[3], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+3), deq.bits.values[3])); + // 0...3, 32...35, 4....7, 36...39, 16...19, 48...51, 20...23, 52...56 + + // 8..11, 40...43, 12...15, 44...47, 24...27, 56...59, 28...31, 60...63 + // b0 b2 b0 b2 b1 b3 b1 b3 + sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); + // same as above + 64, so + // b4 b6, b4 b6 b5 b7 b5 b7 + sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); + // b0 b2 b4 b6 b1 b3 b5 b7 + + // b0 b2 b4 b6 b1 b3 b5 b7 + sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3)); +#else + auto sumi1 = _mm256_maddubs_epi16(us[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+0), deq.bits.values[0])); + auto sumi2 = _mm256_maddubs_epi16(us[1], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+1), deq.bits.values[1])); + auto sumi3 = _mm256_maddubs_epi16(us[2], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+2), deq.bits.values[2])); + auto sumi4 = _mm256_maddubs_epi16(us[3], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+3), deq.bits.values[3])); + sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); + sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); + sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3)); + sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1); +#endif + if constexpr (nrc_y > 1) { + auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j); + auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4)); + accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]); + } else { + accd[iy] = _mm256_fmadd_ps(scales[j], _mm256_cvtepi32_ps(sumi1), accd[iy]); + } + } + + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + + } +} + template <int nrc_y> static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); @@ -1669,14 +1839,13 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn } } #ifdef HAVE_FANCY_SIMD - auto m4 = _mm256_mul_ps(d4, _mm256_set1_ps(-128.f)); + auto m4 = _mm256_mul_ps(d4, _mm256_set1_ps(-127.f)); #endif for (int iy = 0; iy < nrc_y; ++iy) { auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))); acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]); #ifdef HAVE_FANCY_SIMD - auto bsums = (const float *)q8.y[iy][ibl].bsums; - acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(bsums[0]), acc[iy]); + acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(q8.y[iy][ibl].sum), acc[iy]); #endif isum[iy] = _mm256_setzero_si256(); } @@ -1982,6 +2151,284 @@ void iqk_convert_q5_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int } } +void iqk_convert_q6_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_q6_K * x8[8]; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + float all_s[64]; + uint32_t block[8]; + __m256i values[8]; + + auto ml = _mm256_set1_epi8(0x0f); + auto mh = _mm256_set1_epi8(0x30); + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_q6_K *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = GGML_FP16_TO_FP32(x8[k][i].d); + auto ql = x8[k][i].ql; + auto qh = x8[k][i].qh; + for (int i128 = 0; i128 < 2; ++i128) { + auto lbits1 = _mm256_loadu_si256((const __m256i *)ql + 2*i128 + 0); + auto lbits2 = _mm256_loadu_si256((const __m256i *)ql + 2*i128 + 1); + auto hbits = _mm256_loadu_si256((const __m256i *)qh + i128); + values[4*i128+0] = _mm256_or_si256(_mm256_and_si256(lbits1, ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh)); + values[4*i128+1] = _mm256_or_si256(_mm256_and_si256(lbits2, ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); + values[4*i128+2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), ml), _mm256_and_si256(hbits, mh)); + values[4*i128+3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), ml), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh)); + } + for (int ib32 = 0; ib32 < 8; ++ib32) { + // We have two blocks of 16 with different scales + // We multiply the quants with the scales, find the max value, and convert to 8-bit quants with a single block scale. + auto q8 = _mm256_add_epi8(values[ib32], _mm256_set1_epi8(-32)); + auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q8)); + auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q8, 1)); + q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(x8[k][i].scales[2*ib32+0])); + q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(x8[k][i].scales[2*ib32+1])); + auto abs_q16_l = _mm256_sign_epi16(q16_l, q16_l); + auto abs_q16_h = _mm256_sign_epi16(q16_h, q16_h); + auto max_q16 = _mm256_max_epi16(abs_q16_l, abs_q16_h); + auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_q16), _mm256_extracti128_si256(max_q16, 1))); + auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1)); + auto max4 = _mm_cvtepi32_ps(imax4); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + float max = _mm_cvtss_f32(max4) / 127; + all_s[8*ib32+k] = d*max; + if (max > 1e-9f) { + auto scale = _mm256_set1_ps(1/max); + auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l)); + auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1)); + auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h)); + auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1)); + i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST)); + i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST)); + i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST)); + i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST)); + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + _mm256_storeu_si256((__m256i *)block, i0); + } else { + _mm256_storeu_si256((__m256i *)block, _mm256_setzero_si256()); + } + auto qs = (uint32_t *)y[ib32].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + for (int ib32 = 0; ib32 < 8; ++ib32) { + _mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(_mm256_loadu_ps(all_s + 8*ib32), _MM_FROUND_TO_NEAREST_INT)); + } + y += QK_K/32; + } + } +} + +void iqk_convert_q3_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_q3_K * x8[8]; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + float all_s[64]; + uint32_t block[8]; + __m256i values[8]; + + ScaleQ3 sc3; + auto ml = _mm256_set1_epi8(0x03); + auto mh = _mm256_set1_epi8(0x04); + + union { __m256i vec; int16_t val[16]; } helper; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_q3_K *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = GGML_FP16_TO_FP32(x8[k][i].d); + auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].hmask); + for (int i128 = 0; i128 < 2; ++i128) { + auto q2bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs + i128); + values[4*i128+0] = _mm256_and_si256(q2bits, ml); + values[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml); + values[4*i128+2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml); + values[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml); + values[4*i128+0] = _mm256_or_si256(values[4*i128+0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); + values[4*i128+1] = _mm256_or_si256(values[4*i128+1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)); + values[4*i128+2] = _mm256_or_si256(values[4*i128+2], _mm256_and_si256(hbits, mh)); + values[4*i128+3] = _mm256_or_si256(values[4*i128+3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh)); + values[4*i128+0] = _mm256_sub_epi8(values[4*i128+0], mh); + values[4*i128+1] = _mm256_sub_epi8(values[4*i128+1], mh); + values[4*i128+2] = _mm256_sub_epi8(values[4*i128+2], mh); + values[4*i128+3] = _mm256_sub_epi8(values[4*i128+3], mh); + hbits = _mm256_srli_epi16(hbits, 4); + } + helper.vec = _mm256_cvtepi8_epi16(sc3.make_scales((const uint16_t *)x8[k][i].scales)); + for (int ib32 = 0; ib32 < 8; ++ib32) { + auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(values[ib32])); + auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(values[ib32], 1)); + q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(helper.val[2*ib32+0])); + q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(helper.val[2*ib32+1])); + auto abs_q16_l = _mm256_sign_epi16(q16_l, q16_l); + auto abs_q16_h = _mm256_sign_epi16(q16_h, q16_h); + auto max_q16 = _mm256_max_epi16(abs_q16_l, abs_q16_h); + auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_q16), _mm256_extracti128_si256(max_q16, 1))); + auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1)); + auto max4 = _mm_cvtepi32_ps(imax4); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + float max = _mm_cvtss_f32(max4) / 127; + all_s[8*ib32+k] = d*max; + if (max > 1e-9f) { + auto scale = _mm256_set1_ps(1/max); + auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l)); + auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1)); + auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h)); + auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1)); + i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST)); + i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST)); + i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST)); + i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST)); + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + _mm256_storeu_si256((__m256i *)block, i0); + } else { + _mm256_storeu_si256((__m256i *)block, _mm256_setzero_si256()); + } + auto qs = (uint32_t *)y[ib32].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + for (int ib32 = 0; ib32 < 8; ++ib32) { + _mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(_mm256_loadu_ps(all_s + 8*ib32), _MM_FROUND_TO_NEAREST_INT)); + } + y += QK_K/32; + } + } +} + +void iqk_convert_q3_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_q3_K * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + uint32_t block[8]; + __m256i values[8]; + + ScaleQ3 sc3; + auto ml = _mm256_set1_epi8(0x03); + auto mh = _mm256_set1_epi8(0x04); + + union { __m256i vec; int16_t val[16]; } helper; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_q3_K *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = GGML_FP16_TO_FP32(x8[k][i].d); + auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].hmask); + helper.vec = _mm256_cvtepi8_epi16(sc3.make_scales((const uint16_t *)x8[k][i].scales)); + auto max_i16 = _mm256_setzero_si256(); + for (int i128 = 0; i128 < 2; ++i128) { + auto q2bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs + i128); + values[4*i128+0] = _mm256_and_si256(q2bits, ml); + values[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml); + values[4*i128+2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml); + values[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml); + values[4*i128+0] = _mm256_or_si256(values[4*i128+0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); + values[4*i128+1] = _mm256_or_si256(values[4*i128+1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)); + values[4*i128+2] = _mm256_or_si256(values[4*i128+2], _mm256_and_si256(hbits, mh)); + values[4*i128+3] = _mm256_or_si256(values[4*i128+3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh)); + values[4*i128+0] = _mm256_sub_epi8(values[4*i128+0], mh); + values[4*i128+1] = _mm256_sub_epi8(values[4*i128+1], mh); + values[4*i128+2] = _mm256_sub_epi8(values[4*i128+2], mh); + values[4*i128+3] = _mm256_sub_epi8(values[4*i128+3], mh); + hbits = _mm256_srli_epi16(hbits, 4); + + for (int l = 0; l < 4; ++l) { + auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(values[4*i128+l])); + auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(values[4*i128+l], 1)); + q16_l = _mm256_mullo_epi16(_mm256_set1_epi16(helper.val[8*i128+2*l+0]), q16_l); + q16_h = _mm256_mullo_epi16(_mm256_set1_epi16(helper.val[8*i128+2*l+1]), q16_h); + max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_l, q16_l)); + max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_h, q16_h)); + } + } + auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1))); + auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1)); + auto max4 = _mm_cvtepi32_ps(imax4); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + bool needs_scaling = true; + float dnew = _mm_cvtss_f32(max4) / 127; + if (dnew < 1.f) { + dnew = 1.f; needs_scaling = false; + } + d *= dnew; + y[i].d[k] = GGML_FP32_TO_FP16(d); + auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f); + for (int ib32 = 0; ib32 < 8; ++ib32) { + auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(values[ib32])); + auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(values[ib32], 1)); + q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(helper.val[2*ib32+0])); + q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(helper.val[2*ib32+1])); + if (needs_scaling) { + auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l)); + auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1)); + auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h)); + auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1)); + i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST)); + i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST)); + i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST)); + i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST)); + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + _mm256_storeu_si256((__m256i *)block, i0); + } else { + // 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 + auto i0 = _mm256_packs_epi16(q16_l, q16_h); + auto i0_l = _mm256_castsi256_si128(i0); + auto i0_h = _mm256_extracti128_si256(i0, 1); + _mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h)); + _mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h)); + } + auto qs = (uint32_t *)y[i].qs + 64*ib32; + for (int l = 0; l < 8; ++l) { + qs[8*l + k] = block[l]; + } + } + } + } + y += nb; + } +} + } // namespace @@ -1989,9 +2436,12 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_ auto etypeA = ggml_type(typeA); auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32 - : etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8 + //: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8 : etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV - : etypeA == GGML_TYPE_Q4_K || etypeA == GGML_TYPE_Q5_K ? GGML_TYPE_Q8_2_X4 + : etypeA == GGML_TYPE_Q4_K || etypeA == GGML_TYPE_Q5_K || + etypeA == GGML_TYPE_Q6_K ? GGML_TYPE_Q8_2_X4 + //etypeA == GGML_TYPE_Q6_K || etypeA == GGML_TYPE_Q3_K ? GGML_TYPE_Q8_2_X4 + //: etypeA == GGML_TYPE_Q4_K || etypeA == GGML_TYPE_Q5_K ? GGML_TYPE_Q8_2_X4 : GGML_TYPE_Q8_K; if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) { @@ -2006,6 +2456,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_ break; case GGML_TYPE_Q3_K: set_functions<DequantizerQ3K>(kernels); + //IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_2_X4_T, DequantizerQ3K_AVX2, kernels); break; case GGML_TYPE_Q4_K: IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ4K_AVX2, kernels); @@ -2016,7 +2467,8 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_ //set_functions<DequantizerQ5K>(kernels); break; case GGML_TYPE_Q6_K: - set_functions<DequantizerQ6K>(kernels); + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_2_X4_T, DequantizerQ6K_AVX2, kernels); + //set_functions<DequantizerQ6K>(kernels); break; case GGML_TYPE_IQ4_XS: set_functions<DequantizerIQ4XS>(kernels); @@ -2064,8 +2516,10 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_ bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) { switch (ggml_type(type)) { + case GGML_TYPE_Q3_K: iqk_convert_q3_k_q8_k_r8(n, vx, bx, vy, nrc_x); break; case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break; case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_Q6_K: iqk_convert_q6_k_q8_0_r8(n, vx, bx, vy, nrc_x); break; default: return false; } return true; @@ -3075,7 +3529,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_ auto etypeA = ggml_type(typeA); auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32 - : etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8 + //: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8 : etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV : GGML_TYPE_Q8_K; |