diff options
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 178 |
1 files changed, 175 insertions, 3 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 66d26a25..7cd0dbf5 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1209,6 +1209,67 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> { }; }; +struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> { + DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {} + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + uint32_t aux32[2]; + auto b1 = _mm512_loadu_si512((const __m512i *)x[i].qs + 0); + auto b2 = _mm512_loadu_si512((const __m512i *)x[i].qs + 1); + auto bs1 = _mm512_and_si512(b1, mask15); + bs1 = _mm512_xor_si512(bs1, _mm512_srli_epi16(bs1, 1)); + auto bs2 = _mm512_and_si512(b2, mask15); + bs2 = _mm512_xor_si512(bs2, _mm512_srli_epi16(bs2, 1)); + bits.values[0] = _mm512_and_si512(bs1, bits.ml); + bits.values[1] = _mm512_and_si512(_mm512_srli_epi16(bs1, 4), bits.ml); + bits.values[2] = _mm512_and_si512(bs2, bits.ml); + bits.values[3] = _mm512_and_si512(_mm512_srli_epi16(bs2, 4), bits.ml); + auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]); + bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1])); + bits.values[0] = _mm512_shuffle_epi8(values, tmp); + tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]); + bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3])); + bits.values[2] = _mm512_shuffle_epi8(values, tmp); + // + // Now the more difficult part - prepare the scales + // + aux32[0] = _mm512_cmpeq_epi16_mask(_mm512_and_si512(b1, mask1), mask1); + aux32[1] = _mm512_cmpeq_epi16_mask(_mm512_and_si512(b2, mask1), mask1); + + auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)aux32)); + auto m1 = _mm512_castsi512_si128(mask1); + auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4); + scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127); + auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts)); + s8k.accum_mins(scales_s, q8, i, d, accm); + auto scales256 = MM256_SET_M128I(scales128, scales128); + auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); + scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]); + scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]); + scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]); + scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]); + } + + Q4Bits bits; + Scales8KBase s8k; + const __m512i values; + const __m512i mask15 = _mm512_set1_epi16(0xfffe); + const __m512i mask1 = _mm512_set1_epi16(1); + const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); + const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); + const __m128i mask = _mm_set1_epi16(254); + const __m128i m127 = _mm_set1_epi16(-127); + const __m128i m128 = _mm_set1_epi16(-128); + const __m128i m4 = _mm_set1_epi16(4); + const __m512i shuffles[4] = { + _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1), + }; +}; + + template <typename Q8> inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) { const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0)); @@ -1821,8 +1882,54 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> { const __m128i m128 = _mm_set1_epi16(-128); const __m128i m1 = _mm_set1_epi16(1); const __m128i m4 = _mm_set1_epi16(4); - const __m256i shuff1 = _mm256_set_epi64x(0x0706070605040504, 0x0302030201000100, 0x0706070605040504, 0x0302030201000100); - const __m256i shuff2 = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908); +}; + +struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> { + DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {} + template <typename Q8> + inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { + union { __m256i vec; uint16_t val[16]; } helper; + for (int k = 0; k < 4; ++k) { + data[k] = _mm256_loadu_si256((const __m256i *)x[i].qs + k); + auto p = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(data[k], m1), m1), smask); + p = _mm256_add_epi32(_mm256_unpackhi_epi64(p, p), p); + p = _mm256_add_epi32(_mm256_shuffle_epi32(p, _MM_SHUFFLE(2, 3, 0, 1)), p); + helper.vec = _mm256_hadd_epi16(p, p); + aux[2*k+0] = helper.val[0]; + aux[2*k+1] = helper.val[8]; + data[k] = _mm256_and_si256(data[k], bmask); + data[k] = _mm256_xor_si256(data[k], _mm256_srli_epi16(data[k], 1)); + } + auto scales128 = _mm_loadu_si128((const __m128i *)aux); + auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, _mm256_castsi256_si128(m1)), _mm256_castsi256_si128(m1)), m4); + scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127); + auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts)); + s8k.accum_mins(scales_s, q8, i, d, accd); + return MM256_SET_M128I(scales128, scales128); + } + inline void prepare(int, int j) { + for (int k = 0; k < 2; ++k) { + auto p1 = _mm256_castsi256_si128(data[2*j+k]); + auto p2 = _mm256_extractf128_si256(data[2*j+k], 1); + bits.values[2*k+0] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(p1, 4), p1), bits.ml); + bits.values[2*k+0] = _mm256_shuffle_epi8(values, bits.values[2*k+0]); + bits.values[2*k+1] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(p2, 4), p2), bits.ml); + bits.values[2*k+1] = _mm256_shuffle_epi8(values, bits.values[2*k+1]); + } + } + + Q4Bits bits; + Scales8KBase s8k; + const __m256i values; + __m256i data[4]; + const __m256i smask = _mm256_set_epi64x(0x0080004000200010, 0x0008000400020001, 0x0080004000200010, 0x0008000400020001); + const __m256i bmask = _mm256_set1_epi16(0xfffe); + const __m128i mask = _mm_set1_epi16(254); + const __m128i m127 = _mm_set1_epi16(-127); + const __m128i m128 = _mm_set1_epi16(-128); + const __m256i m1 = _mm256_set1_epi16(1); + const __m128i m4 = _mm_set1_epi16(4); + uint16_t aux[8]; }; struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> { @@ -3848,7 +3955,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { std::is_same_v<Dequantizer, DequantizerIQ4K> || std::is_same_v<Dequantizer, DequantizerIQ3K> || std::is_same_v<Dequantizer, DequantizerIQ4XS>|| - std::is_same_v<Dequantizer, DequantizerIQ4KS>) { + std::is_same_v<Dequantizer, DequantizerIQ4KS>|| + std::is_same_v<Dequantizer, DequantizerIQ4KSS>) { m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>; m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>; m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 3>; @@ -4012,6 +4120,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { assert (ne00 % QK_K == 0); MulMat::set_functions<DequantizerIQ4KS>(mm); break; + case GGML_TYPE_IQ4_KSS: + assert (ne00 % QK_K == 0); + MulMat::set_functions<DequantizerIQ4KSS>(mm); + break; case GGML_TYPE_IQ2_K: assert (ne00 % QK_K == 0); MulMat::set_functions<DequantizerIQ2K>(mm); @@ -4945,6 +5057,63 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> { const int16x8_t m127 = vdupq_n_s16(-127); }; +struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> { + + DequantizerIQ4KSS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template <typename Q8> + inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { + (void)q8; + (void)acc; + auto q4bits_1 = vld1q_u16_x4((const uint16_t *)x[i].qs); + q4bits_2 = vld1q_u16_x4((const uint16_t *)x[i].qs + 32); + for (int k = 0; k < 4; ++k) { + aux[k+0] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_1.val[k], m1), shift)); + aux[k+4] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_2.val[k], m1), shift)); + q4bits_1.val[k] = vandq_u16(q4bits_1.val[k], bmask); + q4bits_1.val[k] = veorq_u16(q4bits_1.val[k], vshrq_n_u16(q4bits_1.val[k], 1)); + q4bits_2.val[k] = vandq_u16(q4bits_2.val[k], bmask); + q4bits_2.val[k] = veorq_u16(q4bits_2.val[k], vshrq_n_u16(q4bits_2.val[k], 1)); + } + make_quants(q4bits_1, bits, aux); + auto scales16 = vld1q_s16(aux); + scales16 = vaddq_s16(vandq_s16(scales16, mask), m127); + int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))}; + return scales; + } + inline void make_quants(uint16x8x4_t& q4bits, Q4bits& bits, const int16_t * aux) const { + bits.b1.val[0] = vqtbl1q_s8(values.val[aux[0] & 1], vandq_u8(q4bits.val[0], bits.m4b)); + bits.b1.val[1] = vqtbl1q_s8(values.val[aux[0] & 1], vshrq_n_u8(q4bits.val[0], 4)); + bits.b1.val[2] = vqtbl1q_s8(values.val[aux[1] & 1], vandq_u8(q4bits.val[1], bits.m4b)); + bits.b1.val[3] = vqtbl1q_s8(values.val[aux[1] & 1], vshrq_n_u8(q4bits.val[1], 4)); + bits.b2.val[0] = vqtbl1q_s8(values.val[aux[2] & 1], vandq_u8(q4bits.val[2], bits.m4b)); + bits.b2.val[1] = vqtbl1q_s8(values.val[aux[2] & 1], vshrq_n_u8(q4bits.val[2], 4)); + bits.b2.val[2] = vqtbl1q_s8(values.val[aux[3] & 1], vandq_u8(q4bits.val[3], bits.m4b)); + bits.b2.val[3] = vqtbl1q_s8(values.val[aux[3] & 1], vshrq_n_u8(q4bits.val[3], 4)); + } + inline void prepare([[maybe_unused]] int i, int j) { + if (j == 0) return; + make_quants(q4bits_2, bits, aux+4); + } + static int16x8_t load_shift() { + static const int16_t k_shift[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + return vld1q_s16(k_shift); + } + + Q4bits bits; + const int8x16x2_t values; + const uint16x8_t mask = vdupq_n_s16(254); + const uint16x8_t bmask = vdupq_n_u16(0xfffe); + const uint16x8_t m1 = vdupq_n_u16(1); + const int16x8_t shift = load_shift(); + const int16x8_t m127 = vdupq_n_s16(-127); + uint16x8x4_t q4bits_2; + int16_t aux[8]; +}; + struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> { DequantizerIQ2KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} @@ -6716,6 +6885,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_IQ4_KS: MulMat::set_functions<DequantizerIQ4KS>(m); break; + case GGML_TYPE_IQ4_KSS: + MulMat::set_functions<DequantizerIQ4KSS>(m); + break; case GGML_TYPE_IQ2_KS: MulMat::set_functions<DequantizerIQ2KS>(m); break; |