summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-01-30 09:28:53 +0200
committerGitHub <noreply@github.com>2025-01-30 09:28:53 +0200
commit2e6b523853a8659c63283a6deca805051ecd713a (patch)
tree2b64156fa1c7004403a070efb030cf4a61805825 /ggml/src
parent4a73c250023a74bb1665875bbced7f1a3857b7f6 (diff)
Faster Q4_K_R4 and Q5_K_R4 on AVX2/Zen4 (#182)
* Slightly faster AVX2 implementation for q4_k_r4 * Even better AVX2 implementation for q4_k_r4 We now arrive at PP-512 = 328 t/s for LLaMA-3.1-8B on a Ryzen-5975WX CPU, up from 291 t/s when I last measured on 3c5f8722. With FA and Q8_0 K-cache we get to 339.5 t/s. * Fix llama-bench labels that I broke with #181 * Faster AVX2 implementation for q5_k_q4 We arrive at 302 t/s for LLaMA-3.1-8B on a Ryzen-5975WX CPU, up from 273 t/s. * Use AVX2 implementation of q4_k_r4 and q5_k_r4 also on Zen4 After the changes I made to AVX2, it ends up being slightly faster compared to what I had for Zen4. * Minor tweak * Cleanup --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp319
1 files changed, 66 insertions, 253 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 308d0dca..7fd56c42 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -4430,17 +4430,47 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
}
template <int nrc_y>
-static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+inline void process_min_r4_b32(int ibl, __m256 m4, __m256i mins, const Q8<nrc_y, block_q8_K>& q8, __m256 * acc) {
+ auto mins_l = _mm256_castsi256_si128(mins);
+ auto mins_h = _mm256_extracti128_si256(mins, 1);
+ auto aux1 = _mm_unpacklo_epi32(mins_l, mins_h);
+ auto aux2 = _mm_unpackhi_epi32(mins_l, mins_h);
+ auto ic1 = _mm256_cvtepi8_epi32(aux1);
+ auto ic2 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux1, 0xee));
+ auto ic3 = _mm256_cvtepi8_epi32(aux2);
+ auto ic4 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux2, 0xee));
+ if constexpr (nrc_y == 1) {
+ auto bs = _mm256_loadu_ps((const float *)q8.y[0][ibl].bsums);
+ auto sumf = _mm256_mul_ps(_mm256_cvtepi32_ps(ic1), _mm256_shuffle_ps(bs, bs, 0x00));
+ sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic2), _mm256_shuffle_ps(bs, bs, 0x55), sumf);
+ sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic3), _mm256_shuffle_ps(bs, bs, 0xaa), sumf);
+ sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic4), _mm256_shuffle_ps(bs, bs, 0xff), sumf);
+ acc[0] = _mm256_fmadd_ps(m4, sumf, acc[0]);
+ } else {
+ auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic1));
+ auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic2));
+ auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic3));
+ auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic4));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
+ acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = _mm256_set1_epi8(0xf);
auto m3 = _mm256_set1_epi8(0x30);
-#ifndef HAVE_FANCY_SIMD
- auto m1 = _mm256_set1_epi16(1);
-#endif
int nbl = n / QK_K;
union { __m256i vec; uint32_t val[8]; } hd;
__m256 acc[nrc_y] = {};
+ __m256i isum[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx);
@@ -4448,31 +4478,20 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
- if constexpr (nrc_y == 1) {
- d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
- }
auto lbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
auto hbits128 = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_h);
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m3));
auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3));
- auto shuffle = _mm256_set1_epi64x(0x0000000400000000);
- auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
- shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
- auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
- shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
- auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
- shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
- auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
- acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
- }
+ process_min_r4_b32(ibl, m4, mins, q8, acc);
for (int ib = 0; ib < QK_K/32; ++ib) {
- auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]))));
+#ifdef HAVE_FANCY_SIMD
+ auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
+#else
+ auto aux = _mm_set1_epi32(hd.val[ib]);
+ aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
+ auto scales_d = MM256_SET_M128I(aux, aux);
+#endif
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
qx[0] = _mm256_and_si256(bits1, mf);
@@ -4487,21 +4506,20 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
- auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales_d, _mm256_add_epi16(sumi1, sumi2)));
#endif
- if constexpr (nrc_y == 1) {
- acc[iy] = _mm256_fmadd_ps(scales_d, _mm256_cvtepi32_ps(sumi), acc[iy]);
- } else {
- float d8 = q8.scale(iy, ibl);
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
}
}
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
@@ -4511,113 +4529,17 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
}
}
-#ifdef HAVE_FANCY_SIMD
-template <int nrc_y>
-static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- //mul_mat_q4_k_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
- if constexpr (nrc_y == 1){
- mul_mat_q4_k_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x);
- } else {
- GGML_ASSERT(nrc_x%8 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto mf = _mm512_set1_epi8(0xf);
- int nbl = n / QK_K;
- using helper_t = union { __m512i vec; uint32_t val[16]; };
- helper_t hd, hm;
- __m512 acc[nrc_y] = {};
- __m512i isum[nrc_y] = {};
- __m512i qx[4];
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_q4_k_r4 * iq4l = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx);
- const block_q4_k_r4 * iq4h = (const block_q4_k_r4 *)((const char *)vx + (ix+4)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto d1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l[ibl].d));
- auto d2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h[ibl].d));
- auto dl = _mm256_castps256_ps128(d1);
- auto ml = _mm256_extractf128_ps(d1, 1);
- auto dh = _mm256_castps256_ps128(d2);
- auto mh = _mm256_extractf128_ps(d2, 1);
- auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1);
- auto m4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(ml, ml)), _mm256_set_m128(mh, mh), 1);
- m4 = _mm512_mul_ps(m4, _mm512_set1_ps(-0.5f));
- auto slbits_l = _mm256_loadu_si256((const __m256i *)iq4l[ibl].scales_l);
- auto shbits_l = _mm256_loadu_si256((const __m256i *)iq4h[ibl].scales_l);
- auto slb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_l), shbits_l, 1);
- auto sld = _mm512_and_si512(slb, mf);
- auto slm = _mm512_and_si512(_mm512_srli_epi16(slb, 4), mf);
- auto slbits_h = _mm_loadu_si128((const __m128i *)iq4l[ibl].scales_h);
- auto shbits_h = _mm_loadu_si128((const __m128i *)iq4h[ibl].scales_h);
- auto slbits_h2 = MM256_SET_M128I(_mm_srli_epi16(slbits_h, 4), slbits_h);
- auto shbits_h2 = MM256_SET_M128I(_mm_srli_epi16(shbits_h, 4), shbits_h);
- auto shb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_h2), shbits_h2, 1);
- auto shd = _mm512_and_si512(_mm512_slli_epi16(shb, 4), _mm512_set1_epi8(0x30));
- auto shm = _mm512_and_si512(_mm512_slli_epi16(shb, 2), _mm512_set1_epi8(0x30));
- hd.vec = _mm512_or_si512(sld, shd);
- hm.vec = _mm512_or_si512(slm, shm);
- for (int ib = 0; ib < QK_K/32; ++ib) {
- auto scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+0]));
- auto scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+8]));
- auto iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
- scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+0]));
- scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+8]));
- auto iscales_m = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
- auto scales_m = _mm512_mul_ps(m4, _mm512_cvtepi32_ps(iscales_m));
- auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+0)),
- _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+0), 1);
- auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+1)),
- _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+1), 1);
- qx[0] = _mm512_and_si512(bits1, mf);
- qx[1] = _mm512_and_si512(bits2, mf);
- qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), mf);
- qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), mf);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
- auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi));
- float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
- acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm512_setzero_si512();
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1));
- auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3));
- info.store(ix+0, iy, sum1);
- info.store(ix+4, iy, sum2);
- acc[iy] = _mm512_setzero_ps();
- }
- }
- }
-}
-#else
template <int nrc_y>
-static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- mul_mat_q4_k_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
-}
-#endif
-
-template <int nrc_y>
-static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = _mm256_set1_epi8(0xf);
auto m10 = _mm256_set1_epi8(0x10);
auto m30 = _mm256_set1_epi8(0x30);
-#ifndef HAVE_FANCY_SIMD
- auto m1 = _mm256_set1_epi16(1);
-#endif
int nbl = n / QK_K;
union { __m256i vec; uint32_t val[8]; } hd;
__m256 acc[nrc_y] = {};
+ __m256i isum[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx);
@@ -4625,31 +4547,20 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5[ibl].d));
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
- if constexpr (nrc_y == 1) {
- d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
- }
auto lbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l);
auto hbits128 = _mm_loadu_si128((const __m128i *)iq5[ibl].scales_h);
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m30));
auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m30));
- auto shuffle = _mm256_set1_epi64x(0x0000000400000000);
- auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
- shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
- auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
- shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
- auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
- shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
- auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
- acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
- }
+ process_min_r4_b32(ibl, m4, mins, q8, acc);
for (int ib = 0; ib < QK_K/32; ++ib) {
- auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]))));
+#ifdef HAVE_FANCY_SIMD
+ auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
+#else
+ auto aux = _mm_set1_epi32(hd.val[ib]);
+ aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
+ auto scales_d = MM256_SET_M128I(aux, aux);
+#endif
auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
auto hbits128 = _mm_loadu_si128((const __m128i*)iq5[ibl].qh + ib);
@@ -4666,21 +4577,22 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
- auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
+ // To avoid overflow, we can only add up to 4 q5 x q8 products.
+ auto sumi = _mm256_add_epi32(_mm256_madd_epi16(scales_d, sumi1), _mm256_madd_epi16(scales_d, sumi2));
+ isum[iy] = _mm256_add_epi32(isum[iy], sumi);
#endif
- if constexpr (nrc_y == 1) {
- acc[iy] = _mm256_fmadd_ps(scales_d, _mm256_cvtepi32_ps(sumi), acc[iy]);
- } else {
- float d8 = q8.scale(iy, ibl);
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
}
}
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
@@ -4690,105 +4602,6 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
}
}
-#ifdef HAVE_FANCY_SIMD
-template <int nrc_y>
-static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- if constexpr (nrc_y == 1){
- mul_mat_q5_k_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x);
- } else {
- GGML_ASSERT(nrc_x%8 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto mf = _mm512_set1_epi8(0xf);
- auto m10 = _mm512_set1_epi8(0x10);
- int nbl = n / QK_K;
- using helper_t = union { __m512i vec; uint32_t val[16]; };
- helper_t hd, hm;
- __m512 acc[nrc_y] = {};
- __m512i isum[nrc_y] = {};
- __m512i qx[4];
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_q5_k_r4 * iq5l = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx);
- const block_q5_k_r4 * iq5h = (const block_q5_k_r4 *)((const char *)vx + (ix+4)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto d1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5l[ibl].d));
- auto d2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5h[ibl].d));
- auto dl = _mm256_castps256_ps128(d1);
- auto ml = _mm256_extractf128_ps(d1, 1);
- auto dh = _mm256_castps256_ps128(d2);
- auto mh = _mm256_extractf128_ps(d2, 1);
- auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1);
- auto m4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(ml, ml)), _mm256_set_m128(mh, mh), 1);
- m4 = _mm512_mul_ps(m4, _mm512_set1_ps(-0.5f));
- auto slbits_l = _mm256_loadu_si256((const __m256i *)iq5l[ibl].scales_l);
- auto shbits_l = _mm256_loadu_si256((const __m256i *)iq5h[ibl].scales_l);
- auto slb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_l), shbits_l, 1);
- auto sld = _mm512_and_si512(slb, mf);
- auto slm = _mm512_and_si512(_mm512_srli_epi16(slb, 4), mf);
- auto slbits_h = _mm_loadu_si128((const __m128i *)iq5l[ibl].scales_h);
- auto shbits_h = _mm_loadu_si128((const __m128i *)iq5h[ibl].scales_h);
- auto slbits_h2 = MM256_SET_M128I(_mm_srli_epi16(slbits_h, 4), slbits_h);
- auto shbits_h2 = MM256_SET_M128I(_mm_srli_epi16(shbits_h, 4), shbits_h);
- auto shb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_h2), shbits_h2, 1);
- auto shd = _mm512_and_si512(_mm512_slli_epi16(shb, 4), _mm512_set1_epi8(0x30));
- auto shm = _mm512_and_si512(_mm512_slli_epi16(shb, 2), _mm512_set1_epi8(0x30));
- hd.vec = _mm512_or_si512(sld, shd);
- hm.vec = _mm512_or_si512(slm, shm);
- for (int ib = 0; ib < QK_K/32; ++ib) {
- auto scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+0]));
- auto scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+8]));
- auto iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
- scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+0]));
- scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+8]));
- auto iscales_m = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
- auto scales_m = _mm512_mul_ps(m4, _mm512_cvtepi32_ps(iscales_m));
- auto lbits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[ibl].qs+2*ib+0)),
- _mm256_loadu_si256((const __m256i *)iq5h[ibl].qs+2*ib+0), 1);
- auto lbits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[ibl].qs+2*ib+1)),
- _mm256_loadu_si256((const __m256i *)iq5h[ibl].qs+2*ib+1), 1);
- auto hbits1 = _mm_loadu_si128((const __m128i*)iq5l[ibl].qh+ib);
- auto hbits2 = _mm_loadu_si128((const __m128i*)iq5h[ibl].qh+ib);
- auto hbl = MM256_SET_M128I(hbits1, _mm_slli_epi16(hbits1, 4));
- auto hbh = MM256_SET_M128I(hbits2, _mm_slli_epi16(hbits2, 4));
- auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbl), hbh, 1);
- qx[0] = _mm512_or_si512(_mm512_and_si512(lbits1, mf), _mm512_and_si512(m10, hbits));
- qx[1] = _mm512_or_si512(_mm512_and_si512(lbits2, mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 2)));
- qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(lbits1, 4), mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 1)));
- qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(lbits2, 4), mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 3)));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
- auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi));
- float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
- acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm512_setzero_si512();
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1));
- auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3));
- info.store(ix+0, iy, sum1);
- info.store(ix+4, iy, sum2);
- acc[iy] = _mm512_setzero_ps();
- }
- }
- }
-}
-#else
-template <int nrc_y>
-static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- mul_mat_q5_k_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
-}
-#endif
-
template <int nrc_y>
static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);