diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-05-17 10:42:33 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-17 10:42:33 +0300 |
commit | c35a383bcd8e4bd334ba2b8d2eb96103e69f75d4 (patch) | |
tree | 8658d5c34801e47efb62322822b628fff3ae00e4 | |
parent | 7abdf2b099ecf9bea156a635a8f22d168483f2b1 (diff) |
Zen4: Faster PP for IQ2_KS, IQ4_KS, IQ5_KS (#428)
* Zen4: faster PP for iq4_ks and iq5_ks
* Zen4: faster PP for iq2_ks
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 141 |
1 files changed, 119 insertions, 22 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 7d7ae798..654cc706 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1798,6 +1798,13 @@ struct Q4Bits { values[2] = _mm512_and_si512(q4bits, ml); values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); } + inline void prepare64a(const uint8_t * q4) { + for (int k = 0; k < 4; ++k) { + auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + k); + values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(q4bits), _mm256_srli_epi16(q4bits, 4), 1); + values[k] = _mm512_and_si512(values[k], ml); + } + } __m512i values[4]; const __m512i ml = _mm512_set1_epi8(0xf); BlockPermuter perm; @@ -2106,16 +2113,26 @@ struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> { struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> { DequantizerIQ2KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {} template <typename Q8> - inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + inline void compute_block(int i, const Q8& q8, __m512 * acc) { prepare(x[i].qs); auto scales128 = make_scales(x[i].scales, x[i].extra >> 8); auto shifts = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi8(x[i].extra), hmask), hmask), m5); - auto scales_s = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts))); - s8k.accum_mins(scales_s, q8, i, d, accm); + auto mins128 = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts))); + auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0])); auto scales256 = MM256_SET_M128I(scales128, scales128); auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); - scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); - scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); + __m512i scales[4]; + for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums(iy, i); + auto prod = _mm256_madd_epi16(mins, q8s); + auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0); + for (int k = 0; k < 4; ++k) { + auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k)); + sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]); + } + acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]); + } } inline void prepare(const uint8_t * q2) { bits.prepare(q2); @@ -2140,7 +2157,7 @@ struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> return _mm_cvtepi8_epi16(_mm_add_epi8(scl, sch)); } Q2Bits bits; - Scales8K s8k; + Scales8KBase s8k; const __m512i values; const __m128i m16 = _mm_set1_epi8(-16); @@ -2149,6 +2166,12 @@ struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> const __m128i hmask = _mm_set1_epi64x(0x8040201008040201); const __m128i shuffle = _mm_set1_epi64x(0x0703060205010400); const __m128i shift = _mm_set_epi32(0, 0, 4, 0); + const __m512i shuffles[4] = { + _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1), + }; }; struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> { @@ -2377,6 +2400,29 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> { scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]); prepare(x[i].qs); } + template <typename Q8> + inline void compute_block(int i, const Q8& q8, __m512 * acc) { + auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales)); + auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4); + scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127); + auto mins128 = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts)); + auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0])); + auto scales256 = MM256_SET_M128I(scales128, scales128); + auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); + __m512i scales[4]; + for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]); + prepare(x[i].qs); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums(iy, i); + auto prod = _mm256_madd_epi16(mins, q8s); + auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0); + for (int k = 0; k < 4; ++k) { + auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k)); + sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]); + } + acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]); + } + } inline void prepare(const uint8_t * q4) { bits.prepare64(q4); // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111 @@ -2425,10 +2471,33 @@ struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> { scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]); prepare(x[i].qs, x[i].qh); } + template <typename Q8> + inline void compute_block(int i, const Q8& q8, __m512 * acc) { + auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales)); + auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2); + scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127); + auto mins128 = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts)); + auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0])); + auto scales256 = MM256_SET_M128I(scales128, scales128); + auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); + __m512i scales[4]; + for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]); + prepare(x[i].qs, x[i].qh); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums(iy, i); + auto prod = _mm256_madd_epi16(mins, q8s); + auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0); + for (int k = 0; k < 4; ++k) { + auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k)); + sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]); + } + acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]); + } + } inline void prepare(const uint8_t * q4, const uint8_t * qh) { - bits.prepare64(q4); + bits.prepare64a(q4); auto h256 = _mm256_loadu_si256((const __m256i *)qh); - auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 2), 1); + auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 1), 1); auto m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1); auto m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2); bits.values[0] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[0]), m1, values[1], bits.values[0]); @@ -2438,15 +2507,6 @@ struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> { m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2); bits.values[2] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[2]), m1, values[1], bits.values[2]); bits.values[3] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[3]), m2, values[1], bits.values[3]); - // We now have in bits.valuse[0]: 0...31, 64...95 - // bits.valuse[1]: 32..63, 96..127 - // etc. - auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]); - bits.values[1] = _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]); - bits.values[0] = tmp; - tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]); - bits.values[3] = _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]); - bits.values[2] = tmp; } static void load_values(__m512i * values) { static const uint8_t kvalues_iq5nl[32] = { @@ -2465,9 +2525,7 @@ struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> { Scales8KBase s8k; __m512i values[2]; const __m512i hmask1 = _mm512_set1_epi8(1); - const __m512i hmask2 = _mm512_set1_epi8(2); - const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0); - const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4); + const __m512i hmask2 = _mm512_set1_epi8(4); const __m128i m127 = _mm_set1_epi16(-127); const __m128i m128 = _mm_set1_epi16(-128); const __m128i mask = _mm_set1_epi16(254); @@ -2651,6 +2709,34 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D } } +template <typename Dequantizer, int nrc_y> +static void mul_mat_iqX_k_q8_K_AVX512_new(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8<nrc_y> q8(info); + + Dequantizer deq(vx, bx); + + __m512 accd[nrc_y]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps(); + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + deq.compute_block(i, q8, accd); + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, _mm512_reduce_add_ps(accd[iy])); + } + + } +} + template <typename Dequantizer> static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); @@ -9713,8 +9799,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { std::is_same_v<Dequantizer, DequantizerIQ4K> || std::is_same_v<Dequantizer, DequantizerIQ3K> || std::is_same_v<Dequantizer, DequantizerIQ4XS>|| - std::is_same_v<Dequantizer, DequantizerIQ4KS>|| - std::is_same_v<Dequantizer, DequantizerIQ5KS>|| + //std::is_same_v<Dequantizer, DequantizerIQ4KS>|| + //std::is_same_v<Dequantizer, DequantizerIQ5KS>|| std::is_same_v<Dequantizer, DequantizerIQ4KSS>) { m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>; m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>; @@ -9724,6 +9810,17 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 6>; m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 7>; m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 8>; + } else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2KS> || + std::is_same_v<Dequantizer, DequantizerIQ4KS> || + std::is_same_v<Dequantizer, DequantizerIQ5KS>) { + m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 1>; + m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 2>; + m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 3>; + m.funcs[3] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 4>; + m.funcs[4] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 5>; + m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 6>; + m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 7>; + m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 8>; } else { m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>; m.funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>; |