summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml.c4
-rw-r--r--ggml/src/iqk/iqk_gemm_iqk_quants.cpp686
-rw-r--r--ggml/src/iqk/iqk_gemm_iqk_quants.h2
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp28
4 files changed, 710 insertions, 10 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index a6260136..69b1b46d 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1699,7 +1699,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq5_k,
.from_float_ref = (ggml_from_float_t)quantize_row_iq5_k_ref,
.vec_dot = vec_dot_iq5_k_q8_k,
+//#ifdef __AVX2__
+// .vec_dot_type = GGML_TYPE_Q8_2_X4,
+//#else
.vec_dot_type = GGML_TYPE_Q8_K,
+//#endif
.nrows = 1,
.row_meta_size = 0,
},
diff --git a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp
index 15c963ca..a01d7e4c 100644
--- a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp
+++ b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp
@@ -2053,8 +2053,694 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX
#endif
}
+inline float convert_to_q8_k_r8(int k, float d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
+ auto max_i16 = _mm256_setzero_si256();
+ __m256i qs[16];
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ qs[2*ib32+0] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
+ qs[2*ib32+1] = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
+ qs[2*ib32+0] = _mm256_mullo_epi16(qs[2*ib32+0], _mm256_set1_epi16(scales[2*ib32+0]));
+ qs[2*ib32+1] = _mm256_mullo_epi16(qs[2*ib32+1], _mm256_set1_epi16(scales[2*ib32+1]));
+ max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+0], qs[2*ib32+0]));
+ max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+1], qs[2*ib32+1]));
+ }
+ auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1)));
+ auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
+ auto max4 = _mm_cvtepi32_ps(imax4);
+ max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
+ max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
+ bool needs_scaling = true;
+ float dnew = _mm_cvtss_f32(max4) * d0;
+ if (dnew < 1.f) {
+ dnew = 1.f; needs_scaling = false;
+ }
+ auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f);
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ if (needs_scaling) {
+ auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+0]));
+ auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+0], 1));
+ auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+1]));
+ auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+1], 1));
+ i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
+ i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
+ i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
+ i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
+ i0 = _mm256_packs_epi32(i0, i1);
+ i2 = _mm256_packs_epi32(i2, i3);
+ i0 = _mm256_packs_epi16(i0, i2);
+ i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
+ _mm256_storeu_si256((__m256i *)block, i0);
+ } else {
+ // 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
+ auto i0 = _mm256_packs_epi16(qs[2*ib32+0], qs[2*ib32+1]);
+ auto i0_l = _mm256_castsi256_si128(i0);
+ auto i0_h = _mm256_extracti128_si256(i0, 1);
+ _mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h));
+ _mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h));
+ }
+ auto qs = (uint32_t *)q8_k + 64*ib32;
+ for (int l = 0; l < 8; ++l) {
+ qs[8*l + k] = block[l];
+ }
+ }
+ return dnew;
+}
+
+void iqk_convert_iq2_ks_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq2_ks * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ __m256i values;
+ {
+ auto v = _mm_loadl_epi64((const __m128i *)iq2nl_values);
+ values = MM256_SET_M128I(v, v);
+ }
+
+ ggml_half dh[8];
+ float dnew[8];
+ uint32_t block[8];
+ int16_t ls[16];
+
+ __m256i xv[8];
+
+ auto ml = _mm256_set1_epi8(0x03);
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) {
+ const ggml_half * dptr = (const ggml_half *)((const char *)vx + (ix+k)*bx);
+ dh[k] = dptr[0];
+ x8[k] = (const block_iq2_ks *)(dptr + 1);
+ }
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ auto extra = x8[k][i].extra;
+ for (int i128 = 0; i128 < 2; ++i128) {
+ ls[8*i128+0] = ls[8*i128+1] = ((x8[k][i].scales[2*i128+0] & 0xf) | ((extra >> 4) & 0x10)) - 16;
+ ls[8*i128+2] = ls[8*i128+3] = ((x8[k][i].scales[2*i128+0] >> 4) | ((extra >> 5) & 0x10)) - 16;
+ ls[8*i128+4] = ls[8*i128+5] = ((x8[k][i].scales[2*i128+1] & 0xf) | ((extra >> 6) & 0x10)) - 16;
+ ls[8*i128+6] = ls[8*i128+7] = ((x8[k][i].scales[2*i128+1] >> 4) | ((extra >> 7) & 0x10)) - 16;
+ auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+i128);
+ xv[4*i128+0] = _mm256_and_si256(bits, ml);
+ xv[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), ml);
+ xv[4*i128+2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), ml);
+ xv[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), ml);
+ xv[4*i128+0] = _mm256_add_epi8(xv[4*i128+0], _mm256_set1_epi8((extra << 2) & 0x04));
+ xv[4*i128+1] = _mm256_add_epi8(xv[4*i128+1], _mm256_set1_epi8((extra << 1) & 0x04));
+ xv[4*i128+2] = _mm256_add_epi8(xv[4*i128+2], _mm256_set1_epi8((extra >> 0) & 0x04));
+ xv[4*i128+3] = _mm256_add_epi8(xv[4*i128+3], _mm256_set1_epi8((extra >> 1) & 0x04));
+ xv[4*i128+0] = _mm256_shuffle_epi8(values, xv[4*i128+0]);
+ xv[4*i128+1] = _mm256_shuffle_epi8(values, xv[4*i128+1]);
+ xv[4*i128+2] = _mm256_shuffle_epi8(values, xv[4*i128+2]);
+ xv[4*i128+3] = _mm256_shuffle_epi8(values, xv[4*i128+3]);
+ extra >>= 4;
+ }
+ dnew[k] = convert_to_q8_k_r8(k, 1.f/125, xv, ls, block, y[i].qs);
+ }
+ auto vd = _mm256_mul_ps(_mm256_loadu_ps(dnew), _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh)));
+ _mm_storeu_si128((__m128i *)y[i].d, _mm256_cvtps_ph(vd, _MM_ROUND_NEAREST));
+ }
+ y += nb;
+ }
+}
+
+void iqk_convert_iq2_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq2_k * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ __m256i values;
+ {
+ auto v = _mm_loadl_epi64((const __m128i *)iq2nl_values);
+ values = MM256_SET_M128I(v, v);
+ }
+
+ __m256i xv[8];
+ uint32_t block[8];
+
+ const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
+
+ union { __m256i vec; int16_t val[16]; } helper;
+
+ auto ml = _mm256_set1_epi8(0x03);
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_k *)((const char *)vx + (ix+k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = GGML_FP16_TO_FP32(x8[k][i].d);
+ uint64_t aux64; std::memcpy(&aux64, x8[k][i].scales, 8);
+ auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf));
+ scl = _mm_add_epi8(scl, _mm_set1_epi8(-8));
+ helper.vec = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scl, scale_shuffle));
+ auto extra = x8[k][i].extra;
+ for (int i128 = 0; i128 < 2; ++i128) {
+ auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+i128);
+ xv[4*i128+0] = _mm256_and_si256(bits, ml);
+ xv[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), ml);
+ xv[4*i128+2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), ml);
+ xv[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), ml);
+ auto shift1 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x02) << 1), _mm_set1_epi8((extra & 0x01) << 2));
+ auto shift2 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x08) >> 1), _mm_set1_epi8((extra & 0x04) >> 0));
+ auto shift3 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x20) >> 3), _mm_set1_epi8((extra & 0x10) >> 2));
+ auto shift4 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x80) >> 5), _mm_set1_epi8((extra & 0x40) >> 4));
+ xv[4*i128+0] = _mm256_add_epi8(xv[4*i128+0], shift1);
+ xv[4*i128+1] = _mm256_add_epi8(xv[4*i128+1], shift2);
+ xv[4*i128+2] = _mm256_add_epi8(xv[4*i128+2], shift3);
+ xv[4*i128+3] = _mm256_add_epi8(xv[4*i128+3], shift4);
+ xv[4*i128+0] = _mm256_shuffle_epi8(values, xv[4*i128+0]);
+ xv[4*i128+1] = _mm256_shuffle_epi8(values, xv[4*i128+1]);
+ xv[4*i128+2] = _mm256_shuffle_epi8(values, xv[4*i128+2]);
+ xv[4*i128+3] = _mm256_shuffle_epi8(values, xv[4*i128+3]);
+ extra >>= 8;
+ }
+ float dnew = convert_to_q8_k_r8(k, 1.f/120, xv, helper.val, block, y[i].qs);
+ y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
+ }
+ }
+ y += nb;
+ }
+}
+
+void iqk_convert_iq3_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq3_k * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ __m256i values;
+ {
+ auto v = _mm_loadu_si128((const __m128i *)iq3nl_values);
+ values = MM256_SET_M128I(v, v);
+ }
+
+ __m256i xv[8];
+ uint32_t block[8];
+
+ constexpr static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
+ const __m128i sign_mask = _mm_set_epi64x(0x8080404020201010, 0x0808040402020101);
+ const __m128i hshuff = _mm_loadu_si128((const __m128i*)k_shuff);
+ const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
+
+ union { __m256i vec; int16_t val[16]; } helper;
+
+ auto ml = _mm256_set1_epi8(0x03);
+ auto hmask = _mm256_set1_epi8(4);
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq3_k *)((const char *)vx + (ix+k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = GGML_FP16_TO_FP32(x8[k][i].d);
+ uint64_t aux64; std::memcpy(&aux64, x8[k][i].scales_l, 8);
+ auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf));
+ scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), _mm_set1_epi8(1));
+ auto sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(x8[k][i].scales_h), sign_mask), sign_mask);
+ auto sch = _mm_shuffle_epi8(_mm_or_si128(sc_signs, _mm_set1_epi8(1)), hshuff);
+ helper.vec = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(_mm_sign_epi8(scl, sch), scale_shuffle));
+ auto extra = x8[k][i].extra;
+ auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh);
+ for (int i128 = 0; i128 < 2; ++i128) {
+ auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+i128);
+ xv[4*i128+0] = _mm256_and_si256(bits, ml);
+ xv[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), ml);
+ xv[4*i128+2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), ml);
+ xv[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), ml);
+ xv[4*i128+0] = _mm256_or_si256(xv[4*i128+0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), hmask));
+ xv[4*i128+1] = _mm256_or_si256(xv[4*i128+1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), hmask));
+ xv[4*i128+2] = _mm256_or_si256(xv[4*i128+2], _mm256_and_si256(hbits, hmask));
+ xv[4*i128+3] = _mm256_or_si256(xv[4*i128+3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), hmask));
+ auto shift1 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x02) << 2), _mm_set1_epi8((extra & 0x01) << 3));
+ auto shift2 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x08) << 0), _mm_set1_epi8((extra & 0x04) << 1));
+ auto shift3 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x20) >> 2), _mm_set1_epi8((extra & 0x10) >> 1));
+ auto shift4 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x80) >> 4), _mm_set1_epi8((extra & 0x40) >> 3));
+ xv[4*i128+0] = _mm256_add_epi8(xv[4*i128+0], shift1);
+ xv[4*i128+1] = _mm256_add_epi8(xv[4*i128+1], shift2);
+ xv[4*i128+2] = _mm256_add_epi8(xv[4*i128+2], shift3);
+ xv[4*i128+3] = _mm256_add_epi8(xv[4*i128+3], shift4);
+ xv[4*i128+0] = _mm256_shuffle_epi8(values, xv[4*i128+0]);
+ xv[4*i128+1] = _mm256_shuffle_epi8(values, xv[4*i128+1]);
+ xv[4*i128+2] = _mm256_shuffle_epi8(values, xv[4*i128+2]);
+ xv[4*i128+3] = _mm256_shuffle_epi8(values, xv[4*i128+3]);
+ hbits = _mm256_srli_epi16(hbits, 4);
+ extra >>= 8;
+ }
+ float dnew = convert_to_q8_k_r8(k, 1.f/127, xv, helper.val, block, y[i].qs);
+ y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
+ }
+ }
+ y += nb;
+ }
+}
+
+void iqk_convert_iq4_ks_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq4_ks * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ __m256i values[2];
+ {
+ auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0);
+ auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1);
+ values[0] = MM256_SET_M128I(v1, v1);
+ values[1] = MM256_SET_M128I(v2, v2);
+ }
+
+ float drow[8];
+ float dnew[8];
+ int16_t ls[16];
+
+ __m256i xv[8];
+ uint32_t block[8];
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) {
+ const float * dptr = (const float *)((const char *)vx + (ix + k)*bx);
+ drow[k] = dptr[0];
+ x8[k] = (const block_iq4_ks *)(dptr + 1);
+ }
+ auto vd = _mm256_loadu_ps(drow);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ ls[2*ib32+0] = (x8[k][i].scales[ib32] & 254) - 127;
+ ls[2*ib32+1] = ls[2*ib32+0];
+ auto aux128 = _mm_loadu_si128((const __m128i *)x8[k][i].qs+ib32);
+ xv[ib32] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), _mm256_set1_epi8(0xf));
+ xv[ib32] = _mm256_shuffle_epi8(values[x8[k][i].scales[ib32] & 1], xv[ib32]);
+ }
+ dnew[k] = convert_to_q8_k_r8(k, 1.f/127, xv, ls, block, y[i].qs);
+ }
+ _mm_storeu_si128((__m128i *)y[i].d, _mm256_cvtps_ph(_mm256_mul_ps(vd, _mm256_loadu_ps(dnew)), _MM_ROUND_NEAREST));
+ }
+ y += nb;
+ }
+}
+
+void iqk_convert_iq4_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq4_k * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ __m256i values[4];
+ {
+ auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0);
+ auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1);
+ values[0] = MM256_SET_M128I(v1, v1);
+ values[1] = MM256_SET_M128I(v1, v2);
+ values[2] = MM256_SET_M128I(v2, v1);
+ values[3] = MM256_SET_M128I(v2, v2);
+ }
+
+ __m256i xv[8];
+ uint32_t block[8];
+ int16_t ls[16];
+
+ //auto hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
+
+ //union { __m256i vec; int16_t val[16]; } helper;
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq4_k *)((const char *)vx + (ix+k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = GGML_FP16_TO_FP32(x8[k][i].d);
+ auto extra = x8[k][i].extra;
+ //uint64_t aux64;
+ //memcpy(&aux64, x8[k][i].scales_l, 8);
+ //auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf));
+ //const uint32_t aux32 = *(const uint32_t *)x8[k][i].scales_h;
+ //auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), _mm_set1_epi8(0x30));
+ //auto sch = _mm_shuffle_epi8(aux, hshuff);
+ //aux = _mm_add_epi8(_mm_or_si128(scl, sch), _mm_set1_epi8(-32));
+ //helper.vec = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(aux, hshuff));
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ const uint8_t sh = x8[k][i].scales_h[ib32/2] >> 4*(ib32%2);
+ ls[2*ib32+0] = ((x8[k][i].scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32;
+ ls[2*ib32+1] = ((x8[k][i].scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32;
+ auto bits = _mm_loadu_si128((const __m128i *)x8[k][i].qs+ib32);
+ xv[ib32] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(bits, 4), bits), _mm256_set1_epi8(0xf));
+ xv[ib32] = _mm256_shuffle_epi8(values[extra & 3], xv[ib32]); extra >>= 2;
+ }
+ //float dnew = convert_to_q8_k_r8(k, 1.f/127, xv, helper.val, block, y[i].qs);
+ float dnew = convert_to_q8_k_r8(k, 1.f/127, xv, ls, block, y[i].qs);
+ y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
+ }
+ }
+ y += nb;
+ }
+}
+
+void iqk_convert_iq5_ks_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq5_ks * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ __m256i values[2];
+ {
+ auto v1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0);
+ auto v2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1);
+ values[0] = MM256_SET_M128I(v1, v1);
+ values[1] = MM256_SET_M128I(v2, v2);
+ }
+
+ float drow[8];
+ float dnew[8];
+ int16_t ls[16];
+
+ __m256i xv[8];
+ uint32_t block[8];
+
+ auto mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) {
+ const float * dptr = (const float *)((const char *)vx + (ix + k)*bx);
+ drow[k] = dptr[0];
+ x8[k] = (const block_iq5_ks *)(dptr + 1);
+ }
+ auto vd = _mm256_loadu_ps(drow);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh);
+ for (int ib64 = 0; ib64 < 4; ++ib64) {
+ ls[4*ib64+0] = (x8[k][i].scales[2*ib64+0] & 254) - 127;
+ ls[4*ib64+1] = ls[4*ib64+0];
+ ls[4*ib64+2] = (x8[k][i].scales[2*ib64+1] & 254) - 127;
+ ls[4*ib64+3] = ls[4*ib64+2];
+ auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+ib64);
+ xv[2*ib64+0] = _mm256_and_si256(bits, _mm256_set1_epi8(0xf));
+ xv[2*ib64+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf));
+ auto qh = _mm256_and_si256(_mm256_slli_epi16(hbits, 7), mh);
+ auto q5vl = _mm256_or_si256(xv[2*ib64+0], qh);
+ auto q5vh = _mm256_or_si256(xv[2*ib64+0], _mm256_xor_si256(qh, mh));
+ xv[2*ib64+0] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
+ qh = _mm256_and_si256(_mm256_slli_epi16(hbits, 6), mh);
+ q5vl = _mm256_or_si256(xv[2*ib64+1], qh);
+ q5vh = _mm256_or_si256(xv[2*ib64+1], _mm256_xor_si256(qh, mh));
+ xv[2*ib64+1] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
+ auto shift1 = _mm256_set1_epi8((x8[k][i].scales[2*ib64+0] & 1) << 1);
+ auto shift2 = _mm256_set1_epi8((x8[k][i].scales[2*ib64+1] & 1) << 1);
+ xv[2*ib64+0] = _mm256_add_epi8(xv[2*ib64+0], shift1);
+ xv[2*ib64+1] = _mm256_add_epi8(xv[2*ib64+1], shift2);
+ hbits = _mm256_srli_epi16(hbits, 2);
+ }
+ dnew[k] = convert_to_q8_k_r8(k, 1.f/127, xv, ls, block, y[i].qs);
+ }
+ _mm_storeu_si128((__m128i *)y[i].d, _mm256_cvtps_ph(_mm256_mul_ps(vd, _mm256_loadu_ps(dnew)), _MM_ROUND_NEAREST));
+ }
+ y += nb;
+ }
+}
+
+void iqk_convert_iq5_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq5_k * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ __m256i values[2];
+ {
+ auto v1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0);
+ auto v2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1);
+ values[0] = MM256_SET_M128I(v1, v1);
+ values[1] = MM256_SET_M128I(v2, v2);
+ }
+
+ __m256i xv[8];
+ uint32_t block[8];
+ int16_t ls[16];
+
+ auto mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq5_k *)((const char *)vx + (ix+k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = GGML_FP16_TO_FP32(x8[k][i].d);
+ auto extra = x8[k][i].extra;
+ auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh);
+ for (int ib64 = 0; ib64 < 4; ++ib64) {
+ ls[4*ib64+0] = ((x8[k][i].scales_l[2*ib64+0] & 0xf) | ((x8[k][i].scales_h[ib64] << 4) & 0x30)) - 32;
+ ls[4*ib64+1] = ((x8[k][i].scales_l[2*ib64+0] >> 4) | ((x8[k][i].scales_h[ib64] << 2) & 0x30)) - 32;
+ ls[4*ib64+2] = ((x8[k][i].scales_l[2*ib64+1] & 0xf) | ((x8[k][i].scales_h[ib64] >> 0) & 0x30)) - 32;
+ ls[4*ib64+3] = ((x8[k][i].scales_l[2*ib64+1] >> 4) | ((x8[k][i].scales_h[ib64] >> 2) & 0x30)) - 32;
+ auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+ib64);
+ xv[2*ib64+0] = _mm256_and_si256(bits, _mm256_set1_epi8(0xf));
+ xv[2*ib64+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf));
+ auto qh = _mm256_and_si256(_mm256_slli_epi16(hbits, 7), mh);
+ auto q5vl = _mm256_or_si256(xv[2*ib64+0], qh);
+ auto q5vh = _mm256_or_si256(xv[2*ib64+0], _mm256_xor_si256(qh, mh));
+ xv[2*ib64+0] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
+ qh = _mm256_and_si256(_mm256_slli_epi16(hbits, 6), mh);
+ q5vl = _mm256_or_si256(xv[2*ib64+1], qh);
+ q5vh = _mm256_or_si256(xv[2*ib64+1], _mm256_xor_si256(qh, mh));
+ xv[2*ib64+1] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
+ auto shift1 = MM256_SET_M128I(_mm_set1_epi8((extra & 2) << 0), _mm_set1_epi8((extra & 1) << 1));
+ auto shift2 = MM256_SET_M128I(_mm_set1_epi8((extra & 8) >> 2), _mm_set1_epi8((extra & 4) >> 1));
+ xv[2*ib64+0] = _mm256_add_epi8(xv[2*ib64+0], shift1);
+ xv[2*ib64+1] = _mm256_add_epi8(xv[2*ib64+1], shift2);
+ hbits = _mm256_srli_epi16(hbits, 2);
+ extra >>= 4;
+ }
+ float dnew = convert_to_q8_k_r8(k, 1.f/127, xv, ls, block, y[i].qs);
+ y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
+ }
+ }
+ y += nb;
+ }
+}
+
+void iqk_convert_iq5_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq5_k * x8[8];
+
+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
+
+ __m256i values[2];
+ {
+ auto v1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0);
+ auto v2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1);
+ values[0] = MM256_SET_M128I(v1, v1);
+ values[1] = MM256_SET_M128I(v2, v2);
+ }
+
+ __m256i xv[8];
+ uint32_t block[8];
+ int16_t ls[16];
+ float all_s[64];
+
+ auto mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq5_k *)((const char *)vx + (ix+k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = GGML_FP16_TO_FP32(x8[k][i].d);
+ auto extra = x8[k][i].extra;
+ auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh);
+ for (int ib64 = 0; ib64 < 4; ++ib64) {
+ ls[4*ib64+0] = ((x8[k][i].scales_l[2*ib64+0] & 0xf) | ((x8[k][i].scales_h[ib64] << 4) & 0x30)) - 32;
+ ls[4*ib64+1] = ((x8[k][i].scales_l[2*ib64+0] >> 4) | ((x8[k][i].scales_h[ib64] << 2) & 0x30)) - 32;
+ ls[4*ib64+2] = ((x8[k][i].scales_l[2*ib64+1] & 0xf) | ((x8[k][i].scales_h[ib64] >> 0) & 0x30)) - 32;
+ ls[4*ib64+3] = ((x8[k][i].scales_l[2*ib64+1] >> 4) | ((x8[k][i].scales_h[ib64] >> 2) & 0x30)) - 32;
+ auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+ib64);
+ xv[2*ib64+0] = _mm256_and_si256(bits, _mm256_set1_epi8(0xf));
+ xv[2*ib64+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf));
+ auto qh = _mm256_and_si256(_mm256_slli_epi16(hbits, 7), mh);
+ auto q5vl = _mm256_or_si256(xv[2*ib64+0], qh);
+ auto q5vh = _mm256_or_si256(xv[2*ib64+0], _mm256_xor_si256(qh, mh));
+ xv[2*ib64+0] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
+ qh = _mm256_and_si256(_mm256_slli_epi16(hbits, 6), mh);
+ q5vl = _mm256_or_si256(xv[2*ib64+1], qh);
+ q5vh = _mm256_or_si256(xv[2*ib64+1], _mm256_xor_si256(qh, mh));
+ xv[2*ib64+1] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
+ auto shift1 = MM256_SET_M128I(_mm_set1_epi8((extra & 2) << 0), _mm_set1_epi8((extra & 1) << 1));
+ auto shift2 = MM256_SET_M128I(_mm_set1_epi8((extra & 8) >> 2), _mm_set1_epi8((extra & 4) >> 1));
+ xv[2*ib64+0] = _mm256_add_epi8(xv[2*ib64+0], shift1);
+ xv[2*ib64+1] = _mm256_add_epi8(xv[2*ib64+1], shift2);
+ hbits = _mm256_srli_epi16(hbits, 2);
+ extra >>= 4;
+ }
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ // We have two blocks of 16 with different scales
+ // We multiply the quants with the scales, find the max value, and convert to 8-bit quants with a single block scale.
+ auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(xv[ib32]));
+ auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xv[ib32], 1));
+ q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(ls[2*ib32+0]));
+ q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(ls[2*ib32+1]));
+ auto abs_q16_l = _mm256_sign_epi16(q16_l, q16_l);
+ auto abs_q16_h = _mm256_sign_epi16(q16_h, q16_h);
+ auto max_q16 = _mm256_max_epi16(abs_q16_l, abs_q16_h);
+ auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_q16), _mm256_extracti128_si256(max_q16, 1)));
+ auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
+ auto max4 = _mm_cvtepi32_ps(imax4);
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+ float max = _mm_cvtss_f32(max4) / 127;
+ all_s[8*ib32+k] = d*max;
+ if (max > 1e-9f) {
+ auto scale = _mm256_set1_ps(1/max);
+ auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l));
+ auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1));
+ auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h));
+ auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1));
+ i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
+ i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
+ i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
+ i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
+ i0 = _mm256_packs_epi32(i0, i1);
+ i2 = _mm256_packs_epi32(i2, i3);
+ i0 = _mm256_packs_epi16(i0, i2);
+ i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
+ _mm256_storeu_si256((__m256i *)block, i0);
+ } else {
+ _mm256_storeu_si256((__m256i *)block, _mm256_setzero_si256());
+ }
+ auto qs = (uint32_t *)y[ib32].qs;
+ for (int l = 0; l < 4; ++l) {
+ qs[8*l + k + 0] = block[l + 0];
+ qs[8*l + k + 32] = block[l + 4];
+ }
+ }
+ }
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ _mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(_mm256_loadu_ps(all_s + 8*ib32), _MM_FROUND_TO_NEAREST_INT));
+ }
+ y += QK_K/32;
+ }
+ }
+}
+
+void iqk_convert_iq6_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq6_k * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ __m256i values[4];
+ for (int k = 0; k < 4; ++k) {
+ auto values128 = _mm_loadu_si128((const __m128i *)iq6nl_values + k);
+ values[k] = MM256_SET_M128I(values128, values128);
+ }
+
+ __m256i xv[8];
+ uint32_t block[8];
+
+ union { __m256i vec; int16_t val[16]; } helper;
+
+ auto mh1 = _mm256_set1_epi8(1);
+ auto mh2 = _mm256_set1_epi8(2);
+ auto mh3 = _mm256_set1_epi8(3);
+
+ auto make_one = [&values, &mh1, &mh2, &mh3] (__m256i l, __m256i hbits) {
+ auto mask4 = _mm256_cmpeq_epi8(_mm256_and_si256(hbits, mh3), mh3);
+ auto h1 = _mm256_andnot_si256(mask4, hbits);
+ auto mask2 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh1), mh1);
+ auto mask3 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh2), mh2);
+ auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(-1)); // 0xff;
+ return _mm256_or_si256(_mm256_or_si256(_mm256_and_si256(mask1, _mm256_shuffle_epi8(values[0], l)),
+ _mm256_and_si256(mask2, _mm256_shuffle_epi8(values[1], l))),
+ _mm256_or_si256(_mm256_and_si256(mask3, _mm256_shuffle_epi8(values[2], l)),
+ _mm256_and_si256(mask4, _mm256_shuffle_epi8(values[3], l))));
+ };
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq6_k *)((const char *)vx + (ix+k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = GGML_FP16_TO_FP32(x8[k][i].d);
+ helper.vec = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*)x8[k][i].scales));
+ auto extra = x8[k][i].extra;
+ for (int i128 = 0; i128 < 2; ++i128) {
+ auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh+i128);
+ auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+2*i128+0);
+ xv[4*i128+0] = _mm256_and_si256(bits, _mm256_set1_epi8(0xf));
+ xv[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf));
+ bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+2*i128+1);
+ xv[4*i128+2] = _mm256_and_si256(bits, _mm256_set1_epi8(0xf));
+ xv[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf));
+ for (int k = 0; k < 4; ++k) {
+ xv[4*i128+k] = make_one(xv[4*i128+k], hbits);
+ hbits = _mm256_srli_epi16(hbits, 2);
+ }
+ auto shift1 = MM256_SET_M128I(_mm_set1_epi8((extra >> 1) & 1), _mm_set1_epi8((extra >> 0) & 1));
+ auto shift2 = MM256_SET_M128I(_mm_set1_epi8((extra >> 3) & 1), _mm_set1_epi8((extra >> 2) & 1));
+ auto shift3 = MM256_SET_M128I(_mm_set1_epi8((extra >> 5) & 1), _mm_set1_epi8((extra >> 4) & 1));
+ auto shift4 = MM256_SET_M128I(_mm_set1_epi8((extra >> 7) & 1), _mm_set1_epi8((extra >> 6) & 1));
+ xv[4*i128+0] = _mm256_add_epi8(xv[4*i128+0], shift1);
+ xv[4*i128+1] = _mm256_add_epi8(xv[4*i128+1], shift2);
+ xv[4*i128+2] = _mm256_add_epi8(xv[4*i128+2], shift3);
+ xv[4*i128+3] = _mm256_add_epi8(xv[4*i128+3], shift4);
+ extra >>= 8;
+ }
+ float dnew = convert_to_q8_k_r8(k, 1.f/127, xv, helper.val, block, y[i].qs);
+ y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
+ }
+ }
+ y += nb;
+ }
+}
+
} // namespace
+bool iqk_convert_iqk_quants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ if (n%QK_K != 0 || nrc_x%8 != 0) return false;
+ switch (ggml_type(type)) {
+ case GGML_TYPE_IQ2_KS : iqk_convert_iq2_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ2_K : iqk_convert_iq2_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ3_K : iqk_convert_iq3_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ4_KS : iqk_convert_iq4_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ4_K : iqk_convert_iq4_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ5_KS : iqk_convert_iq5_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ5_K : iqk_convert_iq5_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ6_K : iqk_convert_iq6_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
+ default: return false;
+ }
+ return true;
+}
+
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
auto etypeA = ggml_type(typeA);
diff --git a/ggml/src/iqk/iqk_gemm_iqk_quants.h b/ggml/src/iqk/iqk_gemm_iqk_quants.h
index cd076ff7..41beca63 100644
--- a/ggml/src/iqk/iqk_gemm_iqk_quants.h
+++ b/ggml/src/iqk/iqk_gemm_iqk_quants.h
@@ -8,4 +8,6 @@
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
+bool iqk_convert_iqk_quants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x);
+
#endif
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 0b29a572..81b5841d 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -250,6 +250,14 @@ struct MulMat {
case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q6_K : return nrc_y >= 64 ? GGML_TYPE_Q8_0_R8 : type;
+ case GGML_TYPE_IQ2_KS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ2_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ3_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ4_KS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ5_KS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ6_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
default: break;
}
#else
@@ -375,22 +383,22 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
case GGML_TYPE_IQ3_XXS_R4:
case GGML_TYPE_IQ3_S_R4:
return iqk_convert_iquants_q80_r8(typeA, n, vx, bx, vy, nrc_x);
- //case GGML_TYPE_IQ4_KS:
- //case GGML_TYPE_IQ5_KS:
- //case GGML_TYPE_IQ4_KSS:
- //case GGML_TYPE_IQ2_K:
- //case GGML_TYPE_IQ2_KS:
- //case GGML_TYPE_IQ3_K:
- //case GGML_TYPE_IQ4_K:
- //case GGML_TYPE_IQ5_K:
- //case GGML_TYPE_IQ6_K:
+ case GGML_TYPE_IQ2_KS:
+ case GGML_TYPE_IQ2_K:
+ case GGML_TYPE_IQ3_K:
+ case GGML_TYPE_IQ4_KSS:
+ case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ5_KS:
+ case GGML_TYPE_IQ5_K:
+ case GGML_TYPE_IQ6_K:
//case GGML_TYPE_IQ2_K_R4:
//case GGML_TYPE_IQ3_K_R4:
//case GGML_TYPE_IQ4_K_R4:
//case GGML_TYPE_IQ5_K_R4:
//case GGML_TYPE_IQ4_KS_R4:
//case GGML_TYPE_IQ5_KS_R4:
- // return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
+ return iqk_convert_iqk_quants_q80_r8(typeA, n, vx, bx, vy, nrc_x);
case GGML_TYPE_IQ2_KT:
case GGML_TYPE_IQ3_KT:
case GGML_TYPE_IQ4_KT: