summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-06-13 07:55:57 +0300
committerGitHub <noreply@github.com>2025-06-13 07:55:57 +0300
commit7a882f0b63897b22f3534f2c0c8ce34c20526360 (patch)
tree3b3213f0a0f0ce5456cdd112848abf6eaf8ef6a9
parentb57bd8658bfb20e65ad0b601eef6732fee45b81f (diff)
Perhaps a slightly better version for IQ2_XXS, IQ3_XXS, IQ3_S GEMV (#524)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/iqk/iqk_gemm_iquants.cpp164
1 files changed, 105 insertions, 59 deletions
diff --git a/ggml/src/iqk/iqk_gemm_iquants.cpp b/ggml/src/iqk/iqk_gemm_iquants.cpp
index 7f0258c1..60396fee 100644
--- a/ggml/src/iqk/iqk_gemm_iquants.cpp
+++ b/ggml/src/iqk/iqk_gemm_iquants.cpp
@@ -145,35 +145,6 @@ struct SignHelper {
const __m256i mone = _mm256_set1_epi8(1);
};
-// for (int i = 0; i < nb; ++i) {
-//
-// __m256i sumi[nrc_y], all_scales;
-// //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();
-// __m256i mins;
-// float dmin = deq.new_block(i, &all_scales, mins);
-// for (int iy = 0; iy < nrc_y; ++iy) {
-// auto bsums = q8.load_bsums(iy, i);
-// auto prod = _mm256_madd_epi16(mins, bsums);
-// accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
-// }
-//
-// for (int j = 0; j < QK_K/128; ++j) {
-// deq.prepare(i, j);
-// set_scales_8(&all_scales, j, scales);
-// //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);
-// multiply_add(deq.bits, scales, j, i, q8, sumi);
-// }
-// for (int iy = 0; iy < nrc_y; ++iy) {
-// const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
-// accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
-// }
-// }
-//
-// for (int iy = 0; iy < nrc_y; ++iy) {
-// info.store(ix, iy, hsum_float_8(accd[iy]));
-// }
-// }
-
struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
@@ -221,7 +192,7 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
}
IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const {
-#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
+#if defined z_HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0);
esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2);
#else
@@ -246,7 +217,11 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
}
inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
- Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
+ Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
+ make4(data.val, bits.values, q8_quants);
+ }
+ inline void prepare(int i, int j, __m256i * q8_quants) {
+ Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
make4(data.val, bits.values, q8_quants);
}
@@ -526,6 +501,13 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
sign_2_values(signs+0, q8_quants+0);
sign_2_values(signs+4, q8_quants+2);
}
+ inline void prepare(int i, int j, __m256i * q8_quants) {
+ auto qs = x[i].qs + 32*j;
+ const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j;
+ make4_unsigned(qs, bits.values);
+ sign_2_values(signs+0, q8_quants+0);
+ sign_2_values(signs+4, q8_quants+2);
+ }
constexpr static int minv = 64;
@@ -625,6 +607,10 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants);
}
+ inline void prepare(int i, int j, __m256i * q8_quants) {
+ prepare_unsigned(i, j);
+ sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants);
+ }
inline void prepare_unsigned(int i, int j) {
auto qs = x[i].qs + 32*j;
@@ -787,15 +773,69 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
}
}
-template <typename Dequantizer, int nrc_y>
+template <int n_sum>
+inline __m256i compute_dot_4(const __m256i * x, const __m256i * y) {
+#ifdef HAVE_FANCY_SIMD
+ auto sumi0 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[0], y[0]);
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[1], y[1]);
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[2], y[2]);
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[3], y[3]);
+ sumi0 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
+ sumi2 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
+ return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
+#else
+ auto m1 = _mm256_set1_epi16(1);
+ if constexpr (n_sum == 2) {
+ auto sumi0 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[0], y[0]));
+ auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[1], y[1]));
+ auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[2], y[2]));
+ auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[3], y[3]));
+ sumi0 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
+ sumi2 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
+ return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
+ }
+ else {
+ auto sumi0 = _mm256_maddubs_epi16(x[0], y[0]);
+ auto sumi1 = _mm256_maddubs_epi16(x[1], y[1]);
+ auto sumi2 = _mm256_maddubs_epi16(x[2], y[2]);
+ auto sumi3 = _mm256_maddubs_epi16(x[3], y[3]);
+ if constexpr (n_sum == 4) {
+ sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
+ sumi2 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
+ sumi0 = _mm256_madd_epi16(m1, sumi0);
+ sumi2 = _mm256_madd_epi16(m1, sumi2);
+ return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
+ }
+ else {
+ auto sumi0 = _mm256_maddubs_epi16(x[0], y[0]);
+ auto sumi1 = _mm256_maddubs_epi16(x[1], y[1]);
+ auto sumi2 = _mm256_maddubs_epi16(x[2], y[2]);
+ auto sumi3 = _mm256_maddubs_epi16(x[3], y[3]);
+ sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
+ sumi2 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
+ sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
+ return _mm256_madd_epi16(m1, sumi0);
+ }
+ }
+#endif
+}
+
+template <typename Dequantizer, int nrc_y, int n_sum = 2>
static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
static_assert(Dequantizer::num_blocks == 8);
+ static_assert(n_sum == 2 || n_sum == 4 || n_sum == 8);
+#ifdef HAVE_FANCY_SIMD
+ constexpr bool use_1_row = nrc_y == 1;
+#else
+ constexpr bool use_1_row = nrc_y == 1 && !std::is_same_v<Dequantizer, DequantizerIQ2XXS>;
+#endif
+
const int nb = n / QK_K;
Q8<nrc_y, block_q8_2_x4> q8(info);
Dequantizer deq(vx, bx);
__m256 scales[3];
__m256 accd[nrc_y];
- __m256i sumi[4];
+ __m256i vy[4];
for (int ix = 0; ix < nrc_x; ++ix) {
@@ -806,35 +846,33 @@ static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const Data
for (int i = 0; i < nb; ++i) {
deq.new_block_f(i, scales);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d + 4)));
- auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d + 4)));
- auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
- accd[iy] = _mm256_fmadd_ps(scales[2], my, accd[iy]);
+ if constexpr (!use_1_row) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d + 4)));
+ auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d + 4)));
+ auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
+ accd[iy] = _mm256_fmadd_ps(scales[2], my, accd[iy]);
+ }
}
for (int j = 0; j < QK_K/128; ++j) {
- deq.prepare(i, j);
- auto& values = deq.bits.values;
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qs = q8.y[iy][2*i+j].qs;
-#ifdef HAVE_FANCY_SIMD
- sumi[0] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[0], _mm256_loadu_si256((const __m256i*)qs+0));
- sumi[1] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[1], _mm256_loadu_si256((const __m256i*)qs+1));
- sumi[2] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[2], _mm256_loadu_si256((const __m256i*)qs+2));
- sumi[3] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[3], _mm256_loadu_si256((const __m256i*)qs+3));
-#else
- sumi[0] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[0], _mm256_loadu_si256((const __m256i*)qs+0)));
- sumi[1] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[1], _mm256_loadu_si256((const __m256i*)qs+1)));
- sumi[2] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[2], _mm256_loadu_si256((const __m256i*)qs+2)));
- sumi[3] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[3], _mm256_loadu_si256((const __m256i*)qs+3)));
-#endif
- sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[0], sumi[1]), _mm256_unpackhi_epi32(sumi[0], sumi[1]));
- sumi[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[2], sumi[3]), _mm256_unpackhi_epi32(sumi[2], sumi[3]));
- sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi[0], sumi[2]), _mm256_unpackhi_epi64(sumi[0], sumi[2]));
- auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][2*i+j].d)), 16));
+ if constexpr (use_1_row) {
+ for (int k = 0; k < 4; ++k) vy[k] = _mm256_loadu_si256((const __m256i*)q8.y[0][2*i+j].qs+k);
+ deq.prepare(i, j, vy);
+ auto sumi = compute_dot_4<2*n_sum>(deq.bits.values, vy);
+ auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[0][2*i+j].d)), 16));
auto dy = _mm256_set_m128(d4, d4);
- accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi[0]), accd[iy]);
+ accd[0] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi), accd[0]);
+ } else {
+ deq.prepare(i, j);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qs = q8.y[iy][2*i+j].qs;
+ for (int k = 0; k < 4; ++k) vy[k] = _mm256_loadu_si256((const __m256i*)qs+k);
+ auto sumi = compute_dot_4<n_sum>(deq.bits.values, vy);
+ auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][2*i+j].d)), 16));
+ auto dy = _mm256_set_m128(d4, d4);
+ accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi), accd[iy]);
+ }
}
}
}
@@ -1934,7 +1972,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
if (ggml_type(typeA) == GGML_TYPE_IQ3_S) {
if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
- IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels);
+ //IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels);
+ kernels[0] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 1, 8>;
+ kernels[1] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 2, 8>;
+ kernels[2] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 3, 8>;
+ kernels[3] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 4, 8>;
+ kernels[4] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 5, 8>;
+ kernels[5] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 6, 8>;
+ kernels[6] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 7, 8>;
+ kernels[7] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 8, 8>;
func16 = nullptr;
return true;
}