diff options
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r-- | iqk_mul_mat.cpp | 106 |
1 files changed, 106 insertions, 0 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index eaa263aa..83d3a472 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -4212,6 +4212,101 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } } +template <int nrc_y> +static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + const int nb = n / QK_IQ1BN; + + Q8_K64<nrc_y> q8(info); + + float32x4_t accd[nrc_y]; + int8x16x4_t signs; + uint64x2x4_t a; + int8x16x4_t v; + + const auto m1 = vdupq_n_u8(1); + const uint8x16x4_t sign_shuffles = { + vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101}), + vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303}), + vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}), + vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}), + }; + const auto shift = vreinterpretq_s8_u32(vdupq_n_u32(0xfafcfe00)); + const auto qmask = vdupq_n_u8(3); + const auto shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0909090908080808}); + const auto mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); + const auto mask2 = vdupq_n_s8(3); + + for (int ix = 0; ix < nrc_x; ++ix) { + + const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx); + float d = GGML_FP16_TO_FP32(*(const ggml_half *)x); + auto extra_ptr = (const uint16_t *)x; + + auto all_signs = vdupq_n_u8(extra_ptr[1]); + all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1); + signs.val[0] = vqtbl1q_u8(all_signs, sign_shuffles.val[0]); + signs.val[1] = vqtbl1q_u8(all_signs, sign_shuffles.val[1]); + signs.val[2] = vqtbl1q_u8(all_signs, sign_shuffles.val[2]); + signs.val[3] = vqtbl1q_u8(all_signs, sign_shuffles.val[3]); + + auto ql = (const uint8_t *)(extra_ptr + 2); + auto qh = ql + QK_IQ1BN/8; + a.val[0] = uint64x2_t{iq1bn_grid_u16[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_u16[ql[1] | ((qh[0] << 4) & 0x0f00)]}; + a.val[1] = uint64x2_t{iq1bn_grid_u16[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_u16[ql[3] | ((qh[1] << 4) & 0x0f00)]}; + a.val[2] = uint64x2_t{iq1bn_grid_u16[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_u16[ql[5] | ((qh[2] << 4) & 0x0f00)]}; + a.val[3] = uint64x2_t{iq1bn_grid_u16[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_u16[ql[7] | ((qh[3] << 4) & 0x0f00)]}; + + v.val[0] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), shift), qmask), m1); + v.val[1] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), shift), qmask), m1); + v.val[2] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), shift), qmask), m1); + v.val[3] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), shift), qmask), m1); + + v.val[0] = vmulq_s8(v.val[0], signs.val[0]); + v.val[1] = vmulq_s8(v.val[1], signs.val[1]); + v.val[2] = vmulq_s8(v.val[2], signs.val[2]); + v.val[3] = vmulq_s8(v.val[3], signs.val[3]); + + if constexpr (nrc_y == 1) { + auto q = q8.load_quants(0, 0); + int32x4_t sumi = vdupq_n_s32(0); + for (int j = 0; j < 4; ++j) { + sumi = ggml_vdotq_s32(sumi, q.val[j], v.val[j]); + } + accd[0] = vmulq_f32(vdupq_n_f32(q8.scale(0, 0)), vcvtq_f32_s32(sumi)); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + int32x4_t sumi = vdupq_n_s32(0); + auto q = q8.load_quants(iy, 0, 0); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[0]), q.val[1], v.val[1]); + q = q8.load_quants(iy, 0, 1); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[2]), q.val[1], v.val[3]); + accd[iy] = vmulq_f32(vdupq_n_f32(q8.scale(iy, 0)), vcvtq_f32_s32(sumi)); + } + } + + for (int i = 1; i < nb; ++i) { + auto q2bits = vld1q_u8(x[i].qs); + v.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); + v.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); + v.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); + v.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + for (int iy = 0; iy < nrc_y; ++iy) { + int32x4_t sumi = vdupq_n_s32(0); + auto q = q8.load_quants(iy, i, 0); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[0]), q.val[1], v.val[1]); + q = q8.load_quants(iy, i, 1); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[2]), q.val[1], v.val[3]); + accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi)); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, d * vaddvq_f32(accd[iy])); + } + + } +} + template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> || std::is_same_v<Dequantizer, DequantizerQ80>) { @@ -4306,6 +4401,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { m.funcs[7] = mul_mat_iq1bn_q8_K64<8>; expected_Btype = GGML_TYPE_Q8_K64; break; + case GGML_TYPE_IQ2_BN: + m.funcs[0] = mul_mat_iq2bn_q8_K64<1>; + m.funcs[1] = mul_mat_iq2bn_q8_K64<2>; + m.funcs[2] = mul_mat_iq2bn_q8_K64<3>; + m.funcs[3] = mul_mat_iq2bn_q8_K64<4>; + m.funcs[4] = mul_mat_iq2bn_q8_K64<5>; + m.funcs[5] = mul_mat_iq2bn_q8_K64<6>; + m.funcs[6] = mul_mat_iq2bn_q8_K64<7>; + m.funcs[7] = mul_mat_iq2bn_q8_K64<8>; + expected_Btype = GGML_TYPE_Q8_K64; + break; case GGML_TYPE_Q4_0: MulMat::set_functions<DequantizerQ40>(m); expected_Btype = GGML_TYPE_Q8_0; |