diff options
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r-- | iqk_mul_mat.cpp | 97 |
1 files changed, 52 insertions, 45 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 0f53e02c..0b2e3552 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -4375,47 +4375,54 @@ static const uint64_t kall_signs[257] = { struct DequantizerIQ1BN { const uint8x16_t m1 = vdupq_n_u8(1); - const uint8x16x4_t sign_shuffles = { - vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101}), - vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303}), - vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}), - vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}), - }; - const int8x16_t shift = vreinterpretq_s16_u64(vdupq_n_u64(0xfffafffcfffe0000)); - const uint8x16_t qmask = vdupq_n_u8(3); - const uint8x16_t shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0100010001000100, 0x0908090809080908}); - const uint8x16_t mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); - int8x16x4_t signs; - uint64x2x4_t a; - inline void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, int8x16x4_t& v) { - auto all_signs = vld1q_u8((const uint8_t *)(kall_signs + extra)); - //auto all_signs = vdupq_n_u8(extra); - //all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1); - signs.val[0] = vqtbl1q_u8(all_signs, sign_shuffles.val[0]); - signs.val[1] = vqtbl1q_u8(all_signs, sign_shuffles.val[1]); - signs.val[2] = vqtbl1q_u8(all_signs, sign_shuffles.val[2]); - signs.val[3] = vqtbl1q_u8(all_signs, sign_shuffles.val[3]); - - uint32_t aux32[2]; - std::memcpy(aux32, qh, 4); - aux32[1] = aux32[0] & 0xf0f0f0f0; - aux32[0] &= 0x0f0f0f0f; - const uint8_t * h = (const uint8_t *)aux32; - a.val[0] = uint64x2_t{iq1bn_grid_u16[ql[0] | (h[0] << 8)], iq1bn_grid_u16[ql[1] | (h[4] << 4)]}; - a.val[1] = uint64x2_t{iq1bn_grid_u16[ql[2] | (h[1] << 8)], iq1bn_grid_u16[ql[3] | (h[5] << 4)]}; - a.val[2] = uint64x2_t{iq1bn_grid_u16[ql[4] | (h[2] << 8)], iq1bn_grid_u16[ql[5] | (h[6] << 4)]}; - a.val[3] = uint64x2_t{iq1bn_grid_u16[ql[6] | (h[3] << 8)], iq1bn_grid_u16[ql[7] | (h[7] << 4)]}; - - v.val[0] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), shift), qmask), m1); - v.val[1] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), shift), qmask), m1); - v.val[2] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), shift), qmask), m1); - v.val[3] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), shift), qmask), m1); - - v.val[0] = vmulq_s8(v.val[0], signs.val[0]); - v.val[1] = vmulq_s8(v.val[1], signs.val[1]); - v.val[2] = vmulq_s8(v.val[2], signs.val[2]); - v.val[3] = vmulq_s8(v.val[3], signs.val[3]); + static inline uint8x16_t load_shuffle_l() { + static const uint8_t data[16] = {1, 255, 2, 255, 3, 255, 4, 255, 5, 255, 6, 255, 7, 255, 8, 255}; + return vld1q_u8(data); + } + static inline uint8x16_t load_shuffle_h() { + static const uint8_t data[16] = {9, 255, 10, 255, 11, 255, 12, 255, 9, 255, 10, 255, 11, 255, 12, 255}; + return vld1q_u8(data); + } + static inline uint8x16_t load_shuffle_hh() { + static const uint8_t data[16] = {0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255}; + return vld1q_u8(data); + } + static inline int16x8_t load_shift_hh() { + static const int16_t data[8] = {12, 11, 10, 9, 8, 7, 6, 5}; + return vld1q_s16(data); + } + static inline uint16x8_t load_mult() { + static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1}; + return vld1q_u16(data); + } + + const uint8x16_t shuff_l = load_shuffle_l(); + const uint8x16_t shuff_h = load_shuffle_h(); + const int32x4_t shift_h = {8, 8, 4, 4}; + const uint16x8_t mask_h = vdupq_n_u16(0x0f00); + const uint8x16_t shuff_hh = load_shuffle_hh(); + const uint16x8_t mask_hh = vdupq_n_u16(4096); + const int16x8_t shift_hh = load_shift_hh(); + const uint16x8_t mult = load_mult(); + const uint16x8_t mask = vdupq_n_u16(0x1fff); + const uint16x8_t m3 = vdupq_n_u16(3); + + inline void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const { + auto data = vld1q_u8((const uint8_t *)x); + auto aux1 = vqtbl1q_u8(data, shuff_l); + auto aux2 = vandq_u16(vshlq_u32(vqtbl1q_u8(data, shuff_h), shift_h), mask_h); + auto aux3 = vandq_u16(vshlq_u16(vqtbl1q_u8(data, shuff_hh), shift_hh), mask_hh); + auto all = vorrq_u16(vorrq_u16(aux1, aux2), aux3); + auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100)); + auto step = vdupq_n_u8(2); + for (int k = 0; k < 4; ++k) { + auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step); + auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step); + v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13); + v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13); + v.val[k] = vsubq_s8(vreinterpretq_s8_u8(vcombine_u8(vmovn_u16(v1), vmovn_u16(v2))), m1); + } } }; @@ -4438,10 +4445,10 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn if constexpr (nrc_y == 1) { int32x4_t acc[4] = {}; for (int i = 0; i < nb/2; ++i) { - deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1); + deq.prepare_iq1bn_quants(x+2*i+0, v1); auto q = q8.load_quants64(0, i, 0); for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]); - deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v1); + deq.prepare_iq1bn_quants(x+2*i+1, v1); q = q8.load_quants64(0, i, 1); for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]); } @@ -4453,8 +4460,8 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn for (int i = 0; i < nb/2; ++i) { - deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1); - deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v2); + deq.prepare_iq1bn_quants(x+2*i+0, v1); + deq.prepare_iq1bn_quants(x+2*i+1, v2); for (int iy = 0; iy < nrc_y; ++iy) { auto q = q8.load_quants(iy, i, 0); @@ -4470,7 +4477,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } int i = 2*(nb/2); if (i < nb) { - deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, v1); + deq.prepare_iq1bn_quants(x+i, v1); if constexpr (nrc_y == 1) { auto q = q8.load_quants(0, i/2, 0); for (int j = 0; j < 4; ++j) { |