summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-12-02 07:25:39 +0100
committerGitHub <noreply@github.com>2024-12-02 07:25:39 +0100
commit6d0462d4a39085a9f9da04e0a5fc7cc9d4578818 (patch)
treeb7fd71bda09bb8e2315feff8b6128ad0b7cbefc7 /ggml/src
parent8ad84b9fab9570c36220cb791f9a67a4d2c7fd2f (diff)
IQ4_NL_X4 (#118)
* Adding iq4_nl_x4 Looks very promising - I get PP-512(LLaMA-3.1-8B) = 230 t/s on the Ryzen-7950X! This is faster than any other quant and ~40% faster than iq4_nl. * iq4_nl_x4: getting amazing This Zen4 variant gets us to PP-512(LLaMA-3.1-8B) = 263 t/s! * iq4_nl_x4: AVX2 Here we gain only 25% compared to iq4_nl * iq4_nl_x4: NEON On M2-Max we get PP-512(LLaMA-3.1-8B) = 109.7 t/s, up from 82.4 t/s for iq4_nl. * iq4_nl_x4: minor NEON improvement and cleanup This gets us to 110.3 t/s. In comparison, IQ4_NL_4_4 in mainline llama.cpp achieves 92.3 t/s. * iq4_nl_x4: NEON specialization for matrix x vector --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/ggml-common.h5
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c26
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp230
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp114
-rw-r--r--ggml/src/iqk/iqk_quantize.h6
6 files changed, 374 insertions, 8 deletions
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index f0c1ae68..2af3323d 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -419,6 +419,11 @@ typedef struct {
uint8_t qs[QK4_NL/2];
} block_iq4_nl;
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_half) + QK4_NL/2, "wrong iq4_nl block size/padding");
+typedef struct {
+ ggml_half d[4];
+ uint8_t qs[2*QK4_NL];
+} block_iq4_nl_x4;
+static_assert(sizeof(block_iq4_nl_x4) == 4*sizeof(ggml_half) + 2*QK4_NL, "wrong iq4_nl_x4 block size/padding");
typedef struct {
ggml_half d;
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index d18b1981..376c97f8 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15196,6 +15196,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_IQ6_K: break;
case GGML_TYPE_IQ4_KS: break;
case GGML_TYPE_IQ4_KSS: break;
+ case GGML_TYPE_IQ4_NL_X4: break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
{
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 39218ff4..c975212e 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1245,6 +1245,23 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
+ [GGML_TYPE_IQ4_NL_X4] = {
+ .type_name = "iq4_nl_x4",
+ .blck_size = QK4_NL,
+ .type_size = sizeof(block_iq4_nl),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq4_nl_x4,
+ .from_float = quantize_row_iq4_nl_x4,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_x4_ref,
+ .vec_dot = vec_dot_iq4_nl_x4_q8_0,
+#if GGML_USE_IQK_MULMAT && defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0,
+#endif
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
};
// For internal test use
@@ -3903,6 +3920,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ1_BN: wtype = GGML_TYPE_IQ1_BN; break;
case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break;
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
+ case GGML_FTYPE_MOSTLY_IQ4_NL_X4: wtype = GGML_TYPE_IQ4_NL_X4;break;
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
case GGML_FTYPE_MOSTLY_IQ4_KS: wtype = GGML_TYPE_IQ4_KS; break;
case GGML_FTYPE_MOSTLY_IQ4_KSS: wtype = GGML_TYPE_IQ4_KSS; break;
@@ -10426,6 +10444,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KSS:
@@ -10868,6 +10887,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KSS:
@@ -11007,6 +11027,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KSS:
@@ -14192,6 +14213,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KSS:
@@ -14571,6 +14593,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KSS:
@@ -14844,6 +14867,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KSS:
@@ -15444,6 +15468,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KSS:
@@ -22270,6 +22295,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ1_BN: result = quantize_iq1_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ4_NL_X4: result = quantize_iq4_nl_x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_KS: result = quantize_iq4_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_KSS: result = quantize_iq4_kss(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index d7682e54..cfda9e18 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -23,6 +23,7 @@
#include "ggml-impl.h"
#include "ggml-quants.h"
#include "iqk_mul_mat.h"
+#include "iqk_quantize.h"
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
@@ -87,6 +88,16 @@ struct DataInfo {
inline void store(int ix, int iy, float result) const {
*(dst_row(iy) + ix) = result;
}
+#ifdef __AVX__
+ inline void store(int ix, int iy, __m128 result) const {
+ _mm_storeu_ps(dst_row(iy) + ix, result);
+ }
+#endif
+#ifdef __ARM_NEON
+ inline void store(int ix, int iy, float32x4_t result) const {
+ vst1q_f32(dst_row(iy) + ix, result);
+ }
+#endif
inline float * dst_row(int iy) const {
if (!row_mapping) return s + (cur_y + iy)*bs;
int i12 = row_mapping[cur_y + iy].i2;
@@ -2068,6 +2079,112 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
#endif // Zen4 or vanilla AVX2
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_1_x4> q8(info);
+ auto m4 = _mm512_set1_epi8(0xf);
+ auto values = load_iq4nl_values_512();
+ int nb = n / QK4_NL;
+ GGML_ASSERT(nb%4 == 0);
+ __m512 acc[2*nrc_y] = {};
+ __m512i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_iq4_nl_x4 * iq4l = (const block_iq4_nl_x4 *)((const char *)vx + (ix+0)*bx);
+ const block_iq4_nl_x4 * iq4h = (const block_iq4_nl_x4 *)((const char *)vx + (ix+4)*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int k = 0; k < 4; ++k) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d));
+ auto scales1 = _mm256_set_m128(scales128, scales128);
+ scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d));
+ auto scales2 = _mm256_set_m128(scales128, scales128);
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-64.f));
+ auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)),
+ _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1);
+ auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)),
+ _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1);
+ qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
+ qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
+ qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
+ qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]));
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]);
+ acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
+ auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
+ auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
+ info.store(ix+0, iy, sum1);
+ info.store(ix+4, iy, sum2);
+ }
+ }
+}
+#else
+template <int nrc_y>
+static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_1_x4> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ auto m1 = _mm256_set1_epi16(1);
+ auto values = load_iq4nl_values_256();
+ int nb = n / QK4_NL;
+ GGML_ASSERT(nb%4 == 0);
+ __m256 acc[nrc_y] = {};
+ //__m256 acc[2*nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int k = 0; k < 4; ++k) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d));
+ auto scales = _mm256_set_m128(scales128, scales128);
+ auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-64.f));
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1);
+ auto q1 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
+ auto q2 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
+ auto q3 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
+ auto q4 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
+ auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55))));
+ auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff))));
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[iy]);
+ //acc[2*iy+0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[2*iy+0]);
+ //acc[2*iy+1] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ //auto sum256 = _mm256_add_ps(acc[2*iy+0], acc[2*iy+1]);
+ //acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_ps();
+ //auto sum = _mm_add_ps(_mm256_castps256_ps128(sum256), _mm256_extractf128_ps(sum256, 1));
+ //info.store(ix, iy, sum);
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, sum);
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+#endif
+
template <typename Bits>
inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) {
if (j == 0) {
@@ -4025,6 +4142,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
MulMat::set_functions<IQ4_NL_Unpacker>(mm);
expected_typeB = GGML_TYPE_Q8_1;
break;
+ case GGML_TYPE_IQ4_NL_X4:
+ assert (ne00 % QK4_NL == 0);
+ mm.funcs[0] = mul_mat_iq4_nl_x4_q8_1<1>;
+ mm.funcs[1] = mul_mat_iq4_nl_x4_q8_1<2>;
+ mm.funcs[2] = mul_mat_iq4_nl_x4_q8_1<3>;
+ mm.funcs[3] = mul_mat_iq4_nl_x4_q8_1<4>;
+ mm.funcs[4] = mul_mat_iq4_nl_x4_q8_1<5>;
+ mm.funcs[5] = mul_mat_iq4_nl_x4_q8_1<6>;
+ mm.funcs[6] = mul_mat_iq4_nl_x4_q8_1<7>;
+ mm.funcs[7] = mul_mat_iq4_nl_x4_q8_1<8>;
+ expected_typeB = GGML_TYPE_Q8_1;
+ break;
default:
return false;
@@ -6427,6 +6556,96 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
}
+template <int nrc_y>
+void mul_mat_iq4_nl_x4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_0_x4> q8(info);
+ auto m4 = vdupq_n_u8(0xf);
+ auto values = vld1q_s8(iq4k_values);
+ int nb = n / QK4_NL;
+ GGML_ASSERT(nb%4 == 0);
+ int8x16_t qx[8];
+ float32x4_t acc[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int k = 0; k < 4; ++k) {
+ auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
+ auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
+ qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
+ qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
+ qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
+ qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
+ qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
+ qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
+ qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
+ auto sumi = vdupq_n_s32(0);
+ sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
+ sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
+ auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
+ acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+void mul_mat_iq4_nl_x4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<1, block_q8_0_x4> q8(info);
+ auto m4 = vdupq_n_u8(0xf);
+ auto values = vld1q_s8(iq4k_values);
+ int nb = n / QK4_NL;
+ GGML_ASSERT(nb%4 == 0);
+ int8x16_t qx[8];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto acc = vdupq_n_f32(0.f);
+ const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ auto y1 = vld1q_s8_x4(q8.y[0][ib4].qs);
+ auto y2 = vld1q_s8_x4(q8.y[0][ib4].qs+64);
+ for (int k = 0; k < 4; ++k) {
+ auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
+ auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[0][ib4].d[k])));
+ auto sumi = vdupq_n_s32(0);
+ const auto yval = k < 2 ? y1.val + 2*k : y2.val + 2*(k-2);
+ auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
+ qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
+ sumi = vdotq_laneq_s32(sumi, qx[0], yval[0], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[1], yval[1], 0);
+ qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
+ qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
+ sumi = vdotq_laneq_s32(sumi, qx[2], yval[0], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[3], yval[1], 1);
+ qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
+ qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
+ sumi = vdotq_laneq_s32(sumi, qx[4], yval[0], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[5], yval[1], 2);
+ qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
+ qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
+ sumi = vdotq_laneq_s32(sumi, qx[6], yval[0], 3);
+ sumi = vdotq_laneq_s32(sumi, qx[7], yval[1], 3);
+ acc = vfmaq_f32(acc, d4d8, vcvtq_f32_s32(sumi));
+ }
+ }
+ info.store(ix, 0, acc);
+ }
+}
+
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> ||
std::is_same_v<Dequantizer, DequantizerQ80> || std::is_same_v<Dequantizer, DequantizerIQ4NL> ||
@@ -6596,6 +6815,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
MulMat::set_functions<DequantizerIQ4NL>(m);
expected_Btype = GGML_TYPE_Q8_0;
break;
+ case GGML_TYPE_IQ4_NL_X4:
+ m.funcs[0] = mul_mat_iq4_nl_x4_q8_0_1;
+ m.funcs[1] = mul_mat_iq4_nl_x4_q8_0<2>;
+ m.funcs[2] = mul_mat_iq4_nl_x4_q8_0<3>;
+ m.funcs[3] = mul_mat_iq4_nl_x4_q8_0<4>;
+ m.funcs[4] = mul_mat_iq4_nl_x4_q8_0<5>;
+ m.funcs[5] = mul_mat_iq4_nl_x4_q8_0<6>;
+ m.funcs[6] = mul_mat_iq4_nl_x4_q8_0<7>;
+ m.funcs[7] = mul_mat_iq4_nl_x4_q8_0<8>;
+ expected_Btype = GGML_TYPE_Q8_0;
+ break;
default:
return false;
}
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index b9d48237..88b5628c 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -669,12 +669,12 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl
}
}
-void quantize_row_iq2_k_ref(const float * GGML_RESTRICT x, block_iq2_k * GGML_RESTRICT y, int64_t k) {
+void quantize_row_iq2_k_ref(const float * x, block_iq2_k * y, int64_t k) {
assert(k % QK_K == 0);
quantize_iq2_k(x, (void *)y, 1, k, nullptr);
}
-void quantize_row_iq2_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+void quantize_row_iq2_k(const float * x, void * vy, int64_t k) {
assert(k % QK_K == 0);
block_iq2_k * y = (block_iq2_k *)vy;
quantize_row_iq2_k_ref(x, y, k);
@@ -692,7 +692,7 @@ size_t quantize_iq2_k(const float * src, void * dst, int64_t nrows, int64_t n_pe
return nrows * nblock * sizeof(block_iq2_k);
}
-void dequantize_row_iq2_k(const block_iq2_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
+void dequantize_row_iq2_k(const block_iq2_k * x, float * y, int64_t k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
@@ -723,7 +723,7 @@ void dequantize_row_iq2_k(const block_iq2_k * GGML_RESTRICT x, float * GGML_RES
}
-void vec_dot_iq2_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+void vec_dot_iq2_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
GGML_UNUSED(nrc);
@@ -967,12 +967,12 @@ void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const f
}
}
-void quantize_row_iq2_ks_ref(const float * GGML_RESTRICT x, block_iq2_ks * GGML_RESTRICT y, int64_t k) {
+void quantize_row_iq2_ks_ref(const float * x, block_iq2_ks * y, int64_t k) {
assert(k % QK_K == 0);
quantize_iq2_ks(x, (void *)y, 1, k, nullptr);
}
-void quantize_row_iq2_ks(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+void quantize_row_iq2_ks(const float * x, void * vy, int64_t k) {
assert(k % QK_K == 0);
block_iq2_ks * y = (block_iq2_ks *)vy;
quantize_row_iq2_ks_ref(x, y, k);
@@ -994,7 +994,7 @@ size_t quantize_iq2_ks(const float * src, void * dst, int64_t nrows, int64_t n_p
return nrows * row_size;
}
-void dequantize_row_iq2_ks(const block_iq2_ks * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
+void dequantize_row_iq2_ks(const block_iq2_ks * x, float * y, int64_t k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
@@ -1334,7 +1334,7 @@ void dequantize_row_iq3_k(const block_iq3_k * x, float * y, int64_t k) {
}
}
-void vec_dot_iq3_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+void vec_dot_iq3_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
GGML_UNUSED(nrc);
@@ -3119,4 +3119,102 @@ void vec_dot_iq4_kss_q8_k(int n, float * s, size_t bs, const void * vx, size_t b
GGML_UNUSED(by);
}
+//
+// ========================================= x4
+//
+void quantize_row_iq4_nl_x4_ref(const float * x, block_iq4_nl_x4 * y, int64_t k) {
+ // we assume we are called with 4 rows
+ quantize_iq4_nl_x4(x, (void *)y, 4, k/4, nullptr);
+}
+
+void quantize_row_iq4_nl_x4(const float * x, void * y, int64_t k) {
+ // we assume we are called with 4 rows
+ quantize_iq4_nl_x4(x, y, 4, k/4, nullptr);
+}
+
+static void repack_iq4_nl(int nrows, int n_per_row, const block_iq4_nl * x, block_iq4_nl_x4 * y) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK4_NL == 0);
+ int nblock = n_per_row/QK4_NL;
+ const block_iq4_nl * x4[4];
+ for (int row = 0; row < nrows; row += 4) {
+ for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k;
+ for (int ib = 0; ib < nblock; ++ib) {
+ for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d;
+ for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) {
+ y[ib].qs[4*k+i+ 0] = (x4[k][ib].qs[i+0] & 0xf) | ((x4[k][ib].qs[i+ 8] & 0x0f) << 4); // 0....3 + 8...11 from each row
+ y[ib].qs[4*k+i+16] = (x4[k][ib].qs[i+0] >> 4) | ((x4[k][ib].qs[i+ 8] & 0xf0)); // 16...19 + 24...27 from each row
+ y[ib].qs[4*k+i+32] = (x4[k][ib].qs[i+4] & 0xf) | ((x4[k][ib].qs[i+12] & 0x0f) << 4); // 4....7 + 12...15 from each row
+ y[ib].qs[4*k+i+48] = (x4[k][ib].qs[i+4] >> 4) | ((x4[k][ib].qs[i+12] & 0xf0)); // 20...23 + 28...31 from each row
+ }
+ }
+ x += 4*nblock;
+ y += nblock;
+ }
+}
+
+size_t quantize_iq4_nl_x4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(nrows%4 == 0);
+ auto row_size_nl = ggml_row_size(GGML_TYPE_IQ4_NL, n_per_row);
+ std::vector<char> qtmp(4*row_size_nl);
+ //std::vector<float> check1(4*n_per_row), check2(4*n_per_row);
+ char * qrow = (char *)dst;
+ for (int row = 0; row < nrows; row += 4) {
+ quantize_iq4_nl(src, qtmp.data(), 4, n_per_row, imatrix);
+ repack_iq4_nl(4, n_per_row, (const block_iq4_nl *)qtmp.data(), (block_iq4_nl_x4 *)qrow);
+ //dequantize_row_iq4_nl_x4((const block_iq4_nl_x4 *)qrow, check1.data(), 4*n_per_row);
+ //dequantize_row_iq4_nl((const block_iq4_nl *)qtmp.data(), check2.data(), 4*n_per_row);
+ //for (int k = 0; k < 4; ++k) {
+ // auto x1 = check1.data() + k*n_per_row;
+ // auto x2 = check2.data() + k*n_per_row;
+ // int nbad = 0;
+ // for (int j = 0; j < n_per_row; ++j) {
+ // if (std::abs(x1[j] - x2[j]) > 1e-8) {
+ // printf("Oops: %g vs %g\n", x1[j], x2[j]);
+ // if (++nbad > 20) GGML_ABORT("fatal error");
+ // }
+ // }
+ //}
+ src += 4*n_per_row;
+ qrow += 4*row_size_nl;
+ }
+ return nrows*row_size_nl;
+}
+
+void dequantize_row_iq4_nl_x4(const block_iq4_nl_x4 * x, float * y, int64_t k) {
+ // we assume we are called with 4 rows
+ int n_per_row = k/4;
+ int nb = n_per_row/QK4_NL;
+ float * yk[4];
+ for (int k = 0; k < 4; ++k) yk[k] = y + k*n_per_row;
+ for (int ib = 0; ib < nb; ++ib) {
+ for (int k = 0; k < 4; ++k) {
+ float scale = GGML_FP16_TO_FP32(x[ib].d[k]);
+ for (int i = 0; i < 4; ++i) {
+ yk[k][QK4_NL*ib+i+ 0] = scale * iq4k_values[x[ib].qs[4*k+i+ 0] & 0xf];
+ yk[k][QK4_NL*ib+i+ 8] = scale * iq4k_values[x[ib].qs[4*k+i+ 0] >> 4];
+ yk[k][QK4_NL*ib+i+16] = scale * iq4k_values[x[ib].qs[4*k+i+16] & 0xf];
+ yk[k][QK4_NL*ib+i+24] = scale * iq4k_values[x[ib].qs[4*k+i+16] >> 4];
+ yk[k][QK4_NL*ib+i+ 4] = scale * iq4k_values[x[ib].qs[4*k+i+32] & 0xf];
+ yk[k][QK4_NL*ib+i+12] = scale * iq4k_values[x[ib].qs[4*k+i+32] >> 4];
+ yk[k][QK4_NL*ib+i+20] = scale * iq4k_values[x[ib].qs[4*k+i+48] & 0xf];
+ yk[k][QK4_NL*ib+i+28] = scale * iq4k_values[x[ib].qs[4*k+i+48] >> 4];
+ }
+ }
+ }
+}
+
+void vec_dot_iq4_nl_x4_q8_0(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_NL_X4, vx, 0, GGML_TYPE_Q8_0, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK4_NL == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+}
+
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 50c425af..7942cc04 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -63,6 +63,12 @@ void vec_dot_iq2_ks_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void
void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
+void quantize_row_iq4_nl_x4_ref(const float * GGML_RESTRICT x, block_iq4_nl_x4 * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_nl_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq4_nl_x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_iq4_nl_x4(const block_iq4_nl_x4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_iq4_nl_x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
#ifdef __cplusplus
}
#endif