diff options
Diffstat (limited to 'ggml/src/iqk/iqk_gemm_ktquants.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_gemm_ktquants.cpp | 356 |
1 files changed, 335 insertions, 21 deletions
diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 38e76e1e..bc7bcf8b 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -13,7 +13,7 @@ namespace { -static inline uint32_t trellis_next(uint32_t& val) { +inline uint32_t trellis_next(uint32_t& val) { constexpr uint32_t ka = 89226354; constexpr uint32_t kb = 64248484; constexpr uint32_t kmask = 0x8fff8fff; @@ -22,7 +22,7 @@ static inline uint32_t trellis_next(uint32_t& val) { return (val & kmask) ^ km32; } -static inline float trellis_gen(uint32_t& val, uint32_t* s) { +inline float trellis_gen(uint32_t& val, uint32_t* s) { const ggml_fp16_t * h = (const ggml_fp16_t *)s; s[0] = trellis_next(val); return GGML_FP16_TO_FP32(h[0]) + GGML_FP16_TO_FP32(h[1]); @@ -59,7 +59,7 @@ struct Trellis1 { } }; -static inline __m256 trellis_gen8(__m256i i8) { +inline __m256 trellis_gen8(__m256i i8) { // split upper and lower bits of each 32-bit lane into two 8xfloat16 `hlo`, `hhi` __m256i low_16_bits_mask = _mm256_set1_epi32(0x0000FFFF); __m256i lower_halves_lanes32 = _mm256_and_si256(i8, low_16_bits_mask); @@ -97,8 +97,47 @@ struct Trellis2 { } }; +void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + + union { __m256 vec; float val[8]; } s_helper; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(*dptr * 31.75f * 1.05f); + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + s8 = _mm_shuffle_epi8(values, s8); + auto s32 = _mm256_cvtepi8_epi32(s8); + s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32)); + for (int ib = 0; ib < QK_K/64; ++ib) { + auto scale1 = _mm256_set1_ps(s_helper.val[2*ib+0]); + auto scale2 = _mm256_set1_ps(s_helper.val[2*ib+1]); + for (int j = 0; j < 4; ++j) { + auto xval1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(ql[8*ib+j+0]+4096))); + auto xval2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(ql[8*ib+j+4]+4096))); + _mm256_storeu_ps(y + i*QK_K + 64*ib + 8*j + 0, xval1); + _mm256_storeu_ps(y + i*QK_K + 64*ib + 8*j + 32, xval2); + } + } + } + + y += stride_y; + } +} + template <int nrc_y> -static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); const int nb = n/QK_K; @@ -159,14 +198,63 @@ static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataIn } } -static inline __m256 abs_ps(__m256 vals) { +inline __m256 abs_ps(__m256 vals) { // Clear sign-bit of all the 32-bit floats in vals __m256 sign_bit = _mm256_set1_ps(-0.0f); return _mm256_andnot_ps(sign_bit, vals); } +void iqk_dequantize_iq3_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + union { __m256 vec; float val[8]; } s_helper; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + + __m256i all_signs[4]; + auto mask1 = _mm256_set1_epi32(0x01); + auto mask2 = _mm256_set1_epi32(0x10); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(*dptr * 31.75f * 1.015f); + const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + const uint8_t * qh = x[i].qh; + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + auto s32 = _mm256_cvtepi8_epi32(s8); + s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32)); + for (int j = 0; j < 4; ++j) all_signs[j] = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qh + 8*j))); + for (int ib = 0; ib < 4; ++ib) { + auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]); + auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]); + for (int j = 0; j < 4; ++j) { + uint32_t val1 = ql[4*ib+j ] + 4096; + uint32_t val2 = ql[4*ib+j+16] + 4096; + auto sign1 = _mm256_and_si256(_mm256_cmpeq_epi32(_mm256_and_si256(all_signs[j], mask1), mask1), _mm256_set1_epi32(0x80000000)); + auto sign2 = _mm256_and_si256(_mm256_cmpeq_epi32(_mm256_and_si256(all_signs[j], mask2), mask2), _mm256_set1_epi32(0x80000000)); + all_signs[j] = _mm256_srli_epi32(all_signs[j], 1); + auto x_val1 = abs_ps(trellis_gen8(trellis.next8(val1))); + auto x_val2 = abs_ps(trellis_gen8(trellis.next8(val2))); + x_val1 = _mm256_mul_ps(scale1, _mm256_xor_ps(x_val1, _mm256_castsi256_ps(sign1))); + x_val2 = _mm256_mul_ps(scale2, _mm256_xor_ps(x_val2, _mm256_castsi256_ps(sign2))); + _mm256_storeu_ps(y + i*QK_K+32*ib+8*j , x_val1); + _mm256_storeu_ps(y + i*QK_K+32*ib+8*j+128, x_val2); + } + } + } + y += stride_y; + } +} + template <int nrc_y> -static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); const int nb = n/QK_K; @@ -227,8 +315,57 @@ static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataIn } } +void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis2 trellis; + + union { __m256 vec; float val[8]; } s_helper; + union { __m256i vec; uint32_t val[8]; } o_helper; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); + auto dav = _mm256_set1_ps(dptr[1]); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + + for (int i = 0; i < nb; ++i) { + auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs); + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + const uint8_t * qh = ql + kNumGroups; + auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1); + s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64)))); + o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096)); + for (int ib = 0; ib < 4; ++ib) { + auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]); + auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]); + for (int j = 0; j < 4; ++j) { + const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + auto x_val1 = _mm256_fmadd_ps(scale1, trellis_gen8(trellis.next8(val1, val3)), dav); + auto x_val2 = _mm256_fmadd_ps(scale2, trellis_gen8(trellis.next8(val2, val4)), dav); + + _mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j, x_val1); + _mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j + QK_K/2, x_val2); + + } + } + } + + y += stride_y; + + } +} + template <int nrc_y> -static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); const int nb = n/QK_K; constexpr int kNumGroups = 64; @@ -333,6 +470,16 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat } +bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, size_t stride_y, int nrc_x) { + switch (type) { + case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; + case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; + case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; + default: return false; + } + return true; +} + #else // !__x86_64__ namespace { @@ -403,8 +550,52 @@ struct Trellis1 { } }; +void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float16_t * y, size_t stride_y, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + auto values = vld1q_s8(iq4k_values); + + union { float16x8_t vec; float16_t val[8]; } s_helper; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + const float d = *dptr * 31.75f * 1.05f; + auto vd = vdupq_n_f32(d); + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + auto u32 = *(const uint32_t *)x[i].scales; + auto s8_u32 = uint32x2_t{u32, u32 >> 4}; + s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f)); + auto s8 = vqtbl1_s8(values, vreinterpret_u8_u32(s8_u32)); + auto s16 = vmovl_s8(s8); + auto s32l = vmovl_s16(vget_low_s16 (s16)); + auto s32h = vmovl_s16(vget_high_s16(s16)); + auto f32l = vmulq_f32(vd, vcvtq_f32_s32(s32l)); + auto f32h = vmulq_f32(vd, vcvtq_f32_s32(s32h)); + s_helper.vec = vcombine_f16(vcvt_f16_f32(f32l), vcvt_f16_f32(f32h)); + for (int ib = 0; ib < QK_K/64; ++ib) { + auto scale1 = vdupq_n_f16(s_helper.val[2*ib+0]); + auto scale2 = vdupq_n_f16(s_helper.val[2*ib+1]); + for (int j = 0; j < 4; ++j) { + auto xval1 = vmulq_f16(scale1, trellis.gen8(ql[8*ib+j+0]+4096)); + auto xval2 = vmulq_f16(scale2, trellis.gen8(ql[8*ib+j+4]+4096)); + vst1q_f16(y + i*QK_K + 64*ib + 8*j + 0, xval1); + vst1q_f16(y + i*QK_K + 64*ib + 8*j + 32, xval2); + } + } + } + + y += stride_y; + } +} + template <int nrc_y> -static void mul_mat_iq2_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +void mul_mat_iq2_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); const int nb = n/QK_K; @@ -466,8 +657,61 @@ static void mul_mat_iq2_kt_F16_T(int n, const void * vx, size_t bx, const DataIn } } +void iqk_dequantize_iq3_kt(int n, const void * vx, size_t bx, float16_t * y, size_t stride_y, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + union { float16x8_t vec; float16_t val[8]; } s_helper; + + uint16x8_t all_signs[4]; + auto mask1 = vdupq_n_u16(0x01); + auto mask2 = vdupq_n_u16(0x10); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + const float d = *dptr * 31.75f * 1.015f; + auto vd = vdupq_n_f32(d); + const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + const uint8_t * qh = x[i].qh; + auto u32 = *(const uint32_t *)x[i].scales; + auto s8_u32 = uint32x2_t{u32, u32 >> 4}; + s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f)); + auto s16 = vmovl_s8(vreinterpret_s8_u32(s8_u32)); + auto s32l = vmovl_s16(vget_low_s16 (s16)); + auto s32h = vmovl_s16(vget_high_s16(s16)); + auto f32l = vmulq_f32(vd, vcvtq_f32_s32(s32l)); + auto f32h = vmulq_f32(vd, vcvtq_f32_s32(s32h)); + s_helper.vec = vcombine_f16(vcvt_f16_f32(f32l), vcvt_f16_f32(f32h)); + for (int j = 0; j < 4; ++j) all_signs[j] = vmovl_u8(vld1_u8(qh + 8*j)); + for (int ib = 0; ib < 4; ++ib) { + auto scale1 = vdupq_n_f16(s_helper.val[ib+0]); + auto scale2 = vdupq_n_f16(s_helper.val[ib+4]); + for (int j = 0; j < 4; ++j) { + uint32_t val1 = ql[4*ib+j ] + 4096; + uint32_t val2 = ql[4*ib+j+16] + 4096; + auto sign1 = vshlq_n_u16(vandq_u16(all_signs[j], mask1), 15); + auto sign2 = vshlq_n_u16(vandq_u16(all_signs[j], mask2), 11); + all_signs[j] = vshrq_n_u16(all_signs[j], 1); + auto x_val1 = vabsq_f16(trellis.gen8(val1)); + auto x_val2 = vabsq_f16(trellis.gen8(val2)); + x_val1 = vmulq_f16(scale1, vreinterpretq_f16_u16(vorrq_u16(vreinterpretq_u16_f16(x_val1), sign1))); + x_val2 = vmulq_f16(scale2, vreinterpretq_f16_u16(vorrq_u16(vreinterpretq_u16_f16(x_val2), sign2))); + vst1q_f16(y + i*QK_K+32*ib+8*j , x_val1); + vst1q_f16(y + i*QK_K+32*ib+8*j+128, x_val2); + } + } + } + y += stride_y; + } +} + template <int nrc_y> -static void mul_mat_iq3_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +void mul_mat_iq3_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); const int nb = n/QK_K; @@ -527,8 +771,63 @@ static void mul_mat_iq3_kt_F16_T(int n, const void * vx, size_t bx, const DataIn } } +void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float16_t * y, size_t stride_y, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis1 trellis; + + union { float16x8_t vec; float16_t val[8]; } s_helper; + union { uint16x8_t vec; uint16_t val[8]; } o_helper; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = dptr[0] * 31.75f * 1.01f; + //auto dav = dptr[1]; + // Something goes wrong when we add the average. Why? + //auto vav = std::abs(dav) > 0.00006103515625f ? vdupq_n_f16(GGML_FP32_TO_FP16(dav)) : vdupq_n_f16(0); + auto vd = vdupq_n_f32(d); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + + for (int i = 0; i < nb; ++i) { + const uint32_t * shb = x[i].qs; + auto vshb = vld1q_u32_x2(shb); + auto vshb16 = vcombine_u16(vmovn_u32(vandq_u32(vshb.val[0], vdupq_n_u32(0xff))), vmovn_u32(vandq_u32(vshb.val[1], vdupq_n_u32(0xff)))); + const uint8_t * ql = (const uint8_t *)(shb + 8); + const uint8_t * qh = ql + kNumGroups; + auto iscales = vsubq_s16(vreinterpretq_s16_u16(vshrq_n_u16(vshb16, 1)), vdupq_n_s16(64)); + auto s32l = vmovl_s16(vget_low_s16(iscales)); + auto s32h = vmovl_s16(vget_high_s16(iscales)); + auto f32l = vmulq_f32(vd, vcvtq_f32_s32(s32l)); + auto f32h = vmulq_f32(vd, vcvtq_f32_s32(s32h)); + s_helper.vec = vcombine_f16(vcvt_f16_f32(f32l), vcvt_f16_f32(f32h)); + o_helper.vec = vaddq_u16(vshlq_n_u16(vandq_u16(vshb16, vdupq_n_u16(1)), 15), vdupq_n_u16(4096)); + for (int ib = 0; ib < 4; ++ib) { + auto scale1 = vdupq_n_f16(s_helper.val[ib+0]); + auto scale2 = vdupq_n_f16(s_helper.val[ib+4]); + for (int j = 0; j < 4; ++j) { + const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + //auto x_val1 = vfmaq_f16(vav, scale1, trellis.gen8(val1, val3)); + //auto x_val2 = vfmaq_f16(vav, scale2, trellis.gen8(val2, val4)); + auto x_val1 = vmulq_f16(scale1, trellis.gen8(val1, val3)); + auto x_val2 = vmulq_f16(scale2, trellis.gen8(val2, val4)); + vst1q_f16(y + i*QK_K+32*ib+8*j+ 0, x_val1); + vst1q_f16(y + i*QK_K+32*ib+8*j+128, x_val2); + } + } + } + y += stride_y; + } +} + template <int nrc_y> -static void mul_mat_iq4_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +void mul_mat_iq4_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); const int nb = n/QK_K; constexpr int kNumGroups = 64; @@ -548,8 +847,6 @@ static void mul_mat_iq4_kt_F16_T(int n, const void * vx, size_t bx, const DataIn auto sum = vdupq_n_f16(0); for (int i = 0; i < n/8; ++i) sum = vaddq_f16(sum, vld1q_f16(y[iy] + 8*i)); auto sum32 = vaddq_f32(vcvt_f32_f16(vget_low_f16(sum)), vcvt_f32_f16(vget_high_f16(sum))); - //auto sum32 = vdupq_n_f32(0); - //for (int i = 0; i < n/4; ++i) sum32 = vaddq_f32(sum32, vcvt_f32_f16(vld1_f16(y[iy] + 4*i))); row_sum[iy] = vaddvq_f32(sum32); } @@ -557,6 +854,7 @@ static void mul_mat_iq4_kt_F16_T(int n, const void * vx, size_t bx, const DataIn const float * dptr = (const float *)((const char*)vx + ix*bx); auto d = dptr[0] * 31.75f * 1.01f; auto dav = dptr[1]; + auto vd = vdupq_n_f32(d); const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f16(0); @@ -568,7 +866,12 @@ static void mul_mat_iq4_kt_F16_T(int n, const void * vx, size_t bx, const DataIn const uint8_t * ql = (const uint8_t *)(shb + 8); const uint8_t * qh = ql + kNumGroups; auto iscales = vsubq_s16(vreinterpretq_s16_u16(vshrq_n_u16(vshb16, 1)), vdupq_n_s16(64)); - s_helper.vec = vcvtq_f16_s16(iscales); + auto s32l = vmovl_s16(vget_low_s16(iscales)); + auto s32h = vmovl_s16(vget_high_s16(iscales)); + auto f32l = vmulq_f32(vd, vcvtq_f32_s32(s32l)); + auto f32h = vmulq_f32(vd, vcvtq_f32_s32(s32h)); + s_helper.vec = vcombine_f16(vcvt_f16_f32(f32l), vcvt_f16_f32(f32h)); + //s_helper.vec = vcvtq_f16_s16(iscales); o_helper.vec = vaddq_u16(vshlq_n_u16(vandq_u16(vshb16, vdupq_n_u16(1)), 15), vdupq_n_u16(4096)); for (int ib = 0; ib < 4; ++ib) { auto scale1 = vdupq_n_f16(s_helper.val[ib+0]); @@ -602,18 +905,18 @@ static void mul_mat_iq4_kt_F16_T(int n, const void * vx, size_t bx, const DataIn if constexpr (nrc_y == 1) { auto sum16 = vaddq_f16(accd[0], accd[1]); auto sum = vaddq_f32(vcvt_f32_f16(vget_low_f16(sum16)), vcvt_f32_f16(vget_high_f16(sum16))); - info.store(ix, 0, d*vaddvq_f32(sum) + dav*row_sum[0]); + info.store(ix, 0, vaddvq_f32(sum) + dav*row_sum[0]); } else { for (int iy = 0; iy < nrc_y; ++iy) { auto sum = vaddq_f32(vcvt_f32_f16(vget_low_f16(accd[iy])), vcvt_f32_f16(vget_high_f16(accd[iy]))); - info.store(ix, iy, d*vaddvq_f32(sum) + dav*row_sum[iy]); + info.store(ix, iy, vaddvq_f32(sum) + dav*row_sum[iy]); } } } } template <int nrc_y> -static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); const int nb = n/QK_K; constexpr int kNumGroups = 64; @@ -693,11 +996,11 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) { - if (ne00%QK_K == 0 && ggml_type(typeB) == GGML_TYPE_F32 && ggml_type(typeA) == GGML_TYPE_IQ4_KT) { - IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels); - func16 = nullptr; - return true; - } + //if (ne00%QK_K == 0 && ggml_type(typeB) == GGML_TYPE_F32 && ggml_type(typeA) == GGML_TYPE_IQ4_KT) { + // IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels); + // func16 = nullptr; + // return true; + //} if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F16) { return false; @@ -722,6 +1025,17 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat return true; } +bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, size_t stride_y, int nrc_x) { + switch (type) { + case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break; + case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break; + case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break; + default: return false; + } + + return true; +} + #endif #endif |