summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--iqk_mul_mat.cpp50
1 files changed, 26 insertions, 24 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 78c02347..9f4224cc 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -4026,19 +4026,21 @@ template <int nrc> struct Q8_K64 {
template <int nrc_y>
static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
+
Q8_K64<nrc_y> q8(info);
- float32x4_t accd[nrc_y];
- int8x16x4_t signs;
- uint64x2x4_t aux;
- uint8x16x4_t vp, vm;
+ float32x4_t accd[nrc_y];
+ int8x16x4_t signs;
+ uint64x2x4_t a;
+ int8x16x4_t v;
const auto m1 = vdupq_n_u8(1);
- uint8x16x4_t sign_shuffles;
- sign_shuffles.val[0] = vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101});
- sign_shuffles.val[1] = vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303});
- sign_shuffles.val[2] = vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505});
- sign_shuffles.val[3] = vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707});
+ const uint8x16x4_t sign_shuffles = {
+ vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101}),
+ vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303}),
+ vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}),
+ vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}),
+ };
const auto shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0808080808080808});
const auto shuff2 = vaddq_u8(shuff1, m1);
const auto mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
@@ -4067,26 +4069,26 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
auto ql = x[i].ql;
auto qh = x[i].qh;
- aux.val[0] = uint64x2_t{iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)]};
- aux.val[1] = uint64x2_t{iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)]};
- aux.val[2] = uint64x2_t{iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)]};
- aux.val[3] = uint64x2_t{iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)]};
-
- vp.val[0] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[0], shuff1), mask1), mask1);
- vp.val[1] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[1], shuff1), mask1), mask1);
- vp.val[2] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[2], shuff1), mask1), mask1);
- vp.val[3] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[3], shuff1), mask1), mask1);
- vm.val[0] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[0], shuff2), mask1), mask1);
- vm.val[1] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[1], shuff2), mask1), mask1);
- vm.val[2] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[2], shuff2), mask1), mask1);
- vm.val[3] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[3], shuff2), mask1), mask1);
+ a.val[0] = uint64x2_t{iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)]};
+ a.val[1] = uint64x2_t{iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)]};
+ a.val[2] = uint64x2_t{iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)]};
+ a.val[3] = uint64x2_t{iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)]};
+
+ v.val[0] = vsubq_s8(vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff2), mask1), mask1)),
+ vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), mask1), mask1)));
+ v.val[1] = vsubq_s8(vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff2), mask1), mask1)),
+ vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), mask1), mask1)));
+ v.val[2] = vsubq_s8(vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff2), mask1), mask1)),
+ vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), mask1), mask1)));
+ v.val[3] = vsubq_s8(vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff2), mask1), mask1)),
+ vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), mask1), mask1)));
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i);
int32x4_t sumi = vdupq_n_s32(0);
for (int j = 0; j < 4; ++j) {
- auto tmp = vmulq_s8(q.val[j], signs.val[j]);
- tmp = vsubq_s8(vmulq_s8(q.val[j], vm.val[j]), vmulq_s8(q.val[j], vp.val[j]));
+ auto tmp = vmulq_s8(q.val[j], vreinterpretq_s8_u8(signs.val[j]));
+ tmp = vmulq_s8(q.val[j], v.val[j]);
sumi = ggml_vdotq_s32(sumi, m1, tmp);
}
accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi));