summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-01-30 18:36:24 +0200
committerGitHub <noreply@github.com>2025-01-30 18:36:24 +0200
commitecf111a11ca56ff0731308f94bd6c5e96658b6ef (patch)
treef05decc6721785febc889b246955571c32b28b4f
parent2e6b523853a8659c63283a6deca805051ecd713a (diff)
Deepseek-Lite (#184)
* Quantization mixes tweaks * Make iq4_nl_r4 work with row size that are not a multiple of 128 ... on Zen4 * Make iq4_nl_r4 work with row size that are not a multiple of 128 ... on AVX2 * Make iq4_nl_r4 work with row size that are not a multiple of 128 ... on AVX2 * Make q6_0_w4 work with row size that are not a multiple of 128 ... on Zen4 * Make q6_0_w4 work with row size that are not a multiple of 128 ... on Zen4 * Make q5_0_r4 work with row size that are not a multiple of 128 ... on Zen4 and AVX2 * Make q5,6_0_r4, iq4_nl_e4 work with row size that are not a multiple of 128 also on NEON. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp467
-rw-r--r--src/llama.cpp18
2 files changed, 315 insertions, 170 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 7fd56c42..f633229d 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -2474,44 +2474,63 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data
auto m4 = _mm512_set1_epi8(0xf);
auto values = load_iq4nl_values_512();
int nb = n / QK4_NL;
- GGML_ASSERT(nb%4 == 0);
__m512 acc[2*nrc_y] = {};
__m512i qx[4];
+ float d8[8*nrc_y];
+ auto prepare = [&qx, &m4, &values] (const block_iq4_nl_r4& iq4l, const block_iq4_nl_r4& iq4h) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l.d));
+ auto scales1 = _mm256_set_m128(scales128, scales128);
+ scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h.d));
+ auto scales2 = _mm256_set_m128(scales128, scales128);
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+0)),
+ _mm256_loadu_si256((const __m256i *)iq4h.qs+0), 1);
+ auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+1)),
+ _mm256_loadu_si256((const __m256i *)iq4h.qs+1), 1);
+ qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
+ qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
+ qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
+ qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
+ return scales;
+ };
+ auto dot = [&qx] (__m256i y8) {
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ return sumi;
+ };
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_iq4_nl_r4 * iq4l = (const block_iq4_nl_r4 *)((const char *)vx + (ix+0)*bx);
const block_iq4_nl_r4 * iq4h = (const block_iq4_nl_r4 *)((const char *)vx + (ix+4)*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)));
+ }
for (int k = 0; k < 4; ++k) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d));
- auto scales1 = _mm256_set_m128(scales128, scales128);
- scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d));
- auto scales2 = _mm256_set_m128(scales128, scales128);
- auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
- auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-64.f));
- auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)),
- _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1);
- auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)),
- _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1);
- qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
- qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
- qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
- qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
+ auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
- auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]));
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
+ auto dy = _mm512_set1_ps(d8[8*iy+k]);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
- acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq4l[ib], iq4h[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
+ auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]);
+ auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-64.f), acc[2*iy+1], acc[2*iy+0]);
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
@@ -2530,37 +2549,57 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
auto values = MM256_SET_M128I(values128, values128);
int nb = n / QK4_NL;
- GGML_ASSERT(nb%4 == 0);
__m256 acc[nrc_y] = {};
- //__m256 acc[2*nrc_y] = {};
+ __m256i qs[4];
+ float d8[4*nrc_y];
+ auto prepare = [&qs, &values, &m4] (const block_iq4_nl_r4& iq4) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4.d));
+ auto scales = _mm256_set_m128(scales128, scales128);
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq4.qs+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq4.qs+1);
+ qs[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
+ qs[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
+ qs[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
+ qs[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
+ return scales;
+ };
+ auto dot = [&qs, &m1] (__m256i y) {
+ auto u1 = _mm256_sign_epi8(qs[0], qs[0]);
+ auto u2 = _mm256_sign_epi8(qs[1], qs[1]);
+ auto sumi1 = _mm256_add_epi32(
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qs[0]))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qs[1]))));
+ u1 = _mm256_sign_epi8(qs[2], qs[2]);
+ u2 = _mm256_sign_epi8(qs[3], qs[3]);
+ auto sumi2 = _mm256_add_epi32(
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qs[2]))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qs[3]))));
+ return _mm256_add_epi32(sumi1, sumi2);
+ };
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ _mm_storeu_ps(d8+4*iy, _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)));
+ }
for (int k = 0; k < 4; ++k) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d));
- auto scales = _mm256_set_m128(scales128, scales128);
- auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0);
- auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1);
- auto q1 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
- auto q2 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
- auto q3 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
- auto q4 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
- auto s1 = _mm256_sign_epi8(q1, q1);
- auto s2 = _mm256_sign_epi8(q2, q2);
- auto s3 = _mm256_sign_epi8(q3, q3);
- auto s4 = _mm256_sign_epi8(q4, q4);
-
+ auto scales = prepare(iq4[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
- auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q1))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q2))));
- auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q3))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q4))));
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]);
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq4[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
info.store(ix, iy, sum);
@@ -2797,43 +2836,73 @@ static void mul_mat_q5_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D
Q8<nrc_y, block_q8_1_x4> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto m5 = _mm256_set1_epi8(0x10);
+#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
+#endif
+ auto mscale = _mm256_set_m128(_mm_set1_ps(-8.f), _mm_set1_ps(1.f));
int nb = n / QK5_0;
- GGML_ASSERT(nb%4 == 0);
__m256 acc[nrc_y] = {};
+ __m256i qx[4];
float d8[8*nrc_y];
+ auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5.d));
+ auto scales = _mm256_set_m128(scales128, scales128);
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq5.qs+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq5.qs+1);
+ auto hbits = _mm_loadu_si128((const __m128i *)iq5.qh);
+ auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 1), hbits);
+ qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 4), m5));
+ qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 2), m5));
+ qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hb, m5));
+ qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hb, 2), m5));;
+ return scales;
+ };
+#ifdef HAVE_FANCY_SIMD
+ auto dot = [&qx] (__m256i y) {
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ return sumi;
+ };
+#else
+ auto dot = [&qx, &m1] (__m256i y) {
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
+ return sumi;
+ };
+#endif
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d));
- _mm256_storeu_ps(d8 + 8*iy, scales);
+ _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(mscale, scales));
}
for (int k = 0; k < 4; ++k) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5[4*ib4+k].d));
- auto scales = _mm256_set_m128(scales128, scales128);
- auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f));
- auto bits1 = _mm256_loadu_si256((const __m256i *)iq5[4*ib4+k].qs+0);
- auto bits2 = _mm256_loadu_si256((const __m256i *)iq5[4*ib4+k].qs+1);
- auto hbits = _mm_loadu_si128((const __m128i *)iq5[4*ib4+k].qh);
- auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 1), hbits);
- auto q1 = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 4), m5));
- auto q2 = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 2), m5));
- auto q3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hb, m5));
- auto q4 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hb, 2), m5));;
+ auto scales = prepare(iq5[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00)),
- _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa)),
- _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff)));
- auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]);
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq5[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-8.f*GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]);
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
info.store(ix, iy, sum);
@@ -2853,50 +2922,68 @@ static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
auto m4 = _mm512_set1_epi8(0xf);
auto m5 = _mm512_set1_epi8(0x10);
int nb = n / QK5_0;
- GGML_ASSERT(nb%4 == 0);
__m512 acc[2*nrc_y] = {};
__m512i qx[4];
+ float d8[8*nrc_y];
+ auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5l, const block_q5_0_r4& iq5h) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5l.d));
+ auto scales1 = _mm256_set_m128(scales128, scales128);
+ scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5h.d));
+ auto scales2 = _mm256_set_m128(scales128, scales128);
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+0)),
+ _mm256_loadu_si256((const __m256i *)iq5h.qs+0), 1);
+ auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+1)),
+ _mm256_loadu_si256((const __m256i *)iq5h.qs+1), 1);
+ auto hbits1 = _mm_loadu_si128((const __m128i *)iq5l.qh);
+ auto hbits2 = _mm_loadu_si128((const __m128i *)iq5h.qh);
+ auto hb1 = MM256_SET_M128I(_mm_srli_epi16(hbits1, 1), hbits1);
+ auto hb2 = MM256_SET_M128I(_mm_srli_epi16(hbits2, 1), hbits2);
+ auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hb1), hb2, 1);
+ qx[0] = _mm512_or_si512(_mm512_and_si512(bits1, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 4), m5));
+ qx[1] = _mm512_or_si512(_mm512_and_si512(bits2, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5));
+ qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4), _mm512_and_si512(hb, m5));
+ qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4), _mm512_and_si512(_mm512_srli_epi16(hb, 2), m5));
+ return scales;
+ };
+ auto dot = [&qx] (__m256i y8) {
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ return sumi;
+ };
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q5_0_r4 * iq5l = (const block_q5_0_r4 *)((const char *)vx + (ix+0)*bx);
const block_q5_0_r4 * iq5h = (const block_q5_0_r4 *)((const char *)vx + (ix+4)*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)));
+ }
for (int k = 0; k < 4; ++k) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5l[4*ib4+k].d));
- auto scales1 = _mm256_set_m128(scales128, scales128);
- scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5h[4*ib4+k].d));
- auto scales2 = _mm256_set_m128(scales128, scales128);
- auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
- auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-8.f));
- auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[4*ib4+k].qs+0)),
- _mm256_loadu_si256((const __m256i *)iq5h[4*ib4+k].qs+0), 1);
- auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[4*ib4+k].qs+1)),
- _mm256_loadu_si256((const __m256i *)iq5h[4*ib4+k].qs+1), 1);
- auto hbits1 = _mm_loadu_si128((const __m128i *)iq5l[4*ib4+k].qh);
- auto hbits2 = _mm_loadu_si128((const __m128i *)iq5h[4*ib4+k].qh);
- auto hb1 = MM256_SET_M128I(_mm_srli_epi16(hbits1, 1), hbits1);
- auto hb2 = MM256_SET_M128I(_mm_srli_epi16(hbits2, 1), hbits2);
- auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hb1), hb2, 1);
- qx[0] = _mm512_or_si512(_mm512_and_si512(bits1, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 4), m5));
- qx[1] = _mm512_or_si512(_mm512_and_si512(bits2, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5));
- //qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5);
- qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4), _mm512_and_si512(hb, m5));
- qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4), _mm512_and_si512(_mm512_srli_epi16(hb, 2), m5));
+ auto scales = prepare(iq5l[4*ib4+k], iq5h[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
- auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]));
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
+ auto dy = _mm512_set1_ps(d8[8*iy+k]);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
- acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq5l[ib], iq5h[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
+ auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]);
+ auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]);
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
@@ -2919,51 +3006,72 @@ static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D
Q8<nrc_y, block_q8_1_x4> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto m6 = _mm256_set1_epi8(0x30);
+ auto mscale = _mm256_set_m128(_mm_set1_ps(-16.f), _mm_set1_ps(1.f));
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
#endif
int nb = n / QK6_0;
- GGML_ASSERT(nb%4 == 0);
__m256 acc[nrc_y] = {};
float d8[8*nrc_y];
+ __m256i qx[4];
+ auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6.d));
+ auto scales = _mm256_set_m128(scales128, scales128);
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq6.qs+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq6.qs+1);
+ auto hbits = _mm256_loadu_si256((const __m256i *)iq6.qh);
+ qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), m6));
+ qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), m6));
+ qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hbits, m6));
+ qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m6));
+ return scales;
+ };
+#ifdef HAVE_FANCY_SIMD
+ auto dot = [&qx] (__m256i y) {
+ auto sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ return sumi;
+ };
+#else
+ auto dot = [&qx, &m1] (__m256i y) {
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
+ return sumi;
+ };
+#endif
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d));
- _mm256_storeu_ps(d8 + 8*iy, scales);
+ _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(scales, mscale));
}
for (int k = 0; k < 4; ++k) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6[4*ib4+k].d));
- auto scales = _mm256_set_m128(scales128, scales128);
- auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-16.f));
- auto bits1 = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qs+0);
- auto bits2 = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qs+1);
- auto hbits = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qh);
- auto q1 = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), m6));
- auto q2 = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), m6));
- auto q3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hbits, m6));
- auto q4 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m6));
+ auto scales = prepare(iq6[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), q1, _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, q2, _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, q3, _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, q4, _mm256_shuffle_epi32(y, 0xff));
-#else
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00)),
- _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa)),
- _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff)));
- auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
-#endif
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]);
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq6[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-16.f*GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]);
+ }
+ }
+
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
info.store(ix, iy, sum);
@@ -2983,47 +3091,67 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
auto m4 = _mm512_set1_epi8(0xf);
auto m6 = _mm512_set1_epi8(0x30);
int nb = n / QK6_0;
- GGML_ASSERT(nb%4 == 0);
__m512 acc[2*nrc_y] = {};
__m512i qx[4];
+ float d8[8*nrc_y];
+ auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6l, const block_q6_0_r4& iq6h) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6l.d));
+ auto scales1 = _mm256_set_m128(scales128, scales128);
+ scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6h.d));
+ auto scales2 = _mm256_set_m128(scales128, scales128);
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+0)),
+ _mm256_loadu_si256((const __m256i *)iq6h.qs+0), 1);
+ auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+1)),
+ _mm256_loadu_si256((const __m256i *)iq6h.qs+1), 1);
+ auto hbits1 = _mm256_loadu_si256((const __m256i *)iq6l.qh);
+ auto hbits2 = _mm256_loadu_si256((const __m256i *)iq6h.qh);
+ auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hbits1), hbits2, 1);
+ qx[0] = _mm512_and_si512(bits1, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 4), m6);
+ qx[1] = _mm512_and_si512(bits2, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m6);;
+ qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(hb, m6);
+ qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4) | _mm512_and_si512(_mm512_srli_epi16(hb, 2), m6);
+ return scales;
+ };
+ auto dot = [&qx] (__m256i y8) {
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ return sumi;
+ };
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q6_0_r4 * iq6l = (const block_q6_0_r4 *)((const char *)vx + (ix+0)*bx);
const block_q6_0_r4 * iq6h = (const block_q6_0_r4 *)((const char *)vx + (ix+4)*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d));
+ _mm256_storeu_ps(d8 + 8*iy, scales);
+ }
for (int k = 0; k < 4; ++k) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6l[4*ib4+k].d));
- auto scales1 = _mm256_set_m128(scales128, scales128);
- scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6h[4*ib4+k].d));
- auto scales2 = _mm256_set_m128(scales128, scales128);
- auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
- auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-16.f));
- auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qs+0)),
- _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qs+0), 1);
- auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qs+1)),
- _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qs+1), 1);
- auto hbits1 = _mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qh);
- auto hbits2 = _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qh);
- auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hbits1), hbits2, 1);
- qx[0] = _mm512_and_si512(bits1, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 4), m6);
- qx[1] = _mm512_and_si512(bits2, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m6);;
- qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(hb, m6);
- qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4) | _mm512_and_si512(_mm512_srli_epi16(hb, 2), m6);
+ auto scales = prepare(iq6l[4*ib4+k], iq6h[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
- auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]));
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
+ auto dy = _mm512_set1_ps(d8[8*iy+k]);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
- acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq6l[ib], iq6h[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
+ auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]);
+ auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-16.f), acc[2*iy+1], acc[2*iy+0]);
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
@@ -12087,7 +12215,6 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
Q8<nrc_y, block_q8_0_x4> q8(info);
Dequantizer deq(vx, bx);
int nb = n / QK4_NL;
- GGML_ASSERT(nb%4 == 0);
int8x16_t qx[8];
float d8[4*nrc_y];
float32x4_t acc[nrc_y] = {};
@@ -12098,7 +12225,7 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
- auto scales = deq.prepare(ib4, k, qx);
+ auto scales = deq.prepare(4*ib4+k, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
auto sumi = interleaved_dotq(qx, y);
@@ -12107,6 +12234,16 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = deq.prepare(ib, qx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_0 *)q8.y[iy];
+ auto y = vld1q_s8_x2(qy[ib].qs);
+ auto sumi = interleaved_dotq(qx, y);
+ auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, deq.result(acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
@@ -12164,9 +12301,9 @@ void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
struct IQ4_NL_R4_Dequantizer {
IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {}
inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); }
- inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
- auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
- auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
+ inline float32x4_t prepare(int ib, int8x16_t * qx) const {
+ auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ib].d));
+ auto bits = vld1q_u8_x4(iq4[ib].qs);
prepare_iq4_nl_quants(values, m4, bits, qx);
return scales;
}
@@ -12242,10 +12379,10 @@ struct Q4_0_R8_Dequantizer {
struct Q5_0_R4_Dequantizer {
Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); }
- inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
- auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d));
- auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs);
- auto hbits = vld1q_u8(iq5[4*ib4+k].qh);
+ inline float32x4_t prepare(int ib, int8x16_t * qx) const {
+ auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ib].d));
+ auto lbits = vld1q_u8_x4(iq5[ib].qs);
+ auto hbits = vld1q_u8(iq5[ib].qh);
qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3
qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19
qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7
@@ -12271,10 +12408,10 @@ struct Q5_0_R4_Dequantizer {
struct Q6_0_R4_Dequantizer {
Q6_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
inline void new_row(int ix) { iq6 = (const block_q6_0_r4 *)(cx + ix*bx); }
- inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
- auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[4*ib4+k].d));
- auto lbits = vld1q_u8_x4(iq6[4*ib4+k].qs);
- auto hbits = vld1q_u8_x2(iq6[4*ib4+k].qh);
+ inline float32x4_t prepare(int ib, int8x16_t * qx) const {
+ auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ib].d));
+ auto lbits = vld1q_u8_x4(iq6[ib].qs);
+ auto hbits = vld1q_u8_x2(iq6[ib].qh);
qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3
qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19
qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7
diff --git a/src/llama.cpp b/src/llama.cpp
index b6a4a06d..570c056c 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -16075,7 +16075,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || is_iq2_m ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
++qs.i_attention_wv;
}
- else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
+ else if (qs.model.hparams.n_expert >= 8 && name.find("attn_k") != std::string::npos) {
+ new_type = GGML_TYPE_Q4_K;
+ }
+ else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q") != std::string::npos) {
new_type = GGML_TYPE_Q4_K;
}
else if (name.find("attn_qkv.weight") != std::string::npos) {
@@ -16088,7 +16091,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
++qs.i_ffn_down;
}
else if (name.find("attn_output.weight") != std::string::npos) {
- if (qs.model.hparams.n_expert == 8) {
+ if (qs.model.hparams.n_expert >= 4) {
new_type = GGML_TYPE_Q5_K;
} else {
if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_K;
@@ -16188,9 +16191,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (new_type == GGML_TYPE_Q5_K) new_type = GGML_TYPE_Q6_K;
}
++qs.i_attention_wv;
- } else if (name.find("attn_k.weight") != std::string::npos) {
+ } else if (name.find("attn_k") != std::string::npos) {
if (qs.params->attn_k_type < GGML_TYPE_COUNT) new_type = qs.params->attn_k_type;
- else if (qs.model.hparams.n_expert == 8) {
+ else if (qs.model.hparams.n_expert >= 8) {
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
// TODO: explore better strategies
new_type = GGML_TYPE_Q8_0;
@@ -16201,8 +16204,13 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4) {
new_type = GGML_TYPE_IQ2_S;
}
- } else if (name.find("attn_q.weight") != std::string::npos) {
+ } else if (name.find("attn_q") != std::string::npos) {
if (qs.params->attn_q_type < GGML_TYPE_COUNT) new_type = qs.params->attn_q_type;
+ else if (qs.model.hparams.n_expert >= 8) {
+ // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
+ // TODO: explore better strategies
+ new_type = GGML_TYPE_Q8_0;
+ }
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
new_type = GGML_TYPE_IQ3_XXS;
}