summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--iqk_mul_mat.cpp97
1 files changed, 52 insertions, 45 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 0f53e02c..0b2e3552 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -4375,47 +4375,54 @@ static const uint64_t kall_signs[257] = {
struct DequantizerIQ1BN {
const uint8x16_t m1 = vdupq_n_u8(1);
- const uint8x16x4_t sign_shuffles = {
- vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101}),
- vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303}),
- vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}),
- vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}),
- };
- const int8x16_t shift = vreinterpretq_s16_u64(vdupq_n_u64(0xfffafffcfffe0000));
- const uint8x16_t qmask = vdupq_n_u8(3);
- const uint8x16_t shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0100010001000100, 0x0908090809080908});
- const uint8x16_t mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
- int8x16x4_t signs;
- uint64x2x4_t a;
- inline void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, int8x16x4_t& v) {
- auto all_signs = vld1q_u8((const uint8_t *)(kall_signs + extra));
- //auto all_signs = vdupq_n_u8(extra);
- //all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1);
- signs.val[0] = vqtbl1q_u8(all_signs, sign_shuffles.val[0]);
- signs.val[1] = vqtbl1q_u8(all_signs, sign_shuffles.val[1]);
- signs.val[2] = vqtbl1q_u8(all_signs, sign_shuffles.val[2]);
- signs.val[3] = vqtbl1q_u8(all_signs, sign_shuffles.val[3]);
-
- uint32_t aux32[2];
- std::memcpy(aux32, qh, 4);
- aux32[1] = aux32[0] & 0xf0f0f0f0;
- aux32[0] &= 0x0f0f0f0f;
- const uint8_t * h = (const uint8_t *)aux32;
- a.val[0] = uint64x2_t{iq1bn_grid_u16[ql[0] | (h[0] << 8)], iq1bn_grid_u16[ql[1] | (h[4] << 4)]};
- a.val[1] = uint64x2_t{iq1bn_grid_u16[ql[2] | (h[1] << 8)], iq1bn_grid_u16[ql[3] | (h[5] << 4)]};
- a.val[2] = uint64x2_t{iq1bn_grid_u16[ql[4] | (h[2] << 8)], iq1bn_grid_u16[ql[5] | (h[6] << 4)]};
- a.val[3] = uint64x2_t{iq1bn_grid_u16[ql[6] | (h[3] << 8)], iq1bn_grid_u16[ql[7] | (h[7] << 4)]};
-
- v.val[0] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), shift), qmask), m1);
- v.val[1] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), shift), qmask), m1);
- v.val[2] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), shift), qmask), m1);
- v.val[3] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), shift), qmask), m1);
-
- v.val[0] = vmulq_s8(v.val[0], signs.val[0]);
- v.val[1] = vmulq_s8(v.val[1], signs.val[1]);
- v.val[2] = vmulq_s8(v.val[2], signs.val[2]);
- v.val[3] = vmulq_s8(v.val[3], signs.val[3]);
+ static inline uint8x16_t load_shuffle_l() {
+ static const uint8_t data[16] = {1, 255, 2, 255, 3, 255, 4, 255, 5, 255, 6, 255, 7, 255, 8, 255};
+ return vld1q_u8(data);
+ }
+ static inline uint8x16_t load_shuffle_h() {
+ static const uint8_t data[16] = {9, 255, 10, 255, 11, 255, 12, 255, 9, 255, 10, 255, 11, 255, 12, 255};
+ return vld1q_u8(data);
+ }
+ static inline uint8x16_t load_shuffle_hh() {
+ static const uint8_t data[16] = {0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255};
+ return vld1q_u8(data);
+ }
+ static inline int16x8_t load_shift_hh() {
+ static const int16_t data[8] = {12, 11, 10, 9, 8, 7, 6, 5};
+ return vld1q_s16(data);
+ }
+ static inline uint16x8_t load_mult() {
+ static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
+ return vld1q_u16(data);
+ }
+
+ const uint8x16_t shuff_l = load_shuffle_l();
+ const uint8x16_t shuff_h = load_shuffle_h();
+ const int32x4_t shift_h = {8, 8, 4, 4};
+ const uint16x8_t mask_h = vdupq_n_u16(0x0f00);
+ const uint8x16_t shuff_hh = load_shuffle_hh();
+ const uint16x8_t mask_hh = vdupq_n_u16(4096);
+ const int16x8_t shift_hh = load_shift_hh();
+ const uint16x8_t mult = load_mult();
+ const uint16x8_t mask = vdupq_n_u16(0x1fff);
+ const uint16x8_t m3 = vdupq_n_u16(3);
+
+ inline void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
+ auto data = vld1q_u8((const uint8_t *)x);
+ auto aux1 = vqtbl1q_u8(data, shuff_l);
+ auto aux2 = vandq_u16(vshlq_u32(vqtbl1q_u8(data, shuff_h), shift_h), mask_h);
+ auto aux3 = vandq_u16(vshlq_u16(vqtbl1q_u8(data, shuff_hh), shift_hh), mask_hh);
+ auto all = vorrq_u16(vorrq_u16(aux1, aux2), aux3);
+ auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
+ auto step = vdupq_n_u8(2);
+ for (int k = 0; k < 4; ++k) {
+ auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
+ auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
+ v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13);
+ v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13);
+ v.val[k] = vsubq_s8(vreinterpretq_s8_u8(vcombine_u8(vmovn_u16(v1), vmovn_u16(v2))), m1);
+ }
}
};
@@ -4438,10 +4445,10 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
if constexpr (nrc_y == 1) {
int32x4_t acc[4] = {};
for (int i = 0; i < nb/2; ++i) {
- deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1);
+ deq.prepare_iq1bn_quants(x+2*i+0, v1);
auto q = q8.load_quants64(0, i, 0);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
- deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v1);
+ deq.prepare_iq1bn_quants(x+2*i+1, v1);
q = q8.load_quants64(0, i, 1);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
}
@@ -4453,8 +4460,8 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
for (int i = 0; i < nb/2; ++i) {
- deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1);
- deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v2);
+ deq.prepare_iq1bn_quants(x+2*i+0, v1);
+ deq.prepare_iq1bn_quants(x+2*i+1, v2);
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i, 0);
@@ -4470,7 +4477,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
int i = 2*(nb/2);
if (i < nb) {
- deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, v1);
+ deq.prepare_iq1bn_quants(x+i, v1);
if constexpr (nrc_y == 1) {
auto q = q8.load_quants(0, i/2, 0);
for (int j = 0; j < 4; ++j) {