summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp172
1 files changed, 172 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 6c3a3575..8c649de4 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -2383,6 +2383,79 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
};
};
+struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
+ DequantizerIQ5KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(values); }
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2);
+ scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
+ auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
+ s8k.accum_mins(scales_s, q8, i, d, accm);
+ auto scales256 = MM256_SET_M128I(scales128, scales128);
+ auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
+ scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
+ scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
+ scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
+ scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
+ prepare(x[i].qs, x[i].qh);
+ }
+ inline void prepare(const uint8_t * q4, const uint8_t * qh) {
+ bits.prepare64(q4);
+ auto h256 = _mm256_loadu_si256((const __m256i *)qh);
+ auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 2), 1);
+ auto m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1);
+ auto m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
+ bits.values[0] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[0]), m1, values[1], bits.values[0]);
+ bits.values[1] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[1]), m2, values[1], bits.values[1]);
+ hbits = _mm512_srli_epi16(hbits, 4);
+ m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1);
+ m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
+ bits.values[2] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[2]), m1, values[1], bits.values[2]);
+ bits.values[3] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[3]), m2, values[1], bits.values[3]);
+ // We now have in bits.valuse[0]: 0...31, 64...95
+ // bits.valuse[1]: 32..63, 96..127
+ // etc.
+ auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
+ bits.values[1] = _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]);
+ bits.values[0] = tmp;
+ tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
+ bits.values[3] = _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]);
+ bits.values[2] = tmp;
+ }
+ static void load_values(__m512i * values) {
+ static const uint8_t kvalues_iq5nl[32] = {
+ 2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127,
+ 133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249,
+ };
+ auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0);
+ auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1);
+ auto values256_1 = MM256_SET_M128I(values128_1, values128_1);
+ auto values256_2 = MM256_SET_M128I(values128_2, values128_2);
+ values[0] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_1), values256_1, 1);
+ values[1] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_2), values256_2, 1);
+ }
+
+ Q4Bits bits;
+ Scales8KBase s8k;
+ __m512i values[2];
+ const __m512i hmask1 = _mm512_set1_epi8(1);
+ const __m512i hmask2 = _mm512_set1_epi8(2);
+ const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
+ const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
+ const __m128i m127 = _mm_set1_epi16(-127);
+ const __m128i m128 = _mm_set1_epi16(-128);
+ const __m128i mask = _mm_set1_epi16(254);
+ const __m128i m1 = _mm_set1_epi16(1);
+ const __m128i m2 = _mm_set1_epi16(2);
+ const __m512i shuffles[4] = {
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
+ };
+};
+
struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
template <typename Q8>
@@ -2977,6 +3050,53 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
const __m128i m4 = _mm_set1_epi16(4);
};
+struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
+ DequantizerIQ5KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(values); }
+ template <typename Q8>
+ inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
+ hbits = _mm256_loadu_si256((const __m256i *)x[i].qh);
+ auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2);
+ scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
+ auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
+ s8k.accum_mins(scales_s, q8, i, d, accd);
+ return MM256_SET_M128I(scales128, scales128);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ auto h = j == 0 ? hbits : _mm256_srli_epi16(hbits, 4);
+ for (int k = 0; k < 4; ++k) {
+ auto qh = _mm256_and_si256(_mm256_slli_epi16(h, 7-k), mh);
+ auto q5vl = _mm256_or_si256(bits.values[k], qh);
+ auto q5vh = _mm256_or_si256(bits.values[k], _mm256_xor_si256(qh, mh));
+ bits.values[k] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
+ }
+ }
+ static void load_values(__m256i * values) {
+ static const uint8_t kvalues_iq5nl[32] = {
+ 2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127,
+ 133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249,
+ };
+ auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0);
+ auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1);
+ values[0] = MM256_SET_M128I(values128_1, values128_1);
+ values[1] = MM256_SET_M128I(values128_2, values128_2);
+ }
+
+ Q4Bits bits;
+ Scales8KBase s8k;
+ __m256i hbits;
+ __m256i values[2];
+ const __m128i maskl = _mm_set1_epi8(0xf);
+ const __m128i maskh = _mm_set1_epi8(0x30);
+ const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing
+ const __m128i mask = _mm_set1_epi16(254);
+ const __m128i m127 = _mm_set1_epi16(-127);
+ const __m128i m128 = _mm_set1_epi16(-128);
+ const __m128i m1 = _mm_set1_epi16(1);
+ const __m128i m2 = _mm_set1_epi16(2);
+};
+
struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {}
template <typename Q8>
@@ -9455,6 +9575,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
std::is_same_v<Dequantizer, DequantizerIQ3K> ||
std::is_same_v<Dequantizer, DequantizerIQ4XS>||
std::is_same_v<Dequantizer, DequantizerIQ4KS>||
+ std::is_same_v<Dequantizer, DequantizerIQ5KS>||
std::is_same_v<Dequantizer, DequantizerIQ4KSS>) {
m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;
@@ -9620,6 +9741,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ4KS>(mm);
break;
+ case GGML_TYPE_IQ5_KS:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerIQ5KS>(mm);
+ break;
case GGML_TYPE_IQ4_KSS:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ4KSS>(mm);
@@ -10926,6 +11051,50 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
const int16x8_t m127 = vdupq_n_s16(-127);
};
+struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
+ DequantizerIQ5KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq5nl_values)) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ (void)q8;
+ (void)acc;
+ auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(vld1_u8(x[i].scales)), mask)), m127);
+ int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
+ return scales;
+ }
+
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs+64*j);
+ if (j == 1) {
+ for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4);
+ }
+ bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm));
+ bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm));
+ bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm));
+ bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm));
+ bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm));
+ bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm));
+ bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm));
+ bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm));
+ for (int k = 0; k < 4; ++k) {
+ bits.b1.val[k] = vqtbl2q_s8(values, bits.b1.val[k]);
+ bits.b2.val[k] = vqtbl2q_s8(values, bits.b2.val[k]);
+ }
+ }
+
+ Q4bits bits;
+ const int8x16x2_t values;
+ const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
+ const uint8x16_t hm = vdupq_n_u8(0x10);
+ const uint16x8_t mask = vdupq_n_u16(254);
+ const int16x8_t m127 = vdupq_n_s16(-127);
+ uint8x16x2_t hbits;
+
+};
+
struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
DequantizerIQ4KSS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {}
@@ -14894,6 +15063,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_IQ5_K:
MulMat::set_functions<DequantizerIQ5K>(m);
break;
+ case GGML_TYPE_IQ5_KS:
+ MulMat::set_functions<DequantizerIQ5KS>(m);
+ break;
case GGML_TYPE_IQ6_K:
MulMat::set_functions<DequantizerIQ6K>(m);
break;