diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-06-13 07:58:15 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-13 07:58:15 +0300 |
commit | 066ed4fd1158ddaab0080ef0e77bd5b7e12ec114 (patch) | |
tree | 42707c91f1e27486ffe2e3b4dc974c6694760263 | |
parent | f72983f7fe16f02cda4af40172b87ff721920b46 (diff) |
Faster CPU prompt processing for Q4_K and Q5_K (#525)
* q4_K: dequantize to q8_1_r8 for batch >= 32
We get 268 t/s, up from 186 t/s.
* q4_K: GEMM with q8_2_X4
* q5_K: GEMM with q8_2_X4 and repack to q8_1_r8
* Remove the scales, they are not needed
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml.c | 8 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_gemm_kquants.cpp | 297 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_gemm_kquants.h | 2 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_gemm_legacy_quants.cpp | 78 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 11 |
5 files changed, 391 insertions, 5 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 3953cd7d..069533ae 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -976,7 +976,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q4_K, .from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref, .vec_dot = ggml_vec_dot_q4_K_q8_K, +#ifdef __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, +#else .vec_dot_type = GGML_TYPE_Q8_K, +#endif .nrows = 1, .row_meta_size = 0, }, @@ -1002,7 +1006,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q5_K, .from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref, .vec_dot = ggml_vec_dot_q5_K_q8_K, +#ifdef __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, +#else .vec_dot_type = GGML_TYPE_Q8_K, +#endif .nrows = 1, .row_meta_size = 0, }, diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index dfbff710..589fbc26 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -719,6 +719,147 @@ static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf #endif +// inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) { +// make_q4_scales(data, utmp); +// const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); +// const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1); +// accum_mins(mins128, q8, i, c, accd); +// const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); +// return MM256_SET_M128I(sc128, sc128); +// } +// +// inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) { +// d = GGML_FP16_TO_FP32(x[i].d); +// bits.prepare(x[i].qs); +// auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); +// scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); +// scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); +// } + + +struct Q4Bits_AVX2 { + inline void prepare(const uint8_t * q4, int j) { + auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0); + values[0] = _mm256_and_si256(q4bits, ml); + values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); + q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1); + values[2] = _mm256_and_si256(q4bits, ml); + values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); + } + __m256i values[4]; + const __m256i ml = _mm256_set1_epi8(0xf); +}; + +struct DequantizerQ4K_AVX2 final : public BaseDequantizer<block_q4_K> { + DequantizerQ4K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + } + Q4Bits_AVX2 bits; +}; + +struct DequantizerQ5K_AVX2 final : public BaseDequantizer<block_q5_K> { + DequantizerQ5K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + hbits = j == 0 ? _mm256_loadu_si256((const __m256i *)x[i].qh) : _mm256_srli_epi16(hbits, 4); + apply_hbits(); + } + inline void apply_hbits() { + bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh)); + bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh)); + bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); + bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)); + } + + const __m256i mh = _mm256_set1_epi8(0x10); + Q4Bits_AVX2 bits; + __m256i hbits; +}; + +template <typename Dequantizer, int nrc_y> +static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8<nrc_y, block_q8_2_x4> q8(info); + + Dequantizer deq(vx, bx); + + uint32_t utmp[4]; + __m256 accd[nrc_y]; + __m256 scales[2]; + float d8[8*nrc_y]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + deq.d = GGML_FP16_TO_FP32(deq.x[i].d); + auto vm = _mm256_cvtph_ps(_mm_set1_epi16(deq.x[i].dmin)); + make_q4_scales(deq.x[i].scales, utmp); + auto mins = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(utmp + 2))))); + mins = _mm256_mul_ps(_mm256_set1_ps(-1.f), mins); + for (int iy = 0; iy < nrc_y; ++iy) { + auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d))); + auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d))); + auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16)); + _mm256_storeu_ps(d8 + 8*iy, dy); + auto m4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4))); + auto m4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4))); + auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(m4_2, m4_1), 16)); + accd[iy] = _mm256_fmadd_ps(my, mins, accd[iy]); + } + + auto all_scales = _mm256_mul_ps(_mm256_set1_ps(deq.d), _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)utmp)))); + scales[0] = _mm256_set_m128(_mm256_castps256_ps128(all_scales), _mm256_castps256_ps128(all_scales)); + auto scales_h = _mm256_extractf128_ps(all_scales, 1); + scales[1] = _mm256_set_m128(scales_h, scales_h); + + for (int j = 0; j < QK_K/128; ++j) { + + deq.prepare(i, j); + + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_2_x4& y = q8.y[iy][2*i+j]; +#ifdef HAVE_FANCY_SIMD + auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0)); + auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1)); + auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2)); + auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3)); + sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); + sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); + sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3)); +#else + auto sumi1 = _mm256_maddubs_epi16(deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0)); + auto sumi2 = _mm256_maddubs_epi16(deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1)); + auto sumi3 = _mm256_maddubs_epi16(deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2)); + auto sumi4 = _mm256_maddubs_epi16(deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3)); + sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); + sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); + sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3)); + sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1); +#endif + auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j); + auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4)); + accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]); + } + + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + + } +} + template <int nrc_y> static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); @@ -1702,6 +1843,146 @@ static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const Data } } +typedef struct { + ggml_half d[16]; + int8_t qs[8*QK8_1]; +} block_q8_1_r8; + +void iqk_convert_q4_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_q4_K * x8[8]; + + block_q8_1_r8 * y = (block_q8_1_r8 *)vy; + + ggml_half dh[16]; + uint16_t all_ls[128]; + + uint32_t utmp[4]; + const uint8_t * u8 = (const uint8_t *)utmp; + uint32_t block[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_q4_K *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + dh[k+0] = x8[k][i].d; + dh[k+8] = x8[k][i].dmin; + make_q4_scales(x8[k][i].scales, utmp); + auto qs = x8[k][i].qs; + for (int ib64 = 0; ib64 < 4; ++ib64) { + all_ls[8*(2*ib64 + 0) + k ] = u8[2*ib64+0]; + all_ls[8*(2*ib64 + 1) + k ] = u8[2*ib64+1]; + all_ls[8*(2*ib64 + 0) + k + 64] = u8[2*ib64+8]; + all_ls[8*(2*ib64 + 1) + k + 64] = u8[2*ib64+9]; + auto bits = _mm256_loadu_si256((const __m256i *)qs+ib64); + auto values1 = _mm256_and_si256(bits, _mm256_set1_epi8(0xf)); + auto values2 = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf)); + _mm256_storeu_si256((__m256i *)block, values1); + auto q8 = (uint32_t *)y[2*ib64+0].qs; + for (int l = 0; l < 4; ++l) { + q8[8*l + k + 0] = block[l + 0]; + q8[8*l + k + 32] = block[l + 4]; + } + _mm256_storeu_si256((__m256i *)block, values2); + q8 = (uint32_t *)y[2*ib64+1].qs; + for (int l = 0; l < 4; ++l) { + q8[8*l + k + 0] = block[l + 0]; + q8[8*l + k + 32] = block[l + 4]; + } + } + } + auto vd = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+0)); + auto vm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+1)); + vm = _mm256_mul_ps(_mm256_set1_ps(-1.f), vm); + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32); + auto iscales32 = _mm256_cvtepi16_epi32(iscales16); + auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32)); + _mm_storeu_si128((__m128i *)y[ib32].d+0, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32 + 8); + iscales32 = _mm256_cvtepi16_epi32(iscales16); + scales = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(iscales32)); + _mm_storeu_si128((__m128i *)y[ib32].d+1, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + } + y += QK_K/32; + } + } +} + +void iqk_convert_q5_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_q5_K * x8[8]; + + block_q8_1_r8 * y = (block_q8_1_r8 *)vy; + + ggml_half dh[16]; + uint16_t all_ls[128]; + + uint32_t utmp[4]; + const uint8_t * u8 = (const uint8_t *)utmp; + uint32_t block[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_q5_K *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + dh[k+0] = x8[k][i].d; + dh[k+8] = x8[k][i].dmin; + make_q4_scales(x8[k][i].scales, utmp); + auto qs = x8[k][i].qs; + auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh); + for (int ib64 = 0; ib64 < 4; ++ib64) { + all_ls[8*(2*ib64 + 0) + k ] = u8[2*ib64+0]; + all_ls[8*(2*ib64 + 1) + k ] = u8[2*ib64+1]; + all_ls[8*(2*ib64 + 0) + k + 64] = u8[2*ib64+8]; + all_ls[8*(2*ib64 + 1) + k + 64] = u8[2*ib64+9]; + auto bits = _mm256_loadu_si256((const __m256i *)qs+ib64); + auto values1 = _mm256_and_si256(bits, _mm256_set1_epi8(0xf)); + auto values2 = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf)); + values1 = _mm256_or_si256(values1, _mm256_and_si256(_mm256_set1_epi8(0x10), _mm256_slli_epi16(hbits, 4))); + values2 = _mm256_or_si256(values2, _mm256_and_si256(_mm256_set1_epi8(0x10), _mm256_slli_epi16(hbits, 3))); + hbits = _mm256_srli_epi16(hbits, 2); + _mm256_storeu_si256((__m256i *)block, values1); + auto q8 = (uint32_t *)y[2*ib64+0].qs; + for (int l = 0; l < 4; ++l) { + q8[8*l + k + 0] = block[l + 0]; + q8[8*l + k + 32] = block[l + 4]; + } + _mm256_storeu_si256((__m256i *)block, values2); + q8 = (uint32_t *)y[2*ib64+1].qs; + for (int l = 0; l < 4; ++l) { + q8[8*l + k + 0] = block[l + 0]; + q8[8*l + k + 32] = block[l + 4]; + } + } + } + auto vd = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+0)); + auto vm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+1)); + vm = _mm256_mul_ps(_mm256_set1_ps(-1.f), vm); + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32); + auto iscales32 = _mm256_cvtepi16_epi32(iscales16); + auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32)); + _mm_storeu_si128((__m128i *)y[ib32].d+0, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32 + 8); + iscales32 = _mm256_cvtepi16_epi32(iscales16); + scales = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(iscales32)); + _mm_storeu_si128((__m128i *)y[ib32].d+1, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + } + y += QK_K/32; + } + } +} + + } // namespace bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) { @@ -1710,6 +1991,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_ auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32 : etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8 : etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV + : etypeA == GGML_TYPE_Q4_K || etypeA == GGML_TYPE_Q5_K ? GGML_TYPE_Q8_2_X4 : GGML_TYPE_Q8_K; if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) { @@ -1726,10 +2008,12 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_ set_functions<DequantizerQ3K>(kernels); break; case GGML_TYPE_Q4_K: - set_functions<DequantizerQ4K>(kernels); + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ4K_AVX2, kernels); + //set_functions<DequantizerQ4K>(kernels); break; case GGML_TYPE_Q5_K: - set_functions<DequantizerQ5K>(kernels); + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ5K_AVX2, kernels); + //set_functions<DequantizerQ5K>(kernels); break; case GGML_TYPE_Q6_K: set_functions<DequantizerQ6K>(kernels); @@ -1778,6 +2062,15 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_ } +bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) { + switch (ggml_type(type)) { + case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break; + default: return false; + } + return true; +} + #else // --------------------------------- __aarch64__ -------------------------------------- diff --git a/ggml/src/iqk/iqk_gemm_kquants.h b/ggml/src/iqk/iqk_gemm_kquants.h index 071d2e50..3518ebc4 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.h +++ b/ggml/src/iqk/iqk_gemm_kquants.h @@ -10,4 +10,6 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_ void iqk_gemm_q8kv_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step); +bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x); + #endif diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp index 6e262aab..17d2dad3 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp @@ -1615,6 +1615,81 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn } #endif +typedef struct { + ggml_half d[16]; + uint8_t qs[256]; +} block_q8_1_r8; + +template <int nrc_y> +static void mul_mat_q8_1_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8<nrc_y, block_q8_2_x4> q8(info); + int nb = n / QK8_0; + __m256 acc[nrc_y] = {}; + float d8[4*nrc_y]; + __m256i qx[4]; + auto dot = [&qx] (const int8_t * qy) { + auto y128 = _mm_loadu_si128((const __m128i*)qy); + auto y = MM256_SET_M128I(y128, y128); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + return sumi; +#else + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + return _mm256_add_epi32(_mm256_madd_epi16(_mm256_set1_epi16(1), sumi1), _mm256_madd_epi16(_mm256_set1_epi16(1), sumi2)); +#endif + }; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_1_r8 * iq8 = (const block_q8_1_r8 *)((const char *)vx + ix*bx); + for (int i4 = 0; i4 < nb/4; ++i4) { + { + __m256 mx[4]; + for (int ib32 = 0; ib32 < 4; ++ib32) mx[ib32] = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*i4+ib32].d+1)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][i4].d)), 16)); + _mm_storeu_ps(d8 + 4*iy + 0, scales); + auto bsums4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][i4].d+4))), 16)); + auto bsums = _mm256_set_m128(bsums4, bsums4); + acc[iy] = _mm256_fmadd_ps(mx[0], _mm256_shuffle_ps(bsums, bsums, 0x00), acc[iy]); + acc[iy] = _mm256_fmadd_ps(mx[1], _mm256_shuffle_ps(bsums, bsums, 0x55), acc[iy]); + acc[iy] = _mm256_fmadd_ps(mx[2], _mm256_shuffle_ps(bsums, bsums, 0xaa), acc[iy]); + acc[iy] = _mm256_fmadd_ps(mx[3], _mm256_shuffle_ps(bsums, bsums, 0xff), acc[iy]); + } + } + for (int ib32 = 0; ib32 < 4; ++ib32) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*i4+ib32].d)); + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*i4+ib32].qs+j); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = dot(q8.y[iy][i4].qs+32*ib32); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+ib32])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*i4+ib32].qs+4+j); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = dot(q8.y[iy][i4].qs+32*ib32+16); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+ib32])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = _mm256_setzero_ps(); + } + } +} + template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) { if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> || std::is_same_v<Dequantizer, Q8_0_Unpacker>) { @@ -1694,6 +1769,9 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu case GGML_TYPE_IQ4_NL_R4: IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_nl_r4_q8_2, kernels) break; + case GGML_TYPE_Q8_1: // Note: we are misusing the Q8_1 type for Q8_1_R8 + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_1_r8_q8_2, kernels) + break; default: return false; } diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 53ce99a4..7c0d3aff 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -243,6 +243,8 @@ struct MulMat { case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; + case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; default: break; } #else @@ -283,6 +285,7 @@ struct MulMat { case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q8_KV: case GGML_TYPE_Q8_KV_R8: + case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_K_R8: return 8; case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q8_0_R8: @@ -318,6 +321,7 @@ struct MulMat { case GGML_TYPE_Q8_0_R8: case GGML_TYPE_Q8_KV: case GGML_TYPE_Q8_KV_R8: + case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_K_R8: return 8; case GGML_TYPE_BF16_R16: return 16; default: return 1; @@ -341,8 +345,8 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, // return iqk_set_kernels_float(ne00, typeA, typeB, mm.funcs); //case GGML_TYPE_Q2_K: //case GGML_TYPE_Q3_K: - //case GGML_TYPE_Q4_K: - //case GGML_TYPE_Q5_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: //case GGML_TYPE_Q6_K: //case GGML_TYPE_IQ4_XS: //case GGML_TYPE_Q2_K_R4: @@ -354,7 +358,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, //case GGML_TYPE_Q8_K_R8: //case GGML_TYPE_Q8_KV: //case GGML_TYPE_Q8_KV_R8: - // return iqk_set_kernels_kquants(ne00, typeA, typeB, mm.funcs, mm.func16); + return iqk_convert_kquants_q8X_r8(typeA, n, vx, bx, vy, nrc_x); case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -790,6 +794,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_Q5_1: case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: case GGML_TYPE_IQ4_NL: case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q5_0_R4: |