summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_common.h
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_common.h')
-rw-r--r--ggml/src/iqk/iqk_common.h695
1 files changed, 693 insertions, 2 deletions
diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h
index dc3e369f..6feeff1a 100644
--- a/ggml/src/iqk/iqk_common.h
+++ b/ggml/src/iqk/iqk_common.h
@@ -7,6 +7,8 @@
// SPDX-License-Identifier: MIT
//
+#pragma once
+
#include "iqk_config.h"
#if defined IQK_IMPLEMENT
@@ -14,6 +16,7 @@
#include <cstring>
#include <type_traits>
#include <vector>
+#include <cstdint>
#include "ggml-impl.h"
#include "ggml-quants.h"
@@ -79,8 +82,6 @@ struct Perf {
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
#endif
-namespace {
-
typedef struct {
int32_t i1;
int32_t i2;
@@ -135,4 +136,694 @@ struct DataInfo {
typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x);
+#define IQK_MAX_NY 8
+
+#define IQK_SET_MUL_MAT_FUNCTIONS_T(kernel, Dequantizer, funcs) \
+ funcs[0] = kernel<Dequantizer, 1>;\
+ funcs[1] = kernel<Dequantizer, 2>;\
+ funcs[2] = kernel<Dequantizer, 3>;\
+ funcs[3] = kernel<Dequantizer, 4>;\
+ funcs[4] = kernel<Dequantizer, 5>;\
+ funcs[5] = kernel<Dequantizer, 6>;\
+ funcs[6] = kernel<Dequantizer, 7>;\
+ funcs[7] = kernel<Dequantizer, 8>;\
+
+#define IQK_SET_MUL_MAT_FUNCTIONS(kernel, funcs) \
+ funcs[0] = kernel<1>;\
+ funcs[1] = kernel<2>;\
+ funcs[2] = kernel<3>;\
+ funcs[3] = kernel<4>;\
+ funcs[4] = kernel<5>;\
+ funcs[5] = kernel<6>;\
+ funcs[6] = kernel<7>;\
+ funcs[7] = kernel<8>;\
+
+
+// ==================================================================================================
+
+static inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {
+ const uint16_t * scales = (const uint16_t *)scales8;
+ const uint32_t a0 = scales[0] | (scales[1] << 16);
+ const uint32_t a1 = scales[2] | (scales[3] << 16);
+ const uint32_t a2 = scales[4] | (scales[5] << 16);
+ aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030);
+ aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030);
+ aux32[2] = a1 & 0x3f3f3f3f;
+ aux32[0] = a0 & 0x3f3f3f3f;
+}
+
+#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__)
+const uint64_t keven_signs[128] = {
+ 0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,
+ 0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,
+ 0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff,
+ 0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff,
+ 0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff,
+ 0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff,
+ 0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff,
+ 0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff,
+ 0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff,
+ 0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff,
+ 0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff,
+ 0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff,
+ 0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff,
+ 0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff,
+ 0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff,
+ 0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff,
+ 0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff,
+ 0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff,
+ 0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff,
+ 0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff,
+ 0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff,
+ 0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff,
+ 0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff,
+ 0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff,
+ 0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff,
+ 0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff,
+ 0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff,
+ 0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff,
+ 0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff,
+ 0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff,
+ 0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,
+ 0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,
+};
+#endif
+
+#ifdef __AVX2__
+
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
+static inline float hsum_float_4(__m128 x) {
+ x = _mm_add_ps(x, _mm_movehl_ps(x, x));
+ x = _mm_add_ss(x, _mm_movehdup_ps(x));
+ return _mm_cvtss_f32(x);
+}
+static inline float hsum_float_8(__m256 x) {
+ return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
+}
+static inline int hsum_i32_8(const __m256i a) {
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+static inline float hmax_float_8(__m256 x) {
+ __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
+ max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4));
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
+ return _mm_cvtss_f32(max4);
+}
+
+static inline __m256 hsum_float_8x8(__m256 * accm) {
+ for (int i = 0; i < 4; ++i) {
+ accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31));
+ //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
+ // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
+ }
+ for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
+ return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
+}
+
+static inline __m128i load_iq4nl_values_128() {
+ static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
+ return _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
+}
+
+static inline __m256i load_iq4nl_values_256() {
+ auto val128 = load_iq4nl_values_128();
+ return MM256_SET_M128I(val128, val128);
+}
+
+#ifdef HAVE_FANCY_SIMD
+static inline __m512i load_iq4nl_values_512() {
+ auto val256 = load_iq4nl_values_256();
+ return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
+}
+#endif
+
+static inline __m128i load_iq4k_values_128() {
+ return _mm_loadu_si128((const __m128i *)iq4k_values);
+}
+
+static inline __m256i load_iq4k_values_256() {
+ auto val128 = load_iq4k_values_128();
+ return MM256_SET_M128I(val128, val128);
+}
+
+template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);
+ }
+
+#ifdef HAVE_FANCY_SIMD
+ inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); }
+#endif
+ inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); }
+ inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); }
+ inline float scale(int iy, int i) const { return y[iy][i].d; }
+
+ const block_q8 * y[nrc_y];
+};
+
+template <int nrc> struct Q8_16 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8_16(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto ptr = (const float *)info.src1_row(iy);
+ std::memcpy(d + 5*iy, ptr, 5*sizeof(float));
+ y[iy] = (const int8_t *)(ptr + 5);
+ }
+ }
+
+#ifdef HAVE_FANCY_SIMD
+ inline __m512i load_quants64(int iy, int i) const { return _mm512_loadu_si512((const __m512i*)y[iy] + i); }
+#endif
+ inline __m256i load_quants(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy] + i); }
+ inline float scale(int iy, int k) const { return d[5*iy+k]; }
+ inline float sum_row(int iy) const { return d[5*iy + 4]; }
+ inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 5*iy); }
+
+ float d[5*nrc_y];
+ const int8_t * y[nrc_y];
+};
+
+struct Scales8KBase {
+ template <typename Q8>
+ inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
+ const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i q8s = q8.load_bsums(iy, i);
+ const __m256i prod = _mm256_madd_epi16(mins, q8s);
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
+ }
+ }
+ inline __m256i shuffle(__m128i mins) const {
+ return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0]));
+ }
+ const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),
+ _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};
+};
+
+template <typename Block, bool per_row_scale = false, bool is_f16 = false>
+struct BaseDequantizer {
+ BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {}
+ inline void new_row(int ix) {
+ if constexpr (per_row_scale) {
+ if constexpr (is_f16) {
+ const ggml_half * dptr = (const ggml_half *)((const char *)vx + bx*ix);
+ d = GGML_FP16_TO_FP32(*dptr);
+ x = (const Block *)(dptr + 1);
+ } else {
+ const float * dptr = (const float *)((const char *)vx + bx*ix);
+ d = *dptr;
+ x = (const Block *)(dptr + 1);
+ }
+ } else {
+ x = (const Block *)((const char *)vx + bx*ix);
+ }
+ }
+
+ const void * vx;
+ const size_t bx;
+ const Block * x;
+
+ float d;
+};
+
+template <typename Q8, typename Bits>
+static inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
+ if (j == 0) {
+#ifdef HAVE_FANCY_SIMD
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
+ }
+#else
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
+ const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
+ const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
+ const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
+ sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));
+ }
+#endif
+ } else {
+#ifdef HAVE_FANCY_SIMD
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
+ }
+#else
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
+ const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
+ const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
+ const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
+ sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));
+ sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));
+ }
+#endif
+ }
+}
+
+template <typename Q8, typename Bits>
+static inline void multiply_add_avx2(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
+ __m256i p[4];
+ if (j == 0) {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ for (int k = 0; k < 4; ++k) {
+ auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
+ p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, k), bits.values[k])));
+ }
+ sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p[0], p[1]), _mm256_add_epi32(p[2], p[3]));
+ }
+ } else {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ for (int k = 0; k < 4; ++k) {
+ auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
+ p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, 4+k), bits.values[k])));
+ }
+ sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[0], p[2]));
+ sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[1], p[3]));
+ }
+ }
+}
+
+#ifdef HAVE_FANCY_SIMD
+
+struct BlockPermuter {
+ const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
+ const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
+};
+
+struct Q4Bits {
+ inline void prepare(const uint8_t * q4) {
+ auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
+ auto tmp1 = _mm512_and_si512(q4bits, ml);
+ auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
+ values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
+ values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
+ q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
+ tmp1 = _mm512_and_si512(q4bits, ml);
+ tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
+ values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
+ values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
+ }
+ inline void prepare64(const uint8_t * q4) {
+ auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
+ values[0] = _mm512_and_si512(q4bits, ml);
+ values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
+ q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
+ values[2] = _mm512_and_si512(q4bits, ml);
+ values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
+ }
+ inline void prepare64a(const uint8_t * q4) {
+ for (int k = 0; k < 4; ++k) {
+ auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + k);
+ values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(q4bits), _mm256_srli_epi16(q4bits, 4), 1);
+ values[k] = _mm512_and_si512(values[k], ml);
+ }
+ }
+ __m512i values[4];
+ const __m512i ml = _mm512_set1_epi8(0xf);
+ const BlockPermuter perm;
+};
+
+struct Q2Bits {
+ inline void prepare(const uint8_t * q2) {
+
+ auto q2bits = _mm512_loadu_si512((const __m512i*)q2);
+ auto tmp = _mm512_srli_epi16(q2bits, 2);
+
+ values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);
+ values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);
+ values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);
+ values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);
+ values[0] = _mm512_and_si512(values[0], ml);
+ values[2] = _mm512_and_si512(values[2], ml);
+ }
+ __m512i values[4];
+ const __m512i ml = _mm512_set1_epi8(0x03);
+ BlockPermuter perm;
+};
+
+#else
+
+struct Q2Bits {
+ inline void prepare(const uint8_t * q2, int j) {
+ auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);
+ values[0] = _mm256_and_si256(q2bits, ml);
+ values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
+ values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
+ values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
+ }
+ __m256i values[4];
+ const __m256i ml = _mm256_set1_epi8(0x03);
+};
+
+struct Q4Bits {
+ inline void prepare(const uint8_t * q4, int j) {
+ auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
+ values[0] = _mm256_and_si256(q4bits, ml);
+ values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
+ q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
+ values[2] = _mm256_and_si256(q4bits, ml);
+ values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
+ }
+ inline void prepare64(const uint8_t * q4, int j) {
+ auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
+ values[0] = _mm256_and_si256(q4bits, ml);
+ values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
+ q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
+ values[1] = _mm256_and_si256(q4bits, ml);
+ values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
+ }
+ inline void prepare16(const uint8_t * q4, int j) {
+ values[0] = dequant16(q4 + 64*j + 0);
+ values[1] = dequant16(q4 + 64*j + 16);
+ values[2] = dequant16(q4 + 64*j + 32);
+ values[3] = dequant16(q4 + 64*j + 48);
+ }
+ inline __m256i dequant16(const uint8_t * qs) const {
+ const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);
+ const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);
+ return _mm256_and_si256(ml, aux256);
+ }
+ __m256i values[4];
+ const __m256i ml = _mm256_set1_epi8(0xf);
+};
+
+#endif
+
+#else
+// ------------------------------------ __aarch64__ --------------------------------------------------
+
+template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);
+ }
+
+ inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }
+ inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); }
+ inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); }
+ inline int16x8_t load_bsums8(int iy, int i) const {
+ auto q8s = vld1q_s16_x2(y[iy][i].bsums);
+ return vpaddq_s16(q8s.val[0], q8s.val[1]);
+ }
+ inline float scale(int iy, int i) const { return y[iy][i].d; }
+
+ const block_q8 * y[nrc_y];
+};
+
+template <typename block_q, bool has_row_scale = false, bool scale_is_f16 = false>
+struct BaseDequantizer {
+ BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}
+ inline void new_row(int ix) {
+ if constexpr (has_row_scale) {
+ if constexpr (scale_is_f16) {
+ const ggml_half * dptr = (const ggml_half *)((const char *)vx + ix*bx);
+ d = GGML_FP16_TO_FP32(*dptr);
+ x = (const block_q *)(dptr + 1);
+ } else {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ d = *dptr;
+ x = (const block_q *)(dptr + 1);
+ }
+ } else {
+ x = (const block_q *)((const char *)vx + ix*bx);
+ }
+ }
+ const void * vx;
+ const block_q * x;
+ const size_t bx;
+ const int nrc;
+ float d;
+};
+
+struct Q4bits {
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
+ uint8x16x4_t b1, b2;
+ inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const {
+ b.val[0] = vandq_u8(val[0], m4b);
+ b.val[2] = vshrq_n_u8(val[0], 4);
+ b.val[1] = vandq_u8(val[1], m4b);
+ b.val[3] = vshrq_n_u8(val[1], 4);
+ }
+ inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const {
+ b.val[0] = vandq_u8(val[0], m4b);
+ b.val[1] = vshrq_n_u8(val[0], 4);
+ b.val[2] = vandq_u8(val[1], m4b);
+ b.val[3] = vshrq_n_u8(val[1], 4);
+ }
+ inline void prepare(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x2(qs);
+ prepare4(b1, q4bits.val);
+ q4bits = vld1q_u8_x2(qs+32);
+ prepare4(b2, q4bits.val);
+ }
+ inline void prepare_v2(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x4(qs);
+ prepare4(b1, q4bits.val+0);
+ prepare4(b2, q4bits.val+2);
+ }
+ inline void prepare64(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x4(qs);
+ b1.val[0] = vandq_u8(q4bits.val[0], m4b);
+ b1.val[1] = vandq_u8(q4bits.val[1], m4b);
+ b1.val[2] = vandq_u8(q4bits.val[2], m4b);
+ b1.val[3] = vandq_u8(q4bits.val[3], m4b);
+ b2.val[0] = vshrq_n_u8(q4bits.val[0], 4);
+ b2.val[1] = vshrq_n_u8(q4bits.val[1], 4);
+ b2.val[2] = vshrq_n_u8(q4bits.val[2], 4);
+ b2.val[3] = vshrq_n_u8(q4bits.val[3], 4);
+ }
+ inline void prepare16(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x2(qs);
+ prepare4_16(b1, q4bits.val);
+ q4bits = vld1q_u8_x2(qs+32);
+ prepare4_16(b2, q4bits.val);
+ }
+ inline void prepare16_v2(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x4(qs);
+ prepare4_16(b1, q4bits.val+0);
+ prepare4_16(b2, q4bits.val+2);
+ }
+};
+
+struct Q2bits {
+ const uint8x16_t m4b = vdupq_n_u8(0x03);
+ uint8x16x4_t b1, b2;
+ inline void prepare(const uint8_t * qs) {
+ auto q2bits = vld1q_u8_x2(qs);
+ b1.val[0] = vandq_u8(q2bits.val[0], m4b);
+ b1.val[1] = vandq_u8(q2bits.val[1], m4b);
+
+ q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
+ q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
+ b1.val[2] = vandq_u8(q2bits.val[0], m4b);
+ b1.val[3] = vandq_u8(q2bits.val[1], m4b);
+
+ q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
+ q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
+ b2.val[0] = vandq_u8(q2bits.val[0], m4b);
+ b2.val[1] = vandq_u8(q2bits.val[1], m4b);
+
+ q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
+ q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
+ b2.val[2] = vandq_u8(q2bits.val[0], m4b);
+ b2.val[3] = vandq_u8(q2bits.val[1], m4b);
+ }
+};
+
+template <typename Q8>
+static inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
+ const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) {
+ auto mzero = vdupq_n_s32(0);
+ auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
+ auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
+ vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1]); // block 1
+ auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
+ auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
+ vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1]); // block 2
+ auto p12 = vpaddq_s32(p1, p2);
+
+ auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
+ auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
+ vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1]); // block 1
+ auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
+ auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
+ vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1]); // block 2
+ auto p34 = vpaddq_s32(p3, p4);
+
+ auto pall = vpaddq_s32(p12, p34);
+ sumi = vmlaq_s32(sumi, scales.val[j], pall);
+}
+
+template <typename Q8>
+static inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
+ const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {
+
+ auto mzero = vdupq_n_s32(0);
+ auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
+ auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
+ ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1,
+ auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
+ auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
+ ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4,
+ auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3
+ sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12);
+
+ auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
+ auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
+ ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5,
+ auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
+ auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
+ ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7,
+ auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7
+ sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34);
+}
+
+struct SignHelper {
+
+ inline void init() { shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); }
+
+ inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16) {
+ auto aux = vqtbl1q_u8(signs16, shuffle);
+ auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1));
+ b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s));
+ shuffle = vaddq_u8(shuffle, step);
+ }
+
+ const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
+ const uint8x16_t m1 = vdupq_n_u8(1);
+ const uint8x16_t step = vdupq_n_u8(2);
+ uint8x16_t shuffle;
+};
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y, block_q8_K> q8(info);
+
+ Dequantizer deq(vx, bx, nrc_y);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq.new_row(ix);
+
+ float32x4_t acc[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
+
+ for (int i = 0; i < nb; ++i) {
+
+ int32x4_t sumi[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);
+
+ if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) {
+ deq.process_scales(i, q8, acc);
+ deq.prepare(i, 0);
+ deq.compute(q8, i, 0, sumi);
+ deq.prepare(i, 1);
+ deq.compute(q8, i, 1, sumi);
+ } else {
+ if constexpr (Dequantizer::num_blocks() == 8) {
+ auto scales = deq.new_block(i, q8, acc);
+ deq.prepare(i, 0);
+ for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
+ deq.prepare(i, 1);
+ for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
+ }
+ else if constexpr (Dequantizer::num_blocks() == 16) {
+ auto scales = deq.new_block(i, q8, acc);
+ deq.prepare(i, 0);
+ for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
+ deq.prepare(i, 1);
+ for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
+ }
+ else {
+ GGML_ASSERT(false);
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vaddvq_f32(acc[iy]));
+ }
+ }
+}
+
+static IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) {
+ auto sumi = vdupq_n_s32(0);
+ sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
+ sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
+ return sumi;
+}
+
+static IQK_ALWAYS_INLINE int32x4x2_t interleaved_dotq_b16(const int8x16_t * qx, const int8x16x2_t& y) {
+ int32x4x2_t sumi = { vdupq_n_s32(0), vdupq_n_s32(0) };
+ sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[0], y.val[0], 0);
+ sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[1], y.val[1], 0);
+ sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[2], y.val[0], 1);
+ sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[3], y.val[1], 1);
+ sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[4], y.val[0], 2);
+ sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[5], y.val[1], 2);
+ sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[6], y.val[0], 3);
+ sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[7], y.val[1], 3);
+ return sumi;
+}
+
+static IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16_t& y) {
+ auto sumi = vdupq_n_s32(0);
+ sumi = vdotq_laneq_s32(sumi, qx[0], y, 0);
+ sumi = vdotq_laneq_s32(sumi, qx[1], y, 1);
+ sumi = vdotq_laneq_s32(sumi, qx[2], y, 2);
+ sumi = vdotq_laneq_s32(sumi, qx[3], y, 3);
+ return sumi;
+}
+
+static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) {
+ qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
+ qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
+ qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
+ qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
+ qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
+ qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
+ qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
+}
+
+static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x2_t& bits, int8x16_t * qx) {
+ qx[0] = vqtbl1q_s8(values, vandq_u8( bits.val[0], m4));
+ qx[1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4));
+ qx[2] = vqtbl1q_s8(values, vandq_u8( bits.val[1], m4));
+ qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4));
+}
+
+#endif
+
#endif