summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-07-16 08:32:15 +0200
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-07-16 08:32:15 +0200
commit6393e2682720893092f77c2a6d428a2c13ecccf7 (patch)
tree5abf83a8fa424ba13f213c2cc72fd4e54169047b
parent26a1a689c6140ac823f5a74259f0423d7394deed (diff)
iq1bn(no lookup): NEON attempts
We are at TG-128 = 25.7 t/s, which is quite a bit worse than lookup.
-rw-r--r--iqk_mul_mat.cpp43
1 files changed, 32 insertions, 11 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 0b2e3552..d9aa074e 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -4393,9 +4393,19 @@ struct DequantizerIQ1BN {
return vld1q_s16(data);
}
static inline uint16x8_t load_mult() {
- static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
+ //static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
+ static const uint16_t data[8] = {2187*8, 729*8, 243*8, 81*8, 27*8, 9*8, 3*8, 1*8};
return vld1q_u16(data);
}
+ //static inline uint8x16x4_t load_shuffles(uint16_t s0) {
+ // uint8x16x4_t r;
+ // auto step = vdupq_n_u8(4);
+ // r.val[0] = vreinterpretq_u8_u16(vdupq_n_u16(s0));
+ // r.val[1] = vaddq_u8(r.val[0], step);
+ // r.val[2] = vaddq_u8(r.val[1], step);
+ // r.val[3] = vaddq_u8(r.val[2], step);
+ // return r;
+ //}
const uint8x16_t shuff_l = load_shuffle_l();
const uint8x16_t shuff_h = load_shuffle_h();
@@ -4405,22 +4415,33 @@ struct DequantizerIQ1BN {
const uint16x8_t mask_hh = vdupq_n_u16(4096);
const int16x8_t shift_hh = load_shift_hh();
const uint16x8_t mult = load_mult();
- const uint16x8_t mask = vdupq_n_u16(0x1fff);
- const uint16x8_t m3 = vdupq_n_u16(3);
-
- inline void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
+ const uint8x16_t step = vdupq_n_u8(2);
+ const uint8x16_t shuff0 = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
+ //const uint8x16x4_t shuff1 = load_shuffles(0x0100);
+ //const uint8x16x4_t shuff2 = load_shuffles(0x0302);
+ //const uint16x8_t mask = vdupq_n_u16(0x1fff);
+ //const uint16x8_t m3 = vdupq_n_u16(3);
+
+ IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
auto data = vld1q_u8((const uint8_t *)x);
auto aux1 = vqtbl1q_u8(data, shuff_l);
auto aux2 = vandq_u16(vshlq_u32(vqtbl1q_u8(data, shuff_h), shift_h), mask_h);
auto aux3 = vandq_u16(vshlq_u16(vqtbl1q_u8(data, shuff_hh), shift_hh), mask_hh);
auto all = vorrq_u16(vorrq_u16(aux1, aux2), aux3);
- auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
- auto step = vdupq_n_u8(2);
+ auto shuffle = shuff0;
+ //auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
+ //auto step = vdupq_n_u8(2);
for (int k = 0; k < 4; ++k) {
auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
- v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13);
- v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13);
+ //auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff1.val[k]));
+ //auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff2.val[k]));
+ v1 = vmulq_u16(v1, mult);
+ v2 = vmulq_u16(v2, mult);
+ v1 = vshrq_n_u16(vhaddq_u16(v1, vshrq_n_u16(v1, 1)), 14);
+ v2 = vshrq_n_u16(vhaddq_u16(v2, vshrq_n_u16(v2, 1)), 14);
+ //v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13);
+ //v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13);
v.val[k] = vsubq_s8(vreinterpretq_s8_u8(vcombine_u8(vmovn_u16(v1), vmovn_u16(v2))), m1);
}
}
@@ -4448,9 +4469,9 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
deq.prepare_iq1bn_quants(x+2*i+0, v1);
auto q = q8.load_quants64(0, i, 0);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
- deq.prepare_iq1bn_quants(x+2*i+1, v1);
+ deq.prepare_iq1bn_quants(x+2*i+1, v2);
q = q8.load_quants64(0, i, 1);
- for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
+ for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v2.val[j]);
}
accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3]));
}