diff options
-rw-r--r-- | iqk_mul_mat.cpp | 43 |
1 files changed, 32 insertions, 11 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 0b2e3552..d9aa074e 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -4393,9 +4393,19 @@ struct DequantizerIQ1BN { return vld1q_s16(data); } static inline uint16x8_t load_mult() { - static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1}; + //static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1}; + static const uint16_t data[8] = {2187*8, 729*8, 243*8, 81*8, 27*8, 9*8, 3*8, 1*8}; return vld1q_u16(data); } + //static inline uint8x16x4_t load_shuffles(uint16_t s0) { + // uint8x16x4_t r; + // auto step = vdupq_n_u8(4); + // r.val[0] = vreinterpretq_u8_u16(vdupq_n_u16(s0)); + // r.val[1] = vaddq_u8(r.val[0], step); + // r.val[2] = vaddq_u8(r.val[1], step); + // r.val[3] = vaddq_u8(r.val[2], step); + // return r; + //} const uint8x16_t shuff_l = load_shuffle_l(); const uint8x16_t shuff_h = load_shuffle_h(); @@ -4405,22 +4415,33 @@ struct DequantizerIQ1BN { const uint16x8_t mask_hh = vdupq_n_u16(4096); const int16x8_t shift_hh = load_shift_hh(); const uint16x8_t mult = load_mult(); - const uint16x8_t mask = vdupq_n_u16(0x1fff); - const uint16x8_t m3 = vdupq_n_u16(3); - - inline void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const { + const uint8x16_t step = vdupq_n_u8(2); + const uint8x16_t shuff0 = vreinterpretq_u8_u16(vdupq_n_u16(0x0100)); + //const uint8x16x4_t shuff1 = load_shuffles(0x0100); + //const uint8x16x4_t shuff2 = load_shuffles(0x0302); + //const uint16x8_t mask = vdupq_n_u16(0x1fff); + //const uint16x8_t m3 = vdupq_n_u16(3); + + IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const { auto data = vld1q_u8((const uint8_t *)x); auto aux1 = vqtbl1q_u8(data, shuff_l); auto aux2 = vandq_u16(vshlq_u32(vqtbl1q_u8(data, shuff_h), shift_h), mask_h); auto aux3 = vandq_u16(vshlq_u16(vqtbl1q_u8(data, shuff_hh), shift_hh), mask_hh); auto all = vorrq_u16(vorrq_u16(aux1, aux2), aux3); - auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100)); - auto step = vdupq_n_u8(2); + auto shuffle = shuff0; + //auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100)); + //auto step = vdupq_n_u8(2); for (int k = 0; k < 4; ++k) { auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step); auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step); - v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13); - v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13); + //auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff1.val[k])); + //auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff2.val[k])); + v1 = vmulq_u16(v1, mult); + v2 = vmulq_u16(v2, mult); + v1 = vshrq_n_u16(vhaddq_u16(v1, vshrq_n_u16(v1, 1)), 14); + v2 = vshrq_n_u16(vhaddq_u16(v2, vshrq_n_u16(v2, 1)), 14); + //v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13); + //v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13); v.val[k] = vsubq_s8(vreinterpretq_s8_u8(vcombine_u8(vmovn_u16(v1), vmovn_u16(v2))), m1); } } @@ -4448,9 +4469,9 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn deq.prepare_iq1bn_quants(x+2*i+0, v1); auto q = q8.load_quants64(0, i, 0); for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]); - deq.prepare_iq1bn_quants(x+2*i+1, v1); + deq.prepare_iq1bn_quants(x+2*i+1, v2); q = q8.load_quants64(0, i, 1); - for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]); + for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v2.val[j]); } accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3])); } |