summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_gemm_ktquants.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_gemm_ktquants.cpp')
-rw-r--r--ggml/src/iqk/iqk_gemm_ktquants.cpp431
1 files changed, 428 insertions, 3 deletions
diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp
index 19c30e2a..57702199 100644
--- a/ggml/src/iqk/iqk_gemm_ktquants.cpp
+++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp
@@ -212,6 +212,56 @@ struct Trellis3 {
}
}
}
+ IQK_ALWAYS_INLINE inline void next_128(__m256i val, __m256i * result) const {
+ // Even though we only have 16 vector registers nn AVX2, this is still faster
+ __m256i aux[16];
+ __m256i tmp[2];
+ tmp[0] = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(val));
+ tmp[1] = _mm256_cvtepu16_epi32(_mm256_extracti128_si256(val, 1));
+ for (int k = 0; k < 2; ++k) {
+ auto vl = _mm256_castsi256_si128(tmp[k]);
+ auto v = MM256_SET_M128I(vl, vl);
+ aux[8*k+0] = _mm256_shuffle_epi32(v, 0x00);
+ aux[8*k+1] = _mm256_shuffle_epi32(v, 0x55);
+ aux[8*k+2] = _mm256_shuffle_epi32(v, 0xaa);
+ aux[8*k+3] = _mm256_shuffle_epi32(v, 0xff);
+ auto vh = _mm256_extracti128_si256(tmp[k], 1);
+ v = MM256_SET_M128I(vh, vh);
+ aux[8*k+4] = _mm256_shuffle_epi32(v, 0x00);
+ aux[8*k+5] = _mm256_shuffle_epi32(v, 0x55);
+ aux[8*k+6] = _mm256_shuffle_epi32(v, 0xaa);
+ aux[8*k+7] = _mm256_shuffle_epi32(v, 0xff);
+ }
+ for (int i = 0; i < 16; ++i) {
+ aux[i] = _mm256_mullo_epi32(aux[i], mka);
+ }
+ auto mask = _mm256_set1_epi32(0x3f3f3f3f);
+ for (int i = 0; i < 16; ++i) {
+ aux[i] = _mm256_and_si256(aux[i], mask);
+ }
+ auto offset = _mm256_set1_epi32(-126);
+#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
+ auto m1 = _mm256_set1_epi32(0x01010101);
+#endif
+ for (int i = 0; i < 16; ++i) {
+#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
+ aux[i] = _mm256_dpbusd_epi32(offset, aux[i], m1);
+#else
+ auto dot = _mm256_maddubs_epi16(aux[i], _mm256_set1_epi32(0x01010101));
+ aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1)));
+#endif
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto v1 = _mm256_packs_epi32(aux[4*k+0], aux[4*k+1]);
+ auto v2 = _mm256_packs_epi32(aux[4*k+2], aux[4*k+3]);
+ result[k] = _mm256_permutevar8x32_epi32(_mm256_packs_epi16(v1, v2), shuffle);
+ }
+ if constexpr (is_abs) {
+ for (int k = 0; k < 4; ++k) {
+ result[k] = _mm256_sign_epi8(result[k], result[k]);
+ }
+ }
+ }
IQK_ALWAYS_INLINE inline void next_128(const uint16_t * val, uint32_t v0, __m256i * result) const {
// Even though we only have 16 vector registers nn AVX2, this is still faster
__m256i aux[16];
@@ -463,6 +513,148 @@ void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
+void iqk_dequantize_iq1_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+ const int nb = n/QK_K;
+
+ Trellis3 trellis;
+
+ auto values = _mm_loadu_si128((const __m128i *)iq4k_values);
+
+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
+
+ const block_iq1_kt * x8[8];
+ float dkt[8];
+ float ls[8];
+ float ls_all[64];
+ uint32_t idx[8];
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) {
+ const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
+ dkt[k] = dptr[0];
+ x8[k] = (const block_iq1_kt *)(dptr + 1);
+ }
+ auto vd = _mm256_loadu_ps(dkt);
+
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ auto sh = _mm_loadl_epi64((const __m128i *)x8[k][i].sh);
+ auto s8 = _mm_shuffle_epi8(values, _mm_and_si128(sh, _mm_set1_epi8(0xf)));
+ auto s32 = _mm256_cvtepi8_epi32(s8);
+ _mm256_storeu_ps(ls_all + 8*k, _mm256_cvtepi32_ps(s32));
+ }
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib];
+ auto scales = _mm256_mul_ps(vd, _mm256_loadu_ps(ls));
+ _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
+ for (int j = 0; j < 4; ++j) {
+ int jj = 4*ib + j;
+ for (int k = 0; k < 8; ++k) {
+ idx[k] = (x8[k][i].ql[jj] | ((x8[k][i].qh[jj%16] << (8 - 4*(jj/16))) & 0xf00) | ((x8[k][i].sh[jj/4] << (8 - (jj%4))) & 0x1000)) + 4096;
+ }
+ __m256i packed[2];
+ trellis.next64(idx, packed);
+ _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+0, packed[0]);
+ _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+1, packed[1]);
+ }
+ }
+ y += 8; // = QK_K/32;
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_iq1_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK_K == 0);
+ const int nb = n/QK_K;
+
+ Trellis3<true, false> trellis;
+
+ auto values = _mm_loadu_si128((const __m128i *)iq4k_values);
+
+ constexpr int k_acc = nrc_y;
+
+ __m256 accd[k_acc];
+ const block_q8_2_x4 * y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ y[iy] = (const block_q8_2_x4 *)info.src1_row(iy);
+ }
+
+ __m256i xv[4], dot[4];
+ __m256 scales[2];
+
+ auto sum_4 = [&dot] () {
+ // dot[k] has 8 values from block k
+ // 0 1 0 1 0 1 0 1
+ dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1]));
+ // 2 3 2 3 2 3 2 3
+ dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3]));
+ // 0 1 2 3 0 1 2 3
+ dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2]));
+ return _mm256_cvtepi32_ps(dot[0]);
+ };
+
+ auto compute_dot = [&dot, &xv] (const int8_t * y) {
+ for (int k = 0; k < 4; ++k) {
+ auto yv = _mm256_loadu_si256((const __m256i *)y + k);
+#ifdef HAVE_FANCY_SIMD
+ //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv);
+ dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k]));
+#else
+ auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k]));
+ dot[k] = _mm256_madd_epi16(p, _mm256_set1_epi16(1));
+#endif
+ }
+ };
+
+ __m256i idx[2];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ const float * dptr = (const float *)((const char*)vx + ix*bx);
+ auto d = _mm256_set1_ps(dptr[0]);
+ const block_iq1_kt * x = (const block_iq1_kt *)(dptr + 1);
+
+ for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+ auto sh = _mm_loadl_epi64((const __m128i *)x[i].sh);
+ auto s32 = _mm256_cvtepi8_epi32(_mm_shuffle_epi8(values, _mm_and_si128(sh, _mm_set1_epi8(0xf))));
+ auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32));
+ auto scales_l = _mm256_castps256_ps128(all_scales);
+ auto scales_h = _mm256_extractf128_ps(all_scales, 1);
+ scales[0] = _mm256_set_m128(scales_l, scales_l);
+ scales[1] = _mm256_set_m128(scales_h, scales_h);
+ auto qs8l = _mm_loadu_si128((const __m128i *)x[i].ql+0);
+ auto qs8h = _mm_loadu_si128((const __m128i *)x[i].ql+1);
+ auto qh16 = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)x[i].qh));
+ idx[0] = _mm256_or_si256(_mm256_cvtepu8_epi16(qs8l), _mm256_and_si256(_mm256_set1_epi16(0xf00), _mm256_slli_epi16(qh16, 8)));
+ idx[1] = _mm256_or_si256(_mm256_cvtepu8_epi16(qs8h), _mm256_and_si256(_mm256_set1_epi16(0xf00), _mm256_slli_epi16(qh16, 4)));
+ idx[0] = _mm256_add_epi16(idx[0], _mm256_set1_epi16(4096));
+ idx[1] = _mm256_add_epi16(idx[1], _mm256_set1_epi16(4096));
+ auto sh32 = _mm256_and_si256(_mm256_cvtepu8_epi32(sh), _mm256_set1_epi32(0xf0));
+ sh32 = _mm256_and_si256(_mm256_mullo_epi32(sh32, _mm256_set1_epi32(0x01020408)), _mm256_set1_epi8(-128));
+ idx[0] = _mm256_add_epi16(idx[0], _mm256_slli_epi16(_mm256_cvtepu8_epi16(_mm256_castsi256_si128(sh32)), 5));
+ idx[1] = _mm256_add_epi16(idx[1], _mm256_slli_epi16(_mm256_cvtepu8_epi16(_mm256_extracti128_si256(sh32, 1)), 5));
+ for (int i128 = 0; i128 < 2; ++i128) {
+ trellis.next_128(idx[i128], xv);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const block_q8_2_x4& yb = y[iy][2*i+i128];
+ auto dy4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)yb.d)), 16));
+ auto dy8 = _mm256_mul_ps(scales[i128], _mm256_set_m128(dy4, dy4));
+ compute_dot(yb.qs);
+ accd[iy] = _mm256_fmadd_ps(dy8, sum_4(), accd[iy]);
+ }
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+ }
+}
+
template <int nrc_y>
void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
@@ -1091,11 +1283,11 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
func16 = nullptr;
- if (typeA == GGML_TYPE_IQ4_KT) {
+ if (typeA == GGML_TYPE_IQ1_KT) {
if (typeB == GGML_TYPE_Q8_2_X4) {
- IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_2_x4_T, kernels);
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_kt_q8_2_x4_T, kernels);
#ifdef HAVE_FANCY_SIMD
- func16 = mul_mat_iq4_kt_q8_2_x4_T<16>;
+ func16 = mul_mat_iq1_kt_q8_2_x4_T<16>;
#endif
return true;
}
@@ -1124,6 +1316,17 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
return false;
}
+ if (typeA == GGML_TYPE_IQ4_KT) {
+ if (typeB == GGML_TYPE_Q8_2_X4) {
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_2_x4_T, kernels);
+#ifdef HAVE_FANCY_SIMD
+ func16 = mul_mat_iq4_kt_q8_2_x4_T<16>;
+#endif
+ return true;
+ }
+ return false;
+ }
+
if (ggml_type(typeB) != GGML_TYPE_F32) {
return false;
}
@@ -1148,6 +1351,7 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, [[maybe_unused]] size_t stride_y, int nrc_x) {
switch (type) {
+ case GGML_TYPE_IQ1_KT: iqk_dequantize_iq1_kt_q80_r8(n, vx, bx, y, nrc_x); break;
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt_q80_r8(n, vx, bx, y, nrc_x); break;
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt_q80_r8(n, vx, bx, y, nrc_x); break;
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break;
@@ -1701,6 +1905,27 @@ struct Trellis3 {
}
return result;
}
+ inline int8x16x2_t next32(uint16x4_t val16) const {
+ auto vka3 = vdupq_n_u32(ka3);
+ int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)};
+ auto val32 = vmovl_u16(val16);
+ uint32x4x4_t aux32 = { vmulq_laneq_u32(mka, val32, 0), vmulq_laneq_u32(mka, val32, 1), vmulq_laneq_u32(mka, val32, 2), vmulq_laneq_u32(mka, val32, 3) };
+ int8x16x2_t i8;
+ auto mask = vdupq_n_u32(0x3f3f3f3f);
+ for (int i = 0; i < 2; ++i) {
+ i8.val[0] = vandq_u32(mask, aux32.val[2*i+0]);
+ i8.val[1] = vandq_u32(mask, vmulq_u32(vka3, aux32.val[2*i+0]));
+ auto s1 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1]));
+ i8.val[0] = vandq_u32(mask, aux32.val[2*i+1]);
+ i8.val[1] = vandq_u32(mask, vmulq_u32(vka3, aux32.val[2*i+1]));
+ auto s2 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1]));
+ result.val[i] = vaddq_s8(result.val[i], vpaddq_s8(s1, s2));
+ if constexpr (is_abs) {
+ result.val[i] = vreinterpretq_s8_u8(vabsq_s8(result.val[i]));
+ }
+ }
+ return result;
+ }
inline int8x16x2_t next32(const uint16_t * val, uint32_t v0) const {
auto vka3 = vdupq_n_u32(ka3);
int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)};
@@ -2028,6 +2253,196 @@ void iqk_dequantize_iq3_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
}
}
+void iqk_dequantize_iq1_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+ const int nb = n/QK_K;
+
+ Trellis3 trellis;
+
+ auto values = vld1q_s8(iq4k_values);
+
+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
+
+ const block_iq1_kt * x8[8];
+ float dkt[8];
+ float ls[8], ls_all[64];
+ uint16_t all_idx[256];
+ uint32_t idx[8];
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) {
+ const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
+ dkt[k] = dptr[0];
+ x8[k] = (const block_iq1_kt *)(dptr + 1);
+ }
+ auto vd = vld1q_f32_x2(dkt);
+
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ auto sh = vld1_u8(x8[k][i].sh);
+ auto s16 = vmovl_s8(vqtbl1_s8(values, vand_u8(sh, vdup_n_u8(0xf))));
+ vst1q_f32(ls_all + 8*k + 0, vcvtq_f32_s32(vmovl_s16(vget_low_s16(s16))));
+ vst1q_f32(ls_all + 8*k + 4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16))));
+ auto ql = vld1q_u8_x2(x8[k][i].ql);
+ auto qh = vld1q_u8(x8[k][i].qh);
+ auto qhl = vmovl_u8(vget_low_u8(qh));
+ auto qhh = vmovl_u8(vget_high_u8(qh));
+ uint16x8x4_t idx;
+ idx.val[0] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 8)));
+ idx.val[1] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 8)));
+ idx.val[2] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 4)));
+ idx.val[3] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 4)));
+ for (int k = 0; k < 4; ++k) idx.val[k] = vaddq_u16(idx.val[k], vdupq_n_u16(4096));
+ auto sh16 = vandq_u16(vmovl_u8(sh), vdupq_n_u16(0xf0));
+ auto sh32l = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_low_u16 (sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80));
+ auto sh32h = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_high_u16(sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80));
+ idx.val[0] = vaddq_u16(idx.val[0], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32l)), 5));
+ idx.val[1] = vaddq_u16(idx.val[1], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32l)), 5));
+ idx.val[2] = vaddq_u16(idx.val[2], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32h)), 5));
+ idx.val[3] = vaddq_u16(idx.val[3], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32h)), 5));
+ vst1q_u16_x4(all_idx + 32*k, idx);
+ }
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib];
+ auto scales1 = vmulq_f32(vd.val[0], vld1q_f32(ls+0));
+ auto scales2 = vmulq_f32(vd.val[1], vld1q_f32(ls+4));
+ vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1));
+ vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2));
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 8; ++k) idx[k] = all_idx[32*k + 4*ib + j];
+ vst1q_s8_x4(y[ib].qs+64*j, trellis.next64(idx));
+ }
+ }
+ y += 8; // = QK_K/32;
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_iq1_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK_K == 0);
+ const int nb = n/QK_K;
+
+ Trellis3 trellis;
+
+ auto values = vld1q_s8(iq4k_values);
+
+ constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
+
+ float32x4_t accd[k_acc];
+
+ const block_q8_0_x4 * y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ y[iy] = (const block_q8_0_x4 *)info.src1_row(iy);
+ }
+
+ int8x16x2_t xv[8];
+ uint16x8x4_t idx;
+ int32x4x4_t dot;
+
+ auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) {
+ for (int k = 0; k < 4; ++k) {
+ auto yv = vld1q_s8_x2(y + 32*k);
+ dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]);
+ }
+ dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]);
+ dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]);
+ return vpaddq_s32(dot.val[0], dot.val[2]);
+ };
+
+ float32x4x2_t scales;
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ const float * dptr = (const float *)((const char*)vx + ix*bx);
+ auto d = vdupq_n_f32(dptr[0]);
+ const block_iq1_kt * x = (const block_iq1_kt *)(dptr + 1);
+
+ for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0);
+
+ for (int i = 0; i < nb; ++i) {
+ auto sh = vld1_u8(x[i].sh);
+ auto s16 = vmovl_s8(vqtbl1_s8(values, vand_u8(sh, vdup_n_u8(0xf))));
+ scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (s16))));
+ scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16))));
+ auto ql = vld1q_u8_x2(x[i].ql);
+ auto qh = vld1q_u8(x[i].qh);
+ auto qhl = vmovl_u8(vget_low_u8(qh));
+ auto qhh = vmovl_u8(vget_high_u8(qh));
+ idx.val[0] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 8)));
+ idx.val[1] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 8)));
+ idx.val[2] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 4)));
+ idx.val[3] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 4)));
+ for (int k = 0; k < 4; ++k) idx.val[k] = vaddq_u16(idx.val[k], vdupq_n_u16(4096));
+ auto sh16 = vandq_u16(vmovl_u8(sh), vdupq_n_u16(0xf0));
+ auto sh32l = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_low_u16 (sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80));
+ auto sh32h = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_high_u16(sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80));
+ idx.val[0] = vaddq_u16(idx.val[0], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32l)), 5));
+ idx.val[1] = vaddq_u16(idx.val[1], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32l)), 5));
+ idx.val[2] = vaddq_u16(idx.val[2], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32h)), 5));
+ idx.val[3] = vaddq_u16(idx.val[3], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32h)), 5));
+ if constexpr (nrc_y == 1) {
+ const block_q8_0_x4& ybl = y[0][2*i+0];
+ const block_q8_0_x4& ybh = y[0][2*i+1];
+ auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
+ auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
+ int32x4x4_t suml = {};
+ int32x4x4_t sumh = {};
+ for (int ib = 0; ib < 2; ++ib) {
+ auto xl = trellis.next32(vget_low_u16(idx.val[ib+0]));
+ auto xh = trellis.next32(vget_low_u16(idx.val[ib+2]));
+ auto yl = vld1q_s8_x2(ybl.qs + 64*ib);
+ auto yh = vld1q_s8_x2(ybh.qs + 64*ib);
+ suml.val[2*ib+0] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]);
+ sumh.val[2*ib+0] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]);
+ xl = trellis.next32(vget_high_u16(idx.val[ib+0]));
+ xh = trellis.next32(vget_high_u16(idx.val[ib+2]));
+ yl = vld1q_s8_x2(ybl.qs + 64*ib + 32);
+ yh = vld1q_s8_x2(ybh.qs + 64*ib + 32);
+ suml.val[2*ib+1] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]);
+ sumh.val[2*ib+1] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]);
+ }
+ auto sl1 = vpaddq_s32(suml.val[0], suml.val[1]);
+ auto sl2 = vpaddq_s32(suml.val[2], suml.val[3]);
+ auto sl = vpaddq_s32(sl1, sl2);
+ auto sh1 = vpaddq_s32(sumh.val[0], sumh.val[1]);
+ auto sh2 = vpaddq_s32(sumh.val[2], sumh.val[3]);
+ auto sh = vpaddq_s32(sh1, sh2);
+ accd[0] = vfmaq_f32(accd[0], dyl, vcvtq_f32_s32(sl));
+ accd[1] = vfmaq_f32(accd[1], dyh, vcvtq_f32_s32(sh));
+ } else {
+ for (int k = 0; k < 4; ++k) {
+ xv[2*k+0] = trellis.next32(vget_low_u16 (idx.val[k]));
+ xv[2*k+1] = trellis.next32(vget_high_u16(idx.val[k]));
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const block_q8_0_x4& ybl = y[iy][2*i+0];
+ const block_q8_0_x4& ybh = y[iy][2*i+1];
+ auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
+ auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
+ auto sumil = compute_dot(ybl.qs, xv+0);
+ auto sumih = compute_dot(ybh.qs, xv+4);
+ if constexpr (nrc_y == 1) {
+ accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil));
+ accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih));
+ } else {
+ accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil));
+ accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih));
+ }
+ }
+ }
+ }
+
+ if constexpr (nrc_y == 1) {
+ info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1])));
+ } else {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vaddvq_f32(accd[iy]));
+ }
+ }
+ }
+}
+
template <int nrc_y>
void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
@@ -2284,6 +2699,15 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
return false;
}
+ if (ggml_type(typeA) == GGML_TYPE_IQ1_KT) {
+ if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) {
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_kt_q8_0_x4_T, kernels);
+ func16 = nullptr;
+ return true;
+ }
+ return false;
+ }
+
if (ggml_type(typeB) != GGML_TYPE_F16) {
return false;
}
@@ -2309,6 +2733,7 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, [[maybe_unused]] size_t stride_y, int nrc_x) {
switch (type) {
+ case GGML_TYPE_IQ1_KT: iqk_dequantize_iq1_kt_q80_r8(n, vx, bx, y, nrc_x); break;
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt_q80_r8(n, vx, bx, y, nrc_x); break;
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt_q80_r8(n, vx, bx, y, nrc_x); break;
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break;