From 0c5a353ebdcc58e8b8051f2c38a92a8c23fa8092 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 17 Jun 2024 08:05:06 +0200 Subject: iqk_mul_mat(iq1_bn): WIP NEON - don't see why it is not working --- iqk_mul_mat.cpp | 50 ++++++++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 24 deletions(-) (limited to 'iqk_mul_mat.cpp') diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 78c02347..9f4224cc 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -4026,19 +4026,21 @@ template struct Q8_K64 { template static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_IQ1BN; + Q8_K64 q8(info); - float32x4_t accd[nrc_y]; - int8x16x4_t signs; - uint64x2x4_t aux; - uint8x16x4_t vp, vm; + float32x4_t accd[nrc_y]; + int8x16x4_t signs; + uint64x2x4_t a; + int8x16x4_t v; const auto m1 = vdupq_n_u8(1); - uint8x16x4_t sign_shuffles; - sign_shuffles.val[0] = vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101}); - sign_shuffles.val[1] = vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303}); - sign_shuffles.val[2] = vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}); - sign_shuffles.val[3] = vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}); + const uint8x16x4_t sign_shuffles = { + vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101}), + vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303}), + vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}), + vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}), + }; const auto shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0808080808080808}); const auto shuff2 = vaddq_u8(shuff1, m1); const auto mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); @@ -4067,26 +4069,26 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn auto ql = x[i].ql; auto qh = x[i].qh; - aux.val[0] = uint64x2_t{iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)]}; - aux.val[1] = uint64x2_t{iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)]}; - aux.val[2] = uint64x2_t{iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)]}; - aux.val[3] = uint64x2_t{iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)]}; - - vp.val[0] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[0], shuff1), mask1), mask1); - vp.val[1] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[1], shuff1), mask1), mask1); - vp.val[2] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[2], shuff1), mask1), mask1); - vp.val[3] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[3], shuff1), mask1), mask1); - vm.val[0] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[0], shuff2), mask1), mask1); - vm.val[1] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[1], shuff2), mask1), mask1); - vm.val[2] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[2], shuff2), mask1), mask1); - vm.val[3] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[3], shuff2), mask1), mask1); + a.val[0] = uint64x2_t{iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)]}; + a.val[1] = uint64x2_t{iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)]}; + a.val[2] = uint64x2_t{iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)]}; + a.val[3] = uint64x2_t{iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)]}; + + v.val[0] = vsubq_s8(vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff2), mask1), mask1)), + vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), mask1), mask1))); + v.val[1] = vsubq_s8(vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff2), mask1), mask1)), + vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), mask1), mask1))); + v.val[2] = vsubq_s8(vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff2), mask1), mask1)), + vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), mask1), mask1))); + v.val[3] = vsubq_s8(vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff2), mask1), mask1)), + vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), mask1), mask1))); for (int iy = 0; iy < nrc_y; ++iy) { auto q = q8.load_quants(iy, i); int32x4_t sumi = vdupq_n_s32(0); for (int j = 0; j < 4; ++j) { - auto tmp = vmulq_s8(q.val[j], signs.val[j]); - tmp = vsubq_s8(vmulq_s8(q.val[j], vm.val[j]), vmulq_s8(q.val[j], vp.val[j])); + auto tmp = vmulq_s8(q.val[j], vreinterpretq_s8_u8(signs.val[j])); + tmp = vmulq_s8(q.val[j], v.val[j]); sumi = ggml_vdotq_s32(sumi, m1, tmp); } accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi)); -- cgit v1.2.3