diff options
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r-- | iqk_mul_mat.cpp | 142 |
1 files changed, 131 insertions, 11 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 54db5ab4..b53d08e7 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1313,12 +1313,7 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> { const __m256i min_value = _mm256_set1_epi8(minv); }; -//inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) { -// const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); -// const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); -// scales[0] = MM256_SET_M128I(l_scales, l_scales); -// scales[1] = MM256_SET_M128I(h_scales, h_scales); -//} + struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> { DequantizerIQ2S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} @@ -1327,11 +1322,18 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> { inline __m256i load_scales(int i) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales); - auto all = _mm_and_si128(_mm_or_si128(_mm_slli_si128(_mm_srli_epi16(tmp, 4), 8), tmp), _mm_set1_epi8(0xf)); + auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf)); auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1)); - auto shuffle = _mm_set_epi64x(0x0f070e060d050c04, 0x0b030a0209010800); - return _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, shuffle)); - } + return _mm256_cvtepi8_epi16(scales8); + } + //inline __m256i load_scales(int i) { + // d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + // auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales); + // auto all = _mm_and_si128(_mm_or_si128(_mm_slli_si128(_mm_srli_epi16(tmp, 4), 8), tmp), _mm_set1_epi8(0xf)); + // auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1)); + // auto shuffle = _mm_set_epi64x(0x0f070e060d050c04, 0x0b030a0209010800); + // return _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, shuffle)); + //} inline static void prepare_scales(const __m256i& all, __m256i * scales) { auto scales_l = _mm256_castsi256_si128(all); auto scales_h = _mm256_extractf128_si256(all, 1); @@ -1403,6 +1405,120 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> { }; +struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> { + DequantizerIQ2XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + constexpr static int num_blocks = 16; + + inline __m256i load_scales(int i) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales); + auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf)); + auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1)); + return _mm256_cvtepi8_epi16(scales8); + } + inline static void prepare_scales(const __m256i& all, __m256i * scales) { + auto scales_l = _mm256_castsi256_si128(all); + auto scales_h = _mm256_extractf128_si256(all, 1); + scales[0] = MM256_SET_M128I(scales_l, scales_l); + scales[1] = MM256_SET_M128I(scales_h, scales_h); + } + + inline void new_block(int i, __m256i * scales) { + prepare_scales(load_scales(i), scales); + } + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accd, __m256i * scales) { + auto all_scales = load_scales(i); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, i); + auto prod = _mm256_madd_epi16(all_scales, bsums); + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(-d*q8.scale(iy, i)*minv), _mm256_cvtepi32_ps(prod), accd[iy]); + } + prepare_scales(all_scales, scales); + } + + struct Helper { + const __m256i mone = _mm256_set1_epi8(1); + const __m256i mask = _mm256_set1_epi64x(0x8040201008040201); + //const __m256i bhelper = _mm256_set_epi64x(0x8000008000808000, 0x0080800080000080, 0x8000008000808000, 0x0080800080000080); + const __m256i bhelper = load_bhelper(); + const __m256i shuff1 = _mm256_set_epi64x(0x0606060606060606, 0x0404040404040404, 0x0202020202020202, 0x0000000000000000); + const __m256i shuff2 = _mm256_set_epi64x(0x0e0e0e0e0e0e0e0e, 0x0c0c0c0c0c0c0c0c, 0x0a0a0a0a0a0a0a0a, 0x0808080808080808); + static __m256i load_bhelper() { + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + return _mm256_loadu_si256((const __m256i*)k_bit_helper); + } + }; + + union index_t { + __m256i vec; + uint16_t val[8]; + }; + + inline static void make4(const __m256i& data, const __m256i& mask, __m256i * values) { + index_t idx; + idx.vec = _mm256_and_si256(data, mask); + values[0] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 3]], iq2xs_grid[idx.val[ 2]], iq2xs_grid[idx.val[ 1]], iq2xs_grid[idx.val[ 0]]); + values[1] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 7]], iq2xs_grid[idx.val[ 6]], iq2xs_grid[idx.val[ 5]], iq2xs_grid[idx.val[ 4]]); + values[2] = _mm256_set_epi64x(iq2xs_grid[idx.val[11]], iq2xs_grid[idx.val[10]], iq2xs_grid[idx.val[ 9]], iq2xs_grid[idx.val[ 8]]); + values[3] = _mm256_set_epi64x(iq2xs_grid[idx.val[15]], iq2xs_grid[idx.val[14]], iq2xs_grid[idx.val[13]], iq2xs_grid[idx.val[12]]); + } + inline static void sign_value(const __m256i& sign_bits, const __m256i& shuffle, const __m256i& mask, + const __m256i& mone, __m256i& value) { + auto signs = _mm256_shuffle_epi8(sign_bits, shuffle); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, mask), mask); + value = _mm256_sign_epi8(value, _mm256_or_si256(signs, mone)); + } + inline static void sign_values(const __m256i& data, const Helper& helper, __m256i * values) { + auto psb1 = _mm256_srli_epi16(data, 9); + auto psb2 = _mm256_srli_epi16(data, 13); + auto psbc = _mm256_xor_si256(psb1, psb2); + auto oddb = _mm256_shuffle_epi8(helper.bhelper, psbc); + auto full = _mm256_or_si256(psb1, oddb); + auto full_l = _mm256_castsi256_si128(full); + auto full_h = _mm256_extractf128_si256(full, 1); + auto full_1 = MM256_SET_M128I(full_l, full_l); + auto full_2 = MM256_SET_M128I(full_h, full_h); + sign_value(full_1, helper.shuff1, helper.mask, helper.mone, values[0]); + sign_value(full_1, helper.shuff2, helper.mask, helper.mone, values[1]); + sign_value(full_2, helper.shuff1, helper.mask, helper.mone, values[2]); + sign_value(full_2, helper.shuff2, helper.mask, helper.mone, values[3]); + } + inline static void make4_signed(const Helper& helper, const uint16_t * qs, const __m256i& m511, + const __m256i& min_value, __m256i * values) { + auto q2 = _mm256_loadu_si256((const __m256i *)qs); + make4(q2, m511, values); + sign_values(q2, helper, values); + for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value); + } + inline static void make4(const Helper& helper, const uint16_t * qs, const __m256i& m511, __m256i * values, __m256i * q8) { + auto q2 = _mm256_loadu_si256((const __m256i *)qs); + make4(q2, m511, values); + sign_values(q2, helper, q8); + } + + inline void prepare(int i, int j) { + make4_signed(helper, x[i].qs + 16*j, idx_mask, min_value, bits.values); + } + template <typename Q8> + inline void prepare(int i, int j, const Q8& q8, __m256i * q8_quants) { + for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k); + make4(helper, x[i].qs + 16*j, idx_mask, bits.values, q8_quants); + } + + constexpr static int minv = 43; + + SimpleBits bits; + Helper helper; + const __m256i idx_mask = _mm256_set1_epi16(511); + const __m256i min_value = _mm256_set1_epi8(minv); + +}; + // // ============================== Legacy quants // @@ -1778,7 +1894,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>; } else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S> || std::is_same_v<Dequantizer, DequantizerIQ3XXS> || - std::is_same_v<Dequantizer, DequantizerIQ2S>) { + std::is_same_v<Dequantizer, DequantizerIQ2S> || std::is_same_v<Dequantizer, DequantizerIQ2XS>) { m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>; m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>; m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>; @@ -1870,6 +1986,10 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int assert (ne00 % QK_K == 0); MulMat::set_functions<DequantizerIQ2S>(mm); break; + case GGML_TYPE_IQ2_XS: + assert (ne00 % QK_K == 0); + MulMat::set_functions<DequantizerIQ2XS>(mm); + break; case GGML_TYPE_Q4_0: assert (ne00 % QK4_0 == 0); MulMat::set_functions<Q4_0_Unpacker>(mm); |