diff options
-rw-r--r-- | ggml/src/ggml-common.h | 15 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 1 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 690 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 150 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 4 | ||||
-rw-r--r-- | src/llama.cpp | 8 |
6 files changed, 437 insertions, 431 deletions
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 7f79b27b..d08870ad 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -236,6 +236,11 @@ typedef struct { int8_t qs[4*QK8_0]; } block_q8_0_x4; static_assert(sizeof(block_q8_0_x4) == 4*sizeof(block_q8_0), "wrong q8_0_x4 block size/padding"); +typedef struct { + ggml_half d[8]; + int8_t qs[8*QK8_0]; +} block_q8_0_r8; +static_assert(sizeof(block_q8_0_r8) == 8*sizeof(block_q8_0), "wrong q8_0_r8 block size/padding"); typedef struct { ggml_half d[4]; // deltas for 4 q4_0 blocks @@ -534,12 +539,12 @@ typedef struct { static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); typedef struct { - ggml_half d[4]; - uint8_t scales_h[QK_K/32]; - uint8_t scales_l[QK_K/16]; - uint8_t qs[QK_K*2]; + ggml_half d[8]; + uint8_t scales_h[QK_K/16]; + uint8_t scales_l[QK_K/ 8]; + uint8_t qs[QK_K*4]; } block_iq4_xs_r4; -static_assert(sizeof(block_iq4_xs_r4) == 4*sizeof(ggml_half) + QK_K/32 + QK_K/16 + QK_K*2, "wrong iq4_xs_rs block size/padding"); +static_assert(sizeof(block_iq4_xs_r4) == 8*sizeof(block_iq4_xs), "wrong iq4_xs_rs block size/padding"); typedef struct { uint8_t scales[QK_K/32]; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 23ac9915..391d9e2e 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -936,7 +936,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) #if defined(__ARM_NEON) for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; float32x4_t srcv [8]; float32x4_t asrcv[8]; float32x4_t amaxv[8]; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 7ddaee2a..d8273415 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -245,9 +245,7 @@ struct MulMat { case GGML_TYPE_Q4_0_R4: case GGML_TYPE_Q5_0_R4: case GGML_TYPE_Q6_0_R4: - case GGML_TYPE_Q8_0_R4: case GGML_TYPE_IQ4_NL_R4: - case GGML_TYPE_IQ4_XS_R4: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_K_R4: @@ -259,6 +257,8 @@ struct MulMat { case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_BN_R4: return 4; + case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_Q8_0_R4: case GGML_TYPE_Q8_K_R8: return 8; case GGML_TYPE_BF16_R16: return 16; default: return 1; @@ -2902,91 +2902,103 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn #ifdef HAVE_FANCY_SIMD template <int nrc_y> static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(nrc_x%16 == 0); Q8<nrc_y, block_q8_1_x4> q8(info); int nb = n / QK8_0; GGML_ASSERT(nb%4 == 0); if constexpr (nrc_y == 1) { auto m127 = _mm256_set1_epi8(127); - auto m1 = _mm256_set1_epi16(1); - __m256 acc[nrc_y] = {}; - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); + __m256 acc[2] = {}; + __m256i qx[8]; + float d8[8]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { + _mm256_storeu_ps(d8, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d))); for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq8[4*ib4+k].d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-63.5f)); - auto q1 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0), m127); - auto q2 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1), m127); - auto q3 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2), m127); - auto q4 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3), m127); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55)))); - auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff)))); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[iy]); - } + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d)); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-127.f)); + qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0), m127); + qx[1] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1), m127); + qx[2] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2), m127); + qx[3] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3), m127); + qx[4] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4), m127); + qx[5] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5), m127); + qx[6] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6), m127); + qx[7] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7), m127); + auto y4l = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+0); + auto y4h = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+1); + auto yl = MM256_SET_M128I(y4l, y4l); + auto yh = MM256_SET_M128I(y4h, y4h); + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff)); + sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[k])); + acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]); + acc[1] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[k+4]), acc[1]); } } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); - info.store(ix, iy, sum); - acc[iy] = _mm256_setzero_ps(); - } + info.store(ix, 0, _mm256_add_ps(acc[0], acc[1])); + acc[0] = acc[1] = _mm256_setzero_ps(); } } else { __m512 acc[2*nrc_y] = {}; - __m512i qx[4]; + __m512i qx[8]; + float d8[8*nrc_y]; auto m127 = _mm512_set1_epi8(127); - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_q8_0_x4 * q8l = (const block_q8_0_x4 *)((const char *)vx + (ix+0)*bx); - const block_q8_0_x4 * q8h = (const block_q8_0_x4 *)((const char *)vx + (ix+4)*bx); + for (int ix = 0; ix < nrc_x; ix += 16) { + const block_q8_0_r8 * q8l = (const block_q8_0_r8 *)((const char *)vx + (ix+0)*bx); + const block_q8_0_r8 * q8h = (const block_q8_0_r8 *)((const char *)vx + (ix+8)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8l[4*ib4+k].d)); - auto scales1 = _mm256_set_m128(scales128, scales128); - scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8h[4*ib4+k].d)); - auto scales2 = _mm256_set_m128(scales128, scales128); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-63.5f)); - qx[0] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+0)), - _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+0), 1); - qx[1] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+1)), - _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+1), 1); - qx[2] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+2)), - _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+2), 1); - qx[3] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+3)), - _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+3), 1); - qx[0] = _mm512_add_epi8(qx[0], m127); - qx[1] = _mm512_add_epi8(qx[1], m127); - qx[2] = _mm512_add_epi8(qx[2], m127); - qx[3] = _mm512_add_epi8(qx[3], m127); + auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[4*ib4+k].d)); + auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[4*ib4+k].d)); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-127.f)); + for (int j = 0; j < 8; ++j) { + qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+j)), + _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+j), 1); + qx[j] = _mm512_add_epi8(qx[j], m127); + } for (int iy = 0; iy < nrc_y; ++iy) { - auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0); + auto y4h = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1); + auto y8l = MM256_SET_M128I(y4l, y4l); + auto y8h = MM256_SET_M128I(y4h, y4h); + auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1); + auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1); auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff))); + sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff))); + auto dy = _mm512_set1_ps(d8[8*iy+k]); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); } } } for (int iy = 0; iy < nrc_y; ++iy) { auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + info.store(ix, iy, sum512); acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); - auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); - auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); - info.store(ix+0, iy, sum1); - info.store(ix+4, iy, sum2); + //auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); + //auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); + //info.store(ix+0, iy, sum1); + //info.store(ix+4, iy, sum2); } } } @@ -2994,45 +3006,72 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn #else template <int nrc_y> static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(nrc_x%8 == 0); Q8<nrc_y, block_q8_1_x4> q8(info); auto m1 = _mm256_set1_epi16(1); int nb = n / QK8_0; GGML_ASSERT(nb%4 == 0); __m256 acc[nrc_y] = {}; float d8[4*nrc_y]; - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)); _mm_storeu_ps(d8 + 4*iy, scales); } for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq8[4*ib4+k].d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0); - auto q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1); - auto q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2); - auto q4 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3); + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d)); + auto q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0); + auto q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1); + auto q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2); + auto q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3); + auto s0 = _mm256_sign_epi8(q0, q0); auto s1 = _mm256_sign_epi8(q1, q1); auto s2 = _mm256_sign_epi8(q2, q2); auto s3 = _mm256_sign_epi8(q3, q3); - auto s4 = _mm256_sign_epi8(q4, q4); for (int iy = 0; iy < nrc_y; ++iy) { - auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q1))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q2)))); - auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q3))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q4)))); + auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0); + auto y = MM256_SET_M128I(y128, y128); + auto sumi1 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1))) + ); + auto sumi2 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3))) + ); + auto sumi = _mm256_add_epi32(sumi1, sumi2); auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4); + q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5); + q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6); + q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7); + s0 = _mm256_sign_epi8(q0, q0); + s1 = _mm256_sign_epi8(q1, q1); + s2 = _mm256_sign_epi8(q2, q2); + s3 = _mm256_sign_epi8(q3, q3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1); + auto y = MM256_SET_M128I(y128, y128); + auto sumi1 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1))) + ); + auto sumi2 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3))) + ); + auto sumi = _mm256_add_epi32(sumi1, sumi2); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); } } } for (int iy = 0; iy < nrc_y; ++iy) { - auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); - info.store(ix, iy, sum); + info.store(ix, iy, acc[iy]); acc[iy] = _mm256_setzero_ps(); } } @@ -3041,9 +3080,11 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn template <int nrc_y> static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(nrc_x%8 == 0); Q8<nrc_y, block_q8_K> q8(info); auto m4 = _mm256_set1_epi8(0xf); + auto m30 = _mm256_set1_epi8(0x30); + auto m32 = _mm256_set1_epi8(32); #ifndef HAVE_FANCY_SIMD auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); @@ -3052,40 +3093,40 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const auto values = load_iq4nl_values_256(); #endif int nbl = n / QK_K; - using helper_t = union { __m256i vec; uint32_t val[8]; }; + using helper_t = union { __m256i vec[2]; uint64_t val[8]; }; helper_t h; __m256 acc[nrc_y] = {}; - __m256i isum[nrc_y] = {}; __m256i qx[4]; - for (int ix = 0; ix < nrc_x; ix += 4) { + for (int ix = 0; ix < nrc_x; ix += 8) { const block_iq4_xs_r4 * iq4 = (const block_iq4_xs_r4 *)((const char *)vx + (ix+0)*bx); for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 - auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[ibl].d)); - auto d4 = _mm256_set_m128(dl, dl); - auto slbits = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_l); - auto sl = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(slbits, 4), slbits), _mm256_set1_epi8(0xf)); - auto aux64 = (const uint64_t *)iq4[ibl].scales_h; - auto shbits = _mm_set_epi64x(aux64[0] >> 2, aux64[0]); - auto sh = _mm256_and_si256(MM256_SET_M128I(shbits, _mm_slli_epi16(shbits, 4)), _mm256_set1_epi8(0x30)); - h.vec = _mm256_sub_epi8(_mm256_or_si256(sl, sh), _mm256_set1_epi8(32)); + auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d)); + auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l); + auto sl1 = _mm256_and_si256(slbits, m4); + auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4); + auto shbits = _mm_loadu_si128((const __m128i*)iq4[ibl].scales_h); + auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits); + h.vec[0] = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(_mm256_slli_epi16(sh, 4), m30)), m32); + h.vec[1] = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(sh, m30)), m32); + __m256i isum[nrc_y] = {}; for (int ib = 0; ib < QK_K/32; ++ib) { #ifdef HAVE_FANCY_SIMD - auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib])); + auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi64x(h.val[ib])); auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales)); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-64.f)); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-128.f)); for (int iy = 0; iy < nrc_y; ++iy) { float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]); } #else - auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle); + auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(h.val[ib])), s_shuffle); #endif - auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0); - auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1); - qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); - qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); - qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); - qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+1); + qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1)); + qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4))); + qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2)); + qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4))); #ifndef HAVE_FANCY_SIMD auto s1 = _mm256_sign_epi8(qx[0], qx[0]); auto s2 = _mm256_sign_epi8(qx[1], qx[1]); @@ -3093,7 +3134,8 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const auto s4 = _mm256_sign_epi8(qx[3], qx[3]); #endif for (int iy = 0; iy < nrc_y; ++iy) { - auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); + auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+0); + auto y = MM256_SET_M128I(y128, y128); #ifdef HAVE_FANCY_SIMD auto sumi = _mm256_setzero_si256(); sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); @@ -3106,20 +3148,51 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); - isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2))); - isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4))); + auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)), + _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4))); + isum[iy] = _mm256_add_epi32(isum[iy], sumi); +#endif + } + bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+2); + bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+3); + qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1)); + qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4))); + qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2)); + qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4))); +#ifndef HAVE_FANCY_SIMD + s1 = _mm256_sign_epi8(qx[0], qx[0]); + s2 = _mm256_sign_epi8(qx[1], qx[1]); + s3 = _mm256_sign_epi8(qx[2], qx[2]); + s4 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+1); + auto y = MM256_SET_M128I(y128, y128); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi)); +#else + auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)), + _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4))); + isum[iy] = _mm256_add_epi32(isum[iy], sumi); #endif } } for (int iy = 0; iy < nrc_y; ++iy) { acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); - isum[iy] = _mm256_setzero_si256(); } } for (int iy = 0; iy < nrc_y; ++iy) { - auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, acc[iy]); acc[iy] = _mm256_setzero_ps(); - info.store(ix+0, iy, sum); } } } @@ -3127,6 +3200,8 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const #ifdef HAVE_FANCY_SIMD template <int nrc_y> static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_iq4_xs_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x); + return; if constexpr (nrc_y == 1){ mul_mat_iq4_xs_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x); } else { @@ -10529,6 +10604,13 @@ IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31 } +IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x2_t& bits, int8x16_t * qx) { + qx[0] = vqtbl1q_s8(values, vandq_u8( bits.val[0], m4)); + qx[1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); + qx[2] = vqtbl1q_s8(values, vandq_u8( bits.val[1], m4)); + qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); +} + template <int nrc_y> void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -10539,43 +10621,92 @@ void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i auto values = vld1q_s8(iq4k_values); int nbl = n / QK_K; int8x16_t qx[8]; - int8x16x2_t iscales; - int32x4x4_t scales; - float32x4_t acc[nrc_y] = {}; - for (int ix = 0; ix < nrc_x; ix += 4) { + int8x16x4_t iscales; + int32x4x2_t scales; + float32x4_t acc[2*nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 8) { const block_iq4_xs_r4 * iq4 = (const block_iq4_xs_r4 *)((const char *)vx + ix*bx); for (int ibl = 0; ibl < nbl; ++ibl) { - auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d)); - auto sl = vld1q_u8(iq4[ibl].scales_l); - auto sh8 = vld1_u8(iq4[ibl].scales_h); - auto sh = vcombine_u8(sh8, vshr_n_u8(sh8, 2)); - iscales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl, m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32); - iscales.val[1] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl, 4), vandq_u8(sh, m3)), m32); + auto d4_f16 = vld1q_f16((const float16_t *)iq4[ibl].d); + auto d4l = vcvt_f32_f16(vget_low_f16 (d4_f16)); + auto d4h = vcvt_f32_f16(vget_high_f16(d4_f16)); + auto sl = vld1q_u8_x2(iq4[ibl].scales_l); + auto sh = vld1q_u8(iq4[ibl].scales_h); + iscales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32); + iscales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32); + iscales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32); + iscales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32); int32x4_t isum[nrc_y] = {}; - for (int is = 0; is < 2; ++is) { - auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is])); - auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is])); + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[ib64])); + auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[ib64])); scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1)); - scales.val[1] = vmovl_s16(vget_high_s16(iscales16_1)); - scales.val[2] = vmovl_s16(vget_low_s16(iscales16_2)); - scales.val[3] = vmovl_s16(vget_high_s16(iscales16_2)); - for (int ib = 0; ib < 4; ++ib) { - auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib); - prepare_iq4_nl_quants(values, m4, bits, qx); + scales.val[1] = vmovl_s16(vget_low_s16(iscales16_2)); + for (int l = 0; l < 2; ++l) { + uint8x16x2_t bits; + bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l); + bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 32); + prepare_iq4_nl_quants_r8(values, m4, bits, qx+0); + bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 64); + bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 96); + prepare_iq4_nl_quants_r8(values, m4, bits, qx+4); for (int iy = 0; iy < nrc_y; ++iy) { - auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib); - auto sumi = interleaved_dotq(qx, y); - isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi); + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+64*ib64+32*l); + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); + sumi = vdotq_laneq_s32(sumi, qx[1], y.val[0], 1); + sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 2); + sumi = vdotq_laneq_s32(sumi, qx[3], y.val[0], 3); + sumi = vdotq_laneq_s32(sumi, qx[4], y.val[1], 0); + sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 1); + sumi = vdotq_laneq_s32(sumi, qx[6], y.val[1], 2); + sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); + isum[iy] = vmlaq_s32(isum[iy], sumi, scales.val[l]); } } } for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + auto d8 = vdupq_n_f32(q8.scale(iy, ibl)); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[ib64])); + auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[ib64])); + scales.val[0] = vmovl_s16(vget_high_s16(iscales16_1)); + scales.val[1] = vmovl_s16(vget_high_s16(iscales16_2)); + for (int l = 0; l < 2; ++l) { + uint8x16x2_t bits; + bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 16); + bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 48); + prepare_iq4_nl_quants_r8(values, m4, bits, qx+0); + bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 80); + bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l +112); + prepare_iq4_nl_quants_r8(values, m4, bits, qx+4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+64*ib64+32*l); + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); + sumi = vdotq_laneq_s32(sumi, qx[1], y.val[0], 1); + sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 2); + sumi = vdotq_laneq_s32(sumi, qx[3], y.val[0], 3); + sumi = vdotq_laneq_s32(sumi, qx[4], y.val[1], 0); + sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 1); + sumi = vdotq_laneq_s32(sumi, qx[6], y.val[1], 2); + sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); + isum[iy] = vmlaq_s32(isum[iy], sumi, scales.val[l]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto d8 = vdupq_n_f32(q8.scale(iy, ibl)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[iy])); } } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, acc[iy]); - acc[iy] = vdupq_n_f32(0.f); + info.store(ix+0, iy, acc[2*iy+0]); + info.store(ix+4, iy, acc[2*iy+1]); + acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f); } } } @@ -12045,81 +12176,54 @@ struct Q6_0_R4_Dequantizer { template <int nrc_y> void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(nrc_x%8 == 0); Q8<nrc_y, block_q8_0_x4> q8(info); int nb = n / QK8_0; GGML_ASSERT(nb%4 == 0); - float32x4_t acc[nrc_y] = {}; + float32x4_t acc[2*nrc_y] = {}; + int8x16_t qx[16]; float d8[4*nrc_y]; - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); } for (int k = 0; k < 4; ++k) { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[4*ib4+k].d)); - auto qx1 = vld1q_s8_x4(iq8[4*ib4+k].qs); - auto qx2 = vld1q_s8_x4(iq8[4*ib4+k].qs+64); + auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d); + auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16)); + auto scales2 = vcvt_f32_f16(vget_high_f16(scales16)); + for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j); for (int iy = 0; iy < nrc_y; ++iy) { auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); - auto sumi = vdupq_n_s32(0); - sumi = vdotq_laneq_s32(sumi, qx1.val[0], y.val[0], 0); - sumi = vdotq_laneq_s32(sumi, qx1.val[1], y.val[1], 0); - sumi = vdotq_laneq_s32(sumi, qx1.val[2], y.val[0], 1); - sumi = vdotq_laneq_s32(sumi, qx1.val[3], y.val[1], 1); - sumi = vdotq_laneq_s32(sumi, qx2.val[0], y.val[0], 2); - sumi = vdotq_laneq_s32(sumi, qx2.val[1], y.val[1], 2); - sumi = vdotq_laneq_s32(sumi, qx2.val[2], y.val[0], 3); - sumi = vdotq_laneq_s32(sumi, qx2.val[3], y.val[1], 3); - auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k])); - acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); + auto sumi1 = vdupq_n_s32(0); + auto sumi2 = vdupq_n_s32(0); + sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0); + sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0); + sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1); + sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1); + sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2); + sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2); + sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3); + sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2); + sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3); + sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3); + auto dy = vdupq_n_f32(d8[4*iy+k]); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2)); } } } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, acc[iy]); - acc[iy] = vdupq_n_f32(0.f); - } - } -} - -template <int nrc_y> -void mul_mat_q8_0_r4_q8_0_128(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - GGML_ASSERT(n == 128); - int8x16x4_t qx[8]; - float32x4_t scales[4]; - float32x4_t scales_y[4]; - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); - for (int k = 0; k < 4; ++k) { - scales[k] = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[k].d)); - qx[2*k+0] = vld1q_s8_x4(iq8[k].qs); - qx[2*k+1] = vld1q_s8_x4(iq8[k].qs+64); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto by = (const block_q8_0_x4 *)info.src1_row(iy); - auto d8 = vcvt_f32_f16(vld1_f16((const float16_t *)by->d)); - scales_y[0] = vmulq_laneq_f32(scales[0], d8, 0); - scales_y[1] = vmulq_laneq_f32(scales[1], d8, 1); - scales_y[2] = vmulq_laneq_f32(scales[2], d8, 2); - scales_y[3] = vmulq_laneq_f32(scales[3], d8, 3); - auto sumf = vdupq_n_f32(0.f); - for (int k = 0; k < 4; ++k) { - auto y = vld1q_s8_x2(by->qs+32*k); - auto sumi = vdupq_n_s32(0); - sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[0], y.val[0], 0); - sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[1], y.val[1], 0); - sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[2], y.val[0], 1); - sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[3], y.val[1], 1); - sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[0], y.val[0], 2); - sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[1], y.val[1], 2); - sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[2], y.val[0], 3); - sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[3], y.val[1], 3); - sumf = vfmaq_f32(sumf, scales_y[k], vcvtq_f32_s32(sumi)); - } - info.store(ix, iy, sumf); + info.store(ix+0, iy, acc[2*iy+0]); + info.store(ix+4, iy, acc[2*iy+1]); + acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f); } } } @@ -12763,22 +12867,25 @@ struct HelperQ80R4 : public BaseHelper<step> { Base::stride = (D/QK8_0)*sizeof(block_q8_0); } - static std::vector<block_q8_0_x4> repack(int nk, const HelperQ80<D, step> q8) { + static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step> q8) { static_assert(D%QK8_0 == 0); - GGML_ASSERT(nk%4 == 0); + GGML_ASSERT(nk%8 == 0); constexpr int nblock = D/QK8_0; - std::vector<block_q8_0_x4> result(nblock * nk/4); + std::vector<block_q8_0_r8> result(nblock * nk/8); auto y = result.data(); - const block_q8_0 * x4[4]; - for (int row = 0; row < nk; row += 4) { - for (int k = 0; k < 4; ++k) x4[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride); + const block_q8_0 * x8[8]; +#ifdef __ARM_NEON + int8x16x2_t m0, m1, m2, m3; +#endif + for (int row = 0; row < nk; row += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride); for (int ib = 0; ib < nblock; ++ib) { - for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d; + for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; #ifdef __AVX2__ - auto m0 = _mm256_loadu_si256((const __m256i *)x4[0][ib].qs); - auto m1 = _mm256_loadu_si256((const __m256i *)x4[1][ib].qs); - auto m2 = _mm256_loadu_si256((const __m256i *)x4[2][ib].qs); - auto m3 = _mm256_loadu_si256((const __m256i *)x4[3][ib].qs); + auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs), _mm_loadu_si128((const __m128i *)x8[0][ib].qs)); + auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs), _mm_loadu_si128((const __m128i *)x8[1][ib].qs)); + auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs), _mm_loadu_si128((const __m128i *)x8[2][ib].qs)); + auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs), _mm_loadu_si128((const __m128i *)x8[3][ib].qs)); auto t0 = _mm256_unpacklo_epi32(m0, m1); auto t1 = _mm256_unpacklo_epi32(m2, m3); auto t2 = _mm256_unpackhi_epi32(m0, m1); @@ -12791,32 +12898,50 @@ struct HelperQ80R4 : public BaseHelper<step> { _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1); _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2); _mm256_storeu_si256((__m256i *)y[ib].qs + 3, m3); + m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[0][ib].qs+1)); + m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[1][ib].qs+1)); + m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[2][ib].qs+1)); + m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[3][ib].qs+1)); + t0 = _mm256_unpacklo_epi32(m0, m1); + t1 = _mm256_unpacklo_epi32(m2, m3); + t2 = _mm256_unpackhi_epi32(m0, m1); + t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); + _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0); + _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1); + _mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2); + _mm256_storeu_si256((__m256i *)y[ib].qs + 7, m3); #elif defined __ARM_NEON - auto m0 = vld1q_s8_x2(x4[0][ib].qs); - auto m1 = vld1q_s8_x2(x4[1][ib].qs); - auto m2 = vld1q_s8_x2(x4[2][ib].qs); - auto m3 = vld1q_s8_x2(x4[3][ib].qs); - auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); - auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); - m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); - m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); - m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); - m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); - row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); - row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); - m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); - m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); - m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); - m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); - vst1q_s8_x2(y[ib].qs + 0, m0); - vst1q_s8_x2(y[ib].qs + 32, m1); - vst1q_s8_x2(y[ib].qs + 64, m2); - vst1q_s8_x2(y[ib].qs + 96, m3); + for (int l = 0; l < 2; ++l) { + m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l); + m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l); + m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l); + m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0); + vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1); + vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2); + vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3); + } #else for (int l = 0; l < 4; ++l) { - for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { - y[ib].qs[32*l+4*k+i+ 0] = x4[k][ib].qs[i+4*l+ 0]; - y[ib].qs[32*l+4*k+i+16] = x4[k][ib].qs[i+4*l+16]; + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; } } #endif @@ -12826,7 +12951,7 @@ struct HelperQ80R4 : public BaseHelper<step> { return result; } - std::vector<block_q8_0_x4> r4; + std::vector<block_q8_0_r8> r4; }; template <int D, int step> @@ -13370,78 +13495,6 @@ struct FlashQKV { qkv_cache_t qkv_cache[D*q_step] = {}; }; -#ifdef HAVE_FANCY_SIMD -template <int nrc_y> -static void mul_mat_q8_0_r4_q8_1_128([[maybe_unused]] int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); - GGML_ASSERT(n == 128); - //Q8<nrc_y, block_q8_1_x4> q8(info); - __m512i qx[16]; - __m512 scales[4]; - __m512 scales_m[4]; - __m512 dy[4]; - auto m127 = _mm512_set1_epi8(127); - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_q8_0_x4 * q8l = (const block_q8_0_x4 *)((const char *)vx + (ix+0)*bx); - const block_q8_0_x4 * q8h = (const block_q8_0_x4 *)((const char *)vx + (ix+4)*bx); - for (int k = 0; k < 4; ++k) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8l[k].d)); - auto scales1 = _mm256_set_m128(scales128, scales128); - scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8h[k].d)); - auto scales2 = _mm256_set_m128(scales128, scales128); - scales[k] = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - scales_m[k] = _mm512_mul_ps(scales[k], _mm512_set1_ps(-63.5f)); - qx[4*k+0] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+0)), - _mm256_loadu_si256((const __m256i *)q8h[k].qs+0), 1); - qx[4*k+1] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+1)), - _mm256_loadu_si256((const __m256i *)q8h[k].qs+1), 1); - qx[4*k+2] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+2)), - _mm256_loadu_si256((const __m256i *)q8h[k].qs+2), 1); - qx[4*k+3] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+3)), - _mm256_loadu_si256((const __m256i *)q8h[k].qs+3), 1); - qx[4*k+0] = _mm512_add_epi8(qx[4*k+0], m127); - qx[4*k+1] = _mm512_add_epi8(qx[4*k+1], m127); - qx[4*k+2] = _mm512_add_epi8(qx[4*k+2], m127); - qx[4*k+3] = _mm512_add_epi8(qx[4*k+3], m127); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto by = (const block_q8_1_x4 *)info.src1_row(iy); - //auto dall = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][0].d)); - auto dall = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)by->d)); - auto d128 = _mm256_castps256_ps128(dall); - auto m128 = _mm256_extractf128_ps(dall, 1); - auto m256 = _mm256_set_m128(m128, m128); - auto m512 = _mm512_insertf32x8(_mm512_castps256_ps512(m256), m256, 1); - auto sumf = _mm512_mul_ps(scales_m[0], _mm512_shuffle_ps(m512, m512, 0x00)); - sumf = _mm512_fmadd_ps(scales_m[1], _mm512_shuffle_ps(m512, m512, 0x55), sumf); - sumf = _mm512_fmadd_ps(scales_m[2], _mm512_shuffle_ps(m512, m512, 0xaa), sumf); - sumf = _mm512_fmadd_ps(scales_m[3], _mm512_shuffle_ps(m512, m512, 0xff), sumf); - auto d256 = _mm256_set_m128(d128, d128); - auto d512 = _mm512_insertf32x8(_mm512_castps256_ps512(d256), d256, 1); - dy[0] = _mm512_mul_ps(scales[0], _mm512_shuffle_ps(d512, d512, 0x00)); - dy[1] = _mm512_mul_ps(scales[1], _mm512_shuffle_ps(d512, d512, 0x55)); - dy[2] = _mm512_mul_ps(scales[2], _mm512_shuffle_ps(d512, d512, 0xaa)); - dy[3] = _mm512_mul_ps(scales[3], _mm512_shuffle_ps(d512, d512, 0xff)); - for (int k = 0; k < 4; ++k) { - //auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][0].qs+k); - auto y8 = _mm256_loadu_si256((const __m256i*)by->qs+k); - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - sumf = _mm512_fmadd_ps(dy[k], _mm512_cvtepi32_ps(sumi), sumf); - } - auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sumf, 0), _mm512_extractf32x4_ps(sumf, 1)); - auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sumf, 2), _mm512_extractf32x4_ps(sumf, 3)); - info.store(ix+0, iy, sum1); - info.store(ix+4, iy, sum2); - } - } -} -#endif - template <int D, int q_step, int k_step> struct FlashQKfp32 { static_assert(D%F16::block_size == 0 && D <= 256); @@ -13706,45 +13759,10 @@ struct FlashQKfp32 { } else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) { #ifdef __aarch64__ - if constexpr (D == 128) { - if (q_step >= 64 && nq >= 64) { - return std::make_pair(mul_mat_q8_0_r4_q8_0_128<64>, 64); - } - else if (q_step >= 32 && nq >= 32) { - return std::make_pair(mul_mat_q8_0_r4_q8_0_128<32>, 32); - } - else if (q_step >= 16 && nq >= 16) { - return std::make_pair(mul_mat_q8_0_r4_q8_0_128<16>, 16); - } - else { - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0_128, nq); - } - } else { - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); - } - //MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); -#else -#ifdef HAVE_FANCY_SIMD - if constexpr (D == 128) { - if (q_step >= 64 && nq >= 64) { - return std::make_pair(mul_mat_q8_0_r4_q8_1_128<64>, 64); - } - else if (q_step >= 32 && nq >= 32) { - return std::make_pair(mul_mat_q8_0_r4_q8_1_128<32>, 32); - } - else if (q_step >= 16 && nq >= 16) { - return std::make_pair(mul_mat_q8_0_r4_q8_1_128<16>, 16); - } - else { - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1_128, nq); - } - } else { - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq); - } + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); #else MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq); #endif -#endif } else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) { #ifdef __aarch64__ diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 221bc48c..59a36c5c 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -3709,63 +3709,63 @@ void vec_dot_q4_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t b // // ========================================= q8_0_r4 // -void quantize_row_q8_0_r4_ref(const float * x, block_q8_0_x4 * y, int64_t k) { +void quantize_row_q8_0_r4_ref(const float * x, block_q8_0_r8 * y, int64_t k) { // we assume we are called with 4 rows - quantize_q8_0_r4(x, (void *)y, 4, k/4, nullptr); + quantize_q8_0_r4(x, (void *)y, 8, k/8, nullptr); } void quantize_row_q8_0_r4(const float * x, void * y, int64_t k) { // we assume we are called with 4 rows - quantize_q8_0_r4(x, y, 4, k/4, nullptr); + quantize_q8_0_r4(x, y, 8, k/8, nullptr); } -static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8_0_x4 * y) { - GGML_ASSERT(nrows%4 == 0); +static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8_0_r8 * y) { + GGML_ASSERT(nrows%8 == 0); GGML_ASSERT(n_per_row%QK8_0 == 0); int nblock = n_per_row/QK8_0; - const block_q8_0 * x4[4]; - for (int row = 0; row < nrows; row += 4) { - for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + const block_q8_0 * x8[8]; + for (int row = 0; row < nrows; row += 8) { + for (int k = 0; k < 8; ++k) x8[k] = x + nblock*k; for (int ib = 0; ib < nblock; ++ib) { - for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d; + for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; for (int l = 0; l < 4; ++l) { - for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { - y[ib].qs[32*l+4*k+i+ 0] = x4[k][ib].qs[i+4*l+ 0]; - y[ib].qs[32*l+4*k+i+16] = x4[k][ib].qs[i+4*l+16]; + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; } } } - x += 4*nblock; + x += 8*nblock; y += nblock; } } size_t quantize_q8_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { - GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(nrows%8 == 0); auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); - std::vector<char> qtmp(4*row_size_0); + std::vector<char> qtmp(8*row_size_0); char * qrow = (char *)dst; - for (int row = 0; row < nrows; row += 4) { - quantize_q8_0(src, qtmp.data(), 4, n_per_row, imatrix); - repack_q8_0(4, n_per_row, (const block_q8_0 *)qtmp.data(), (block_q8_0_x4 *)qrow); - src += 4*n_per_row; - qrow += 4*row_size_0; + for (int row = 0; row < nrows; row += 8) { + quantize_q8_0(src, qtmp.data(), 8, n_per_row, imatrix); + repack_q8_0(8, n_per_row, (const block_q8_0 *)qtmp.data(), (block_q8_0_r8 *)qrow); + src += 8*n_per_row; + qrow += 8*row_size_0; } return nrows*row_size_0; } -void dequantize_row_q8_0_r4(const block_q8_0_x4 * x, float * y, int64_t k) { +void dequantize_row_q8_0_r4(const block_q8_0_r8 * x, float * y, int64_t k) { // we assume we are called with 4 rows - int n_per_row = k/4; + int n_per_row = k/8; int nb = n_per_row/QK8_0; - float * yk[4]; - for (int k = 0; k < 4; ++k) yk[k] = y + k*n_per_row; + float * yk[8]; + for (int k = 0; k < 8; ++k) yk[k] = y + k*n_per_row; for (int ib = 0; ib < nb; ++ib) { - for (int k = 0; k < 4; ++k) { + for (int k = 0; k < 8; ++k) { float scale = GGML_FP16_TO_FP32(x[ib].d[k]); for (int l = 0; l < 4; ++l) for (int i = 0; i < 4; ++i) { - yk[k][QK8_0*ib+4*l+i+ 0] = scale * x[ib].qs[QK8_0*l+4*k+i+ 0]; - yk[k][QK8_0*ib+4*l+i+16] = scale * x[ib].qs[QK8_0*l+4*k+i+16]; + yk[k][QK8_0*ib+4*l+i+ 0] = scale * x[ib].qs[32*l+4*k+i+ 0]; + yk[k][QK8_0*ib+4*l+i+16] = scale * x[ib].qs[32*l+4*k+i+128]; } } } @@ -3987,93 +3987,77 @@ void vec_dot_q6_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t b // void quantize_row_iq4_xs_r4_ref(const float * x, block_iq4_xs_r4 * y, int64_t k) { - quantize_iq4_xs_r4(x, (void *)y, 4, k/4, nullptr); + quantize_iq4_xs_r4(x, (void *)y, 8, k/8, nullptr); } void quantize_row_iq4_xs_r4(const float * x, void * y, int64_t k) { - quantize_iq4_xs_r4(x, y, 4, k/4, nullptr); + quantize_iq4_xs_r4(x, y, 8, k/8, nullptr); } static void repack_iq4_xs(int nrows, int n_per_row, const block_iq4_xs * x, block_iq4_xs_r4 * y) { - GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(nrows%8 == 0); GGML_ASSERT(n_per_row%QK_K == 0); int nblock = n_per_row/QK_K; - const block_iq4_xs * x4[4]; - for (int row = 0; row < nrows; row += 4) { - for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + const block_iq4_xs * x8[8]; + for (int row = 0; row < nrows; row += 8) { + for (int k = 0; k < 8; ++k) x8[k] = x + nblock*k; for (int ibl = 0; ibl < nblock; ++ibl) { - std::memset(y[ibl].scales_l, 0, QK_K/16); - std::memset(y[ibl].scales_h, 0, QK_K/32); - for (int k = 0; k < 4; ++k) { - y[ibl].d[k] = x4[k][ibl].d; + std::memset(y[ibl].scales_l, 0, QK_K/8); + std::memset(y[ibl].scales_h, 0, QK_K/16); + for (int k = 0; k < 8; ++k) { + y[ibl].d[k] = x8[k][ibl].d; for (int ib = 0; ib < QK_K/32; ++ib) { - uint8_t sl = (x4[k][ibl].scales_l[ib/2] >> 4*(ib%2)) & 0xf; - uint8_t sh = (x4[k][ibl].scales_h >> 2*ib) & 3; - int i = 4*ib + k; - y[ibl].scales_l[i%16] |= (sl << 4*(i/16)); - y[ibl].scales_h[i%8 ] |= (sh << 2*(i/8)); - } - } - for (int ib = 0; ib < QK_K/32; ++ib) { - for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { - y[ibl].qs[64*ib+4*k+i+ 0] = (x4[k][ibl].qs[16*ib+i+0] & 0xf) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0x0f) << 4); // 0....3 + 8...11 from each row - y[ibl].qs[64*ib+4*k+i+16] = (x4[k][ibl].qs[16*ib+i+0] >> 4) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0xf0)); // 16...19 + 24...27 from each row - y[ibl].qs[64*ib+4*k+i+32] = (x4[k][ibl].qs[16*ib+i+4] & 0xf) | ((x4[k][ibl].qs[16*ib+i+12] & 0x0f) << 4); // 4....7 + 12...15 from each row - y[ibl].qs[64*ib+4*k+i+48] = (x4[k][ibl].qs[16*ib+i+4] >> 4) | ((x4[k][ibl].qs[16*ib+i+12] & 0xf0)); // 20...23 + 28...31 from each row + uint8_t sl = (x8[k][ibl].scales_l[ib/2] >> 4*(ib%2)) & 0xf; + uint8_t sh = (x8[k][ibl].scales_h >> 2*ib) & 3; + int i = 8*ib + k; + y[ibl].scales_l[i%32] |= (sl << 4*(i/32)); + y[ibl].scales_h[i%16] |= (sh << 2*(i/16)); + for (int i = 0; i < 4; ++i) { + y[ibl].qs[128*ib+4*k+i+ 0] = (x8[k][ibl].qs[16*ib+i+0] & 0xf) | ((x8[k][ibl].qs[16*ib+i+ 4] & 0xf) << 4); + y[ibl].qs[128*ib+4*k+i+32] = (x8[k][ibl].qs[16*ib+i+8] & 0xf) | ((x8[k][ibl].qs[16*ib+i+12] & 0xf) << 4); + y[ibl].qs[128*ib+4*k+i+64] = (x8[k][ibl].qs[16*ib+i+0] >> 4) | ((x8[k][ibl].qs[16*ib+i+ 4] >> 4) << 4); + y[ibl].qs[128*ib+4*k+i+96] = (x8[k][ibl].qs[16*ib+i+8] >> 4) | ((x8[k][ibl].qs[16*ib+i+12] >> 4) << 4); + } } } } - x += 4*nblock; + x += 8*nblock; y += nblock; } } size_t quantize_iq4_xs_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { - GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(nrows%8 == 0); GGML_ASSERT(n_per_row%QK_K == 0); char * qcur = (char *)dst; auto row_size = ggml_row_size(GGML_TYPE_IQ4_XS, n_per_row); - std::vector<char> qtmp(4*row_size); - for (int row = 0; row < nrows; row += 4) { - quantize_iq4_xs(src, (void *)qtmp.data(), 4, n_per_row, imatrix); - repack_iq4_xs(4, n_per_row, (const block_iq4_xs *)qtmp.data(), (block_iq4_xs_r4 *)qcur); - qcur += 4*row_size; - src += 4*n_per_row; + std::vector<char> qtmp(8*row_size); + for (int row = 0; row < nrows; row += 8) { + quantize_iq4_xs(src, (void *)qtmp.data(), 8, n_per_row, imatrix); + repack_iq4_xs(8, n_per_row, (const block_iq4_xs *)qtmp.data(), (block_iq4_xs_r4 *)qcur); + qcur += 8*row_size; + src += 8*n_per_row; } return nrows*row_size; } void dequantize_row_iq4_xs_r4(const block_iq4_xs_r4 * x, float * y, int64_t k) { - auto n_per_row = k/4; - float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + auto n_per_row = k/8; + float * y8[8]; + for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k; int nblock = n_per_row/QK_K; for (int ibl = 0; ibl < nblock; ++ibl) { - for (int k = 0; k < 4; ++k) { + for (int k = 0; k < 8; ++k) { const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); for (int ib = 0; ib < QK_K/32; ++ib) { - int is = 4*ib + k; - float dl = d * ((((x[ibl].scales_l[is%16] >> 4*(is/16)) & 0xf) | (((x[ibl].scales_h[is%8] >> 2*(is/8)) & 3) << 4)) - 32); - for (int i = 0; i < 4; ++i) { - y4[k][QK_K*ibl+32*ib+i+ 0] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+ 0] & 0xf]; - y4[k][QK_K*ibl+32*ib+i+ 8] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+ 0] >> 4]; - y4[k][QK_K*ibl+32*ib+i+16] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+16] & 0xf]; - y4[k][QK_K*ibl+32*ib+i+24] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+16] >> 4]; - y4[k][QK_K*ibl+32*ib+i+ 4] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+32] & 0xf]; - y4[k][QK_K*ibl+32*ib+i+12] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+32] >> 4]; - y4[k][QK_K*ibl+32*ib+i+20] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+48] & 0xf]; - y4[k][QK_K*ibl+32*ib+i+28] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+48] >> 4]; + int is = 8*ib + k; + float dl = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + for (int l = 0; l < 4; ++l) for (int i = 0; i < 4; ++i) { + y8[k][QK_K*ibl+32*ib+8*l+i+0] = dl * iq4k_values[x[ibl].qs[128*ib+4*k+i+32*l] & 0xf]; + y8[k][QK_K*ibl+32*ib+8*l+i+4] = dl * iq4k_values[x[ibl].qs[128*ib+4*k+i+32*l] >> 4]; } } } - //dequantize_row_iq4_xs(x + ib, ytmp, QK_K); - //for (int k = 0; k < 4; ++k) { - // for (int l = 0; l < 16; ++l) { - // for (int i = 0; i < 4; ++i) { - // //y4[k][ib*kBlockSize + i + 16*(l%4) + 4*(l/4)] = ytmp[16*l + 4*k + i]; - // y4[k][ib*kBlockSize + i + 8*(l%8) + 4*(l/8)] = ytmp[16*l + 4*k + i]; - // } - // } - //} } } @@ -6063,7 +6047,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_IQ3_K, { GGML_TYPE_IQ3_K_R4, 4, (Repack::repack_func)repack_iq3_k} }, { GGML_TYPE_IQ4_K, { GGML_TYPE_IQ4_K_R4, 4, (Repack::repack_func)repack_iq4_k} }, { GGML_TYPE_IQ5_K, { GGML_TYPE_IQ5_K_R4, 4, (Repack::repack_func)repack_iq5_k} }, - { GGML_TYPE_IQ4_XS, { GGML_TYPE_IQ4_XS_R4, 4, (Repack::repack_func)repack_iq4_xs} }, + { GGML_TYPE_IQ4_XS, { GGML_TYPE_IQ4_XS_R4, 8, (Repack::repack_func)repack_iq4_xs} }, { GGML_TYPE_IQ4_KS, { GGML_TYPE_IQ4_KS_R4, 4, (Repack::repack_func)repack_iq4_ks} }, { GGML_TYPE_IQ4_NL, { GGML_TYPE_IQ4_NL_R4, 4, (Repack::repack_func)repack_iq4_nl} }, { GGML_TYPE_IQ2_BN, { GGML_TYPE_IQ2_BN_R4, 4, (Repack::repack_func)repack_iq2_bn} }, @@ -6080,7 +6064,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_Q4_0, { GGML_TYPE_Q4_0_R4, 4, (Repack::repack_func)repack_q4_0} }, { GGML_TYPE_Q5_0, { GGML_TYPE_Q5_0_R4, 4, (Repack::repack_func)repack_q5_0} }, { GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} }, - { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R4, 4, (Repack::repack_func)repack_q8_0} }, + { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R4, 8, (Repack::repack_func)repack_q8_0} }, { GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} }, #ifdef __AVX512BF16__ { GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_bf16_t>}}, diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 729b0ec0..64860b4d 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -73,10 +73,10 @@ size_t quantize_q4_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds void dequantize_row_q4_0_r4(const block_iq4_nl_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_q4_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void quantize_row_q8_0_r4_ref(const float * GGML_RESTRICT x, block_q8_0_x4 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_r4_ref(const float * GGML_RESTRICT x, block_q8_0_r8 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_0_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_q8_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -void dequantize_row_q8_0_r4(const block_q8_0_x4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q8_0_r4(const block_q8_0_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_q8_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void quantize_row_q5_0_r4_ref(const float * GGML_RESTRICT x, block_q5_0_r4 * GGML_RESTRICT y, int64_t k); diff --git a/src/llama.cpp b/src/llama.cpp index c2bc5cc0..836fd97a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16906,8 +16906,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s else chunk_size_multiplier = 4; } else if (new_type == GGML_TYPE_IQ4_XS_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_XS; - else chunk_size_multiplier = 4; + if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_IQ4_XS; + else chunk_size_multiplier = 8; } else if (new_type == GGML_TYPE_Q4_0_R4) { if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0; @@ -16922,8 +16922,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s else chunk_size_multiplier = 4; } else if (new_type == GGML_TYPE_Q8_0_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q8_0; - else chunk_size_multiplier = 4; + if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0; + else chunk_size_multiplier = 8; } else if (new_type == GGML_TYPE_Q2_K_R4) { if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q2_K; |