diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-06-20 09:26:36 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-20 09:26:36 +0300 |
commit | 1843ed22c56cea6a4016005e78e26afd6c0c3948 (patch) | |
tree | 6dc69ccb0a3ec7687665bc2f3b0d59bdffa033ee | |
parent | 144ee1c4c68ac288607210a0f3bcb30b30b8682d (diff) |
New integer trellis on ARM_NEON (#544)
* Adapt iq3_kt to new trellis on NEON
* iq3_kt is now working on NEON
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/iqk/iqk_gemm_ktquants.cpp | 206 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 4 |
2 files changed, 206 insertions, 4 deletions
diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index e69e3561..88b15eea 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1585,6 +1585,7 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +template <bool is_abs = false> struct Trellis3 { constexpr static uint32_t ka = 0xCBAC1FED; constexpr static uint32_t ka1 = ka*ka; @@ -1611,6 +1612,9 @@ struct Trellis3 { i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); auto s2 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); result.val[i] = vaddq_s8(result.val[i], vpaddq_s8(s1, s2)); + if constexpr (is_abs) { + result.val[i] = vabsq_s8(result.val[i]); + } } return result; } @@ -1630,6 +1634,9 @@ struct Trellis3 { i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); auto s2 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); result.val[i] = vaddq_s8(result.val[i], vpaddq_s8(s1, s2)); + if constexpr (is_abs) { + result.val[i] = vreinterpretq_s8_u8(vabsq_s8(result.val[i])); + } } return result; } @@ -1657,6 +1664,9 @@ struct Trellis3 { result.val[i+0] = vaddq_s8(result.val[i+0], vpaddq_s8(s1_1, s2_1)); result.val[i+2] = vaddq_s8(result.val[i+2], vpaddq_s8(s1_2, s2_2)); } + if constexpr (is_abs) { + for (int i = 0; i < 4; ++i) result.val[i] = vabsq_s8(result.val[i]); + } return result; } static uint8x16_t load_shuffle() { @@ -1872,6 +1882,69 @@ void iqk_dequantize_iq2_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, } } +void iqk_dequantize_iq3_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + + Trellis3<true> trellis; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq3_kt * x8[8]; + + float dkt[8]; + float ls[8], ls_all[64]; + uint32_t idx[8]; + uint32_t sign_bits[16]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0] * 1.05f; + x8[k] = (const block_iq3_kt *)(dptr + 1); + } + auto vd = vld1q_f32_x2(dkt); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + auto u32 = *(const uint32_t *)x8[k][i].scales; + auto s8_u32 = uint32x2_t{u32, u32 >> 4}; + s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f)); + auto s16 = vmovl_s8(vreinterpret_s8_u32(s8_u32)); + vst1q_f32(ls_all + 8*k + 0, vcvtq_f32_s32(vmovl_s16(vget_low_s16(s16)))); + vst1q_f32(ls_all + 8*k + 4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16)))); + } + auto mask = vdupq_n_u8(1); + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib]; + auto scales1 = vmulq_f32(vd.val[0], vld1q_f32(ls+0)); + auto scales2 = vmulq_f32(vd.val[1], vld1q_f32(ls+4)); + vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1)); + vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2)); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 8; ++k) { + const uint16_t * ql = (const uint16_t *)x8[k][i].ql; + idx[k] = ql[4*ib+j] + 4096; + auto qh = (const uint32_t *)x8[k][i].qh; + sign_bits[k+0] = qh[2*j+0]; + sign_bits[k+8] = qh[2*j+1]; + } + auto packed = trellis.next64(idx); + auto signs = vld1q_u8_x4((const uint8_t *)sign_bits); + for (int l = 0; l < 4; ++l) { + auto s = vorrq_u8(vceqq_u8(vandq_u8(signs.val[l], mask), mask), vdupq_n_u8(1)); + packed.val[l] = vmulq_s8(packed.val[l], vreinterpretq_s8_u8(s)); + } + vst1q_s8_x4(y[ib].qs+64*j, packed); + } + mask = vshlq_n_u8(mask, 1); + } + y += 8; // = QK_K/32; + } + } +} + template <int nrc_y> void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); @@ -1974,6 +2047,126 @@ void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& } } +template <int nrc_y> +void mul_mat_iq3_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis3<true> trellis; + + constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y; + + float32x4_t accd[k_acc]; + + const block_q8_0_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_0_x4 *)info.src1_row(iy); + } + + int8x16x2_t xv[8]; + int32x4x4_t dot; + + auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) { + for (int k = 0; k < 4; ++k) { + auto yv = vld1q_s8_x2(y + 32*k); + dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]); + } + dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]); + dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]); + return vpaddq_s32(dot.val[0], dot.val[2]); + }; + + float32x4x2_t scales; + auto mask = vdupq_n_u8(1); + auto maskh = vdupq_n_u8(0x10); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = vdupq_n_f32(dptr[0]*1.05f); + const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0); + + for (int i = 0; i < nb; ++i) { + auto u32 = *(const uint32_t *)x[i].scales; + auto s8_u32 = uint32x2_t{u32, u32 >> 4}; + s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f)); + auto s16 = vmovl_s8(vreinterpret_s8_u32(s8_u32)); + scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (s16)))); + scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16)))); + const uint16_t * ql = (const uint16_t *)x[i].ql; + auto sign_bits = vld1q_u8_x2(x[i].qh); + if constexpr (nrc_y == 1) { + const block_q8_0_x4& ybl = y[0][2*i+0]; + const block_q8_0_x4& ybh = y[0][2*i+1]; + auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d))); + auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d))); + int32x4x4_t suml = {}; + int32x4x4_t sumh = {}; + for (int ib = 0; ib < 4; ++ib) { + auto xl = trellis.next32(ql + 4*ib + 0, 4096); + auto signs1 = vorrq_u8(vceqq_u8(vandq_u8(sign_bits.val[0], mask), mask), mask); + auto signs2 = vorrq_u8(vceqq_u8(vandq_u8(sign_bits.val[1], mask), mask), mask); + xl.val[0] = vmulq_s8(xl.val[0], vreinterpretq_s8_u8(signs1)); + xl.val[1] = vmulq_s8(xl.val[1], vreinterpretq_s8_u8(signs2)); + auto xh = trellis.next32(ql + 4*ib + 16, 4096); + signs1 = vorrq_u8(vceqq_u8(vandq_u8(sign_bits.val[0], maskh), maskh), mask); + signs2 = vorrq_u8(vceqq_u8(vandq_u8(sign_bits.val[1], maskh), maskh), mask); + xh.val[0] = vmulq_s8(xh.val[0], vreinterpretq_s8_u8(signs1)); + xh.val[1] = vmulq_s8(xh.val[1], vreinterpretq_s8_u8(signs2)); + auto yl = vld1q_s8_x2(ybl.qs + 32*ib); + auto yh = vld1q_s8_x2(ybh.qs + 32*ib); + suml.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]); + sumh.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]); + sign_bits.val[0] = vshrq_n_u8(sign_bits.val[0], 1); + sign_bits.val[1] = vshrq_n_u8(sign_bits.val[1], 1); + } + auto sl1 = vpaddq_s32(suml.val[0], suml.val[1]); + auto sl2 = vpaddq_s32(suml.val[2], suml.val[3]); + auto sl = vpaddq_s32(sl1, sl2); + auto sh1 = vpaddq_s32(sumh.val[0], sumh.val[1]); + auto sh2 = vpaddq_s32(sumh.val[2], sumh.val[3]); + auto sh = vpaddq_s32(sh1, sh2); + accd[0] = vfmaq_f32(accd[0], dyl, vcvtq_f32_s32(sl)); + accd[1] = vfmaq_f32(accd[1], dyh, vcvtq_f32_s32(sh)); + } else { + for (int k = 0; k < 8; ++k) { + xv[k] = trellis.next32(ql + 4*k, 4096); + auto signs1 = vorrq_u8(vceqq_u8(vandq_u8(sign_bits.val[0], mask), mask), mask); + auto signs2 = vorrq_u8(vceqq_u8(vandq_u8(sign_bits.val[1], mask), mask), mask); + xv[k].val[0] = vmulq_s8(xv[k].val[0], vreinterpretq_s8_u8(signs1)); + xv[k].val[1] = vmulq_s8(xv[k].val[1], vreinterpretq_s8_u8(signs2)); + sign_bits.val[0] = vshrq_n_u8(sign_bits.val[0], 1); + sign_bits.val[1] = vshrq_n_u8(sign_bits.val[1], 1); + } + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_0_x4& ybl = y[iy][2*i+0]; + const block_q8_0_x4& ybh = y[iy][2*i+1]; + auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d))); + auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d))); + auto sumil = compute_dot(ybl.qs, xv+0); + auto sumih = compute_dot(ybh.qs, xv+4); + if constexpr (nrc_y == 1) { + accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil)); + accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih)); + } else { + accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil)); + accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih)); + } + } + } + } + + if constexpr (nrc_y == 1) { + info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1]))); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(accd[iy])); + } + } + } +} + } bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) { @@ -1990,6 +2183,15 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat return false; } + if (ggml_type(typeA) == GGML_TYPE_IQ3_KT) { + if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) { + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_kt_q8_0_x4_T, kernels); + func16 = nullptr; + return true; + } + return false; + } + if (ggml_type(typeA) == GGML_TYPE_IQ2_KT) { if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) { IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_q8_0_x4_T, kernels); @@ -2022,10 +2224,10 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat return true; } -bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, size_t stride_y, int nrc_x) { +bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, [[maybe_unused]] size_t stride_y, int nrc_x) { switch (type) { case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt_q80_r8(n, vx, bx, y, nrc_x); break; - case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break; + case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt_q80_r8(n, vx, bx, y, nrc_x); break; case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break; default: return false; } diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index f718e43e..cf3d752d 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -272,7 +272,7 @@ struct MulMat { #else switch (type) { case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; - case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type; + case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; default: break; } @@ -435,7 +435,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, return iqk_convert_1bit_q80_r8(typeA, n, vx, bx, vy, nrc_x); default: - return false; + break; } return false; |