diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-06-13 07:55:57 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-13 07:55:57 +0300 |
commit | 7a882f0b63897b22f3534f2c0c8ce34c20526360 (patch) | |
tree | 3b3213f0a0f0ce5456cdd112848abf6eaf8ef6a9 | |
parent | b57bd8658bfb20e65ad0b601eef6732fee45b81f (diff) |
Perhaps a slightly better version for IQ2_XXS, IQ3_XXS, IQ3_S GEMV (#524)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/iqk/iqk_gemm_iquants.cpp | 164 |
1 files changed, 105 insertions, 59 deletions
diff --git a/ggml/src/iqk/iqk_gemm_iquants.cpp b/ggml/src/iqk/iqk_gemm_iquants.cpp index 7f0258c1..60396fee 100644 --- a/ggml/src/iqk/iqk_gemm_iquants.cpp +++ b/ggml/src/iqk/iqk_gemm_iquants.cpp @@ -145,35 +145,6 @@ struct SignHelper { const __m256i mone = _mm256_set1_epi8(1); }; -// for (int i = 0; i < nb; ++i) { -// -// __m256i sumi[nrc_y], all_scales; -// //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256(); -// __m256i mins; -// float dmin = deq.new_block(i, &all_scales, mins); -// for (int iy = 0; iy < nrc_y; ++iy) { -// auto bsums = q8.load_bsums(iy, i); -// auto prod = _mm256_madd_epi16(mins, bsums); -// accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]); -// } -// -// for (int j = 0; j < QK_K/128; ++j) { -// deq.prepare(i, j); -// set_scales_8(&all_scales, j, scales); -// //multiply_add_iq(deq.bits, scales, j, i, q8, sumi); -// multiply_add(deq.bits, scales, j, i, q8, sumi); -// } -// for (int iy = 0; iy < nrc_y; ++iy) { -// const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i)); -// accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]); -// } -// } -// -// for (int iy = 0; iy < nrc_y; ++iy) { -// info.store(ix, iy, hsum_float_8(accd[iy])); -// } -// } - struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> { DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} @@ -221,7 +192,7 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> { } IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const { -#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__ +#if defined z_HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__ esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0); esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2); #else @@ -246,7 +217,11 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> { } inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) { for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k); - Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j); + Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j); + make4(data.val, bits.values, q8_quants); + } + inline void prepare(int i, int j, __m256i * q8_quants) { + Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j); make4(data.val, bits.values, q8_quants); } @@ -526,6 +501,13 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> { sign_2_values(signs+0, q8_quants+0); sign_2_values(signs+4, q8_quants+2); } + inline void prepare(int i, int j, __m256i * q8_quants) { + auto qs = x[i].qs + 32*j; + const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j; + make4_unsigned(qs, bits.values); + sign_2_values(signs+0, q8_quants+0); + sign_2_values(signs+4, q8_quants+2); + } constexpr static int minv = 64; @@ -625,6 +607,10 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> { for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k); sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants); } + inline void prepare(int i, int j, __m256i * q8_quants) { + prepare_unsigned(i, j); + sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants); + } inline void prepare_unsigned(int i, int j) { auto qs = x[i].qs + 32*j; @@ -787,15 +773,69 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data } } -template <typename Dequantizer, int nrc_y> +template <int n_sum> +inline __m256i compute_dot_4(const __m256i * x, const __m256i * y) { +#ifdef HAVE_FANCY_SIMD + auto sumi0 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[0], y[0]); + auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[1], y[1]); + auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[2], y[2]); + auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[3], y[3]); + sumi0 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1)); + sumi2 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3)); + return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2)); +#else + auto m1 = _mm256_set1_epi16(1); + if constexpr (n_sum == 2) { + auto sumi0 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[0], y[0])); + auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[1], y[1])); + auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[2], y[2])); + auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[3], y[3])); + sumi0 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1)); + sumi2 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3)); + return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2)); + } + else { + auto sumi0 = _mm256_maddubs_epi16(x[0], y[0]); + auto sumi1 = _mm256_maddubs_epi16(x[1], y[1]); + auto sumi2 = _mm256_maddubs_epi16(x[2], y[2]); + auto sumi3 = _mm256_maddubs_epi16(x[3], y[3]); + if constexpr (n_sum == 4) { + sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1)); + sumi2 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3)); + sumi0 = _mm256_madd_epi16(m1, sumi0); + sumi2 = _mm256_madd_epi16(m1, sumi2); + return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2)); + } + else { + auto sumi0 = _mm256_maddubs_epi16(x[0], y[0]); + auto sumi1 = _mm256_maddubs_epi16(x[1], y[1]); + auto sumi2 = _mm256_maddubs_epi16(x[2], y[2]); + auto sumi3 = _mm256_maddubs_epi16(x[3], y[3]); + sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1)); + sumi2 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3)); + sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2)); + return _mm256_madd_epi16(m1, sumi0); + } + } +#endif +} + +template <typename Dequantizer, int nrc_y, int n_sum = 2> static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { static_assert(Dequantizer::num_blocks == 8); + static_assert(n_sum == 2 || n_sum == 4 || n_sum == 8); +#ifdef HAVE_FANCY_SIMD + constexpr bool use_1_row = nrc_y == 1; +#else + constexpr bool use_1_row = nrc_y == 1 && !std::is_same_v<Dequantizer, DequantizerIQ2XXS>; +#endif + const int nb = n / QK_K; Q8<nrc_y, block_q8_2_x4> q8(info); Dequantizer deq(vx, bx); __m256 scales[3]; __m256 accd[nrc_y]; - __m256i sumi[4]; + __m256i vy[4]; for (int ix = 0; ix < nrc_x; ++ix) { @@ -806,35 +846,33 @@ static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const Data for (int i = 0; i < nb; ++i) { deq.new_block_f(i, scales); - for (int iy = 0; iy < nrc_y; ++iy) { - auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d + 4))); - auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d + 4))); - auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16)); - accd[iy] = _mm256_fmadd_ps(scales[2], my, accd[iy]); + if constexpr (!use_1_row) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d + 4))); + auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d + 4))); + auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16)); + accd[iy] = _mm256_fmadd_ps(scales[2], my, accd[iy]); + } } for (int j = 0; j < QK_K/128; ++j) { - deq.prepare(i, j); - auto& values = deq.bits.values; - for (int iy = 0; iy < nrc_y; ++iy) { - auto qs = q8.y[iy][2*i+j].qs; -#ifdef HAVE_FANCY_SIMD - sumi[0] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[0], _mm256_loadu_si256((const __m256i*)qs+0)); - sumi[1] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[1], _mm256_loadu_si256((const __m256i*)qs+1)); - sumi[2] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[2], _mm256_loadu_si256((const __m256i*)qs+2)); - sumi[3] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[3], _mm256_loadu_si256((const __m256i*)qs+3)); -#else - sumi[0] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[0], _mm256_loadu_si256((const __m256i*)qs+0))); - sumi[1] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[1], _mm256_loadu_si256((const __m256i*)qs+1))); - sumi[2] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[2], _mm256_loadu_si256((const __m256i*)qs+2))); - sumi[3] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[3], _mm256_loadu_si256((const __m256i*)qs+3))); -#endif - sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[0], sumi[1]), _mm256_unpackhi_epi32(sumi[0], sumi[1])); - sumi[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[2], sumi[3]), _mm256_unpackhi_epi32(sumi[2], sumi[3])); - sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi[0], sumi[2]), _mm256_unpackhi_epi64(sumi[0], sumi[2])); - auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][2*i+j].d)), 16)); + if constexpr (use_1_row) { + for (int k = 0; k < 4; ++k) vy[k] = _mm256_loadu_si256((const __m256i*)q8.y[0][2*i+j].qs+k); + deq.prepare(i, j, vy); + auto sumi = compute_dot_4<2*n_sum>(deq.bits.values, vy); + auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[0][2*i+j].d)), 16)); auto dy = _mm256_set_m128(d4, d4); - accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi[0]), accd[iy]); + accd[0] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi), accd[0]); + } else { + deq.prepare(i, j); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qs = q8.y[iy][2*i+j].qs; + for (int k = 0; k < 4; ++k) vy[k] = _mm256_loadu_si256((const __m256i*)qs+k); + auto sumi = compute_dot_4<n_sum>(deq.bits.values, vy); + auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][2*i+j].d)), 16)); + auto dy = _mm256_set_m128(d4, d4); + accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi), accd[iy]); + } } } } @@ -1934,7 +1972,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_ if (ggml_type(typeA) == GGML_TYPE_IQ3_S) { if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) { - IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels); + //IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels); + kernels[0] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 1, 8>; + kernels[1] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 2, 8>; + kernels[2] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 3, 8>; + kernels[3] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 4, 8>; + kernels[4] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 5, 8>; + kernels[5] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 6, 8>; + kernels[6] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 7, 8>; + kernels[7] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 8, 8>; func16 = nullptr; return true; } |