summaryrefslogtreecommitdiff
path: root/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r--iqk_mul_mat.cpp387
1 files changed, 266 insertions, 121 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 0cee0ff4..5148d184 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -1318,27 +1318,56 @@ template <int nrc> struct Q8_K64 {
constexpr static int nrc_y = nrc;
- Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K64 *)info.src1_row(iy); }
+ Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K128 *)info.src1_row(iy); }
inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); }
inline float scale(int iy, int i) const { return y[iy][i].d; }
- const block_q8_K64 * y[nrc_y];
+ const block_q8_K128 * y[nrc_y];
+};
+
+struct DequantizerIQ1BN {
+ const __m256i m1_8 = _mm256_set1_epi8(1);
+ const __m256i shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000);
+ const __m256i shuff2 = _mm256_add_epi8(shuff1, m1_8);
+ const __m256i shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ const __m256i shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404);
+ const __m256i mask1 = _mm256_set1_epi64x(0x8040201008040201);
+ //__m256i signs[2];
+
+ IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) {
+ //auto all_signs = _mm256_set1_epi8(extra);
+ //all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8);
+ //signs[0] = _mm256_shuffle_epi8(all_signs, shuff3);
+ //signs[1] = _mm256_shuffle_epi8(all_signs, shuff4);
+
+ auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)],
+ iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]);
+ auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)],
+ iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]);
+
+ v1 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1),
+ _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1));
+ v2 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1),
+ _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1));
+ //v1 = _mm256_sign_epi8(v1, signs[0]);
+ //v2 = _mm256_sign_epi8(v2, signs[1]);
+
+ auto all_signs = _mm256_set1_epi8(extra);
+ all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8);
+ v1 = _mm256_sign_epi8(v1, _mm256_shuffle_epi8(all_signs, shuff3));
+ v2 = _mm256_sign_epi8(v2, _mm256_shuffle_epi8(all_signs, shuff4));
+ }
};
template <int nrc_y>
IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
+ DequantizerIQ1BN deq;
__m256 accd[nrc_y];
- __m256i signs[2];
+ __m256i val[4];
- const auto m1_8 = _mm256_set1_epi8(1);
- const auto shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000);
- const auto shuff2 = _mm256_add_epi8(shuff1, m1_8);
- const auto shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
- const auto shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404);
- const auto mask1 = _mm256_set1_epi64x(0x8040201008040201);
#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
const auto m1_16 = _mm256_set1_epi16(1);
#endif
@@ -1351,47 +1380,43 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
- for (int i = 0; i < nb; ++i) {
+ for (int i = 0; i < nb/2; ++i) {
- auto all_signs = _mm256_set1_epi8(x[i].extra);
- all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8);
- signs[0] = _mm256_shuffle_epi8(all_signs, shuff3);
- signs[1] = _mm256_shuffle_epi8(all_signs, shuff4);
+ deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]);
+ deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]);
- auto ql = x[i].ql;
- auto qh = x[i].qh;
- auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)],
- iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]);
- auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)],
- iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]);
-
- auto v1 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1),
- _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1));
- auto v2 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1),
- _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1));
-
- if constexpr (nrc_y == 1) {
- auto dot1 = _mm256_sign_epi8(_mm256_sign_epi8(q8.load_quants(0, i, 0), signs[0]), v1);
- auto dot2 = _mm256_sign_epi8(_mm256_sign_epi8(q8.load_quants(0, i, 1), signs[1]), v2);
+ for (int iy = 0; iy < nrc_y; ++iy) {
#if defined __AVX512VNNI__ && defined __AVX512VL__
- auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2);
+ auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]);
+ auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]);
+ auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]);
+ auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]);
+ auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
+ _mm256_setzero_si256(), deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]);
#else
- auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
+ auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])),
+ _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1])));
+ auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])),
+ _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3])));
+ dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot1), accd[iy]);
#endif
- accd[0] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(0, i)), _mm256_cvtepi32_ps(dot), accd[0]);
- } else {
- v1 = _mm256_sign_epi8(v1, signs[0]);
- v2 = _mm256_sign_epi8(v2, signs[1]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1);
- auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2);
+ }
+ }
+ int i = 2*(nb/2);
+ if (i < nb) {
+ deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, val[0], val[1]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
+ auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);
#if defined __AVX512VNNI__ && defined __AVX512VL__
- auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2);
+ auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.m1_8, dot1), deq.m1_8, dot2);
#else
- auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
+ auto dot = _mm256_madd_epi16(m1_16,
+ _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)));
#endif
- accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]);
- }
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i/2)), _mm256_cvtepi32_ps(dot), accd[iy]);
}
}
@@ -1419,38 +1444,73 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx);
{
- auto q2bits = _mm_loadu_si128((const __m128i *)x[0].qs);
- auto q2 = MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits);
- auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2, mask2), m1_8);
- auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2, 4), mask2), m1_8);
+ auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
+ auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[1].qs);
+ auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1);
+ auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2);
+ auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
+ auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8);
+ auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8);
+ auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8);
for (int iy = 0; iy < nrc_y; ++iy) {
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, 0, 0), v1);
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, 0, 1), v2);
+ auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, 0, 2), v3);
+ auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, 0, 3), v4);
#if defined __AVX512VNNI__ && defined __AVX512VL__
- auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2);
+ auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
+ _mm256_setzero_si256(), m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4);
#else
- auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
+ auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
+ _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)),
+ _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4))));
#endif
accd[iy] = _mm256_mul_ps(_mm256_set1_ps(q8.scale(iy, 0)), _mm256_cvtepi32_ps(dot));
}
}
- for (int i = 1; i < nb; ++i) {
- auto q2bits = _mm_loadu_si128((const __m128i *)x[i].qs);
- auto q2 = MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits);
- auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2, mask2), m1_8);
- auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2, 4), mask2), m1_8);
+ for (int i = 1; i < nb/2; ++i) {
+ auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[2*i+0].qs);
+ auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[2*i+1].qs);
+ auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1);
+ auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2);
+ auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
+ auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8);
+ auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8);
+ auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8);
for (int iy = 0; iy < nrc_y; ++iy) {
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1);
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2);
+ auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), v3);
+ auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), v4);
#if defined __AVX512VNNI__ && defined __AVX512VL__
- auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2);
+ auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
+ _mm256_setzero_si256(), m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4);
#else
- auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
+ auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
+ _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)),
+ _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4))));
#endif
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]);
}
}
+ int i = 2*(nb/2);
+ if (i < nb) {
+ auto q2bits = _mm_loadu_si128((const __m128i *)x[i].qs);
+ auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits);
+ auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
+ auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), v1);
+ auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), v2);
+#if defined __AVX512VNNI__ && defined __AVX512VL__
+ dot1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2);
+#else
+ dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
+#endif
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i/2)), _mm256_cvtepi32_ps(dot1), accd[iy]);
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
@@ -4089,13 +4149,52 @@ template <int nrc> struct Q8_K64 {
constexpr static int nrc_y = nrc;
- Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K64 *)info.src1_row(iy); }
+ Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K128 *)info.src1_row(iy); }
- inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy][i].qs); }
+ inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); }
inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }
inline float scale(int iy, int i) const { return y[iy][i].d; }
- const block_q8_K64 * y[nrc_y];
+ const block_q8_K128 * y[nrc_y];
+};
+
+struct DequantizerIQ1BN {
+ const uint8x16_t m1 = vdupq_n_u8(1);
+ const uint8x16x4_t sign_shuffles = {
+ vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101}),
+ vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303}),
+ vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}),
+ vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}),
+ };
+ const int8x16_t shift = vreinterpretq_s8_u32(vdupq_n_u32(0xfafcfe00));
+ const uint8x16_t qmask = vdupq_n_u8(3);
+ const uint8x16_t shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0909090908080808});
+ const uint8x16_t mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
+ int8x16x4_t signs;
+ uint64x2x4_t a;
+ inline void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, int8x16x4_t& v) {
+ auto all_signs = vdupq_n_u8(extra);
+ all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1);
+ signs.val[0] = vqtbl1q_u8(all_signs, sign_shuffles.val[0]);
+ signs.val[1] = vqtbl1q_u8(all_signs, sign_shuffles.val[1]);
+ signs.val[2] = vqtbl1q_u8(all_signs, sign_shuffles.val[2]);
+ signs.val[3] = vqtbl1q_u8(all_signs, sign_shuffles.val[3]);
+
+ a.val[0] = uint64x2_t{iq1bn_grid_u16[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_u16[ql[1] | ((qh[0] << 4) & 0x0f00)]};
+ a.val[1] = uint64x2_t{iq1bn_grid_u16[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_u16[ql[3] | ((qh[1] << 4) & 0x0f00)]};
+ a.val[2] = uint64x2_t{iq1bn_grid_u16[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_u16[ql[5] | ((qh[2] << 4) & 0x0f00)]};
+ a.val[3] = uint64x2_t{iq1bn_grid_u16[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_u16[ql[7] | ((qh[3] << 4) & 0x0f00)]};
+
+ v.val[0] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), shift), qmask), m1);
+ v.val[1] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), shift), qmask), m1);
+ v.val[2] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), shift), qmask), m1);
+ v.val[3] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), shift), qmask), m1);
+
+ v.val[0] = vmulq_s8(v.val[0], signs.val[0]);
+ v.val[1] = vmulq_s8(v.val[1], signs.val[1]);
+ v.val[2] = vmulq_s8(v.val[2], signs.val[2]);
+ v.val[3] = vmulq_s8(v.val[3], signs.val[3]);
+ }
};
template <int nrc_y>
@@ -4103,23 +4202,10 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
+ DequantizerIQ1BN deq;
float32x4_t accd[nrc_y];
- int8x16x4_t signs;
- uint64x2x4_t a;
- int8x16x4_t v;
-
- const auto m1 = vdupq_n_u8(1);
- const uint8x16x4_t sign_shuffles = {
- vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101}),
- vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303}),
- vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}),
- vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}),
- };
- const auto shift = vreinterpretq_s8_u32(vdupq_n_u32(0xfafcfe00));
- const auto qmask = vdupq_n_u8(3);
- const auto shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0909090908080808});
- const auto mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
+ int8x16x4_t v1, v2;
const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
@@ -4129,50 +4215,57 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_f32(0.f);
- for (int i = 0; i < nb; ++i) {
+ for (int i = 0; i < nb/2; ++i) {
- auto all_signs = vdupq_n_u8(x[i].extra);
- all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1);
- signs.val[0] = vqtbl1q_u8(all_signs, sign_shuffles.val[0]);
- signs.val[1] = vqtbl1q_u8(all_signs, sign_shuffles.val[1]);
- signs.val[2] = vqtbl1q_u8(all_signs, sign_shuffles.val[2]);
- signs.val[3] = vqtbl1q_u8(all_signs, sign_shuffles.val[3]);
-
- auto ql = x[i].ql;
- auto qh = x[i].qh;
- a.val[0] = uint64x2_t{iq1bn_grid_u16[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_u16[ql[1] | ((qh[0] << 4) & 0x0f00)]};
- a.val[1] = uint64x2_t{iq1bn_grid_u16[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_u16[ql[3] | ((qh[1] << 4) & 0x0f00)]};
- a.val[2] = uint64x2_t{iq1bn_grid_u16[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_u16[ql[5] | ((qh[2] << 4) & 0x0f00)]};
- a.val[3] = uint64x2_t{iq1bn_grid_u16[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_u16[ql[7] | ((qh[3] << 4) & 0x0f00)]};
-
- v.val[0] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), shift), qmask), m1);
- v.val[1] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), shift), qmask), m1);
- v.val[2] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), shift), qmask), m1);
- v.val[3] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), shift), qmask), m1);
-
- v.val[0] = vmulq_s8(v.val[0], signs.val[0]);
- v.val[1] = vmulq_s8(v.val[1], signs.val[1]);
- v.val[2] = vmulq_s8(v.val[2], signs.val[2]);
- v.val[3] = vmulq_s8(v.val[3], signs.val[3]);
+ deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1);
+ deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v2);
+ int32x4_t sumi1 = vdupq_n_s32(0);
+ int32x4_t sumi2 = vdupq_n_s32(0);
if constexpr (nrc_y == 1) {
- auto q = q8.load_quants(0, i);
- int32x4_t sumi = vdupq_n_s32(0);
+ auto q1 = q8.load_quants64(0, i, 0);
+ auto q2 = q8.load_quants64(0, i, 1);
for (int j = 0; j < 4; ++j) {
- sumi = ggml_vdotq_s32(sumi, q.val[j], v.val[j]);
+ sumi1 = ggml_vdotq_s32(sumi1, q1.val[j], v1.val[j]);
+ sumi2 = ggml_vdotq_s32(sumi2, q2.val[j], v2.val[j]);
}
- accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i)), vcvtq_f32_s32(sumi));
+ accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i)), vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)));
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
int32x4_t sumi = vdupq_n_s32(0);
auto q = q8.load_quants(iy, i, 0);
- sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[0]), q.val[1], v.val[1]);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]);
q = q8.load_quants(iy, i, 1);
- sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[2]), q.val[1], v.val[3]);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]);
+ q = q8.load_quants(iy, i, 2);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[0]), q.val[1], v2.val[1]);
+ q = q8.load_quants(iy, i, 3);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[2]), q.val[1], v2.val[3]);
accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi));
}
}
}
+ int i = 2*(nb/2);
+ if (i < nb) {
+ deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, v1);
+ if constexpr (nrc_y == 1) {
+ auto q = q8.load_quants(0, i/2, 0);
+ int32x4_t sumi = vdupq_n_s32(0);
+ for (int j = 0; j < 4; ++j) {
+ sumi = ggml_vdotq_s32(sumi, q.val[j], v1.val[j]);
+ }
+ accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i/2)), vcvtq_f32_s32(sumi));
+ } else {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ int32x4_t sumi = vdupq_n_s32(0);
+ auto q = q8.load_quants(iy, i/2, 0);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]);
+ q = q8.load_quants(iy, i/2, 1);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]);
+ accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i/2)), vcvtq_f32_s32(sumi));
+ }
+ }
+ }
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vaddvq_f32(accd[iy]));
@@ -4188,7 +4281,7 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
Q8_K64<nrc_y> q8(info);
float32x4_t accd[nrc_y];
- int8x16x4_t v;
+ int8x16x4_t v1, v2;
const auto m1 = vdupq_n_u8(1);
const auto mask2 = vdupq_n_s8(3);
@@ -4199,34 +4292,86 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
{
auto q2bits = vld1q_u8(x[0].qs);
- v.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
- v.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
- v.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
- v.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
+ v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
+ v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
+ v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
+ v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
+ q2bits = vld1q_u8(x[1].qs);
+ v2.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
+ v2.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
+ v2.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
+ v2.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
for (int iy = 0; iy < nrc_y; ++iy) {
int32x4_t sumi = vdupq_n_s32(0);
auto q = q8.load_quants(iy, 0, 0);
- sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[0]), q.val[1], v.val[1]);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]);
q = q8.load_quants(iy, 0, 1);
- sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[2]), q.val[1], v.val[3]);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]);
+ q = q8.load_quants(iy, 0, 2);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[0]), q.val[1], v2.val[1]);
+ q = q8.load_quants(iy, 0, 3);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[2]), q.val[1], v2.val[3]);
accd[iy] = vmulq_f32(vdupq_n_f32(q8.scale(iy, 0)), vcvtq_f32_s32(sumi));
}
}
- for (int i = 1; i < nb; ++i) {
- auto q2bits = vld1q_u8(x[i].qs);
- v.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
- v.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
- v.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
- v.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
+ if constexpr (nrc_y == 1) {
+ for (int i = 1; i < nb/2; ++i) {
+ auto sumi1 = vdupq_n_s32(0);
+ auto sumi2 = vdupq_n_s32(0);
+ for (int j = 0; j < 2; ++j) {
+ auto q = q8.load_quants64(0, i, j);
+ auto q2bits = vld1q_u8(x[2*i+j].qs);
+ v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
+ v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
+ v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
+ v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
+ sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(sumi1, q.val[0], v1.val[0]), q.val[1], v1.val[1]);
+ sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(sumi2, q.val[2], v1.val[2]), q.val[3], v1.val[3]);
+ }
+ accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i)), vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)));
+ }
+ } else {
+ for (int i = 1; i < nb/2; ++i) {
+ auto q2bits = vld1q_u8(x[2*i+0].qs);
+ v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
+ v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
+ v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
+ v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
+ q2bits = vld1q_u8(x[2*i+1].qs);
+ v2.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
+ v2.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
+ v2.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
+ v2.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ int32x4_t sumi = vdupq_n_s32(0);
+ auto q = q8.load_quants(iy, i, 0);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]);
+ q = q8.load_quants(iy, i, 1);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]);
+ q = q8.load_quants(iy, i, 2);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[0]), q.val[1], v2.val[1]);
+ q = q8.load_quants(iy, i, 3);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[2]), q.val[1], v2.val[3]);
+ accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi));
+ }
+ }
+ }
+ int i = 2*(nb/2);
+ if (i < nb) {
+ auto q2bits = vld1q_u8(x[2*i+0].qs);
+ v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
+ v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
+ v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
+ v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
for (int iy = 0; iy < nrc_y; ++iy) {
int32x4_t sumi = vdupq_n_s32(0);
- auto q = q8.load_quants(iy, i, 0);
- sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[0]), q.val[1], v.val[1]);
- q = q8.load_quants(iy, i, 1);
- sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[2]), q.val[1], v.val[3]);
- accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi));
+ auto q = q8.load_quants(iy, i/2, 0);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]);
+ q = q8.load_quants(iy, i/2, 1);
+ sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]);
+ accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i/2)), vcvtq_f32_s32(sumi));
}
}