summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-06-13 07:58:15 +0300
committerGitHub <noreply@github.com>2025-06-13 07:58:15 +0300
commit066ed4fd1158ddaab0080ef0e77bd5b7e12ec114 (patch)
tree42707c91f1e27486ffe2e3b4dc974c6694760263
parentf72983f7fe16f02cda4af40172b87ff721920b46 (diff)
Faster CPU prompt processing for Q4_K and Q5_K (#525)
* q4_K: dequantize to q8_1_r8 for batch >= 32 We get 268 t/s, up from 186 t/s. * q4_K: GEMM with q8_2_X4 * q5_K: GEMM with q8_2_X4 and repack to q8_1_r8 * Remove the scales, they are not needed --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml.c8
-rw-r--r--ggml/src/iqk/iqk_gemm_kquants.cpp297
-rw-r--r--ggml/src/iqk/iqk_gemm_kquants.h2
-rw-r--r--ggml/src/iqk/iqk_gemm_legacy_quants.cpp78
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp11
5 files changed, 391 insertions, 5 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 3953cd7d..069533ae 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -976,7 +976,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q4_K,
.from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref,
.vec_dot = ggml_vec_dot_q4_K_q8_K,
+#ifdef __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_2_X4,
+#else
.vec_dot_type = GGML_TYPE_Q8_K,
+#endif
.nrows = 1,
.row_meta_size = 0,
},
@@ -1002,7 +1006,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q5_K,
.from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref,
.vec_dot = ggml_vec_dot_q5_K_q8_K,
+#ifdef __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_2_X4,
+#else
.vec_dot_type = GGML_TYPE_Q8_K,
+#endif
.nrows = 1,
.row_meta_size = 0,
},
diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp
index dfbff710..589fbc26 100644
--- a/ggml/src/iqk/iqk_gemm_kquants.cpp
+++ b/ggml/src/iqk/iqk_gemm_kquants.cpp
@@ -719,6 +719,147 @@ static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
#endif
+// inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
+// make_q4_scales(data, utmp);
+// const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+// const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1);
+// accum_mins(mins128, q8, i, c, accd);
+// const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
+// return MM256_SET_M128I(sc128, sc128);
+// }
+//
+// inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
+// d = GGML_FP16_TO_FP32(x[i].d);
+// bits.prepare(x[i].qs);
+// auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
+// scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
+// scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
+// }
+
+
+struct Q4Bits_AVX2 {
+ inline void prepare(const uint8_t * q4, int j) {
+ auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
+ values[0] = _mm256_and_si256(q4bits, ml);
+ values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
+ q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
+ values[2] = _mm256_and_si256(q4bits, ml);
+ values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
+ }
+ __m256i values[4];
+ const __m256i ml = _mm256_set1_epi8(0xf);
+};
+
+struct DequantizerQ4K_AVX2 final : public BaseDequantizer<block_q4_K> {
+ DequantizerQ4K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ }
+ Q4Bits_AVX2 bits;
+};
+
+struct DequantizerQ5K_AVX2 final : public BaseDequantizer<block_q5_K> {
+ DequantizerQ5K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ hbits = j == 0 ? _mm256_loadu_si256((const __m256i *)x[i].qh) : _mm256_srli_epi16(hbits, 4);
+ apply_hbits();
+ }
+ inline void apply_hbits() {
+ bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
+ bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh));
+ bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
+ bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
+ }
+
+ const __m256i mh = _mm256_set1_epi8(0x10);
+ Q4Bits_AVX2 bits;
+ __m256i hbits;
+};
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+
+ Dequantizer deq(vx, bx);
+
+ uint32_t utmp[4];
+ __m256 accd[nrc_y];
+ __m256 scales[2];
+ float d8[8*nrc_y];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ deq.d = GGML_FP16_TO_FP32(deq.x[i].d);
+ auto vm = _mm256_cvtph_ps(_mm_set1_epi16(deq.x[i].dmin));
+ make_q4_scales(deq.x[i].scales, utmp);
+ auto mins = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(utmp + 2)))));
+ mins = _mm256_mul_ps(_mm256_set1_ps(-1.f), mins);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d)));
+ auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d)));
+ auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16));
+ _mm256_storeu_ps(d8 + 8*iy, dy);
+ auto m4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4)));
+ auto m4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4)));
+ auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(m4_2, m4_1), 16));
+ accd[iy] = _mm256_fmadd_ps(my, mins, accd[iy]);
+ }
+
+ auto all_scales = _mm256_mul_ps(_mm256_set1_ps(deq.d), _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)utmp))));
+ scales[0] = _mm256_set_m128(_mm256_castps256_ps128(all_scales), _mm256_castps256_ps128(all_scales));
+ auto scales_h = _mm256_extractf128_ps(all_scales, 1);
+ scales[1] = _mm256_set_m128(scales_h, scales_h);
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ deq.prepare(i, j);
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const block_q8_2_x4& y = q8.y[iy][2*i+j];
+#ifdef HAVE_FANCY_SIMD
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0));
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1));
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2));
+ auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3));
+ sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
+ sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
+ sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0));
+ auto sumi2 = _mm256_maddubs_epi16(deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1));
+ auto sumi3 = _mm256_maddubs_epi16(deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2));
+ auto sumi4 = _mm256_maddubs_epi16(deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3));
+ sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
+ sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
+ sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
+ sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1);
+#endif
+ auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j);
+ auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4));
+ accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]);
+ }
+
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+
+ }
+}
+
template <int nrc_y>
static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
@@ -1702,6 +1843,146 @@ static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const Data
}
}
+typedef struct {
+ ggml_half d[16];
+ int8_t qs[8*QK8_1];
+} block_q8_1_r8;
+
+void iqk_convert_q4_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_q4_K * x8[8];
+
+ block_q8_1_r8 * y = (block_q8_1_r8 *)vy;
+
+ ggml_half dh[16];
+ uint16_t all_ls[128];
+
+ uint32_t utmp[4];
+ const uint8_t * u8 = (const uint8_t *)utmp;
+ uint32_t block[8];
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_q4_K *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ dh[k+0] = x8[k][i].d;
+ dh[k+8] = x8[k][i].dmin;
+ make_q4_scales(x8[k][i].scales, utmp);
+ auto qs = x8[k][i].qs;
+ for (int ib64 = 0; ib64 < 4; ++ib64) {
+ all_ls[8*(2*ib64 + 0) + k ] = u8[2*ib64+0];
+ all_ls[8*(2*ib64 + 1) + k ] = u8[2*ib64+1];
+ all_ls[8*(2*ib64 + 0) + k + 64] = u8[2*ib64+8];
+ all_ls[8*(2*ib64 + 1) + k + 64] = u8[2*ib64+9];
+ auto bits = _mm256_loadu_si256((const __m256i *)qs+ib64);
+ auto values1 = _mm256_and_si256(bits, _mm256_set1_epi8(0xf));
+ auto values2 = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf));
+ _mm256_storeu_si256((__m256i *)block, values1);
+ auto q8 = (uint32_t *)y[2*ib64+0].qs;
+ for (int l = 0; l < 4; ++l) {
+ q8[8*l + k + 0] = block[l + 0];
+ q8[8*l + k + 32] = block[l + 4];
+ }
+ _mm256_storeu_si256((__m256i *)block, values2);
+ q8 = (uint32_t *)y[2*ib64+1].qs;
+ for (int l = 0; l < 4; ++l) {
+ q8[8*l + k + 0] = block[l + 0];
+ q8[8*l + k + 32] = block[l + 4];
+ }
+ }
+ }
+ auto vd = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+0));
+ auto vm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+1));
+ vm = _mm256_mul_ps(_mm256_set1_ps(-1.f), vm);
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
+ auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
+ auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
+ _mm_storeu_si128((__m128i *)y[ib32].d+0, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
+ iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32 + 8);
+ iscales32 = _mm256_cvtepi16_epi32(iscales16);
+ scales = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(iscales32));
+ _mm_storeu_si128((__m128i *)y[ib32].d+1, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
+ }
+ y += QK_K/32;
+ }
+ }
+}
+
+void iqk_convert_q5_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_q5_K * x8[8];
+
+ block_q8_1_r8 * y = (block_q8_1_r8 *)vy;
+
+ ggml_half dh[16];
+ uint16_t all_ls[128];
+
+ uint32_t utmp[4];
+ const uint8_t * u8 = (const uint8_t *)utmp;
+ uint32_t block[8];
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_q5_K *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ dh[k+0] = x8[k][i].d;
+ dh[k+8] = x8[k][i].dmin;
+ make_q4_scales(x8[k][i].scales, utmp);
+ auto qs = x8[k][i].qs;
+ auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh);
+ for (int ib64 = 0; ib64 < 4; ++ib64) {
+ all_ls[8*(2*ib64 + 0) + k ] = u8[2*ib64+0];
+ all_ls[8*(2*ib64 + 1) + k ] = u8[2*ib64+1];
+ all_ls[8*(2*ib64 + 0) + k + 64] = u8[2*ib64+8];
+ all_ls[8*(2*ib64 + 1) + k + 64] = u8[2*ib64+9];
+ auto bits = _mm256_loadu_si256((const __m256i *)qs+ib64);
+ auto values1 = _mm256_and_si256(bits, _mm256_set1_epi8(0xf));
+ auto values2 = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf));
+ values1 = _mm256_or_si256(values1, _mm256_and_si256(_mm256_set1_epi8(0x10), _mm256_slli_epi16(hbits, 4)));
+ values2 = _mm256_or_si256(values2, _mm256_and_si256(_mm256_set1_epi8(0x10), _mm256_slli_epi16(hbits, 3)));
+ hbits = _mm256_srli_epi16(hbits, 2);
+ _mm256_storeu_si256((__m256i *)block, values1);
+ auto q8 = (uint32_t *)y[2*ib64+0].qs;
+ for (int l = 0; l < 4; ++l) {
+ q8[8*l + k + 0] = block[l + 0];
+ q8[8*l + k + 32] = block[l + 4];
+ }
+ _mm256_storeu_si256((__m256i *)block, values2);
+ q8 = (uint32_t *)y[2*ib64+1].qs;
+ for (int l = 0; l < 4; ++l) {
+ q8[8*l + k + 0] = block[l + 0];
+ q8[8*l + k + 32] = block[l + 4];
+ }
+ }
+ }
+ auto vd = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+0));
+ auto vm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+1));
+ vm = _mm256_mul_ps(_mm256_set1_ps(-1.f), vm);
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
+ auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
+ auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
+ _mm_storeu_si128((__m128i *)y[ib32].d+0, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
+ iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32 + 8);
+ iscales32 = _mm256_cvtepi16_epi32(iscales16);
+ scales = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(iscales32));
+ _mm_storeu_si128((__m128i *)y[ib32].d+1, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
+ }
+ y += QK_K/32;
+ }
+ }
+}
+
+
} // namespace
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
@@ -1710,6 +1991,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
: etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV
+ : etypeA == GGML_TYPE_Q4_K || etypeA == GGML_TYPE_Q5_K ? GGML_TYPE_Q8_2_X4
: GGML_TYPE_Q8_K;
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
@@ -1726,10 +2008,12 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
set_functions<DequantizerQ3K>(kernels);
break;
case GGML_TYPE_Q4_K:
- set_functions<DequantizerQ4K>(kernels);
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ4K_AVX2, kernels);
+ //set_functions<DequantizerQ4K>(kernels);
break;
case GGML_TYPE_Q5_K:
- set_functions<DequantizerQ5K>(kernels);
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ5K_AVX2, kernels);
+ //set_functions<DequantizerQ5K>(kernels);
break;
case GGML_TYPE_Q6_K:
set_functions<DequantizerQ6K>(kernels);
@@ -1778,6 +2062,15 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
}
+bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ switch (ggml_type(type)) {
+ case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
+ default: return false;
+ }
+ return true;
+}
+
#else
// --------------------------------- __aarch64__ --------------------------------------
diff --git a/ggml/src/iqk/iqk_gemm_kquants.h b/ggml/src/iqk/iqk_gemm_kquants.h
index 071d2e50..3518ebc4 100644
--- a/ggml/src/iqk/iqk_gemm_kquants.h
+++ b/ggml/src/iqk/iqk_gemm_kquants.h
@@ -10,4 +10,6 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
void iqk_gemm_q8kv_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step);
+bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x);
+
#endif
diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
index 6e262aab..17d2dad3 100644
--- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
+++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
@@ -1615,6 +1615,81 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn
}
#endif
+typedef struct {
+ ggml_half d[16];
+ uint8_t qs[256];
+} block_q8_1_r8;
+
+template <int nrc_y>
+static void mul_mat_q8_1_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+ int nb = n / QK8_0;
+ __m256 acc[nrc_y] = {};
+ float d8[4*nrc_y];
+ __m256i qx[4];
+ auto dot = [&qx] (const int8_t * qy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)qy);
+ auto y = MM256_SET_M128I(y128, y128);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ return sumi;
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ return _mm256_add_epi32(_mm256_madd_epi16(_mm256_set1_epi16(1), sumi1), _mm256_madd_epi16(_mm256_set1_epi16(1), sumi2));
+#endif
+ };
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q8_1_r8 * iq8 = (const block_q8_1_r8 *)((const char *)vx + ix*bx);
+ for (int i4 = 0; i4 < nb/4; ++i4) {
+ {
+ __m256 mx[4];
+ for (int ib32 = 0; ib32 < 4; ++ib32) mx[ib32] = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*i4+ib32].d+1));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scales = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][i4].d)), 16));
+ _mm_storeu_ps(d8 + 4*iy + 0, scales);
+ auto bsums4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][i4].d+4))), 16));
+ auto bsums = _mm256_set_m128(bsums4, bsums4);
+ acc[iy] = _mm256_fmadd_ps(mx[0], _mm256_shuffle_ps(bsums, bsums, 0x00), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(mx[1], _mm256_shuffle_ps(bsums, bsums, 0x55), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(mx[2], _mm256_shuffle_ps(bsums, bsums, 0xaa), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(mx[3], _mm256_shuffle_ps(bsums, bsums, 0xff), acc[iy]);
+ }
+ }
+ for (int ib32 = 0; ib32 < 4; ++ib32) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*i4+ib32].d));
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*i4+ib32].qs+j);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = dot(q8.y[iy][i4].qs+32*ib32);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+ib32]));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*i4+ib32].qs+4+j);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = dot(q8.y[iy][i4].qs+32*ib32+16);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+ib32]));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||
std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
@@ -1694,6 +1769,9 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu
case GGML_TYPE_IQ4_NL_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_nl_r4_q8_2, kernels)
break;
+ case GGML_TYPE_Q8_1: // Note: we are misusing the Q8_1 type for Q8_1_R8
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_1_r8_q8_2, kernels)
+ break;
default:
return false;
}
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 53ce99a4..7c0d3aff 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -243,6 +243,8 @@ struct MulMat {
case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
+ case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
+ case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
default: break;
}
#else
@@ -283,6 +285,7 @@ struct MulMat {
case GGML_TYPE_Q5_K_R4:
case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q8_KV_R8:
+ case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_Q4_0_R8:
case GGML_TYPE_Q8_0_R8:
@@ -318,6 +321,7 @@ struct MulMat {
case GGML_TYPE_Q8_0_R8:
case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q8_KV_R8:
+ case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_BF16_R16: return 16;
default: return 1;
@@ -341,8 +345,8 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
// return iqk_set_kernels_float(ne00, typeA, typeB, mm.funcs);
//case GGML_TYPE_Q2_K:
//case GGML_TYPE_Q3_K:
- //case GGML_TYPE_Q4_K:
- //case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
//case GGML_TYPE_Q6_K:
//case GGML_TYPE_IQ4_XS:
//case GGML_TYPE_Q2_K_R4:
@@ -354,7 +358,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
//case GGML_TYPE_Q8_K_R8:
//case GGML_TYPE_Q8_KV:
//case GGML_TYPE_Q8_KV_R8:
- // return iqk_set_kernels_kquants(ne00, typeA, typeB, mm.funcs, mm.func16);
+ return iqk_convert_kquants_q8X_r8(typeA, n, vx, bx, vy, nrc_x);
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
@@ -790,6 +794,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q8_1:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_Q4_0_R8:
case GGML_TYPE_Q5_0_R4: