summaryrefslogtreecommitdiff
path: root/iqk_mul_mat.cpp
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-17 14:16:24 +0200
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:51 +0300
commit30a771bd6bdaa78ca79e27b38783cedd000c7840 (patch)
tree64f76b3f70df06434aa137466222b349c2714fa5 /iqk_mul_mat.cpp
parent8222c9f3d1e91096ab554f62ffbc384535b1963e (diff)
iq1_bn: better NEON implementation
PP is decent with 131 t/s (q4_0 has 150 t/s). TG is better than last commit but still bad at 33.1 t/s (in comparison q4_0 gets 52.3 t/s). I had to go to the (0, 1, 2) table. Apple Silicon clearly does not like operations with signs.
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r--iqk_mul_mat.cpp56
1 files changed, 33 insertions, 23 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 09189fa7..4a2417b4 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -4018,6 +4018,7 @@ template <int nrc> struct Q8_K64 {
Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K64 *)info.src1_row(iy); }
inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy][i].qs); }
+ inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }
inline float scale(int iy, int i) const { return y[iy][i].d; }
const block_q8_K64 * y[nrc_y];
@@ -4041,9 +4042,10 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}),
vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}),
};
- const auto shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0808080808080808});
- const auto shuff2 = vaddq_u8(shuff1, m1);
- const auto mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
+ const auto shift = vreinterpretq_s8_u32(vdupq_n_u32(0xfafcfe00));
+ const auto qmask = vdupq_n_u8(3);
+ const auto shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0909090908080808});
+ const auto mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
typedef union { float f; uint32_t i; } scale_t;
@@ -4069,29 +4071,37 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
auto ql = x[i].ql;
auto qh = x[i].qh;
- a.val[0] = uint64x2_t{iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)]};
- a.val[1] = uint64x2_t{iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)]};
- a.val[2] = uint64x2_t{iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)]};
- a.val[3] = uint64x2_t{iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)]};
-
- v.val[0] = vsubq_s8(vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff2), mask1), mask1)),
- vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), mask1), mask1)));
- v.val[1] = vsubq_s8(vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff2), mask1), mask1)),
- vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), mask1), mask1)));
- v.val[2] = vsubq_s8(vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff2), mask1), mask1)),
- vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), mask1), mask1)));
- v.val[3] = vsubq_s8(vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff2), mask1), mask1)),
- vreinterpretq_s8_u8(vceqq_u8(vandq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), mask1), mask1)));
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto q = q8.load_quants(iy, i);
+ a.val[0] = uint64x2_t{iq1bn_grid_u16[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_u16[ql[1] | ((qh[0] << 4) & 0x0f00)]};
+ a.val[1] = uint64x2_t{iq1bn_grid_u16[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_u16[ql[3] | ((qh[1] << 4) & 0x0f00)]};
+ a.val[2] = uint64x2_t{iq1bn_grid_u16[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_u16[ql[5] | ((qh[2] << 4) & 0x0f00)]};
+ a.val[3] = uint64x2_t{iq1bn_grid_u16[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_u16[ql[7] | ((qh[3] << 4) & 0x0f00)]};
+
+ v.val[0] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), shift), qmask), m1);
+ v.val[1] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), shift), qmask), m1);
+ v.val[2] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), shift), qmask), m1);
+ v.val[3] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), shift), qmask), m1);
+
+ v.val[0] = vmulq_s8(v.val[0], signs.val[0]);
+ v.val[1] = vmulq_s8(v.val[1], signs.val[1]);
+ v.val[2] = vmulq_s8(v.val[2], signs.val[2]);
+ v.val[3] = vmulq_s8(v.val[3], signs.val[3]);
+
+ if constexpr (nrc_y == 1) {
+ auto q = q8.load_quants(0, i);
int32x4_t sumi = vdupq_n_s32(0);
for (int j = 0; j < 4; ++j) {
- auto tmp = vmulq_s8(q.val[j], vreinterpretq_s8_u8(signs.val[j]));
- tmp = vmulq_s8(tmp, v.val[j]);
- sumi = ggml_vdotq_s32(sumi, m1, tmp);
+ sumi = ggml_vdotq_s32(sumi, q.val[j], v.val[j]);
+ }
+ accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i)), vcvtq_f32_s32(sumi));
+ } else {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ int32x4_t sumi = vdupq_n_s32(0);
+ auto q = q8.load_quants(iy, i, 0);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[0]), q.val[1], v.val[1]);
+ q = q8.load_quants(iy, i, 1);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[2]), q.val[1], v.val[3]);
+ accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi));
}
- accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi));
}
}