diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-05-17 08:57:26 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-17 08:57:26 +0300 |
commit | 7abdf2b099ecf9bea156a635a8f22d168483f2b1 (patch) | |
tree | 98a642e178695ba49f0a31a4a4b2700b28cdf69b | |
parent | 134d5481737c05421eb1ba7cd7573136e3fdbd69 (diff) |
IQ5_KS_R4: row-interleaved IQ5_KS (#426)
* iq5_ks_r4: basics
* iq5_ks_r4: Zen4 works
* iq5_ks_r4: AVX2 works
* iq5_ks_r4: NEON
* Fix iq5_ks on NEON
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | examples/quantize/quantize.cpp | 1 | ||||
-rw-r--r-- | ggml/include/ggml.h | 2 | ||||
-rw-r--r-- | ggml/src/ggml-common.h | 7 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 1 | ||||
-rw-r--r-- | ggml/src/ggml.c | 26 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 315 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 124 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 6 | ||||
-rw-r--r-- | include/llama.h | 1 | ||||
-rw-r--r-- | src/llama.cpp | 9 |
10 files changed, 441 insertions, 51 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 1b388a73..b5277ec1 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -67,6 +67,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, { "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", }, { "IQ4_KS_R4",LLAMA_FTYPE_MOSTLY_IQ4_KS_R4,"IQ4_KS repacked", }, + { "IQ5_KS_R4",LLAMA_FTYPE_MOSTLY_IQ5_KS_R4,"IQ5_KS repacked", }, { "IQ4_KSS", LLAMA_FTYPE_MOSTLY_IQ4_KSS, " 4.0 bpw non-linear quantization", }, { "IQ5_KS", LLAMA_FTYPE_MOSTLY_IQ5_KS, " 5.25 bpw non-linear quantization", }, { "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",}, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b6f461ed..a04c7d43 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -452,6 +452,7 @@ extern "C" { GGML_TYPE_IQ4_K_R4 = 339, GGML_TYPE_IQ5_K_R4 = 340, GGML_TYPE_IQ4_KS_R4 = 344, + GGML_TYPE_IQ5_KS_R4 = 352, GGML_TYPE_Q8_KV_R8 = 398, GGML_TYPE_Q8_K_R8 = 399, GGML_TYPE_COUNT, @@ -540,6 +541,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_K_R4 = 332, // except 1d tensors GGML_FTYPE_MOSTLY_IQ5_K_R4 = 333, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_KS_R4 = 337, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ5_KS_R4 = 341, // except 1d tensors GGML_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors GGML_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors }; diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 1c2d1b17..26041ac2 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -694,6 +694,13 @@ typedef struct { } block_iq5_ks; static_assert(sizeof(block_iq5_ks) == QK_K/32 + QK_K/2 + QK_K/8, "wrong iq5_ks block size/padding"); +typedef struct { + uint8_t scales[QK_K/8]; + uint8_t qs[QK_K*2]; + uint8_t qh[QK_K/2]; +} block_iq5_ks_r4; +static_assert(sizeof(block_iq5_ks_r4) == 4*sizeof(block_iq5_ks), "wrong iq5_ks_r4 block size/padding"); + #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 8ebb0d32..0e6aa677 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15451,6 +15451,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ4_K_R4: break; case GGML_TYPE_IQ5_K_R4: break; case GGML_TYPE_IQ4_KS_R4:break; + case GGML_TYPE_IQ5_KS_R4:break; case GGML_TYPE_Q8_KV_R8: break; case GGML_TYPE_Q8_K_R8: break; case GGML_TYPE_Q8_KV: break; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index bc103ab7..7cbc0056 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1343,6 +1343,23 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 4, }, + [GGML_TYPE_IQ5_KS_R4] = { + .type_name = "iq5_ks_r4", + .blck_size = QK_K, + .type_size = sizeof(block_iq5_ks), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq5_ks_r4, + .from_float = quantize_row_iq5_ks_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq5_ks_r4_ref, + .vec_dot = vec_dot_iq5_ks_r4_q8_k, +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_K32, +#else + .vec_dot_type = GGML_TYPE_Q8_K, +#endif + .nrows = 1, + .row_meta_size = 4, + }, [GGML_TYPE_IQ4_KSS] = { .type_name = "iq4_kss", .blck_size = QK_K, @@ -4478,6 +4495,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ4_KS: wtype = GGML_TYPE_IQ4_KS; break; case GGML_FTYPE_MOSTLY_IQ4_KS_R4: wtype = GGML_TYPE_IQ4_KS_R4;break; + case GGML_FTYPE_MOSTLY_IQ5_KS_R4: wtype = GGML_TYPE_IQ5_KS_R4;break; case GGML_FTYPE_MOSTLY_IQ4_KSS: wtype = GGML_TYPE_IQ4_KSS; break; case GGML_FTYPE_MOSTLY_IQ5_KS: wtype = GGML_TYPE_IQ5_KS; break; case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break; @@ -11242,6 +11260,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ2_K: @@ -11715,6 +11734,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ2_K: @@ -11885,6 +11905,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ2_K: @@ -15382,6 +15403,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ2_K: @@ -15792,6 +15814,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ2_K: @@ -16108,6 +16131,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ2_K: @@ -16741,6 +16765,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ2_K: @@ -23810,6 +23835,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KS: result = quantize_iq4_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KS_R4:result = quantize_iq4_ks_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ5_KS_R4:result = quantize_iq5_ks_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KSS: result = quantize_iq4_kss(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ5_KS: result = quantize_iq5_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 6072d56d..7d7ae798 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -342,6 +342,7 @@ struct MulMat { case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS_R4: case GGML_TYPE_IQ2_S_R4: @@ -379,6 +380,7 @@ struct MulMat { case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS_R4: case GGML_TYPE_IQ2_S_R4: @@ -7353,6 +7355,16 @@ static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI } } +static inline __m256i prepare_5bit_quants(const __m256i * values, __m256i ql, __m256i qh, __m256i mask) { + auto q5vl = _mm256_shuffle_epi8(values[0], ql); + auto q5vh = _mm256_shuffle_epi8(values[1], ql); +#ifdef HAVE_FANCY_SIMD + return _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(qh, mask), mask), q5vl, q5vh); +#else + return _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(qh, mask), mask)); +#endif +} + template <int nrc_y> static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -7421,23 +7433,11 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4); qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4); + qx[0] = prepare_5bit_quants(values, qx[0], hb, _mm256_set1_epi8(0x01)); + qx[1] = prepare_5bit_quants(values, qx[1], hb, _mm256_set1_epi8(0x10)); + qx[2] = prepare_5bit_quants(values, qx[2], hb, _mm256_set1_epi8(0x02)); + qx[3] = prepare_5bit_quants(values, qx[3], hb, _mm256_set1_epi8(0x20)); #ifdef HAVE_FANCY_SIMD - auto q5vl = _mm256_shuffle_epi8(values[0], qx[0]); - auto q5vh = _mm256_shuffle_epi8(values[1], qx[0]); - qx[0] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01)), q5vl, q5vh); - - q5vl = _mm256_shuffle_epi8(values[0], qx[1]); - q5vh = _mm256_shuffle_epi8(values[1], qx[1]); - qx[1] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x10)), _mm256_set1_epi8(0x10)), q5vl, q5vh); - - q5vl = _mm256_shuffle_epi8(values[0], qx[2]); - q5vh = _mm256_shuffle_epi8(values[1], qx[2]); - qx[2] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x02)), _mm256_set1_epi8(0x02)), q5vl, q5vh); - - q5vl = _mm256_shuffle_epi8(values[0], qx[3]); - q5vh = _mm256_shuffle_epi8(values[1], qx[3]); - qx[3] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x20)), _mm256_set1_epi8(0x20)), q5vl, q5vh); - if constexpr (nrc_y == 1) { auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); shift = _mm256_shuffle_epi8(shift, shift_shuffle); @@ -7447,23 +7447,6 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI qx[3] = _mm256_add_epi8(qx[3], shift); } #else - - auto q5vl = _mm256_shuffle_epi8(values[0], qx[0]); - auto q5vh = _mm256_shuffle_epi8(values[1], qx[0]); - qx[0] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01))); - - q5vl = _mm256_shuffle_epi8(values[0], qx[1]); - q5vh = _mm256_shuffle_epi8(values[1], qx[1]); - qx[1] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x10)), _mm256_set1_epi8(0x10))); - - q5vl = _mm256_shuffle_epi8(values[0], qx[2]); - q5vh = _mm256_shuffle_epi8(values[1], qx[2]); - qx[2] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x02)), _mm256_set1_epi8(0x02))); - - q5vl = _mm256_shuffle_epi8(values[0], qx[3]); - q5vh = _mm256_shuffle_epi8(values[1], qx[3]); - qx[3] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x20)), _mm256_set1_epi8(0x20))); - auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); shift = _mm256_shuffle_epi8(shift, shift_shuffle); qx[0] = _mm256_add_epi8(qx[0], shift); @@ -7506,6 +7489,128 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI } } +template <int nrc_y> +static void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = _mm256_set1_epi8(0xf); + __m256i values[2]; + { + auto val1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0); + auto val2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1); + values[0] = MM256_SET_M128I(val1, val1); + values[1] = MM256_SET_M128I(val2, val2); +#ifdef HAVE_FANCY_SIMD + values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128)); + values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128)); +#endif + } + int nbl = n / QK_K; + using helper_t = union { __m256i vec; uint32_t val[8]; }; +#ifndef HAVE_FANCY_SIMD + helper_t h, h_shift; + auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); +#else + using helper512_t = union { __m512i vec; uint64_t val[8]; }; + helper_t h; + helper512_t h_shift; +#endif + __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; + __m256i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto dptr = (const float *)((const char *)vx + (ix+0)*bx); + const block_iq5_ks_r4 * iq5 = (const block_iq5_ks_r4 *)(dptr + 4); + auto d4 = _mm_loadu_ps(dptr); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto scales = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales); + h.vec = _mm256_sub_epi8(_mm256_and_si256(scales, _mm256_set1_epi8(-2)), _mm256_set1_epi8(127)); +#ifndef HAVE_FANCY_SIMD + h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1); + { + __m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[0])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[0]))))); + __m256 v2 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[1])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[1]))))); + __m256 v3 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[2])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[2]))))); + __m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[3])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[3]))))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto m8 = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); + acc[iy] = _mm256_fmadd_ps(v1, _mm256_shuffle_ps(m8, m8, 0x00), acc[iy]); + acc[iy] = _mm256_fmadd_ps(v2, _mm256_shuffle_ps(m8, m8, 0x55), acc[iy]); + acc[iy] = _mm256_fmadd_ps(v3, _mm256_shuffle_ps(m8, m8, 0xaa), acc[iy]); + acc[iy] = _mm256_fmadd_ps(v4, _mm256_shuffle_ps(m8, m8, 0xff), acc[iy]); + } + } +#else + auto shift = _mm256_add_epi8(_mm256_set1_epi8(-64), _mm256_and_si256(scales, _mm256_set1_epi8(1))); + h_shift.vec = _mm512_mullo_epi16(_mm512_cvtepi8_epi16(shift), _mm512_cvtepi8_epi16(h.vec)); +#endif + for (int ib = 0; ib < QK_K/32; ++ib) { +#ifdef HAVE_FANCY_SIMD + auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib])); + auto ishifts = _mm256_cvtepi16_epi32(_mm_set1_epi64x(h_shift.val[ib])); + auto scales_m = _mm256_cvtepi32_ps(ishifts); + for (int iy = 0; iy < nrc_y; ++iy) { + float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]); + } +#endif + auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0); + auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1); + auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib); + auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits); + qx[0] = _mm256_and_si256(lbits1, m4); + qx[1] = _mm256_and_si256(lbits2, m4); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4); + + qx[0] = prepare_5bit_quants(values, qx[0], hb, _mm256_set1_epi8(0x01)); + qx[1] = prepare_5bit_quants(values, qx[1], hb, _mm256_set1_epi8(0x10)); + qx[2] = prepare_5bit_quants(values, qx[2], hb, _mm256_set1_epi8(0x02)); + qx[3] = prepare_5bit_quants(values, qx[3], hb, _mm256_set1_epi8(0x20)); + +#ifndef HAVE_FANCY_SIMD + auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle); + auto s1 = _mm256_sign_epi8(qx[0], qx[0]); + auto s2 = _mm256_sign_epi8(qx[1], qx[1]); + auto s3 = _mm256_sign_epi8(qx[2], qx[2]); + auto s4 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi)); +#else + auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4))); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, ibl)), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, _mm_mul_ps(d4, sum)); + } + } +} + template <typename Bits> inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) { if (j == 0) { @@ -9949,6 +10054,22 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { #endif expected_typeB = GGML_TYPE_Q8_K32; break; + case GGML_TYPE_IQ5_KS_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_iq5_ks_r4_q8_k<1>; + mm.funcs[1] = mul_mat_iq5_ks_r4_q8_k<2>; + mm.funcs[2] = mul_mat_iq5_ks_r4_q8_k<3>; + mm.funcs[3] = mul_mat_iq5_ks_r4_q8_k<4>; + mm.funcs[4] = mul_mat_iq5_ks_r4_q8_k<5>; + mm.funcs[5] = mul_mat_iq5_ks_r4_q8_k<6>; + mm.funcs[6] = mul_mat_iq5_ks_r4_q8_k<7>; + mm.funcs[7] = mul_mat_iq5_ks_r4_q8_k<8>; +#ifndef HAVE_FANCY_SIMD + // For some reason Zen4 does not like this particular function + mm.func16 = mul_mat_iq5_ks_r4_q8_k<16>; +#endif + expected_typeB = GGML_TYPE_Q8_K32; + break; case GGML_TYPE_IQ2_XXS_R4: assert (ne00 % QK_K == 0); mm.funcs[0] = mul_mat_iq2_xxs_r4_q8_k<1>; @@ -11086,7 +11207,8 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> { }; struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> { - DequantizerIQ5KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq5nl_values)) {} + DequantizerIQ5KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), + values(vld1q_s8_x4(iq5nl_values)) {} constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } @@ -11095,7 +11217,11 @@ struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> { inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { (void)q8; (void)acc; - auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(vld1_u8(x[i].scales)), mask)), m127); + auto sas8 = vld1_u8(x[i].scales); + auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(sas8), mask)), m127); + hbits = vld1q_u8_x2(x[i].qh); + sas = vcombine_u8(sas8, sas8); + sas = vshlq_n_u8(vandq_u8(sas, vdupq_n_u8(1)), 5); int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))}; return scales; } @@ -11105,27 +11231,29 @@ struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> { if (j == 1) { for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4); } - bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm)); - bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm)); - bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm)); - bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm)); - bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm)); - bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm)); - bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm)); - bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm)); - for (int k = 0; k < 4; ++k) { - bits.b1.val[k] = vqtbl2q_s8(values, bits.b1.val[k]); - bits.b2.val[k] = vqtbl2q_s8(values, bits.b2.val[k]); - } + auto shift = vdupq_n_u8((x[i].scales[4*j+0] & 1) << 5); + bits.b1.val[0] = vaddq_u8(shift, vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm))); + bits.b1.val[1] = vaddq_u8(shift, vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm))); + shift = vdupq_n_u8((x[i].scales[4*j+1] & 1) << 5); + bits.b1.val[2] = vaddq_u8(shift, vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm))); + bits.b1.val[3] = vaddq_u8(shift, vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm))); + for (int k = 0; k < 4; ++k) bits.b1.val[k] = vqtbl4q_s8(values, bits.b1.val[k]); + shift = vdupq_n_u8((x[i].scales[4*j+2] & 1) << 5); + bits.b2.val[0] = vaddq_u8(shift, vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm))); + bits.b2.val[1] = vaddq_u8(shift, vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm))); + shift = vdupq_n_u8((x[i].scales[4*j+3] & 1) << 5); + bits.b2.val[2] = vaddq_u8(shift, vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm))); + bits.b2.val[3] = vaddq_u8(shift, vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm))); + for (int k = 0; k < 4; ++k) bits.b2.val[k] = vqtbl4q_s8(values, bits.b2.val[k]); } Q4bits bits; - const int8x16x2_t values; - const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); + const int8x16x4_t values; const uint8x16_t hm = vdupq_n_u8(0x10); const uint16x8_t mask = vdupq_n_u16(254); const int16x8_t m127 = vdupq_n_s16(-127); uint8x16x2_t hbits; + uint8x16_t sas; }; @@ -13069,6 +13197,91 @@ void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i } template <int nrc_y> +void mul_mat_iq5_ks_r4_q8_k_neon(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = vdupq_n_u8(0xf); + auto m10 = vdupq_n_u8(0x10); + auto values = vld1q_s8_x2(iq5nl_values); + int nbl = n / QK_K; + int8x16_t qx[8]; + int16x8x4_t iscales; + int32x4x4_t scales; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto dptr = (const float *)((const char *)vx + ix*bx); + auto d4 = vld1q_f32(dptr); + const block_iq5_ks_r4 * iq5 = (const block_iq5_ks_r4 *)(dptr + 4); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto sas = vld1q_u8_x2(iq5[ibl].scales); + auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254)); + iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127)); + iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127)); + scale = vandq_u8(sas.val[1], vdupq_n_u8(254)); + iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127)); + iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127)); + // Adding the block shifts costs us ~9% in performance drop. + // Is there a better way? + sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 1); + sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 1); + { + auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0]))); + auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0]))); + auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1]))); + auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1]))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums); + auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]); + auto b8 = vget_low_s16(bs); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3); + b8 = vget_high_s16(bs); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3); + } + } + for (int is = 0; is < 2; ++is) { + scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0])); + scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0])); + scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1])); + scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1])); + for (int ib = 0; ib < 4; ++ib) { + auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib); + auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib); + qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4))); + qx[1] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2))); + qx[2] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits)); + qx[3] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2))); + qx[4] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3))); + qx[5] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1))); + qx[6] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1))); + qx[7] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3))); + for (int l = 0; l < 8; ++l) qx[l] = vqtbl2q_s8(values, qx[l]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vmulq_f32(d4, acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template <int nrc_y> static void mul_mat_iq2_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8<nrc_y, block_q8_K> q8(info); @@ -15274,6 +15487,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq5_k_r4_q8_k); expected_Btype = GGML_TYPE_Q8_K; break; + case GGML_TYPE_IQ5_KS_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq5_ks_r4_q8_k_neon); + expected_Btype = GGML_TYPE_Q8_K; + break; case GGML_TYPE_Q4_0_R8: SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r8_q8_0, Q4_0_R8_Dequantizer); expected_Btype = GGML_TYPE_Q8_0_X4; diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 78b25525..93aa2180 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -5628,7 +5628,8 @@ void quantize_row_iq5_k_r4(const float * x, void * y, int64_t k) { } namespace { -inline void convert_iq5_k(const block_iq5_k& x, uint8_t * L) { +template <typename Block> +inline void convert_iq5_k(const Block& x, uint8_t * L) { const uint8_t * qs = x.qs; const uint8_t * qh = x.qh; int shift = 0; @@ -5752,6 +5753,126 @@ void vec_dot_iq5_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t } // +// ========================================= iq5_ks_r4 +// + +void quantize_row_iq5_ks_r4_ref(const float * x, block_iq5_ks_r4 * y, int64_t k) { + quantize_iq5_ks_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq5_ks_r4(const float * x, void * y, int64_t k) { + quantize_iq5_ks_r4(x, y, 4, k/4, nullptr); +} + +static void repack_iq5_ks(int nrows, int n_per_row, const block_iq5_ks * x, block_iq5_ks_r4 * y, [[maybe_unused]] bool online) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + auto row_size = ggml_row_size(GGML_TYPE_IQ5_KS, n_per_row); + int nblock = n_per_row/QK_K; + const block_iq5_ks * x4[4]; + uint8_t L[QK_K]; + char * cy = (char *)y; + const char * cx = (const char *)x; + for (int row = 0; row < nrows; row += 4) { + float * dptr = (float *)cy; + block_iq5_ks_r4 * y = (block_iq5_ks_r4 *)(dptr + 4); + for (int k = 0; k < 4; ++k) { + auto dk = (const float *)(cx + k*row_size); + dptr[k] = dk[0]; + x4[k] = (const block_iq5_ks *)(dk + 1); + } + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + convert_iq5_k(x4[k][ibl], L); + for (int ib = 0; ib < QK_K/32; ++ib) { + y[ibl].scales[4*ib+k] = x4[k][ibl].scales[ib]; + for (int i = 0; i < 4; ++i) { + y[ibl].qs[64*ib+4*k+i+ 0] = (L[32*ib+i+ 0] & 0xf) | ((L[32*ib+i+ 8] & 0xf) << 4); // 0....3 + 8...11 from each row + y[ibl].qs[64*ib+4*k+i+16] = (L[32*ib+i+16] & 0xf) | ((L[32*ib+i+24] & 0xf) << 4); // 16...19 + 24...27 from each row + y[ibl].qs[64*ib+4*k+i+32] = (L[32*ib+i+ 4] & 0xf) | ((L[32*ib+i+12] & 0xf) << 4); // 4....7 + 12...15 from each row + y[ibl].qs[64*ib+4*k+i+48] = (L[32*ib+i+20] & 0xf) | ((L[32*ib+i+28] & 0xf) << 4); // 20...23 + 28...31 from each row + y[ibl].qh[16*ib+4*k+i ] = ((L[32*ib+i+ 0] >> 4) << 0) | ((L[32*ib+i+ 8] >> 4) << 1) | ((L[32*ib+i+16] >> 4) << 2) | ((L[32*ib+i+24] >> 4) << 3) + | ((L[32*ib+i+ 4] >> 4) << 4) | ((L[32*ib+i+12] >> 4) << 5) | ((L[32*ib+i+20] >> 4) << 6) | ((L[32*ib+i+28] >> 4) << 7); + } + } + } + } + cx += 4*row_size; + cy += 4*row_size; + } +} + +size_t quantize_iq5_ks_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ5_KS, n_per_row); + std::vector<char> qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_iq5_ks(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_iq5_ks(4, n_per_row, (const block_iq5_ks *)qtmp.data(), (block_iq5_ks_r4 *)qcur, false); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq5_ks_r4(const block_iq5_ks_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + //auto row_size = ggml_row_size(GGML_TYPE_IQ5_KS, n_per_row); + int nblock = n_per_row/QK_K; + const float * dptr = (const float *)x; + x = (const block_iq5_ks_r4 *)(dptr + 4); + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = dptr[k]; + //if (!isfinite(d)) { + // printf("Oops: d = %g for ibl = %d, k = %d\n", d, ibl, k); exit(1); + //} + for (int ib = 0; ib < QK_K/32; ++ib) { + uint8_t sc = x[ibl].scales[4*ib+k]; + float dl = d * ((sc & 254) - 127); + //if (!isfinite(dl)) { + // printf("Oops: dl = %g for ibl = %d, k = %d, ib = %d, d = %g, sc = %u\n", dl, ibl, k, ib, d, sc); exit(1); + //} + auto values = iq5nl_values + ((sc & 1) << 5); + for (int i = 0; i < 4; ++i) { + y4[k][QK_K*ibl+32*ib+i+ 0] = dl * values[(x[ibl].qs[64*ib+4*k+i+ 0] & 0xf) | (((x[ibl].qh[16*ib+4*k+i] >> 0) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+ 8] = dl * values[(x[ibl].qs[64*ib+4*k+i+ 0] >> 4) | (((x[ibl].qh[16*ib+4*k+i] >> 1) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+16] = dl * values[(x[ibl].qs[64*ib+4*k+i+16] & 0xf) | (((x[ibl].qh[16*ib+4*k+i] >> 2) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+24] = dl * values[(x[ibl].qs[64*ib+4*k+i+16] >> 4) | (((x[ibl].qh[16*ib+4*k+i] >> 3) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+ 4] = dl * values[(x[ibl].qs[64*ib+4*k+i+32] & 0xf) | (((x[ibl].qh[16*ib+4*k+i] >> 4) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+12] = dl * values[(x[ibl].qs[64*ib+4*k+i+32] >> 4) | (((x[ibl].qh[16*ib+4*k+i] >> 5) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+20] = dl * values[(x[ibl].qs[64*ib+4*k+i+48] & 0xf) | (((x[ibl].qh[16*ib+4*k+i] >> 6) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+28] = dl * values[(x[ibl].qs[64*ib+4*k+i+48] >> 4) | (((x[ibl].qh[16*ib+4*k+i] >> 7) & 1) << 4)]; + } + //for (int i = 0; i < 32; ++i) { + // if (!isfinite(y4[k][QK_K*ibl+32*ib+i])) { + // printf("Oops: y4[%d][%d, %d, %d] = %g\n", k, ibl, ib, i, y4[k][QK_K*ibl+32*ib+i]); + // printf("d = %g, dl = %g\n", d, dl); + // exit(1); + // } + //} + } + } + } +} + +void vec_dot_iq5_ks_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ5_KS_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// // ========================================= q8_k_r8 // @@ -7182,6 +7303,7 @@ const Repack * get_repack_info(ggml_type type) { { GGML_TYPE_IQ5_K, { GGML_TYPE_IQ5_K_R4, 4, (Repack::repack_func)repack_iq5_k} }, { GGML_TYPE_IQ4_XS, { GGML_TYPE_IQ4_XS_R8, 8, (Repack::repack_func)repack_iq4_xs} }, { GGML_TYPE_IQ4_KS, { GGML_TYPE_IQ4_KS_R4, 4, (Repack::repack_func)repack_iq4_ks} }, + { GGML_TYPE_IQ5_KS, { GGML_TYPE_IQ5_KS_R4, 4, (Repack::repack_func)repack_iq5_ks} }, { GGML_TYPE_IQ4_NL, { GGML_TYPE_IQ4_NL_R4, 4, (Repack::repack_func)repack_iq4_nl} }, { GGML_TYPE_IQ2_BN, { GGML_TYPE_IQ2_BN_R4, 4, (Repack::repack_func)repack_iq2_bn} }, { GGML_TYPE_IQ2_XXS,{ GGML_TYPE_IQ2_XXS_R4,4, (Repack::repack_func)repack_iq2_xxs} }, diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 0533d1f7..9c274d4b 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -181,6 +181,12 @@ size_t quantize_iq4_ks_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT void dequantize_row_iq4_ks_r4(const block_iq4_ks_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq4_ks_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq5_ks_r4_ref(const float * GGML_RESTRICT x, block_iq5_ks_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq5_ks_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq5_ks_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq5_ks_r4(const block_iq5_ks_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq5_ks_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void quantize_row_iq2_xxs_r4_ref(const float * GGML_RESTRICT x, block_iq2_xxs_r4 * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_xxs_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_iq2_xxs_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/include/llama.h b/include/llama.h index 98b08bbd..b6b408de 100644 --- a/include/llama.h +++ b/include/llama.h @@ -220,6 +220,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_K_R4 = 340, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ5_K_R4 = 341, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 = 345, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ5_KS_R4 = 350, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors diff --git a/src/llama.cpp b/src/llama.cpp index 838451f6..b7534420 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4372,6 +4372,7 @@ struct llama_model_loader { case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break; case GGML_TYPE_IQ4_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS_R4; break; + case GGML_TYPE_IQ5_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ5_KS_R4; break; case GGML_TYPE_IQ4_KSS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KSS; break; case GGML_TYPE_IQ5_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ5_KS; break; case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break; @@ -5109,6 +5110,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KS: return "IQ4_KS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KS_R4:return "IQ4_KS_R4 - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ5_KS_R4:return "IQ5_KS_R4 - 5.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KSS: return "IQ4_KSS - 4.0 bpw"; case LLAMA_FTYPE_MOSTLY_IQ5_KS: return "IQ5_KS - 5.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_K: return "IQ2_K - 2.375 bpw"; @@ -18621,7 +18623,8 @@ static ggml_type change_type_if_necessary(ggml_type new_type, int nx, int ny) { new_type == GGML_TYPE_IQ4_K_R4|| new_type == GGML_TYPE_Q8_K_R8 || new_type == GGML_TYPE_IQ3_K_R4|| new_type == GGML_TYPE_IQ2_K_R4|| new_type == GGML_TYPE_IQ5_K_R4|| new_type == GGML_TYPE_IQ4_KS_R4 || new_type == GGML_TYPE_IQ3_XXS_R4 || new_type == GGML_TYPE_IQ2_XXS_R4 || new_type == GGML_TYPE_IQ2_XS_R4 || - new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ3_S_R4|| new_type == GGML_TYPE_IQ5_KS) { + new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ3_S_R4|| + new_type == GGML_TYPE_IQ5_KS || new_type == GGML_TYPE_IQ5_KS_R4) { if (nx % QK_K != 0) { LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type)); convert_incompatible_tensor = true; @@ -18664,6 +18667,7 @@ static ggml_type change_type_if_necessary(ggml_type new_type, int nx, int ny) { case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_Q4_K_R4: case GGML_TYPE_IQ5_KS: + case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ5_K_R4: @@ -18708,6 +18712,7 @@ static std::pair<ggml_type, int> interleaved_properties(ggml_type type) { { GGML_TYPE_IQ3_K_R4, { GGML_TYPE_IQ3_K, 4} }, { GGML_TYPE_IQ4_K_R4, { GGML_TYPE_IQ4_K, 4} }, { GGML_TYPE_IQ4_KS_R4, { GGML_TYPE_IQ4_KS, 4} }, + { GGML_TYPE_IQ5_KS_R4, { GGML_TYPE_IQ5_KS, 4} }, { GGML_TYPE_IQ5_K_R4, { GGML_TYPE_IQ5_K, 4} }, { GGML_TYPE_Q8_KV_R8, { GGML_TYPE_Q8_KV, 8} }, { GGML_TYPE_Q8_K_R8, { GGML_TYPE_Q8_K, 8} }, @@ -19254,6 +19259,7 @@ static llama_ftype repacked_ftype(llama_ftype ftype) { { LLAMA_FTYPE_MOSTLY_IQ4_K, LLAMA_FTYPE_MOSTLY_IQ4_K_R4 }, { LLAMA_FTYPE_MOSTLY_IQ5_K, LLAMA_FTYPE_MOSTLY_IQ5_K_R4 }, { LLAMA_FTYPE_MOSTLY_IQ4_KS, LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ5_KS, LLAMA_FTYPE_MOSTLY_IQ5_KS_R4 }, { LLAMA_FTYPE_MOSTLY_Q8_KV, LLAMA_FTYPE_MOSTLY_Q8_KV_R8 }, }; if (auto it = k_map.find(ftype); it != k_map.end()) return it->second; @@ -19323,6 +19329,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ4_KS: default_type = GGML_TYPE_IQ4_KS; break; case LLAMA_FTYPE_MOSTLY_IQ4_KS_R4:default_type = GGML_TYPE_IQ4_KS_R4;break; + case LLAMA_FTYPE_MOSTLY_IQ5_KS_R4:default_type = GGML_TYPE_IQ5_KS_R4;break; case LLAMA_FTYPE_MOSTLY_IQ4_KSS: default_type = GGML_TYPE_IQ4_KSS; break; case LLAMA_FTYPE_MOSTLY_IQ5_KS: default_type = GGML_TYPE_IQ5_KS; break; case LLAMA_FTYPE_MOSTLY_IQ2_K: default_type = GGML_TYPE_IQ2_K; break; |