diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-01-30 18:36:24 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-30 18:36:24 +0200 |
commit | ecf111a11ca56ff0731308f94bd6c5e96658b6ef (patch) | |
tree | f05decc6721785febc889b246955571c32b28b4f | |
parent | 2e6b523853a8659c63283a6deca805051ecd713a (diff) |
Deepseek-Lite (#184)
* Quantization mixes tweaks
* Make iq4_nl_r4 work with row size that are not a multiple of 128
... on Zen4
* Make iq4_nl_r4 work with row size that are not a multiple of 128
... on AVX2
* Make iq4_nl_r4 work with row size that are not a multiple of 128
... on AVX2
* Make q6_0_w4 work with row size that are not a multiple of 128
... on Zen4
* Make q6_0_w4 work with row size that are not a multiple of 128
... on Zen4
* Make q5_0_r4 work with row size that are not a multiple of 128
... on Zen4 and AVX2
* Make q5,6_0_r4, iq4_nl_e4 work with row size that are not a multiple of 128
also on NEON.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 467 | ||||
-rw-r--r-- | src/llama.cpp | 18 |
2 files changed, 315 insertions, 170 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 7fd56c42..f633229d 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2474,44 +2474,63 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data auto m4 = _mm512_set1_epi8(0xf); auto values = load_iq4nl_values_512(); int nb = n / QK4_NL; - GGML_ASSERT(nb%4 == 0); __m512 acc[2*nrc_y] = {}; __m512i qx[4]; + float d8[8*nrc_y]; + auto prepare = [&qx, &m4, &values] (const block_iq4_nl_r4& iq4l, const block_iq4_nl_r4& iq4h) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l.d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h.d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+0)), + _mm256_loadu_si256((const __m256i *)iq4h.qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+1)), + _mm256_loadu_si256((const __m256i *)iq4h.qs+1), 1); + qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4)); + qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4)); + qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4)); + qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4)); + return scales; + }; + auto dot = [&qx] (__m256i y8) { + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + return sumi; + }; for (int ix = 0; ix < nrc_x; ix += 8) { const block_iq4_nl_r4 * iq4l = (const block_iq4_nl_r4 *)((const char *)vx + (ix+0)*bx); const block_iq4_nl_r4 * iq4h = (const block_iq4_nl_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d)); - auto scales1 = _mm256_set_m128(scales128, scales128); - scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d)); - auto scales2 = _mm256_set_m128(scales128, scales128); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-64.f)); - auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)), - _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1); - auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)), - _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1); - qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4)); - qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4)); - qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4)); - qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4)); + auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); + auto dy = _mm512_set1_ps(d8[8*iy+k]); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq4l[ib], iq4h[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { - auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-64.f), acc[2*iy+1], acc[2*iy+0]); acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); @@ -2530,37 +2549,57 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); auto values = MM256_SET_M128I(values128, values128); int nb = n / QK4_NL; - GGML_ASSERT(nb%4 == 0); __m256 acc[nrc_y] = {}; - //__m256 acc[2*nrc_y] = {}; + __m256i qs[4]; + float d8[4*nrc_y]; + auto prepare = [&qs, &values, &m4] (const block_iq4_nl_r4& iq4) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4.d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq4.qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq4.qs+1); + qs[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); + qs[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); + qs[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); + qs[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); + return scales; + }; + auto dot = [&qs, &m1] (__m256i y) { + auto u1 = _mm256_sign_epi8(qs[0], qs[0]); + auto u2 = _mm256_sign_epi8(qs[1], qs[1]); + auto sumi1 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qs[0]))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qs[1])))); + u1 = _mm256_sign_epi8(qs[2], qs[2]); + u2 = _mm256_sign_epi8(qs[3], qs[3]); + auto sumi2 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qs[2]))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qs[3])))); + return _mm256_add_epi32(sumi1, sumi2); + }; for (int ix = 0; ix < nrc_x; ix += 4) { const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + _mm_storeu_ps(d8+4*iy, _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d))); + } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0); - auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1); - auto q1 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); - auto q2 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); - auto q3 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); - auto q4 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); - auto s1 = _mm256_sign_epi8(q1, q1); - auto s2 = _mm256_sign_epi8(q2, q2); - auto s3 = _mm256_sign_epi8(q3, q3); - auto s4 = _mm256_sign_epi8(q4, q4); - + auto scales = prepare(iq4[4*ib4+k]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q1))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q2)))); - auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q3))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q4)))); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]); + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq4[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); info.store(ix, iy, sum); @@ -2797,43 +2836,73 @@ static void mul_mat_q5_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D Q8<nrc_y, block_q8_1_x4> q8(info); auto m4 = _mm256_set1_epi8(0xf); auto m5 = _mm256_set1_epi8(0x10); +#ifndef HAVE_FANCY_SIMD auto m1 = _mm256_set1_epi16(1); +#endif + auto mscale = _mm256_set_m128(_mm_set1_ps(-8.f), _mm_set1_ps(1.f)); int nb = n / QK5_0; - GGML_ASSERT(nb%4 == 0); __m256 acc[nrc_y] = {}; + __m256i qx[4]; float d8[8*nrc_y]; + auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5.d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq5.qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq5.qs+1); + auto hbits = _mm_loadu_si128((const __m128i *)iq5.qh); + auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 1), hbits); + qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 4), m5)); + qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 2), m5)); + qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hb, m5)); + qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hb, 2), m5));; + return scales; + }; +#ifdef HAVE_FANCY_SIMD + auto dot = [&qx] (__m256i y) { + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + return sumi; + }; +#else + auto dot = [&qx, &m1] (__m256i y) { + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + return sumi; + }; +#endif for (int ix = 0; ix < nrc_x; ix += 4) { const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); - _mm256_storeu_ps(d8 + 8*iy, scales); + _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(mscale, scales)); } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5[4*ib4+k].d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f)); - auto bits1 = _mm256_loadu_si256((const __m256i *)iq5[4*ib4+k].qs+0); - auto bits2 = _mm256_loadu_si256((const __m256i *)iq5[4*ib4+k].qs+1); - auto hbits = _mm_loadu_si128((const __m128i *)iq5[4*ib4+k].qh); - auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 1), hbits); - auto q1 = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 4), m5)); - auto q2 = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 2), m5)); - auto q3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hb, m5)); - auto q4 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hb, 2), m5));; + auto scales = prepare(iq5[4*ib4+k]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00)), - _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55))); - auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa)), - _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff))); - auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k])); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq5[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-8.f*GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); info.store(ix, iy, sum); @@ -2853,50 +2922,68 @@ static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn auto m4 = _mm512_set1_epi8(0xf); auto m5 = _mm512_set1_epi8(0x10); int nb = n / QK5_0; - GGML_ASSERT(nb%4 == 0); __m512 acc[2*nrc_y] = {}; __m512i qx[4]; + float d8[8*nrc_y]; + auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5l, const block_q5_0_r4& iq5h) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5l.d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5h.d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+0)), + _mm256_loadu_si256((const __m256i *)iq5h.qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+1)), + _mm256_loadu_si256((const __m256i *)iq5h.qs+1), 1); + auto hbits1 = _mm_loadu_si128((const __m128i *)iq5l.qh); + auto hbits2 = _mm_loadu_si128((const __m128i *)iq5h.qh); + auto hb1 = MM256_SET_M128I(_mm_srli_epi16(hbits1, 1), hbits1); + auto hb2 = MM256_SET_M128I(_mm_srli_epi16(hbits2, 1), hbits2); + auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hb1), hb2, 1); + qx[0] = _mm512_or_si512(_mm512_and_si512(bits1, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 4), m5)); + qx[1] = _mm512_or_si512(_mm512_and_si512(bits2, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5)); + qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4), _mm512_and_si512(hb, m5)); + qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4), _mm512_and_si512(_mm512_srli_epi16(hb, 2), m5)); + return scales; + }; + auto dot = [&qx] (__m256i y8) { + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + return sumi; + }; for (int ix = 0; ix < nrc_x; ix += 8) { const block_q5_0_r4 * iq5l = (const block_q5_0_r4 *)((const char *)vx + (ix+0)*bx); const block_q5_0_r4 * iq5h = (const block_q5_0_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5l[4*ib4+k].d)); - auto scales1 = _mm256_set_m128(scales128, scales128); - scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5h[4*ib4+k].d)); - auto scales2 = _mm256_set_m128(scales128, scales128); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-8.f)); - auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[4*ib4+k].qs+0)), - _mm256_loadu_si256((const __m256i *)iq5h[4*ib4+k].qs+0), 1); - auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[4*ib4+k].qs+1)), - _mm256_loadu_si256((const __m256i *)iq5h[4*ib4+k].qs+1), 1); - auto hbits1 = _mm_loadu_si128((const __m128i *)iq5l[4*ib4+k].qh); - auto hbits2 = _mm_loadu_si128((const __m128i *)iq5h[4*ib4+k].qh); - auto hb1 = MM256_SET_M128I(_mm_srli_epi16(hbits1, 1), hbits1); - auto hb2 = MM256_SET_M128I(_mm_srli_epi16(hbits2, 1), hbits2); - auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hb1), hb2, 1); - qx[0] = _mm512_or_si512(_mm512_and_si512(bits1, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 4), m5)); - qx[1] = _mm512_or_si512(_mm512_and_si512(bits2, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5)); - //qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5); - qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4), _mm512_and_si512(hb, m5)); - qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4), _mm512_and_si512(_mm512_srli_epi16(hb, 2), m5)); + auto scales = prepare(iq5l[4*ib4+k], iq5h[4*ib4+k]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); + auto dy = _mm512_set1_ps(d8[8*iy+k]); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq5l[ib], iq5h[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { - auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]); acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); @@ -2919,51 +3006,72 @@ static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D Q8<nrc_y, block_q8_1_x4> q8(info); auto m4 = _mm256_set1_epi8(0xf); auto m6 = _mm256_set1_epi8(0x30); + auto mscale = _mm256_set_m128(_mm_set1_ps(-16.f), _mm_set1_ps(1.f)); #ifndef HAVE_FANCY_SIMD auto m1 = _mm256_set1_epi16(1); #endif int nb = n / QK6_0; - GGML_ASSERT(nb%4 == 0); __m256 acc[nrc_y] = {}; float d8[8*nrc_y]; + __m256i qx[4]; + auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6.d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq6.qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq6.qs+1); + auto hbits = _mm256_loadu_si256((const __m256i *)iq6.qh); + qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), m6)); + qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), m6)); + qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hbits, m6)); + qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m6)); + return scales; + }; +#ifdef HAVE_FANCY_SIMD + auto dot = [&qx] (__m256i y) { + auto sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + return sumi; + }; +#else + auto dot = [&qx, &m1] (__m256i y) { + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2)); + return sumi; + }; +#endif for (int ix = 0; ix < nrc_x; ix += 4) { const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); - _mm256_storeu_ps(d8 + 8*iy, scales); + _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(scales, mscale)); } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6[4*ib4+k].d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-16.f)); - auto bits1 = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qs+0); - auto bits2 = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qs+1); - auto hbits = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qh); - auto q1 = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), m6)); - auto q2 = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), m6)); - auto q3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hbits, m6)); - auto q4 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m6)); + auto scales = prepare(iq6[4*ib4+k]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); -#ifdef HAVE_FANCY_SIMD - auto sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), q1, _mm256_shuffle_epi32(y, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, q2, _mm256_shuffle_epi32(y, 0x55)); - sumi = _mm256_dpbusd_epi32(sumi, q3, _mm256_shuffle_epi32(y, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, q4, _mm256_shuffle_epi32(y, 0xff)); -#else - auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00)), - _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55))); - auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa)), - _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff))); - auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2)); -#endif + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k])); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq6[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-16.f*GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); info.store(ix, iy, sum); @@ -2983,47 +3091,67 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn auto m4 = _mm512_set1_epi8(0xf); auto m6 = _mm512_set1_epi8(0x30); int nb = n / QK6_0; - GGML_ASSERT(nb%4 == 0); __m512 acc[2*nrc_y] = {}; __m512i qx[4]; + float d8[8*nrc_y]; + auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6l, const block_q6_0_r4& iq6h) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6l.d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6h.d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+0)), + _mm256_loadu_si256((const __m256i *)iq6h.qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+1)), + _mm256_loadu_si256((const __m256i *)iq6h.qs+1), 1); + auto hbits1 = _mm256_loadu_si256((const __m256i *)iq6l.qh); + auto hbits2 = _mm256_loadu_si256((const __m256i *)iq6h.qh); + auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hbits1), hbits2, 1); + qx[0] = _mm512_and_si512(bits1, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 4), m6); + qx[1] = _mm512_and_si512(bits2, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m6);; + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(hb, m6); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4) | _mm512_and_si512(_mm512_srli_epi16(hb, 2), m6); + return scales; + }; + auto dot = [&qx] (__m256i y8) { + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + return sumi; + }; for (int ix = 0; ix < nrc_x; ix += 8) { const block_q6_0_r4 * iq6l = (const block_q6_0_r4 *)((const char *)vx + (ix+0)*bx); const block_q6_0_r4 * iq6h = (const block_q6_0_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); + _mm256_storeu_ps(d8 + 8*iy, scales); + } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6l[4*ib4+k].d)); - auto scales1 = _mm256_set_m128(scales128, scales128); - scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6h[4*ib4+k].d)); - auto scales2 = _mm256_set_m128(scales128, scales128); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-16.f)); - auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qs+0)), - _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qs+0), 1); - auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qs+1)), - _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qs+1), 1); - auto hbits1 = _mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qh); - auto hbits2 = _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qh); - auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hbits1), hbits2, 1); - qx[0] = _mm512_and_si512(bits1, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 4), m6); - qx[1] = _mm512_and_si512(bits2, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m6);; - qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(hb, m6); - qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4) | _mm512_and_si512(_mm512_srli_epi16(hb, 2), m6); + auto scales = prepare(iq6l[4*ib4+k], iq6h[4*ib4+k]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); + auto dy = _mm512_set1_ps(d8[8*iy+k]); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq6l[ib], iq6h[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { - auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-16.f), acc[2*iy+1], acc[2*iy+0]); acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); @@ -12087,7 +12215,6 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, Q8<nrc_y, block_q8_0_x4> q8(info); Dequantizer deq(vx, bx); int nb = n / QK4_NL; - GGML_ASSERT(nb%4 == 0); int8x16_t qx[8]; float d8[4*nrc_y]; float32x4_t acc[nrc_y] = {}; @@ -12098,7 +12225,7 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); } for (int k = 0; k < 4; ++k) { - auto scales = deq.prepare(ib4, k, qx); + auto scales = deq.prepare(4*ib4+k, qx); for (int iy = 0; iy < nrc_y; ++iy) { auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); auto sumi = interleaved_dotq(qx, y); @@ -12107,6 +12234,16 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, } } } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = deq.prepare(ib, qx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_0 *)q8.y[iy]; + auto y = vld1q_s8_x2(qy[ib].qs); + auto sumi = interleaved_dotq(qx, y); + auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d))); + acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); + } + } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, deq.result(acc[iy])); acc[iy] = vdupq_n_f32(0.f); @@ -12164,9 +12301,9 @@ void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, struct IQ4_NL_R4_Dequantizer { IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {} inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); } - inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); - auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); + inline float32x4_t prepare(int ib, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ib].d)); + auto bits = vld1q_u8_x4(iq4[ib].qs); prepare_iq4_nl_quants(values, m4, bits, qx); return scales; } @@ -12242,10 +12379,10 @@ struct Q4_0_R8_Dequantizer { struct Q5_0_R4_Dequantizer { Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); } - inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d)); - auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs); - auto hbits = vld1q_u8(iq5[4*ib4+k].qh); + inline float32x4_t prepare(int ib, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ib].d)); + auto lbits = vld1q_u8_x4(iq5[ib].qs); + auto hbits = vld1q_u8(iq5[ib].qh); qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3 qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19 qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7 @@ -12271,10 +12408,10 @@ struct Q5_0_R4_Dequantizer { struct Q6_0_R4_Dequantizer { Q6_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} inline void new_row(int ix) { iq6 = (const block_q6_0_r4 *)(cx + ix*bx); } - inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[4*ib4+k].d)); - auto lbits = vld1q_u8_x4(iq6[4*ib4+k].qs); - auto hbits = vld1q_u8_x2(iq6[4*ib4+k].qh); + inline float32x4_t prepare(int ib, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ib].d)); + auto lbits = vld1q_u8_x4(iq6[ib].qs); + auto hbits = vld1q_u8_x2(iq6[ib].qh); qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3 qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19 qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7 diff --git a/src/llama.cpp b/src/llama.cpp index b6a4a06d..570c056c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16075,7 +16075,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || is_iq2_m ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; ++qs.i_attention_wv; } - else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) { + else if (qs.model.hparams.n_expert >= 8 && name.find("attn_k") != std::string::npos) { + new_type = GGML_TYPE_Q4_K; + } + else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q") != std::string::npos) { new_type = GGML_TYPE_Q4_K; } else if (name.find("attn_qkv.weight") != std::string::npos) { @@ -16088,7 +16091,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ++qs.i_ffn_down; } else if (name.find("attn_output.weight") != std::string::npos) { - if (qs.model.hparams.n_expert == 8) { + if (qs.model.hparams.n_expert >= 4) { new_type = GGML_TYPE_Q5_K; } else { if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_K; @@ -16188,9 +16191,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (new_type == GGML_TYPE_Q5_K) new_type = GGML_TYPE_Q6_K; } ++qs.i_attention_wv; - } else if (name.find("attn_k.weight") != std::string::npos) { + } else if (name.find("attn_k") != std::string::npos) { if (qs.params->attn_k_type < GGML_TYPE_COUNT) new_type = qs.params->attn_k_type; - else if (qs.model.hparams.n_expert == 8) { + else if (qs.model.hparams.n_expert >= 8) { // for the 8-expert model, bumping this to Q8_0 trades just ~128MB // TODO: explore better strategies new_type = GGML_TYPE_Q8_0; @@ -16201,8 +16204,13 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4) { new_type = GGML_TYPE_IQ2_S; } - } else if (name.find("attn_q.weight") != std::string::npos) { + } else if (name.find("attn_q") != std::string::npos) { if (qs.params->attn_q_type < GGML_TYPE_COUNT) new_type = qs.params->attn_q_type; + else if (qs.model.hparams.n_expert >= 8) { + // for the 8-expert model, bumping this to Q8_0 trades just ~128MB + // TODO: explore better strategies + new_type = GGML_TYPE_Q8_0; + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) { new_type = GGML_TYPE_IQ3_XXS; } |