summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp178
1 files changed, 175 insertions, 3 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 66d26a25..7cd0dbf5 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -1209,6 +1209,67 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
};
};
+struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
+ DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ uint32_t aux32[2];
+ auto b1 = _mm512_loadu_si512((const __m512i *)x[i].qs + 0);
+ auto b2 = _mm512_loadu_si512((const __m512i *)x[i].qs + 1);
+ auto bs1 = _mm512_and_si512(b1, mask15);
+ bs1 = _mm512_xor_si512(bs1, _mm512_srli_epi16(bs1, 1));
+ auto bs2 = _mm512_and_si512(b2, mask15);
+ bs2 = _mm512_xor_si512(bs2, _mm512_srli_epi16(bs2, 1));
+ bits.values[0] = _mm512_and_si512(bs1, bits.ml);
+ bits.values[1] = _mm512_and_si512(_mm512_srli_epi16(bs1, 4), bits.ml);
+ bits.values[2] = _mm512_and_si512(bs2, bits.ml);
+ bits.values[3] = _mm512_and_si512(_mm512_srli_epi16(bs2, 4), bits.ml);
+ auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
+ bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
+ bits.values[0] = _mm512_shuffle_epi8(values, tmp);
+ tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
+ bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
+ bits.values[2] = _mm512_shuffle_epi8(values, tmp);
+ //
+ // Now the more difficult part - prepare the scales
+ //
+ aux32[0] = _mm512_cmpeq_epi16_mask(_mm512_and_si512(b1, mask1), mask1);
+ aux32[1] = _mm512_cmpeq_epi16_mask(_mm512_and_si512(b2, mask1), mask1);
+
+ auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)aux32));
+ auto m1 = _mm512_castsi512_si128(mask1);
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4);
+ scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
+ auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
+ s8k.accum_mins(scales_s, q8, i, d, accm);
+ auto scales256 = MM256_SET_M128I(scales128, scales128);
+ auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
+ scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
+ scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
+ scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
+ scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
+ }
+
+ Q4Bits bits;
+ Scales8KBase s8k;
+ const __m512i values;
+ const __m512i mask15 = _mm512_set1_epi16(0xfffe);
+ const __m512i mask1 = _mm512_set1_epi16(1);
+ const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
+ const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
+ const __m128i mask = _mm_set1_epi16(254);
+ const __m128i m127 = _mm_set1_epi16(-127);
+ const __m128i m128 = _mm_set1_epi16(-128);
+ const __m128i m4 = _mm_set1_epi16(4);
+ const __m512i shuffles[4] = {
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
+ };
+};
+
+
template <typename Q8>
inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
@@ -1821,8 +1882,54 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
const __m128i m128 = _mm_set1_epi16(-128);
const __m128i m1 = _mm_set1_epi16(1);
const __m128i m4 = _mm_set1_epi16(4);
- const __m256i shuff1 = _mm256_set_epi64x(0x0706070605040504, 0x0302030201000100, 0x0706070605040504, 0x0302030201000100);
- const __m256i shuff2 = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908);
+};
+
+struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
+ DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {}
+ template <typename Q8>
+ inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
+ union { __m256i vec; uint16_t val[16]; } helper;
+ for (int k = 0; k < 4; ++k) {
+ data[k] = _mm256_loadu_si256((const __m256i *)x[i].qs + k);
+ auto p = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(data[k], m1), m1), smask);
+ p = _mm256_add_epi32(_mm256_unpackhi_epi64(p, p), p);
+ p = _mm256_add_epi32(_mm256_shuffle_epi32(p, _MM_SHUFFLE(2, 3, 0, 1)), p);
+ helper.vec = _mm256_hadd_epi16(p, p);
+ aux[2*k+0] = helper.val[0];
+ aux[2*k+1] = helper.val[8];
+ data[k] = _mm256_and_si256(data[k], bmask);
+ data[k] = _mm256_xor_si256(data[k], _mm256_srli_epi16(data[k], 1));
+ }
+ auto scales128 = _mm_loadu_si128((const __m128i *)aux);
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, _mm256_castsi256_si128(m1)), _mm256_castsi256_si128(m1)), m4);
+ scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
+ auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
+ s8k.accum_mins(scales_s, q8, i, d, accd);
+ return MM256_SET_M128I(scales128, scales128);
+ }
+ inline void prepare(int, int j) {
+ for (int k = 0; k < 2; ++k) {
+ auto p1 = _mm256_castsi256_si128(data[2*j+k]);
+ auto p2 = _mm256_extractf128_si256(data[2*j+k], 1);
+ bits.values[2*k+0] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(p1, 4), p1), bits.ml);
+ bits.values[2*k+0] = _mm256_shuffle_epi8(values, bits.values[2*k+0]);
+ bits.values[2*k+1] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(p2, 4), p2), bits.ml);
+ bits.values[2*k+1] = _mm256_shuffle_epi8(values, bits.values[2*k+1]);
+ }
+ }
+
+ Q4Bits bits;
+ Scales8KBase s8k;
+ const __m256i values;
+ __m256i data[4];
+ const __m256i smask = _mm256_set_epi64x(0x0080004000200010, 0x0008000400020001, 0x0080004000200010, 0x0008000400020001);
+ const __m256i bmask = _mm256_set1_epi16(0xfffe);
+ const __m128i mask = _mm_set1_epi16(254);
+ const __m128i m127 = _mm_set1_epi16(-127);
+ const __m128i m128 = _mm_set1_epi16(-128);
+ const __m256i m1 = _mm256_set1_epi16(1);
+ const __m128i m4 = _mm_set1_epi16(4);
+ uint16_t aux[8];
};
struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
@@ -3848,7 +3955,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
std::is_same_v<Dequantizer, DequantizerIQ4K> ||
std::is_same_v<Dequantizer, DequantizerIQ3K> ||
std::is_same_v<Dequantizer, DequantizerIQ4XS>||
- std::is_same_v<Dequantizer, DequantizerIQ4KS>) {
+ std::is_same_v<Dequantizer, DequantizerIQ4KS>||
+ std::is_same_v<Dequantizer, DequantizerIQ4KSS>) {
m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;
m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 3>;
@@ -4012,6 +4120,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ4KS>(mm);
break;
+ case GGML_TYPE_IQ4_KSS:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerIQ4KSS>(mm);
+ break;
case GGML_TYPE_IQ2_K:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ2K>(mm);
@@ -4945,6 +5057,63 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
const int16x8_t m127 = vdupq_n_s16(-127);
};
+struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
+
+ DequantizerIQ4KSS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ (void)q8;
+ (void)acc;
+ auto q4bits_1 = vld1q_u16_x4((const uint16_t *)x[i].qs);
+ q4bits_2 = vld1q_u16_x4((const uint16_t *)x[i].qs + 32);
+ for (int k = 0; k < 4; ++k) {
+ aux[k+0] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_1.val[k], m1), shift));
+ aux[k+4] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_2.val[k], m1), shift));
+ q4bits_1.val[k] = vandq_u16(q4bits_1.val[k], bmask);
+ q4bits_1.val[k] = veorq_u16(q4bits_1.val[k], vshrq_n_u16(q4bits_1.val[k], 1));
+ q4bits_2.val[k] = vandq_u16(q4bits_2.val[k], bmask);
+ q4bits_2.val[k] = veorq_u16(q4bits_2.val[k], vshrq_n_u16(q4bits_2.val[k], 1));
+ }
+ make_quants(q4bits_1, bits, aux);
+ auto scales16 = vld1q_s16(aux);
+ scales16 = vaddq_s16(vandq_s16(scales16, mask), m127);
+ int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
+ return scales;
+ }
+ inline void make_quants(uint16x8x4_t& q4bits, Q4bits& bits, const int16_t * aux) const {
+ bits.b1.val[0] = vqtbl1q_s8(values.val[aux[0] & 1], vandq_u8(q4bits.val[0], bits.m4b));
+ bits.b1.val[1] = vqtbl1q_s8(values.val[aux[0] & 1], vshrq_n_u8(q4bits.val[0], 4));
+ bits.b1.val[2] = vqtbl1q_s8(values.val[aux[1] & 1], vandq_u8(q4bits.val[1], bits.m4b));
+ bits.b1.val[3] = vqtbl1q_s8(values.val[aux[1] & 1], vshrq_n_u8(q4bits.val[1], 4));
+ bits.b2.val[0] = vqtbl1q_s8(values.val[aux[2] & 1], vandq_u8(q4bits.val[2], bits.m4b));
+ bits.b2.val[1] = vqtbl1q_s8(values.val[aux[2] & 1], vshrq_n_u8(q4bits.val[2], 4));
+ bits.b2.val[2] = vqtbl1q_s8(values.val[aux[3] & 1], vandq_u8(q4bits.val[3], bits.m4b));
+ bits.b2.val[3] = vqtbl1q_s8(values.val[aux[3] & 1], vshrq_n_u8(q4bits.val[3], 4));
+ }
+ inline void prepare([[maybe_unused]] int i, int j) {
+ if (j == 0) return;
+ make_quants(q4bits_2, bits, aux+4);
+ }
+ static int16x8_t load_shift() {
+ static const int16_t k_shift[8] = {0, 1, 2, 3, 4, 5, 6, 7};
+ return vld1q_s16(k_shift);
+ }
+
+ Q4bits bits;
+ const int8x16x2_t values;
+ const uint16x8_t mask = vdupq_n_s16(254);
+ const uint16x8_t bmask = vdupq_n_u16(0xfffe);
+ const uint16x8_t m1 = vdupq_n_u16(1);
+ const int16x8_t shift = load_shift();
+ const int16x8_t m127 = vdupq_n_s16(-127);
+ uint16x8x4_t q4bits_2;
+ int16_t aux[8];
+};
+
struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
DequantizerIQ2KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
@@ -6716,6 +6885,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_IQ4_KS:
MulMat::set_functions<DequantizerIQ4KS>(m);
break;
+ case GGML_TYPE_IQ4_KSS:
+ MulMat::set_functions<DequantizerIQ4KSS>(m);
+ break;
case GGML_TYPE_IQ2_KS:
MulMat::set_functions<DequantizerIQ2KS>(m);
break;