summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-06-11 11:12:30 +0300
committerGitHub <noreply@github.com>2025-06-11 11:12:30 +0300
commite56061fa12e5fcb7a9a8fa5fca9e74c20d166bcc (patch)
treef0627d22fac777db7e397119b2f38421a2b289b8
parent3c1f2c68fdbb9e5be26aad85902c3d86057d69b5 (diff)
IQ2_XXS: much faster CPU prompt processing (#515)
* Much faster iq2_xxs GEMM PP-512 = 290 t/s vs ~110 t/s (iq2_xxs) or 148 t/s (iq2_xxs_r4) on main. * iq2_xxs: q8_2_x4 GEMM * iq2_xxs: use template for q8_2_x4 GEMM * Fix AVX2 * Cleanup * NEON is not working yet, so still use Q8_K GEMM --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml.c4
-rw-r--r--ggml/src/iqk/iqk_common.h20
-rw-r--r--ggml/src/iqk/iqk_gemm_iquants.cpp238
-rw-r--r--ggml/src/iqk/iqk_gemm_iquants.h2
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp92
5 files changed, 346 insertions, 10 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 3ff294cc..9e3c4b90 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1067,7 +1067,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq2_xxs,
.from_float_ref = (ggml_from_float_t)quantize_row_iq2_xxs_ref,
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
+#ifdef __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_2_X4,
+#else
.vec_dot_type = GGML_TYPE_Q8_K,
+#endif
.nrows = 1,
.row_meta_size = 0,
},
diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h
index cce040dd..8d44c8f9 100644
--- a/ggml/src/iqk/iqk_common.h
+++ b/ggml/src/iqk/iqk_common.h
@@ -172,7 +172,6 @@ static inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {
aux32[0] = a0 & 0x3f3f3f3f;
}
-#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__)
const uint64_t keven_signs[128] = {
0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,
0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,
@@ -207,7 +206,6 @@ const uint64_t keven_signs[128] = {
0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,
0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,
};
-#endif
#ifdef __AVX2__
@@ -540,6 +538,24 @@ struct Q4Bits {
#endif
+inline void iqk_transpose_8x8(__m256 * m) {
+ for (int k = 0; k < 8; k += 4) {
+ auto t0 = _mm256_unpacklo_ps(m[k+0], m[k+1]);
+ auto t1 = _mm256_unpacklo_ps(m[k+2], m[k+3]);
+ auto t2 = _mm256_unpackhi_ps(m[k+0], m[k+1]);
+ auto t3 = _mm256_unpackhi_ps(m[k+2], m[k+3]);
+ m[k+0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
+ m[k+1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
+ m[k+2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
+ m[k+3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto t = _mm256_set_m128(_mm256_extractf128_ps(m[k+4], 1), _mm256_extractf128_ps(m[k], 1));
+ m[k+0] = _mm256_set_m128(_mm256_castps256_ps128(m[k+4]), _mm256_castps256_ps128(m[k+0]));
+ m[k+4] = t;
+ }
+}
+
#else
// ------------------------------------ __aarch64__ --------------------------------------------------
diff --git a/ggml/src/iqk/iqk_gemm_iquants.cpp b/ggml/src/iqk/iqk_gemm_iquants.cpp
index 782e48d8..9e2ddc10 100644
--- a/ggml/src/iqk/iqk_gemm_iquants.cpp
+++ b/ggml/src/iqk/iqk_gemm_iquants.cpp
@@ -87,13 +87,12 @@ struct EvenSignHelper {
const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0);
const __m256i mask = _mm256_set1_epi32(127);
const __m256i mone = _mm256_set1_epi32(1);
-#else
+#endif
inline void sign_value(uint32_t aux32, __m256i& value) const {
auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127],
keven_signs[(aux32 >> 7) & 127], keven_signs[(aux32 >> 0) & 127]);
value = _mm256_sign_epi8(value, signs);
}
-#endif
};
struct SignHelper {
@@ -144,6 +143,35 @@ struct SignHelper {
const __m256i mone = _mm256_set1_epi8(1);
};
+// for (int i = 0; i < nb; ++i) {
+//
+// __m256i sumi[nrc_y], all_scales;
+// //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();
+// __m256i mins;
+// float dmin = deq.new_block(i, &all_scales, mins);
+// for (int iy = 0; iy < nrc_y; ++iy) {
+// auto bsums = q8.load_bsums(iy, i);
+// auto prod = _mm256_madd_epi16(mins, bsums);
+// accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
+// }
+//
+// for (int j = 0; j < QK_K/128; ++j) {
+// deq.prepare(i, j);
+// set_scales_8(&all_scales, j, scales);
+// //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);
+// multiply_add(deq.bits, scales, j, i, q8, sumi);
+// }
+// for (int iy = 0; iy < nrc_y; ++iy) {
+// const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
+// accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
+// }
+// }
+//
+// for (int iy = 0; iy < nrc_y; ++iy) {
+// info.store(ix, iy, hsum_float_8(accd[iy]));
+// }
+// }
+
struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
@@ -165,6 +193,16 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
auto sc16 = load_scales(i);
scales[0] = MM256_SET_M128I(sc16, sc16);
}
+ inline void new_block_f(int i, __m256 * scales) {
+ auto sc16 = load_scales(i);
+ auto scf = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(sc16)));
+ auto scf_l = _mm256_castps256_ps128(scf);
+ auto scf_h = _mm256_extractf128_ps(scf, 1);
+ scales[0] = _mm256_set_m128(scf_l, scf_l);
+ scales[1] = _mm256_set_m128(scf_h, scf_h);
+ scales[2] = _mm256_mul_ps(scf, _mm256_set1_ps(-minv));
+ }
+
inline float new_block(int i, __m256i * scales, __m256i& mins) {
auto sc16 = load_scales(i);
mins = scb.shuffle(sc16);
@@ -730,6 +768,130 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
}
template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ static_assert(Dequantizer::num_blocks == 8);
+ const int nb = n / QK_K;
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+ Dequantizer deq(vx, bx);
+ __m256 scales[3];
+ __m256 accd[nrc_y];
+ __m256i sumi[4];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ deq.new_block_f(i, scales);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d + 4)));
+ auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d + 4)));
+ auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
+ accd[iy] = _mm256_fmadd_ps(scales[2], my, accd[iy]);
+ }
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ deq.prepare(i, j);
+ auto& values = deq.bits.values;
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qs = q8.y[iy][2*i+j].qs;
+#ifdef HAVE_FANCY_SIMD
+ sumi[0] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[0], _mm256_loadu_si256((const __m256i*)qs+0));
+ sumi[1] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[1], _mm256_loadu_si256((const __m256i*)qs+1));
+ sumi[2] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[2], _mm256_loadu_si256((const __m256i*)qs+2));
+ sumi[3] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[3], _mm256_loadu_si256((const __m256i*)qs+3));
+#else
+ sumi[0] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[0], _mm256_loadu_si256((const __m256i*)qs+0)));
+ sumi[1] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[1], _mm256_loadu_si256((const __m256i*)qs+1)));
+ sumi[2] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[2], _mm256_loadu_si256((const __m256i*)qs+2)));
+ sumi[3] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[3], _mm256_loadu_si256((const __m256i*)qs+3)));
+#endif
+ sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[0], sumi[1]), _mm256_unpackhi_epi32(sumi[0], sumi[1]));
+ sumi[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[2], sumi[3]), _mm256_unpackhi_epi32(sumi[2], sumi[3]));
+ sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi[0], sumi[2]), _mm256_unpackhi_epi64(sumi[0], sumi[2]));
+ auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][2*i+j].d)), 16));
+ auto dy = _mm256_set_m128(d4, d4);
+ accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi[0]), accd[iy]);
+ }
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq2_xxs_q8_2_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK_K;
+ __m256 scales[2];
+ __m256 accd[nrc_y];
+ __m256i sumi[4];
+ __m256i xv[4];
+ EvenSignHelper esh;
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ const block_iq2_xxs * x = (const block_iq2_xxs *)((const char *)vx + ix*bx);
+
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d)*0.125f;
+ const uint16_t * a16 = x[i].qs;
+ auto sc16 = _mm_set_epi16(a16[31], a16[27], a16[23], a16[19], a16[15], a16[11], a16[7], a16[3]);
+ sc16 = _mm_or_si128(_mm_slli_epi16(_mm_srli_epi16(sc16, 12), 1), _mm_set1_epi16(1));
+ auto sc32 = _mm256_cvtepi16_epi32(sc16);
+ auto all_scales = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sc32));
+ auto all_mins = _mm256_mul_ps(all_scales, _mm256_set1_ps(-43.f));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = (const block_q8_2_x4 *)info.src1_row(iy);
+ auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(y[2*i+0].d + 4)));
+ auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(y[2*i+1].d + 4)));
+ auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
+ accd[iy] = _mm256_fmadd_ps(all_mins, my, accd[iy]);
+ }
+ auto scales_l = _mm256_castps256_ps128(all_scales);
+ auto scales_h = _mm256_extractf128_ps(all_scales, 1);
+ scales[0] = _mm256_set_m128(scales_l, scales_l);
+ scales[1] = _mm256_set_m128(scales_h, scales_h);
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ const uint8_t * a8 = (const uint8_t *)(a16 + 16*j);
+ for (int k = 0; k < 4; ++k) {
+ auto a8k = a8 + 8*k;
+ xv[k] = _mm256_set_epi64x(iq2xxs_grid[a8k[3]], iq2xxs_grid[a8k[2]], iq2xxs_grid[a8k[1]], iq2xxs_grid[a8k[0]]);
+ uint32_t aux32; std::memcpy(&aux32, a8k+4, sizeof(uint32_t));
+ esh.sign_value(aux32, xv[k]);
+ xv[k] = _mm256_add_epi8(xv[k], _mm256_set1_epi8(43));
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = (const block_q8_2_x4 *)info.src1_row(iy);
+ sumi[0] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[0], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+0));
+ sumi[1] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[1], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+1));
+ sumi[2] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[2], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+2));
+ sumi[3] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[3], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+3));
+ sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[0], sumi[1]), _mm256_unpackhi_epi32(sumi[0], sumi[1]));
+ sumi[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[2], sumi[3]), _mm256_unpackhi_epi32(sumi[2], sumi[3]));
+ sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi[0], sumi[2]), _mm256_unpackhi_epi64(sumi[0], sumi[2]));
+ auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)y[2*i+j].d)), 16));
+ auto dy = _mm256_set_m128(d4, d4);
+ accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi[0]), accd[iy]);
+ }
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+ }
+}
+
+template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
#ifdef HAVE_FANCY_SIMD
@@ -1560,6 +1722,55 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
}
}
+void iqk_convert_iq2_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq2_xxs * x8[8];
+
+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
+
+ ggml_half dh[8];
+ uint16_t all_ls[64];
+ EvenSignHelper esh;
+
+ uint32_t block[8];
+ uint32_t aux32[2];
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_xxs *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ // TODO: simdify
+ for (int k = 0; k < 8; ++k) {
+ dh[k] = x8[k][i].d;
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ std::memcpy(aux32, x8[k][i].qs + 4*ib32, 2*sizeof(uint32_t));
+ all_ls[8*ib32 + k] = (2*(aux32[1] >> 28) + 1);
+ auto value = _mm256_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
+ esh.sign_value(aux32[1], value);
+ _mm256_storeu_si256((__m256i *)block, value);
+ auto qs = (uint32_t *)y[ib32].qs;
+ for (int l = 0; l < 4; ++l) {
+ qs[8*l + k + 0] = block[l + 0];
+ qs[8*l + k + 32] = block[l + 4];
+ }
+ }
+ }
+ auto vd = _mm256_mul_ps(_mm256_set1_ps(0.125f), _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh)));
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
+ auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
+ auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
+ _mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
+ }
+ y += QK_K/32;
+ }
+ }
+}
+
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
@@ -1575,7 +1786,19 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX
bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
- if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
+ if (ne00%QK_K != 0) return false;
+
+ if (ggml_type(typeA) == GGML_TYPE_IQ2_XXS) {
+ if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ2XXS, kernels);
+ //IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xxs_q8_2_IQ_N, kernels);
+ func16 = nullptr;
+ return true;
+ }
+ return false;
+ }
+
+ if (ggml_type(typeB) != GGML_TYPE_Q8_K) {
return false;
}
@@ -1629,6 +1852,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
}
+bool iqk_convert_iquants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ if (n%QK_K != 0 || nrc_x%8 != 0) return false;
+ switch (ggml_type(type)) {
+ case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_0_r8(n, vx, bx, vy, nrc_x); break;
+ default: return false;
+ }
+ return true;
+}
+
#else
// --------------------------------------- __aarch64__ ---------------------------------------------
diff --git a/ggml/src/iqk/iqk_gemm_iquants.h b/ggml/src/iqk/iqk_gemm_iquants.h
index 4182526a..82d1bb80 100644
--- a/ggml/src/iqk/iqk_gemm_iquants.h
+++ b/ggml/src/iqk/iqk_gemm_iquants.h
@@ -8,4 +8,6 @@
bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
+bool iqk_convert_iquants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x);
+
#endif
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index d04ad22a..6f8c0106 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -239,6 +239,7 @@ struct MulMat {
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
+ case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
default: break;
}
#else
@@ -327,6 +328,89 @@ static std::vector<char> & thread_local_work_buffer() {
return f;
}
+bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, size_t stride_y, int nrc_x) {
+
+ switch (typeA) {
+ //case GGML_TYPE_F16:
+ //case GGML_TYPE_F32:
+ //case GGML_TYPE_BF16:
+ //case GGML_TYPE_BF16_R16:
+ // return iqk_set_kernels_float(ne00, typeA, typeB, mm.funcs);
+ //case GGML_TYPE_Q2_K:
+ //case GGML_TYPE_Q3_K:
+ //case GGML_TYPE_Q4_K:
+ //case GGML_TYPE_Q5_K:
+ //case GGML_TYPE_Q6_K:
+ //case GGML_TYPE_IQ4_XS:
+ //case GGML_TYPE_Q2_K_R4:
+ //case GGML_TYPE_Q3_K_R4:
+ //case GGML_TYPE_Q4_K_R4:
+ //case GGML_TYPE_Q5_K_R4:
+ //case GGML_TYPE_Q6_K_R4:
+ //case GGML_TYPE_IQ4_XS_R8:
+ //case GGML_TYPE_Q8_K_R8:
+ //case GGML_TYPE_Q8_KV:
+ //case GGML_TYPE_Q8_KV_R8:
+ // return iqk_set_kernels_kquants(ne00, typeA, typeB, mm.funcs, mm.func16);
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ2_XXS_R4:
+ case GGML_TYPE_IQ2_XS_R4:
+ case GGML_TYPE_IQ2_S_R4:
+ case GGML_TYPE_IQ3_XXS_R4:
+ case GGML_TYPE_IQ3_S_R4:
+ return iqk_convert_iquants_q80_r8(typeA, n, vx, bx, vy, nrc_x);
+ //case GGML_TYPE_IQ4_KS:
+ //case GGML_TYPE_IQ5_KS:
+ //case GGML_TYPE_IQ4_KSS:
+ //case GGML_TYPE_IQ2_K:
+ //case GGML_TYPE_IQ2_KS:
+ //case GGML_TYPE_IQ3_K:
+ //case GGML_TYPE_IQ4_K:
+ //case GGML_TYPE_IQ5_K:
+ //case GGML_TYPE_IQ6_K:
+ //case GGML_TYPE_IQ2_K_R4:
+ //case GGML_TYPE_IQ3_K_R4:
+ //case GGML_TYPE_IQ4_K_R4:
+ //case GGML_TYPE_IQ5_K_R4:
+ //case GGML_TYPE_IQ4_KS_R4:
+ //case GGML_TYPE_IQ5_KS_R4:
+ // return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
+ return iqk_dequantize_ktquants(typeA, n, vx, bx, vy, stride_y, nrc_x);
+ //case GGML_TYPE_Q4_0:
+ //case GGML_TYPE_Q4_1:
+ //case GGML_TYPE_Q5_0:
+ //case GGML_TYPE_Q5_1:
+ //case GGML_TYPE_Q6_0:
+ //case GGML_TYPE_Q8_0:
+ //case GGML_TYPE_IQ4_NL:
+ //case GGML_TYPE_Q4_0_R8:
+ //case GGML_TYPE_Q5_0_R4:
+ //case GGML_TYPE_Q6_0_R4:
+ //case GGML_TYPE_Q8_0_R8:
+ //case GGML_TYPE_IQ4_NL_R4:
+ // return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
+ //case GGML_TYPE_IQ1_S:
+ //case GGML_TYPE_IQ1_S_R4:
+ //case GGML_TYPE_IQ1_M_R4:
+ //case GGML_TYPE_IQ1_BN:
+ //case GGML_TYPE_IQ2_BN:
+ //case GGML_TYPE_IQ2_BN_R4:
+ // return iqk_set_kernels_1bit(ne00, typeA, typeB, mm.funcs, mm.func16);
+
+ default:
+ return false;
+ }
+
+ return false;
+}
+
}
extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
@@ -352,9 +436,7 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
first_x *= num_rows;
nrc_x *= num_rows;
- auto type_size = ggml_type_size(dequant_type);
-
- size_t row_size_qx = ne00*type_size;
+ size_t row_size_qx = ggml_row_size(dequant_type, ne00);
size_t row_size_qy = strideB;
//printf("Dequant mul mat %s x %s: ne00 = %d, row_size = %d\n", ggml_type_name(dequant_type), ggml_type_name(ggml_type(typeB)), (int)ne00, (int)row_size_qx);
@@ -368,7 +450,7 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
if (f.size() < row_size_qx*this_nrc_x) f.resize(row_size_qx*this_nrc_x);
- if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data(), ne00, this_nrc_x)) {
+ if (!iqk_convert_repack(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data(), ne00, this_nrc_x)) {
GGML_ABORT("Fatal error");
}
mm.mul_mat_NxM(ne00, f.data(), row_size_qx, this_info, this_nrc_x, Ny);
@@ -678,7 +760,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_IQ2_S_R4:
case GGML_TYPE_IQ3_XXS_R4:
case GGML_TYPE_IQ3_S_R4:
- return ggml_type(typeB) == GGML_TYPE_Q8_K ? iqk_set_kernels_iquants(ne00, typeA, typeB, mm.funcs, mm.func16) : false;
+ return iqk_set_kernels_iquants(ne00, typeA, typeB, mm.funcs, mm.func16);
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ5_KS:
case GGML_TYPE_IQ4_KSS: