diff options
-rw-r--r-- | examples/llama-bench/llama-bench.cpp | 4 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 319 |
2 files changed, 68 insertions, 255 deletions
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index b46bd855..42320da8 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -756,7 +756,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param continue; } cmd_params_instance instance = { - /* .test_kind = */ TEST_KIND_PP, + /* .test_kind = */ TEST_KIND_TG, /* .model = */ m, /* .n_prompt = */ 0, /* .n_gen = */ n_gen, @@ -784,7 +784,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param continue; } cmd_params_instance instance = { - /* .test_kind = */ TEST_KIND_PP, + /* .test_kind = */ TEST_KIND_PG, /* .model = */ m, /* .n_prompt = */ n_pg.first, /* .n_gen = */ n_pg.second, diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 308d0dca..7fd56c42 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -4430,17 +4430,47 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI } template <int nrc_y> -static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +inline void process_min_r4_b32(int ibl, __m256 m4, __m256i mins, const Q8<nrc_y, block_q8_K>& q8, __m256 * acc) { + auto mins_l = _mm256_castsi256_si128(mins); + auto mins_h = _mm256_extracti128_si256(mins, 1); + auto aux1 = _mm_unpacklo_epi32(mins_l, mins_h); + auto aux2 = _mm_unpackhi_epi32(mins_l, mins_h); + auto ic1 = _mm256_cvtepi8_epi32(aux1); + auto ic2 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux1, 0xee)); + auto ic3 = _mm256_cvtepi8_epi32(aux2); + auto ic4 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux2, 0xee)); + if constexpr (nrc_y == 1) { + auto bs = _mm256_loadu_ps((const float *)q8.y[0][ibl].bsums); + auto sumf = _mm256_mul_ps(_mm256_cvtepi32_ps(ic1), _mm256_shuffle_ps(bs, bs, 0x00)); + sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic2), _mm256_shuffle_ps(bs, bs, 0x55), sumf); + sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic3), _mm256_shuffle_ps(bs, bs, 0xaa), sumf); + sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic4), _mm256_shuffle_ps(bs, bs, 0xff), sumf); + acc[0] = _mm256_fmadd_ps(m4, sumf, acc[0]); + } else { + auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic1)); + auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic2)); + auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic3)); + auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic4)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); + acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]); + } + } +} + +template <int nrc_y> +static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8<nrc_y, block_q8_K> q8(info); auto mf = _mm256_set1_epi8(0xf); auto m3 = _mm256_set1_epi8(0x30); -#ifndef HAVE_FANCY_SIMD - auto m1 = _mm256_set1_epi16(1); -#endif int nbl = n / QK_K; union { __m256i vec; uint32_t val[8]; } hd; __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; __m256i qx[4]; for (int ix = 0; ix < nrc_x; ix += 4) { const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx); @@ -4448,31 +4478,20 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d)); auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl)); auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1))); - if constexpr (nrc_y == 1) { - d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl))); - } auto lbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l); auto hbits128 = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_h); auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4)); hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m3)); auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3)); - auto shuffle = _mm256_set1_epi64x(0x0000000400000000); - auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); - auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); - auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); - auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - for (int iy = 0; iy < nrc_y; ++iy) { - auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); - acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]); - } + process_min_r4_b32(ibl, m4, mins, q8, acc); for (int ib = 0; ib < QK_K/32; ++ib) { - auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib])))); +#ifdef HAVE_FANCY_SIMD + auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib])); +#else + auto aux = _mm_set1_epi32(hd.val[ib]); + aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux)); + auto scales_d = MM256_SET_M128I(aux, aux); +#endif auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0); auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1); qx[0] = _mm256_and_si256(bits1, mf); @@ -4487,21 +4506,20 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi)); #else auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); - auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales_d, _mm256_add_epi16(sumi1, sumi2))); #endif - if constexpr (nrc_y == 1) { - acc[iy] = _mm256_fmadd_ps(scales_d, _mm256_cvtepi32_ps(sumi), acc[iy]); - } else { - float d8 = q8.scale(iy, ibl); - acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]); - } } } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } } for (int iy = 0; iy < nrc_y; ++iy) { auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); @@ -4511,113 +4529,17 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D } } -#ifdef HAVE_FANCY_SIMD -template <int nrc_y> -static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - //mul_mat_q4_k_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x); - if constexpr (nrc_y == 1){ - mul_mat_q4_k_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x); - } else { - GGML_ASSERT(nrc_x%8 == 0); - Q8<nrc_y, block_q8_K> q8(info); - auto mf = _mm512_set1_epi8(0xf); - int nbl = n / QK_K; - using helper_t = union { __m512i vec; uint32_t val[16]; }; - helper_t hd, hm; - __m512 acc[nrc_y] = {}; - __m512i isum[nrc_y] = {}; - __m512i qx[4]; - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_q4_k_r4 * iq4l = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx); - const block_q4_k_r4 * iq4h = (const block_q4_k_r4 *)((const char *)vx + (ix+4)*bx); - for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 - auto d1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l[ibl].d)); - auto d2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h[ibl].d)); - auto dl = _mm256_castps256_ps128(d1); - auto ml = _mm256_extractf128_ps(d1, 1); - auto dh = _mm256_castps256_ps128(d2); - auto mh = _mm256_extractf128_ps(d2, 1); - auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1); - auto m4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(ml, ml)), _mm256_set_m128(mh, mh), 1); - m4 = _mm512_mul_ps(m4, _mm512_set1_ps(-0.5f)); - auto slbits_l = _mm256_loadu_si256((const __m256i *)iq4l[ibl].scales_l); - auto shbits_l = _mm256_loadu_si256((const __m256i *)iq4h[ibl].scales_l); - auto slb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_l), shbits_l, 1); - auto sld = _mm512_and_si512(slb, mf); - auto slm = _mm512_and_si512(_mm512_srli_epi16(slb, 4), mf); - auto slbits_h = _mm_loadu_si128((const __m128i *)iq4l[ibl].scales_h); - auto shbits_h = _mm_loadu_si128((const __m128i *)iq4h[ibl].scales_h); - auto slbits_h2 = MM256_SET_M128I(_mm_srli_epi16(slbits_h, 4), slbits_h); - auto shbits_h2 = MM256_SET_M128I(_mm_srli_epi16(shbits_h, 4), shbits_h); - auto shb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_h2), shbits_h2, 1); - auto shd = _mm512_and_si512(_mm512_slli_epi16(shb, 4), _mm512_set1_epi8(0x30)); - auto shm = _mm512_and_si512(_mm512_slli_epi16(shb, 2), _mm512_set1_epi8(0x30)); - hd.vec = _mm512_or_si512(sld, shd); - hm.vec = _mm512_or_si512(slm, shm); - for (int ib = 0; ib < QK_K/32; ++ib) { - auto scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+0])); - auto scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+8])); - auto iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1); - scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+0])); - scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+8])); - auto iscales_m = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1); - auto scales_m = _mm512_mul_ps(m4, _mm512_cvtepi32_ps(iscales_m)); - auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+0)), - _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+0), 1); - auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+1)), - _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+1), 1); - qx[0] = _mm512_and_si512(bits1, mf); - qx[1] = _mm512_and_si512(bits2, mf); - qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), mf); - qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), mf); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi)); - float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; - acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]); - isum[iy] = _mm512_setzero_si512(); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1)); - auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3)); - info.store(ix+0, iy, sum1); - info.store(ix+4, iy, sum2); - acc[iy] = _mm512_setzero_ps(); - } - } - } -} -#else template <int nrc_y> -static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_q4_k_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x); -} -#endif - -template <int nrc_y> -static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8<nrc_y, block_q8_K> q8(info); auto mf = _mm256_set1_epi8(0xf); auto m10 = _mm256_set1_epi8(0x10); auto m30 = _mm256_set1_epi8(0x30); -#ifndef HAVE_FANCY_SIMD - auto m1 = _mm256_set1_epi16(1); -#endif int nbl = n / QK_K; union { __m256i vec; uint32_t val[8]; } hd; __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; __m256i qx[4]; for (int ix = 0; ix < nrc_x; ix += 4) { const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx); @@ -4625,31 +4547,20 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5[ibl].d)); auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl)); auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1))); - if constexpr (nrc_y == 1) { - d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl))); - } auto lbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l); auto hbits128 = _mm_loadu_si128((const __m128i *)iq5[ibl].scales_h); auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4)); hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m30)); auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m30)); - auto shuffle = _mm256_set1_epi64x(0x0000000400000000); - auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); - auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); - auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); - auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); - for (int iy = 0; iy < nrc_y; ++iy) { - auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); - acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]); - } + process_min_r4_b32(ibl, m4, mins, q8, acc); for (int ib = 0; ib < QK_K/32; ++ib) { - auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib])))); +#ifdef HAVE_FANCY_SIMD + auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib])); +#else + auto aux = _mm_set1_epi32(hd.val[ib]); + aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux)); + auto scales_d = MM256_SET_M128I(aux, aux); +#endif auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0); auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1); auto hbits128 = _mm_loadu_si128((const __m128i*)iq5[ibl].qh + ib); @@ -4666,21 +4577,22 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi)); #else auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); - auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + // To avoid overflow, we can only add up to 4 q5 x q8 products. + auto sumi = _mm256_add_epi32(_mm256_madd_epi16(scales_d, sumi1), _mm256_madd_epi16(scales_d, sumi2)); + isum[iy] = _mm256_add_epi32(isum[iy], sumi); #endif - if constexpr (nrc_y == 1) { - acc[iy] = _mm256_fmadd_ps(scales_d, _mm256_cvtepi32_ps(sumi), acc[iy]); - } else { - float d8 = q8.scale(iy, ibl); - acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]); - } } } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } } for (int iy = 0; iy < nrc_y; ++iy) { auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); @@ -4690,105 +4602,6 @@ static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D } } -#ifdef HAVE_FANCY_SIMD -template <int nrc_y> -static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - if constexpr (nrc_y == 1){ - mul_mat_q5_k_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x); - } else { - GGML_ASSERT(nrc_x%8 == 0); - Q8<nrc_y, block_q8_K> q8(info); - auto mf = _mm512_set1_epi8(0xf); - auto m10 = _mm512_set1_epi8(0x10); - int nbl = n / QK_K; - using helper_t = union { __m512i vec; uint32_t val[16]; }; - helper_t hd, hm; - __m512 acc[nrc_y] = {}; - __m512i isum[nrc_y] = {}; - __m512i qx[4]; - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_q5_k_r4 * iq5l = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx); - const block_q5_k_r4 * iq5h = (const block_q5_k_r4 *)((const char *)vx + (ix+4)*bx); - for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 - auto d1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5l[ibl].d)); - auto d2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5h[ibl].d)); - auto dl = _mm256_castps256_ps128(d1); - auto ml = _mm256_extractf128_ps(d1, 1); - auto dh = _mm256_castps256_ps128(d2); - auto mh = _mm256_extractf128_ps(d2, 1); - auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1); - auto m4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(ml, ml)), _mm256_set_m128(mh, mh), 1); - m4 = _mm512_mul_ps(m4, _mm512_set1_ps(-0.5f)); - auto slbits_l = _mm256_loadu_si256((const __m256i *)iq5l[ibl].scales_l); - auto shbits_l = _mm256_loadu_si256((const __m256i *)iq5h[ibl].scales_l); - auto slb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_l), shbits_l, 1); - auto sld = _mm512_and_si512(slb, mf); - auto slm = _mm512_and_si512(_mm512_srli_epi16(slb, 4), mf); - auto slbits_h = _mm_loadu_si128((const __m128i *)iq5l[ibl].scales_h); - auto shbits_h = _mm_loadu_si128((const __m128i *)iq5h[ibl].scales_h); - auto slbits_h2 = MM256_SET_M128I(_mm_srli_epi16(slbits_h, 4), slbits_h); - auto shbits_h2 = MM256_SET_M128I(_mm_srli_epi16(shbits_h, 4), shbits_h); - auto shb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_h2), shbits_h2, 1); - auto shd = _mm512_and_si512(_mm512_slli_epi16(shb, 4), _mm512_set1_epi8(0x30)); - auto shm = _mm512_and_si512(_mm512_slli_epi16(shb, 2), _mm512_set1_epi8(0x30)); - hd.vec = _mm512_or_si512(sld, shd); - hm.vec = _mm512_or_si512(slm, shm); - for (int ib = 0; ib < QK_K/32; ++ib) { - auto scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+0])); - auto scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+8])); - auto iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1); - scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+0])); - scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+8])); - auto iscales_m = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1); - auto scales_m = _mm512_mul_ps(m4, _mm512_cvtepi32_ps(iscales_m)); - auto lbits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[ibl].qs+2*ib+0)), - _mm256_loadu_si256((const __m256i *)iq5h[ibl].qs+2*ib+0), 1); - auto lbits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[ibl].qs+2*ib+1)), - _mm256_loadu_si256((const __m256i *)iq5h[ibl].qs+2*ib+1), 1); - auto hbits1 = _mm_loadu_si128((const __m128i*)iq5l[ibl].qh+ib); - auto hbits2 = _mm_loadu_si128((const __m128i*)iq5h[ibl].qh+ib); - auto hbl = MM256_SET_M128I(hbits1, _mm_slli_epi16(hbits1, 4)); - auto hbh = MM256_SET_M128I(hbits2, _mm_slli_epi16(hbits2, 4)); - auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbl), hbh, 1); - qx[0] = _mm512_or_si512(_mm512_and_si512(lbits1, mf), _mm512_and_si512(m10, hbits)); - qx[1] = _mm512_or_si512(_mm512_and_si512(lbits2, mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 2))); - qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(lbits1, 4), mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 1))); - qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(lbits2, 4), mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 3))); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi)); - float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; - acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]); - isum[iy] = _mm512_setzero_si512(); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1)); - auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3)); - info.store(ix+0, iy, sum1); - info.store(ix+4, iy, sum2); - acc[iy] = _mm512_setzero_ps(); - } - } - } -} -#else -template <int nrc_y> -static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_q5_k_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x); -} -#endif - template <int nrc_y> static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); |