summaryrefslogtreecommitdiff
path: root/ggml
diff options
context:
space:
mode:
Diffstat (limited to 'ggml')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp494
1 files changed, 339 insertions, 155 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 8d2b4090..308d0dca 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -111,6 +111,15 @@ struct Perf {
#define IQK_ALWAYS_INLINE __attribute__((__always_inline__))
#endif
+#if defined __x86_64__
+#if defined HAVE_FANCY_SIMD
+ #undef HAVE_FANCY_SIMD
+#endif
+#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)
+ #define HAVE_FANCY_SIMD
+#endif
+#endif
+
namespace {
typedef struct {
@@ -236,6 +245,35 @@ struct MulMat {
}
static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny);
static inline int num_rows(ggml_type type) {
+#ifdef HAVE_FANCY_SIMD
+ switch (type) {
+ case GGML_TYPE_Q2_K_R4:
+ case GGML_TYPE_Q3_K_R4:
+ case GGML_TYPE_Q6_K_R4:
+ case GGML_TYPE_IQ2_K_R4:
+ case GGML_TYPE_IQ3_K_R4:
+ case GGML_TYPE_IQ4_K_R4:
+ case GGML_TYPE_IQ5_K_R4:
+ case GGML_TYPE_IQ4_KS_R4:
+ case GGML_TYPE_IQ2_XXS_R4:
+ case GGML_TYPE_IQ2_XS_R4:
+ case GGML_TYPE_IQ2_S_R4:
+ case GGML_TYPE_IQ3_XXS_R4:
+ case GGML_TYPE_IQ3_S_R4: return 4;
+ case GGML_TYPE_IQ4_NL_R4:
+ case GGML_TYPE_Q5_0_R4:
+ case GGML_TYPE_Q6_0_R4:
+ case GGML_TYPE_IQ2_BN_R4:
+ case GGML_TYPE_IQ4_XS_R4:
+ case GGML_TYPE_Q4_K_R4:
+ case GGML_TYPE_Q5_K_R4:
+ case GGML_TYPE_Q8_K_R8: return 8;
+ case GGML_TYPE_Q4_0_R4:
+ case GGML_TYPE_Q8_0_R4:
+ case GGML_TYPE_BF16_R16: return 16;
+ default: return 1;
+ }
+#else
switch (type) {
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K_R4:
@@ -263,6 +301,7 @@ struct MulMat {
case GGML_TYPE_BF16_R16: return 16;
default: return 1;
}
+#endif
}
private:
template <typename Dequantizer> static void set_functions(MulMat& m);
@@ -377,13 +416,6 @@ const uint64_t keven_signs[128] = {
#if defined __x86_64__
-#if defined HAVE_FANCY_SIMD
- #undef HAVE_FANCY_SIMD
-#endif
-#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)
- #define HAVE_FANCY_SIMD
-#endif
-
namespace {
inline float hsum_float_4(__m128 x) {
@@ -2608,6 +2640,15 @@ static void mul_mat_q4_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D
acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(helper.val[k+4]), acc2);
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto qy = (const block_q8_1 *)q8.y[0];
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d));
+ prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4);
+ auto sumi = accum_q4_0_quants(v, qy[ib].qs);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1);
+ acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc2);
+ }
acc1 = _mm256_fmadd_ps(acc2, _mm256_set1_ps(-8.f), acc1);
info.store(ix, 0, acc1);
}
@@ -2645,6 +2686,18 @@ static void mul_mat_q4_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d));
+ auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f));
+ prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = accum_q4_0_quants(v, qy[ib].qs);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]);
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = _mm256_setzero_ps();
@@ -2664,9 +2717,38 @@ static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
Q8<nrc_y, block_q8_1_x4> q8(info);
auto m4 = _mm512_set1_epi8(0xf);
int nb = n / QK4_NL;
- GGML_ASSERT(nb%4 == 0);
__m512 acc[2*nrc_y] = {};
__m512i qx[8];
+ auto prepare = [&qx, &m4] (const block_iq4_nl_r8& iq4l, const block_iq4_nl_r8& iq4h) {
+ auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l.d));
+ auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h.d));
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ for (int j = 0; j < 4; ++j) {
+ auto bits = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+j)),
+ _mm256_loadu_si256((const __m256i *)iq4h.qs+j), 1);
+ qx[j+0] = _mm512_and_si512(bits, m4);
+ qx[j+4] = _mm512_and_si512(_mm512_srli_epi16(bits, 4), m4);
+ }
+ return scales;
+ };
+ auto dot = [&qx] (const int8_t * qy) {
+ auto y4l = _mm_loadu_si128((const __m128i*)qy+0);
+ auto y4h = _mm_loadu_si128((const __m128i*)qy+1);
+ auto y8l = MM256_SET_M128I(y4l, y4l);
+ auto y8h = MM256_SET_M128I(y4h, y4h);
+ auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
+ auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
+ return sumi;
+ };
float d8[8*nrc_y];
for (int ix = 0; ix < nrc_x; ix += 16) {
const block_iq4_nl_r8 * iq4l = (const block_iq4_nl_r8 *)((const char *)vx + (ix+0)*bx);
@@ -2676,47 +2758,25 @@ static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
_mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
- auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l[4*ib4+k].d));
- auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h[4*ib4+k].d));
- auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
- auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)),
- _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1);
- auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)),
- _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1);
- auto bits3 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+2)),
- _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+2), 1);
- auto bits4 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+3)),
- _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+3), 1);
- qx[0] = _mm512_and_si512(bits1, m4);
- qx[1] = _mm512_and_si512(bits2, m4);
- qx[2] = _mm512_and_si512(bits3, m4);
- qx[3] = _mm512_and_si512(bits4, m4);
- qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4);
- qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4);
- qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits3, 4), m4);
- qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits4, 4), m4);
+ auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0);
- auto y4h = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1);
- auto y8l = MM256_SET_M128I(y4l, y4l);
- auto y8h = MM256_SET_M128I(y4h, y4h);
- auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
- auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
+ auto sumi = dot(q8.y[iy][ib4].qs+32*k);
auto dy = _mm512_set1_ps(d8[8*iy+k]);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq4l[ib], iq4h[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(qy[ib].qs);
+ auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]);
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
@@ -2981,12 +3041,56 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
#endif
#ifdef HAVE_FANCY_SIMD
+inline __m512i qx_r8_q8_dot_product(const __m512i * qx, const int8_t * y) {
+ auto y4l = _mm_loadu_si128((const __m128i*)y+0);
+ auto y4h = _mm_loadu_si128((const __m128i*)y+1);
+ auto y8l = MM256_SET_M128I(y4l, y4l);
+ auto y8h = MM256_SET_M128I(y4h, y4h);
+ auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
+ auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
+ return sumi;
+}
+inline __m256i qx_r8_q8_dot_product(const __m256i * qx, const int8_t * y) {
+ auto y4l = _mm_loadu_si128((const __m128i*)y+0);
+ auto y4h = _mm_loadu_si128((const __m128i*)y+1);
+ auto yl = MM256_SET_M128I(y4l, y4l);
+ auto yh = MM256_SET_M128I(y4h, y4h);
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff));
+ return sumi;
+}
+inline __m256i q8_0_r8_dot_product(const uint8_t * x, const int8_t * y, __m256i * qx) {
+ qx[0] = _mm256_loadu_si256((const __m256i *)x+0);
+ qx[1] = _mm256_loadu_si256((const __m256i *)x+1);
+ qx[2] = _mm256_loadu_si256((const __m256i *)x+2);
+ qx[3] = _mm256_loadu_si256((const __m256i *)x+3);
+ qx[4] = _mm256_loadu_si256((const __m256i *)x+4);
+ qx[5] = _mm256_loadu_si256((const __m256i *)x+5);
+ qx[6] = _mm256_loadu_si256((const __m256i *)x+6);
+ qx[7] = _mm256_loadu_si256((const __m256i *)x+7);
+ return qx_r8_q8_dot_product(qx, y);
+}
template <int nrc_y>
static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%16 == 0);
Q8<nrc_y, block_q8_1_x4> q8(info);
int nb = n / QK8_0;
- GGML_ASSERT(nb%4 == 0);
if constexpr (nrc_y == 1) {
__m256 acc[2] = {};
__m256i qx[8];
@@ -2997,32 +3101,22 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
_mm256_storeu_ps(d8, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)));
for (int k = 0; k < 4; ++k) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d));
- qx[0] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0);
- qx[1] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1);
- qx[2] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2);
- qx[3] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3);
- qx[4] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4);
- qx[5] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5);
- qx[6] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6);
- qx[7] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7);
- auto y4l = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+0);
- auto y4h = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+1);
- auto yl = MM256_SET_M128I(y4l, y4l);
- auto yh = MM256_SET_M128I(y4h, y4h);
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff));
- sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff));
+ auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[4*ib4+k].qs, q8.y[0][ib4].qs+32*k, qx);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[k]));
acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]);
acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[k+4]), acc[1]);
}
}
+ if (4*(nb/4) < nb) {
+ auto qy = (const block_q8_1 *)q8.y[0];
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
+ auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]);
+ acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[1]);
+ }
+ }
info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0]));
acc[0] = acc[1] = _mm256_setzero_ps();
}
@@ -3046,27 +3140,29 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
_mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+j), 1);
}
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0);
- auto y4h = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1);
- auto y8l = MM256_SET_M128I(y4l, y4l);
- auto y8h = MM256_SET_M128I(y4h, y4h);
- auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
- auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
+ auto sumi = qx_r8_q8_dot_product(qx, q8.y[iy][ib4].qs+32*k);
auto dy = _mm512_set1_ps(d8[8*iy+k]);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[ib].d));
+ auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[ib].d));
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ for (int j = 0; j < 8; ++j) {
+ qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[ib].qs+j)),
+ _mm256_loadu_si256((const __m256i *)q8h[ib].qs+j), 1);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs);
+ auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-127.f), acc[2*iy+1], acc[2*iy+0]);
info.store(ix, iy, sum512);
@@ -3082,9 +3178,22 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
Q8<nrc_y, block_q8_1_x4> q8(info);
auto m1 = _mm256_set1_epi16(1);
int nb = n / QK8_0;
- GGML_ASSERT(nb%4 == 0);
__m256 acc[nrc_y] = {};
float d8[4*nrc_y];
+ __m256i qx[4], sx[4];
+ auto dot = [&qx, &sx, &m1] (const int8_t * qy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)qy);
+ auto y = MM256_SET_M128I(y128, y128);
+ auto sumi1 = _mm256_add_epi32(
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])))
+ );
+ auto sumi2 = _mm256_add_epi32(
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])))
+ );
+ return _mm256_add_epi32(sumi1, sumi2);
+ };
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
@@ -3094,54 +3203,49 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
}
for (int k = 0; k < 4; ++k) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d));
- auto q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0);
- auto q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1);
- auto q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2);
- auto q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3);
- auto s0 = _mm256_sign_epi8(q0, q0);
- auto s1 = _mm256_sign_epi8(q1, q1);
- auto s2 = _mm256_sign_epi8(q2, q2);
- auto s3 = _mm256_sign_epi8(q3, q3);
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0);
- auto y = MM256_SET_M128I(y128, y128);
- auto sumi1 = _mm256_add_epi32(
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1)))
- );
- auto sumi2 = _mm256_add_epi32(
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3)))
- );
- auto sumi = _mm256_add_epi32(sumi1, sumi2);
+ auto sumi = dot(q8.y[iy][ib4].qs+32*k);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
- q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4);
- q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5);
- q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6);
- q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7);
- s0 = _mm256_sign_epi8(q0, q0);
- s1 = _mm256_sign_epi8(q1, q1);
- s2 = _mm256_sign_epi8(q2, q2);
- s3 = _mm256_sign_epi8(q3, q3);
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4+j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1);
- auto y = MM256_SET_M128I(y128, y128);
- auto sumi1 = _mm256_add_epi32(
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1)))
- );
- auto sumi2 = _mm256_add_epi32(
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3)))
- );
- auto sumi = _mm256_add_epi32(sumi1, sumi2);
+ auto sumi = dot(q8.y[iy][ib4].qs+32*k+16);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(qy[ib].qs);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+4+j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(qy[ib].qs+16);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = _mm256_setzero_ps();
@@ -7080,6 +7184,7 @@ struct QFBase {
static inline Acc acc_first(const Data& y, const Data& x) {
return _mm512_mul_ps(y, x);
}
+ static inline Acc add(Acc x, Acc y) { return _mm512_add_ps(x, y); }
static inline float hsum(Acc acc) {
return _mm512_reduce_add_ps(acc);
}
@@ -7118,6 +7223,7 @@ struct QFBase {
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
return _mm256_fmadd_ps(y, x, prev);
}
+ static inline Acc add(Acc x, Acc y) { return _mm256_add_ps(x, y); }
static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc);
acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
@@ -7190,6 +7296,44 @@ template <typename Float, int nrc_in> struct QFT final : public QFBase {
const Float * y[nrc];
};
+// TBD if we want this
+//template <typename Qy, typename Qx>
+//IQK_NOINLINE void mul_mat_Qx_Qy_Mx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+// static_assert(Qy::nrc == 1);
+// int nb = n/QFBase::k_step;
+// int nb4 = n/4;
+// Qy y(info);
+// Qx x(cx + ix0*bx, bx);
+// QFBase::Data xv[2*Qx::nrc];
+// QFBase::Acc acc[2*Qx::nrc];
+// auto yv1 = y.load1(0, 0);
+// auto yv2 = y.load1(0, 1);
+// for (int ix = 0; ix < Qx::nrc; ++ix) {
+// xv[2*ix+0] = x.load1(ix, 0);
+// xv[2*ix+1] = x.load1(ix, 1);
+// acc[2*ix+0] = QFBase::acc_first(yv1, xv[2*ix+0]);
+// acc[2*ix+1] = QFBase::acc_first(yv2, xv[2*ix+1]);
+// }
+// for (int i = 1; i < nb/2; ++i) {
+// yv1 = y.load1(0, 2*i+0);
+// yv2 = y.load1(0, 2*i+1);
+// for (int ix = 0; ix < Qx::nrc; ++ix) {
+// xv[2*ix+0] = x.load1(ix, 2*i+0);
+// xv[2*ix+1] = x.load1(ix, 2*i+1);
+// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[2*ix+0]);
+// acc[2*ix+1] = QFBase::acc(acc[2*ix+1], yv2, xv[2*ix+1]);
+// }
+// }
+// for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
+// yv1 = y.load_tail(0, i);
+// for (int ix = 0; ix < Qx::nrc; ++ix) {
+// xv[ix] = x.load_tail(ix, i);
+// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[ix]);
+// }
+// }
+// for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, 0, QFBase::hsum(QFBase::add(acc[2*ix+0], acc[2*ix+1])));
+//}
+
template <typename Qy, typename Qx>
IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
int nb = n/QFBase::k_step;
@@ -7287,12 +7431,29 @@ inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, co
// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.
template <int nrc_y, typename FloatX, typename FloatY>
void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const char * cx = (const char *)vx;
+ // TBD if we want this
+ //if constexpr (nrc_y == 1) {
+ // constexpr int k_nx = 2;
+ // for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ // mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
+ // }
+ // if (int lastx = k_nx*(nrc_x/k_nx); lastx < nrc_x) {
+ // int nx = nrc_x - lastx;
+ // switch (nx) {
+ // case 1: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info); break;
+ // case 2: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, lastx, info); break;
+ // case 3: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, lastx, info); break;
+ // }
+ // //mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info);
+ // }
+ // return;
+ //}
#ifdef __AVX512F__
constexpr int k_nx = 5;
#else
constexpr int k_nx = nrc_y == 1 ? 4 : 2;
#endif
- const char * cx = (const char *)vx;
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
}
@@ -12146,7 +12307,6 @@ void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
Q8<nrc_y, block_q8_0_x4> q8(info);
Dequantizer deq(vx, bx);
int nb = n / QK4_NL;
- GGML_ASSERT(nb%4 == 0);
int8x16_t qx[16];
float d8[4*nrc_y];
float32x4_t acc[2*nrc_y] = {};
@@ -12168,6 +12328,18 @@ void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = deq.prepare(ib, 0, qx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_0 *)q8.y[iy];
+ auto y = vld1q_s8_x2(qy[ib].qs);
+ auto sumi1 = interleaved_dotq(qx+0, y);
+ auto sumi2 = interleaved_dotq(qx+8, y);
+ auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1));
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2));
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, deq.result(acc[2*iy+0]));
info.store(ix+4, iy, deq.result(acc[2*iy+1]));
@@ -12312,12 +12484,32 @@ struct Q6_0_R4_Dequantizer {
const int8x16_t m32 = vdupq_n_s8(-32);
};
+inline void qx_0_q8_0_dot(const int8x16_t * qx, const int8_t * qy, int32x4_t& sumi1, int32x4_t& sumi2) {
+ auto y = vld1q_s8_x2(qy);
+ sumi1 = sumi2 = vdupq_n_s32(0);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
+}
+
template <int nrc_y>
void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
int nb = n / QK8_0;
- GGML_ASSERT(nb%4 == 0);
float32x4_t acc[2*nrc_y] = {};
int8x16_t qx[16];
float d8[4*nrc_y];
@@ -12332,32 +12524,29 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j);
+ int32x4_t sumi1, sumi2;
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
- auto sumi1 = vdupq_n_s32(0);
- auto sumi2 = vdupq_n_s32(0);
- sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
- sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
- sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
- sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
- sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
- sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
- sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
- sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
- sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
- sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
- sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
- sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
- sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
- sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
- sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
- sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
+ qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2);
auto dy = vdupq_n_f32(d8[4*iy+k]);
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
}
}
}
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d);
+ auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
+ auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
+ for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j);
+ int32x4_t sumi1, sumi2;
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_0 *)q8.y[iy];
+ qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2);
+ auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, acc[2*iy+0]);
info.store(ix+4, iy, acc[2*iy+1]);
@@ -13033,10 +13222,10 @@ struct HelperQ80R4 : public BaseHelper<step> {
m2 = _mm256_unpacklo_epi64(t2, t3);
m3 = _mm256_unpackhi_epi64(t2, t3);
#ifdef HAVE_FANCY_SIMD
- m0 = _mm256_xor_si256(m0, _mm256_set1_epi8(-128));
- m1 = _mm256_xor_si256(m1, _mm256_set1_epi8(-128));
- m2 = _mm256_xor_si256(m2, _mm256_set1_epi8(-128));
- m3 = _mm256_xor_si256(m3, _mm256_set1_epi8(-128));
+ m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
+ m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
+ m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
+ m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
#endif
_mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0);
_mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1);
@@ -13055,10 +13244,10 @@ struct HelperQ80R4 : public BaseHelper<step> {
m2 = _mm256_unpacklo_epi64(t2, t3);
m3 = _mm256_unpackhi_epi64(t2, t3);
#ifdef HAVE_FANCY_SIMD
- m0 = _mm256_xor_si256(m0, _mm256_set1_epi8(-128));
- m1 = _mm256_xor_si256(m1, _mm256_set1_epi8(-128));
- m2 = _mm256_xor_si256(m2, _mm256_set1_epi8(-128));
- m3 = _mm256_xor_si256(m3, _mm256_set1_epi8(-128));
+ m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
+ m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
+ m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
+ m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
#endif
_mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0);
_mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1);
@@ -13895,16 +14084,11 @@ struct FlashQKfp32 {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
#else
- if constexpr (D >= 128) {
#ifdef HAVE_FANCY_SIMD
- MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q8_0_1_Unpacker, nq);
+ MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q8_0_1_Unpacker, nq);
#else
- MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
+ MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
#endif
- } else {
- // This does not actually work until we fix K-cache to be quantized to Q8_0_x4 only if D%128 == 0
- MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
- }
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {