summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--iqk_mul_mat.cpp61
1 files changed, 48 insertions, 13 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 44d0703a..0f53e02c 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -1342,8 +1342,11 @@ template <int nrc> struct Q8_K64 {
struct DequantizerIQ1BN {
const __m256i m1_8 = _mm256_set1_epi8(1);
- const __m128i mask1 = _mm_set1_epi8(0xf0);
+#ifdef HAVE_FANCY_SIMD
+ const __m128i shifthh = _mm_set_epi16(5, 6, 7, 8, 9, 10, 11, 12);
+#else
const __m128i mulhh = _mm_set_epi16(32, 64, 128, 256, 512, 1024, 2048, 4096);
+#endif
const __m128i maskhh = _mm_set1_epi16(4096);
const __m256i shuffles[4] = {
_mm256_set_epi64x(0x0302030203020302, 0x0302030203020302, 0x0100010001000100, 0x0100010001000100),
@@ -1353,22 +1356,54 @@ struct DequantizerIQ1BN {
};
const __m256i mult = _mm256_set_epi16(8, 24, 72, 216, 648, 1944, 5832, 17496, 8, 24, 72, 216, 648, 1944, 5832, 17496);
const __m256i m3 = _mm256_set1_epi16(3);
+ const __m128i shuff_l = _mm_set_epi8(-128, 8, -128, 7, -128, 6, -128, 5, -128, 4, -128, 3, -128, 2, -128, 1);
+ const __m128i shuff_h = _mm_set_epi8(12, -128, 11, -128, 10, -128, 9, -128, 12, -128, 11, -128, 10, -128, 9, -128);
+ const __m128i shift_h = _mm_set_epi32(4, 4, 0, 0);
+ const __m128i mask_h = _mm_set1_epi16(0x0f00);
+ const __m128i shuff_hh = _mm_set_epi8(-128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0);
+#ifdef HAVE_FANCY_SIMD
+ const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
+#endif
- IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) {
-
- auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql));
- uint32_t aux32; std::memcpy(&aux32, qh, 4);
- auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1));
- auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh);
- auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3));
+ IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) {
+ auto data = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes!
+ auto aux1 = _mm_shuffle_epi8(data, shuff_l);
+ auto aux2 = _mm_and_si128(_mm_srlv_epi32(_mm_shuffle_epi8(data, shuff_h), shift_h), mask_h);
+#ifdef HAVE_FANCY_SIMD
+ auto aux3 = _mm_and_si128(_mm_sllv_epi16(_mm_shuffle_epi8(data, shuff_hh), shifthh), maskhh);
+#else
+ auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_shuffle_epi8(data, shuff_hh), mulhh), maskhh);
+#endif
+ auto all128 = _mm_or_si128(_mm_or_si128(aux1, aux2), aux3);
auto all = MM256_SET_M128I(all128, all128);
auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
+#ifdef HAVE_FANCY_SIMD
+ v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8);
+ v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8);
+#else
v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8);
v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8);
+#endif
}
+
+ //IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) {
+
+ // auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql));
+ // uint32_t aux32; std::memcpy(&aux32, qh, 4);
+ // auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1));
+ // auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh);
+ // auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3));
+ // auto all = MM256_SET_M128I(all128, all128);
+ // auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
+ // auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
+ // auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
+ // auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
+ // v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8);
+ // v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8);
+ //}
};
template <int nrc_y>
@@ -1392,8 +1427,8 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
if constexpr (nrc_y == 1) {
__m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256();
for (int i = 0; i < nb/2; ++i) {
- deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]);
- deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]);
+ deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
+ deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
#if defined __AVX512VNNI__ && defined __AVX512VL__
auto dot1 = _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0]);
auto dot2 = _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]);
@@ -1418,8 +1453,8 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
for (int i = 0; i < nb/2; ++i) {
- deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]);
- deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]);
+ deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
+ deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
for (int iy = 0; iy < nrc_y; ++iy) {
#if defined __AVX512VNNI__ && defined __AVX512VL__
@@ -1442,7 +1477,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
}
int i = 2*(nb/2);
if (i < nb) {
- deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, val[0], val[1]);
+ deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);