summaryrefslogtreecommitdiff
path: root/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r--iqk_mul_mat.cpp142
1 files changed, 131 insertions, 11 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 54db5ab4..b53d08e7 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -1313,12 +1313,7 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
const __m256i min_value = _mm256_set1_epi8(minv);
};
-//inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {
-// const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
-// const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
-// scales[0] = MM256_SET_M128I(l_scales, l_scales);
-// scales[1] = MM256_SET_M128I(h_scales, h_scales);
-//}
+
struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
DequantizerIQ2S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
@@ -1327,11 +1322,18 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
inline __m256i load_scales(int i) {
d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
- auto all = _mm_and_si128(_mm_or_si128(_mm_slli_si128(_mm_srli_epi16(tmp, 4), 8), tmp), _mm_set1_epi8(0xf));
+ auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf));
auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
- auto shuffle = _mm_set_epi64x(0x0f070e060d050c04, 0x0b030a0209010800);
- return _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, shuffle));
- }
+ return _mm256_cvtepi8_epi16(scales8);
+ }
+ //inline __m256i load_scales(int i) {
+ // d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ // auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
+ // auto all = _mm_and_si128(_mm_or_si128(_mm_slli_si128(_mm_srli_epi16(tmp, 4), 8), tmp), _mm_set1_epi8(0xf));
+ // auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
+ // auto shuffle = _mm_set_epi64x(0x0f070e060d050c04, 0x0b030a0209010800);
+ // return _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, shuffle));
+ //}
inline static void prepare_scales(const __m256i& all, __m256i * scales) {
auto scales_l = _mm256_castsi256_si128(all);
auto scales_h = _mm256_extractf128_si256(all, 1);
@@ -1403,6 +1405,120 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
};
+struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
+ DequantizerIQ2XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ constexpr static int num_blocks = 16;
+
+ inline __m256i load_scales(int i) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
+ auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf));
+ auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
+ return _mm256_cvtepi8_epi16(scales8);
+ }
+ inline static void prepare_scales(const __m256i& all, __m256i * scales) {
+ auto scales_l = _mm256_castsi256_si128(all);
+ auto scales_h = _mm256_extractf128_si256(all, 1);
+ scales[0] = MM256_SET_M128I(scales_l, scales_l);
+ scales[1] = MM256_SET_M128I(scales_h, scales_h);
+ }
+
+ inline void new_block(int i, __m256i * scales) {
+ prepare_scales(load_scales(i), scales);
+ }
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accd, __m256i * scales) {
+ auto all_scales = load_scales(i);
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, i);
+ auto prod = _mm256_madd_epi16(all_scales, bsums);
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(-d*q8.scale(iy, i)*minv), _mm256_cvtepi32_ps(prod), accd[iy]);
+ }
+ prepare_scales(all_scales, scales);
+ }
+
+ struct Helper {
+ const __m256i mone = _mm256_set1_epi8(1);
+ const __m256i mask = _mm256_set1_epi64x(0x8040201008040201);
+ //const __m256i bhelper = _mm256_set_epi64x(0x8000008000808000, 0x0080800080000080, 0x8000008000808000, 0x0080800080000080);
+ const __m256i bhelper = load_bhelper();
+ const __m256i shuff1 = _mm256_set_epi64x(0x0606060606060606, 0x0404040404040404, 0x0202020202020202, 0x0000000000000000);
+ const __m256i shuff2 = _mm256_set_epi64x(0x0e0e0e0e0e0e0e0e, 0x0c0c0c0c0c0c0c0c, 0x0a0a0a0a0a0a0a0a, 0x0808080808080808);
+ static __m256i load_bhelper() {
+ static const uint8_t k_bit_helper[32] = {
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ };
+ return _mm256_loadu_si256((const __m256i*)k_bit_helper);
+ }
+ };
+
+ union index_t {
+ __m256i vec;
+ uint16_t val[8];
+ };
+
+ inline static void make4(const __m256i& data, const __m256i& mask, __m256i * values) {
+ index_t idx;
+ idx.vec = _mm256_and_si256(data, mask);
+ values[0] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 3]], iq2xs_grid[idx.val[ 2]], iq2xs_grid[idx.val[ 1]], iq2xs_grid[idx.val[ 0]]);
+ values[1] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 7]], iq2xs_grid[idx.val[ 6]], iq2xs_grid[idx.val[ 5]], iq2xs_grid[idx.val[ 4]]);
+ values[2] = _mm256_set_epi64x(iq2xs_grid[idx.val[11]], iq2xs_grid[idx.val[10]], iq2xs_grid[idx.val[ 9]], iq2xs_grid[idx.val[ 8]]);
+ values[3] = _mm256_set_epi64x(iq2xs_grid[idx.val[15]], iq2xs_grid[idx.val[14]], iq2xs_grid[idx.val[13]], iq2xs_grid[idx.val[12]]);
+ }
+ inline static void sign_value(const __m256i& sign_bits, const __m256i& shuffle, const __m256i& mask,
+ const __m256i& mone, __m256i& value) {
+ auto signs = _mm256_shuffle_epi8(sign_bits, shuffle);
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, mask), mask);
+ value = _mm256_sign_epi8(value, _mm256_or_si256(signs, mone));
+ }
+ inline static void sign_values(const __m256i& data, const Helper& helper, __m256i * values) {
+ auto psb1 = _mm256_srli_epi16(data, 9);
+ auto psb2 = _mm256_srli_epi16(data, 13);
+ auto psbc = _mm256_xor_si256(psb1, psb2);
+ auto oddb = _mm256_shuffle_epi8(helper.bhelper, psbc);
+ auto full = _mm256_or_si256(psb1, oddb);
+ auto full_l = _mm256_castsi256_si128(full);
+ auto full_h = _mm256_extractf128_si256(full, 1);
+ auto full_1 = MM256_SET_M128I(full_l, full_l);
+ auto full_2 = MM256_SET_M128I(full_h, full_h);
+ sign_value(full_1, helper.shuff1, helper.mask, helper.mone, values[0]);
+ sign_value(full_1, helper.shuff2, helper.mask, helper.mone, values[1]);
+ sign_value(full_2, helper.shuff1, helper.mask, helper.mone, values[2]);
+ sign_value(full_2, helper.shuff2, helper.mask, helper.mone, values[3]);
+ }
+ inline static void make4_signed(const Helper& helper, const uint16_t * qs, const __m256i& m511,
+ const __m256i& min_value, __m256i * values) {
+ auto q2 = _mm256_loadu_si256((const __m256i *)qs);
+ make4(q2, m511, values);
+ sign_values(q2, helper, values);
+ for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);
+ }
+ inline static void make4(const Helper& helper, const uint16_t * qs, const __m256i& m511, __m256i * values, __m256i * q8) {
+ auto q2 = _mm256_loadu_si256((const __m256i *)qs);
+ make4(q2, m511, values);
+ sign_values(q2, helper, q8);
+ }
+
+ inline void prepare(int i, int j) {
+ make4_signed(helper, x[i].qs + 16*j, idx_mask, min_value, bits.values);
+ }
+ template <typename Q8>
+ inline void prepare(int i, int j, const Q8& q8, __m256i * q8_quants) {
+ for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
+ make4(helper, x[i].qs + 16*j, idx_mask, bits.values, q8_quants);
+ }
+
+ constexpr static int minv = 43;
+
+ SimpleBits bits;
+ Helper helper;
+ const __m256i idx_mask = _mm256_set1_epi16(511);
+ const __m256i min_value = _mm256_set1_epi8(minv);
+
+};
+
//
// ============================== Legacy quants
//
@@ -1778,7 +1894,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>;
}
else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S> || std::is_same_v<Dequantizer, DequantizerIQ3XXS> ||
- std::is_same_v<Dequantizer, DequantizerIQ2S>) {
+ std::is_same_v<Dequantizer, DequantizerIQ2S> || std::is_same_v<Dequantizer, DequantizerIQ2XS>) {
m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;
@@ -1870,6 +1986,10 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ2S>(mm);
break;
+ case GGML_TYPE_IQ2_XS:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerIQ2XS>(mm);
+ break;
case GGML_TYPE_Q4_0:
assert (ne00 % QK4_0 == 0);
MulMat::set_functions<Q4_0_Unpacker>(mm);