diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-12-18 13:29:25 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-18 13:29:25 +0100 |
commit | 59d742b00fb48a3704f86c16afbb6b4ebcde8e68 (patch) | |
tree | ed95493397cc9ea45b8f8d897d1200d0b3588a86 | |
parent | 9b6d14a2991da41af4aa7ef64a712c63b73ad9fe (diff) |
IQ5_K_R4 (#149)
* iq5_k_r4: Zen4
Much slower than the others.
* iq5_k_r5: WIP
* Minor
* iq5_k_r4: fix AVX2 nrc_y = 1 case
* iq5_k_r4: better Zen4
But TG is still slower than iq5_k
* iq5_k_r4: slightly better AVX2
* iq5_k_r4: NEON
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | examples/quantize/quantize.cpp | 1 | ||||
-rw-r--r-- | ggml/include/ggml.h | 2 | ||||
-rw-r--r-- | ggml/src/ggml-common.h | 10 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 1 | ||||
-rw-r--r-- | ggml/src/ggml.c | 22 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 353 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 136 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 6 | ||||
-rw-r--r-- | include/llama.h | 1 | ||||
-rw-r--r-- | src/llama.cpp | 13 |
10 files changed, 532 insertions, 13 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 73485838..d505f493 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -61,6 +61,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "IQ4_K", LLAMA_FTYPE_MOSTLY_IQ4_K, " 4.5 bpw non-linear quantization", }, { "IQ4_K_R4", LLAMA_FTYPE_MOSTLY_IQ4_K_R4, "IQ4_K repacked", }, { "IQ5_K", LLAMA_FTYPE_MOSTLY_IQ5_K, " 5.5 bpw non-linear quantization", }, + { "IQ5_K_R4", LLAMA_FTYPE_MOSTLY_IQ5_K_R4, "IQ5_K repacked", }, { "IQ6_K", LLAMA_FTYPE_MOSTLY_IQ6_K, " 6.6 bpw non-linear quantization", }, { "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", }, { "Q4_K_R4", LLAMA_FTYPE_MOSTLY_Q4_K_R4, "Q4_K_S repacked", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 77ee0fb9..07142692 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -426,6 +426,7 @@ extern "C" { GGML_TYPE_IQ2_K_R4 = 337, GGML_TYPE_IQ3_K_R4 = 338, GGML_TYPE_IQ4_K_R4 = 339, + GGML_TYPE_IQ5_K_R4 = 340, GGML_TYPE_Q8_K_R8 = 399, GGML_TYPE_COUNT, }; @@ -502,6 +503,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ2_K_R4 = 330, // except 1d tensors GGML_FTYPE_MOSTLY_IQ3_K_R4 = 331, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_K_R4 = 332, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ5_K_R4 = 333, // except 1d tensors GGML_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors }; diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 03cc3460..0af461c7 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -585,6 +585,16 @@ typedef struct { static_assert(sizeof(block_iq5_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + QK_K/8 + 3*QK_K/64, "wrong iq5_k block size/padding"); typedef struct { + ggml_half d[4]; + uint8_t extra[8]; + uint8_t scales_h[QK_K/16]; + uint8_t scales_l[QK_K/8 ]; + uint8_t qs[QK_K*2]; + uint8_t qh[QK_K/2]; +} block_iq5_k_r4; +static_assert(sizeof(block_iq5_k_r4) == 4*sizeof(block_iq5_k), "wrong iq5_k_r4 block size/padding"); + +typedef struct { ggml_half d; uint16_t extra; int8_t scales[QK_K/16]; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index a3beba20..d0de3d0f 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15210,6 +15210,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ2_K_R4: break; case GGML_TYPE_IQ3_K_R4: break; case GGML_TYPE_IQ4_K_R4: break; + case GGML_TYPE_IQ5_K_R4: break; case GGML_TYPE_Q8_K_R8: break; case GGML_TYPE_BF16_R16: break; case GGML_TYPE_Q4_0_4_4: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 45c873f2..526b1139 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1399,6 +1399,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_IQ5_K_R4] = { + .type_name = "iq5_k_r4", + .blck_size = QK_K, + .type_size = sizeof(block_iq5_k), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq5_k_r4, + .from_float = quantize_row_iq5_k_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq5_k_r4_ref, + .vec_dot = vec_dot_iq5_k_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_IQ6_K] = { .type_name = "iq6_k", .blck_size = QK_K, @@ -4193,6 +4206,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ3_K_R4: wtype = GGML_TYPE_IQ3_K_R4; break; case GGML_FTYPE_MOSTLY_IQ4_K_R4: wtype = GGML_TYPE_IQ4_K_R4; break; case GGML_FTYPE_MOSTLY_IQ5_K: wtype = GGML_TYPE_IQ5_K; break; + case GGML_FTYPE_MOSTLY_IQ5_K_R4: wtype = GGML_TYPE_IQ5_K_R4; break; case GGML_FTYPE_MOSTLY_IQ6_K: wtype = GGML_TYPE_IQ6_K; break; case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; @@ -10732,6 +10746,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: @@ -11190,6 +11205,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: @@ -11345,6 +11361,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: @@ -14546,6 +14563,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: @@ -14941,6 +14959,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: @@ -15230,6 +15249,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: @@ -15848,6 +15868,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: @@ -22694,6 +22715,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ3_K_R4:result = quantize_iq3_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_K_R4:result = quantize_iq4_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ5_K: result = quantize_iq5_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ5_K_R4:result = quantize_iq5_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ6_K: result = quantize_iq6_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index bfa68c1d..3e8cbf06 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -185,6 +185,7 @@ struct MulMat { case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_K_R4: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ2_BN_R4: return 4; case GGML_TYPE_Q8_K_R8: return 8; case GGML_TYPE_BF16_R16: return 16; @@ -3959,7 +3960,8 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI #endif template <int nrc_y> -IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff, +//IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff, +inline void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff, __m256i * isum, int16_t min) { auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row @@ -4008,6 +4010,46 @@ IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8 } template <int nrc_y> +inline void iq2345_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff, + __m256i extra, __m256i * isum, int8_t min, int8_t delta) { + auto mask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101); + auto vdelta = _mm256_set1_epi8(delta); + auto vmin = _mm256_set1_epi8(min); + auto min1 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(extra, mask), mask))); + auto min2 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(extra, 4), mask), mask))); + auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row + auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row + auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row + auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row + auto m1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 0)), shuff); // blocks 0, 1, 2, 3 for each row + auto m2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 1)), shuff); // blocks 4, 5, 6, 7 for each row + auto m3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 0)), shuff); // blocks 8, 9, 10, 11 for each row + auto m4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 1)), shuff); // blocks 12, 13, 14, 15 for each row + auto s1 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 0), _mm256_extracti128_si256(m1, 0)), + MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9 + auto s2 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 1), _mm256_extracti128_si256(m1, 1)), + MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11 + auto s3 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 0), _mm256_extracti128_si256(m2, 0)), + MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13 + auto s4 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 1), _mm256_extracti128_si256(m2, 1)), + MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15 + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, ibl); +#ifdef HAVE_FANCY_SIMD + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff)); +#else + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff))); +#endif + } +} + +template <int nrc_y> static void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8<nrc_y, block_q8_K> q8(info); @@ -4268,6 +4310,159 @@ static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI } } +template <int nrc_y> +static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m30 = _mm256_set1_epi8(0x30); + auto m32 = _mm256_set1_epi8(32); + auto ms = _mm256_set1_epi8(2); + auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000); + __m256i values[2]; + { + auto val1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0); + auto val2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1); + values[0] = MM256_SET_M128I(val1, val1); + values[1] = MM256_SET_M128I(val2, val2); +#ifdef HAVE_FANCY_SIMD + values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128)); + values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128)); +#endif + } +#ifdef HAVE_FANCY_SIMD + static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); +#else + auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); +#endif + int nbl = n / QK_K; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + uint64_t stored_scales[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq5[ibl].extra); + auto slbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l); + auto sl1 = _mm256_and_si256(slbits, m4); + auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4); + auto shbits = _mm_loadu_si128((const __m128i*)iq5[ibl].scales_h); + auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits); + auto i8scales1 = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(m30, _mm256_slli_epi16(sh, 4))), m32); + auto i8scales2 = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(m30, sh)), m32); + _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1); + _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2); + __m256i isum[nrc_y] = {}; +#ifdef HAVE_FANCY_SIMD + if constexpr (nrc_y == 1) { + iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128); + } else { + iq2345_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, extra, isum, -128, 2); + } +#endif + for (int ib = 0; ib < QK_K/32; ++ib) { +#ifdef HAVE_FANCY_SIMD + auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib))); +#else + auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle); +#endif + auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0); + auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1); + auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib); + auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits); + qx[0] = _mm256_and_si256(lbits1, m4); + qx[1] = _mm256_and_si256(lbits2, m4); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4); + +#ifdef HAVE_FANCY_SIMD + auto q5vl = _mm256_shuffle_epi8(values[0], qx[0]); + auto q5vh = _mm256_shuffle_epi8(values[1], qx[0]); + qx[0] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01)), q5vl, q5vh); + + q5vl = _mm256_shuffle_epi8(values[0], qx[1]); + q5vh = _mm256_shuffle_epi8(values[1], qx[1]); + qx[1] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x10)), _mm256_set1_epi8(0x10)), q5vl, q5vh); + + q5vl = _mm256_shuffle_epi8(values[0], qx[2]); + q5vh = _mm256_shuffle_epi8(values[1], qx[2]); + qx[2] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x02)), _mm256_set1_epi8(0x02)), q5vl, q5vh); + + q5vl = _mm256_shuffle_epi8(values[0], qx[3]); + q5vh = _mm256_shuffle_epi8(values[1], qx[3]); + qx[3] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x20)), _mm256_set1_epi8(0x20)), q5vl, q5vh); + + if constexpr (nrc_y == 1) { + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_add_epi8(qx[0], shift); + qx[1] = _mm256_add_epi8(qx[1], shift); + qx[2] = _mm256_add_epi8(qx[2], shift); + qx[3] = _mm256_add_epi8(qx[3], shift); + } +#else + + auto q5vl = _mm256_shuffle_epi8(values[0], qx[0]); + auto q5vh = _mm256_shuffle_epi8(values[1], qx[0]); + qx[0] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01))); + + q5vl = _mm256_shuffle_epi8(values[0], qx[1]); + q5vh = _mm256_shuffle_epi8(values[1], qx[1]); + qx[1] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x10)), _mm256_set1_epi8(0x10))); + + q5vl = _mm256_shuffle_epi8(values[0], qx[2]); + q5vh = _mm256_shuffle_epi8(values[1], qx[2]); + qx[2] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x02)), _mm256_set1_epi8(0x02))); + + q5vl = _mm256_shuffle_epi8(values[0], qx[3]); + q5vh = _mm256_shuffle_epi8(values[1], qx[3]); + qx[3] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x20)), _mm256_set1_epi8(0x20))); + + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_add_epi8(qx[0], shift); + qx[1] = _mm256_add_epi8(qx[1], shift); + qx[2] = _mm256_add_epi8(qx[2], shift); + qx[3] = _mm256_add_epi8(qx[3], shift); + auto s1 = _mm256_sign_epi8(qx[0], qx[0]); + auto s2 = _mm256_sign_epi8(qx[1], qx[1]); + auto s3 = _mm256_sign_epi8(qx[2], qx[2]); + auto s4 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi)); +#else + auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4))); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + template <typename Bits> inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) { if (j == 0) { @@ -6371,6 +6566,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[7] = mul_mat_iq4_k_r4_q8_k<8>; expected_typeB = GGML_TYPE_Q8_K; break; + case GGML_TYPE_IQ5_K_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_iq5_k_r4_q8_k<1>; + mm.funcs[1] = mul_mat_iq5_k_r4_q8_k<2>; + mm.funcs[2] = mul_mat_iq5_k_r4_q8_k<3>; + mm.funcs[3] = mul_mat_iq5_k_r4_q8_k<4>; + mm.funcs[4] = mul_mat_iq5_k_r4_q8_k<5>; + mm.funcs[5] = mul_mat_iq5_k_r4_q8_k<6>; + mm.funcs[6] = mul_mat_iq5_k_r4_q8_k<7>; + mm.funcs[7] = mul_mat_iq5_k_r4_q8_k<8>; + expected_typeB = GGML_TYPE_Q8_K; + break; case GGML_TYPE_IQ2_K_R4: assert (ne00 % QK_K == 0); mm.funcs[0] = mul_mat_iq2_k_r4_q8_k<1>; @@ -9071,18 +9278,23 @@ void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i } } -template <int nrc_y, bool is_iq2k> +template <int nrc_y, int k_shift> inline void iq3_4_add_shift(int ibl, const Q8<nrc_y, block_q8_K>& q8, const int8x16x4_t& i8scales, uint8x16_t extra, int32x4_t * isum) { - auto ms = is_iq2k ? vdupq_n_s8(5) : vdupq_n_s8(4); + auto ms = vdupq_n_s8(k_shift); int8x16_t s8_1, s8_2; - if constexpr (is_iq2k) { + if constexpr (k_shift == 5) { auto m1 = vdupq_n_u8(1); s8_1 = vmulq_s8(i8scales.val[0], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2); s8_2 = vmulq_s8(i8scales.val[1], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2); } else { - s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 2))); - s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, extra)); + if constexpr (k_shift == 4) { + s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 2))); + s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, extra)); + } else { + s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 1))); + s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, vshrq_n_u8(extra, 1))); + } } auto s16_1 = vmovl_s8(vget_low_s8 (s8_1)); auto s16_2 = vmovl_s8(vget_high_s8(s8_1)); @@ -9100,13 +9312,18 @@ inline void iq3_4_add_shift(int ibl, const Q8<nrc_y, block_q8_K>& q8, const int8 isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2); isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3); } - if constexpr (is_iq2k) { + if constexpr (k_shift == 5) { auto m1 = vdupq_n_u8(1); s8_1 = vmulq_s8(i8scales.val[2], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2); s8_2 = vmulq_s8(i8scales.val[3], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2); } else { - s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 2))); - s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 4))); + if constexpr (k_shift == 4) { + s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 2))); + s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 4))); + } else { + s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 3))); + s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 5))); + } } s16_1 = vmovl_s8(vget_low_s8 (s8_1)); s16_2 = vmovl_s8(vget_high_s8(s8_1)); @@ -9162,7 +9379,7 @@ void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& in i8scales.val[3] = vaddq_s8(vshrq_n_u8(sl.val[1], 4), vdupq_n_s8(-8)); int32x4_t isum[nrc_y] = {}; if constexpr (nrc_y == 1) { - iq3_4_add_shift<nrc_y, true>(ibl, q8, i8scales, extra, isum); + iq3_4_add_shift<nrc_y, 5>(ibl, q8, i8scales, extra, isum); } for (int is = 0; is < 2; ++is) { i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0])); @@ -9275,7 +9492,7 @@ void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& in i8scales.val[3] = vmulq_s8(i8scales.val[3], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1))); int32x4_t isum[nrc_y] = {}; if constexpr (nrc_y == 1) { - iq3_4_add_shift<nrc_y, false>(ibl, q8, i8scales, extra, isum); + iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum); } for (int is = 0; is < 2; ++is) { i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0])); @@ -9382,7 +9599,7 @@ void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& in i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32); int32x4_t isum[nrc_y] = {}; if constexpr (nrc_y == 1) { - iq3_4_add_shift<nrc_y, false>(ibl, q8, i8scales, extra, isum); + iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum); } for (int is = 0; is < 2; ++is) { i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0])); @@ -9443,6 +9660,114 @@ void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& in } } +template <int nrc_y> +void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = vdupq_n_u8(0xf); + auto m3 = vdupq_n_u8(0x30); + auto ms = vdupq_n_u8(2); + auto m32 = vdupq_n_s8(-32); + auto m10 = vdupq_n_u8(0x10); + uint8x16x2_t shift_shuffle = { + vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}), + vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606}) + }; + auto values = vld1q_s8_x2(iq5nl_values); + int nbl = n / QK_K; + int8x16_t qx[4]; + int8x16x4_t i8scales; + int16x8x4_t i16scales; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d)); + auto extra8 = vld1_u8(iq5[ibl].extra); + uint8x16_t extra; + if constexpr (nrc_y == 1) { + extra = vcombine_u8(extra8, vshr_n_u8(extra8,1)); + } else { + extra = vcombine_u8(extra8, extra8); + } + auto sl = vld1q_u8_x2(iq5[ibl].scales_l); + auto sh = vld1q_u8(iq5[ibl].scales_h); + i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32); + i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32); + i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32); + i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32); + int32x4_t isum[nrc_y] = {}; + if constexpr (nrc_y == 1) { + iq3_4_add_shift<nrc_y, 2>(ibl, q8, i8scales, extra, isum); + } + for (int is = 0; is < 2; ++is) { + i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0])); + i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0])); + i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1])); + i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1])); + for (int ib = 0; ib < 4; ++ib) { + auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib); + auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib); + qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4))); // aligns with 1st half of qx[0] in AVX2 + qx[1] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits)); // aligns with 1st half of qx[1] in AVX2 + qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3))); // aligns with 1st half of qx[2] in AVX2 + qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1))); // aligns with 1st half of qx[3] in AVX2 + uint8x16_t shifts; + if constexpr (nrc_y == 1) { + qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows + qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7 + qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11 + qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15 + } else { + shifts = vandq_u8(ms, vshlq_n_u8(extra, 1)); + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]); + extra = vshrq_n_u8(extra, 1); + qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows + qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7 + qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11 + qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15 + } + auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qx[0] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2))); // aligns with 2nd half of qx[0] in AVX2 + qx[1] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2))); // aligns with 2nd half of qx[1] in AVX2 + qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1))); // aligns with 2nd half of qx[2] in AVX2 + qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3))); // aligns with 2nd half of qx[3] in AVX2 + if constexpr (nrc_y == 1) { + qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows + qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7 + qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11 + qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15 + } else { + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]); + qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows + qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7 + qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11 + qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15 + } + scales = vmovl_s16(vget_high_s16(i16scales.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + IQK_ALWAYS_INLINE void prepare_q4_k_quants(const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) { qx[0] = vandq_u8(bits.val[0], m4); // 0...3 from the 4 rows qx[1] = vandq_u8(bits.val[1], m4); // 16..19 @@ -10282,6 +10607,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_k_r4_q8_k); expected_Btype = GGML_TYPE_Q8_K; break; + case GGML_TYPE_IQ5_K_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq5_k_r4_q8_k); + expected_Btype = GGML_TYPE_Q8_K; + break; case GGML_TYPE_Q4_0_R4: SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q4_0_R4_Dequantizer); expected_Btype = GGML_TYPE_Q8_0; diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 3408d054..0007dc04 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -4685,6 +4685,142 @@ void vec_dot_iq4_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t } // +// ========================================= iq5_k_r4 +// + +void quantize_row_iq5_k_r4_ref(const float * x, block_iq5_k_r4 * y, int64_t k) { + quantize_iq5_k_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq5_k_r4(const float * x, void * y, int64_t k) { + quantize_iq5_k_r4(x, y, 4, k/4, nullptr); +} + +namespace { +inline void convert_iq5_k(const block_iq5_k& x, uint8_t * L) { + const uint8_t * qs = x.qs; + const uint8_t * qh = x.qh; + int shift = 0; + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + for (int j = 0; j < 16; ++j) { + L[j+ 0] = (qs[j+ 0] & 0xf) | (((qh[j+ 0] >> shift) & 1) << 4); + L[j+16] = (qs[j+16] & 0xf) | (((qh[j+16] >> shift) & 1) << 4); + L[j+32] = (qs[j+ 0] >> 4) | (((qh[j+ 0] >> shift) & 2) << 3); + L[j+48] = (qs[j+16] >> 4) | (((qh[j+16] >> shift) & 2) << 3); + } + L += 64; + qs += 32; + shift += 2; + if (shift == 8) { qh += 32; shift = 0; } + } +} +} + +static void repack_iq5_k(int nrows, int n_per_row, const block_iq5_k * x, block_iq5_k_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_iq5_k * x4[4]; + uint8_t L[QK_K]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + std::memset(y[ibl].extra, 0, 8); + std::memset(y[ibl].scales_l, 0, QK_K/8); + std::memset(y[ibl].scales_h, 0, QK_K/16); + for (int k = 0; k < 4; ++k) { + y[ibl].d[k] = x4[k][ibl].d; + auto extra = x4[k][ibl].extra; + convert_iq5_k(x4[k][ibl], L); + for (int ib = 0; ib < QK_K/32; ++ib) { + if (extra & 1) y[ibl].extra[k+0] |= (1 << ib); + if (extra & 2) y[ibl].extra[k+4] |= (1 << ib); + extra >>= 2; + uint8_t sl1 = x4[k][ibl].scales_l[ib] & 0xf; + uint8_t sl2 = x4[k][ibl].scales_l[ib] >> 4; + uint8_t sh = x4[k][ibl].scales_h[ib/2] >> 4*(ib%2); + uint8_t sh1 = (sh >> 0) & 3; + uint8_t sh2 = (sh >> 2) & 3; + int i = 8*ib + k; + y[ibl].scales_l[i%32] |= (sl1 << 4*(i/32)); + y[ibl].scales_h[i%16] |= (sh1 << 2*(i/16)); + i += 4; + y[ibl].scales_l[i%32] |= (sl2 << 4*(i/32)); + y[ibl].scales_h[i%16] |= (sh2 << 2*(i/16)); + for (int i = 0; i < 4; ++i) { + y[ibl].qs[64*ib+4*k+i+ 0] = (L[32*ib+i+ 0] & 0xf) | ((L[32*ib+i+ 8] & 0xf) << 4); // 0....3 + 8...11 from each row + y[ibl].qs[64*ib+4*k+i+16] = (L[32*ib+i+16] & 0xf) | ((L[32*ib+i+24] & 0xf) << 4); // 16...19 + 24...27 from each row + y[ibl].qs[64*ib+4*k+i+32] = (L[32*ib+i+ 4] & 0xf) | ((L[32*ib+i+12] & 0xf) << 4); // 4....7 + 12...15 from each row + y[ibl].qs[64*ib+4*k+i+48] = (L[32*ib+i+20] & 0xf) | ((L[32*ib+i+28] & 0xf) << 4); // 20...23 + 28...31 from each row + y[ibl].qh[16*ib+4*k+i ] = ((L[32*ib+i+ 0] >> 4) << 0) | ((L[32*ib+i+ 8] >> 4) << 1) | ((L[32*ib+i+16] >> 4) << 2) | ((L[32*ib+i+24] >> 4) << 3) + | ((L[32*ib+i+ 4] >> 4) << 4) | ((L[32*ib+i+12] >> 4) << 5) | ((L[32*ib+i+20] >> 4) << 6) | ((L[32*ib+i+28] >> 4) << 7); + } + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_iq5_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ5_K, n_per_row); + std::vector<char> qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_iq5_k(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_iq5_k(4, n_per_row, (const block_iq5_k *)qtmp.data(), (block_iq5_k_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq5_k_r4(const block_iq5_k_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); + for (int ib = 0; ib < QK_K/32; ++ib) { + int is = 8*ib + k; + float dl1 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + is += 4; + float dl2 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + auto values1 = iq5nl_values + (x[ibl].extra[k+0] & (1 << ib) ? 32 : 0); + auto values2 = iq5nl_values + (x[ibl].extra[k+4] & (1 << ib) ? 32 : 0); + for (int i = 0; i < 4; ++i) { + y4[k][QK_K*ibl+32*ib+i+ 0] = dl1 * values1[(x[ibl].qs[64*ib+4*k+i+ 0] & 0xf) | (((x[ibl].qh[16*ib+4*k+i] >> 0) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+ 8] = dl1 * values1[(x[ibl].qs[64*ib+4*k+i+ 0] >> 4) | (((x[ibl].qh[16*ib+4*k+i] >> 1) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+16] = dl2 * values2[(x[ibl].qs[64*ib+4*k+i+16] & 0xf) | (((x[ibl].qh[16*ib+4*k+i] >> 2) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+24] = dl2 * values2[(x[ibl].qs[64*ib+4*k+i+16] >> 4) | (((x[ibl].qh[16*ib+4*k+i] >> 3) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+ 4] = dl1 * values1[(x[ibl].qs[64*ib+4*k+i+32] & 0xf) | (((x[ibl].qh[16*ib+4*k+i] >> 4) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+12] = dl1 * values1[(x[ibl].qs[64*ib+4*k+i+32] >> 4) | (((x[ibl].qh[16*ib+4*k+i] >> 5) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+20] = dl2 * values2[(x[ibl].qs[64*ib+4*k+i+48] & 0xf) | (((x[ibl].qh[16*ib+4*k+i] >> 6) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+28] = dl2 * values2[(x[ibl].qs[64*ib+4*k+i+48] >> 4) | (((x[ibl].qh[16*ib+4*k+i] >> 7) & 1) << 4)]; + } + } + } + } +} + +void vec_dot_iq5_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ5_K_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// // ========================================= q8_k_r8 // diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 7c568ded..b8604caa 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -139,6 +139,12 @@ size_t quantize_q6_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds void dequantize_row_q6_k_r4(const block_q6_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_q6_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq5_k_r4_ref(const float * GGML_RESTRICT x, block_iq5_k_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq5_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq5_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq5_k_r4(const block_iq5_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq5_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void quantize_row_iq4_k_r4_ref(const float * GGML_RESTRICT x, block_iq4_k_r4 * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_iq4_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/include/llama.h b/include/llama.h index e63d76fe..7267fbd4 100644 --- a/include/llama.h +++ b/include/llama.h @@ -196,6 +196,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ2_K_R4 = 338, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_K_R4 = 339, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_K_R4 = 340, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ5_K_R4 = 341, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file diff --git a/src/llama.cpp b/src/llama.cpp index 68e59758..f04d7f6f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3876,6 +3876,7 @@ struct llama_model_loader { case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break; case GGML_TYPE_IQ4_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_K_R4;break; case GGML_TYPE_IQ5_K: ftype = LLAMA_FTYPE_MOSTLY_IQ5_K; break; + case GGML_TYPE_IQ5_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ5_K_R4;break; case GGML_TYPE_IQ6_K: ftype = LLAMA_FTYPE_MOSTLY_IQ6_K; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break; @@ -4601,6 +4602,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ4_K: return "IQ4_K - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_K_R4: return "IQ4_K_R4 - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ5_K: return "IQ5_K - 5.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ5_K_R4: return "IQ5_K_R4 - 5.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ6_K: return "IQ6_K - 6.6 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_BN: return "IQ1_BN - 1.625 bpw Bitnet"; case LLAMA_FTYPE_MOSTLY_IQ2_BN: return "IQ2_BN - 2.00 bpw Bitnet"; @@ -15854,6 +15856,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (new_type == GGML_TYPE_IQ4_K_R4) { new_type = GGML_TYPE_IQ4_K; } + else if (new_type == GGML_TYPE_IQ5_K_R4) { + new_type = GGML_TYPE_IQ5_K; + } else if (new_type == GGML_TYPE_Q4_0_R4) { new_type = GGML_TYPE_Q4_0; } @@ -16150,7 +16155,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_Q6_K_R4 || new_type == GGML_TYPE_Q5_K_R4 || new_type == GGML_TYPE_Q3_K_R4 || new_type == GGML_TYPE_Q2_K_R4 || new_type == GGML_TYPE_IQ4_K_R4|| new_type == GGML_TYPE_Q8_K_R8 || new_type == GGML_TYPE_IQ3_K_R4|| - new_type == GGML_TYPE_IQ2_K_R4) { + new_type == GGML_TYPE_IQ2_K_R4|| new_type == GGML_TYPE_IQ5_K_R4) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K != 0) { @@ -16193,6 +16198,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n case GGML_TYPE_Q4_K_R4: case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q6_0; break; case GGML_TYPE_IQ6_K: @@ -16325,6 +16331,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ4_K: default_type = GGML_TYPE_IQ4_K; break; case LLAMA_FTYPE_MOSTLY_IQ4_K_R4:default_type = GGML_TYPE_IQ4_K_R4;break; case LLAMA_FTYPE_MOSTLY_IQ5_K: default_type = GGML_TYPE_IQ5_K; break; + case LLAMA_FTYPE_MOSTLY_IQ5_K_R4:default_type = GGML_TYPE_IQ5_K_R4;break; case LLAMA_FTYPE_MOSTLY_IQ6_K: default_type = GGML_TYPE_IQ6_K; break; case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; @@ -16741,6 +16748,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_K; else chunk_size_multiplier = 4; } + else if (new_type == GGML_TYPE_IQ5_K_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ5_K; + else chunk_size_multiplier = 4; + } else if (new_type == GGML_TYPE_BF16_R16) { if (tensor->ne[1] % 16 != 0) new_type = GGML_TYPE_BF16; else chunk_size_multiplier = 16; |