summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-06-17 07:12:48 +0300
committerGitHub <noreply@github.com>2025-06-17 07:12:48 +0300
commit0f8f8b32e2d0c7e3ac8bbafee6965dcd1305d002 (patch)
treedebb40f79053c891ff1cfa9839a7cca8505c1a04
parent6fc5bbb657525bb1ef20b682e1cc4ab5fd44aba6 (diff)
Much faster CPU prompt processing (part 1) (#531)
* q6_K dequantizing GEMM * Much easier: just use different vec_dot types! * WIP * Finally q6_K x q8_2_x4 dot product works * Very slightly better * We don't need the changes in ggml.c * Fix AVX2 * iq2_xs * Fix AVX2 * iq2_s * q3_K * Fix q8_k_r8 on Zen4 * q3_K: repack to q8_k_r8 instead of q8_0_r8 With that we hit 360 t/s for LlaMA-3.1-8B on a Ryzen-7950X. q8_k_r8 is 386 t/s, so for a batch size of 512 repacking costs ~7% of the time taken by the actual GEMM. * q3_K: don't scale when all quants in a block are <= 127 when repacking * iq2_s: repack to q8_k_r8 instead of q8_0_r8 * iq2_xs: rapck to q8_k_r8 * WIP * iq2_xs: repack to q8_k_r8 * iq3_xxs: repack to q8_k_r8 * iq3_s: use q8_k_r8 * iq1_s: repack to q8_k_r8 * iq1_m: repack to q8_k_r8 * iq1_m: slightly faster * Slightly faster --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-common.h3
-rw-r--r--ggml/src/ggml.c23
-rw-r--r--ggml/src/iqk/iqk_gemm_1bit.cpp159
-rw-r--r--ggml/src/iqk/iqk_gemm_iquants.cpp820
-rw-r--r--ggml/src/iqk/iqk_gemm_kquants.cpp468
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp37
-rw-r--r--ggml/src/iqk/iqk_mul_mat.h2
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp34
8 files changed, 1451 insertions, 95 deletions
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 5fe27b29..2bfe5d39 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -386,10 +386,11 @@ static_assert(sizeof(block_q6_k_r4) == 4*sizeof(ggml_half) + QK_K/4 + 3*QK_K, "w
// This is only used for intermediate quantization and dot products
typedef struct {
float d; // delta
+ float sum; // sum of quants in the entire block
int8_t qs[QK_K]; // quants
int16_t bsums[QK_K/16]; // sum of quants in groups of 16
} block_q8_K;
-static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
+static_assert(sizeof(block_q8_K) == 2*sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
typedef struct {
float d; // delta
int8_t qs[64]; // quants
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 069533ae..a6260136 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1036,7 +1036,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q6_K,
.from_float_ref = (ggml_from_float_t) quantize_row_q6_K_ref,
.vec_dot = ggml_vec_dot_q6_K_q8_K,
+#ifdef __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_2_X4,
+#else
.vec_dot_type = GGML_TYPE_Q8_K,
+#endif
+// .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
.row_meta_size = 0,
},
@@ -1062,7 +1067,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q8_k_r8,
.from_float_ref = (ggml_from_float_t) quantize_row_q8_k_r8_ref,
.vec_dot = vec_dot_q8_k_r8_q8_k,
- .vec_dot_type = GGML_TYPE_Q8_KR8,
+ .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
.row_meta_size = 0,
},
@@ -1075,11 +1080,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq2_xxs,
.from_float_ref = (ggml_from_float_t)quantize_row_iq2_xxs_ref,
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
-#ifdef __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_2_X4,
-#else
.vec_dot_type = GGML_TYPE_Q8_K,
-#endif
.nrows = 1,
.row_meta_size = 0,
},
@@ -1131,11 +1132,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq3_xxs,
.from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref,
.vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
-#ifdef __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_2_X4,
-#else
.vec_dot_type = GGML_TYPE_Q8_K,
-#endif
.nrows = 1,
.row_meta_size = 0,
},
@@ -1161,11 +1158,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq3_s,
.from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_ref,
.vec_dot = ggml_vec_dot_iq3_s_q8_K,
-#ifdef __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_2_X4,
-#else
.vec_dot_type = GGML_TYPE_Q8_K,
-#endif
.nrows = 1,
.row_meta_size = 0,
},
@@ -1217,11 +1210,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq1_s,
.from_float_ref = (ggml_from_float_t)quantize_row_iq1_s_ref,
.vec_dot = ggml_vec_dot_iq1_s_q8_K,
-#ifdef __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_2_X4,
-#else
.vec_dot_type = GGML_TYPE_Q8_K,
-#endif
.nrows = 1,
.row_meta_size = 0,
},
diff --git a/ggml/src/iqk/iqk_gemm_1bit.cpp b/ggml/src/iqk/iqk_gemm_1bit.cpp
index 05196c1d..770fbf2c 100644
--- a/ggml/src/iqk/iqk_gemm_1bit.cpp
+++ b/ggml/src/iqk/iqk_gemm_1bit.cpp
@@ -1607,6 +1607,162 @@ static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const Da
}
#endif
+inline float convert_to_q8_k_r8(int k, int d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
+ auto max_i16 = _mm256_setzero_si256();
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
+ auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
+ q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(scales[2*ib32+0]));
+ q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(scales[2*ib32+1]));
+ max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_l, q16_l));
+ max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_h, q16_h));
+ }
+ auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1)));
+ auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
+ auto max4 = _mm_cvtepi32_ps(imax4);
+ max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
+ max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
+ bool needs_scaling = true;
+ float dnew = _mm_cvtss_f32(max4) / d0;
+ if (dnew < 1.f) {
+ dnew = 1.f; needs_scaling = false;
+ }
+ auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f);
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
+ auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
+ q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(scales[2*ib32+0]));
+ q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(scales[2*ib32+1]));
+ if (needs_scaling) {
+ auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l));
+ auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1));
+ auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h));
+ auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1));
+ i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
+ i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
+ i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
+ i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
+ i0 = _mm256_packs_epi32(i0, i1);
+ i2 = _mm256_packs_epi32(i2, i3);
+ i0 = _mm256_packs_epi16(i0, i2);
+ i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
+ _mm256_storeu_si256((__m256i *)block, i0);
+ } else {
+ // 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
+ auto i0 = _mm256_packs_epi16(q16_l, q16_h);
+ auto i0_l = _mm256_castsi256_si128(i0);
+ auto i0_h = _mm256_extracti128_si256(i0, 1);
+ _mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h));
+ _mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h));
+ }
+ auto qs = (uint32_t *)q8_k + 64*ib32;
+ for (int l = 0; l < 8; ++l) {
+ qs[8*l + k] = block[l];
+ }
+ }
+ return dnew;
+}
+
+void iqk_convert_iq1_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq1_s * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ int16_t ls[16];
+
+ uint32_t block[8];
+
+ __m256i qx[8];
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq1_s *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = 0.125f * GGML_FP16_TO_FP32(x8[k][i].d);
+ auto qs = x8[k][i].qs;
+ auto qh = x8[k][i].qh;
+ __m256i value;
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ ls[2*ib32 + 0] = (2*((qh[ib32] >> 12) & 7) + 1);
+ ls[2*ib32 + 1] = ls[2*ib32 + 0];
+ value = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib32] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib32] << 2) & 0x700)],
+ iq1s_grid[qs[1] | ((qh[ib32] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib32] << 8) & 0x700)]);
+ value = _mm256_slli_epi16(_mm256_add_epi8(value, _mm256_set1_epi8(1)), 3);
+ int8_t delta = qh[ib32] & 0x8000 ? -9 : -7;
+ value = _mm256_add_epi8(value, _mm256_set1_epi8(delta));
+ qx[ib32] = value;
+ qs += 4;
+ }
+ float dnew = convert_to_q8_k_r8(k, 126, qx, ls, block, y[i].qs);
+ y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
+ }
+ }
+ y += nb;
+ }
+}
+
+void iqk_convert_iq1_m_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq1_m * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ int16_t ls[16];
+
+ uint32_t block[8];
+
+ __m256i qx[8];
+
+ auto mask = _mm256_setr_epi32(0x00000008, 0x00000008, 0x00000080, 0x00000080, 0x00080000, 0x00080000, 0x00800000, 0x00800000);
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq1_m *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ const uint16_t * sc = (const uint16_t *)x8[k][i].scales;
+ iq1m_scale_t scale;
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+ float d = 0.125f * GGML_FP16_TO_FP32(scale.f16);
+ auto qs = x8[k][i].qs;
+ auto qh = x8[k][i].qh;
+ __m256i value;
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ ls[2*ib32 + 0] = (2*((sc[ib32/2] >> (6*(ib32%2)+0)) & 0x7) + 1);
+ ls[2*ib32 + 1] = (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1);
+ value = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | ((qh[1] << 8) & 0x700)],
+ iq1s_grid[qs[1] | ((qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | ((qh[0] << 8) & 0x700)]);
+ value = _mm256_slli_epi16(_mm256_add_epi8(value, _mm256_set1_epi8(1)), 3);
+
+ auto delta_mask = _mm256_cmpeq_epi32(_mm256_and_si256(_mm256_set1_epi32(qh[0] | qh[1] << 16), mask), mask);
+ auto delta = _mm256_add_epi8(_mm256_set1_epi8(7), _mm256_and_si256(delta_mask, _mm256_set1_epi8(2)));
+ qx[ib32] = _mm256_sub_epi8(value, delta);
+
+ //int64_t delta1 = qh[0] & 0x08 ? 0x0909090909090909 : 0x0707070707070707;
+ //int64_t delta2 = qh[0] & 0x80 ? 0x0909090909090909 : 0x0707070707070707;
+ //int64_t delta3 = qh[1] & 0x08 ? 0x0909090909090909 : 0x0707070707070707;
+ //int64_t delta4 = qh[1] & 0x80 ? 0x0909090909090909 : 0x0707070707070707;
+ //value = _mm256_sub_epi8(value, _mm256_set_epi64x(delta4, delta3, delta2, delta1));
+ //qx[ib32] = value;
+ qs += 4;
+ qh += 2;
+ }
+ float dnew = convert_to_q8_k_r8(k, 126, qx, ls, block, y[i].qs);
+ y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
+ }
+ }
+ y += nb;
+ }
+}
+
void iqk_convert_iq1_s_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
@@ -1722,7 +1878,8 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
bool iqk_convert_1bit_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
if (n%QK_K != 0 || nrc_x%8 != 0) return false;
switch (ggml_type(type)) {
- case GGML_TYPE_IQ1_S: iqk_convert_iq1_s_q8_0_r8(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ1_S: iqk_convert_iq1_s_q8_k_r8(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ1_M: iqk_convert_iq1_m_q8_k_r8(n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;
diff --git a/ggml/src/iqk/iqk_gemm_iquants.cpp b/ggml/src/iqk/iqk_gemm_iquants.cpp
index 60396fee..c8688dc6 100644
--- a/ggml/src/iqk/iqk_gemm_iquants.cpp
+++ b/ggml/src/iqk/iqk_gemm_iquants.cpp
@@ -238,13 +238,17 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
constexpr static int num_blocks = 16;
- inline __m256i load_scales(int i) {
- d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
- auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
+ static inline __m256i make_scales(const uint8_t * scales) {
+ auto tmp = _mm_loadl_epi64((const __m128i *)scales);
auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf));
auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
return _mm256_cvtepi8_epi16(scales8);
}
+
+ inline __m256i load_scales(int i) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ return make_scales(x[i].scales);
+ }
inline static void prepare_scales(const __m256i& all, __m256i * scales) {
auto scales_l = _mm256_castsi256_si128(all);
auto scales_h = _mm256_extractf128_si256(all, 1);
@@ -296,8 +300,8 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, mask), mask);
value = _mm256_sign_epi8(value, _mm256_or_si256(signs, mone));
}
- inline void sign_values(const __m256i& data, __m256i * values) const {
#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
+ static IQK_ALWAYS_INLINE inline void sign_values_popcnt(const __m256i& data, __m256i * values) {
auto partial_bits = _mm256_cvtepi16_epi8(_mm256_srli_epi16(data, 9));
auto pcnt = _mm_popcnt_epi8(partial_bits);
auto full_bits = _mm_or_si128(partial_bits, _mm_slli_epi16(_mm_and_si128(pcnt, _mm_set1_epi8(1)), 7));
@@ -307,7 +311,9 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
values[1] = _mm256_mask_sub_epi8(values[1], m32[1], zero, values[1]);
values[2] = _mm256_mask_sub_epi8(values[2], m32[2], zero, values[2]);
values[3] = _mm256_mask_sub_epi8(values[3], m32[3], zero, values[3]);
+ }
#else
+ static IQK_ALWAYS_INLINE inline void sign_values_helper(const __m256i& data, const Helper& helper, __m256i * values) {
auto psb1 = _mm256_srli_epi16(data, 9);
auto psb2 = _mm256_srli_epi16(data, 13);
auto psbc = _mm256_xor_si256(psb1, psb2);
@@ -321,6 +327,13 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
sign_value(full_1, helper.shuff2, helper.mask, helper.mone, values[1]);
sign_value(full_2, helper.shuff1, helper.mask, helper.mone, values[2]);
sign_value(full_2, helper.shuff2, helper.mask, helper.mone, values[3]);
+ }
+#endif
+ IQK_ALWAYS_INLINE inline void sign_values(const __m256i& data, __m256i * values) const {
+#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
+ sign_values_popcnt(data, values);
+#else
+ sign_values_helper(data, helper, values);
#endif
}
inline void make4_signed(const uint16_t * qs, const __m256i& m511,
@@ -343,6 +356,17 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
make4(x[i].qs + 16*j, idx_mask, bits.values, q8_quants);
}
+ inline void prepare_signed(int i, int j, __m256i * us) {
+ auto q2 = _mm256_loadu_si256((const __m256i *)x[i].qs+j);
+ make4(q2, idx_mask, us);
+ for (int k = 0; k < 4; ++k) bits.values[k] = us[k];
+ sign_values(q2, bits.values);
+ }
+ IQK_ALWAYS_INLINE inline void prepare_signed(int i, int j, __m256i * us, __m256i * s) {
+ auto q2 = _mm256_loadu_si256((const __m256i *)x[i].qs+j);
+ make4(q2, idx_mask, us);
+ sign_values(q2, s);
+ }
constexpr static int minv = 43;
@@ -360,13 +384,16 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
constexpr static int num_blocks = 16;
- inline __m256i load_scales(int i) {
- d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
- auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
+ static inline __m256i make_scales(const uint8_t * scales) {
+ auto tmp = _mm_loadl_epi64((const __m128i *)scales);
auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf));
auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
return _mm256_cvtepi8_epi16(scales8);
}
+ inline __m256i load_scales(int i) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ return make_scales(x[i].scales);
+ }
inline static void prepare_scales(const __m256i& all, __m256i * scales) {
auto scales_l = _mm256_castsi256_si128(all);
auto scales_h = _mm256_extractf128_si256(all, 1);
@@ -421,6 +448,38 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
q8_quants[2] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+2), sh.make_signs(signs[4] | (signs[5] << 16)));
q8_quants[3] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+3), sh.make_signs(signs[6] | (signs[7] << 16)));
}
+ static inline void prepare(const uint8_t * qs, const uint8_t * qh, const uint16_t * signs, const SignHelper& sh, __m256i * values) {
+ auto idx_shift = _mm256_set_epi32(2, 4, 6, 8, 2, 4, 6, 8);
+ auto idx_mask = _mm256_set1_epi32(0x300);
+ make2(qs+0, qh+0, idx_shift, idx_mask, values+0);
+ make2(qs+8, qh+2, idx_shift, idx_mask, values+2);
+ values[0] = _mm256_sign_epi8(values[0], sh.make_signs(signs[0] | (signs[1] << 16)));
+ values[1] = _mm256_sign_epi8(values[1], sh.make_signs(signs[2] | (signs[3] << 16)));
+ values[2] = _mm256_sign_epi8(values[2], sh.make_signs(signs[4] | (signs[5] << 16)));
+ values[3] = _mm256_sign_epi8(values[3], sh.make_signs(signs[6] | (signs[7] << 16)));
+ }
+ inline void prepare_signed(int i, int j, __m256i * us, __m256i * s) {
+ auto qs = x[i].qs + 16*j;
+ auto qh = x[i].qh + 4*j;
+ const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/8) + 8*j;
+ make2(qs+0, qh+0, idx_shift, idx_mask, us+0);
+ make2(qs+8, qh+2, idx_shift, idx_mask, us+2);
+ s[0] = _mm256_sign_epi8(s[0], sh.make_signs(signs[0] | (signs[1] << 16)));
+ s[1] = _mm256_sign_epi8(s[1], sh.make_signs(signs[2] | (signs[3] << 16)));
+ s[2] = _mm256_sign_epi8(s[2], sh.make_signs(signs[4] | (signs[5] << 16)));
+ s[3] = _mm256_sign_epi8(s[3], sh.make_signs(signs[6] | (signs[7] << 16)));
+ }
+ inline void prepare_signed(int i, int j, __m256i * us) {
+ auto qs = x[i].qs + 16*j;
+ auto qh = x[i].qh + 4*j;
+ const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/8) + 8*j;
+ make2(qs+0, qh+0, idx_shift, idx_mask, us+0);
+ make2(qs+8, qh+2, idx_shift, idx_mask, us+2);
+ bits.values[0] = _mm256_sign_epi8(us[0], sh.make_signs(signs[0] | (signs[1] << 16)));
+ bits.values[1] = _mm256_sign_epi8(us[1], sh.make_signs(signs[2] | (signs[3] << 16)));
+ bits.values[2] = _mm256_sign_epi8(us[2], sh.make_signs(signs[4] | (signs[5] << 16)));
+ bits.values[3] = _mm256_sign_epi8(us[3], sh.make_signs(signs[6] | (signs[7] << 16)));
+ }
constexpr static int minv = 43;
@@ -1780,6 +1839,59 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
}
}
+inline float convert_to_q8_k_r8(int k, float d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
+ auto max_i16 = _mm256_setzero_si256();
+ __m256i qs[16];
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ qs[2*ib32+0] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
+ qs[2*ib32+1] = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
+ qs[2*ib32+0] = _mm256_mullo_epi16(qs[2*ib32+0], _mm256_set1_epi16(scales[2*ib32+0]));
+ qs[2*ib32+1] = _mm256_mullo_epi16(qs[2*ib32+1], _mm256_set1_epi16(scales[2*ib32+1]));
+ max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+0], qs[2*ib32+0]));
+ max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+1], qs[2*ib32+1]));
+ }
+ auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1)));
+ auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
+ auto max4 = _mm_cvtepi32_ps(imax4);
+ max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
+ max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
+ bool needs_scaling = true;
+ float dnew = _mm_cvtss_f32(max4) * d0;
+ if (dnew < 1.f) {
+ dnew = 1.f; needs_scaling = false;
+ }
+ auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f);
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ if (needs_scaling) {
+ auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+0]));
+ auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+0], 1));
+ auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+1]));
+ auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+1], 1));
+ i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
+ i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
+ i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
+ i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
+ i0 = _mm256_packs_epi32(i0, i1);
+ i2 = _mm256_packs_epi32(i2, i3);
+ i0 = _mm256_packs_epi16(i0, i2);
+ i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
+ _mm256_storeu_si256((__m256i *)block, i0);
+ } else {
+ // 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
+ auto i0 = _mm256_packs_epi16(qs[2*ib32+0], qs[2*ib32+1]);
+ auto i0_l = _mm256_castsi256_si128(i0);
+ auto i0_h = _mm256_extracti128_si256(i0, 1);
+ _mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h));
+ _mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h));
+ }
+ auto qs = (uint32_t *)q8_k + 64*ib32;
+ for (int l = 0; l < 8; ++l) {
+ qs[8*l + k] = block[l];
+ }
+ }
+ return dnew;
+}
+
void iqk_convert_iq2_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
@@ -1829,6 +1941,563 @@ void iqk_convert_iq2_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, i
}
}
+void iqk_convert_iq2_xxs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq2_xxs * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ int16_t ls[16];
+ EvenSignHelper esh;
+
+ uint32_t block[8];
+ uint32_t aux32[2];
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+
+ __m256i values[8];
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_xxs *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ // TODO: simdify
+ for (int k = 0; k < 8; ++k) {
+ float d = 0.125f * GGML_FP16_TO_FP32(x8[k][i].d);
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ std::memcpy(aux32, x8[k][i].qs + 4*ib32, 2*sizeof(uint32_t));
+ ls[2*ib32+0] = (2*(aux32[1] >> 28) + 1);
+ ls[2*ib32+1] = ls[2*ib32+0];
+ values[ib32] = _mm256_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
+ esh.sign_value(aux32[1], values[ib32]);
+ }
+ float dnew = convert_to_q8_k_r8(k, 1.f/124, values, ls, block, y[i].qs);
+ y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
+ }
+ }
+ y += nb;
+ }
+}
+
+void iqk_convert_iq2_xs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq2_xs * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ uint32_t block[8];
+
+ union { __m256i vec; int16_t val[16]; } helper;
+ __m256i qx[8];
+
+#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__)
+ DequantizerIQ2XS::Helper sign_helper;
+#endif
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_xs *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = 0.125f * GGML_FP16_TO_FP32(x8[k][i].d);
+ helper.vec = DequantizerIQ2XS::make_scales(x8[k][i].scales);
+ auto q2l = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+0);
+ auto q2h = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+1);
+ DequantizerIQ2XS::make4(q2l, _mm256_set1_epi16(511), qx+0);
+ DequantizerIQ2XS::make4(q2h, _mm256_set1_epi16(511), qx+4);
+#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
+ DequantizerIQ2XS::sign_values_popcnt(q2l, qx+0);
+ DequantizerIQ2XS::sign_values_popcnt(q2h, qx+4);
+#else
+ DequantizerIQ2XS::sign_values_helper(q2l, sign_helper, qx+0);
+ DequantizerIQ2XS::sign_values_helper(q2h, sign_helper, qx+4);
+#endif
+ float dnew = convert_to_q8_k_r8(k, 1.f/124, qx, helper.val, block, y[i].qs);
+ y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
+ }
+ }
+ y += nb;
+ }
+}
+
+void iqk_convert_iq2_xs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq2_xs * x8[8];
+
+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
+
+ float all_s[64];
+
+ uint32_t block[8];
+
+ union { __m256i vec; int16_t val[16]; } helper;
+ __m256i qx[8];
+
+#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__)
+ DequantizerIQ2XS::Helper sign_helper;
+#endif
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_xs *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = 0.125f * GGML_FP16_TO_FP32(x8[k][i].d);
+ helper.vec = DequantizerIQ2XS::make_scales(x8[k][i].scales);
+ auto q2l = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+0);
+ auto q2h = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+1);
+ DequantizerIQ2XS::make4(q2l, _mm256_set1_epi16(511), qx+0);
+ DequantizerIQ2XS::make4(q2h, _mm256_set1_epi16(511), qx+4);
+#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
+ DequantizerIQ2XS::sign_values_popcnt(q2l, qx+0);
+ DequantizerIQ2XS::sign_values_popcnt(q2h, qx+4);
+#else
+ DequantizerIQ2XS::sign_values_helper(q2l, sign_helper, qx+0);
+ DequantizerIQ2XS::sign_values_helper(q2h, sign_helper, qx+4);
+#endif
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
+ auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
+ q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(helper.val[2*ib32+0]));
+ q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(helper.val[2*ib32+1]));
+ auto abs_q16_l = _mm256_sign_epi16(q16_l, q16_l);
+ auto abs_q16_h = _mm256_sign_epi16(q16_h, q16_h);
+ auto max_q16 = _mm256_max_epi16(abs_q16_l, abs_q16_h);
+ auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_q16), _mm256_extracti128_si256(max_q16, 1)));
+ auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
+ auto max4 = _mm_cvtepi32_ps(imax4);
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+ float max = _mm_cvtss_f32(max4) / 127;
+ all_s[8*ib32+k] = d*max;
+ if (max > 1e-9f) {
+ auto scale = _mm256_set1_ps(1/max);
+ auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l));
+ auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1));
+ auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h));
+ auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1));
+ i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
+ i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
+ i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
+ i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
+ i0 = _mm256_packs_epi32(i0, i1);
+ i2 = _mm256_packs_epi32(i2, i3);
+ i0 = _mm256_packs_epi16(i0, i2);
+ i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
+ _mm256_storeu_si256((__m256i *)block, i0);
+ } else {
+ _mm256_storeu_si256((__m256i *)block, _mm256_setzero_si256());
+ }
+ auto qs = (uint32_t *)y[ib32].qs;
+ for (int l = 0; l < 4; ++l) {
+ qs[8*l + k + 0] = block[l + 0];
+ qs[8*l + k + 32] = block[l + 4];
+ }
+ }
+ }
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ _mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(_mm256_loadu_ps(all_s + 8*ib32), _MM_FROUND_TO_NEAREST_INT));
+ }
+ y += QK_K/32;
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq2_xs_q8_2_X4(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+
+ DequantizerIQ2XS deq(vx, bx);
+
+ __m256 accd[nrc_y];
+ __m256 scales[2];
+ float d8[8*nrc_y];
+ __m256i us[4];
+
+ uint8_t k_shuff[32] = {0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15};
+ auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ deq.d = 0.125f * GGML_FP16_TO_FP32(deq.x[i].d);
+ auto vd = _mm256_set1_ps(deq.d);
+ auto sc16 = _mm256_shuffle_epi8(DequantizerIQ2XS::make_scales(deq.x[i].scales), shuff);
+ scales[0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(sc16))));
+ scales[1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(sc16, 1))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d)));
+ auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d)));
+ auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16));
+ if constexpr (nrc_y == 1) {
+ auto dyh = _mm256_extractf128_ps(dy, 1);
+ scales[0] = _mm256_mul_ps(scales[0], _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)));
+ scales[1] = _mm256_mul_ps(scales[1], _mm256_set_m128(dyh, dyh));
+ } else {
+ _mm256_storeu_ps(d8 + 8*iy, dy);
+ }
+ }
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ if constexpr (nrc_y == 1) {
+ auto qs = q8.y[0][2*i+j].qs;
+ for (int k = 0; k < 4; ++k) us[k] = _mm256_loadu_si256((const __m256i*)qs+k);
+ deq.prepare_signed(i, j, deq.bits.values, us);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[0], us[0]);
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[1], us[1]);
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[2], us[2]);
+ auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[3], us[3]);
+ sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
+ sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
+ sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(deq.bits.values[0], us[0]);
+ auto sumi2 = _mm256_maddubs_epi16(deq.bits.values[1], us[1]);
+ auto sumi3 = _mm256_maddubs_epi16(deq.bits.values[2], us[2]);
+ auto sumi4 = _mm256_maddubs_epi16(deq.bits.values[3], us[3]);
+ sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)));
+ sumi3 = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)));
+ sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
+#endif
+ accd[0] = _mm256_fmadd_ps(scales[j], _mm256_cvtepi32_ps(sumi1), accd[0]);
+ }
+ else {
+ deq.prepare_signed(i, j, us);
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qs = q8.y[iy][2*i+j].qs;
+#ifdef HAVE_FANCY_SIMD
+ // 0...31
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+0), deq.bits.values[0]));
+ // 32...63
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[1], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+1), deq.bits.values[1]));
+ // 64...95
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[2], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+2), deq.bits.values[2]));
+ // 96...128
+ auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[3], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+3), deq.bits.values[3]));
+ // 0...3, 32...35, 4....7, 36...39, 16...19, 48...51, 20...23, 52...56 +
+ // 8..11, 40...43, 12...15, 44...47, 24...27, 56...59, 28...31, 60...63
+ // b0 b2 b0 b2 b1 b3 b1 b3
+ sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
+ // same as above + 64, so
+ // b4 b6, b4 b6 b5 b7 b5 b7
+ sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
+ // b0 b2 b4 b6 b1 b3 b5 b7 +
+ // b0 b2 b4 b6 b1 b3 b5 b7
+ sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(us[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+0), deq.bits.values[0]));
+ auto sumi2 = _mm256_maddubs_epi16(us[1], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+1), deq.bits.values[1]));
+ auto sumi3 = _mm256_maddubs_epi16(us[2], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+2), deq.bits.values[2]));
+ auto sumi4 = _mm256_maddubs_epi16(us[3], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+3), deq.bits.values[3]));
+ sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
+ sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
+ sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
+ sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1);
+#endif
+ auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j);
+ auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4));
+ accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]);
+ }
+ }
+
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+
+ }
+}
+
+void iqk_convert_iq2_s_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq2_s * x8[8];
+
+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
+
+ float all_s[64];
+
+ uint32_t block[8];
+
+ union { __m256i vec; int16_t val[16]; } helper;
+ __m256i qx[8];
+
+ SignHelper sh;
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_s *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = 0.125f * GGML_FP16_TO_FP32(x8[k][i].d);
+ helper.vec = DequantizerIQ2S::make_scales(x8[k][i].scales);
+ DequantizerIQ2S::prepare(x8[k][i].qs+ 0, x8[k][i].qh+0, (const uint16_t *)(x8[k][i].qs + QK_K/8) + 0, sh, qx+0);
+ DequantizerIQ2S::prepare(x8[k][i].qs+16, x8[k][i].qh+4, (const uint16_t *)(x8[k][i].qs + QK_K/8) + 8, sh, qx+4);
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
+ auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
+ q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(helper.val[2*ib32+0]));
+ q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(helper.val[2*ib32+1]));
+ auto abs_q16_l = _mm256_sign_epi16(q16_l, q16_l);
+ auto abs_q16_h = _mm256_sign_epi16(q16_h, q16_h);
+ auto max_q16 = _mm256_max_epi16(abs_q16_l, abs_q16_h);
+ auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_q16), _mm256_extracti128_si256(max_q16, 1)));
+ auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
+ auto max4 = _mm_cvtepi32_ps(imax4);
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+ float max = _mm_cvtss_f32(max4) / 127;
+ all_s[8*ib32+k] = d*max;
+ if (max > 1e-9f) {
+ auto scale = _mm256_set1_ps(1/max);
+ auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l));
+ auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1));
+ auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h));
+ auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1));
+ i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
+ i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
+ i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
+ i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
+ i0 = _mm256_packs_epi32(i0, i1);
+ i2 = _mm256_packs_epi32(i2, i3);
+ i0 = _mm256_packs_epi16(i0, i2);
+ i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
+ _mm256_storeu_si256((__m256i *)block, i0);
+ } else {
+ _mm256_storeu_si256((__m256i *)block, _mm256_setzero_si256());
+ }
+ auto qs = (uint32_t *)y[ib32].qs;
+ for (int l = 0; l < 4; ++l) {
+ qs[8*l + k + 0] = block[l + 0];
+ qs[8*l + k + 32] = block[l + 4];
+ }
+ }
+ }
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ _mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(_mm256_loadu_ps(all_s + 8*ib32), _MM_FROUND_TO_NEAREST_INT));
+ }
+ y += QK_K/32;
+ }
+ }
+}
+
+void iqk_convert_iq2_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq2_s * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ uint32_t block[8];
+
+ union { __m256i vec; int16_t val[16]; } helper;
+ __m256i qx[8];
+
+ SignHelper sh;
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_s *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = 0.125f * GGML_FP16_TO_FP32(x8[k][i].d);
+ helper.vec = DequantizerIQ2S::make_scales(x8[k][i].scales);
+ DequantizerIQ2S::prepare(x8[k][i].qs+ 0, x8[k][i].qh+0, (const uint16_t *)(x8[k][i].qs + QK_K/8) + 0, sh, qx+0);
+ DequantizerIQ2S::prepare(x8[k][i].qs+16, x8[k][i].qh+4, (const uint16_t *)(x8[k][i].qs + QK_K/8) + 8, sh, qx+4);
+ float dnew = convert_to_q8_k_r8(k, 1.f/124, qx, helper.val, block, y[i].qs);
+ y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
+ }
+ }
+ y += nb;
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq2_s_q8_2_X4(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+
+ DequantizerIQ2S deq(vx, bx);
+
+ __m256 accd[nrc_y];
+ __m256 scales[2];
+ float d8[8*nrc_y];
+ __m256i us[4];
+
+ uint8_t k_shuff[32] = {0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15};
+ auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ deq.d = 0.125f * GGML_FP16_TO_FP32(deq.x[i].d);
+ auto vd = _mm256_set1_ps(deq.d);
+ auto sc16 = _mm256_shuffle_epi8(DequantizerIQ2S::make_scales(deq.x[i].scales), shuff);
+ scales[0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(sc16))));
+ scales[1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(sc16, 1))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d)));
+ auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d)));
+ auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16));
+ if constexpr (nrc_y == 1) {
+ auto dyh = _mm256_extractf128_ps(dy, 1);
+ scales[0] = _mm256_mul_ps(scales[0], _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)));
+ scales[1] = _mm256_mul_ps(scales[1], _mm256_set_m128(dyh, dyh));
+ } else {
+ _mm256_storeu_ps(d8 + 8*iy, dy);
+ }
+ }
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ if constexpr (nrc_y == 1) {
+ auto qs = q8.y[0][2*i+j].qs;
+ for (int k = 0; k < 4; ++k) us[k] = _mm256_loadu_si256((const __m256i*)qs+k);
+ deq.prepare_signed(i, j, deq.bits.values, us);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[0], us[0]);
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[1], us[1]);
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[2], us[2]);
+ auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[3], us[3]);
+ sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
+ sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
+ sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(deq.bits.values[0], us[0]);
+ auto sumi2 = _mm256_maddubs_epi16(deq.bits.values[1], us[1]);
+ auto sumi3 = _mm256_maddubs_epi16(deq.bits.values[2], us[2]);
+ auto sumi4 = _mm256_maddubs_epi16(deq.bits.values[3], us[3]);
+ sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)));
+ sumi3 = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)));
+ sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
+#endif
+ accd[0] = _mm256_fmadd_ps(scales[j], _mm256_cvtepi32_ps(sumi1), accd[0]);
+ }
+ else {
+ deq.prepare_signed(i, j, us);
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qs = q8.y[iy][2*i+j].qs;
+#ifdef HAVE_FANCY_SIMD
+ // 0...31
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+0), deq.bits.values[0]));
+ // 32...63
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[1], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+1), deq.bits.values[1]));
+ // 64...95
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[2], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+2), deq.bits.values[2]));
+ // 96...128
+ auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[3], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+3), deq.bits.values[3]));
+ // 0...3, 32...35, 4....7, 36...39, 16...19, 48...51, 20...23, 52...56 +
+ // 8..11, 40...43, 12...15, 44...47, 24...27, 56...59, 28...31, 60...63
+ // b0 b2 b0 b2 b1 b3 b1 b3
+ sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
+ // same as above + 64, so
+ // b4 b6, b4 b6 b5 b7 b5 b7
+ sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
+ // b0 b2 b4 b6 b1 b3 b5 b7 +
+ // b0 b2 b4 b6 b1 b3 b5 b7
+ sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(us[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+0), deq.bits.values[0]));
+ auto sumi2 = _mm256_maddubs_epi16(us[1], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+1), deq.bits.values[1]));
+ auto sumi3 = _mm256_maddubs_epi16(us[2], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+2), deq.bits.values[2]));
+ auto sumi4 = _mm256_maddubs_epi16(us[3], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+3), deq.bits.values[3]));
+ sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
+ sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
+ sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
+ sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1);
+#endif
+ auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j);
+ auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4));
+ accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]);
+ }
+ }
+
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+
+ }
+}
+
+void iqk_convert_iq3_xxs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq3_xxs * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ int16_t ls[16];
+ EvenSignHelper esh;
+
+ __m256i values[8];
+ uint32_t block[8];
+ uint32_t aux32;
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq3_xxs *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = 0.25f * GGML_FP16_TO_FP32(x8[k][i].d);
+ auto qs = x8[k][i].qs;
+ auto sas = qs + QK_K/4;
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ std::memcpy(&aux32, sas + 4*ib32, sizeof(uint32_t));
+ ls[2*ib32 + 0] = (2*(aux32 >> 28) + 1);
+ ls[2*ib32 + 1] = ls[2*ib32 + 0];
+ values[ib32] = _mm256_set_epi32(iq3xxs_grid[qs[7]], iq3xxs_grid[qs[6]], iq3xxs_grid[qs[5]], iq3xxs_grid[qs[4]],
+ iq3xxs_grid[qs[3]], iq3xxs_grid[qs[2]], iq3xxs_grid[qs[1]], iq3xxs_grid[qs[0]]);
+ esh.sign_value(aux32, values[ib32]);
+ qs += 8;
+ }
+ float dnew = convert_to_q8_k_r8(k, 1.f/124, values, ls, block, y[i].qs);
+ y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
+ }
+ }
+ y += nb;
+ }
+}
+
void iqk_convert_iq3_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
@@ -1881,6 +2550,49 @@ void iqk_convert_iq3_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, i
}
}
+void iqk_convert_iq3_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq3_s * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ int16_t ls[16];
+ SignHelper sh;
+ IndexHelperIQ3S helper;
+
+ uint32_t block[8];
+ __m256i values[8];
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq3_s *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = GGML_FP16_TO_FP32(x8[k][i].d);
+ auto qs = x8[k][i].qs;
+ auto qh = x8[k][i].qh;
+ auto signs = (const uint16_t *)x8[k][i].signs;
+ helper.make2(qs+ 0, qh+0, values+0);
+ helper.make2(qs+16, qh+2, values+2);
+ sh.sign_4_values(signs+0, values+0);
+ helper.make2(qs+32, qh+4, values+4);
+ helper.make2(qs+48, qh+6, values+6);
+ sh.sign_4_values(signs+8, values+4);
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ ls[2*ib32 + 0] = (2*((x8[k][i].scales[ib32/2] >> 4*(ib32%2)) & 0xf) + 1);
+ ls[2*ib32 + 1] = ls[2*ib32 + 0];
+ }
+ float dnew = convert_to_q8_k_r8(k, 1.f/127, values, ls, block, y[i].qs);
+ y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
+ }
+ }
+ y += nb;
+ }
+}
+
void iqk_convert_iq3_s_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
@@ -1952,40 +2664,58 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
if (ne00%QK_K != 0) return false;
- if (ggml_type(typeA) == GGML_TYPE_IQ2_XXS) {
- if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
- IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ2XXS, kernels);
- func16 = nullptr;
- return true;
- }
- return false;
- }
-
- if (ggml_type(typeA) == GGML_TYPE_IQ3_XXS) {
- if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
- IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3XXS, kernels);
- func16 = nullptr;
- return true;
- }
- return false;
- }
-
- if (ggml_type(typeA) == GGML_TYPE_IQ3_S) {
- if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
- //IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels);
- kernels[0] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 1, 8>;
- kernels[1] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 2, 8>;
- kernels[2] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 3, 8>;
- kernels[3] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 4, 8>;
- kernels[4] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 5, 8>;
- kernels[5] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 6, 8>;
- kernels[6] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 7, 8>;
- kernels[7] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 8, 8>;
- func16 = nullptr;
- return true;
- }
- return false;
- }
+ //if (ggml_type(typeA) == GGML_TYPE_IQ2_XXS) {
+ // if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
+ // IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ2XXS, kernels);
+ // func16 = nullptr;
+ // return true;
+ // }
+ // return false;
+ //}
+
+ //if (ggml_type(typeA) == GGML_TYPE_IQ2_XS) {
+ // if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
+ // IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xs_q8_2_X4, kernels);
+ // func16 = nullptr;
+ // return true;
+ // }
+ // return false;
+ //}
+
+ //if (ggml_type(typeA) == GGML_TYPE_IQ2_S) {
+ // if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
+ // IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_s_q8_2_X4, kernels);
+ // func16 = nullptr;
+ // return true;
+ // }
+ // return false;
+ //}
+
+ //if (ggml_type(typeA) == GGML_TYPE_IQ3_XXS) {
+ // if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
+ // IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3XXS, kernels);
+ // func16 = nullptr;
+ // return true;
+ // }
+ // return false;
+ //}
+
+ //if (ggml_type(typeA) == GGML_TYPE_IQ3_S) {
+ // if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
+ // //IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels);
+ // kernels[0] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 1, 8>;
+ // kernels[1] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 2, 8>;
+ // kernels[2] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 3, 8>;
+ // kernels[3] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 4, 8>;
+ // kernels[4] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 5, 8>;
+ // kernels[5] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 6, 8>;
+ // kernels[6] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 7, 8>;
+ // kernels[7] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 8, 8>;
+ // func16 = nullptr;
+ // return true;
+ // }
+ // return false;
+ //}
if (ggml_type(typeB) != GGML_TYPE_Q8_K) {
return false;
@@ -2044,9 +2774,11 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
bool iqk_convert_iquants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
if (n%QK_K != 0 || nrc_x%8 != 0) return false;
switch (ggml_type(type)) {
- case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_0_r8(n, vx, bx, vy, nrc_x); break;
- case GGML_TYPE_IQ3_XXS: iqk_convert_iq3_xxs_q8_0_r8(n, vx, bx, vy, nrc_x); break;
- case GGML_TYPE_IQ3_S : iqk_convert_iq3_s_q8_0_r8 (n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ2_XS : iqk_convert_iq2_xs_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ2_S : iqk_convert_iq2_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ3_XXS: iqk_convert_iq3_xxs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ3_S : iqk_convert_iq3_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;
diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp
index 589fbc26..43eff43c 100644
--- a/ggml/src/iqk/iqk_gemm_kquants.cpp
+++ b/ggml/src/iqk/iqk_gemm_kquants.cpp
@@ -6,6 +6,7 @@
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
+#include "ggml-quants.h"
#ifdef __x86_64__
@@ -860,6 +861,175 @@ static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const Data
}
}
+struct DequantizerQ6K_AVX2 final : public BaseDequantizer<block_q6_K> {
+ DequantizerQ6K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ inline void prepare(int i, int j) {
+ auto lbits1 = _mm256_loadu_si256((const __m256i *)x[i].ql + 2*j+0);
+ auto lbits2 = _mm256_loadu_si256((const __m256i *)x[i].ql + 2*j+1);
+ auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j);
+ bits.values[0] = _mm256_or_si256(_mm256_and_si256(lbits1, bits.ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
+ bits.values[1] = _mm256_or_si256(_mm256_and_si256(lbits2, bits.ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
+ bits.values[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), bits.ml), _mm256_and_si256(hbits, mh));
+ bits.values[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), bits.ml), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));
+ }
+ inline void prepare_signed(int i, int j, __m256i * us) {
+ prepare(i, j);
+ for (int k = 0; k < 4; ++k) {
+ bits.values[k] = _mm256_add_epi8(bits.values[k], _mm256_set1_epi8(-32));
+ us[k] = _mm256_sign_epi8(bits.values[k], bits.values[k]);
+ }
+ }
+ inline __m256i make_scales(int i) const {
+ return _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)x[i].scales));
+ }
+
+ const __m256i mh = _mm256_set1_epi8(0x30);
+ Q4Bits_AVX2 bits;
+};
+
+struct SimpleBits {
+ __m256i values[4];
+};
+
+struct DequantizerQ3K_AVX2 final : public BaseDequantizer<block_q3_K> {
+ DequantizerQ3K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ inline void prepare(int i, int j) {
+ hbits = j == 0 ? _mm256_loadu_si256((const __m256i *)x[i].hmask) : _mm256_srli_epi16(hbits, 4);
+ auto q2bits = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
+ bits.values[0] = _mm256_and_si256(q2bits, ml);
+ bits.values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
+ bits.values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
+ bits.values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
+ bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
+ bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
+ bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));
+ bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));
+ //bits.values[0] = _mm256_sub_epi8(bits.values[0], _mm256_xor_si256(mh, _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)));
+ //bits.values[1] = _mm256_sub_epi8(bits.values[1], _mm256_xor_si256(mh, _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)));
+ //bits.values[2] = _mm256_sub_epi8(bits.values[2], _mm256_xor_si256(mh, _mm256_and_si256(hbits, mh)));
+ //bits.values[3] = _mm256_sub_epi8(bits.values[3], _mm256_xor_si256(mh, _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh)));
+ }
+ inline void prepare_signed(int i, int j, __m256i * us) {
+ prepare(i, j);
+ for (int k = 0; k < 4; ++k) {
+ bits.values[k] = _mm256_sub_epi8(bits.values[k], mh);
+ us[k] = _mm256_sign_epi8(bits.values[k], bits.values[k]);
+ }
+ //for (int k = 0; k < 4; ++k) {
+ // us[k] = _mm256_sign_epi8(bits.values[k], bits.values[k]);
+ //}
+ }
+ inline __m256i make_scales(int i) const {
+ return _mm256_cvtepi8_epi16(sc3.make_scales((const uint16_t *)x[i].scales));
+ }
+
+ ScaleQ3 sc3;
+
+ __m256i hbits;
+ SimpleBits bits;
+ const __m256i ml = _mm256_set1_epi8(3);
+ const __m256i mh = _mm256_set1_epi8(4);
+};
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qY_K_q8_2_X4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+
+ Dequantizer deq(vx, bx);
+
+ __m256 accd[nrc_y];
+ __m256 scales[2];
+ float d8[8*nrc_y];
+ __m256i us[4];
+
+ uint8_t k_shuff[32] = {0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15};
+ auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ deq.d = GGML_FP16_TO_FP32(deq.x[i].d);
+ auto vd = _mm256_set1_ps(deq.d);
+ auto sc16 = _mm256_shuffle_epi8(deq.make_scales(i), shuff);
+ scales[0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(sc16))));
+ scales[1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(sc16, 1))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d)));
+ auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d)));
+ auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16));
+ if constexpr (nrc_y == 1) {
+ auto dyh = _mm256_extractf128_ps(dy, 1);
+ scales[0] = _mm256_mul_ps(scales[0], _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)));
+ scales[1] = _mm256_mul_ps(scales[1], _mm256_set_m128(dyh, dyh));
+ } else {
+ _mm256_storeu_ps(d8 + 8*iy, dy);
+ }
+ }
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ deq.prepare_signed(i, j, us);
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qs = q8.y[iy][2*i+j].qs;
+#ifdef HAVE_FANCY_SIMD
+ // 0...31
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+0), deq.bits.values[0]));
+ // 32...63
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[1], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+1), deq.bits.values[1]));
+ // 64...95
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[2], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+2), deq.bits.values[2]));
+ // 96...128
+ auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[3], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+3), deq.bits.values[3]));
+ // 0...3, 32...35, 4....7, 36...39, 16...19, 48...51, 20...23, 52...56 +
+ // 8..11, 40...43, 12...15, 44...47, 24...27, 56...59, 28...31, 60...63
+ // b0 b2 b0 b2 b1 b3 b1 b3
+ sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
+ // same as above + 64, so
+ // b4 b6, b4 b6 b5 b7 b5 b7
+ sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
+ // b0 b2 b4 b6 b1 b3 b5 b7 +
+ // b0 b2 b4 b6 b1 b3 b5 b7
+ sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(us[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+0), deq.bits.values[0]));
+ auto sumi2 = _mm256_maddubs_epi16(us[1], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+1), deq.bits.values[1]));
+ auto sumi3 = _mm256_maddubs_epi16(us[2], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+2), deq.bits.values[2]));
+ auto sumi4 = _mm256_maddubs_epi16(us[3], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+3), deq.bits.values[3]));
+ sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
+ sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
+ sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
+ sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1);
+#endif
+ if constexpr (nrc_y > 1) {
+ auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j);
+ auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4));
+ accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]);
+ } else {
+ accd[iy] = _mm256_fmadd_ps(scales[j], _mm256_cvtepi32_ps(sumi1), accd[iy]);
+ }
+ }
+
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+
+ }
+}
+
template <int nrc_y>
static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
@@ -1669,14 +1839,13 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
}
}
#ifdef HAVE_FANCY_SIMD
- auto m4 = _mm256_mul_ps(d4, _mm256_set1_ps(-128.f));
+ auto m4 = _mm256_mul_ps(d4, _mm256_set1_ps(-127.f));
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
#ifdef HAVE_FANCY_SIMD
- auto bsums = (const float *)q8.y[iy][ibl].bsums;
- acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(bsums[0]), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(q8.y[iy][ibl].sum), acc[iy]);
#endif
isum[iy] = _mm256_setzero_si256();
}
@@ -1982,6 +2151,284 @@ void iqk_convert_q5_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int
}
}
+void iqk_convert_q6_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_q6_K * x8[8];
+
+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
+
+ float all_s[64];
+ uint32_t block[8];
+ __m256i values[8];
+
+ auto ml = _mm256_set1_epi8(0x0f);
+ auto mh = _mm256_set1_epi8(0x30);
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_q6_K *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = GGML_FP16_TO_FP32(x8[k][i].d);
+ auto ql = x8[k][i].ql;
+ auto qh = x8[k][i].qh;
+ for (int i128 = 0; i128 < 2; ++i128) {
+ auto lbits1 = _mm256_loadu_si256((const __m256i *)ql + 2*i128 + 0);
+ auto lbits2 = _mm256_loadu_si256((const __m256i *)ql + 2*i128 + 1);
+ auto hbits = _mm256_loadu_si256((const __m256i *)qh + i128);
+ values[4*i128+0] = _mm256_or_si256(_mm256_and_si256(lbits1, ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
+ values[4*i128+1] = _mm256_or_si256(_mm256_and_si256(lbits2, ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
+ values[4*i128+2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), ml), _mm256_and_si256(hbits, mh));
+ values[4*i128+3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), ml), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));
+ }
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ // We have two blocks of 16 with different scales
+ // We multiply the quants with the scales, find the max value, and convert to 8-bit quants with a single block scale.
+ auto q8 = _mm256_add_epi8(values[ib32], _mm256_set1_epi8(-32));
+ auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q8));
+ auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q8, 1));
+ q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(x8[k][i].scales[2*ib32+0]));
+ q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(x8[k][i].scales[2*ib32+1]));
+ auto abs_q16_l = _mm256_sign_epi16(q16_l, q16_l);
+ auto abs_q16_h = _mm256_sign_epi16(q16_h, q16_h);
+ auto max_q16 = _mm256_max_epi16(abs_q16_l, abs_q16_h);
+ auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_q16), _mm256_extracti128_si256(max_q16, 1)));
+ auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
+ auto max4 = _mm_cvtepi32_ps(imax4);
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+ float max = _mm_cvtss_f32(max4) / 127;
+ all_s[8*ib32+k] = d*max;
+ if (max > 1e-9f) {
+ auto scale = _mm256_set1_ps(1/max);
+ auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l));
+ auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1));
+ auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h));
+ auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1));
+ i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
+ i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
+ i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
+ i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
+ i0 = _mm256_packs_epi32(i0, i1);
+ i2 = _mm256_packs_epi32(i2, i3);
+ i0 = _mm256_packs_epi16(i0, i2);
+ i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
+ _mm256_storeu_si256((__m256i *)block, i0);
+ } else {
+ _mm256_storeu_si256((__m256i *)block, _mm256_setzero_si256());
+ }
+ auto qs = (uint32_t *)y[ib32].qs;
+ for (int l = 0; l < 4; ++l) {
+ qs[8*l + k + 0] = block[l + 0];
+ qs[8*l + k + 32] = block[l + 4];
+ }
+ }
+ }
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ _mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(_mm256_loadu_ps(all_s + 8*ib32), _MM_FROUND_TO_NEAREST_INT));
+ }
+ y += QK_K/32;
+ }
+ }
+}
+
+void iqk_convert_q3_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_q3_K * x8[8];
+
+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
+
+ float all_s[64];
+ uint32_t block[8];
+ __m256i values[8];
+
+ ScaleQ3 sc3;
+ auto ml = _mm256_set1_epi8(0x03);
+ auto mh = _mm256_set1_epi8(0x04);
+
+ union { __m256i vec; int16_t val[16]; } helper;
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_q3_K *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = GGML_FP16_TO_FP32(x8[k][i].d);
+ auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].hmask);
+ for (int i128 = 0; i128 < 2; ++i128) {
+ auto q2bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs + i128);
+ values[4*i128+0] = _mm256_and_si256(q2bits, ml);
+ values[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
+ values[4*i128+2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
+ values[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
+ values[4*i128+0] = _mm256_or_si256(values[4*i128+0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
+ values[4*i128+1] = _mm256_or_si256(values[4*i128+1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
+ values[4*i128+2] = _mm256_or_si256(values[4*i128+2], _mm256_and_si256(hbits, mh));
+ values[4*i128+3] = _mm256_or_si256(values[4*i128+3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));
+ values[4*i128+0] = _mm256_sub_epi8(values[4*i128+0], mh);
+ values[4*i128+1] = _mm256_sub_epi8(values[4*i128+1], mh);
+ values[4*i128+2] = _mm256_sub_epi8(values[4*i128+2], mh);
+ values[4*i128+3] = _mm256_sub_epi8(values[4*i128+3], mh);
+ hbits = _mm256_srli_epi16(hbits, 4);
+ }
+ helper.vec = _mm256_cvtepi8_epi16(sc3.make_scales((const uint16_t *)x8[k][i].scales));
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(values[ib32]));
+ auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(values[ib32], 1));
+ q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(helper.val[2*ib32+0]));
+ q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(helper.val[2*ib32+1]));
+ auto abs_q16_l = _mm256_sign_epi16(q16_l, q16_l);
+ auto abs_q16_h = _mm256_sign_epi16(q16_h, q16_h);
+ auto max_q16 = _mm256_max_epi16(abs_q16_l, abs_q16_h);
+ auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_q16), _mm256_extracti128_si256(max_q16, 1)));
+ auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
+ auto max4 = _mm_cvtepi32_ps(imax4);
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+ float max = _mm_cvtss_f32(max4) / 127;
+ all_s[8*ib32+k] = d*max;
+ if (max > 1e-9f) {
+ auto scale = _mm256_set1_ps(1/max);
+ auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l));
+ auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1));
+ auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h));
+ auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1));
+ i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
+ i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
+ i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
+ i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
+ i0 = _mm256_packs_epi32(i0, i1);
+ i2 = _mm256_packs_epi32(i2, i3);
+ i0 = _mm256_packs_epi16(i0, i2);
+ i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
+ _mm256_storeu_si256((__m256i *)block, i0);
+ } else {
+ _mm256_storeu_si256((__m256i *)block, _mm256_setzero_si256());
+ }
+ auto qs = (uint32_t *)y[ib32].qs;
+ for (int l = 0; l < 4; ++l) {
+ qs[8*l + k + 0] = block[l + 0];
+ qs[8*l + k + 32] = block[l + 4];
+ }
+ }
+ }
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ _mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(_mm256_loadu_ps(all_s + 8*ib32), _MM_FROUND_TO_NEAREST_INT));
+ }
+ y += QK_K/32;
+ }
+ }
+}
+
+void iqk_convert_q3_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_q3_K * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ uint32_t block[8];
+ __m256i values[8];
+
+ ScaleQ3 sc3;
+ auto ml = _mm256_set1_epi8(0x03);
+ auto mh = _mm256_set1_epi8(0x04);
+
+ union { __m256i vec; int16_t val[16]; } helper;
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_q3_K *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = GGML_FP16_TO_FP32(x8[k][i].d);
+ auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].hmask);
+ helper.vec = _mm256_cvtepi8_epi16(sc3.make_scales((const uint16_t *)x8[k][i].scales));
+ auto max_i16 = _mm256_setzero_si256();
+ for (int i128 = 0; i128 < 2; ++i128) {
+ auto q2bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs + i128);
+ values[4*i128+0] = _mm256_and_si256(q2bits, ml);
+ values[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
+ values[4*i128+2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
+ values[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
+ values[4*i128+0] = _mm256_or_si256(values[4*i128+0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
+ values[4*i128+1] = _mm256_or_si256(values[4*i128+1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
+ values[4*i128+2] = _mm256_or_si256(values[4*i128+2], _mm256_and_si256(hbits, mh));
+ values[4*i128+3] = _mm256_or_si256(values[4*i128+3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));
+ values[4*i128+0] = _mm256_sub_epi8(values[4*i128+0], mh);
+ values[4*i128+1] = _mm256_sub_epi8(values[4*i128+1], mh);
+ values[4*i128+2] = _mm256_sub_epi8(values[4*i128+2], mh);
+ values[4*i128+3] = _mm256_sub_epi8(values[4*i128+3], mh);
+ hbits = _mm256_srli_epi16(hbits, 4);
+
+ for (int l = 0; l < 4; ++l) {
+ auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(values[4*i128+l]));
+ auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(values[4*i128+l], 1));
+ q16_l = _mm256_mullo_epi16(_mm256_set1_epi16(helper.val[8*i128+2*l+0]), q16_l);
+ q16_h = _mm256_mullo_epi16(_mm256_set1_epi16(helper.val[8*i128+2*l+1]), q16_h);
+ max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_l, q16_l));
+ max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_h, q16_h));
+ }
+ }
+ auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1)));
+ auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
+ auto max4 = _mm_cvtepi32_ps(imax4);
+ max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
+ max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
+ bool needs_scaling = true;
+ float dnew = _mm_cvtss_f32(max4) / 127;
+ if (dnew < 1.f) {
+ dnew = 1.f; needs_scaling = false;
+ }
+ d *= dnew;
+ y[i].d[k] = GGML_FP32_TO_FP16(d);
+ auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f);
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(values[ib32]));
+ auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(values[ib32], 1));
+ q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(helper.val[2*ib32+0]));
+ q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(helper.val[2*ib32+1]));
+ if (needs_scaling) {
+ auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l));
+ auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1));
+ auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h));
+ auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1));
+ i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
+ i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
+ i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
+ i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
+ i0 = _mm256_packs_epi32(i0, i1);
+ i2 = _mm256_packs_epi32(i2, i3);
+ i0 = _mm256_packs_epi16(i0, i2);
+ i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
+ _mm256_storeu_si256((__m256i *)block, i0);
+ } else {
+ // 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
+ auto i0 = _mm256_packs_epi16(q16_l, q16_h);
+ auto i0_l = _mm256_castsi256_si128(i0);
+ auto i0_h = _mm256_extracti128_si256(i0, 1);
+ _mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h));
+ _mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h));
+ }
+ auto qs = (uint32_t *)y[i].qs + 64*ib32;
+ for (int l = 0; l < 8; ++l) {
+ qs[8*l + k] = block[l];
+ }
+ }
+ }
+ }
+ y += nb;
+ }
+}
+
} // namespace
@@ -1989,9 +2436,12 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
auto etypeA = ggml_type(typeA);
auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
- : etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
+ //: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
: etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV
- : etypeA == GGML_TYPE_Q4_K || etypeA == GGML_TYPE_Q5_K ? GGML_TYPE_Q8_2_X4
+ : etypeA == GGML_TYPE_Q4_K || etypeA == GGML_TYPE_Q5_K ||
+ etypeA == GGML_TYPE_Q6_K ? GGML_TYPE_Q8_2_X4
+ //etypeA == GGML_TYPE_Q6_K || etypeA == GGML_TYPE_Q3_K ? GGML_TYPE_Q8_2_X4
+ //: etypeA == GGML_TYPE_Q4_K || etypeA == GGML_TYPE_Q5_K ? GGML_TYPE_Q8_2_X4
: GGML_TYPE_Q8_K;
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
@@ -2006,6 +2456,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
break;
case GGML_TYPE_Q3_K:
set_functions<DequantizerQ3K>(kernels);
+ //IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_2_X4_T, DequantizerQ3K_AVX2, kernels);
break;
case GGML_TYPE_Q4_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ4K_AVX2, kernels);
@@ -2016,7 +2467,8 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
//set_functions<DequantizerQ5K>(kernels);
break;
case GGML_TYPE_Q6_K:
- set_functions<DequantizerQ6K>(kernels);
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_2_X4_T, DequantizerQ6K_AVX2, kernels);
+ //set_functions<DequantizerQ6K>(kernels);
break;
case GGML_TYPE_IQ4_XS:
set_functions<DequantizerIQ4XS>(kernels);
@@ -2064,8 +2516,10 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
switch (ggml_type(type)) {
+ case GGML_TYPE_Q3_K: iqk_convert_q3_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_Q6_K: iqk_convert_q6_k_q8_0_r8(n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;
@@ -3075,7 +3529,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
auto etypeA = ggml_type(typeA);
auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
- : etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
+ //: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
: etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV
: GGML_TYPE_Q8_K;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index b23dc6d4..0b29a572 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -239,12 +239,17 @@ struct MulMat {
case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type;
case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type;
case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type;
- case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
- case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
- case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
- case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
+ case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ2_S : return nrc_y >= 16 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ1_M : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_Q3_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
+ case GGML_TYPE_Q6_K : return nrc_y >= 64 ? GGML_TYPE_Q8_0_R8 : type;
default: break;
}
#else
@@ -344,10 +349,10 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
//case GGML_TYPE_BF16_R16:
// return iqk_set_kernels_float(ne00, typeA, typeB, mm.funcs);
//case GGML_TYPE_Q2_K:
- //case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
- //case GGML_TYPE_Q6_K:
+ case GGML_TYPE_Q6_K:
//case GGML_TYPE_IQ4_XS:
//case GGML_TYPE_Q2_K_R4:
//case GGML_TYPE_Q3_K_R4:
@@ -404,6 +409,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
//case GGML_TYPE_IQ4_NL_R4:
// return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
//case GGML_TYPE_IQ1_S_R4:
//case GGML_TYPE_IQ1_M_R4:
//case GGML_TYPE_IQ1_BN:
@@ -420,6 +426,10 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
}
+extern "C" IQK_API int iqk_dequant_type(int type, int Ny) {
+ return MulMat::is_dequant_better(ggml_type(type), Ny);
+}
+
extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
@@ -597,7 +607,12 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
MulMat mm;
auto etypeA = ggml_type(typeA);
- if (auto dequant_type = MulMat::is_dequant_better(etypeA, Ny); dequant_type != etypeA) {
+ //auto etypeB = ggml_type(typeB);
+ auto dequant_type = MulMat::is_dequant_better(etypeA, Ny);
+ //if (etypeB != GGML_TYPE_F32) {
+ // if (ith == 0) printf("%s: typeA = %s, typeB = %s, dequant_type = %s\n", __func__, ggml_type_name(etypeA), ggml_type_name(etypeB), ggml_type_name(dequant_type));
+ //}
+ if (dequant_type != etypeA) {
if (!MulMat::prepare(dequant_type, typeB, ne00, mm, Ny)) {
return false;
}
@@ -612,9 +627,7 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
first_x *= num_rows;
nrc_x *= num_rows;
- auto type_size = ggml_type_size(dequant_type);
-
- size_t row_size_qx = ne00*type_size;
+ size_t row_size_qx = ggml_row_size(dequant_type, ne00);
size_t row_size_qy = strideB;
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
@@ -680,9 +693,7 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
first_x *= num_rows;
nrc_x *= num_rows;
- auto type_size = ggml_type_size(dequant_type);
-
- size_t row_size_qx = ne00*type_size;
+ size_t row_size_qx = ggml_row_size(dequant_type, ne00);
size_t row_size_qy = strideB;
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h
index 6f44af52..87722f6f 100644
--- a/ggml/src/iqk/iqk_mul_mat.h
+++ b/ggml/src/iqk/iqk_mul_mat.h
@@ -34,6 +34,8 @@ IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int un
int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
+IQK_API int iqk_dequant_type(int type, int Ny);
+
typedef void (*barrier_t) (void *);
IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 2eb53d1c..9261d02e 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -2831,6 +2831,8 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) {
const __m256 mul = _mm256_set1_ps( id );
xx = xb;
int8_t * q8 = y[i].qs;
+ int block_sum_i32 = 0;
+ float block_sum_f32 = 0;
for (int ib = 0; ib < QK_K/32; ++ib) {
__m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
__m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
@@ -2844,13 +2846,15 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) {
__m256i i1 = _mm256_cvtps_epi32(v1);
__m256i i2 = _mm256_cvtps_epi32(v2);
__m256i i3 = _mm256_cvtps_epi32(v3);
- if constexpr (q8_type > 0) {
+ if constexpr (q8_type == 1) {
int bsum = hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
auto bs = (float *)y[i].bsums;
bs[ib] = d*bsum;
+ block_sum_f32 += bs[ib];
} else {
y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1));
y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3));
+ block_sum_i32 += y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1];
}
i0 = _mm256_packs_epi32( i0, i1 );
i2 = _mm256_packs_epi32( i2, i3 );
@@ -2859,12 +2863,17 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) {
_mm256_storeu_si256((__m256i *)q8, i0);
q8 += 32;
}
- if constexpr (q8_type == 2) {
- auto bs = (float *)y[i].bsums;
- float sum = 0;
- for (int ib = 0; ib < QK_K/32; ++ib) sum += bs[ib];
- bs[0] = sum;
+ if constexpr (q8_type == 1) {
+ y[i].sum = block_sum_f32;
+ } else {
+ y[i].sum = d*block_sum_i32;
}
+ //if constexpr (q8_type == 2) {
+ // auto bs = (float *)y[i].bsums;
+ // float sum = 0;
+ // for (int ib = 0; ib < QK_K/32; ++ib) sum += bs[ib];
+ // bs[0] = sum;
+ //}
}
#else
for (int i = 0; i < nb; i++) {
@@ -2890,9 +2899,9 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) {
int v = nearest_int(iscale*x[j]);
y[i].qs[j] = MIN(127, v);
}
- if constexpr (q8_type > 0) {
+ float d = 1/iscale;
+ if constexpr (q8_type == 1) {
auto bs = (float *)y[i].bsums;
- float d = 1/iscale;
float sum = 0;
for (int j = 0; j < QK_K/32; ++j) {
int sum = 0;
@@ -2902,19 +2911,20 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) {
bs[j] = d*sum;
sum += bs[j];
}
- if constexpr (q8_type == 2) {
- bs[0] = sum;
- }
+ y[i].sum = sum;
} else {
+ int tot = 0;
for (int j = 0; j < QK_K/16; ++j) {
int sum = 0;
for (int ii = 0; ii < 16; ++ii) {
sum += y[i].qs[j*16 + ii];
}
y[i].bsums[j] = sum;
+ tot += sum;
}
+ y[i].sum = d*tot;
}
- y[i].d = 1/iscale;
+ y[i].d = d;
x += QK_K;
}
#endif