diff options
Diffstat (limited to 'ggml/src/iqk/iqk_gemm_iqk_quants.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_gemm_iqk_quants.cpp | 3289 |
1 files changed, 3289 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp new file mode 100644 index 00000000..15c963ca --- /dev/null +++ b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp @@ -0,0 +1,3289 @@ +#include "iqk_gemm_iqk_quants.h" + +#ifdef IQK_IMPLEMENT + +#include "ggml-impl.h" + +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#ifdef __x86_64__ + +namespace { + +#ifdef HAVE_FANCY_SIMD + +struct IQXKScales { + IQXKScales(uint8_t shift, int8_t min_val) : eshift(_mm256_set1_epi16(shift)), min(_mm256_set1_epi16(min_val)) {} + template <typename Q8> + inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m512i * scales) const { + auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)); + scales16 = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, extra, min, eshift)); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i)); + accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); + } + scales16 = MM256_SET_M128I(scales8, scales8); + scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1)); + scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2)); + } + const __m256i eshift; + const __m256i min; + const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101); + const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200); + const __m256i shuffle1 = _mm256_set_epi64x(0x0b0b0b0b09090909, 0x0303030301010101, 0x0a0a0a0a08080808, 0x0202020200000000); + const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0d0d0d0d, 0x0707070705050505, 0x0e0e0e0e0c0c0c0c, 0x0606060604040404); +}; + +struct IQXKScales2 { + IQXKScales2(uint8_t shift, int8_t min_val) : eshift(_mm256_set1_epi16(shift)), min(_mm256_set1_epi16(min_val)) {} + template <typename Q8> + inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m512i * scales) const { + process(i, d, extra, _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)), q8, accm, scales); + } + template <typename Q8> + inline void process(int i, float d, uint16_t extra, __m256i scales16, const Q8& q8, __m256 * accm, __m512i * scales) const { + auto scales_s = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, extra, min, eshift)); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i prod = _mm256_madd_epi16(scales_s, q8.load_bsums(iy, i)); + accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); + } + auto aux_1 = MM256_SET_M128I(_mm256_castsi256_si128(scales16), _mm256_castsi256_si128(scales16)); + auto aux_2 = MM256_SET_M128I(_mm256_extracti128_si256(scales16, 1), _mm256_extracti128_si256(scales16, 1)); + auto scales16_1 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_1), aux_1, 1); + auto scales16_2 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_2), aux_2, 1); + scales[0] = _mm512_shuffle_epi8(scales16_1, shuffles[0]); + scales[1] = _mm512_shuffle_epi8(scales16_1, shuffles[1]); + scales[2] = _mm512_shuffle_epi8(scales16_2, shuffles[0]); + scales[3] = _mm512_shuffle_epi8(scales16_2, shuffles[1]); + } + const __m256i eshift; + const __m256i min; + const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101); + const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200); + const __m512i shuffles[2] = { + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(), + _mm_set1_epi16(0x0100), 0), _mm_set1_epi16(0x0302), 1), _mm_set1_epi16(0x0504), 2), _mm_set1_epi16(0x0706), 3), + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(), + _mm_set1_epi16(0x0908), 0), _mm_set1_epi16(0x0b0a), 1), _mm_set1_epi16(0x0d0c), 2), _mm_set1_epi16(0x0f0e), 3) + }; +}; + +struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> { + DequantizerIQ2KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {} + template <typename Q8> + inline void compute_block(int i, const Q8& q8, __m512 * acc) { + prepare(x[i].qs); + auto scales128 = make_scales(x[i].scales, x[i].extra >> 8); + auto shifts = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi8(x[i].extra), hmask), hmask), m5); + auto mins128 = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts))); + auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0])); + auto scales256 = MM256_SET_M128I(scales128, scales128); + auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); + __m512i scales[4]; + for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums(iy, i); + auto prod = _mm256_madd_epi16(mins, q8s); + auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0); + for (int k = 0; k < 4; ++k) { + auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k)); + sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]); + } + acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]); + } + } + inline void prepare(const uint8_t * q2) { + bits.prepare(q2); + bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]); + } + static inline __m512i load_values() { + static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl); + auto val256 = MM256_SET_M128I(val128, val128); + return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); + } + inline __m128i make_scales(const uint8_t * scales_l, uint8_t scales_h) const { + const uint16_t * scales = (const uint16_t *)scales_l; + uint32_t aux32 = scales[0] | (uint32_t(scales[1]) << 16); + auto scl = _mm_srlv_epi32(_mm_set1_epi32(aux32), shift); + scl = _mm_and_si128(_mm_shuffle_epi8(scl, shuffle), _mm_set1_epi8(0xf)); + auto sch = _mm_set1_epi8(scales_h); + sch = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(sch, hmask), _mm_setzero_si128()), m16); + return _mm_cvtepi8_epi16(_mm_add_epi8(scl, sch)); + } + Q2Bits bits; + Scales8KBase s8k; + + const __m512i values; + const __m128i m16 = _mm_set1_epi8(-16); + const __m128i m5 = _mm_set1_epi8(5); + const __m128i m32 = _mm_set1_epi8(-32); + const __m128i hmask = _mm_set1_epi64x(0x8040201008040201); + const __m128i shuffle = _mm_set1_epi64x(0x0703060205010400); + const __m128i shift = _mm_set_epi32(0, 0, 4, 0); + const __m512i shuffles[4] = { + _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1), + }; +}; + +struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> { + DequantizerIQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(IQXKScales(5, -32)), values(load_values()) {} + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + prepare(x[i].qs); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales), q8, accm, scales); + } + inline void prepare(const uint8_t * q2) { + bits.prepare(q2); + bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]); + } + static inline __m512i load_values() { + static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl); + auto val256 = MM256_SET_M128I(val128, val128); + return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); + } + inline __m128i make_scales(const uint8_t * scales_l) const { + uint64_t aux64; std::memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf)); + return _mm_add_epi8(scl, m8); + } + Q2Bits bits; + const IQXKScales iqxk; + + const __m512i values; + const __m128i m8 = _mm_set1_epi8(-8); +}; + +struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> { + DequantizerIQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -64), values(load_values()) {} + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + prepare(x[i].qs, x[i].qh); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_h, x[i].scales_l), q8, accm, scales); + } + inline void prepare(const uint8_t * q2, const uint8_t * qh) { + bits.prepare(q2); + auto h256 = _mm256_loadu_si256((const __m256i *)qh); + auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 1), 1); + bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), hmask)); + bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, hmask)); + bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), hmask)); + bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), hmask)); + bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]); + } + static inline __m512i load_values() { + static const uint8_t kvalues_iq3nl[16] = {1, 24, 41, 54, 65, 77, 92, 111, 5, 28, 45, 58, 69, 81, 96, 115}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq3nl); + auto val256 = MM256_SET_M128I(val128, val128); + return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); + } + inline __m128i make_scales(uint16_t signs, const uint8_t * scales_l) const { + uint64_t aux64; std::memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf)); + scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), m1); + const __m128i sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(signs), sign_mask), sign_mask); + const __m128i sch = _mm_shuffle_epi8(_mm_or_si128(sc_signs, _mm_set1_epi8(1)), hshuff); + return _mm_sign_epi8(scl, sch); + } + Q2Bits bits; + const IQXKScales2 iqxk; + + const __m512i values; + const __m512i hmask = _mm512_set1_epi8(4); + const __m128i m1 = _mm_set1_epi8(1); + const __m128i sign_mask = _mm_set_epi64x(0x8080404020201010, 0x0808040402020101); + const __m128i hshuff = _mm_loadu_si128((const __m128i*)k_shuff); + constexpr static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; +}; + +struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> { + DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {} + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + uint32_t aux32[2]; + auto b1 = _mm512_loadu_si512((const __m512i *)x[i].qs + 0); + auto b2 = _mm512_loadu_si512((const __m512i *)x[i].qs + 1); + auto bs1 = _mm512_and_si512(b1, mask15); + bs1 = _mm512_xor_si512(bs1, _mm512_srli_epi16(bs1, 1)); + auto bs2 = _mm512_and_si512(b2, mask15); + bs2 = _mm512_xor_si512(bs2, _mm512_srli_epi16(bs2, 1)); + bits.values[0] = _mm512_and_si512(bs1, bits.ml); + bits.values[1] = _mm512_and_si512(_mm512_srli_epi16(bs1, 4), bits.ml); + bits.values[2] = _mm512_and_si512(bs2, bits.ml); + bits.values[3] = _mm512_and_si512(_mm512_srli_epi16(bs2, 4), bits.ml); + auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]); + bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1])); + bits.values[0] = _mm512_shuffle_epi8(values, tmp); + tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]); + bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3])); + bits.values[2] = _mm512_shuffle_epi8(values, tmp); + // + // Now the more difficult part - prepare the scales + // + aux32[0] = _mm512_cmpeq_epi16_mask(_mm512_and_si512(b1, mask1), mask1); + aux32[1] = _mm512_cmpeq_epi16_mask(_mm512_and_si512(b2, mask1), mask1); + + auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)aux32)); + auto m1 = _mm512_castsi512_si128(mask1); + auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4); + scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127); + auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts)); + s8k.accum_mins(scales_s, q8, i, d, accm); + auto scales256 = MM256_SET_M128I(scales128, scales128); + auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); + scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]); + scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]); + scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]); + scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]); + } + + Q4Bits bits; + Scales8KBase s8k; + const __m512i values; + const __m512i mask15 = _mm512_set1_epi16(-2); // value is 0xfffe, but to shut up the stupid compiler warning we use the signed value + const __m512i mask1 = _mm512_set1_epi16(1); + const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); + const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); + const __m128i mask = _mm_set1_epi16(254); + const __m128i m127 = _mm_set1_epi16(-127); + const __m128i m128 = _mm_set1_epi16(-128); + const __m128i m4 = _mm_set1_epi16(4); + const __m512i shuffles[4] = { + _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1), + }; +}; + +struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> { + DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {} + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales)); + auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4); + scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127); + auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts)); + s8k.accum_mins(scales_s, q8, i, d, accm); + auto scales256 = MM256_SET_M128I(scales128, scales128); + auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); + scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]); + scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]); + scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]); + scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]); + prepare(x[i].qs); + } + template <typename Q8> + inline void compute_block(int i, const Q8& q8, __m512 * acc) { + auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales)); + auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4); + scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127); + auto mins128 = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts)); + auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0])); + auto scales256 = MM256_SET_M128I(scales128, scales128); + auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); + __m512i scales[4]; + for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]); + prepare(x[i].qs); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums(iy, i); + auto prod = _mm256_madd_epi16(mins, q8s); + auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0); + for (int k = 0; k < 4; ++k) { + auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k)); + sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]); + } + acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]); + } + } + inline void prepare(const uint8_t * q4) { + bits.prepare64(q4); + // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111 + // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127 + // etc. + auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]); + bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1])); + bits.values[0] = _mm512_shuffle_epi8(values, tmp); + tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]); + bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3])); + bits.values[2] = _mm512_shuffle_epi8(values, tmp); + } + + Q4Bits bits; + Scales8KBase s8k; + const __m512i values; + const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); + const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); + const __m128i mask = _mm_set1_epi16(254); + const __m128i m127 = _mm_set1_epi16(-127); + const __m128i m128 = _mm_set1_epi16(-128); + const __m128i m1 = _mm_set1_epi16(1); + const __m128i m4 = _mm_set1_epi16(4); + const __m512i shuffles[4] = { + _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1), + }; +}; + +struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> { + DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -128), values(load_iq4nl_values_512()) {} + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + prepare(x[i].qs); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales); + } + inline void prepare(const uint8_t * q4) { + bits.prepare64(q4); + // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111 + // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127 + // etc. + auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]); + bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1])); + bits.values[0] = _mm512_shuffle_epi8(values, tmp); + tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]); + bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3])); + bits.values[2] = _mm512_shuffle_epi8(values, tmp); + } + __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const { + uint64_t aux64; + memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl); + const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16); + auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh); + auto sch = _mm_shuffle_epi8(aux, iqxk.scale_shuffle); + return _mm_add_epi8(_mm_or_si128(scl, sch), m32); + } + + Q4Bits bits; + const IQXKScales2 iqxk; + const __m512i values; + const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); + const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); + const __m128i maskl = _mm_set1_epi8(0xf); + const __m128i maskh = _mm_set1_epi8(0x30); + const __m128i m32 = _mm_set1_epi8(-32); +}; + +struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> { + DequantizerIQ5KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(values); } + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales)); + auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2); + scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127); + auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts)); + s8k.accum_mins(scales_s, q8, i, d, accm); + auto scales256 = MM256_SET_M128I(scales128, scales128); + auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); + scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]); + scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]); + scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]); + scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]); + prepare(x[i].qs, x[i].qh); + } + template <typename Q8> + inline void compute_block(int i, const Q8& q8, __m512 * acc) { + auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales)); + auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2); + scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127); + auto mins128 = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts)); + auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0])); + auto scales256 = MM256_SET_M128I(scales128, scales128); + auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); + __m512i scales[4]; + for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]); + prepare(x[i].qs, x[i].qh); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums(iy, i); + auto prod = _mm256_madd_epi16(mins, q8s); + auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0); + for (int k = 0; k < 4; ++k) { + auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k)); + sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]); + } + acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]); + } + } + inline void prepare(const uint8_t * q4, const uint8_t * qh) { + bits.prepare64a(q4); + auto h256 = _mm256_loadu_si256((const __m256i *)qh); + auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 1), 1); + auto m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1); + auto m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2); + bits.values[0] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[0]), m1, values[1], bits.values[0]); + bits.values[1] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[1]), m2, values[1], bits.values[1]); + hbits = _mm512_srli_epi16(hbits, 4); + m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1); + m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2); + bits.values[2] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[2]), m1, values[1], bits.values[2]); + bits.values[3] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[3]), m2, values[1], bits.values[3]); + } + static void load_values(__m512i * values) { + static const uint8_t kvalues_iq5nl[32] = { + 2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127, + 133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249, + }; + auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0); + auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1); + auto values256_1 = MM256_SET_M128I(values128_1, values128_1); + auto values256_2 = MM256_SET_M128I(values128_2, values128_2); + values[0] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_1), values256_1, 1); + values[1] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_2), values256_2, 1); + } + + Q4Bits bits; + Scales8KBase s8k; + __m512i values[2]; + const __m512i hmask1 = _mm512_set1_epi8(1); + const __m512i hmask2 = _mm512_set1_epi8(4); + const __m128i m127 = _mm_set1_epi16(-127); + const __m128i m128 = _mm_set1_epi16(-128); + const __m128i mask = _mm_set1_epi16(254); + const __m128i m1 = _mm_set1_epi16(1); + const __m128i m2 = _mm_set1_epi16(2); + const __m512i shuffles[4] = { + _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1), + }; +}; + +struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> { + DequantizerIQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(2, -128) { load_values(values); } + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + prepare(x[i].qs, x[i].qh); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales); + } + inline void prepare(const uint8_t * q4, const uint8_t * qh) { + bits.prepare64(q4); + auto h256 = _mm256_loadu_si256((const __m256i *)qh); + auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 2), 1); + auto m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1); + auto m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2); + bits.values[0] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[0]), m1, values[1], bits.values[0]); + bits.values[1] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[1]), m2, values[1], bits.values[1]); + hbits = _mm512_srli_epi16(hbits, 4); + m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1); + m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2); + bits.values[2] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[2]), m1, values[1], bits.values[2]); + bits.values[3] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[3]), m2, values[1], bits.values[3]); + // We now have in bits.valuse[0]: 0...31, 64...95 + // bits.valuse[1]: 32..63, 96..127 + // etc. + auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]); + bits.values[1] = _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]); + bits.values[0] = tmp; + tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]); + bits.values[3] = _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]); + bits.values[2] = tmp; + } + __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const { + uint64_t aux64; + memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl); + const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16); + auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh); + auto sch = _mm_shuffle_epi8(aux, iqxk.scale_shuffle); + return _mm_add_epi8(_mm_or_si128(scl, sch), m32); + } + static void load_values(__m512i * values) { + static const uint8_t kvalues_iq5nl[32] = { + 2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127, + 133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249, + }; + auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0); + auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1); + auto values256_1 = MM256_SET_M128I(values128_1, values128_1); + auto values256_2 = MM256_SET_M128I(values128_2, values128_2); + values[0] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_1), values256_1, 1); + values[1] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_2), values256_2, 1); + } + + Q4Bits bits; + const IQXKScales2 iqxk; + __m512i values[2]; + const __m512i hmask1 = _mm512_set1_epi8(1); + const __m512i hmask2 = _mm512_set1_epi8(2); + const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0); + const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4); + const __m128i maskl = _mm_set1_epi8(0xf); + const __m128i maskh = _mm_set1_epi8(0x30); + const __m128i m32 = _mm_set1_epi8(-32); +}; + +struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> { + DequantizerIQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(1, -128) { load_values(values); } + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + prepare(x[i].qs, x[i].qh); + auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales); + iqxk.process(i, d, x[i].extra, _mm256_cvtepi8_epi16(scales8), q8, accm, scales); + } + inline __m512i make_one(__m512i l, __m512i h) const { + auto p = _mm512_shuffle_epi8(values[0], l); + p = _mm512_mask_shuffle_epi8(p, _mm512_cmpeq_epi8_mask(_mm512_and_si512(h, masks[0]), masks[0]), values[1], l); + p = _mm512_mask_shuffle_epi8(p, _mm512_cmpeq_epi8_mask(_mm512_and_si512(h, masks[1]), masks[1]), values[2], l); + p = _mm512_mask_shuffle_epi8(p, _mm512_cmpeq_epi8_mask(_mm512_and_si512(h, masks[2]), masks[2]), values[3], l); + return p; + } + inline void prepare(const uint8_t * q4, const uint8_t * qh) { + bits.prepare64(q4); + auto h256_1 = _mm256_loadu_si256((const __m256i *)qh + 0); + auto h256_2 = _mm256_loadu_si256((const __m256i *)qh + 1); + auto h1 = _mm512_inserti32x8(_mm512_castsi256_si512(h256_1), _mm256_srli_epi16(h256_1, 4), 1); + auto h2 = _mm512_inserti32x8(_mm512_castsi256_si512(h256_2), _mm256_srli_epi16(h256_2, 4), 1); + bits.values[0] = make_one(bits.values[0], h1); + bits.values[1] = make_one(bits.values[1], _mm512_srli_epi16(h1, 2)); + bits.values[2] = make_one(bits.values[2], h2); + bits.values[3] = make_one(bits.values[3], _mm512_srli_epi16(h2, 2)); + // We now have in bits.valuse[0]: 0...31, 64...95 + // bits.valuse[1]: 32..63, 96..127 + // etc. + auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]); + bits.values[1] = _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]); + bits.values[0] = tmp; + tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]); + bits.values[3] = _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]); + bits.values[2] = tmp; + } + static void load_values(__m512i * values) { + static const uint8_t kvalues_iq6nl[64] = { + 1, 7, 13, 19, 24, 30, 35, 40, 44, 49, 54, 58, 62, 66, 70, 74, + 77, 81, 84, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 117, 120, 123, + 126, 128, 131, 134, 137, 140, 142, 145, 148, 151, 155, 158, 161, 164, 168, 172, + 175, 179, 183, 187, 191, 196, 200, 205, 210, 215, 220, 226, 231, 237, 243, 249, + }; + for (int k = 0; k < 4; ++k) { + auto values128 = _mm_loadu_si128((const __m128i *)kvalues_iq6nl + k); + auto values256 = MM256_SET_M128I(values128, values128); + values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(values256), values256, 1); + } + } + + Q4Bits bits; + IQXKScales2 iqxk; + __m512i values[4]; + __m512i masks[3] = { _mm512_set1_epi8(0x01), _mm512_set1_epi8(0x02), _mm512_set1_epi8(0x03) }; + const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0); + const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4); +}; + +template <typename Dequantizer, int nrc_y> +static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8<nrc_y> q8(info); + + Dequantizer deq(vx, bx); + + __m256 accm[nrc_y]; + __m512 accd[nrc_y]; + __m512i scales[4]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps(); + for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps(); + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + deq.new_block(i, q8, accm, scales); + + for (int iy = 0; iy < nrc_y; ++iy) { + const __m512i p1 = _mm512_maddubs_epi16(deq.bits.values[0], q8.load_quants64(iy, i, 0)); + const __m512i p2 = _mm512_maddubs_epi16(deq.bits.values[1], q8.load_quants64(iy, i, 1)); + const __m512i p3 = _mm512_maddubs_epi16(deq.bits.values[2], q8.load_quants64(iy, i, 2)); + const __m512i p4 = _mm512_maddubs_epi16(deq.bits.values[3], q8.load_quants64(iy, i, 3)); + auto sumi = _mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_setzero_si512(), + p1, scales[0]), p2, scales[1]), p3, scales[2]), p4, scales[3]); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); + info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); + } + + } +} + +template <typename Dequantizer, int nrc_y> +static void mul_mat_iqX_k_q8_K_AVX512_new(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8<nrc_y> q8(info); + + Dequantizer deq(vx, bx); + + __m512 accd[nrc_y]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps(); + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + deq.compute_block(i, q8, accd); + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, _mm512_reduce_add_ps(accd[iy])); + } + + } +} + +template <typename Q8> +inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) { + const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0)); + const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1)); + const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2)); + const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3)); + auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); + sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); +} + +template <typename Dequantizer> +static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + constexpr int k_nx = 2; + + Q8<1> q8(info); + + Dequantizer deq1(vx, bx); + Dequantizer deq2(vx, bx); + + Dequantizer * deq[k_nx]; + deq[0] = &deq1; + deq[1] = &deq2; + + __m512i scales[2*k_nx]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + auto accd = _mm512_setzero_ps(); + auto accm = _mm256_setzero_ps(); + + for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix); + + for (int i = 0; i < nb/k_nx; ++i) { + + for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx); + + for (int kx = 0; kx < k_nx; ++kx) { + compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd); + } + + } + if (2*(nb/2) < nb) { + int i0 = 2*(nb/2); + deq[0]->new_block(i0, q8, &accm, scales); + compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd); + } + + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1)); + info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256))); + } +} + +template <typename Dequantizer, int nrc_y> +static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8<nrc_y> q8(info); + + Dequantizer deq(vx, bx); + + __m256 accm[nrc_y]; + __m512 accd[nrc_y]; + __m512i scales[2]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps(); + for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps(); + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + deq.new_block(i, q8, accm, scales); + + for (int iy = 0; iy < nrc_y; ++iy) { + const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0)); + const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1)); + const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2)); + const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3)); + auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); + sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); + info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); + } + + } +} + +#else + +inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) { + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + scales[0] = MM256_SET_M128I(l_scales, l_scales); + scales[1] = MM256_SET_M128I(h_scales, h_scales); +} + +struct IQXKScales { + IQXKScales(int8_t shift, int8_t min_val) : min(_mm256_set1_epi16(min_val)), eshift(_mm_set1_epi8(shift)) {} + template <typename Q8> + inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m256i * scales) const { + auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, hshuff)); + process(i, d, extra, scales16, q8, accm, scales); + } + template <typename Q8> + inline void process(int i, float d, uint16_t extra, __m256i scales16, const Q8& q8, __m256 * accm, __m256i * scales) const { + auto extra128 = _mm_set1_epi16(extra); + extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask); + extra128 = _mm_and_si128(extra128, eshift); + extra128 = _mm_shuffle_epi8(extra128, eshuffle); + auto scales_s = _mm256_mullo_epi16(scales16, _mm256_add_epi16(min, _mm256_cvtepi8_epi16(extra128))); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i prod = _mm256_madd_epi16(scales_s, q8.load_bsums(iy, i)); + accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); + } + prepare_scales_16(scales16, scales); + } + + const __m256i min; + const __m128i eshift; + const __m128i hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101); + const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200); +}; + +struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> { + DequantizerIQ2KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {} + template <typename Q8> + inline __m256i new_block(int i, const Q8& q8, __m256 * accm) { + auto scales128 = make_scales(x[i].scales, x[i].extra >> 8); + auto shifts = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi8(x[i].extra), hmask), hmask), m5); + auto scales_s = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts))); + s8k.accum_mins(scales_s, q8, i, d, accm); + return MM256_SET_M128I(scales128, scales128); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]); + } + static inline __m256i load_values() { + static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl); + return MM256_SET_M128I(val128, val128); + } + inline __m128i make_scales(const uint8_t * scales_l, uint8_t scales_h) const { + const uint16_t * scales = (const uint16_t *)scales_l; + uint32_t aux32 = scales[0] | (uint32_t(scales[1]) << 16); + auto scl = _mm_srlv_epi32(_mm_set1_epi32(aux32), shift); + scl = _mm_and_si128(_mm_shuffle_epi8(scl, shuffle), _mm_set1_epi8(0xf)); + auto sch = _mm_set1_epi8(scales_h); + sch = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(sch, hmask), _mm_setzero_si128()), m16); + return _mm_cvtepi8_epi16(_mm_add_epi8(scl, sch)); + } + Q2Bits bits; + Scales8KBase s8k; + + const __m256i values; + const __m128i m16 = _mm_set1_epi8(-16); + const __m128i m5 = _mm_set1_epi8(5); + const __m128i m32 = _mm_set1_epi8(-32); + const __m128i hmask = _mm_set1_epi64x(0x8040201008040201); + const __m128i shuffle = _mm_set1_epi64x(0x0703060205010400); + const __m128i shift = _mm_set_epi32(0, 0, 4, 0); +}; + +struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> { + DequantizerIQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(5, -32), values(load_values()) {} + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales), q8, accm, scales); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]); + } + static inline __m256i load_values() { + static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl); + return MM256_SET_M128I(val128, val128); + } + inline __m128i make_scales(const uint8_t * scales_l) const { + uint64_t aux64; std::memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl); + return _mm_add_epi8(scl, m8); + } + + Q2Bits bits; + const IQXKScales iqxk; + const __m256i values; + const __m128i m8 = _mm_set1_epi8(-8); + const __m128i maskl = _mm_set1_epi8(0xf); +}; + +struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> { + DequantizerIQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -64), values(load_values()) {} + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_h, x[i].scales_l), q8, accm, scales); + hbits = _mm256_loadu_si256((const __m256i *)x[i].qh); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + auto h256 = j == 0 ? hbits : _mm256_srli_epi16(hbits, 4); + bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(h256, 2), hmask)); + bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(h256, 1), hmask)); + bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(h256, hmask)); + bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(h256, 1), hmask)); + bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]); + } + static inline __m256i load_values() { + static const uint8_t kvalues_iq3nl[16] = {1, 24, 41, 54, 65, 77, 92, 111, 5, 28, 45, 58, 69, 81, 96, 115}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq3nl); + return MM256_SET_M128I(val128, val128); + } + inline __m128i make_scales(uint16_t signs, const uint8_t * scales_l) const { + uint64_t aux64; std::memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf)); + scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), m1); + const __m128i sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(signs), sign_mask), sign_mask); + const __m128i sch = _mm_shuffle_epi8(_mm_or_si128(sc_signs, _mm_set1_epi8(1)), hshuff); + return _mm_sign_epi8(scl, sch); + } + + Q2Bits bits; + const IQXKScales iqxk; + const __m256i values; + __m256i hbits; + const __m256i hmask = _mm256_set1_epi8(4); + const __m128i m1 = _mm_set1_epi8(1); + const __m128i sign_mask = _mm_set_epi64x(0x8080404020201010, 0x0808040402020101); + const __m128i hshuff = _mm_loadu_si128((const __m128i*)k_shuff); + constexpr static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; +}; + +struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> { + DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {} + template <typename Q8> + inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { + union { __m256i vec; uint16_t val[16]; } helper; + for (int k = 0; k < 4; ++k) { + data[k] = _mm256_loadu_si256((const __m256i *)x[i].qs + k); + auto p = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(data[k], m1), m1), smask); + p = _mm256_add_epi32(_mm256_unpackhi_epi64(p, p), p); + p = _mm256_add_epi32(_mm256_shuffle_epi32(p, _MM_SHUFFLE(2, 3, 0, 1)), p); + helper.vec = _mm256_hadd_epi16(p, p); + aux[2*k+0] = helper.val[0]; + aux[2*k+1] = helper.val[8]; + data[k] = _mm256_and_si256(data[k], bmask); + data[k] = _mm256_xor_si256(data[k], _mm256_srli_epi16(data[k], 1)); + } + auto scales128 = _mm_loadu_si128((const __m128i *)aux); + auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, _mm256_castsi256_si128(m1)), _mm256_castsi256_si128(m1)), m4); + scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127); + auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts)); + s8k.accum_mins(scales_s, q8, i, d, accd); + return MM256_SET_M128I(scales128, scales128); + } + inline void prepare(int, int j) { + for (int k = 0; k < 2; ++k) { + auto p1 = _mm256_castsi256_si128(data[2*j+k]); + auto p2 = _mm256_extractf128_si256(data[2*j+k], 1); + bits.values[2*k+0] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(p1, 4), p1), bits.ml); + bits.values[2*k+0] = _mm256_shuffle_epi8(values, bits.values[2*k+0]); + bits.values[2*k+1] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(p2, 4), p2), bits.ml); + bits.values[2*k+1] = _mm256_shuffle_epi8(values, bits.values[2*k+1]); + } + } + + Q4Bits bits; + Scales8KBase s8k; + const __m256i values; + __m256i data[4]; + const __m256i smask = _mm256_set_epi64x(0x0080004000200010, 0x0008000400020001, 0x0080004000200010, 0x0008000400020001); + const __m256i bmask = _mm256_set1_epi16(-2); // 0xfffe; + const __m128i mask = _mm_set1_epi16(254); + const __m128i m127 = _mm_set1_epi16(-127); + const __m128i m128 = _mm_set1_epi16(-128); + const __m256i m1 = _mm256_set1_epi16(1); + const __m128i m4 = _mm_set1_epi16(4); + uint16_t aux[8]; +}; + +struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> { + DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); } + template <typename Q8> + inline __m256i new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accd) { + auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales)); + scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127); + return MM256_SET_M128I(scales128, scales128); + } + inline void prepare(int i, int j) { + bits.prepare16(x[i].qs, j); + bits.values[0] = _mm256_shuffle_epi8(values[x[i].scales[4*j+0] & 1], bits.values[0]); + bits.values[1] = _mm256_shuffle_epi8(values[x[i].scales[4*j+1] & 1], bits.values[1]); + bits.values[2] = _mm256_shuffle_epi8(values[x[i].scales[4*j+2] & 1], bits.values[2]); + bits.values[3] = _mm256_shuffle_epi8(values[x[i].scales[4*j+3] & 1], bits.values[3]); + } + void load_values() { + auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0); + auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1); + values[0] = MM256_SET_M128I(v1, v1); + values[1] = MM256_SET_M128I(v2, v2); + } + + + Q4Bits bits; + __m256i values[2]; + const __m128i mask = _mm_set1_epi16(254); + const __m128i m127 = _mm_set1_epi16(-127); +}; + +struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> { + DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); } + template <typename Q8> + inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, __m256i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + auto scales8 = make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h); + auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, hshuff)); + prepare_scales_16(scales16, scales); + } + inline void prepare(int i, int j) { + bits.prepare16(x[i].qs, j); + auto extra = x[i].extra >> 8*j; + bits.values[0] = _mm256_shuffle_epi8(values[extra & 3], bits.values[0]); extra >>= 2; + bits.values[1] = _mm256_shuffle_epi8(values[extra & 3], bits.values[1]); extra >>= 2; + bits.values[2] = _mm256_shuffle_epi8(values[extra & 3], bits.values[2]); extra >>= 2; + bits.values[3] = _mm256_shuffle_epi8(values[extra & 3], bits.values[3]); + } + __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const { + uint64_t aux64; + memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl); + const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16); + auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh); + auto sch = _mm_shuffle_epi8(aux, hshuff); + return _mm_add_epi8(_mm_or_si128(scl, sch), m32); + } + void load_values() { + auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0); + auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1); + values[0] = MM256_SET_M128I(v1, v1); + values[1] = MM256_SET_M128I(v1, v2); + values[2] = MM256_SET_M128I(v2, v1); + values[3] = MM256_SET_M128I(v2, v2); + } + + Q4Bits bits; + const __m128i maskl = _mm_set1_epi8(0xf); + const __m128i maskh = _mm_set1_epi8(0x30); + const __m128i m32 = _mm_set1_epi8(-32); + const __m128i hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800); + __m256i values[4]; +}; + +struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> { + DequantizerIQ5KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(values); } + template <typename Q8> + inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { + hbits = _mm256_loadu_si256((const __m256i *)x[i].qh); + auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales)); + auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2); + scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127); + auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts)); + s8k.accum_mins(scales_s, q8, i, d, accd); + return MM256_SET_M128I(scales128, scales128); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + auto h = j == 0 ? hbits : _mm256_srli_epi16(hbits, 4); + for (int k = 0; k < 4; ++k) { + auto qh = _mm256_and_si256(_mm256_slli_epi16(h, 7-k), mh); + auto q5vl = _mm256_or_si256(bits.values[k], qh); + auto q5vh = _mm256_or_si256(bits.values[k], _mm256_xor_si256(qh, mh)); + bits.values[k] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); + } + } + static void load_values(__m256i * values) { + static const uint8_t kvalues_iq5nl[32] = { + 2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127, + 133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249, + }; + auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0); + auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1); + values[0] = MM256_SET_M128I(values128_1, values128_1); + values[1] = MM256_SET_M128I(values128_2, values128_2); + } + + Q4Bits bits; + Scales8KBase s8k; + __m256i hbits; + __m256i values[2]; + const __m128i maskl = _mm_set1_epi8(0xf); + const __m128i maskh = _mm_set1_epi8(0x30); + const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing + const __m128i mask = _mm_set1_epi16(254); + const __m128i m127 = _mm_set1_epi16(-127); + const __m128i m128 = _mm_set1_epi16(-128); + const __m128i m1 = _mm_set1_epi16(1); + const __m128i m2 = _mm_set1_epi16(2); +}; + +struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> { + DequantizerIQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(2, 0) { load_values(values); } + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales); + hbits = _mm256_loadu_si256((const __m256i *)x[i].qh); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + auto h = j == 0 ? hbits : _mm256_srli_epi16(hbits, 4); + for (int k = 0; k < 4; ++k) { + auto qh = _mm256_and_si256(_mm256_slli_epi16(h, 7-k), mh); + auto q5vl = _mm256_or_si256(bits.values[k], qh); + auto q5vh = _mm256_or_si256(bits.values[k], _mm256_xor_si256(qh, mh)); + bits.values[k] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); + } + } + __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const { + uint64_t aux64; + memcpy(&aux64, scales_l, 8); + auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl); + const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16); + auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh); + auto sch = _mm_shuffle_epi8(aux, iqxk.hshuff); + return _mm_add_epi8(_mm_or_si128(scl, sch), m32); + } + static void load_values(__m256i * values) { + auto values128_1 = _mm_loadu_si128((const __m128i *)iq5nl_values + 0); + auto values128_2 = _mm_loadu_si128((const __m128i *)iq5nl_values + 1); + values[0] = MM256_SET_M128I(values128_1, values128_1); + values[1] = MM256_SET_M128I(values128_2, values128_2); + } + + Q4Bits bits; + const IQXKScales iqxk; + __m256i hbits; + __m256i values[2]; + const __m128i maskl = _mm_set1_epi8(0xf); + const __m128i maskh = _mm_set1_epi8(0x30); + const __m128i m32 = _mm_set1_epi8(-32); + const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing +}; + +struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> { + DequantizerIQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(1, 0) { load_values(values); } + template <typename Q8> + inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales); + auto scales16 = _mm256_cvtepi8_epi16(scales8); + iqxk.process(i, d, x[i].extra, scales16, q8, accm, scales); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j); + for (int k = 0; k < 4; ++k) { + bits.values[k] = make_one(bits.values[k], hbits); + hbits = _mm256_srli_epi16(hbits, 2); + } + } + inline __m256i make_one(__m256i l, __m256i hbits) const { + auto mask4 = _mm256_cmpeq_epi8(_mm256_and_si256(hbits, mh3), mh3); + auto h1 = _mm256_andnot_si256(mask4, hbits); + auto mask2 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh1), mh1); + auto mask3 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh2), mh2); + auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(-1)); // 0xff; + return _mm256_or_si256(_mm256_or_si256(_mm256_and_si256(mask1, _mm256_shuffle_epi8(values[0], l)), + _mm256_and_si256(mask2, _mm256_shuffle_epi8(values[1], l))), + _mm256_or_si256(_mm256_and_si256(mask3, _mm256_shuffle_epi8(values[2], l)), + _mm256_and_si256(mask4, _mm256_shuffle_epi8(values[3], l)))); + } + static void load_values(__m256i * values) { + for (int k = 0; k < 4; ++k) { + auto values128 = _mm_loadu_si128((const __m128i *)iq6nl_values + k); + values[k] = MM256_SET_M128I(values128, values128); + } + } + + Q4Bits bits; + const IQXKScales iqxk; + __m256i values[4]; + const __m256i mh1 = _mm256_set1_epi8(1); + const __m256i mh2 = _mm256_set1_epi8(2); + const __m256i mh3 = _mm256_set1_epi8(3); + const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing +}; + +inline __m256i get_scale_shuffle_16(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} + +inline void set_scales_16(const __m256i& all_scales, __m256i * scales) { + scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0)); + scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1)); + scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2)); + scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3)); +} + +inline __m256i get_scale_shuffle_8(int i) { + return _mm256_set1_epi16((2*i) | ((2*i+1) << 8)); +} + +inline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) { + scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0)); + scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1)); + scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2)); + scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3)); +} + +template <typename Dequantizer, int nrc_y> +static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Q8<nrc_y> q8(info); + + __m256i all_scales[2]; + __m256i scales[4]; + __m256 accd[nrc_y]; + + Dequantizer deq(vx, bx); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + deq.new_block(i, q8, accd, all_scales); + + __m256i sumi[nrc_y]; + + for (int j = 0; j < QK_K/128; ++j) { + deq.prepare(i, j); + set_scales_16(all_scales[j], scales); + if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4K> || + std::is_same_v<Dequantizer, DequantizerIQ5K> || + std::is_same_v<Dequantizer, DequantizerIQ6K>) { + multiply_add_avx2(deq.bits, scales, j, i, q8, sumi); + } else { + multiply_add(deq.bits, scales, j, i, q8, sumi); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(sumi[iy]), accd[iy]); + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + + } + +} + +template <typename Dequantizer, int nrc_y> +static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8<nrc_y> q8(info); + + Dequantizer deq(vx, bx); + + __m256 accd[nrc_y]; + __m256i scales[4]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + auto all_scales = deq.new_block(i, q8, accd); + + __m256i sumi[nrc_y]; + + for (int j = 0; j < QK_K/128; ++j) { + + deq.prepare(i, j); + + set_scales_8(all_scales, j, scales); + + if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS>) { + multiply_add_avx2(deq.bits, scales, j, i, q8, sumi); + } else { + multiply_add(deq.bits, scales, j, i, q8, sumi); + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i)); + accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]); + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + + } +} + +#endif + +template <int nrc_y> +//IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff, +inline void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff, + __m256i * isum, int16_t min) { + auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row + auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row + auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row + auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row + if constexpr (nrc_y == 1) { + auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9 + auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11 + auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13 + auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15 + auto sumi = _mm256_setzero_si256(); + auto bsums = q8.load_bsums(0, ibl); +#ifdef HAVE_FANCY_SIMD + sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00)); + sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55)); + sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa)); + sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff)); +#else + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff))); +#endif + isum[0] = _mm256_mullo_epi32(sumi, _mm256_set1_epi32(min)); + + } else { + auto s1 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9 + auto s2 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11 + auto s3 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13 + auto s4 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15 + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, ibl); +#ifdef HAVE_FANCY_SIMD + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff)); +#else + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff))); +#endif + } + } +} + +template <int nrc_y> +inline void iq2345_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff, + __m256i extra, __m256i * isum, int8_t min, int8_t delta) { + auto mask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101); + auto vdelta = _mm256_set1_epi8(delta); + auto vmin = _mm256_set1_epi8(min); + auto min1 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(extra, mask), mask))); + auto min2 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(extra, 4), mask), mask))); + auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row + auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row + auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row + auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row + auto m1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 0)), shuff); // blocks 0, 1, 2, 3 for each row + auto m2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 1)), shuff); // blocks 4, 5, 6, 7 for each row + auto m3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 0)), shuff); // blocks 8, 9, 10, 11 for each row + auto m4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 1)), shuff); // blocks 12, 13, 14, 15 for each row + auto s1 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 0), _mm256_extracti128_si256(m1, 0)), + MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9 + auto s2 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 1), _mm256_extracti128_si256(m1, 1)), + MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11 + auto s3 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 0), _mm256_extracti128_si256(m2, 0)), + MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13 + auto s4 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 1), _mm256_extracti128_si256(m2, 1)), + MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15 + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, ibl); +#ifdef HAVE_FANCY_SIMD + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff)); +#else + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff))); +#endif + } +} + +template <int nrc_y> +static void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto ms = _mm256_set1_epi8(4); + auto m03 = _mm256_set1_epi8(0x03); + auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000); + static const uint8_t kvalues_iq2nl[32] = {1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54}; + auto values = _mm256_loadu_si256((const __m256i*)kvalues_iq2nl); + static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); +#ifndef HAVE_FANCY_SIMD + auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); +#endif + int nbl = n / QK_K; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + uint64_t stored_scales[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq2_k_r4 * iq2 = (const block_iq2_k_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq2[ibl].extra); + auto slbits = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales); + auto i8scales1 = _mm256_add_epi8(_mm256_and_si256(slbits, m4), _mm256_set1_epi8(-8)); + auto i8scales2 = _mm256_add_epi8(_mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4), _mm256_set1_epi8(-8)); + _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1); + _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2); + __m256i isum[nrc_y] = {}; + iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -32); + for (int ib = 0; ib < QK_K/32; ++ib) { +#ifdef HAVE_FANCY_SIMD + auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib))); +#else + auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle); +#endif + auto lb = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs+ib); + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 2)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_and_si256(lb, m03); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(lb, 2), m03); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(lb, 4), m03); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(lb, 6), m03); + qx[0] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[0], shift)); + qx[1] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[1], shift)); + qx[2] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[2], shift)); + qx[3] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[3], shift)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi)); +#else + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2))); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + +template <int nrc_y> +static void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto ms = _mm256_set1_epi8(8); + auto m03 = _mm256_set1_epi8(0x03); + auto m04 = _mm256_set1_epi8(0x04); + auto smask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101); + auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000); + auto values128 = _mm_loadu_si128((const __m128i *)iq3nl_values); + auto values = MM256_SET_M128I(values128, values128); + values = _mm256_add_epi8(values, _mm256_set1_epi8(64)); + static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); +#ifndef HAVE_FANCY_SIMD + auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); +#endif + int nbl = n / QK_K; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + uint64_t stored_scales[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq3[ibl].extra); + auto slbits = _mm256_loadu_si256((const __m256i *)iq3[ibl].scales_l); + auto sl1 = _mm256_add_epi8(_mm256_slli_epi16(_mm256_and_si256(slbits, m4), 1), _mm256_set1_epi8(1)); + auto sl2 = _mm256_add_epi8(_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4), 1), _mm256_set1_epi8(1)); + auto sh = _mm256_set1_epi64x(((const uint64_t *)iq3[ibl].scales_h)[0]); + auto sh1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sh, smask), smask), _mm256_set1_epi8(1)); + auto sh2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(sh, 4), smask), smask), _mm256_set1_epi8(1)); + auto i8scales1 = _mm256_sign_epi8(sl1, sh1); + auto i8scales2 = _mm256_sign_epi8(sl2, sh2); + _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1); + _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2); + __m256i isum[nrc_y] = {}; + iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -64); + for (int ib = 0; ib < QK_K/32; ++ib) { +#ifdef HAVE_FANCY_SIMD + auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib))); +#else + auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle); +#endif + auto lb = _mm256_loadu_si256((const __m256i *)iq3[ibl].qs+ib); + auto hbits = _mm_loadu_si128((const __m128i *)iq3[ibl].qh+ib); + auto hb = MM256_SET_M128I(hbits, _mm_slli_epi16(hbits, 4)); + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 3)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_or_si256(_mm256_and_si256(lb, m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 2))); + qx[1] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 2), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 3))); + qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 4), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 4))); + qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 6), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 5))); + qx[0] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[0], shift)); + qx[1] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[1], shift)); + qx[2] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[2], shift)); + qx[3] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[3], shift)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi)); +#else + auto sumi1 = _mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)); + auto sumi2 = _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)); + auto sumi3 = _mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)); + auto sumi4 = _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4))); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + +template <int nrc_y> +static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m30 = _mm256_set1_epi8(0x30); + auto m32 = _mm256_set1_epi8(32); + auto ms = _mm256_set1_epi8(4); + auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000); +#ifdef HAVE_FANCY_SIMD + auto values = load_iq4nl_values_256(); + static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); +#else + auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); + auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); + auto values = MM256_SET_M128I(values128, values128); +#endif + int nbl = n / QK_K; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + uint64_t stored_scales[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq4_k_r4 * iq4 = (const block_iq4_k_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq4[ibl].extra); + auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l); + auto sl1 = _mm256_and_si256(slbits, m4); + auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4); + auto shbits = _mm_loadu_si128((const __m128i*)iq4[ibl].scales_h); + auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits); + auto i8scales1 = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(m30, _mm256_slli_epi16(sh, 4))), m32); + auto i8scales2 = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(m30, sh)), m32); + _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1); + _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2); + __m256i isum[nrc_y] = {}; +#ifdef HAVE_FANCY_SIMD + iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128); +#endif + for (int ib = 0; ib < QK_K/32; ++ib) { +#ifdef HAVE_FANCY_SIMD + auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib))); +#else + auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle); +#endif + auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1); + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 2)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4))); + qx[1] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4))); + qx[2] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4))); + qx[3] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4))); +#ifndef HAVE_FANCY_SIMD + auto s1 = _mm256_sign_epi8(qx[0], qx[0]); + auto s2 = _mm256_sign_epi8(qx[1], qx[1]); + auto s3 = _mm256_sign_epi8(qx[2], qx[2]); + auto s4 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi)); +#else + auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4))); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + +static inline __m256i prepare_5bit_quants(const __m256i * values, __m256i ql, __m256i qh, __m256i mask) { + auto q5vl = _mm256_shuffle_epi8(values[0], ql); + auto q5vh = _mm256_shuffle_epi8(values[1], ql); +#ifdef HAVE_FANCY_SIMD + return _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(qh, mask), mask), q5vl, q5vh); +#else + return _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(qh, mask), mask)); +#endif +} + +template <int nrc_y> +static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m30 = _mm256_set1_epi8(0x30); + auto m32 = _mm256_set1_epi8(32); + auto ms = _mm256_set1_epi8(2); + auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000); + __m256i values[2]; + { + auto val1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0); + auto val2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1); + values[0] = MM256_SET_M128I(val1, val1); + values[1] = MM256_SET_M128I(val2, val2); +#ifdef HAVE_FANCY_SIMD + values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128)); + values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128)); +#endif + } +#ifdef HAVE_FANCY_SIMD + static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); +#else + auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); +#endif + int nbl = n / QK_K; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + uint64_t stored_scales[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq5[ibl].extra); + auto slbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l); + auto sl1 = _mm256_and_si256(slbits, m4); + auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4); + auto shbits = _mm_loadu_si128((const __m128i*)iq5[ibl].scales_h); + auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits); + auto i8scales1 = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(m30, _mm256_slli_epi16(sh, 4))), m32); + auto i8scales2 = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(m30, sh)), m32); + _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1); + _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2); + __m256i isum[nrc_y] = {}; +#ifdef HAVE_FANCY_SIMD + if constexpr (nrc_y == 1) { + iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128); + } else { + iq2345_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, extra, isum, -128, 2); + } +#endif + for (int ib = 0; ib < QK_K/32; ++ib) { +#ifdef HAVE_FANCY_SIMD + auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib))); +#else + auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle); +#endif + auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0); + auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1); + auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib); + auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits); + qx[0] = _mm256_and_si256(lbits1, m4); + qx[1] = _mm256_and_si256(lbits2, m4); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4); + + qx[0] = prepare_5bit_quants(values, qx[0], hb, _mm256_set1_epi8(0x01)); + qx[1] = prepare_5bit_quants(values, qx[1], hb, _mm256_set1_epi8(0x10)); + qx[2] = prepare_5bit_quants(values, qx[2], hb, _mm256_set1_epi8(0x02)); + qx[3] = prepare_5bit_quants(values, qx[3], hb, _mm256_set1_epi8(0x20)); +#ifdef HAVE_FANCY_SIMD + if constexpr (nrc_y == 1) { + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_add_epi8(qx[0], shift); + qx[1] = _mm256_add_epi8(qx[1], shift); + qx[2] = _mm256_add_epi8(qx[2], shift); + qx[3] = _mm256_add_epi8(qx[3], shift); + } +#else + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_add_epi8(qx[0], shift); + qx[1] = _mm256_add_epi8(qx[1], shift); + qx[2] = _mm256_add_epi8(qx[2], shift); + qx[3] = _mm256_add_epi8(qx[3], shift); + auto s1 = _mm256_sign_epi8(qx[0], qx[0]); + auto s2 = _mm256_sign_epi8(qx[1], qx[1]); + auto s3 = _mm256_sign_epi8(qx[2], qx[2]); + auto s4 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi)); +#else + auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4))); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + +template <int nrc_y> +static void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = _mm256_set1_epi8(0xf); +#ifndef HAVE_FANCY_SIMD + auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); + auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); + auto values = MM256_SET_M128I(values128, values128); +#else + auto values = load_iq4nl_values_256(); +#endif + int nbl = n / QK_K; + using helper_t = union { __m256i vec; uint32_t val[8]; }; +#ifndef HAVE_FANCY_SIMD + helper_t h, h_shift; +#else + using helper512_t = union { __m512i vec; uint64_t val[8]; }; + helper_t h; + helper512_t h_shift; +#endif + __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; + __m256i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto dptr = (const float *)((const char *)vx + (ix+0)*bx); + const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4); + auto d4 = _mm_loadu_ps(dptr); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto scales = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales); + h.vec = _mm256_sub_epi8(_mm256_and_si256(scales, _mm256_set1_epi8(-2)), _mm256_set1_epi8(127)); +#ifndef HAVE_FANCY_SIMD + h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 2); + { + __m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[0])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[0]))))); + __m256 v2 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[1])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[1]))))); + __m256 v3 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[2])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[2]))))); + __m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[3])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[3]))))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto m8 = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); + acc[iy] = _mm256_fmadd_ps(v1, _mm256_shuffle_ps(m8, m8, 0x00), acc[iy]); + acc[iy] = _mm256_fmadd_ps(v2, _mm256_shuffle_ps(m8, m8, 0x55), acc[iy]); + acc[iy] = _mm256_fmadd_ps(v3, _mm256_shuffle_ps(m8, m8, 0xaa), acc[iy]); + acc[iy] = _mm256_fmadd_ps(v4, _mm256_shuffle_ps(m8, m8, 0xff), acc[iy]); + } + } +#else + auto shift = _mm256_add_epi8(_mm256_set1_epi8(-64), _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1)); + h_shift.vec = _mm512_mullo_epi16(_mm512_cvtepi8_epi16(shift), _mm512_cvtepi8_epi16(h.vec)); +#endif + for (int ib = 0; ib < QK_K/32; ++ib) { +#ifdef HAVE_FANCY_SIMD + auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib])); + auto ishifts = _mm256_cvtepi16_epi32(_mm_set1_epi64x(h_shift.val[ib])); + auto scales_m = _mm256_cvtepi32_ps(ishifts); + for (int iy = 0; iy < nrc_y; ++iy) { + float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]); + } +#endif + auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1); + qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); + qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); + qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); + qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); +#ifndef HAVE_FANCY_SIMD + auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle); + auto s1 = _mm256_sign_epi8(qx[0], qx[0]); + auto s2 = _mm256_sign_epi8(qx[1], qx[1]); + auto s3 = _mm256_sign_epi8(qx[2], qx[2]); + auto s4 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi)); +#else + auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4))); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, ibl)), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, _mm_mul_ps(d4, sum)); + } + } +} + +template <int nrc_y> +static void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = _mm256_set1_epi8(0xf); + __m256i values[2]; + { + auto val1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0); + auto val2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1); + values[0] = MM256_SET_M128I(val1, val1); + values[1] = MM256_SET_M128I(val2, val2); +#ifdef HAVE_FANCY_SIMD + values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128)); + values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128)); +#endif + } + int nbl = n / QK_K; + using helper_t = union { __m256i vec; uint32_t val[8]; }; +#ifndef HAVE_FANCY_SIMD + helper_t h, h_shift; + auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); +#else + using helper512_t = union { __m512i vec; uint64_t val[8]; }; + helper_t h; + helper512_t h_shift; +#endif + __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; + __m256i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto dptr = (const float *)((const char *)vx + (ix+0)*bx); + const block_iq5_ks_r4 * iq5 = (const block_iq5_ks_r4 *)(dptr + 4); + auto d4 = _mm_loadu_ps(dptr); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto scales = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales); + h.vec = _mm256_sub_epi8(_mm256_and_si256(scales, _mm256_set1_epi8(-2)), _mm256_set1_epi8(127)); +#ifndef HAVE_FANCY_SIMD + h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1); + { + __m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[0])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[0]))))); + __m256 v2 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[1])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[1]))))); + __m256 v3 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[2])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[2]))))); + __m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[3])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[3]))))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto m8 = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); + acc[iy] = _mm256_fmadd_ps(v1, _mm256_shuffle_ps(m8, m8, 0x00), acc[iy]); + acc[iy] = _mm256_fmadd_ps(v2, _mm256_shuffle_ps(m8, m8, 0x55), acc[iy]); + acc[iy] = _mm256_fmadd_ps(v3, _mm256_shuffle_ps(m8, m8, 0xaa), acc[iy]); + acc[iy] = _mm256_fmadd_ps(v4, _mm256_shuffle_ps(m8, m8, 0xff), acc[iy]); + } + } +#else + auto shift = _mm256_add_epi8(_mm256_set1_epi8(-64), _mm256_and_si256(scales, _mm256_set1_epi8(1))); + h_shift.vec = _mm512_mullo_epi16(_mm512_cvtepi8_epi16(shift), _mm512_cvtepi8_epi16(h.vec)); +#endif + for (int ib = 0; ib < QK_K/32; ++ib) { +#ifdef HAVE_FANCY_SIMD + auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib])); + auto ishifts = _mm256_cvtepi16_epi32(_mm_set1_epi64x(h_shift.val[ib])); + auto scales_m = _mm256_cvtepi32_ps(ishifts); + for (int iy = 0; iy < nrc_y; ++iy) { + float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]); + } +#endif + auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0); + auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1); + auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib); + auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits); + qx[0] = _mm256_and_si256(lbits1, m4); + qx[1] = _mm256_and_si256(lbits2, m4); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4); + + qx[0] = prepare_5bit_quants(values, qx[0], hb, _mm256_set1_epi8(0x01)); + qx[1] = prepare_5bit_quants(values, qx[1], hb, _mm256_set1_epi8(0x10)); + qx[2] = prepare_5bit_quants(values, qx[2], hb, _mm256_set1_epi8(0x02)); + qx[3] = prepare_5bit_quants(values, qx[3], hb, _mm256_set1_epi8(0x20)); + +#ifndef HAVE_FANCY_SIMD + auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle); + auto s1 = _mm256_sign_epi8(qx[0], qx[0]); + auto s2 = _mm256_sign_epi8(qx[1], qx[1]); + auto s3 = _mm256_sign_epi8(qx[2], qx[2]); + auto s4 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi)); +#else + auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4))); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, ibl)), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, _mm_mul_ps(d4, sum)); + } + } +} + + +template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) { +#ifdef HAVE_FANCY_SIMD + if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2KS> || + std::is_same_v<Dequantizer, DequantizerIQ4KS> || + std::is_same_v<Dequantizer, DequantizerIQ5KS>) { + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_iqX_k_q8_K_AVX512_new, Dequantizer, funcs) + } else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2K>) { + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_AVX512, Dequantizer, funcs); + funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>; + } else { + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_iqX_k_q8_K_AVX512, Dequantizer, funcs); + } +#else + if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2K>|| + std::is_same_v<Dequantizer, DequantizerIQ3K>|| + std::is_same_v<Dequantizer, DequantizerIQ4K>|| + std::is_same_v<Dequantizer, DequantizerIQ5K>|| + std::is_same_v<Dequantizer, DequantizerIQ6K>) { + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_K_T, Dequantizer, funcs); + } else { + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, Dequantizer, funcs); + } + +#endif +} + +} // namespace + +bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) { + + auto etypeA = ggml_type(typeA); + auto expected_type_B = etypeA == GGML_TYPE_IQ4_KS_R4 || etypeA == GGML_TYPE_IQ5_KS_R4 ? GGML_TYPE_Q8_K32 : GGML_TYPE_Q8_K; + if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) { + return false; + } + + func16 = nullptr; + + switch (typeA) { + case GGML_TYPE_IQ2_KS: + set_functions<DequantizerIQ2KS>(kernels); + break; + case GGML_TYPE_IQ2_K: + set_functions<DequantizerIQ2K>(kernels); + break; + case GGML_TYPE_IQ3_K: + set_functions<DequantizerIQ3K>(kernels); + break; + case GGML_TYPE_IQ4_KSS: + set_functions<DequantizerIQ4KSS>(kernels); + break; + case GGML_TYPE_IQ4_KS: + set_functions<DequantizerIQ4KS>(kernels); + break; + case GGML_TYPE_IQ4_K: + set_functions<DequantizerIQ4K>(kernels); + break; + case GGML_TYPE_IQ5_KS: + set_functions<DequantizerIQ5KS>(kernels); + break; + case GGML_TYPE_IQ5_K: + set_functions<DequantizerIQ5K>(kernels); + break; + case GGML_TYPE_IQ6_K: + set_functions<DequantizerIQ6K>(kernels); + break; + case GGML_TYPE_IQ2_K_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_k_r4_q8_k, kernels); + break; + case GGML_TYPE_IQ3_K_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_k_r4_q8_k, kernels); +#ifdef HAVE_FANCY_SIMD + func16 = mul_mat_iq3_k_r4_q8_k<16>; +#endif + break; + case GGML_TYPE_IQ4_K_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_k_r4_q8_k, kernels); + func16 = mul_mat_iq4_k_r4_q8_k<16>; + break; + case GGML_TYPE_IQ4_KS_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_ks_r4_q8_k, kernels); +#ifndef HAVE_FANCY_SIMD + // For some reason Zen4 does not like this particular function + func16 = mul_mat_iq4_ks_r4_q8_k<16>; +#endif + break; + case GGML_TYPE_IQ5_KS_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_ks_r4_q8_k, kernels); +#ifndef HAVE_FANCY_SIMD + // For some reason Zen4 does not like this particular function + func16 = mul_mat_iq5_ks_r4_q8_k<16>; +#endif + break; + case GGML_TYPE_IQ5_K_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_k_r4_q8_k, kernels); + func16 = mul_mat_iq5_k_r4_q8_k<16>; + break; + default: + return false; + } + + return true; + +} + +#else +// ----------------------------------------- __aarch64__ --------------------------------------------- + +namespace { + +inline int32x4x4_t make_wider(const int16x8x2_t& scales16) { + int32x4x4_t scales = { + vmovl_s16(vget_low_s16 (scales16.val[0])), + vmovl_s16(vget_high_s16(scales16.val[0])), + vmovl_s16(vget_low_s16 (scales16.val[1])), + vmovl_s16(vget_high_s16(scales16.val[1])), + }; + return scales; +} + +inline int32x4x4_t make_wider_8(const int8x16_t& scales8) { + int16x8x2_t scales16{vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8))}; + return make_wider(scales16); +} + +template <typename Q8> +inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums(iy, i); + int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0])); + int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0])); + int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1])); + int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1])); + float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4))); + acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i))); + } +} + +struct Scale16Extra { + template <typename Q8> + static inline int32x4x4_t new_block(int i, float d, uint16_t extra, uint8_t val, + const int8x16_t& scales8, const Q8& q8, float32x4_t * acc) { + uint8x16_t e8 = vreinterpretq_u8_u16(vdupq_n_u16(extra)); + e8 = vceqq_u8(vandq_u8(e8, emask), emask); + e8 = vqtbl1q_u8(vandq_u8(e8, vdupq_n_u8(val)), eshuff); + int16x8x2_t extra16 = {vmull_s8(vget_low_s8 (e8), vget_low_s8 (scales8)), + vmull_s8(vget_high_s8(e8), vget_high_s8(scales8))}; + accum_mins_16(extra16, q8, acc, i, d); + return make_wider_8(scales8); + } + + constexpr static uint32x4_t emask = {0x02020101, 0x08080404, 0x20201010, 0x80804040}; + constexpr static uint32x4_t eshuff = {0x06040200, 0x0e0c0a08, 0x07050301, 0x0f0d0b09}; +}; + +// Note: on ARM_NEON we cannot use the values shifted into the uint8_t range because +// the ARM_NEON only has vdotq_s32 or vdotq_u32, where both operands need to +// be signed or unsigned. As the Q8_K quants are signed, we need to have the +// iq4_s quants also signed. We can only use unsigned values in k-quants +// because they are all within the valid int8_t range. +struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> { + DequantizerIQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8(iq4k_values)) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template <typename Q8> + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return Scale16Extra::new_block(i, d, x[i].extra, 4, make_scales(x[i].scales_l, x[i].scales_h), q8, acc); + } + inline void prepare(int i, int j) { + bits.prepare16(x[i].qs+64*j); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]); + bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]); + } + } + inline int8x16_t make_scales(const uint8_t * scales_l, const uint8_t * scales_h) const { + uint8x8_t aux = vld1_u8(scales_l); + uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf)); + const uint32_t * aux32 = (const uint32_t *)scales_h; + uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2}; + uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30)); + int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, hshuff)); + return vaddq_s8(vqtbl1q_s8(scales8, hshuff), vdupq_n_s8(-32)); + } + + Q4bits bits; + const int8x16_t values; + const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); + +}; + +struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> { + DequantizerIQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq5nl_values)) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template <typename Q8> + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + hbits = vld1q_u8_x2(x[i].qh); // hbits.val[0] holds 0....15, 32...47, 64...79, 96...111, 128...143, 160...175, 192...207, 224...239 + // hbits.val[1] holds 16...31, 48...63, 80...95, 112..127, 144...159, 176...191, 208...223, 240...255 + return Scale16Extra::new_block(i, d, x[i].extra, 2, make_scales(x[i].scales_l, x[i].scales_h), q8, acc); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+64*j); + if (j == 1) { + for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4); + } + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm)); + bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm)); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm)); + bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm)); + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm)); + bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm)); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm)); + bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm)); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vqtbl2q_s8(values, bits.b1.val[k]); + bits.b2.val[k] = vqtbl2q_s8(values, bits.b2.val[k]); + } + } + inline int8x16_t make_scales(const uint8_t * scales_l, const uint8_t * scales_h) const { + uint8x8_t aux = vld1_u8(scales_l); + uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf)); + const uint32_t * aux32 = (const uint32_t *)scales_h; + uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2}; + uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30)); + int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, hshuff)); + return vaddq_s8(vqtbl1q_s8(scales8, hshuff), vdupq_n_s8(-32)); + } + + Q4bits bits; + const int8x16x2_t values; + const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); + const uint8x16_t hm = vdupq_n_u8(0x10); + uint8x16x2_t hbits; + +}; + +struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> { + DequantizerIQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x4(iq6nl_values)) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template <typename Q8> + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return Scale16Extra::new_block(i, d, x[i].extra, 1, vld1q_s8(x[i].scales), q8, acc); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+64*j); + auto hbits = vld1q_u8_x2(x[i].qh + 32*j); + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm)); + bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm)); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm)); + bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm)); + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hm)); + bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hm)); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), hm)); + bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), hm)); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vqtbl4q_s8(values, bits.b1.val[k]); + bits.b2.val[k] = vqtbl4q_s8(values, bits.b2.val[k]); + } + } + + Q4bits bits; + const int8x16x4_t values; + const uint8x16_t hm = vdupq_n_u8(0x30); + +}; + +struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> { + DequantizerIQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template <typename Q8> + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return Scale16Extra::new_block(i, d, x[i].extra, 5, make_scales(x[i].scales), q8, acc); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]); + bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]); + } + } + inline int8x16_t make_scales(const uint8_t * scales_l) const { + uint8x8_t aux = vld1_u8(scales_l); + uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf)); + int8x16_t scales = vaddq_s8(vreinterpretq_s8_u8(scl8), vdupq_n_s8(-8)); + return vqtbl1q_s8(scales, hshuff); + } + + Q2bits bits; + const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x000000001101f3e1)); + const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); + +}; + +struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> { + DequantizerIQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template <typename Q8> + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return Scale16Extra::new_block(i, d, x[i].extra, 4, make_scales(x[i].scales_h, x[i].scales_l), q8, acc); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + if (j == 0) { + hbits = vld1q_u8_x2(x[i].qh); + } + else { + hbits.val[0] = vshrq_n_u8(hbits.val[0], 4); + hbits.val[1] = vshrq_n_u8(hbits.val[1], 4); + } + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hmask)); + bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hmask)); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hmask)); + bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hmask)); + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hmask)); + bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hmask)); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 1), hmask)); + bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 1), hmask)); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]); + bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]); + } + } + inline int8x16_t make_scales(uint16_t sign_bits, const uint8_t * scales_l) const { + uint8x8_t aux = vld1_u8(scales_l); + uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf)); + int8x16_t scales = vaddq_s8(vreinterpretq_s8_u8(vshlq_n_u8(scl8, 1)), vdupq_n_s8(1)); + uint8x16_t signs = vceqq_u8(vandq_u8(vreinterpretq_u8_u16(vdupq_n_u16(sign_bits)), sign_mask), sign_mask); + signs = vorrq_u8(signs, vdupq_n_u8(1)); + // scales are 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 + // signs are 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15 + scales = vmulq_s8(scales, vreinterpretq_s8_u8(vqtbl1q_u8(signs, sign_shuffle))); + return vqtbl1q_s8(scales, hshuff); + } + inline static uint8x16_t load_sign_shuffle() { + static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; + return vld1q_u8(k_shuff); + } + + Q2bits bits; + uint8x16x2_t hbits; + const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x2f1c0d01f6e9d8c1)); + const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); + const uint8x16_t hmask = vdupq_n_u8(4); + const uint8x16_t sign_mask = vreinterpretq_u8_u64(uint64x2_t{0x0808040402020101, 0x8080404020201010}); + const uint8x16_t sign_shuffle = load_sign_shuffle(); + +}; + +struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> { + + DequantizerIQ4KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template <typename Q8> + inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { + (void)q8; + (void)acc; + auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(vld1_u8(x[i].scales)), mask)), m127); + int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))}; + return scales; + } + inline void prepare(int i, int j) { + bits.prepare16(x[i].qs+64*j); + const uint32_t * scales32 = (const uint32_t *)x[i].scales; + uint32_t aux32 = scales32[j] & 0x01010101; + const uint8_t * aux8 = (const uint8_t *)&aux32; + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values.val[aux8[k/2+0]], bits.b1.val[k])); + bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values.val[aux8[k/2+2]], bits.b2.val[k])); + } + } + + Q4bits bits; + const int8x16x2_t values; + const uint16x8_t mask = vdupq_n_u16(254); + const int16x8_t m127 = vdupq_n_s16(-127); +}; + +struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> { + DequantizerIQ5KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), + values(vld1q_s8_x4(iq5nl_values)) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template <typename Q8> + inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { + (void)q8; + (void)acc; + auto sas8 = vld1_u8(x[i].scales); + auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(sas8), mask)), m127); + hbits = vld1q_u8_x2(x[i].qh); + sas = vcombine_u8(sas8, sas8); + sas = vshlq_n_u8(vandq_u8(sas, vdupq_n_u8(1)), 5); + int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))}; + return scales; + } + + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+64*j); + if (j == 1) { + for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4); + } + auto shift = vdupq_n_u8((x[i].scales[4*j+0] & 1) << 5); + bits.b1.val[0] = vaddq_u8(shift, vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm))); + bits.b1.val[1] = vaddq_u8(shift, vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm))); + shift = vdupq_n_u8((x[i].scales[4*j+1] & 1) << 5); + bits.b1.val[2] = vaddq_u8(shift, vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm))); + bits.b1.val[3] = vaddq_u8(shift, vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm))); + for (int k = 0; k < 4; ++k) bits.b1.val[k] = vqtbl4q_s8(values, bits.b1.val[k]); + shift = vdupq_n_u8((x[i].scales[4*j+2] & 1) << 5); + bits.b2.val[0] = vaddq_u8(shift, vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm))); + bits.b2.val[1] = vaddq_u8(shift, vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm))); + shift = vdupq_n_u8((x[i].scales[4*j+3] & 1) << 5); + bits.b2.val[2] = vaddq_u8(shift, vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm))); + bits.b2.val[3] = vaddq_u8(shift, vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm))); + for (int k = 0; k < 4; ++k) bits.b2.val[k] = vqtbl4q_s8(values, bits.b2.val[k]); + } + + Q4bits bits; + const int8x16x4_t values; + const uint8x16_t hm = vdupq_n_u8(0x10); + const uint16x8_t mask = vdupq_n_u16(254); + const int16x8_t m127 = vdupq_n_s16(-127); + uint8x16x2_t hbits; + uint8x16_t sas; + +}; + +struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> { + + DequantizerIQ4KSS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template <typename Q8> + inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { + (void)q8; + (void)acc; + auto q4bits_1 = vld1q_u16_x4((const uint16_t *)x[i].qs); + q4bits_2 = vld1q_u16_x4((const uint16_t *)x[i].qs + 32); + for (int k = 0; k < 4; ++k) { + aux[k+0] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_1.val[k], m1), shift)); + aux[k+4] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_2.val[k], m1), shift)); + q4bits_1.val[k] = vandq_u16(q4bits_1.val[k], bmask); + q4bits_1.val[k] = veorq_u16(q4bits_1.val[k], vshrq_n_u16(q4bits_1.val[k], 1)); + q4bits_2.val[k] = vandq_u16(q4bits_2.val[k], bmask); + q4bits_2.val[k] = veorq_u16(q4bits_2.val[k], vshrq_n_u16(q4bits_2.val[k], 1)); + } + make_quants(q4bits_1, bits, aux); + auto scales16 = vld1q_s16(aux); + scales16 = vaddq_s16(vandq_s16(scales16, mask), m127); + int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))}; + return scales; + } + inline void make_quants(uint16x8x4_t& q4bits, Q4bits& bits, const int16_t * aux) const { + bits.b1.val[0] = vqtbl1q_s8(values.val[aux[0] & 1], vandq_u8(q4bits.val[0], bits.m4b)); + bits.b1.val[1] = vqtbl1q_s8(values.val[aux[0] & 1], vshrq_n_u8(q4bits.val[0], 4)); + bits.b1.val[2] = vqtbl1q_s8(values.val[aux[1] & 1], vandq_u8(q4bits.val[1], bits.m4b)); + bits.b1.val[3] = vqtbl1q_s8(values.val[aux[1] & 1], vshrq_n_u8(q4bits.val[1], 4)); + bits.b2.val[0] = vqtbl1q_s8(values.val[aux[2] & 1], vandq_u8(q4bits.val[2], bits.m4b)); + bits.b2.val[1] = vqtbl1q_s8(values.val[aux[2] & 1], vshrq_n_u8(q4bits.val[2], 4)); + bits.b2.val[2] = vqtbl1q_s8(values.val[aux[3] & 1], vandq_u8(q4bits.val[3], bits.m4b)); + bits.b2.val[3] = vqtbl1q_s8(values.val[aux[3] & 1], vshrq_n_u8(q4bits.val[3], 4)); + } + inline void prepare([[maybe_unused]] int i, int j) { + if (j == 0) return; + make_quants(q4bits_2, bits, aux+4); + } + static int16x8_t load_shift() { + static const int16_t k_shift[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + return vld1q_s16(k_shift); + } + + Q4bits bits; + const int8x16x2_t values; + const uint16x8_t mask = vdupq_n_s16(254); + const uint16x8_t bmask = vdupq_n_u16(0xfffe); + const uint16x8_t m1 = vdupq_n_u16(1); + const int16x8_t shift = load_shift(); + const int16x8_t m127 = vdupq_n_s16(-127); + uint16x8x4_t q4bits_2; + int16_t aux[8]; +}; + +struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> { + DequantizerIQ2KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template <typename Q8> + inline int32x4x2_t new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) { + const uint16_t * sc16 = (const uint16_t *)x[i].scales; + uint32_t aux32 = sc16[0] | (sc16[1] << 16); + uint8x8_t scales8 = vreinterpret_u8_u32(vdup_n_u32(aux32)); + scales8 = vand_u8(vzip1_u8(scales8, vshr_n_u8(scales8, 4)), vdup_n_u8(0xf)); + uint8x8_t sh = vand_u8(vceq_u8(vand_u8(vdup_n_u8(x[i].extra >> 8), hmask), vdup_n_u8(0)), vdup_n_u8(16)); + int16x8_t scales16 = vmovl_s8(vsub_s8(vreinterpret_s8_u8(scales8), vreinterpret_s8_u8(sh))); + int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))}; + return scales; + } + inline void prepare(int i, int j) { + uint8_t extra = x[i].extra >> 4*j; + bits.prepare(x[i].qs+32*j); + bits.b1.val[0] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[0]); + bits.b1.val[1] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[1]); extra >>= 1; + bits.b1.val[2] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[2]); + bits.b1.val[3] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[3]); extra >>= 1; + bits.b2.val[0] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[0]); + bits.b2.val[1] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[1]); extra >>= 1; + bits.b2.val[2] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[2]); + bits.b2.val[3] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[3]); + } + + Q2bits bits; + const uint8x8_t hmask = vreinterpret_u8_u64(vdup_n_u64(0x8040201008040201)); + const int8x16x2_t values = { vreinterpretq_s8_u64(vdupq_n_u64(0x1101f3e1)), vreinterpretq_s8_u64(vdupq_n_u64(0x1606f8e6)) }; + +}; + +template <int nrc_y> +void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = vdupq_n_u8(0xf); + auto values = vld1q_s8(iq4k_values); + int nbl = n / QK_K; + int8x16_t qx[8]; + int16x8x4_t iscales; + int32x4x4_t scales; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto dptr = (const float *)((const char *)vx + ix*bx); + auto d4 = vld1q_f32(dptr); + const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto sas = vld1q_u8_x2(iq4[ibl].scales); + auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254)); + iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127)); + iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127)); + scale = vandq_u8(sas.val[1], vdupq_n_u8(254)); + iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127)); + iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127)); + // Adding the block shifts costs us ~9% in performance drop. + // Is there a better way? + sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 2); + sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 2); + { + auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0]))); + auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0]))); + auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1]))); + auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1]))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums); + auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]); + auto b8 = vget_low_s16(bs); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3); + b8 = vget_high_s16(bs); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3); + } + } + for (int is = 0; is < 2; ++is) { + scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0])); + scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0])); + scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1])); + scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1])); + for (int ib = 0; ib < 4; ++ib) { + auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib); + prepare_iq4_nl_quants(values, m4, bits, qx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vmulq_f32(d4, acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template <int nrc_y> +void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = vdupq_n_u8(0xf); + auto m10 = vdupq_n_u8(0x10); + auto values = vld1q_s8_x2(iq5nl_values); + int nbl = n / QK_K; + int8x16_t qx[8]; + int16x8x4_t iscales; + int32x4x4_t scales; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto dptr = (const float *)((const char *)vx + ix*bx); + auto d4 = vld1q_f32(dptr); + const block_iq5_ks_r4 * iq5 = (const block_iq5_ks_r4 *)(dptr + 4); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto sas = vld1q_u8_x2(iq5[ibl].scales); + auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254)); + iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127)); + iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127)); + scale = vandq_u8(sas.val[1], vdupq_n_u8(254)); + iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127)); + iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127)); + // Adding the block shifts costs us ~9% in performance drop. + // Is there a better way? + sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 1); + sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 1); + { + auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0]))); + auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0]))); + auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1]))); + auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1]))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums); + auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]); + auto b8 = vget_low_s16(bs); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3); + b8 = vget_high_s16(bs); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3); + } + } + for (int is = 0; is < 2; ++is) { + scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0])); + scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0])); + scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1])); + scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1])); + for (int ib = 0; ib < 4; ++ib) { + auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib); + auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib); + qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4))); + qx[1] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2))); + qx[2] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits)); + qx[3] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2))); + qx[4] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3))); + qx[5] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1))); + qx[6] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1))); + qx[7] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3))); + for (int l = 0; l < 8; ++l) qx[l] = vqtbl2q_s8(values, qx[l]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vmulq_f32(d4, acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template <int nrc_y, int k_shift> +inline void iq3_4_add_shift(int ibl, const Q8<nrc_y, block_q8_K>& q8, const int8x16x4_t& i8scales, uint8x16_t extra, + int32x4_t * isum) { + auto ms = vdupq_n_s8(k_shift); + int8x16_t s8_1, s8_2; + if constexpr (k_shift == 5) { + auto m1 = vdupq_n_u8(1); + s8_1 = vmulq_s8(i8scales.val[0], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2); + s8_2 = vmulq_s8(i8scales.val[1], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2); + } else { + if constexpr (k_shift == 4) { + s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 2))); + s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, extra)); + } else { + s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 1))); + s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, vshrq_n_u8(extra, 1))); + } + } + auto s16_1 = vmovl_s8(vget_low_s8 (s8_1)); + auto s16_2 = vmovl_s8(vget_high_s8(s8_1)); + auto s16_3 = vmovl_s8(vget_low_s8 (s8_2)); + auto s16_4 = vmovl_s8(vget_high_s8(s8_2)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto b8 = vld1_s16(q8.y[iy][ibl].bsums); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3); + b8 = vld1_s16(q8.y[iy][ibl].bsums+4); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3); + } + if constexpr (k_shift == 5) { + auto m1 = vdupq_n_u8(1); + s8_1 = vmulq_s8(i8scales.val[2], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2); + s8_2 = vmulq_s8(i8scales.val[3], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2); + } else { + if constexpr (k_shift == 4) { + s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 2))); + s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 4))); + } else { + s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 3))); + s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 5))); + } + } + s16_1 = vmovl_s8(vget_low_s8 (s8_1)); + s16_2 = vmovl_s8(vget_high_s8(s8_1)); + s16_3 = vmovl_s8(vget_low_s8 (s8_2)); + s16_4 = vmovl_s8(vget_high_s8(s8_2)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto b8 = vld1_s16(q8.y[iy][ibl].bsums+8); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3); + b8 = vld1_s16(q8.y[iy][ibl].bsums+12); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3); + } +} + +template <int nrc_y> +void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = vdupq_n_u8(0xf); + auto m03 = vdupq_n_u8(0x03); + auto ms = vdupq_n_u8(4); + uint8x16x2_t shift_shuffle = { + vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}), + vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606}) + }; + auto values8 = vld1_s8(iq2nl_values); + auto values = vcombine_s8(values8, values8); + int nbl = n / QK_K; + int8x16_t qx[4]; + int8x16x4_t i8scales; + int16x8x4_t i16scales; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq2_k_r4 * iq2 = (const block_iq2_k_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d)); + auto extra8 = vld1_u8(iq2[ibl].extra); + uint8x16_t extra; + if constexpr (nrc_y == 1) { + extra = vcombine_u8(extra8, vshr_n_u8(extra8,1)); + } else { + extra = vcombine_u8(extra8, extra8); + } + auto sl = vld1q_u8_x2(iq2[ibl].scales); + i8scales.val[0] = vaddq_s8(vandq_u8(sl.val[0], m4), vdupq_n_s8(-8)); + i8scales.val[1] = vaddq_s8(vandq_u8(sl.val[1], m4), vdupq_n_s8(-8)); + i8scales.val[2] = vaddq_s8(vshrq_n_u8(sl.val[0], 4), vdupq_n_s8(-8)); + i8scales.val[3] = vaddq_s8(vshrq_n_u8(sl.val[1], 4), vdupq_n_s8(-8)); + int32x4_t isum[nrc_y] = {}; + if constexpr (nrc_y == 1) { + iq3_4_add_shift<nrc_y, 5>(ibl, q8, i8scales, extra, isum); + } + for (int is = 0; is < 2; ++is) { + i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0])); + i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0])); + i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1])); + i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1])); + for (int ib = 0; ib < 4; ++ib) { + auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib])); + auto bits = vld1q_u8_x2(iq2[ibl].qs + 128*is + 32*ib); + qx[0] = vandq_u8( bits.val[0], m03); + qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m03); + qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m03); + qx[3] = vandq_u8(vshrq_n_u8(bits.val[0], 6), m03); + uint8x16_t shifts; + if constexpr (nrc_y == 1) { + qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7 + qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11 + qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15 + } else { + shifts = vandq_u8(ms, vshlq_n_u8(extra, 2)); + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]); + extra = vshrq_n_u8(extra, 1); + qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7 + qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11 + qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15 + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qx[0] = vandq_u8( bits.val[1], m03); + qx[1] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m03); + qx[2] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m03); + qx[3] = vandq_u8(vshrq_n_u8(bits.val[1], 6), m03); + if constexpr (nrc_y == 1) { + qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7 + qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11 + qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15 + } else { + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]); + qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7 + qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11 + qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15 + } + scales = vmovl_s16(vget_high_s16(i16scales.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template <int nrc_y> +void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = vdupq_n_u8(0xf); + auto ms = nrc_y == 1 ? vdupq_n_u8(4) : vdupq_n_u8(8); + auto m03 = vdupq_n_u8(0x03); + auto m04 = vdupq_n_u8(0x04); + uint8x16x2_t shift_shuffle = { + vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}), + vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606}) + }; + uint8x16x2_t smask = { vcombine_u8(vdup_n_u8(1), vdup_n_u8(2)), vcombine_u8(vdup_n_u8(4), vdup_n_u8(8)) }; + auto values = vld1q_s8(iq3nl_values); + int nbl = n / QK_K; + int8x16_t qx[4]; + int8x16x4_t i8scales; + int16x8x4_t i16scales; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d)); + auto extra8 = vld1_u8(iq3[ibl].extra); + uint8x16_t extra; + if constexpr (nrc_y == 1) { + extra = vcombine_u8(extra8, vshr_n_u8(extra8,1)); + } else { + extra = vcombine_u8(extra8, extra8); + } + auto sl = vld1q_u8_x2(iq3[ibl].scales_l); + auto sh8 = vld1_u8(iq3[ibl].scales_h); + auto sh = vcombine_u8(sh8, sh8); + i8scales.val[0] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[0], m4), 1), vdupq_n_s8(1)); + i8scales.val[1] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[1], m4), 1), vdupq_n_s8(1)); + i8scales.val[2] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[0], 4), 1), vdupq_n_s8(1)); + i8scales.val[3] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[1], 4), 1), vdupq_n_s8(1)); + i8scales.val[0] = vmulq_s8(i8scales.val[0], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1))); + i8scales.val[1] = vmulq_s8(i8scales.val[1], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1))); + sh = vshrq_n_u8(sh, 4); + i8scales.val[2] = vmulq_s8(i8scales.val[2], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1))); + i8scales.val[3] = vmulq_s8(i8scales.val[3], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1))); + int32x4_t isum[nrc_y] = {}; + if constexpr (nrc_y == 1) { + iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum); + } + for (int is = 0; is < 2; ++is) { + i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0])); + i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0])); + i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1])); + i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1])); + for (int ib = 0; ib < 4; ++ib) { + auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib])); + auto lbits = vld1q_u8_x2(iq3[ibl].qs + 128*is + 32*ib); + auto hbits = vld1q_u8(iq3[ibl].qh + 64*is + 16*ib); + qx[0] = vorrq_u8(vandq_u8( lbits.val[0], m03), vandq_u8(m04, vshlq_n_u8(hbits, 2))); + qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 2), m03), vandq_u8(m04, vshlq_n_u8(hbits, 1))); + qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 4), m03), vandq_u8(m04, hbits)); + qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 1))); + uint8x16_t shifts; + if constexpr (nrc_y == 1) { + qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7 + qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11 + qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15 + } else { + shifts = vandq_u8(ms, vshlq_n_u8(extra, 3)); + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]); + extra = vshrq_n_u8(extra, 1); + qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7 + qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11 + qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15 + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qx[0] = vorrq_u8(vandq_u8( lbits.val[1], m03), vandq_u8(m04, vshrq_n_u8(hbits, 2))); + qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 2), m03), vandq_u8(m04, vshrq_n_u8(hbits, 3))); + qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 4), m03), vandq_u8(m04, vshrq_n_u8(hbits, 4))); + qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 5))); + if constexpr (nrc_y == 1) { + qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7 + qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11 + qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15 + } else { + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]); + qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7 + qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11 + qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15 + } + scales = vmovl_s16(vget_high_s16(i16scales.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template <int nrc_y> +void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = vdupq_n_u8(0xf); + auto m3 = vdupq_n_u8(0x30); + auto ms = vdupq_n_u8(4); + auto m32 = vdupq_n_s8(-32); + uint8x16x2_t shift_shuffle = { + vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}), + vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606}) + }; + auto values = vld1q_s8(iq4k_values); + int nbl = n / QK_K; + int8x16_t qx[4]; + int8x16x4_t i8scales; + int16x8x4_t i16scales; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq4_k_r4 * iq4 = (const block_iq4_k_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d)); + auto extra8 = vld1_u8(iq4[ibl].extra); + uint8x16_t extra; + if constexpr (nrc_y == 1) { + extra = vcombine_u8(extra8, vshr_n_u8(extra8,1)); + } else { + extra = vcombine_u8(extra8, extra8); + } + auto sl = vld1q_u8_x2(iq4[ibl].scales_l); + auto sh = vld1q_u8(iq4[ibl].scales_h); + i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32); + i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32); + i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32); + i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32); + int32x4_t isum[nrc_y] = {}; + if constexpr (nrc_y == 1) { + iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum); + } + for (int is = 0; is < 2; ++is) { + i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0])); + i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0])); + i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1])); + i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1])); + for (int ib = 0; ib < 4; ++ib) { + auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib); + uint8x16_t shifts; + if constexpr (nrc_y == 1) { + qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7 + qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11 + qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15 + } else { + shifts = vandq_u8(ms, vshlq_n_u8(extra, 2)); + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]); + extra = vshrq_n_u8(extra, 1); + qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[0], m4))); // 0...3 from the 4 rows + qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[2], m4))); // 4...7 + qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4))); // 8..11 + qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4))); // 12..15 + } + auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + if constexpr (nrc_y == 1) { + qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19 + qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23 + qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27 + qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31 + } else { + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]); + qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[1], m4))); // 16..19 + qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[3], m4))); // 20..23 + qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4))); // 24..27 + qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4))); // 28..31 + } + scales = vmovl_s16(vget_high_s16(i16scales.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template <int nrc_y> +void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + auto m4 = vdupq_n_u8(0xf); + auto m3 = vdupq_n_u8(0x30); + auto ms = vdupq_n_u8(2); + auto m32 = vdupq_n_s8(-32); + auto m10 = vdupq_n_u8(0x10); + uint8x16x2_t shift_shuffle = { + vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}), + vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606}) + }; + auto values = vld1q_s8_x2(iq5nl_values); + int nbl = n / QK_K; + int8x16_t qx[4]; + int8x16x4_t i8scales; + int16x8x4_t i16scales; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d)); + auto extra8 = vld1_u8(iq5[ibl].extra); + uint8x16_t extra; + if constexpr (nrc_y == 1) { + extra = vcombine_u8(extra8, vshr_n_u8(extra8,1)); + } else { + extra = vcombine_u8(extra8, extra8); + } + auto sl = vld1q_u8_x2(iq5[ibl].scales_l); + auto sh = vld1q_u8(iq5[ibl].scales_h); + i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32); + i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32); + i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32); + i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32); + int32x4_t isum[nrc_y] = {}; + if constexpr (nrc_y == 1) { + iq3_4_add_shift<nrc_y, 2>(ibl, q8, i8scales, extra, isum); + } + for (int is = 0; is < 2; ++is) { + i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0])); + i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0])); + i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1])); + i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1])); + for (int ib = 0; ib < 4; ++ib) { + auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib); + auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib); + qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4))); // aligns with 1st half of qx[0] in AVX2 + qx[1] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits)); // aligns with 1st half of qx[1] in AVX2 + qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3))); // aligns with 1st half of qx[2] in AVX2 + qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1))); // aligns with 1st half of qx[3] in AVX2 + uint8x16_t shifts; + if constexpr (nrc_y == 1) { + qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows + qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7 + qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11 + qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15 + } else { + shifts = vandq_u8(ms, vshlq_n_u8(extra, 1)); + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]); + extra = vshrq_n_u8(extra, 1); + qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows + qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7 + qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11 + qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15 + } + auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qx[0] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2))); // aligns with 2nd half of qx[0] in AVX2 + qx[1] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2))); // aligns with 2nd half of qx[1] in AVX2 + qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1))); // aligns with 2nd half of qx[2] in AVX2 + qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3))); // aligns with 2nd half of qx[3] in AVX2 + if constexpr (nrc_y == 1) { + qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows + qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7 + qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11 + qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15 + } else { + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]); + qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows + qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7 + qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11 + qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15 + } + scales = vmovl_s16(vget_high_s16(i16scales.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +} + +bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, [[maybe_unused]] mul_mat_t& func16) { + + if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) { + return false; + } + + func16 = nullptr; + + switch (typeA) { + case GGML_TYPE_IQ2_KS: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ2KS, kernels); + break; + case GGML_TYPE_IQ2_K: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ2K, kernels); + break; + case GGML_TYPE_IQ3_K: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ3K, kernels); + break; + case GGML_TYPE_IQ4_KSS: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ4KSS, kernels); + break; + case GGML_TYPE_IQ4_KS: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ4KS, kernels); + break; + case GGML_TYPE_IQ4_K: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ4K, kernels); + break; + case GGML_TYPE_IQ5_KS: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ5KS, kernels); + break; + case GGML_TYPE_IQ5_K: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ5K, kernels); + break; + case GGML_TYPE_IQ6_K: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ6K, kernels); + break; + case GGML_TYPE_IQ2_K_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_k_r4_q8_k, kernels); + break; + case GGML_TYPE_IQ3_K_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_k_r4_q8_k, kernels); + break; + case GGML_TYPE_IQ4_K_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_k_r4_q8_k, kernels); + break; + case GGML_TYPE_IQ4_KS_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_ks_r4_q8_k, kernels); + break; + case GGML_TYPE_IQ5_KS_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_ks_r4_q8_k, kernels); + break; + case GGML_TYPE_IQ5_K_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_k_r4_q8_k, kernels); + break; + default: + return false; + } + + return true; + +} + +#endif + +#endif |