diff options
Diffstat (limited to 'ggml/src/iqk/iqk_gemm_ktquants.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_gemm_ktquants.cpp | 431 |
1 files changed, 428 insertions, 3 deletions
diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 19c30e2a..57702199 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -212,6 +212,56 @@ struct Trellis3 { } } } + IQK_ALWAYS_INLINE inline void next_128(__m256i val, __m256i * result) const { + // Even though we only have 16 vector registers nn AVX2, this is still faster + __m256i aux[16]; + __m256i tmp[2]; + tmp[0] = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(val)); + tmp[1] = _mm256_cvtepu16_epi32(_mm256_extracti128_si256(val, 1)); + for (int k = 0; k < 2; ++k) { + auto vl = _mm256_castsi256_si128(tmp[k]); + auto v = MM256_SET_M128I(vl, vl); + aux[8*k+0] = _mm256_shuffle_epi32(v, 0x00); + aux[8*k+1] = _mm256_shuffle_epi32(v, 0x55); + aux[8*k+2] = _mm256_shuffle_epi32(v, 0xaa); + aux[8*k+3] = _mm256_shuffle_epi32(v, 0xff); + auto vh = _mm256_extracti128_si256(tmp[k], 1); + v = MM256_SET_M128I(vh, vh); + aux[8*k+4] = _mm256_shuffle_epi32(v, 0x00); + aux[8*k+5] = _mm256_shuffle_epi32(v, 0x55); + aux[8*k+6] = _mm256_shuffle_epi32(v, 0xaa); + aux[8*k+7] = _mm256_shuffle_epi32(v, 0xff); + } + for (int i = 0; i < 16; ++i) { + aux[i] = _mm256_mullo_epi32(aux[i], mka); + } + auto mask = _mm256_set1_epi32(0x3f3f3f3f); + for (int i = 0; i < 16; ++i) { + aux[i] = _mm256_and_si256(aux[i], mask); + } + auto offset = _mm256_set1_epi32(-126); +#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) + auto m1 = _mm256_set1_epi32(0x01010101); +#endif + for (int i = 0; i < 16; ++i) { +#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) + aux[i] = _mm256_dpbusd_epi32(offset, aux[i], m1); +#else + auto dot = _mm256_maddubs_epi16(aux[i], _mm256_set1_epi32(0x01010101)); + aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); +#endif + } + for (int k = 0; k < 4; ++k) { + auto v1 = _mm256_packs_epi32(aux[4*k+0], aux[4*k+1]); + auto v2 = _mm256_packs_epi32(aux[4*k+2], aux[4*k+3]); + result[k] = _mm256_permutevar8x32_epi32(_mm256_packs_epi16(v1, v2), shuffle); + } + if constexpr (is_abs) { + for (int k = 0; k < 4; ++k) { + result[k] = _mm256_sign_epi8(result[k], result[k]); + } + } + } IQK_ALWAYS_INLINE inline void next_128(const uint16_t * val, uint32_t v0, __m256i * result) const { // Even though we only have 16 vector registers nn AVX2, this is still faster __m256i aux[16]; @@ -463,6 +513,148 @@ void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +void iqk_dequantize_iq1_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq1_kt * x8[8]; + float dkt[8]; + float ls[8]; + float ls_all[64]; + uint32_t idx[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq1_kt *)(dptr + 1); + } + auto vd = _mm256_loadu_ps(dkt); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + auto sh = _mm_loadl_epi64((const __m128i *)x8[k][i].sh); + auto s8 = _mm_shuffle_epi8(values, _mm_and_si128(sh, _mm_set1_epi8(0xf))); + auto s32 = _mm256_cvtepi8_epi32(s8); + _mm256_storeu_ps(ls_all + 8*k, _mm256_cvtepi32_ps(s32)); + } + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib]; + auto scales = _mm256_mul_ps(vd, _mm256_loadu_ps(ls)); + _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + for (int j = 0; j < 4; ++j) { + int jj = 4*ib + j; + for (int k = 0; k < 8; ++k) { + idx[k] = (x8[k][i].ql[jj] | ((x8[k][i].qh[jj%16] << (8 - 4*(jj/16))) & 0xf00) | ((x8[k][i].sh[jj/4] << (8 - (jj%4))) & 0x1000)) + 4096; + } + __m256i packed[2]; + trellis.next64(idx, packed); + _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+0, packed[0]); + _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+1, packed[1]); + } + } + y += 8; // = QK_K/32; + } + } +} + +template <int nrc_y> +void mul_mat_iq1_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis3<true, false> trellis; + + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + + constexpr int k_acc = nrc_y; + + __m256 accd[k_acc]; + const block_q8_2_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_2_x4 *)info.src1_row(iy); + } + + __m256i xv[4], dot[4]; + __m256 scales[2]; + + auto sum_4 = [&dot] () { + // dot[k] has 8 values from block k + // 0 1 0 1 0 1 0 1 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1])); + // 2 3 2 3 2 3 2 3 + dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3])); + // 0 1 2 3 0 1 2 3 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2])); + return _mm256_cvtepi32_ps(dot[0]); + }; + + auto compute_dot = [&dot, &xv] (const int8_t * y) { + for (int k = 0; k < 4; ++k) { + auto yv = _mm256_loadu_si256((const __m256i *)y + k); +#ifdef HAVE_FANCY_SIMD + //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); +#else + auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); + dot[k] = _mm256_madd_epi16(p, _mm256_set1_epi16(1)); +#endif + } + }; + + __m256i idx[2]; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(dptr[0]); + const block_iq1_kt * x = (const block_iq1_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + auto sh = _mm_loadl_epi64((const __m128i *)x[i].sh); + auto s32 = _mm256_cvtepi8_epi32(_mm_shuffle_epi8(values, _mm_and_si128(sh, _mm_set1_epi8(0xf)))); + auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32)); + auto scales_l = _mm256_castps256_ps128(all_scales); + auto scales_h = _mm256_extractf128_ps(all_scales, 1); + scales[0] = _mm256_set_m128(scales_l, scales_l); + scales[1] = _mm256_set_m128(scales_h, scales_h); + auto qs8l = _mm_loadu_si128((const __m128i *)x[i].ql+0); + auto qs8h = _mm_loadu_si128((const __m128i *)x[i].ql+1); + auto qh16 = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)x[i].qh)); + idx[0] = _mm256_or_si256(_mm256_cvtepu8_epi16(qs8l), _mm256_and_si256(_mm256_set1_epi16(0xf00), _mm256_slli_epi16(qh16, 8))); + idx[1] = _mm256_or_si256(_mm256_cvtepu8_epi16(qs8h), _mm256_and_si256(_mm256_set1_epi16(0xf00), _mm256_slli_epi16(qh16, 4))); + idx[0] = _mm256_add_epi16(idx[0], _mm256_set1_epi16(4096)); + idx[1] = _mm256_add_epi16(idx[1], _mm256_set1_epi16(4096)); + auto sh32 = _mm256_and_si256(_mm256_cvtepu8_epi32(sh), _mm256_set1_epi32(0xf0)); + sh32 = _mm256_and_si256(_mm256_mullo_epi32(sh32, _mm256_set1_epi32(0x01020408)), _mm256_set1_epi8(-128)); + idx[0] = _mm256_add_epi16(idx[0], _mm256_slli_epi16(_mm256_cvtepu8_epi16(_mm256_castsi256_si128(sh32)), 5)); + idx[1] = _mm256_add_epi16(idx[1], _mm256_slli_epi16(_mm256_cvtepu8_epi16(_mm256_extracti128_si256(sh32, 1)), 5)); + for (int i128 = 0; i128 < 2; ++i128) { + trellis.next_128(idx[i128], xv); + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_2_x4& yb = y[iy][2*i+i128]; + auto dy4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)yb.d)), 16)); + auto dy8 = _mm256_mul_ps(scales[i128], _mm256_set_m128(dy4, dy4)); + compute_dot(yb.qs); + accd[iy] = _mm256_fmadd_ps(dy8, sum_4(), accd[iy]); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + } +} + template <int nrc_y> void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); @@ -1091,11 +1283,11 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat func16 = nullptr; - if (typeA == GGML_TYPE_IQ4_KT) { + if (typeA == GGML_TYPE_IQ1_KT) { if (typeB == GGML_TYPE_Q8_2_X4) { - IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_2_x4_T, kernels); + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_kt_q8_2_x4_T, kernels); #ifdef HAVE_FANCY_SIMD - func16 = mul_mat_iq4_kt_q8_2_x4_T<16>; + func16 = mul_mat_iq1_kt_q8_2_x4_T<16>; #endif return true; } @@ -1124,6 +1316,17 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat return false; } + if (typeA == GGML_TYPE_IQ4_KT) { + if (typeB == GGML_TYPE_Q8_2_X4) { + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_2_x4_T, kernels); +#ifdef HAVE_FANCY_SIMD + func16 = mul_mat_iq4_kt_q8_2_x4_T<16>; +#endif + return true; + } + return false; + } + if (ggml_type(typeB) != GGML_TYPE_F32) { return false; } @@ -1148,6 +1351,7 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, [[maybe_unused]] size_t stride_y, int nrc_x) { switch (type) { + case GGML_TYPE_IQ1_KT: iqk_dequantize_iq1_kt_q80_r8(n, vx, bx, y, nrc_x); break; case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt_q80_r8(n, vx, bx, y, nrc_x); break; case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt_q80_r8(n, vx, bx, y, nrc_x); break; case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break; @@ -1701,6 +1905,27 @@ struct Trellis3 { } return result; } + inline int8x16x2_t next32(uint16x4_t val16) const { + auto vka3 = vdupq_n_u32(ka3); + int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)}; + auto val32 = vmovl_u16(val16); + uint32x4x4_t aux32 = { vmulq_laneq_u32(mka, val32, 0), vmulq_laneq_u32(mka, val32, 1), vmulq_laneq_u32(mka, val32, 2), vmulq_laneq_u32(mka, val32, 3) }; + int8x16x2_t i8; + auto mask = vdupq_n_u32(0x3f3f3f3f); + for (int i = 0; i < 2; ++i) { + i8.val[0] = vandq_u32(mask, aux32.val[2*i+0]); + i8.val[1] = vandq_u32(mask, vmulq_u32(vka3, aux32.val[2*i+0])); + auto s1 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); + i8.val[0] = vandq_u32(mask, aux32.val[2*i+1]); + i8.val[1] = vandq_u32(mask, vmulq_u32(vka3, aux32.val[2*i+1])); + auto s2 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); + result.val[i] = vaddq_s8(result.val[i], vpaddq_s8(s1, s2)); + if constexpr (is_abs) { + result.val[i] = vreinterpretq_s8_u8(vabsq_s8(result.val[i])); + } + } + return result; + } inline int8x16x2_t next32(const uint16_t * val, uint32_t v0) const { auto vka3 = vdupq_n_u32(ka3); int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)}; @@ -2028,6 +2253,196 @@ void iqk_dequantize_iq3_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, } } +void iqk_dequantize_iq1_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto values = vld1q_s8(iq4k_values); + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq1_kt * x8[8]; + float dkt[8]; + float ls[8], ls_all[64]; + uint16_t all_idx[256]; + uint32_t idx[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq1_kt *)(dptr + 1); + } + auto vd = vld1q_f32_x2(dkt); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + auto sh = vld1_u8(x8[k][i].sh); + auto s16 = vmovl_s8(vqtbl1_s8(values, vand_u8(sh, vdup_n_u8(0xf)))); + vst1q_f32(ls_all + 8*k + 0, vcvtq_f32_s32(vmovl_s16(vget_low_s16(s16)))); + vst1q_f32(ls_all + 8*k + 4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16)))); + auto ql = vld1q_u8_x2(x8[k][i].ql); + auto qh = vld1q_u8(x8[k][i].qh); + auto qhl = vmovl_u8(vget_low_u8(qh)); + auto qhh = vmovl_u8(vget_high_u8(qh)); + uint16x8x4_t idx; + idx.val[0] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 8))); + idx.val[1] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 8))); + idx.val[2] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 4))); + idx.val[3] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 4))); + for (int k = 0; k < 4; ++k) idx.val[k] = vaddq_u16(idx.val[k], vdupq_n_u16(4096)); + auto sh16 = vandq_u16(vmovl_u8(sh), vdupq_n_u16(0xf0)); + auto sh32l = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_low_u16 (sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80)); + auto sh32h = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_high_u16(sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80)); + idx.val[0] = vaddq_u16(idx.val[0], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32l)), 5)); + idx.val[1] = vaddq_u16(idx.val[1], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32l)), 5)); + idx.val[2] = vaddq_u16(idx.val[2], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32h)), 5)); + idx.val[3] = vaddq_u16(idx.val[3], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32h)), 5)); + vst1q_u16_x4(all_idx + 32*k, idx); + } + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib]; + auto scales1 = vmulq_f32(vd.val[0], vld1q_f32(ls+0)); + auto scales2 = vmulq_f32(vd.val[1], vld1q_f32(ls+4)); + vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1)); + vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2)); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 8; ++k) idx[k] = all_idx[32*k + 4*ib + j]; + vst1q_s8_x4(y[ib].qs+64*j, trellis.next64(idx)); + } + } + y += 8; // = QK_K/32; + } + } +} + +template <int nrc_y> +void mul_mat_iq1_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto values = vld1q_s8(iq4k_values); + + constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y; + + float32x4_t accd[k_acc]; + + const block_q8_0_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_0_x4 *)info.src1_row(iy); + } + + int8x16x2_t xv[8]; + uint16x8x4_t idx; + int32x4x4_t dot; + + auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) { + for (int k = 0; k < 4; ++k) { + auto yv = vld1q_s8_x2(y + 32*k); + dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]); + } + dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]); + dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]); + return vpaddq_s32(dot.val[0], dot.val[2]); + }; + + float32x4x2_t scales; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = vdupq_n_f32(dptr[0]); + const block_iq1_kt * x = (const block_iq1_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0); + + for (int i = 0; i < nb; ++i) { + auto sh = vld1_u8(x[i].sh); + auto s16 = vmovl_s8(vqtbl1_s8(values, vand_u8(sh, vdup_n_u8(0xf)))); + scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (s16)))); + scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16)))); + auto ql = vld1q_u8_x2(x[i].ql); + auto qh = vld1q_u8(x[i].qh); + auto qhl = vmovl_u8(vget_low_u8(qh)); + auto qhh = vmovl_u8(vget_high_u8(qh)); + idx.val[0] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 8))); + idx.val[1] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 8))); + idx.val[2] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 4))); + idx.val[3] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 4))); + for (int k = 0; k < 4; ++k) idx.val[k] = vaddq_u16(idx.val[k], vdupq_n_u16(4096)); + auto sh16 = vandq_u16(vmovl_u8(sh), vdupq_n_u16(0xf0)); + auto sh32l = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_low_u16 (sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80)); + auto sh32h = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_high_u16(sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80)); + idx.val[0] = vaddq_u16(idx.val[0], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32l)), 5)); + idx.val[1] = vaddq_u16(idx.val[1], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32l)), 5)); + idx.val[2] = vaddq_u16(idx.val[2], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32h)), 5)); + idx.val[3] = vaddq_u16(idx.val[3], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32h)), 5)); + if constexpr (nrc_y == 1) { + const block_q8_0_x4& ybl = y[0][2*i+0]; + const block_q8_0_x4& ybh = y[0][2*i+1]; + auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d))); + auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d))); + int32x4x4_t suml = {}; + int32x4x4_t sumh = {}; + for (int ib = 0; ib < 2; ++ib) { + auto xl = trellis.next32(vget_low_u16(idx.val[ib+0])); + auto xh = trellis.next32(vget_low_u16(idx.val[ib+2])); + auto yl = vld1q_s8_x2(ybl.qs + 64*ib); + auto yh = vld1q_s8_x2(ybh.qs + 64*ib); + suml.val[2*ib+0] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]); + sumh.val[2*ib+0] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]); + xl = trellis.next32(vget_high_u16(idx.val[ib+0])); + xh = trellis.next32(vget_high_u16(idx.val[ib+2])); + yl = vld1q_s8_x2(ybl.qs + 64*ib + 32); + yh = vld1q_s8_x2(ybh.qs + 64*ib + 32); + suml.val[2*ib+1] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]); + sumh.val[2*ib+1] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]); + } + auto sl1 = vpaddq_s32(suml.val[0], suml.val[1]); + auto sl2 = vpaddq_s32(suml.val[2], suml.val[3]); + auto sl = vpaddq_s32(sl1, sl2); + auto sh1 = vpaddq_s32(sumh.val[0], sumh.val[1]); + auto sh2 = vpaddq_s32(sumh.val[2], sumh.val[3]); + auto sh = vpaddq_s32(sh1, sh2); + accd[0] = vfmaq_f32(accd[0], dyl, vcvtq_f32_s32(sl)); + accd[1] = vfmaq_f32(accd[1], dyh, vcvtq_f32_s32(sh)); + } else { + for (int k = 0; k < 4; ++k) { + xv[2*k+0] = trellis.next32(vget_low_u16 (idx.val[k])); + xv[2*k+1] = trellis.next32(vget_high_u16(idx.val[k])); + } + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_0_x4& ybl = y[iy][2*i+0]; + const block_q8_0_x4& ybh = y[iy][2*i+1]; + auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d))); + auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d))); + auto sumil = compute_dot(ybl.qs, xv+0); + auto sumih = compute_dot(ybh.qs, xv+4); + if constexpr (nrc_y == 1) { + accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil)); + accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih)); + } else { + accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil)); + accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih)); + } + } + } + } + + if constexpr (nrc_y == 1) { + info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1]))); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(accd[iy])); + } + } + } +} + template <int nrc_y> void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); @@ -2284,6 +2699,15 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat return false; } + if (ggml_type(typeA) == GGML_TYPE_IQ1_KT) { + if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) { + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_kt_q8_0_x4_T, kernels); + func16 = nullptr; + return true; + } + return false; + } + if (ggml_type(typeB) != GGML_TYPE_F16) { return false; } @@ -2309,6 +2733,7 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, [[maybe_unused]] size_t stride_y, int nrc_x) { switch (type) { + case GGML_TYPE_IQ1_KT: iqk_dequantize_iq1_kt_q80_r8(n, vx, bx, y, nrc_x); break; case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt_q80_r8(n, vx, bx, y, nrc_x); break; case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt_q80_r8(n, vx, bx, y, nrc_x); break; case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break; |