diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-06-18 07:29:33 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-18 07:29:33 +0300 |
commit | dc96820ddb45c639ea4e149e4bbfcb0b67fbcc2b (patch) | |
tree | 2ac3011164d541f5899db1afdad375cc59bfc142 | |
parent | 8b3002bba2ea64b1de9ca2ff87207d8c37b0f08e (diff) |
Much faster CPU prompt processing (part 2) (#533)
* iq4_ks
203 t/s -> 357 t/s. iq4_ks_r4 is 242 t/s.
* iq4_k
175 t/s -> 353 t/s. iq4_k_r4 is 208 t/s.
PPL is actually lower!
* iq5_ks
180 t/s -> 359 t/s. iq5_ks_r4 is 210 t/s.
PPL is actually lower - 7.4160 vs 7.4494 for LlaMA-3.1-8B-Instruct
* iq5_k - accuracy loss is too big
* iq5_k - there was a bug with the shifts
...and that's why PPL was so high. It is also high on main.
This fixes it.
* iq6_k
148 t/s -> 350 t/s. There is no iq6_k_r4
PPL is actually lower because we have a bug in the existing
implementation!
* iq3_k
169 t/s -> 363 t/s. iq3_k_r4 is at 200 t/s.
* iq2_k
190 t/s -> 364 t/s. iq2_k_r4 is at 232 t/s.
* iq2_ks
200 t/s -> 367 t/s. There is no iq2_ks_r4.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml.c | 4 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_gemm_iqk_quants.cpp | 686 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_gemm_iqk_quants.h | 2 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 28 |
4 files changed, 710 insertions, 10 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a6260136..69b1b46d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1699,7 +1699,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq5_k, .from_float_ref = (ggml_from_float_t)quantize_row_iq5_k_ref, .vec_dot = vec_dot_iq5_k_q8_k, +//#ifdef __AVX2__ +// .vec_dot_type = GGML_TYPE_Q8_2_X4, +//#else .vec_dot_type = GGML_TYPE_Q8_K, +//#endif .nrows = 1, .row_meta_size = 0, }, diff --git a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp index 15c963ca..a01d7e4c 100644 --- a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp @@ -2053,8 +2053,694 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX #endif } +inline float convert_to_q8_k_r8(int k, float d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) { + auto max_i16 = _mm256_setzero_si256(); + __m256i qs[16]; + for (int ib32 = 0; ib32 < 8; ++ib32) { + qs[2*ib32+0] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32])); + qs[2*ib32+1] = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1)); + qs[2*ib32+0] = _mm256_mullo_epi16(qs[2*ib32+0], _mm256_set1_epi16(scales[2*ib32+0])); + qs[2*ib32+1] = _mm256_mullo_epi16(qs[2*ib32+1], _mm256_set1_epi16(scales[2*ib32+1])); + max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+0], qs[2*ib32+0])); + max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+1], qs[2*ib32+1])); + } + auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1))); + auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1)); + auto max4 = _mm_cvtepi32_ps(imax4); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + bool needs_scaling = true; + float dnew = _mm_cvtss_f32(max4) * d0; + if (dnew < 1.f) { + dnew = 1.f; needs_scaling = false; + } + auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f); + for (int ib32 = 0; ib32 < 8; ++ib32) { + if (needs_scaling) { + auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+0])); + auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+0], 1)); + auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+1])); + auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+1], 1)); + i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST)); + i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST)); + i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST)); + i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST)); + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + _mm256_storeu_si256((__m256i *)block, i0); + } else { + // 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 + auto i0 = _mm256_packs_epi16(qs[2*ib32+0], qs[2*ib32+1]); + auto i0_l = _mm256_castsi256_si128(i0); + auto i0_h = _mm256_extracti128_si256(i0, 1); + _mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h)); + _mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h)); + } + auto qs = (uint32_t *)q8_k + 64*ib32; + for (int l = 0; l < 8; ++l) { + qs[8*l + k] = block[l]; + } + } + return dnew; +} + +void iqk_convert_iq2_ks_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_iq2_ks * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + __m256i values; + { + auto v = _mm_loadl_epi64((const __m128i *)iq2nl_values); + values = MM256_SET_M128I(v, v); + } + + ggml_half dh[8]; + float dnew[8]; + uint32_t block[8]; + int16_t ls[16]; + + __m256i xv[8]; + + auto ml = _mm256_set1_epi8(0x03); + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const ggml_half * dptr = (const ggml_half *)((const char *)vx + (ix+k)*bx); + dh[k] = dptr[0]; + x8[k] = (const block_iq2_ks *)(dptr + 1); + } + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + auto extra = x8[k][i].extra; + for (int i128 = 0; i128 < 2; ++i128) { + ls[8*i128+0] = ls[8*i128+1] = ((x8[k][i].scales[2*i128+0] & 0xf) | ((extra >> 4) & 0x10)) - 16; + ls[8*i128+2] = ls[8*i128+3] = ((x8[k][i].scales[2*i128+0] >> 4) | ((extra >> 5) & 0x10)) - 16; + ls[8*i128+4] = ls[8*i128+5] = ((x8[k][i].scales[2*i128+1] & 0xf) | ((extra >> 6) & 0x10)) - 16; + ls[8*i128+6] = ls[8*i128+7] = ((x8[k][i].scales[2*i128+1] >> 4) | ((extra >> 7) & 0x10)) - 16; + auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+i128); + xv[4*i128+0] = _mm256_and_si256(bits, ml); + xv[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), ml); + xv[4*i128+2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), ml); + xv[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), ml); + xv[4*i128+0] = _mm256_add_epi8(xv[4*i128+0], _mm256_set1_epi8((extra << 2) & 0x04)); + xv[4*i128+1] = _mm256_add_epi8(xv[4*i128+1], _mm256_set1_epi8((extra << 1) & 0x04)); + xv[4*i128+2] = _mm256_add_epi8(xv[4*i128+2], _mm256_set1_epi8((extra >> 0) & 0x04)); + xv[4*i128+3] = _mm256_add_epi8(xv[4*i128+3], _mm256_set1_epi8((extra >> 1) & 0x04)); + xv[4*i128+0] = _mm256_shuffle_epi8(values, xv[4*i128+0]); + xv[4*i128+1] = _mm256_shuffle_epi8(values, xv[4*i128+1]); + xv[4*i128+2] = _mm256_shuffle_epi8(values, xv[4*i128+2]); + xv[4*i128+3] = _mm256_shuffle_epi8(values, xv[4*i128+3]); + extra >>= 4; + } + dnew[k] = convert_to_q8_k_r8(k, 1.f/125, xv, ls, block, y[i].qs); + } + auto vd = _mm256_mul_ps(_mm256_loadu_ps(dnew), _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh))); + _mm_storeu_si128((__m128i *)y[i].d, _mm256_cvtps_ph(vd, _MM_ROUND_NEAREST)); + } + y += nb; + } +} + +void iqk_convert_iq2_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_iq2_k * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + __m256i values; + { + auto v = _mm_loadl_epi64((const __m128i *)iq2nl_values); + values = MM256_SET_M128I(v, v); + } + + __m256i xv[8]; + uint32_t block[8]; + + const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + + union { __m256i vec; int16_t val[16]; } helper; + + auto ml = _mm256_set1_epi8(0x03); + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_k *)((const char *)vx + (ix+k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = GGML_FP16_TO_FP32(x8[k][i].d); + uint64_t aux64; std::memcpy(&aux64, x8[k][i].scales, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf)); + scl = _mm_add_epi8(scl, _mm_set1_epi8(-8)); + helper.vec = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scl, scale_shuffle)); + auto extra = x8[k][i].extra; + for (int i128 = 0; i128 < 2; ++i128) { + auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+i128); + xv[4*i128+0] = _mm256_and_si256(bits, ml); + xv[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), ml); + xv[4*i128+2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), ml); + xv[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), ml); + auto shift1 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x02) << 1), _mm_set1_epi8((extra & 0x01) << 2)); + auto shift2 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x08) >> 1), _mm_set1_epi8((extra & 0x04) >> 0)); + auto shift3 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x20) >> 3), _mm_set1_epi8((extra & 0x10) >> 2)); + auto shift4 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x80) >> 5), _mm_set1_epi8((extra & 0x40) >> 4)); + xv[4*i128+0] = _mm256_add_epi8(xv[4*i128+0], shift1); + xv[4*i128+1] = _mm256_add_epi8(xv[4*i128+1], shift2); + xv[4*i128+2] = _mm256_add_epi8(xv[4*i128+2], shift3); + xv[4*i128+3] = _mm256_add_epi8(xv[4*i128+3], shift4); + xv[4*i128+0] = _mm256_shuffle_epi8(values, xv[4*i128+0]); + xv[4*i128+1] = _mm256_shuffle_epi8(values, xv[4*i128+1]); + xv[4*i128+2] = _mm256_shuffle_epi8(values, xv[4*i128+2]); + xv[4*i128+3] = _mm256_shuffle_epi8(values, xv[4*i128+3]); + extra >>= 8; + } + float dnew = convert_to_q8_k_r8(k, 1.f/120, xv, helper.val, block, y[i].qs); + y[i].d[k] = GGML_FP32_TO_FP16(d*dnew); + } + } + y += nb; + } +} + +void iqk_convert_iq3_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_iq3_k * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + __m256i values; + { + auto v = _mm_loadu_si128((const __m128i *)iq3nl_values); + values = MM256_SET_M128I(v, v); + } + + __m256i xv[8]; + uint32_t block[8]; + + constexpr static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; + const __m128i sign_mask = _mm_set_epi64x(0x8080404020201010, 0x0808040402020101); + const __m128i hshuff = _mm_loadu_si128((const __m128i*)k_shuff); + const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + + union { __m256i vec; int16_t val[16]; } helper; + + auto ml = _mm256_set1_epi8(0x03); + auto hmask = _mm256_set1_epi8(4); + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_iq3_k *)((const char *)vx + (ix+k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = GGML_FP16_TO_FP32(x8[k][i].d); + uint64_t aux64; std::memcpy(&aux64, x8[k][i].scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf)); + scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), _mm_set1_epi8(1)); + auto sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(x8[k][i].scales_h), sign_mask), sign_mask); + auto sch = _mm_shuffle_epi8(_mm_or_si128(sc_signs, _mm_set1_epi8(1)), hshuff); + helper.vec = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(_mm_sign_epi8(scl, sch), scale_shuffle)); + auto extra = x8[k][i].extra; + auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh); + for (int i128 = 0; i128 < 2; ++i128) { + auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+i128); + xv[4*i128+0] = _mm256_and_si256(bits, ml); + xv[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), ml); + xv[4*i128+2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), ml); + xv[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), ml); + xv[4*i128+0] = _mm256_or_si256(xv[4*i128+0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), hmask)); + xv[4*i128+1] = _mm256_or_si256(xv[4*i128+1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), hmask)); + xv[4*i128+2] = _mm256_or_si256(xv[4*i128+2], _mm256_and_si256(hbits, hmask)); + xv[4*i128+3] = _mm256_or_si256(xv[4*i128+3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), hmask)); + auto shift1 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x02) << 2), _mm_set1_epi8((extra & 0x01) << 3)); + auto shift2 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x08) << 0), _mm_set1_epi8((extra & 0x04) << 1)); + auto shift3 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x20) >> 2), _mm_set1_epi8((extra & 0x10) >> 1)); + auto shift4 = MM256_SET_M128I(_mm_set1_epi8((extra & 0x80) >> 4), _mm_set1_epi8((extra & 0x40) >> 3)); + xv[4*i128+0] = _mm256_add_epi8(xv[4*i128+0], shift1); + xv[4*i128+1] = _mm256_add_epi8(xv[4*i128+1], shift2); + xv[4*i128+2] = _mm256_add_epi8(xv[4*i128+2], shift3); + xv[4*i128+3] = _mm256_add_epi8(xv[4*i128+3], shift4); + xv[4*i128+0] = _mm256_shuffle_epi8(values, xv[4*i128+0]); + xv[4*i128+1] = _mm256_shuffle_epi8(values, xv[4*i128+1]); + xv[4*i128+2] = _mm256_shuffle_epi8(values, xv[4*i128+2]); + xv[4*i128+3] = _mm256_shuffle_epi8(values, xv[4*i128+3]); + hbits = _mm256_srli_epi16(hbits, 4); + extra >>= 8; + } + float dnew = convert_to_q8_k_r8(k, 1.f/127, xv, helper.val, block, y[i].qs); + y[i].d[k] = GGML_FP32_TO_FP16(d*dnew); + } + } + y += nb; + } +} + +void iqk_convert_iq4_ks_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_iq4_ks * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + __m256i values[2]; + { + auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0); + auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1); + values[0] = MM256_SET_M128I(v1, v1); + values[1] = MM256_SET_M128I(v2, v2); + } + + float drow[8]; + float dnew[8]; + int16_t ls[16]; + + __m256i xv[8]; + uint32_t block[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char *)vx + (ix + k)*bx); + drow[k] = dptr[0]; + x8[k] = (const block_iq4_ks *)(dptr + 1); + } + auto vd = _mm256_loadu_ps(drow); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + for (int ib32 = 0; ib32 < 8; ++ib32) { + ls[2*ib32+0] = (x8[k][i].scales[ib32] & 254) - 127; + ls[2*ib32+1] = ls[2*ib32+0]; + auto aux128 = _mm_loadu_si128((const __m128i *)x8[k][i].qs+ib32); + xv[ib32] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), _mm256_set1_epi8(0xf)); + xv[ib32] = _mm256_shuffle_epi8(values[x8[k][i].scales[ib32] & 1], xv[ib32]); + } + dnew[k] = convert_to_q8_k_r8(k, 1.f/127, xv, ls, block, y[i].qs); + } + _mm_storeu_si128((__m128i *)y[i].d, _mm256_cvtps_ph(_mm256_mul_ps(vd, _mm256_loadu_ps(dnew)), _MM_ROUND_NEAREST)); + } + y += nb; + } +} + +void iqk_convert_iq4_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_iq4_k * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + __m256i values[4]; + { + auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0); + auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1); + values[0] = MM256_SET_M128I(v1, v1); + values[1] = MM256_SET_M128I(v1, v2); + values[2] = MM256_SET_M128I(v2, v1); + values[3] = MM256_SET_M128I(v2, v2); + } + + __m256i xv[8]; + uint32_t block[8]; + int16_t ls[16]; + + //auto hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + + //union { __m256i vec; int16_t val[16]; } helper; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_iq4_k *)((const char *)vx + (ix+k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = GGML_FP16_TO_FP32(x8[k][i].d); + auto extra = x8[k][i].extra; + //uint64_t aux64; + //memcpy(&aux64, x8[k][i].scales_l, 8); + //auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf)); + //const uint32_t aux32 = *(const uint32_t *)x8[k][i].scales_h; + //auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), _mm_set1_epi8(0x30)); + //auto sch = _mm_shuffle_epi8(aux, hshuff); + //aux = _mm_add_epi8(_mm_or_si128(scl, sch), _mm_set1_epi8(-32)); + //helper.vec = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(aux, hshuff)); + for (int ib32 = 0; ib32 < 8; ++ib32) { + const uint8_t sh = x8[k][i].scales_h[ib32/2] >> 4*(ib32%2); + ls[2*ib32+0] = ((x8[k][i].scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32; + ls[2*ib32+1] = ((x8[k][i].scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32; + auto bits = _mm_loadu_si128((const __m128i *)x8[k][i].qs+ib32); + xv[ib32] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(bits, 4), bits), _mm256_set1_epi8(0xf)); + xv[ib32] = _mm256_shuffle_epi8(values[extra & 3], xv[ib32]); extra >>= 2; + } + //float dnew = convert_to_q8_k_r8(k, 1.f/127, xv, helper.val, block, y[i].qs); + float dnew = convert_to_q8_k_r8(k, 1.f/127, xv, ls, block, y[i].qs); + y[i].d[k] = GGML_FP32_TO_FP16(d*dnew); + } + } + y += nb; + } +} + +void iqk_convert_iq5_ks_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_iq5_ks * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + __m256i values[2]; + { + auto v1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0); + auto v2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1); + values[0] = MM256_SET_M128I(v1, v1); + values[1] = MM256_SET_M128I(v2, v2); + } + + float drow[8]; + float dnew[8]; + int16_t ls[16]; + + __m256i xv[8]; + uint32_t block[8]; + + auto mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char *)vx + (ix + k)*bx); + drow[k] = dptr[0]; + x8[k] = (const block_iq5_ks *)(dptr + 1); + } + auto vd = _mm256_loadu_ps(drow); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh); + for (int ib64 = 0; ib64 < 4; ++ib64) { + ls[4*ib64+0] = (x8[k][i].scales[2*ib64+0] & 254) - 127; + ls[4*ib64+1] = ls[4*ib64+0]; + ls[4*ib64+2] = (x8[k][i].scales[2*ib64+1] & 254) - 127; + ls[4*ib64+3] = ls[4*ib64+2]; + auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+ib64); + xv[2*ib64+0] = _mm256_and_si256(bits, _mm256_set1_epi8(0xf)); + xv[2*ib64+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf)); + auto qh = _mm256_and_si256(_mm256_slli_epi16(hbits, 7), mh); + auto q5vl = _mm256_or_si256(xv[2*ib64+0], qh); + auto q5vh = _mm256_or_si256(xv[2*ib64+0], _mm256_xor_si256(qh, mh)); + xv[2*ib64+0] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); + qh = _mm256_and_si256(_mm256_slli_epi16(hbits, 6), mh); + q5vl = _mm256_or_si256(xv[2*ib64+1], qh); + q5vh = _mm256_or_si256(xv[2*ib64+1], _mm256_xor_si256(qh, mh)); + xv[2*ib64+1] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); + auto shift1 = _mm256_set1_epi8((x8[k][i].scales[2*ib64+0] & 1) << 1); + auto shift2 = _mm256_set1_epi8((x8[k][i].scales[2*ib64+1] & 1) << 1); + xv[2*ib64+0] = _mm256_add_epi8(xv[2*ib64+0], shift1); + xv[2*ib64+1] = _mm256_add_epi8(xv[2*ib64+1], shift2); + hbits = _mm256_srli_epi16(hbits, 2); + } + dnew[k] = convert_to_q8_k_r8(k, 1.f/127, xv, ls, block, y[i].qs); + } + _mm_storeu_si128((__m128i *)y[i].d, _mm256_cvtps_ph(_mm256_mul_ps(vd, _mm256_loadu_ps(dnew)), _MM_ROUND_NEAREST)); + } + y += nb; + } +} + +void iqk_convert_iq5_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_iq5_k * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + __m256i values[2]; + { + auto v1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0); + auto v2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1); + values[0] = MM256_SET_M128I(v1, v1); + values[1] = MM256_SET_M128I(v2, v2); + } + + __m256i xv[8]; + uint32_t block[8]; + int16_t ls[16]; + + auto mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_iq5_k *)((const char *)vx + (ix+k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = GGML_FP16_TO_FP32(x8[k][i].d); + auto extra = x8[k][i].extra; + auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh); + for (int ib64 = 0; ib64 < 4; ++ib64) { + ls[4*ib64+0] = ((x8[k][i].scales_l[2*ib64+0] & 0xf) | ((x8[k][i].scales_h[ib64] << 4) & 0x30)) - 32; + ls[4*ib64+1] = ((x8[k][i].scales_l[2*ib64+0] >> 4) | ((x8[k][i].scales_h[ib64] << 2) & 0x30)) - 32; + ls[4*ib64+2] = ((x8[k][i].scales_l[2*ib64+1] & 0xf) | ((x8[k][i].scales_h[ib64] >> 0) & 0x30)) - 32; + ls[4*ib64+3] = ((x8[k][i].scales_l[2*ib64+1] >> 4) | ((x8[k][i].scales_h[ib64] >> 2) & 0x30)) - 32; + auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+ib64); + xv[2*ib64+0] = _mm256_and_si256(bits, _mm256_set1_epi8(0xf)); + xv[2*ib64+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf)); + auto qh = _mm256_and_si256(_mm256_slli_epi16(hbits, 7), mh); + auto q5vl = _mm256_or_si256(xv[2*ib64+0], qh); + auto q5vh = _mm256_or_si256(xv[2*ib64+0], _mm256_xor_si256(qh, mh)); + xv[2*ib64+0] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); + qh = _mm256_and_si256(_mm256_slli_epi16(hbits, 6), mh); + q5vl = _mm256_or_si256(xv[2*ib64+1], qh); + q5vh = _mm256_or_si256(xv[2*ib64+1], _mm256_xor_si256(qh, mh)); + xv[2*ib64+1] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); + auto shift1 = MM256_SET_M128I(_mm_set1_epi8((extra & 2) << 0), _mm_set1_epi8((extra & 1) << 1)); + auto shift2 = MM256_SET_M128I(_mm_set1_epi8((extra & 8) >> 2), _mm_set1_epi8((extra & 4) >> 1)); + xv[2*ib64+0] = _mm256_add_epi8(xv[2*ib64+0], shift1); + xv[2*ib64+1] = _mm256_add_epi8(xv[2*ib64+1], shift2); + hbits = _mm256_srli_epi16(hbits, 2); + extra >>= 4; + } + float dnew = convert_to_q8_k_r8(k, 1.f/127, xv, ls, block, y[i].qs); + y[i].d[k] = GGML_FP32_TO_FP16(d*dnew); + } + } + y += nb; + } +} + +void iqk_convert_iq5_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_iq5_k * x8[8]; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + __m256i values[2]; + { + auto v1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0); + auto v2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1); + values[0] = MM256_SET_M128I(v1, v1); + values[1] = MM256_SET_M128I(v2, v2); + } + + __m256i xv[8]; + uint32_t block[8]; + int16_t ls[16]; + float all_s[64]; + + auto mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_iq5_k *)((const char *)vx + (ix+k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = GGML_FP16_TO_FP32(x8[k][i].d); + auto extra = x8[k][i].extra; + auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh); + for (int ib64 = 0; ib64 < 4; ++ib64) { + ls[4*ib64+0] = ((x8[k][i].scales_l[2*ib64+0] & 0xf) | ((x8[k][i].scales_h[ib64] << 4) & 0x30)) - 32; + ls[4*ib64+1] = ((x8[k][i].scales_l[2*ib64+0] >> 4) | ((x8[k][i].scales_h[ib64] << 2) & 0x30)) - 32; + ls[4*ib64+2] = ((x8[k][i].scales_l[2*ib64+1] & 0xf) | ((x8[k][i].scales_h[ib64] >> 0) & 0x30)) - 32; + ls[4*ib64+3] = ((x8[k][i].scales_l[2*ib64+1] >> 4) | ((x8[k][i].scales_h[ib64] >> 2) & 0x30)) - 32; + auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+ib64); + xv[2*ib64+0] = _mm256_and_si256(bits, _mm256_set1_epi8(0xf)); + xv[2*ib64+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf)); + auto qh = _mm256_and_si256(_mm256_slli_epi16(hbits, 7), mh); + auto q5vl = _mm256_or_si256(xv[2*ib64+0], qh); + auto q5vh = _mm256_or_si256(xv[2*ib64+0], _mm256_xor_si256(qh, mh)); + xv[2*ib64+0] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); + qh = _mm256_and_si256(_mm256_slli_epi16(hbits, 6), mh); + q5vl = _mm256_or_si256(xv[2*ib64+1], qh); + q5vh = _mm256_or_si256(xv[2*ib64+1], _mm256_xor_si256(qh, mh)); + xv[2*ib64+1] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); + auto shift1 = MM256_SET_M128I(_mm_set1_epi8((extra & 2) << 0), _mm_set1_epi8((extra & 1) << 1)); + auto shift2 = MM256_SET_M128I(_mm_set1_epi8((extra & 8) >> 2), _mm_set1_epi8((extra & 4) >> 1)); + xv[2*ib64+0] = _mm256_add_epi8(xv[2*ib64+0], shift1); + xv[2*ib64+1] = _mm256_add_epi8(xv[2*ib64+1], shift2); + hbits = _mm256_srli_epi16(hbits, 2); + extra >>= 4; + } + for (int ib32 = 0; ib32 < 8; ++ib32) { + // We have two blocks of 16 with different scales + // We multiply the quants with the scales, find the max value, and convert to 8-bit quants with a single block scale. + auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(xv[ib32])); + auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xv[ib32], 1)); + q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(ls[2*ib32+0])); + q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(ls[2*ib32+1])); + auto abs_q16_l = _mm256_sign_epi16(q16_l, q16_l); + auto abs_q16_h = _mm256_sign_epi16(q16_h, q16_h); + auto max_q16 = _mm256_max_epi16(abs_q16_l, abs_q16_h); + auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_q16), _mm256_extracti128_si256(max_q16, 1))); + auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1)); + auto max4 = _mm_cvtepi32_ps(imax4); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + float max = _mm_cvtss_f32(max4) / 127; + all_s[8*ib32+k] = d*max; + if (max > 1e-9f) { + auto scale = _mm256_set1_ps(1/max); + auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l)); + auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1)); + auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h)); + auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1)); + i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST)); + i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST)); + i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST)); + i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST)); + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + _mm256_storeu_si256((__m256i *)block, i0); + } else { + _mm256_storeu_si256((__m256i *)block, _mm256_setzero_si256()); + } + auto qs = (uint32_t *)y[ib32].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + for (int ib32 = 0; ib32 < 8; ++ib32) { + _mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(_mm256_loadu_ps(all_s + 8*ib32), _MM_FROUND_TO_NEAREST_INT)); + } + y += QK_K/32; + } + } +} + +void iqk_convert_iq6_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_iq6_k * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + __m256i values[4]; + for (int k = 0; k < 4; ++k) { + auto values128 = _mm_loadu_si128((const __m128i *)iq6nl_values + k); + values[k] = MM256_SET_M128I(values128, values128); + } + + __m256i xv[8]; + uint32_t block[8]; + + union { __m256i vec; int16_t val[16]; } helper; + + auto mh1 = _mm256_set1_epi8(1); + auto mh2 = _mm256_set1_epi8(2); + auto mh3 = _mm256_set1_epi8(3); + + auto make_one = [&values, &mh1, &mh2, &mh3] (__m256i l, __m256i hbits) { + auto mask4 = _mm256_cmpeq_epi8(_mm256_and_si256(hbits, mh3), mh3); + auto h1 = _mm256_andnot_si256(mask4, hbits); + auto mask2 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh1), mh1); + auto mask3 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh2), mh2); + auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(-1)); // 0xff; + return _mm256_or_si256(_mm256_or_si256(_mm256_and_si256(mask1, _mm256_shuffle_epi8(values[0], l)), + _mm256_and_si256(mask2, _mm256_shuffle_epi8(values[1], l))), + _mm256_or_si256(_mm256_and_si256(mask3, _mm256_shuffle_epi8(values[2], l)), + _mm256_and_si256(mask4, _mm256_shuffle_epi8(values[3], l)))); + }; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_iq6_k *)((const char *)vx + (ix+k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = GGML_FP16_TO_FP32(x8[k][i].d); + helper.vec = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*)x8[k][i].scales)); + auto extra = x8[k][i].extra; + for (int i128 = 0; i128 < 2; ++i128) { + auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh+i128); + auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+2*i128+0); + xv[4*i128+0] = _mm256_and_si256(bits, _mm256_set1_epi8(0xf)); + xv[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf)); + bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+2*i128+1); + xv[4*i128+2] = _mm256_and_si256(bits, _mm256_set1_epi8(0xf)); + xv[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf)); + for (int k = 0; k < 4; ++k) { + xv[4*i128+k] = make_one(xv[4*i128+k], hbits); + hbits = _mm256_srli_epi16(hbits, 2); + } + auto shift1 = MM256_SET_M128I(_mm_set1_epi8((extra >> 1) & 1), _mm_set1_epi8((extra >> 0) & 1)); + auto shift2 = MM256_SET_M128I(_mm_set1_epi8((extra >> 3) & 1), _mm_set1_epi8((extra >> 2) & 1)); + auto shift3 = MM256_SET_M128I(_mm_set1_epi8((extra >> 5) & 1), _mm_set1_epi8((extra >> 4) & 1)); + auto shift4 = MM256_SET_M128I(_mm_set1_epi8((extra >> 7) & 1), _mm_set1_epi8((extra >> 6) & 1)); + xv[4*i128+0] = _mm256_add_epi8(xv[4*i128+0], shift1); + xv[4*i128+1] = _mm256_add_epi8(xv[4*i128+1], shift2); + xv[4*i128+2] = _mm256_add_epi8(xv[4*i128+2], shift3); + xv[4*i128+3] = _mm256_add_epi8(xv[4*i128+3], shift4); + extra >>= 8; + } + float dnew = convert_to_q8_k_r8(k, 1.f/127, xv, helper.val, block, y[i].qs); + y[i].d[k] = GGML_FP32_TO_FP16(d*dnew); + } + } + y += nb; + } +} + } // namespace +bool iqk_convert_iqk_quants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) { + if (n%QK_K != 0 || nrc_x%8 != 0) return false; + switch (ggml_type(type)) { + case GGML_TYPE_IQ2_KS : iqk_convert_iq2_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_IQ2_K : iqk_convert_iq2_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_IQ3_K : iqk_convert_iq3_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_IQ4_KS : iqk_convert_iq4_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_IQ4_K : iqk_convert_iq4_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_IQ5_KS : iqk_convert_iq5_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_IQ5_K : iqk_convert_iq5_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_IQ6_K : iqk_convert_iq6_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break; + default: return false; + } + return true; +} + bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) { auto etypeA = ggml_type(typeA); diff --git a/ggml/src/iqk/iqk_gemm_iqk_quants.h b/ggml/src/iqk/iqk_gemm_iqk_quants.h index cd076ff7..41beca63 100644 --- a/ggml/src/iqk/iqk_gemm_iqk_quants.h +++ b/ggml/src/iqk/iqk_gemm_iqk_quants.h @@ -8,4 +8,6 @@ bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16); +bool iqk_convert_iqk_quants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x); + #endif diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 0b29a572..81b5841d 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -250,6 +250,14 @@ struct MulMat { case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; case GGML_TYPE_Q6_K : return nrc_y >= 64 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ2_KS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ2_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ3_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ4_KS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ5_KS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ6_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; default: break; } #else @@ -375,22 +383,22 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ3_S_R4: return iqk_convert_iquants_q80_r8(typeA, n, vx, bx, vy, nrc_x); - //case GGML_TYPE_IQ4_KS: - //case GGML_TYPE_IQ5_KS: - //case GGML_TYPE_IQ4_KSS: - //case GGML_TYPE_IQ2_K: - //case GGML_TYPE_IQ2_KS: - //case GGML_TYPE_IQ3_K: - //case GGML_TYPE_IQ4_K: - //case GGML_TYPE_IQ5_K: - //case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ5_KS: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: //case GGML_TYPE_IQ2_K_R4: //case GGML_TYPE_IQ3_K_R4: //case GGML_TYPE_IQ4_K_R4: //case GGML_TYPE_IQ5_K_R4: //case GGML_TYPE_IQ4_KS_R4: //case GGML_TYPE_IQ5_KS_R4: - // return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, mm.funcs, mm.func16); + return iqk_convert_iqk_quants_q80_r8(typeA, n, vx, bx, vy, nrc_x); case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: |