summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_mul_mat.cpp
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-08-09 10:32:07 +0300
committerKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-08-09 16:00:31 +0200
commitf0d7a0d53b0ecdd43ba85bcd49309b291372ca67 (patch)
tree9af6feced146997d2c93daa200c75f9a890dd016 /ggml/src/iqk/iqk_mul_mat.cpp
parentc77dba5273777c6c43d9745fc96114eba867f6c2 (diff)
Fix Zen4 implementation of iq3_k, iq4_k, iq5_k
See comments in f3a823ce729a7db33e7d4375eae7291bbe6196db
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp95
1 files changed, 54 insertions, 41 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 3b6edb19..32ddb3ff 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -774,6 +774,40 @@ struct IQXKScales {
const __m256i shuffle1 = _mm256_set_epi64x(0x0b0b0b0b09090909, 0x0303030301010101, 0x0a0a0a0a08080808, 0x0202020200000000);
const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0d0d0d0d, 0x0707070705050505, 0x0e0e0e0e0c0c0c0c, 0x0606060604040404);
};
+struct IQXKScales2 {
+ IQXKScales2(uint8_t shift, int8_t min_val) : eshift(_mm256_set1_epi16(shift)), min(_mm256_set1_epi16(min_val)) {}
+ template <typename Q8>
+ inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m512i * scales) const {
+ process(i, d, extra, _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)), q8, accm, scales);
+ }
+ template <typename Q8>
+ inline void process(int i, float d, uint16_t extra, __m256i scales16, const Q8& q8, __m256 * accm, __m512i * scales) const {
+ auto scales_s = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, extra, min, eshift));
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i prod = _mm256_madd_epi16(scales_s, q8.load_bsums(iy, i));
+ accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
+ }
+ auto aux_1 = MM256_SET_M128I(_mm256_castsi256_si128(scales16), _mm256_castsi256_si128(scales16));
+ auto aux_2 = MM256_SET_M128I(_mm256_extracti128_si256(scales16, 1), _mm256_extracti128_si256(scales16, 1));
+ auto scales16_1 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_1), aux_1, 1);
+ auto scales16_2 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_2), aux_2, 1);
+ scales[0] = _mm512_shuffle_epi8(scales16_1, shuffles[0]);
+ scales[1] = _mm512_shuffle_epi8(scales16_1, shuffles[1]);
+ scales[2] = _mm512_shuffle_epi8(scales16_2, shuffles[0]);
+ scales[3] = _mm512_shuffle_epi8(scales16_2, shuffles[1]);
+ }
+ const __m256i eshift;
+ const __m256i min;
+ const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
+ const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
+ const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
+ const __m512i shuffles[2] = {
+ _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(),
+ _mm_set1_epi16(0x0100), 0), _mm_set1_epi16(0x0302), 1), _mm_set1_epi16(0x0504), 2), _mm_set1_epi16(0x0706), 3),
+ _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(),
+ _mm_set1_epi16(0x0908), 0), _mm_set1_epi16(0x0b0a), 1), _mm_set1_epi16(0x0d0c), 2), _mm_set1_epi16(0x0f0e), 3)
+ };
+};
struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
DequantizerIQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(IQXKScales(5, -32)), values(load_values()) {}
@@ -809,7 +843,7 @@ struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
};
struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
- DequantizerIQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(IQXKScales(4, -64)), values(load_values()) {}
+ DequantizerIQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -64), values(load_values()) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
@@ -844,7 +878,7 @@ struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
return _mm_sign_epi8(scl, sch);
}
Q2Bits bits;
- const IQXKScales iqxk;
+ const IQXKScales2 iqxk;
const __m512i values;
const __m512i hmask = _mm512_set1_epi8(4);
@@ -885,7 +919,7 @@ struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
}
Q4Bits bits;
- const IQXKScales iqxk;
+ const IQXKScales2 iqxk;
const __m512i values;
const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
@@ -948,7 +982,7 @@ struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
}
Q4Bits bits;
- const IQXKScales iqxk;
+ const IQXKScales2 iqxk;
__m512i values[2];
const __m512i hmask1 = _mm512_set1_epi8(1);
const __m512i hmask2 = _mm512_set1_epi8(2);
@@ -960,26 +994,13 @@ struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
};
struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
- DequantizerIQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(values); }
+ DequantizerIQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(1, -128) { load_values(values); }
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
prepare(x[i].qs, x[i].qh);
auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales);
- auto scales16 = _mm256_cvtepi8_epi16(scales8);
- auto bs = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, x[i].extra, min, shift));
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- auto prod = _mm256_madd_epi16(bs, q8.load_bsums(iy, i));
- accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
- }
- auto aux_1 = MM256_SET_M128I(_mm256_castsi256_si128(scales16), _mm256_castsi256_si128(scales16));
- auto aux_2 = MM256_SET_M128I(_mm256_extracti128_si256(scales16, 1), _mm256_extracti128_si256(scales16, 1));
- auto scales16_1 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_1), aux_1, 1);
- auto scales16_2 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_2), aux_2, 1);
- scales[0] = _mm512_shuffle_epi8(scales16_1, shuffles[0]);
- scales[1] = _mm512_shuffle_epi8(scales16_1, shuffles[1]);
- scales[2] = _mm512_shuffle_epi8(scales16_2, shuffles[0]);
- scales[3] = _mm512_shuffle_epi8(scales16_2, shuffles[1]);
+ iqxk.process(i, d, x[i].extra, _mm256_cvtepi8_epi16(scales8), q8, accm, scales);
}
inline __m512i make_one(__m512i l, __m512i h) const {
auto p = _mm512_shuffle_epi8(values[0], l);
@@ -994,8 +1015,6 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
auto h256_2 = _mm256_loadu_si256((const __m256i *)qh + 1);
auto h1 = _mm512_inserti32x8(_mm512_castsi256_si512(h256_1), _mm256_srli_epi16(h256_1, 4), 1);
auto h2 = _mm512_inserti32x8(_mm512_castsi256_si512(h256_2), _mm256_srli_epi16(h256_2, 4), 1);
- //auto h1 = _mm512_loadu_si512((const __m512i *)qh);
- //auto h2 = _mm512_srli_epi16(h1, 4);
bits.values[0] = make_one(bits.values[0], h1);
bits.values[1] = make_one(bits.values[1], _mm512_srli_epi16(h1, 2));
bits.values[2] = make_one(bits.values[2], h2);
@@ -1025,18 +1044,9 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
}
Q4Bits bits;
+ IQXKScales2 iqxk;
__m512i values[4];
__m512i masks[3] = { _mm512_set1_epi8(0x01), _mm512_set1_epi8(0x02), _mm512_set1_epi8(0x03) };
- const __m256i min = _mm256_set1_epi16(-128);
- const __m256i shift = _mm256_set1_epi16(1);
- const __m512i shuffles[2] = {
- _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(),
- _mm_set1_epi16(0x0100), 0), _mm_set1_epi16(0x0302), 1), _mm_set1_epi16(0x0504), 2), _mm_set1_epi16(0x0706), 3),
- _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(),
- _mm_set1_epi16(0x0908), 0), _mm_set1_epi16(0x0b0a), 1), _mm_set1_epi16(0x0d0c), 2), _mm_set1_epi16(0x0f0e), 3)
- };
- const __m256i shuffle1 = _mm256_set_epi64x(0x0707070703030303, 0x0606060602020202, 0x0505050501010101, 0x0404040400000000);
- const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0b0b0b0b, 0x0e0e0e0e0a0a0a0a, 0x0d0d0d0d09090909, 0x0c0c0c0c08080808);
const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
};
@@ -1552,15 +1562,15 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
}
}
inline __m256i make_one(__m256i l, __m256i hbits) const {
- auto mask4 = _mm256_cmpeq_epi8(_mm256_and_si256(hbits, mh3), mh3);
- auto h1 = _mm256_andnot_si256(mask4, hbits);
- auto mask2 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh1), mh1);
- auto mask3 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh2), mh2);
- auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(0xff));
- return _mm256_or_si256(_mm256_or_si256(_mm256_and_si256(mask1, _mm256_shuffle_epi8(values[0], l)),
- _mm256_and_si256(mask2, _mm256_shuffle_epi8(values[1], l))),
- _mm256_or_si256(_mm256_and_si256(mask3, _mm256_shuffle_epi8(values[2], l)),
- _mm256_and_si256(mask4, _mm256_shuffle_epi8(values[3], l))));
+ auto mask4 = _mm256_cmpeq_epi8(_mm256_and_si256(hbits, mh3), mh3);
+ auto h1 = _mm256_andnot_si256(mask4, hbits);
+ auto mask2 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh1), mh1);
+ auto mask3 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh2), mh2);
+ auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(0xff));
+ return _mm256_or_si256(_mm256_or_si256(_mm256_and_si256(mask1, _mm256_shuffle_epi8(values[0], l)),
+ _mm256_and_si256(mask2, _mm256_shuffle_epi8(values[1], l))),
+ _mm256_or_si256(_mm256_and_si256(mask3, _mm256_shuffle_epi8(values[2], l)),
+ _mm256_and_si256(mask4, _mm256_shuffle_epi8(values[3], l))));
}
static void load_values(__m256i * values) {
static const uint8_t kvalues_iq6nl[64] = {
@@ -3375,7 +3385,10 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
}
else {
#ifdef HAVE_FANCY_SIMD
- if constexpr (std::is_same_v<Dequantizer, DequantizerIQ6K>) {
+ if constexpr (std::is_same_v<Dequantizer, DequantizerIQ6K> ||
+ std::is_same_v<Dequantizer, DequantizerIQ5K> ||
+ std::is_same_v<Dequantizer, DequantizerIQ4K> ||
+ std::is_same_v<Dequantizer, DequantizerIQ3K>) {
m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;
m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 3>;