summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp342
1 files changed, 342 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 081ebb57..f8c876ae 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -119,12 +119,27 @@ typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& inf
struct MulMat {
std::array<mul_mat_t, 8> funcs = {};
+ mul_mat_t func16 = nullptr;
inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {
#ifdef __aarch64__
constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small)
#else
constexpr int k_x_step = 64; // This works best on my Ryzen-7950X (but differences to other tile size are small)
#endif
+ if (func16 && nrc_y >= 16) {
+ int n_step = (nrc_y - info.cur_y)/16;
+ for (int ix = 0; ix < nrc_x; ix += k_x_step) {
+ auto this_info = info;
+ this_info.s += ix;
+ int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
+ for (int iy = 0; iy < n_step; ++iy) {
+ func16(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);
+ this_info.cur_y += 16;
+ }
+ }
+ info.cur_y += 16 * n_step;
+ if (info.cur_y == nrc_y) return;
+ }
int ny = funcs.size();
while (!funcs[ny-1] && ny > 0) --ny;
int n_step = (nrc_y - info.cur_y)/ny;
@@ -3425,6 +3440,165 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data
}
}
+static void mul_mat_iq2_xs_r4_q8_k_16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ constexpr int nrc_y = 16;
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+#ifndef HAVE_FANCY_SIMD
+ auto smask = _mm256_set1_epi64x(0x8040201008040201);
+ auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ auto m4 = _mm256_set1_epi8(4);
+#endif
+ __m256 acc[nrc_y] = {};
+#ifdef HAVE_FANCY_SIMD
+ __m256i shuffles[2] = {
+ _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100),
+ _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
+ };
+ __m256i isum[2*nrc_y] = {};
+#else
+ __m256i shuffles[4] = {
+ MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)),
+ };
+ __m256i isum[nrc_y == 1 ? 4 : nrc_y] = {};
+#endif
+ auto s_shuffle = _mm_set_epi64x(0x0f0d0b0907050301, 0x0e0c0a0806040200);
+ __m256i qx[4];
+ union { __m256i vec; uint16_t val[16]; } helper;
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto iq2 = (const block_iq2_xs_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+ auto s32 = (const uint32_t *)iq2[ibl].scales;
+ {
+ auto scale_bits = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales);
+ auto scales1 = _mm256_and_si256(scale_bits, _mm256_set1_epi8(0xf));
+ auto scales2 = _mm256_and_si256(_mm256_srli_epi16(scale_bits, 4), _mm256_set1_epi8(0xf));
+ scales1 = _mm256_or_si256(_mm256_slli_epi16(scales1, 1), _mm256_set1_epi8(1));
+ scales2 = _mm256_or_si256(_mm256_slli_epi16(scales2, 1), _mm256_set1_epi8(1));
+ auto s1_8 = _mm256_unpacklo_epi8(scales1, scales2); // blocks 0...15, 32...47 (0...3, 8...11 from each row)
+ auto s2_8 = _mm256_unpackhi_epi8(scales1, scales2); // blocks 16..31, 48...63 (4...7, 12..15 from each row)
+ auto s1_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s1_8)); // 0...15 (0...3 from each row)
+ auto s2_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s1_8, 1)); // 32...47 (8..11 from each row)
+ auto s3_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s2_8)); // 16...31 (4...7 from each row)
+ auto s4_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s2_8, 1)); // 48...63 (12.15 from each row)
+ auto t1 = MM256_SET_M128I(_mm256_castsi256_si128(s2_16), _mm256_castsi256_si128(s1_16)); // 0,1 and 8,9 from each row
+ auto t2 = MM256_SET_M128I(_mm256_extracti128_si256(s2_16, 1), _mm256_extracti128_si256(s1_16, 1)); // 2,3 and 10,11 from each row
+ auto t3 = MM256_SET_M128I(_mm256_castsi256_si128(s4_16), _mm256_castsi256_si128(s3_16)); // 4,5 and 12,13 from each row
+ auto t4 = MM256_SET_M128I(_mm256_extracti128_si256(s4_16, 1), _mm256_extracti128_si256(s3_16, 1)); // 6,7 and 14,15 from each row
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, ibl);
+ auto sumi = _mm256_setzero_si256();
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm256_dpwssd_epi32(sumi, t1, _mm256_shuffle_epi32(bsums, 0x00));
+ sumi = _mm256_dpwssd_epi32(sumi, t2, _mm256_shuffle_epi32(bsums, 0x55));
+ sumi = _mm256_dpwssd_epi32(sumi, t3, _mm256_shuffle_epi32(bsums, 0xaa));
+ sumi = _mm256_dpwssd_epi32(sumi, t4, _mm256_shuffle_epi32(bsums, 0xff));
+#else
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t1, _mm256_shuffle_epi32(bsums, 0x00)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t2, _mm256_shuffle_epi32(bsums, 0x55)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t3, _mm256_shuffle_epi32(bsums, 0xaa)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t4, _mm256_shuffle_epi32(bsums, 0xff)));
+#endif
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(-64.f*q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto val = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs + ib);
+ helper.vec = _mm256_and_si256(val, _mm256_set1_epi16(511));
+ qx[0] = _mm256_set_epi64x(iq2xs_grid[helper.val[ 3]], iq2xs_grid[helper.val[ 2]], iq2xs_grid[helper.val[ 1]], iq2xs_grid[helper.val[ 0]]);
+ qx[1] = _mm256_set_epi64x(iq2xs_grid[helper.val[ 7]], iq2xs_grid[helper.val[ 6]], iq2xs_grid[helper.val[ 5]], iq2xs_grid[helper.val[ 4]]);
+ qx[2] = _mm256_set_epi64x(iq2xs_grid[helper.val[11]], iq2xs_grid[helper.val[10]], iq2xs_grid[helper.val[ 9]], iq2xs_grid[helper.val[ 8]]);
+ qx[3] = _mm256_set_epi64x(iq2xs_grid[helper.val[15]], iq2xs_grid[helper.val[14]], iq2xs_grid[helper.val[13]], iq2xs_grid[helper.val[12]]);
+ auto signs16 = _mm256_srli_epi16(val, 9);
+ signs16 = _mm256_xor_si256(signs16, _mm256_slli_epi16(signs16, 1));
+ auto signs128 = _mm_or_si128(_mm256_castsi256_si128(signs16), _mm_slli_epi16(_mm256_extracti128_si256(signs16, 1), 8));
+ signs128 = _mm_shuffle_epi8(signs128, s_shuffle);
+ auto scales = _mm_set1_epi32(s32[ib]);
+ scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf));
+ scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1));
+ auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7
+#ifdef HAVE_FANCY_SIMD
+ __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) };
+ auto mask = (const __mmask32 *)&signs128;
+ qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[0], mask[0], _mm256_setzero_si256(), qx[0]));
+ qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[1], mask[1], _mm256_setzero_si256(), qx[1]));
+ qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[2], mask[2], _mm256_setzero_si256(), qx[2]));
+ qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[3], mask[3], _mm256_setzero_si256(), qx[3]));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], y); // blocks: 0,0,0,0, 1,1,1,1, row 0
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], y); // blocks: 2,2,2,2, 3,3,3,3, row 1
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], y); // blocks: 4,4,4,4, 5,5,5,5, row 2
+ auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], y); // blocks: 6,6,6,6, 7,7,7,7, row 3
+ auto s12 = _mm256_packs_epi32(sumi1, sumi2); // 0,0,0,0, 2,2,2,2, 1,1,1,1, 3,3,3,3
+ auto s34 = _mm256_packs_epi32(sumi3, sumi4); // 4,4,4,4, 6,6,6,6, 5,5,5,5, 7,7,7,7
+ isum[2*iy+0] = _mm256_add_epi32(isum[2*iy+0], _mm256_madd_epi16(scs[0], s12));
+ isum[2*iy+1] = _mm256_add_epi32(isum[2*iy+1], _mm256_madd_epi16(scs[1], s34));
+ }
+#else
+ auto signs = MM256_SET_M128I(signs128, signs128);
+ auto shuffle = sign_shuffle;
+ auto s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[0], s));
+ s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[1], s));
+ s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[2], s));
+ s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[3], s));
+ __m256i scs[4] = {
+ _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]),
+ _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]),
+ };
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], y)); // blocks 4x0, 4x1, row 0
+ auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], y)); // blocks 4x2, 4x3, row 1
+ auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], y)); // blocks 4x4, 4x5, row 2
+ auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], y)); // blocks 4x6, 4x7, row 3
+ auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
+ auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
+ auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
+ isum[iy] = _mm256_add_epi32(isum[iy], sumi);
+ }
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]);
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256();
+#else
+ if constexpr (nrc_y == 1) {
+ auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[0], isum[1]), _mm256_unpackhi_epi32(isum[0], isum[1]));
+ auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[2], isum[3]), _mm256_unpackhi_epi32(isum[2], isum[3]));
+ auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34));
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ isum[0] = isum[1] = isum[2] = isum[3] = _mm256_setzero_si256();
+ } else {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum));
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
template <int nrc_y>
static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -3547,6 +3721,154 @@ static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
}
}
+static void mul_mat_iq2_s_r4_q8_k_16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ constexpr int nrc_y = 16;
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+#ifndef HAVE_FANCY_SIMD
+ auto smask = _mm256_set1_epi64x(0x8040201008040201);
+ auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ auto m4 = _mm256_set1_epi8(4);
+#endif
+ __m256 acc[nrc_y] = {};
+#ifdef HAVE_FANCY_SIMD
+ __m256i shuffles[2] = {
+ _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100),
+ _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
+ };
+ __m256i isum[2*nrc_y] = {};
+#else
+ __m256i shuffles[4] = {
+ MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)),
+ };
+ __m256i isum[nrc_y == 1 ? 4 : nrc_y] = {};
+#endif
+ __m256i qx[4];
+ auto grid = iq2s_grid;
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+ auto s32 = (const uint32_t *)iq2[ibl].scales;
+ auto ql = iq2[ibl].qs;
+ auto qh = iq2[ibl].qh;
+ {
+ auto scale_bits = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales);
+ auto scales1 = _mm256_and_si256(scale_bits, _mm256_set1_epi8(0xf));
+ auto scales2 = _mm256_and_si256(_mm256_srli_epi16(scale_bits, 4), _mm256_set1_epi8(0xf));
+ scales1 = _mm256_or_si256(_mm256_slli_epi16(scales1, 1), _mm256_set1_epi8(1));
+ scales2 = _mm256_or_si256(_mm256_slli_epi16(scales2, 1), _mm256_set1_epi8(1));
+ auto s1_8 = _mm256_unpacklo_epi8(scales1, scales2); // blocks 0...15, 32...47 (0...3, 8...11 from each row)
+ auto s2_8 = _mm256_unpackhi_epi8(scales1, scales2); // blocks 16..31, 48...63 (4...7, 12..15 from each row)
+ auto s1_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s1_8)); // 0...15 (0...3 from each row)
+ auto s2_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s1_8, 1)); // 32...47 (8..11 from each row)
+ auto s3_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s2_8)); // 16...31 (4...7 from each row)
+ auto s4_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s2_8, 1)); // 48...63 (12.15 from each row)
+ auto t1 = MM256_SET_M128I(_mm256_castsi256_si128(s2_16), _mm256_castsi256_si128(s1_16)); // 0,1 and 8,9 from each row
+ auto t2 = MM256_SET_M128I(_mm256_extracti128_si256(s2_16, 1), _mm256_extracti128_si256(s1_16, 1)); // 2,3 and 10,11 from each row
+ auto t3 = MM256_SET_M128I(_mm256_castsi256_si128(s4_16), _mm256_castsi256_si128(s3_16)); // 4,5 and 12,13 from each row
+ auto t4 = MM256_SET_M128I(_mm256_extracti128_si256(s4_16, 1), _mm256_extracti128_si256(s3_16, 1)); // 6,7 and 14,15 from each row
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, ibl);
+ auto sumi = _mm256_setzero_si256();
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm256_dpwssd_epi32(sumi, t1, _mm256_shuffle_epi32(bsums, 0x00));
+ sumi = _mm256_dpwssd_epi32(sumi, t2, _mm256_shuffle_epi32(bsums, 0x55));
+ sumi = _mm256_dpwssd_epi32(sumi, t3, _mm256_shuffle_epi32(bsums, 0xaa));
+ sumi = _mm256_dpwssd_epi32(sumi, t4, _mm256_shuffle_epi32(bsums, 0xff));
+#else
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t1, _mm256_shuffle_epi32(bsums, 0x00)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t2, _mm256_shuffle_epi32(bsums, 0x55)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t3, _mm256_shuffle_epi32(bsums, 0xaa)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t4, _mm256_shuffle_epi32(bsums, 0xff)));
+#endif
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(-64.f*q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ qx[0] = _mm256_set_epi64x(grid[ql[ 3] | ((qh[0] << 2) & 0x300)], grid[ql[ 2] | ((qh[0] << 4) & 0x300)], grid[ql[ 1] | ((qh[0] << 6) & 0x300)], grid[ql[ 0] | ((qh[0] << 8) & 0x300)]);
+ qx[1] = _mm256_set_epi64x(grid[ql[ 7] | ((qh[1] << 2) & 0x300)], grid[ql[ 6] | ((qh[1] << 4) & 0x300)], grid[ql[ 5] | ((qh[1] << 6) & 0x300)], grid[ql[ 4] | ((qh[1] << 8) & 0x300)]);
+ qx[2] = _mm256_set_epi64x(grid[ql[11] | ((qh[2] << 2) & 0x300)], grid[ql[10] | ((qh[2] << 4) & 0x300)], grid[ql[ 9] | ((qh[2] << 6) & 0x300)], grid[ql[ 8] | ((qh[2] << 8) & 0x300)]);
+ qx[3] = _mm256_set_epi64x(grid[ql[15] | ((qh[3] << 2) & 0x300)], grid[ql[14] | ((qh[3] << 4) & 0x300)], grid[ql[13] | ((qh[3] << 6) & 0x300)], grid[ql[12] | ((qh[3] << 8) & 0x300)]);
+ ql += 16; qh += 4;
+ auto signs128 = _mm_loadu_si128((const __m128i*)iq2[ibl].signs + ib);
+ auto scales = _mm_set1_epi32(s32[ib]);
+ scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf));
+ scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1));
+ auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7
+#ifdef HAVE_FANCY_SIMD
+ __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) };
+ auto mask = (const __mmask32 *)&signs128;
+ qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[0], mask[0], _mm256_setzero_si256(), qx[0]));
+ qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[1], mask[1], _mm256_setzero_si256(), qx[1]));
+ qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[2], mask[2], _mm256_setzero_si256(), qx[2]));
+ qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[3], mask[3], _mm256_setzero_si256(), qx[3]));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], y); // blocks: 0,0,0,0, 1,1,1,1, row 0
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], y); // blocks: 2,2,2,2, 3,3,3,3, row 1
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], y); // blocks: 4,4,4,4, 5,5,5,5, row 2
+ auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], y); // blocks: 6,6,6,6, 7,7,7,7, row 3
+ auto s12 = _mm256_packs_epi32(sumi1, sumi2); // 0,0,0,0, 2,2,2,2, 1,1,1,1, 3,3,3,3
+ auto s34 = _mm256_packs_epi32(sumi3, sumi4); // 4,4,4,4, 6,6,6,6, 5,5,5,5, 7,7,7,7
+ isum[2*iy+0] = _mm256_add_epi32(isum[2*iy+0], _mm256_madd_epi16(scs[0], s12));
+ isum[2*iy+1] = _mm256_add_epi32(isum[2*iy+1], _mm256_madd_epi16(scs[1], s34));
+ }
+#else
+ auto signs = MM256_SET_M128I(signs128, signs128);
+ auto shuffle = sign_shuffle;
+ auto s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[0], s));
+ s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[1], s));
+ s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[2], s));
+ s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[3], s));
+ __m256i scs[4] = {
+ _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]),
+ _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]),
+ };
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], y)); // blocks 4x0, 4x1, row 0
+ auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], y)); // blocks 4x2, 4x3, row 1
+ auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], y)); // blocks 4x4, 4x5, row 2
+ auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], y)); // blocks 4x6, 4x7, row 3
+ auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
+ auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
+ auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
+ isum[iy] = _mm256_add_epi32(isum[iy], sumi);
+ }
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]);
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256();
+#else
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum));
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
template <int nrc_y>
static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -7034,6 +7356,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[5] = mul_mat_iq4_ks_r4_q8_k<6>;
mm.funcs[6] = mul_mat_iq4_ks_r4_q8_k<7>;
mm.funcs[7] = mul_mat_iq4_ks_r4_q8_k<8>;
+#ifndef HAVE_FANCY_SIMD
+ // For some reason Zen4 does not like this particular function
+ mm.func16 = mul_mat_iq4_ks_r4_q8_k<16>;
+#endif
expected_typeB = GGML_TYPE_Q8_K32;
break;
case GGML_TYPE_IQ2_XXS_R4:
@@ -7046,6 +7372,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[5] = mul_mat_iq2_xxs_r4_q8_k<6>;
mm.funcs[6] = mul_mat_iq2_xxs_r4_q8_k<7>;
mm.funcs[7] = mul_mat_iq2_xxs_r4_q8_k<8>;
+ mm.func16 = mul_mat_iq2_xxs_r4_q8_k<16>;
expected_typeB = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ2_XS_R4:
@@ -7058,6 +7385,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[5] = mul_mat_iq2_xs_r4_q8_k<6>;
mm.funcs[6] = mul_mat_iq2_xs_r4_q8_k<7>;
mm.funcs[7] = mul_mat_iq2_xs_r4_q8_k<8>;
+#ifndef HAVE_FANCY_SIMD
+ // For some reason Zen4 does not like this particular function
+ mm.func16 = mul_mat_iq2_xs_r4_q8_k_16;
+#endif
expected_typeB = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ2_S_R4:
@@ -7070,6 +7401,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[5] = mul_mat_iq2_s_r4_q8_k<6>;
mm.funcs[6] = mul_mat_iq2_s_r4_q8_k<7>;
mm.funcs[7] = mul_mat_iq2_s_r4_q8_k<8>;
+ mm.func16 = mul_mat_iq2_s_r4_q8_k_16;
expected_typeB = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ3_XXS_R4:
@@ -7082,6 +7414,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[5] = mul_mat_iq3_xxs_r4_q8_k<6>;
mm.funcs[6] = mul_mat_iq3_xxs_r4_q8_k<7>;
mm.funcs[7] = mul_mat_iq3_xxs_r4_q8_k<8>;
+ mm.func16 = mul_mat_iq3_xxs_r4_q8_k<16>;
expected_typeB = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_Q2_K_R4:
@@ -7166,6 +7499,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[5] = mul_mat_iq4_k_r4_q8_k<6>;
mm.funcs[6] = mul_mat_iq4_k_r4_q8_k<7>;
mm.funcs[7] = mul_mat_iq4_k_r4_q8_k<8>;
+ mm.func16 = mul_mat_iq4_k_r4_q8_k<16>;
expected_typeB = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ5_K_R4:
@@ -7178,6 +7512,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[5] = mul_mat_iq5_k_r4_q8_k<6>;
mm.funcs[6] = mul_mat_iq5_k_r4_q8_k<7>;
mm.funcs[7] = mul_mat_iq5_k_r4_q8_k<8>;
+ mm.func16 = mul_mat_iq5_k_r4_q8_k<16>;
expected_typeB = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ2_K_R4:
@@ -7202,6 +7537,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[5] = mul_mat_iq3_k_r4_q8_k<6>;
mm.funcs[6] = mul_mat_iq3_k_r4_q8_k<7>;
mm.funcs[7] = mul_mat_iq3_k_r4_q8_k<8>;
+#ifdef HAVE_FANCY_SIMD
+ mm.func16 = mul_mat_iq3_k_r4_q8_k<16>;
+#endif
expected_typeB = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_Q4_0_R4:
@@ -11487,18 +11825,22 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
break;
case GGML_TYPE_IQ2_XXS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_xxs_r4_q8_k);
+ m.func16 = mul_mat_iq2_xxs_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ2_XS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_xs_r4_q8_k);
+ m.func16 = mul_mat_iq2_xs_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ2_S_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_s_r4_q8_k);
+ m.func16 = mul_mat_iq2_s_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ3_XXS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k);
+ m.func16 = mul_mat_iq3_xxs_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_Q2_K_R4: