diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-05-16 17:25:15 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-16 17:25:15 +0300 |
commit | 134d5481737c05421eb1ba7cd7573136e3fdbd69 (patch) | |
tree | 975a20beae580a2e7b535555182195fbcd01d3ed | |
parent | 34ae71c4d7ceac8fc10479d5ccc996685aaf8a67 (diff) |
Fix AVX2 implementation of IQ4_K, IQ4_KS, IQ5_K, IQ6_K (#427)
* Fix IQ4_K on AVX2
* Fix IQ4_KS on AVX2
* Fix IQ5_K on AVX2
* Fix IQ6_K on AVX2
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 116 |
1 files changed, 75 insertions, 41 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 8c649de4..6072d56d 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1673,6 +1673,29 @@ inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, } } +template <typename Q8, typename Bits> +inline void multiply_add_avx2(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { + __m256i p[4]; + if (j == 0) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + for (int k = 0; k < 4; ++k) { + auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]); + p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, k), bits.values[k]))); + } + sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p[0], p[1]), _mm256_add_epi32(p[2], p[3])); + } + } else { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + for (int k = 0; k < 4; ++k) { + auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]); + p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, 4+k), bits.values[k]))); + } + sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[0], p[2])); + sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[1], p[3])); + } + } +} + struct SignHelper { inline __m256i make_signs(uint32_t sign_bits) const { auto aux256 = _mm256_set1_epi32(sign_bits); @@ -2892,18 +2915,21 @@ struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> { }; struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> { - DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -128), values(load_iq4nl_values_256()) {} + DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); } template <typename Q8> - inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { + inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, __m256i * scales) { d = GGML_FP16_TO_FP32(x[i].d); - iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales); + auto scales8 = make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h); + auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, hshuff)); + prepare_scales_16(scales16, scales); } inline void prepare(int i, int j) { bits.prepare16(x[i].qs, j); - bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]); - bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]); - bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]); - bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]); + auto extra = x[i].extra >> 8*j; + bits.values[0] = _mm256_shuffle_epi8(values[extra & 3], bits.values[0]); extra >>= 2; + bits.values[1] = _mm256_shuffle_epi8(values[extra & 3], bits.values[1]); extra >>= 2; + bits.values[2] = _mm256_shuffle_epi8(values[extra & 3], bits.values[2]); extra >>= 2; + bits.values[3] = _mm256_shuffle_epi8(values[extra & 3], bits.values[3]); } __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const { uint64_t aux64; @@ -2911,20 +2937,28 @@ struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> { auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl); const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16); auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh); - auto sch = _mm_shuffle_epi8(aux, iqxk.hshuff); + auto sch = _mm_shuffle_epi8(aux, hshuff); return _mm_add_epi8(_mm_or_si128(scl, sch), m32); } + void load_values() { + auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0); + auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1); + values[0] = MM256_SET_M128I(v1, v1); + values[1] = MM256_SET_M128I(v1, v2); + values[2] = MM256_SET_M128I(v2, v1); + values[3] = MM256_SET_M128I(v2, v2); + } Q4Bits bits; - const IQXKScales iqxk; - const __m256i values; const __m128i maskl = _mm_set1_epi8(0xf); const __m128i maskh = _mm_set1_epi8(0x30); const __m128i m32 = _mm_set1_epi8(-32); + const __m128i hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + __m256i values[4]; }; struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> { - DequantizerIQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(2, -128) { load_values(values); } + DequantizerIQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(2, 0) { load_values(values); } template <typename Q8> inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { d = GGML_FP16_TO_FP32(x[i].d); @@ -2951,12 +2985,8 @@ struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> { return _mm_add_epi8(_mm_or_si128(scl, sch), m32); } static void load_values(__m256i * values) { - static const uint8_t kvalues_iq5nl[32] = { - 2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127, - 133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249, - }; - auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0); - auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1); + auto values128_1 = _mm_loadu_si128((const __m128i *)iq5nl_values + 0); + auto values128_2 = _mm_loadu_si128((const __m128i *)iq5nl_values + 1); values[0] = MM256_SET_M128I(values128_1, values128_1); values[1] = MM256_SET_M128I(values128_2, values128_2); } @@ -2972,7 +3002,7 @@ struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> { }; struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> { - DequantizerIQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(1, -128) { load_values(values); } + DequantizerIQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(1, 0) { load_values(values); } template <typename Q8> inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { d = GGML_FP16_TO_FP32(x[i].d); @@ -3000,14 +3030,8 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> { _mm256_and_si256(mask4, _mm256_shuffle_epi8(values[3], l)))); } static void load_values(__m256i * values) { - static const uint8_t kvalues_iq6nl[64] = { - 1, 7, 13, 19, 24, 30, 35, 40, 44, 49, 54, 58, 62, 66, 70, 74, - 77, 81, 84, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 117, 120, 123, - 126, 128, 131, 134, 137, 140, 142, 145, 148, 151, 155, 158, 161, 164, 168, 172, - 175, 179, 183, 187, 191, 196, 200, 205, 210, 215, 220, 226, 231, 237, 243, 249, - }; for (int k = 0; k < 4; ++k) { - auto values128 = _mm_loadu_si128((const __m128i *)kvalues_iq6nl + k); + auto values128 = _mm_loadu_si128((const __m128i *)iq6nl_values + k); values[k] = MM256_SET_M128I(values128, values128); } } @@ -3022,32 +3046,32 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> { }; struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> { - DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {} + DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); } template <typename Q8> - inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { + inline __m256i new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accd) { auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales)); - auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4); scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127); - auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts)); - s8k.accum_mins(scales_s, q8, i, d, accd); return MM256_SET_M128I(scales128, scales128); } inline void prepare(int i, int j) { bits.prepare16(x[i].qs, j); - bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]); - bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]); - bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]); - bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]); + bits.values[0] = _mm256_shuffle_epi8(values[x[i].scales[4*j+0] & 1], bits.values[0]); + bits.values[1] = _mm256_shuffle_epi8(values[x[i].scales[4*j+1] & 1], bits.values[1]); + bits.values[2] = _mm256_shuffle_epi8(values[x[i].scales[4*j+2] & 1], bits.values[2]); + bits.values[3] = _mm256_shuffle_epi8(values[x[i].scales[4*j+3] & 1], bits.values[3]); } + void load_values() { + auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0); + auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1); + values[0] = MM256_SET_M128I(v1, v1); + values[1] = MM256_SET_M128I(v2, v2); + } + Q4Bits bits; - Scales8KBase s8k; - const __m256i values; + __m256i values[2]; const __m128i mask = _mm_set1_epi16(254); const __m128i m127 = _mm_set1_epi16(-127); - const __m128i m128 = _mm_set1_epi16(-128); - const __m128i m1 = _mm_set1_epi16(1); - const __m128i m4 = _mm_set1_epi16(4); }; struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> { @@ -3304,7 +3328,13 @@ static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf for (int j = 0; j < QK_K/128; ++j) { deq.prepare(i, j); set_scales_16(all_scales[j], scales); - multiply_add(deq.bits, scales, j, i, q8, sumi); + if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4K> || + std::is_same_v<Dequantizer, DequantizerIQ5K> || + std::is_same_v<Dequantizer, DequantizerIQ6K>) { + multiply_add_avx2(deq.bits, scales, j, i, q8, sumi); + } else { + multiply_add(deq.bits, scales, j, i, q8, sumi); + } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -3351,7 +3381,11 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf set_scales_8(all_scales, j, scales); - multiply_add(deq.bits, scales, j, i, q8, sumi); + if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS>) { + multiply_add_avx2(deq.bits, scales, j, i, q8, sumi); + } else { + multiply_add(deq.bits, scales, j, i, q8, sumi); + } } |