summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml-common.h15
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp690
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp150
-rw-r--r--ggml/src/iqk/iqk_quantize.h4
-rw-r--r--src/llama.cpp8
6 files changed, 437 insertions, 431 deletions
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 7f79b27b..d08870ad 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -236,6 +236,11 @@ typedef struct {
int8_t qs[4*QK8_0];
} block_q8_0_x4;
static_assert(sizeof(block_q8_0_x4) == 4*sizeof(block_q8_0), "wrong q8_0_x4 block size/padding");
+typedef struct {
+ ggml_half d[8];
+ int8_t qs[8*QK8_0];
+} block_q8_0_r8;
+static_assert(sizeof(block_q8_0_r8) == 8*sizeof(block_q8_0), "wrong q8_0_r8 block size/padding");
typedef struct {
ggml_half d[4]; // deltas for 4 q4_0 blocks
@@ -534,12 +539,12 @@ typedef struct {
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
typedef struct {
- ggml_half d[4];
- uint8_t scales_h[QK_K/32];
- uint8_t scales_l[QK_K/16];
- uint8_t qs[QK_K*2];
+ ggml_half d[8];
+ uint8_t scales_h[QK_K/16];
+ uint8_t scales_l[QK_K/ 8];
+ uint8_t qs[QK_K*4];
} block_iq4_xs_r4;
-static_assert(sizeof(block_iq4_xs_r4) == 4*sizeof(ggml_half) + QK_K/32 + QK_K/16 + QK_K*2, "wrong iq4_xs_rs block size/padding");
+static_assert(sizeof(block_iq4_xs_r4) == 8*sizeof(block_iq4_xs), "wrong iq4_xs_rs block size/padding");
typedef struct {
uint8_t scales[QK_K/32];
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index 23ac9915..391d9e2e 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -936,7 +936,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
#if defined(__ARM_NEON)
for (int i = 0; i < nb; i++) {
- int i4 = i/4, ir = i%4;
float32x4_t srcv [8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 7ddaee2a..d8273415 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -245,9 +245,7 @@ struct MulMat {
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
- case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_NL_R4:
- case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_IQ2_K_R4:
case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_K_R4:
@@ -259,6 +257,8 @@ struct MulMat {
case GGML_TYPE_IQ3_XXS_R4:
case GGML_TYPE_IQ3_S_R4:
case GGML_TYPE_IQ2_BN_R4: return 4;
+ case GGML_TYPE_IQ4_XS_R4:
+ case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_BF16_R16: return 16;
default: return 1;
@@ -2902,91 +2902,103 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
+ GGML_ASSERT(nrc_x%16 == 0);
Q8<nrc_y, block_q8_1_x4> q8(info);
int nb = n / QK8_0;
GGML_ASSERT(nb%4 == 0);
if constexpr (nrc_y == 1) {
auto m127 = _mm256_set1_epi8(127);
- auto m1 = _mm256_set1_epi16(1);
- __m256 acc[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx);
+ __m256 acc[2] = {};
+ __m256i qx[8];
+ float d8[8];
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ _mm256_storeu_ps(d8, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)));
for (int k = 0; k < 4; ++k) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq8[4*ib4+k].d));
- auto scales = _mm256_set_m128(scales128, scales128);
- auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-63.5f));
- auto q1 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0), m127);
- auto q2 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1), m127);
- auto q3 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2), m127);
- auto q4 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3), m127);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
- auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55))));
- auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff))));
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[iy]);
- }
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d));
+ auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-127.f));
+ qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0), m127);
+ qx[1] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1), m127);
+ qx[2] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2), m127);
+ qx[3] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3), m127);
+ qx[4] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4), m127);
+ qx[5] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5), m127);
+ qx[6] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6), m127);
+ qx[7] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7), m127);
+ auto y4l = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+0);
+ auto y4h = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+1);
+ auto yl = MM256_SET_M128I(y4l, y4l);
+ auto yh = MM256_SET_M128I(y4h, y4h);
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff));
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[k]));
+ acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]);
+ acc[1] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[k+4]), acc[1]);
}
}
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- info.store(ix, iy, sum);
- acc[iy] = _mm256_setzero_ps();
- }
+ info.store(ix, 0, _mm256_add_ps(acc[0], acc[1]));
+ acc[0] = acc[1] = _mm256_setzero_ps();
}
} else {
__m512 acc[2*nrc_y] = {};
- __m512i qx[4];
+ __m512i qx[8];
+ float d8[8*nrc_y];
auto m127 = _mm512_set1_epi8(127);
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_q8_0_x4 * q8l = (const block_q8_0_x4 *)((const char *)vx + (ix+0)*bx);
- const block_q8_0_x4 * q8h = (const block_q8_0_x4 *)((const char *)vx + (ix+4)*bx);
+ for (int ix = 0; ix < nrc_x; ix += 16) {
+ const block_q8_0_r8 * q8l = (const block_q8_0_r8 *)((const char *)vx + (ix+0)*bx);
+ const block_q8_0_r8 * q8h = (const block_q8_0_r8 *)((const char *)vx + (ix+8)*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)));
+ }
for (int k = 0; k < 4; ++k) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8l[4*ib4+k].d));
- auto scales1 = _mm256_set_m128(scales128, scales128);
- scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8h[4*ib4+k].d));
- auto scales2 = _mm256_set_m128(scales128, scales128);
- auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
- auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-63.5f));
- qx[0] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+0)),
- _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+0), 1);
- qx[1] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+1)),
- _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+1), 1);
- qx[2] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+2)),
- _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+2), 1);
- qx[3] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+3)),
- _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+3), 1);
- qx[0] = _mm512_add_epi8(qx[0], m127);
- qx[1] = _mm512_add_epi8(qx[1], m127);
- qx[2] = _mm512_add_epi8(qx[2], m127);
- qx[3] = _mm512_add_epi8(qx[3], m127);
+ auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[4*ib4+k].d));
+ auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[4*ib4+k].d));
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-127.f));
+ for (int j = 0; j < 8; ++j) {
+ qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+j)),
+ _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+j), 1);
+ qx[j] = _mm512_add_epi8(qx[j], m127);
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
- auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
+ auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0);
+ auto y4h = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1);
+ auto y8l = MM256_SET_M128I(y4l, y4l);
+ auto y8h = MM256_SET_M128I(y4h, y4h);
+ auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
+ auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
+ auto dy = _mm512_set1_ps(d8[8*iy+k]);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
- acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]);
+ info.store(ix, iy, sum512);
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
- auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
- auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
- info.store(ix+0, iy, sum1);
- info.store(ix+4, iy, sum2);
+ //auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
+ //auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
+ //info.store(ix+0, iy, sum1);
+ //info.store(ix+4, iy, sum2);
}
}
}
@@ -2994,45 +3006,72 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
#else
template <int nrc_y>
static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_1_x4> q8(info);
auto m1 = _mm256_set1_epi16(1);
int nb = n / QK8_0;
GGML_ASSERT(nb%4 == 0);
__m256 acc[nrc_y] = {};
float d8[4*nrc_y];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx);
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d));
_mm_storeu_ps(d8 + 4*iy, scales);
}
for (int k = 0; k < 4; ++k) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq8[4*ib4+k].d));
- auto scales = _mm256_set_m128(scales128, scales128);
- auto q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0);
- auto q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1);
- auto q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2);
- auto q4 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3);
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d));
+ auto q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0);
+ auto q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1);
+ auto q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2);
+ auto q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3);
+ auto s0 = _mm256_sign_epi8(q0, q0);
auto s1 = _mm256_sign_epi8(q1, q1);
auto s2 = _mm256_sign_epi8(q2, q2);
auto s3 = _mm256_sign_epi8(q3, q3);
- auto s4 = _mm256_sign_epi8(q4, q4);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
- auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q1))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q2))));
- auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q3))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q4))));
+ auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0);
+ auto y = MM256_SET_M128I(y128, y128);
+ auto sumi1 = _mm256_add_epi32(
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1)))
+ );
+ auto sumi2 = _mm256_add_epi32(
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3)))
+ );
+ auto sumi = _mm256_add_epi32(sumi1, sumi2);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4);
+ q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5);
+ q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6);
+ q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7);
+ s0 = _mm256_sign_epi8(q0, q0);
+ s1 = _mm256_sign_epi8(q1, q1);
+ s2 = _mm256_sign_epi8(q2, q2);
+ s3 = _mm256_sign_epi8(q3, q3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1);
+ auto y = MM256_SET_M128I(y128, y128);
+ auto sumi1 = _mm256_add_epi32(
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1)))
+ );
+ auto sumi2 = _mm256_add_epi32(
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3)))
+ );
+ auto sumi = _mm256_add_epi32(sumi1, sumi2);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- info.store(ix, iy, sum);
+ info.store(ix, iy, acc[iy]);
acc[iy] = _mm256_setzero_ps();
}
}
@@ -3041,9 +3080,11 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
template <int nrc_y>
static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
+ auto m30 = _mm256_set1_epi8(0x30);
+ auto m32 = _mm256_set1_epi8(32);
#ifndef HAVE_FANCY_SIMD
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
@@ -3052,40 +3093,40 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const
auto values = load_iq4nl_values_256();
#endif
int nbl = n / QK_K;
- using helper_t = union { __m256i vec; uint32_t val[8]; };
+ using helper_t = union { __m256i vec[2]; uint64_t val[8]; };
helper_t h;
__m256 acc[nrc_y] = {};
- __m256i isum[nrc_y] = {};
__m256i qx[4];
- for (int ix = 0; ix < nrc_x; ix += 4) {
+ for (int ix = 0; ix < nrc_x; ix += 8) {
const block_iq4_xs_r4 * iq4 = (const block_iq4_xs_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[ibl].d));
- auto d4 = _mm256_set_m128(dl, dl);
- auto slbits = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_l);
- auto sl = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(slbits, 4), slbits), _mm256_set1_epi8(0xf));
- auto aux64 = (const uint64_t *)iq4[ibl].scales_h;
- auto shbits = _mm_set_epi64x(aux64[0] >> 2, aux64[0]);
- auto sh = _mm256_and_si256(MM256_SET_M128I(shbits, _mm_slli_epi16(shbits, 4)), _mm256_set1_epi8(0x30));
- h.vec = _mm256_sub_epi8(_mm256_or_si256(sl, sh), _mm256_set1_epi8(32));
+ auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
+ auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
+ auto sl1 = _mm256_and_si256(slbits, m4);
+ auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4);
+ auto shbits = _mm_loadu_si128((const __m128i*)iq4[ibl].scales_h);
+ auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
+ h.vec[0] = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(_mm256_slli_epi16(sh, 4), m30)), m32);
+ h.vec[1] = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(sh, m30)), m32);
+ __m256i isum[nrc_y] = {};
for (int ib = 0; ib < QK_K/32; ++ib) {
#ifdef HAVE_FANCY_SIMD
- auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib]));
+ auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi64x(h.val[ib]));
auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
- auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-64.f));
+ auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-128.f));
for (int iy = 0; iy < nrc_y; ++iy) {
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
}
#else
- auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle);
+ auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(h.val[ib])), s_shuffle);
#endif
- auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
- auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
- qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
- qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
- qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
- qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+1);
+ qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1));
+ qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4)));
+ qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2));
+ qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4)));
#ifndef HAVE_FANCY_SIMD
auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
@@ -3093,7 +3134,8 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const
auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+ auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+0);
+ auto y = MM256_SET_M128I(y128, y128);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
@@ -3106,20 +3148,51 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
+ auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)),
+ _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
+ isum[iy] = _mm256_add_epi32(isum[iy], sumi);
+#endif
+ }
+ bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+2);
+ bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+3);
+ qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1));
+ qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4)));
+ qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2));
+ qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4)));
+#ifndef HAVE_FANCY_SIMD
+ s1 = _mm256_sign_epi8(qx[0], qx[0]);
+ s2 = _mm256_sign_epi8(qx[1], qx[1]);
+ s3 = _mm256_sign_epi8(qx[2], qx[2]);
+ s4 = _mm256_sign_epi8(qx[3], qx[3]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+1);
+ auto y = MM256_SET_M128I(y128, y128);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
+ auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
+ auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
+ auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
+ auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)),
+ _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
+ isum[iy] = _mm256_add_epi32(isum[iy], sumi);
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm256_setzero_si256();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, acc[iy]);
acc[iy] = _mm256_setzero_ps();
- info.store(ix+0, iy, sum);
}
}
}
@@ -3127,6 +3200,8 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ mul_mat_iq4_xs_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
+ return;
if constexpr (nrc_y == 1){
mul_mat_iq4_xs_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x);
} else {
@@ -10529,6 +10604,13 @@ IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint
qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
}
+IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x2_t& bits, int8x16_t * qx) {
+ qx[0] = vqtbl1q_s8(values, vandq_u8( bits.val[0], m4));
+ qx[1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4));
+ qx[2] = vqtbl1q_s8(values, vandq_u8( bits.val[1], m4));
+ qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4));
+}
+
template <int nrc_y>
void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -10539,43 +10621,92 @@ void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i
auto values = vld1q_s8(iq4k_values);
int nbl = n / QK_K;
int8x16_t qx[8];
- int8x16x2_t iscales;
- int32x4x4_t scales;
- float32x4_t acc[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
+ int8x16x4_t iscales;
+ int32x4x2_t scales;
+ float32x4_t acc[2*nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 8) {
const block_iq4_xs_r4 * iq4 = (const block_iq4_xs_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
- auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d));
- auto sl = vld1q_u8(iq4[ibl].scales_l);
- auto sh8 = vld1_u8(iq4[ibl].scales_h);
- auto sh = vcombine_u8(sh8, vshr_n_u8(sh8, 2));
- iscales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl, m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
- iscales.val[1] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl, 4), vandq_u8(sh, m3)), m32);
+ auto d4_f16 = vld1q_f16((const float16_t *)iq4[ibl].d);
+ auto d4l = vcvt_f32_f16(vget_low_f16 (d4_f16));
+ auto d4h = vcvt_f32_f16(vget_high_f16(d4_f16));
+ auto sl = vld1q_u8_x2(iq4[ibl].scales_l);
+ auto sh = vld1q_u8(iq4[ibl].scales_h);
+ iscales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
+ iscales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
+ iscales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
+ iscales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
int32x4_t isum[nrc_y] = {};
- for (int is = 0; is < 2; ++is) {
- auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
- auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+ auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[ib64]));
+ auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[ib64]));
scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1));
- scales.val[1] = vmovl_s16(vget_high_s16(iscales16_1));
- scales.val[2] = vmovl_s16(vget_low_s16(iscales16_2));
- scales.val[3] = vmovl_s16(vget_high_s16(iscales16_2));
- for (int ib = 0; ib < 4; ++ib) {
- auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
- prepare_iq4_nl_quants(values, m4, bits, qx);
+ scales.val[1] = vmovl_s16(vget_low_s16(iscales16_2));
+ for (int l = 0; l < 2; ++l) {
+ uint8x16x2_t bits;
+ bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l);
+ bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 32);
+ prepare_iq4_nl_quants_r8(values, m4, bits, qx+0);
+ bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 64);
+ bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 96);
+ prepare_iq4_nl_quants_r8(values, m4, bits, qx+4);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+64*ib64+32*l);
+ auto sumi = vdupq_n_s32(0);
+ sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[1], y.val[0], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[3], y.val[0], 3);
+ sumi = vdotq_laneq_s32(sumi, qx[4], y.val[1], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[6], y.val[1], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
+ isum[iy] = vmlaq_s32(isum[iy], sumi, scales.val[l]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[iy]));
+ isum[iy] = vdupq_n_s32(0);
+ }
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+ auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[ib64]));
+ auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[ib64]));
+ scales.val[0] = vmovl_s16(vget_high_s16(iscales16_1));
+ scales.val[1] = vmovl_s16(vget_high_s16(iscales16_2));
+ for (int l = 0; l < 2; ++l) {
+ uint8x16x2_t bits;
+ bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 16);
+ bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 48);
+ prepare_iq4_nl_quants_r8(values, m4, bits, qx+0);
+ bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 80);
+ bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l +112);
+ prepare_iq4_nl_quants_r8(values, m4, bits, qx+4);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+64*ib64+32*l);
+ auto sumi = vdupq_n_s32(0);
+ sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[1], y.val[0], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[3], y.val[0], 3);
+ sumi = vdotq_laneq_s32(sumi, qx[4], y.val[1], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[6], y.val[1], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
+ isum[iy] = vmlaq_s32(isum[iy], sumi, scales.val[l]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
+ info.store(ix+0, iy, acc[2*iy+0]);
+ info.store(ix+4, iy, acc[2*iy+1]);
+ acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
}
}
}
@@ -12045,81 +12176,54 @@ struct Q6_0_R4_Dequantizer {
template <int nrc_y>
void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
int nb = n / QK8_0;
GGML_ASSERT(nb%4 == 0);
- float32x4_t acc[nrc_y] = {};
+ float32x4_t acc[2*nrc_y] = {};
+ int8x16_t qx[16];
float d8[4*nrc_y];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx);
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
- auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[4*ib4+k].d));
- auto qx1 = vld1q_s8_x4(iq8[4*ib4+k].qs);
- auto qx2 = vld1q_s8_x4(iq8[4*ib4+k].qs+64);
+ auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d);
+ auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
+ auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
+ for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
- auto sumi = vdupq_n_s32(0);
- sumi = vdotq_laneq_s32(sumi, qx1.val[0], y.val[0], 0);
- sumi = vdotq_laneq_s32(sumi, qx1.val[1], y.val[1], 0);
- sumi = vdotq_laneq_s32(sumi, qx1.val[2], y.val[0], 1);
- sumi = vdotq_laneq_s32(sumi, qx1.val[3], y.val[1], 1);
- sumi = vdotq_laneq_s32(sumi, qx2.val[0], y.val[0], 2);
- sumi = vdotq_laneq_s32(sumi, qx2.val[1], y.val[1], 2);
- sumi = vdotq_laneq_s32(sumi, qx2.val[2], y.val[0], 3);
- sumi = vdotq_laneq_s32(sumi, qx2.val[3], y.val[1], 3);
- auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k]));
- acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
+ auto sumi1 = vdupq_n_s32(0);
+ auto sumi2 = vdupq_n_s32(0);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
+ auto dy = vdupq_n_f32(d8[4*iy+k]);
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-void mul_mat_q8_0_r4_q8_0_128(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- GGML_ASSERT(n == 128);
- int8x16x4_t qx[8];
- float32x4_t scales[4];
- float32x4_t scales_y[4];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx);
- for (int k = 0; k < 4; ++k) {
- scales[k] = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[k].d));
- qx[2*k+0] = vld1q_s8_x4(iq8[k].qs);
- qx[2*k+1] = vld1q_s8_x4(iq8[k].qs+64);
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto by = (const block_q8_0_x4 *)info.src1_row(iy);
- auto d8 = vcvt_f32_f16(vld1_f16((const float16_t *)by->d));
- scales_y[0] = vmulq_laneq_f32(scales[0], d8, 0);
- scales_y[1] = vmulq_laneq_f32(scales[1], d8, 1);
- scales_y[2] = vmulq_laneq_f32(scales[2], d8, 2);
- scales_y[3] = vmulq_laneq_f32(scales[3], d8, 3);
- auto sumf = vdupq_n_f32(0.f);
- for (int k = 0; k < 4; ++k) {
- auto y = vld1q_s8_x2(by->qs+32*k);
- auto sumi = vdupq_n_s32(0);
- sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[0], y.val[0], 0);
- sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[1], y.val[1], 0);
- sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[2], y.val[0], 1);
- sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[3], y.val[1], 1);
- sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[0], y.val[0], 2);
- sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[1], y.val[1], 2);
- sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[2], y.val[0], 3);
- sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[3], y.val[1], 3);
- sumf = vfmaq_f32(sumf, scales_y[k], vcvtq_f32_s32(sumi));
- }
- info.store(ix, iy, sumf);
+ info.store(ix+0, iy, acc[2*iy+0]);
+ info.store(ix+4, iy, acc[2*iy+1]);
+ acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f);
}
}
}
@@ -12763,22 +12867,25 @@ struct HelperQ80R4 : public BaseHelper<step> {
Base::stride = (D/QK8_0)*sizeof(block_q8_0);
}
- static std::vector<block_q8_0_x4> repack(int nk, const HelperQ80<D, step> q8) {
+ static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step> q8) {
static_assert(D%QK8_0 == 0);
- GGML_ASSERT(nk%4 == 0);
+ GGML_ASSERT(nk%8 == 0);
constexpr int nblock = D/QK8_0;
- std::vector<block_q8_0_x4> result(nblock * nk/4);
+ std::vector<block_q8_0_r8> result(nblock * nk/8);
auto y = result.data();
- const block_q8_0 * x4[4];
- for (int row = 0; row < nk; row += 4) {
- for (int k = 0; k < 4; ++k) x4[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride);
+ const block_q8_0 * x8[8];
+#ifdef __ARM_NEON
+ int8x16x2_t m0, m1, m2, m3;
+#endif
+ for (int row = 0; row < nk; row += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride);
for (int ib = 0; ib < nblock; ++ib) {
- for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d;
+ for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d;
#ifdef __AVX2__
- auto m0 = _mm256_loadu_si256((const __m256i *)x4[0][ib].qs);
- auto m1 = _mm256_loadu_si256((const __m256i *)x4[1][ib].qs);
- auto m2 = _mm256_loadu_si256((const __m256i *)x4[2][ib].qs);
- auto m3 = _mm256_loadu_si256((const __m256i *)x4[3][ib].qs);
+ auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs), _mm_loadu_si128((const __m128i *)x8[0][ib].qs));
+ auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs), _mm_loadu_si128((const __m128i *)x8[1][ib].qs));
+ auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs), _mm_loadu_si128((const __m128i *)x8[2][ib].qs));
+ auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs), _mm_loadu_si128((const __m128i *)x8[3][ib].qs));
auto t0 = _mm256_unpacklo_epi32(m0, m1);
auto t1 = _mm256_unpacklo_epi32(m2, m3);
auto t2 = _mm256_unpackhi_epi32(m0, m1);
@@ -12791,32 +12898,50 @@ struct HelperQ80R4 : public BaseHelper<step> {
_mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1);
_mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2);
_mm256_storeu_si256((__m256i *)y[ib].qs + 3, m3);
+ m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[0][ib].qs+1));
+ m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[1][ib].qs+1));
+ m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[2][ib].qs+1));
+ m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[3][ib].qs+1));
+ t0 = _mm256_unpacklo_epi32(m0, m1);
+ t1 = _mm256_unpacklo_epi32(m2, m3);
+ t2 = _mm256_unpackhi_epi32(m0, m1);
+ t3 = _mm256_unpackhi_epi32(m2, m3);
+ m0 = _mm256_unpacklo_epi64(t0, t1);
+ m1 = _mm256_unpackhi_epi64(t0, t1);
+ m2 = _mm256_unpacklo_epi64(t2, t3);
+ m3 = _mm256_unpackhi_epi64(t2, t3);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 7, m3);
#elif defined __ARM_NEON
- auto m0 = vld1q_s8_x2(x4[0][ib].qs);
- auto m1 = vld1q_s8_x2(x4[1][ib].qs);
- auto m2 = vld1q_s8_x2(x4[2][ib].qs);
- auto m3 = vld1q_s8_x2(x4[3][ib].qs);
- auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
- auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
- m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
- m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
- m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
- m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
- row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
- row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
- m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
- m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
- m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
- m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
- vst1q_s8_x2(y[ib].qs + 0, m0);
- vst1q_s8_x2(y[ib].qs + 32, m1);
- vst1q_s8_x2(y[ib].qs + 64, m2);
- vst1q_s8_x2(y[ib].qs + 96, m3);
+ for (int l = 0; l < 2; ++l) {
+ m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l);
+ m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l);
+ m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l);
+ m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l);
+ auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
+ auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
+ m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
+ row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
+ m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0);
+ vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1);
+ vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2);
+ vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3);
+ }
#else
for (int l = 0; l < 4; ++l) {
- for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) {
- y[ib].qs[32*l+4*k+i+ 0] = x4[k][ib].qs[i+4*l+ 0];
- y[ib].qs[32*l+4*k+i+16] = x4[k][ib].qs[i+4*l+16];
+ for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
+ y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
+ y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
}
}
#endif
@@ -12826,7 +12951,7 @@ struct HelperQ80R4 : public BaseHelper<step> {
return result;
}
- std::vector<block_q8_0_x4> r4;
+ std::vector<block_q8_0_r8> r4;
};
template <int D, int step>
@@ -13370,78 +13495,6 @@ struct FlashQKV {
qkv_cache_t qkv_cache[D*q_step] = {};
};
-#ifdef HAVE_FANCY_SIMD
-template <int nrc_y>
-static void mul_mat_q8_0_r4_q8_1_128([[maybe_unused]] int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
- GGML_ASSERT(n == 128);
- //Q8<nrc_y, block_q8_1_x4> q8(info);
- __m512i qx[16];
- __m512 scales[4];
- __m512 scales_m[4];
- __m512 dy[4];
- auto m127 = _mm512_set1_epi8(127);
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_q8_0_x4 * q8l = (const block_q8_0_x4 *)((const char *)vx + (ix+0)*bx);
- const block_q8_0_x4 * q8h = (const block_q8_0_x4 *)((const char *)vx + (ix+4)*bx);
- for (int k = 0; k < 4; ++k) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8l[k].d));
- auto scales1 = _mm256_set_m128(scales128, scales128);
- scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8h[k].d));
- auto scales2 = _mm256_set_m128(scales128, scales128);
- scales[k] = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
- scales_m[k] = _mm512_mul_ps(scales[k], _mm512_set1_ps(-63.5f));
- qx[4*k+0] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+0)),
- _mm256_loadu_si256((const __m256i *)q8h[k].qs+0), 1);
- qx[4*k+1] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+1)),
- _mm256_loadu_si256((const __m256i *)q8h[k].qs+1), 1);
- qx[4*k+2] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+2)),
- _mm256_loadu_si256((const __m256i *)q8h[k].qs+2), 1);
- qx[4*k+3] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+3)),
- _mm256_loadu_si256((const __m256i *)q8h[k].qs+3), 1);
- qx[4*k+0] = _mm512_add_epi8(qx[4*k+0], m127);
- qx[4*k+1] = _mm512_add_epi8(qx[4*k+1], m127);
- qx[4*k+2] = _mm512_add_epi8(qx[4*k+2], m127);
- qx[4*k+3] = _mm512_add_epi8(qx[4*k+3], m127);
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto by = (const block_q8_1_x4 *)info.src1_row(iy);
- //auto dall = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][0].d));
- auto dall = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)by->d));
- auto d128 = _mm256_castps256_ps128(dall);
- auto m128 = _mm256_extractf128_ps(dall, 1);
- auto m256 = _mm256_set_m128(m128, m128);
- auto m512 = _mm512_insertf32x8(_mm512_castps256_ps512(m256), m256, 1);
- auto sumf = _mm512_mul_ps(scales_m[0], _mm512_shuffle_ps(m512, m512, 0x00));
- sumf = _mm512_fmadd_ps(scales_m[1], _mm512_shuffle_ps(m512, m512, 0x55), sumf);
- sumf = _mm512_fmadd_ps(scales_m[2], _mm512_shuffle_ps(m512, m512, 0xaa), sumf);
- sumf = _mm512_fmadd_ps(scales_m[3], _mm512_shuffle_ps(m512, m512, 0xff), sumf);
- auto d256 = _mm256_set_m128(d128, d128);
- auto d512 = _mm512_insertf32x8(_mm512_castps256_ps512(d256), d256, 1);
- dy[0] = _mm512_mul_ps(scales[0], _mm512_shuffle_ps(d512, d512, 0x00));
- dy[1] = _mm512_mul_ps(scales[1], _mm512_shuffle_ps(d512, d512, 0x55));
- dy[2] = _mm512_mul_ps(scales[2], _mm512_shuffle_ps(d512, d512, 0xaa));
- dy[3] = _mm512_mul_ps(scales[3], _mm512_shuffle_ps(d512, d512, 0xff));
- for (int k = 0; k < 4; ++k) {
- //auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][0].qs+k);
- auto y8 = _mm256_loadu_si256((const __m256i*)by->qs+k);
- auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- sumf = _mm512_fmadd_ps(dy[k], _mm512_cvtepi32_ps(sumi), sumf);
- }
- auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sumf, 0), _mm512_extractf32x4_ps(sumf, 1));
- auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sumf, 2), _mm512_extractf32x4_ps(sumf, 3));
- info.store(ix+0, iy, sum1);
- info.store(ix+4, iy, sum2);
- }
- }
-}
-#endif
-
template <int D, int q_step, int k_step>
struct FlashQKfp32 {
static_assert(D%F16::block_size == 0 && D <= 256);
@@ -13706,45 +13759,10 @@ struct FlashQKfp32 {
}
else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {
#ifdef __aarch64__
- if constexpr (D == 128) {
- if (q_step >= 64 && nq >= 64) {
- return std::make_pair(mul_mat_q8_0_r4_q8_0_128<64>, 64);
- }
- else if (q_step >= 32 && nq >= 32) {
- return std::make_pair(mul_mat_q8_0_r4_q8_0_128<32>, 32);
- }
- else if (q_step >= 16 && nq >= 16) {
- return std::make_pair(mul_mat_q8_0_r4_q8_0_128<16>, 16);
- }
- else {
- MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0_128, nq);
- }
- } else {
- MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
- }
- //MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
-#else
-#ifdef HAVE_FANCY_SIMD
- if constexpr (D == 128) {
- if (q_step >= 64 && nq >= 64) {
- return std::make_pair(mul_mat_q8_0_r4_q8_1_128<64>, 64);
- }
- else if (q_step >= 32 && nq >= 32) {
- return std::make_pair(mul_mat_q8_0_r4_q8_1_128<32>, 32);
- }
- else if (q_step >= 16 && nq >= 16) {
- return std::make_pair(mul_mat_q8_0_r4_q8_1_128<16>, 16);
- }
- else {
- MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1_128, nq);
- }
- } else {
- MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq);
- }
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
#else
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq);
#endif
-#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
#ifdef __aarch64__
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 221bc48c..59a36c5c 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -3709,63 +3709,63 @@ void vec_dot_q4_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t b
//
// ========================================= q8_0_r4
//
-void quantize_row_q8_0_r4_ref(const float * x, block_q8_0_x4 * y, int64_t k) {
+void quantize_row_q8_0_r4_ref(const float * x, block_q8_0_r8 * y, int64_t k) {
// we assume we are called with 4 rows
- quantize_q8_0_r4(x, (void *)y, 4, k/4, nullptr);
+ quantize_q8_0_r4(x, (void *)y, 8, k/8, nullptr);
}
void quantize_row_q8_0_r4(const float * x, void * y, int64_t k) {
// we assume we are called with 4 rows
- quantize_q8_0_r4(x, y, 4, k/4, nullptr);
+ quantize_q8_0_r4(x, y, 8, k/8, nullptr);
}
-static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8_0_x4 * y) {
- GGML_ASSERT(nrows%4 == 0);
+static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8_0_r8 * y) {
+ GGML_ASSERT(nrows%8 == 0);
GGML_ASSERT(n_per_row%QK8_0 == 0);
int nblock = n_per_row/QK8_0;
- const block_q8_0 * x4[4];
- for (int row = 0; row < nrows; row += 4) {
- for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k;
+ const block_q8_0 * x8[8];
+ for (int row = 0; row < nrows; row += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = x + nblock*k;
for (int ib = 0; ib < nblock; ++ib) {
- for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d;
+ for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d;
for (int l = 0; l < 4; ++l) {
- for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) {
- y[ib].qs[32*l+4*k+i+ 0] = x4[k][ib].qs[i+4*l+ 0];
- y[ib].qs[32*l+4*k+i+16] = x4[k][ib].qs[i+4*l+16];
+ for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
+ y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
+ y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
}
}
}
- x += 4*nblock;
+ x += 8*nblock;
y += nblock;
}
}
size_t quantize_q8_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
- GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(nrows%8 == 0);
auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
- std::vector<char> qtmp(4*row_size_0);
+ std::vector<char> qtmp(8*row_size_0);
char * qrow = (char *)dst;
- for (int row = 0; row < nrows; row += 4) {
- quantize_q8_0(src, qtmp.data(), 4, n_per_row, imatrix);
- repack_q8_0(4, n_per_row, (const block_q8_0 *)qtmp.data(), (block_q8_0_x4 *)qrow);
- src += 4*n_per_row;
- qrow += 4*row_size_0;
+ for (int row = 0; row < nrows; row += 8) {
+ quantize_q8_0(src, qtmp.data(), 8, n_per_row, imatrix);
+ repack_q8_0(8, n_per_row, (const block_q8_0 *)qtmp.data(), (block_q8_0_r8 *)qrow);
+ src += 8*n_per_row;
+ qrow += 8*row_size_0;
}
return nrows*row_size_0;
}
-void dequantize_row_q8_0_r4(const block_q8_0_x4 * x, float * y, int64_t k) {
+void dequantize_row_q8_0_r4(const block_q8_0_r8 * x, float * y, int64_t k) {
// we assume we are called with 4 rows
- int n_per_row = k/4;
+ int n_per_row = k/8;
int nb = n_per_row/QK8_0;
- float * yk[4];
- for (int k = 0; k < 4; ++k) yk[k] = y + k*n_per_row;
+ float * yk[8];
+ for (int k = 0; k < 8; ++k) yk[k] = y + k*n_per_row;
for (int ib = 0; ib < nb; ++ib) {
- for (int k = 0; k < 4; ++k) {
+ for (int k = 0; k < 8; ++k) {
float scale = GGML_FP16_TO_FP32(x[ib].d[k]);
for (int l = 0; l < 4; ++l) for (int i = 0; i < 4; ++i) {
- yk[k][QK8_0*ib+4*l+i+ 0] = scale * x[ib].qs[QK8_0*l+4*k+i+ 0];
- yk[k][QK8_0*ib+4*l+i+16] = scale * x[ib].qs[QK8_0*l+4*k+i+16];
+ yk[k][QK8_0*ib+4*l+i+ 0] = scale * x[ib].qs[32*l+4*k+i+ 0];
+ yk[k][QK8_0*ib+4*l+i+16] = scale * x[ib].qs[32*l+4*k+i+128];
}
}
}
@@ -3987,93 +3987,77 @@ void vec_dot_q6_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t b
//
void quantize_row_iq4_xs_r4_ref(const float * x, block_iq4_xs_r4 * y, int64_t k) {
- quantize_iq4_xs_r4(x, (void *)y, 4, k/4, nullptr);
+ quantize_iq4_xs_r4(x, (void *)y, 8, k/8, nullptr);
}
void quantize_row_iq4_xs_r4(const float * x, void * y, int64_t k) {
- quantize_iq4_xs_r4(x, y, 4, k/4, nullptr);
+ quantize_iq4_xs_r4(x, y, 8, k/8, nullptr);
}
static void repack_iq4_xs(int nrows, int n_per_row, const block_iq4_xs * x, block_iq4_xs_r4 * y) {
- GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(nrows%8 == 0);
GGML_ASSERT(n_per_row%QK_K == 0);
int nblock = n_per_row/QK_K;
- const block_iq4_xs * x4[4];
- for (int row = 0; row < nrows; row += 4) {
- for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k;
+ const block_iq4_xs * x8[8];
+ for (int row = 0; row < nrows; row += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = x + nblock*k;
for (int ibl = 0; ibl < nblock; ++ibl) {
- std::memset(y[ibl].scales_l, 0, QK_K/16);
- std::memset(y[ibl].scales_h, 0, QK_K/32);
- for (int k = 0; k < 4; ++k) {
- y[ibl].d[k] = x4[k][ibl].d;
+ std::memset(y[ibl].scales_l, 0, QK_K/8);
+ std::memset(y[ibl].scales_h, 0, QK_K/16);
+ for (int k = 0; k < 8; ++k) {
+ y[ibl].d[k] = x8[k][ibl].d;
for (int ib = 0; ib < QK_K/32; ++ib) {
- uint8_t sl = (x4[k][ibl].scales_l[ib/2] >> 4*(ib%2)) & 0xf;
- uint8_t sh = (x4[k][ibl].scales_h >> 2*ib) & 3;
- int i = 4*ib + k;
- y[ibl].scales_l[i%16] |= (sl << 4*(i/16));
- y[ibl].scales_h[i%8 ] |= (sh << 2*(i/8));
- }
- }
- for (int ib = 0; ib < QK_K/32; ++ib) {
- for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) {
- y[ibl].qs[64*ib+4*k+i+ 0] = (x4[k][ibl].qs[16*ib+i+0] & 0xf) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0x0f) << 4); // 0....3 + 8...11 from each row
- y[ibl].qs[64*ib+4*k+i+16] = (x4[k][ibl].qs[16*ib+i+0] >> 4) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0xf0)); // 16...19 + 24...27 from each row
- y[ibl].qs[64*ib+4*k+i+32] = (x4[k][ibl].qs[16*ib+i+4] & 0xf) | ((x4[k][ibl].qs[16*ib+i+12] & 0x0f) << 4); // 4....7 + 12...15 from each row
- y[ibl].qs[64*ib+4*k+i+48] = (x4[k][ibl].qs[16*ib+i+4] >> 4) | ((x4[k][ibl].qs[16*ib+i+12] & 0xf0)); // 20...23 + 28...31 from each row
+ uint8_t sl = (x8[k][ibl].scales_l[ib/2] >> 4*(ib%2)) & 0xf;
+ uint8_t sh = (x8[k][ibl].scales_h >> 2*ib) & 3;
+ int i = 8*ib + k;
+ y[ibl].scales_l[i%32] |= (sl << 4*(i/32));
+ y[ibl].scales_h[i%16] |= (sh << 2*(i/16));
+ for (int i = 0; i < 4; ++i) {
+ y[ibl].qs[128*ib+4*k+i+ 0] = (x8[k][ibl].qs[16*ib+i+0] & 0xf) | ((x8[k][ibl].qs[16*ib+i+ 4] & 0xf) << 4);
+ y[ibl].qs[128*ib+4*k+i+32] = (x8[k][ibl].qs[16*ib+i+8] & 0xf) | ((x8[k][ibl].qs[16*ib+i+12] & 0xf) << 4);
+ y[ibl].qs[128*ib+4*k+i+64] = (x8[k][ibl].qs[16*ib+i+0] >> 4) | ((x8[k][ibl].qs[16*ib+i+ 4] >> 4) << 4);
+ y[ibl].qs[128*ib+4*k+i+96] = (x8[k][ibl].qs[16*ib+i+8] >> 4) | ((x8[k][ibl].qs[16*ib+i+12] >> 4) << 4);
+ }
}
}
}
- x += 4*nblock;
+ x += 8*nblock;
y += nblock;
}
}
size_t quantize_iq4_xs_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
- GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(nrows%8 == 0);
GGML_ASSERT(n_per_row%QK_K == 0);
char * qcur = (char *)dst;
auto row_size = ggml_row_size(GGML_TYPE_IQ4_XS, n_per_row);
- std::vector<char> qtmp(4*row_size);
- for (int row = 0; row < nrows; row += 4) {
- quantize_iq4_xs(src, (void *)qtmp.data(), 4, n_per_row, imatrix);
- repack_iq4_xs(4, n_per_row, (const block_iq4_xs *)qtmp.data(), (block_iq4_xs_r4 *)qcur);
- qcur += 4*row_size;
- src += 4*n_per_row;
+ std::vector<char> qtmp(8*row_size);
+ for (int row = 0; row < nrows; row += 8) {
+ quantize_iq4_xs(src, (void *)qtmp.data(), 8, n_per_row, imatrix);
+ repack_iq4_xs(8, n_per_row, (const block_iq4_xs *)qtmp.data(), (block_iq4_xs_r4 *)qcur);
+ qcur += 8*row_size;
+ src += 8*n_per_row;
}
return nrows*row_size;
}
void dequantize_row_iq4_xs_r4(const block_iq4_xs_r4 * x, float * y, int64_t k) {
- auto n_per_row = k/4;
- float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row};
+ auto n_per_row = k/8;
+ float * y8[8];
+ for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k;
int nblock = n_per_row/QK_K;
for (int ibl = 0; ibl < nblock; ++ibl) {
- for (int k = 0; k < 4; ++k) {
+ for (int k = 0; k < 8; ++k) {
const float d = GGML_FP16_TO_FP32(x[ibl].d[k]);
for (int ib = 0; ib < QK_K/32; ++ib) {
- int is = 4*ib + k;
- float dl = d * ((((x[ibl].scales_l[is%16] >> 4*(is/16)) & 0xf) | (((x[ibl].scales_h[is%8] >> 2*(is/8)) & 3) << 4)) - 32);
- for (int i = 0; i < 4; ++i) {
- y4[k][QK_K*ibl+32*ib+i+ 0] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+ 0] & 0xf];
- y4[k][QK_K*ibl+32*ib+i+ 8] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+ 0] >> 4];
- y4[k][QK_K*ibl+32*ib+i+16] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+16] & 0xf];
- y4[k][QK_K*ibl+32*ib+i+24] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+16] >> 4];
- y4[k][QK_K*ibl+32*ib+i+ 4] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+32] & 0xf];
- y4[k][QK_K*ibl+32*ib+i+12] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+32] >> 4];
- y4[k][QK_K*ibl+32*ib+i+20] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+48] & 0xf];
- y4[k][QK_K*ibl+32*ib+i+28] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+48] >> 4];
+ int is = 8*ib + k;
+ float dl = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32);
+ for (int l = 0; l < 4; ++l) for (int i = 0; i < 4; ++i) {
+ y8[k][QK_K*ibl+32*ib+8*l+i+0] = dl * iq4k_values[x[ibl].qs[128*ib+4*k+i+32*l] & 0xf];
+ y8[k][QK_K*ibl+32*ib+8*l+i+4] = dl * iq4k_values[x[ibl].qs[128*ib+4*k+i+32*l] >> 4];
}
}
}
- //dequantize_row_iq4_xs(x + ib, ytmp, QK_K);
- //for (int k = 0; k < 4; ++k) {
- // for (int l = 0; l < 16; ++l) {
- // for (int i = 0; i < 4; ++i) {
- // //y4[k][ib*kBlockSize + i + 16*(l%4) + 4*(l/4)] = ytmp[16*l + 4*k + i];
- // y4[k][ib*kBlockSize + i + 8*(l%8) + 4*(l/8)] = ytmp[16*l + 4*k + i];
- // }
- // }
- //}
}
}
@@ -6063,7 +6047,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) {
{ GGML_TYPE_IQ3_K, { GGML_TYPE_IQ3_K_R4, 4, (Repack::repack_func)repack_iq3_k} },
{ GGML_TYPE_IQ4_K, { GGML_TYPE_IQ4_K_R4, 4, (Repack::repack_func)repack_iq4_k} },
{ GGML_TYPE_IQ5_K, { GGML_TYPE_IQ5_K_R4, 4, (Repack::repack_func)repack_iq5_k} },
- { GGML_TYPE_IQ4_XS, { GGML_TYPE_IQ4_XS_R4, 4, (Repack::repack_func)repack_iq4_xs} },
+ { GGML_TYPE_IQ4_XS, { GGML_TYPE_IQ4_XS_R4, 8, (Repack::repack_func)repack_iq4_xs} },
{ GGML_TYPE_IQ4_KS, { GGML_TYPE_IQ4_KS_R4, 4, (Repack::repack_func)repack_iq4_ks} },
{ GGML_TYPE_IQ4_NL, { GGML_TYPE_IQ4_NL_R4, 4, (Repack::repack_func)repack_iq4_nl} },
{ GGML_TYPE_IQ2_BN, { GGML_TYPE_IQ2_BN_R4, 4, (Repack::repack_func)repack_iq2_bn} },
@@ -6080,7 +6064,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) {
{ GGML_TYPE_Q4_0, { GGML_TYPE_Q4_0_R4, 4, (Repack::repack_func)repack_q4_0} },
{ GGML_TYPE_Q5_0, { GGML_TYPE_Q5_0_R4, 4, (Repack::repack_func)repack_q5_0} },
{ GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} },
- { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R4, 4, (Repack::repack_func)repack_q8_0} },
+ { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R4, 8, (Repack::repack_func)repack_q8_0} },
{ GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} },
#ifdef __AVX512BF16__
{ GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_bf16_t>}},
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 729b0ec0..64860b4d 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -73,10 +73,10 @@ size_t quantize_q4_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
void dequantize_row_q4_0_r4(const block_iq4_nl_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_q4_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
-void quantize_row_q8_0_r4_ref(const float * GGML_RESTRICT x, block_q8_0_x4 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_0_r4_ref(const float * GGML_RESTRICT x, block_q8_0_r8 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_0_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
size_t quantize_q8_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
-void dequantize_row_q8_0_r4(const block_q8_0_x4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q8_0_r4(const block_q8_0_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_q8_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void quantize_row_q5_0_r4_ref(const float * GGML_RESTRICT x, block_q5_0_r4 * GGML_RESTRICT y, int64_t k);
diff --git a/src/llama.cpp b/src/llama.cpp
index c2bc5cc0..836fd97a 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -16906,8 +16906,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
else chunk_size_multiplier = 4;
}
else if (new_type == GGML_TYPE_IQ4_XS_R4) {
- if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_XS;
- else chunk_size_multiplier = 4;
+ if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_IQ4_XS;
+ else chunk_size_multiplier = 8;
}
else if (new_type == GGML_TYPE_Q4_0_R4) {
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0;
@@ -16922,8 +16922,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
else chunk_size_multiplier = 4;
}
else if (new_type == GGML_TYPE_Q8_0_R4) {
- if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q8_0;
- else chunk_size_multiplier = 4;
+ if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0;
+ else chunk_size_multiplier = 8;
}
else if (new_type == GGML_TYPE_Q2_K_R4) {
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q2_K;