diff options
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r-- | iqk_mul_mat.cpp | 85 |
1 files changed, 36 insertions, 49 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 581eb401..75221048 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1145,6 +1145,21 @@ struct SignHelper { //aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256, mask1), mask2); //return _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone); } + inline void sign_4_values(const uint16_t * sign_bits, __m256i * values) const { + auto s128 = _mm_loadu_si128((const __m128i *)sign_bits); + auto s256 = MM256_SET_M128I(s128, s128); + __m256i aux256; + auto shuffle = mask1; + auto step = _mm256_set1_epi8(4); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step); + values[0] = _mm256_sign_epi8(values[0], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step); + values[1] = _mm256_sign_epi8(values[1], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step); + values[2] = _mm256_sign_epi8(values[2], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step); + values[3] = _mm256_sign_epi8(values[3], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone)); + } const __m256i mask1 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); const __m256i mask2 = _mm256_set1_epi64x(0x8040201008040201ull); const __m256i mone = _mm256_set1_epi8(1); @@ -1181,65 +1196,37 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> { uint32_t val[8]; }; - struct SignSelf { - SignSelf(const SignHelper& sh, const __m256i& min_value, __m256i * values, const uint16_t * sidx) : - sh(sh), min_value(min_value), values(values), sidx(sidx) {} - inline void apply(int k) { - values[k] = _mm256_add_epi8(_mm256_sign_epi8(values[k], sh.make_signs(sidx+2*k)), min_value); - } - const SignHelper& sh; - const __m256i& min_value; - __m256i * values; - const uint16_t * sidx; - }; - template <typename Q8> - struct SignQ8 { - SignQ8(const Q8& q8, const SignHelper& sh, __m256i * values, const uint16_t * sidx, int i, int j) : - q8(q8), sh(sh), values(values), sidx(sidx), i(i), j(j) {} - inline void apply(int k) { - values[k] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+k), sh.make_signs(sidx+2*k)); - } - const Q8& q8; - const SignHelper& sh; - __m256i * values; - const uint16_t * sidx; - int i; - int j; - }; - - template <typename ApplySignes> - inline static void make1(int k, const __m128i& idx_l, uint8_t qh, __m256i * values, const __m256i& idx_shift, const __m256i& idx_mask, - ApplySignes& as) { + inline static void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values, const __m256i& idx_shift, const __m256i& idx_mask) { index_t idx; - idx.vec = _mm256_set1_epi32(qh); - idx.vec = _mm256_and_si256(_mm256_sllv_epi32(idx.vec, idx_shift), idx_mask); - idx.vec = _mm256_or_si256(idx.vec, _mm256_cvtepi16_epi32(idx_l)); - values[k] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], + auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs)); + auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask); + idx.vec = _mm256_or_si256(idx_h, idx_l); + values[0] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], + iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]); + idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs + 8))); + idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask); + idx.vec = _mm256_or_si256(idx_h, idx_l); + values[1] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]); - as.apply(k); - } - template <typename ApplySignes> - inline static void make2(int k, const uint8_t * qs, const uint8_t * qh, - __m256i * values, const __m256i& idx_shift, const __m256i& idx_mask, ApplySignes& as) { - auto idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); - make1(k+0, _mm256_castsi256_si128 (idx_l ), qh[0], values, idx_shift, idx_mask, as); - make1(k+1, _mm256_extractf128_si256(idx_l, 1), qh[1], values, idx_shift, idx_mask, as); } inline void prepare(int i, int j) { - auto qs = x[i].qs + 32*j; - auto qh = x[i].qh + 4*j; - SignSelf ss(sh, min_value, bits.values, (const uint16_t *)x[i].signs + 8*j); - make2(0, qs+ 0, qh+0, bits.values, idx_shift, idx_mask, ss); - make2(2, qs+16, qh+2, bits.values, idx_shift, idx_mask, ss); + prepare_unsigned(i, j); + sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, bits.values); + for (int k = 0; k < 4; ++k) bits.values[k] = _mm256_add_epi8(bits.values[k], min_value); } template <typename Q8> inline void prepare(int i, int j, const Q8& q8, __m256i * q8_quants) { + prepare_unsigned(i, j); + for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k); + sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants); + } + + inline void prepare_unsigned(int i, int j) { auto qs = x[i].qs + 32*j; auto qh = x[i].qh + 4*j; - SignQ8 sq8(q8, sh, q8_quants, (const uint16_t *)x[i].signs + 8*j, i, j); - make2(0, qs+ 0, qh+0, bits.values, idx_shift, idx_mask, sq8); - make2(2, qs+16, qh+2, bits.values, idx_shift, idx_mask, sq8); + make2(qs+ 0, qh+0, bits.values+0, idx_shift, idx_mask); + make2(qs+16, qh+2, bits.values+2, idx_shift, idx_mask); } constexpr static int minv = 16; |