summaryrefslogtreecommitdiff
path: root/ggml
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-05-22 10:05:51 +0300
committerGitHub <noreply@github.com>2025-05-22 10:05:51 +0300
commitb94cd3b632a78dfb46b18d52b84be66bcf26166a (patch)
treeb65a5e45edad37f95301174d6971614950a8d489 /ggml
parenta2b5057a0c9a2758830b6f841bb22150d2511bb1 (diff)
Refactor iqk_mul_mat.cpp (#435)
* Refactor iqk: WIP * Refactor iqk: Factor out float GEMM (AVX2/AVX512) * Refactor iqk: Factor out GEMM for legacy quants (AVX2/AVX512) * Refactor iqk: Factor out GEMM for k-quants (AVX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for i-quants (AVX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for iqk-quants (AVX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for 1-bit quants (ABX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for iq1_bn, iq2_bn, iq2_bn_r4 * Refactor iqk: Factor out GEMM for repacked legacy quants * Refactor iqk: Factor out GEMM for q8_K_R8, q8_KV * Refactor iqk: Factor out GEMM for repacked i-quants * Refactor iqk: GEMM kernels are refactored on AVX2/AVX512 * Refactor iqk: factor out 1-bit quants (NEON) * Refactor iqk: factor out k-quants (NEON) * Refactor iqk: factor out floats (NEON) * Also iq4_xs belongs to k-quants * Refactor iqk: factor out iqk quants (NEON) * Refactor iqk: factor out legacy quants (NEON) * Refactor iqk: factor out repacked legacy quants (NEON) * Refactor iqk: factor out repacked k-quants (NEON) * Refactor iqk: factor out repacked iqk quants (NEON) * Refactor iqk: GEMM kernels are refactored on NEON * Refactor iqk: FA compiles If it works is a different story. Current compile time: 107.3 sesonds on the Ryzen-7950X * Refactor iqk: FA refactored (Zen4) Compile time for the FA files is now ~21 seconds on my Ryzen-7950X, so still slightly too long for my taste but much better than the 142 seconds we had before. * Adding forgotten file * Most helpers don't need to be templates Also hide Q4_0 and Q8_KV behind IQK_FA_ALL_QUANTS. Compilation time drops to 14 second on the Ryzen-5975WX * Fix bf16 * Refactor iqk: FA refactored (NEON) * Forgotten MMQ ref and typo (#431) * Adding forgotten iq5_k_r4 * Fix iq4_k_r4 on NEON * Fix iq4_ks on NEON It was broken before the refactoring (the shifts were not correctly applied). * Fix q8_0 on NEON * Fix q6_0 K cache --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Nexes the Elder <124105151+Nexesenex@users.noreply.github.com>
Diffstat (limited to 'ggml')
-rw-r--r--ggml/src/CMakeLists.txt25
-rw-r--r--ggml/src/iqk/fa/iqk_fa_128_128.cpp45
-rw-r--r--ggml/src/iqk/fa/iqk_fa_192_128.cpp45
-rw-r--r--ggml/src/iqk/fa/iqk_fa_256_256.cpp45
-rw-r--r--ggml/src/iqk/fa/iqk_fa_576_512.cpp120
-rw-r--r--ggml/src/iqk/fa/iqk_fa_64_64.cpp45
-rw-r--r--ggml/src/iqk/fa/iqk_fa_96_96.cpp45
-rw-r--r--ggml/src/iqk/fa/iqk_fa_templates.h2207
-rw-r--r--ggml/src/iqk/iqk_common.h695
-rw-r--r--ggml/src/iqk/iqk_gemm_1bit.cpp2282
-rw-r--r--ggml/src/iqk/iqk_gemm_1bit.h11
-rw-r--r--ggml/src/iqk/iqk_gemm_floats.cpp1048
-rw-r--r--ggml/src/iqk/iqk_gemm_floats.h13
-rw-r--r--ggml/src/iqk/iqk_gemm_iqk_quants.cpp3289
-rw-r--r--ggml/src/iqk/iqk_gemm_iqk_quants.h11
-rw-r--r--ggml/src/iqk/iqk_gemm_iquants.cpp2252
-rw-r--r--ggml/src/iqk/iqk_gemm_iquants.h11
-rw-r--r--ggml/src/iqk/iqk_gemm_kquants.cpp3121
-rw-r--r--ggml/src/iqk/iqk_gemm_kquants.h13
-rw-r--r--ggml/src/iqk/iqk_gemm_legacy_quants.cpp2763
-rw-r--r--ggml/src/iqk/iqk_gemm_legacy_quants.h14
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp17996
-rw-r--r--ggml/src/iqk/iqk_utils.h207
23 files changed, 18415 insertions, 17888 deletions
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 14650d03..9872b3de 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -258,8 +258,29 @@ set (GGML_HEADERS_IQK iqk/iqk_config.h)
if (GGML_IQK_MUL_MAT)
message(STATUS "Using optimized iqk matrix multiplications")
add_compile_definitions(GGML_USE_IQK_MULMAT)
- set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp iqk/iqk_flash_attn.cpp)
- set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h iqk/iqk_flash_impl.h)
+ set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp
+ iqk/iqk_flash_attn.cpp
+ iqk/fa/iqk_fa_576_512.cpp
+ iqk/fa/iqk_fa_192_128.cpp
+ iqk/fa/iqk_fa_256_256.cpp
+ iqk/fa/iqk_fa_128_128.cpp
+ iqk/fa/iqk_fa_96_96.cpp
+ iqk/fa/iqk_fa_64_64.cpp
+ iqk/iqk_gemm_floats.cpp
+ iqk/iqk_gemm_kquants.cpp
+ iqk/iqk_gemm_iquants.cpp
+ iqk/iqk_gemm_iqk_quants.cpp
+ iqk/iqk_gemm_1bit.cpp
+ iqk/iqk_gemm_legacy_quants.cpp)
+ set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h
+ iqk/iqk_flash_impl.h
+ iqk/fa/iqk_fa_templates.h
+ iqk/iqk_gemm_floats.h
+ iqk/iqk_gemm_kquants.h
+ iqk/iqk_gemm_iquants.h
+ iqk/iqk_gemm_iqk_quants.h
+ iqk/iqk_gemm_1bit.h
+ iqk/iqk_gemm_legacy_quants.h)
if (GGML_IQK_FLASH_ATTENTION)
message(STATUS "Enabling IQK Flash Attention kernels")
add_compile_definitions(GGML_IQK_FLASH_ATTENTION)
diff --git a/ggml/src/iqk/fa/iqk_fa_128_128.cpp b/ggml/src/iqk/fa/iqk_fa_128_128.cpp
new file mode 100644
index 00000000..52eb289d
--- /dev/null
+++ b/ggml/src/iqk/fa/iqk_fa_128_128.cpp
@@ -0,0 +1,45 @@
+#include "iqk/iqk_config.h"
+
+#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
+
+#include "iqk/fa/iqk_fa_templates.h"
+
+IQK_FA_CASE(iqk_fa_128_128) {
+
+ auto type_k = ggml_type(int_type_k);
+ auto type_v = ggml_type(int_type_v);
+
+ stride_q /= sizeof(float); // q stride as float
+ auto ck = (const char *)k;
+ auto cv = (const char *)v;
+ auto cm = (const char *)mask;
+
+#ifdef __AVX512BF16__
+ if (type_k == GGML_TYPE_BF16) {
+ if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
+ if (nk%64 == 0) {
+ iqk_flash_helper_T<128, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ return true;
+ }
+ iqk_flash_helper_T<128, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ return true;
+ }
+#endif
+
+ if (nk%128 == 0) {
+ return iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ }
+ if (nk%64 == 0) {
+ return iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ }
+
+ return iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+
+}
+
+#endif
diff --git a/ggml/src/iqk/fa/iqk_fa_192_128.cpp b/ggml/src/iqk/fa/iqk_fa_192_128.cpp
new file mode 100644
index 00000000..6c4c51fb
--- /dev/null
+++ b/ggml/src/iqk/fa/iqk_fa_192_128.cpp
@@ -0,0 +1,45 @@
+#include "iqk/iqk_config.h"
+
+#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
+
+#include "iqk/fa/iqk_fa_templates.h"
+
+IQK_FA_CASE(iqk_fa_192_128) {
+
+ auto type_k = ggml_type(int_type_k);
+ auto type_v = ggml_type(int_type_v);
+
+ stride_q /= sizeof(float); // q stride as float
+ auto ck = (const char *)k;
+ auto cv = (const char *)v;
+ auto cm = (const char *)mask;
+
+#ifdef __AVX512BF16__
+ if (type_k == GGML_TYPE_BF16) {
+ if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
+ if (nk%64 == 0) {
+ iqk_flash_helper_T<192, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ return true;
+ }
+ iqk_flash_helper_T<192, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ return true;
+ }
+#endif
+
+ if (nk%128 == 0) {
+ return iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ }
+ if (nk%64 == 0) {
+ return iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ }
+
+ return iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+
+}
+
+#endif
diff --git a/ggml/src/iqk/fa/iqk_fa_256_256.cpp b/ggml/src/iqk/fa/iqk_fa_256_256.cpp
new file mode 100644
index 00000000..b0bc35e3
--- /dev/null
+++ b/ggml/src/iqk/fa/iqk_fa_256_256.cpp
@@ -0,0 +1,45 @@
+#include "iqk/iqk_config.h"
+
+#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
+
+#include "iqk/fa/iqk_fa_templates.h"
+
+IQK_FA_CASE(iqk_fa_256_256) {
+
+ auto type_k = ggml_type(int_type_k);
+ auto type_v = ggml_type(int_type_v);
+
+ stride_q /= sizeof(float); // q stride as float
+ auto ck = (const char *)k;
+ auto cv = (const char *)v;
+ auto cm = (const char *)mask;
+
+#ifdef __AVX512BF16__
+ if (type_k == GGML_TYPE_BF16) {
+ if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
+ if (nk%64 == 0) {
+ iqk_flash_helper_T<256, 256, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ return true;
+ }
+ iqk_flash_helper_T<256, 256, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ return true;
+ }
+#endif
+
+ if (nk%128 == 0) {
+ return iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ }
+ if (nk%64 == 0) {
+ return iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ }
+
+ return iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+
+}
+
+#endif
diff --git a/ggml/src/iqk/fa/iqk_fa_576_512.cpp b/ggml/src/iqk/fa/iqk_fa_576_512.cpp
new file mode 100644
index 00000000..5174be30
--- /dev/null
+++ b/ggml/src/iqk/fa/iqk_fa_576_512.cpp
@@ -0,0 +1,120 @@
+#include "iqk/iqk_config.h"
+
+#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
+
+#include "iqk/fa/iqk_fa_templates.h"
+
+namespace {
+
+template <int step_k, typename KHelper, typename VHelper>
+inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
+ int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
+ const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
+ auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
+ nq1 -= n;
+ if (nq1 == 0) return true;
+ q += n*stride_q;
+ mask += n*stride_m;
+ qkv += n*stride_qkv;
+ if (M && S) { M += n; S += n; }
+ return false;
+ };
+ if (nq1 >= 16) {
+ int n_step = nq1/16;
+ FlashAttn<576, 512, 16, step_k> fa(scale, softcap);
+ fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
+ if (update(16*n_step)) return;
+ }
+ if (nq1 >= 8) {
+ int n_step = nq1/8;
+ FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
+ fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
+ if (update(8*n_step)) return;
+ }
+ if (nq1 >= 4) {
+ int n_step = nq1/4;
+ FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
+ fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
+ if (update(4*n_step)) return;
+ }
+ if (nq1 >= 2) {
+ int n_step = nq1/2;
+ FlashAttn<576, 512, 2, step_k> fa(scale, softcap);
+ fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
+ if (update(2*n_step)) return;
+ }
+ FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
+}
+
+template <int step_k>
+inline bool iqk_deepseek_helper(ggml_type type_k,
+ int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
+ const float * q, const char * k, const char * v, const char * mask,
+ float scale, float softcap, float * qkv, float * M, float * S) {
+ if (type_k == GGML_TYPE_Q8_0) {
+ HelperQ80 kh((const char *)k, stride_k);
+ HelperQ80 vh((const char *)v, stride_v);
+ iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
+ return true;
+ }
+ if (type_k == GGML_TYPE_Q8_0_R8) {
+ HelperQ80R8<576> kh((const char *)k, stride_k);
+ HelperQ80 vh((const char *)v, stride_v);
+ iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
+ return true;
+ }
+ if (type_k == GGML_TYPE_Q6_0) {
+ HelperQ60 kh((const char *)k, stride_k);
+ HelperQ60 vh((const char *)v, stride_v);
+ iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
+ return true;
+ }
+#if GGML_IQK_FA_ALL_QUANTS
+ if (type_k == GGML_TYPE_Q8_KV) {
+ HelperQ8KV<576> kh((const char *)k, stride_k);
+ HelperQ8KV<512> vh((const char *)v, stride_v);
+ iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
+ return true;
+ }
+#endif
+ if (type_k == GGML_TYPE_F16) {
+ HelperF16 kh((const char *)k, stride_k);
+ HelperF16 vh((const char *)v, stride_v);
+ iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
+ return true;
+ }
+#ifdef __AVX512BF16__
+ if (type_k == GGML_TYPE_BF16) {
+ HelperBF16<576, step_k> kh((const char *)k, stride_k);
+ HelperBF16<512, step_k> vh((const char *)v, stride_v);
+ if (nq1 % 8 == 0) {
+ FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ } else {
+ FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ }
+ return true;
+ }
+#endif
+ return false;
+}
+
+}
+
+IQK_FA_CASE(iqk_fa_576_512) {
+
+ auto type_k = ggml_type(int_type_k);
+ auto type_v = ggml_type(int_type_v);
+
+ if (!(type_k == type_v || (type_k == GGML_TYPE_Q8_0_R8 && type_v == GGML_TYPE_Q8_0))) {
+ return false;
+ }
+ stride_q /= sizeof(float); // q stride as float
+ return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S);
+
+}
+
+#endif
diff --git a/ggml/src/iqk/fa/iqk_fa_64_64.cpp b/ggml/src/iqk/fa/iqk_fa_64_64.cpp
new file mode 100644
index 00000000..652f682b
--- /dev/null
+++ b/ggml/src/iqk/fa/iqk_fa_64_64.cpp
@@ -0,0 +1,45 @@
+#include "iqk/iqk_config.h"
+
+#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
+
+#include "iqk/fa/iqk_fa_templates.h"
+
+IQK_FA_CASE(iqk_fa_64_64) {
+
+ auto type_k = ggml_type(int_type_k);
+ auto type_v = ggml_type(int_type_v);
+
+ stride_q /= sizeof(float); // q stride as float
+ auto ck = (const char *)k;
+ auto cv = (const char *)v;
+ auto cm = (const char *)mask;
+
+#ifdef __AVX512BF16__
+ if (type_k == GGML_TYPE_BF16) {
+ if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
+ if (nk%64 == 0) {
+ iqk_flash_helper_T<64, 64, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ return true;
+ }
+ iqk_flash_helper_T<64, 64, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ return true;
+ }
+#endif
+
+ if (nk%128 == 0) {
+ return iqk_flash_helper_T<64, 64, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ }
+ if (nk%64 == 0) {
+ return iqk_flash_helper_T<64, 64, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ }
+
+ return iqk_flash_helper_T<64, 64, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+
+}
+
+#endif
diff --git a/ggml/src/iqk/fa/iqk_fa_96_96.cpp b/ggml/src/iqk/fa/iqk_fa_96_96.cpp
new file mode 100644
index 00000000..fed49cb0
--- /dev/null
+++ b/ggml/src/iqk/fa/iqk_fa_96_96.cpp
@@ -0,0 +1,45 @@
+#include "iqk/iqk_config.h"
+
+#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
+
+#include "iqk/fa/iqk_fa_templates.h"
+
+IQK_FA_CASE(iqk_fa_96_96) {
+
+ auto type_k = ggml_type(int_type_k);
+ auto type_v = ggml_type(int_type_v);
+
+ stride_q /= sizeof(float); // q stride as float
+ auto ck = (const char *)k;
+ auto cv = (const char *)v;
+ auto cm = (const char *)mask;
+
+#ifdef __AVX512BF16__
+ if (type_k == GGML_TYPE_BF16) {
+ if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
+ if (nk%64 == 0) {
+ iqk_flash_helper_T<96, 96, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ return true;
+ }
+ iqk_flash_helper_T<96, 96, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ return true;
+ }
+#endif
+
+ if (nk%128 == 0) {
+ return iqk_flash_helper_T<96, 96, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ }
+ if (nk%64 == 0) {
+ return iqk_flash_helper_T<96, 96, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+ }
+
+ return iqk_flash_helper_T<96, 96, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, ck, cv, cm, scale, softcap, qkv, M, S);
+
+}
+
+#endif
diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h
new file mode 100644
index 00000000..6de2acea
--- /dev/null
+++ b/ggml/src/iqk/fa/iqk_fa_templates.h
@@ -0,0 +1,2207 @@
+// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
+// vi: set et ft=cpp fenc=utf-8 :vi
+//
+//
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "iqk/iqk_config.h"
+
+#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
+
+#include <cstring>
+#include <type_traits>
+#include <vector>
+
+#include "ggml-impl.h"
+#include "ggml-quants.h"
+#include "iqk/iqk_quantize.h"
+#include "iqk/iqk_gemm_floats.h"
+#include "iqk/iqk_gemm_kquants.h"
+#include "iqk/iqk_gemm_legacy_quants.h"
+#include "iqk/iqk_utils.h"
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+// clang-format off
+
+namespace {
+
+struct BaseHelper {
+ BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {}
+
+ //inline void set_block(int k1) { block = data + k1*k_step*stride; }
+ inline void reset_block() { block = data; }
+ inline void next_block(int step) { block += step*stride; }
+ inline const char * lblock(int l1) const { return block + l1*stride; }
+
+ const char * data;
+ const char * block;
+ int stride;
+
+};
+
+struct F16 {
+#ifdef __AVX512F__
+ using Data = __m512;
+ constexpr static int block_size = 16;
+ constexpr static int num_registers = 32;
+ constexpr static int q_step = 8;
+ static inline Data zero() { return _mm512_setzero_ps(); }
+ static inline Data load(const char * ptr, int i) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)ptr + i)); }
+ static inline Data set1(float val) { return _mm512_set1_ps(val); }
+ static inline Data mul(Data v1, Data v2) { return _mm512_mul_ps(v1, v2); }
+ static inline Data sub(Data v1, Data v2) { return _mm512_sub_ps(v1, v2); }
+ static inline Data load(const float * ptr) { return _mm512_loadu_ps(ptr); }
+ static inline void store(float * ptr, Data data) { _mm512_storeu_ps(ptr, data); }
+ static inline Data fmadd(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, v2, prev); }
+ static inline float reduce_max(Data data) { return _mm512_reduce_max_ps(data); }
+ static inline float reduce_add(Data data) { return _mm512_reduce_add_ps(data); }
+ static inline Data max(Data v1, Data v2) { return _mm512_max_ps(v1, v2); }
+ static inline Data add(Data v1, Data v2) { return _mm512_add_ps(v1, v2); }
+ static inline Data set4(const float * ptr) {
+ auto v128 = _mm_loadu_ps(ptr);
+ auto v256 = _mm256_set_m128(v128, v128);
+ return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1);
+ }
+ static inline void set4(const float * ptr, Data * vs) {
+ auto v = set4(ptr);
+ vs[0] = _mm512_shuffle_ps(v, v, 0x00);
+ vs[1] = _mm512_shuffle_ps(v, v, 0x55);
+ vs[2] = _mm512_shuffle_ps(v, v, 0xaa);
+ vs[3] = _mm512_shuffle_ps(v, v, 0xff);
+ }
+ static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); }
+ static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); }
+ static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); }
+ static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xff), prev); }
+#elif defined __AVX2__
+ using Data = __m256;
+ constexpr static int block_size = 8;
+ constexpr static int num_registers = 16;
+ constexpr static int q_step = 8;
+ static inline Data zero() { return _mm256_setzero_ps(); }
+ static inline Data load(const char * ptr, int i) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)ptr + i)); }
+ static inline Data set1(float val) { return _mm256_set1_ps(val); }
+ static inline Data mul(Data v1, Data v2) { return _mm256_mul_ps(v1, v2); }
+ static inline Data load(const float * ptr) { return _mm256_loadu_ps(ptr); }
+ static inline Data sub(Data v1, Data v2) { return _mm256_sub_ps(v1, v2); }
+ static inline void store(float * ptr, Data data) { _mm256_storeu_ps(ptr, data); }
+ static inline Data fmadd(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, v2, prev); }
+ static inline float reduce_max(Data data) { return hmax_float_8(data); }
+ static inline float reduce_add(Data data) { return hsum_float_8(data); }
+ static inline Data max(Data v1, Data v2) { return _mm256_max_ps(v1, v2); }
+ static inline Data add(Data v1, Data v2) { return _mm256_add_ps(v1, v2); }
+ static inline Data set4(const float * ptr) {
+ auto v128 = _mm_loadu_ps(ptr);
+ return _mm256_set_m128(v128, v128);
+ }
+ static inline void set4(const float * ptr, Data * vs) {
+ auto v = set4(ptr);
+ vs[0] = _mm256_shuffle_ps(v, v, 0x00);
+ vs[1] = _mm256_shuffle_ps(v, v, 0x55);
+ vs[2] = _mm256_shuffle_ps(v, v, 0xaa);
+ vs[3] = _mm256_shuffle_ps(v, v, 0xff);
+ }
+ static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); }
+ static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); }
+ static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); }
+ static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xff), prev); }
+#else
+ using Data = float16x8_t;
+ constexpr static int block_size = 8;
+ //constexpr static int num_registers = 32;
+ //constexpr static int q_step = 8;
+ static inline Data zero() { return vdupq_n_f16(0); }
+ static inline Data load(const char * ptr, int i) { return vld1q_f16((const float16_t *)ptr + block_size*i); }
+ static inline Data load(const float16_t * ptr, int i) { return vld1q_f16(ptr + block_size*i); }
+ static inline Data load(const float16_t * ptr) { return vld1q_f16(ptr); }
+ static inline Data load(const float * ptr) {
+ auto val1 = vld1q_f32(ptr);
+ auto val2 = vld1q_f32(ptr+4);
+ return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
+ }
+ static inline Data set1(float val) { return vdupq_n_f16(val); }
+ static inline Data mul(Data v1, Data v2) { return vmulq_f16(v1, v2); }
+ static inline Data sub(Data v1, Data v2) { return vsubq_f16(v1, v2); }
+ static inline void store(float * ptr, Data data) {
+ vst1q_f32(ptr+0, vcvt_f32_f16(vget_low_f16(data)));
+ vst1q_f32(ptr+4, vcvt_f32_f16(vget_high_f16(data)));
+ }
+ static inline void store(float16_t * ptr, Data data) { vst1q_f16(ptr, data); }
+ static inline void store(float * ptr, float32x4_t data) { vst1q_f32(ptr, data); }
+ static inline Data fmadd(Data prev, Data v1, Data v2) { return vfmaq_f16(prev, v1, v2); }
+ static inline float reduce_max(Data data) { return vmaxvq_f16(data); }
+ static inline float reduce_add(Data data) {
+ auto sum = vadd_f16(vget_low_f16(data), vget_high_f16(data));
+ return vaddvq_f32(vcvt_f32_f16(sum));
+ }
+ static inline Data max(Data v1, Data v2) { return vmaxq_f16(v1, v2); }
+ static inline Data add(Data v1, Data v2) { return vaddq_f16(v1, v2); }
+ static inline float16x4_t set4(const float * ptr) {
+ auto val32 = vld1q_f32(ptr);
+ return vcvt_f16_f32(val32);
+ }
+ static inline Data fmadd_lane0(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 0); }
+ static inline Data fmadd_lane1(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 1); }
+ static inline Data fmadd_lane2(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 2); }
+ static inline Data fmadd_lane3(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 3); }
+#endif
+ template <int k_step> static inline float reduce_max(const Data * data) {
+ return reduce_T<k_step, &F16::max, &F16::reduce_max>(data);
+ }
+ template <int k_step> static inline float reduce_add(const Data * data) {
+ return reduce_T<k_step, &F16::add, &F16::reduce_add>(data);
+ }
+ template <int k_step, Data (*Op_combine)(Data, Data), float (*Op)(Data)>
+ static float reduce_T(const Data * data) {
+ float result;
+ if constexpr (k_step/block_size == 1) {
+ result = Op(data[0]);
+ }
+ else if constexpr (k_step/block_size == 2) {
+ result = Op(Op_combine(data[0], data[1]));
+ }
+ else {
+ auto vmax = Op_combine(data[0], data[1]);
+ for (int l = 2; l < k_step/block_size; ++l) vmax = Op_combine(vmax, data[l]);
+ result = Op(vmax);
+ }
+ return result;
+ }
+};
+
+struct HelperF16 final : public BaseHelper {
+ using Base = BaseHelper;
+ HelperF16(const char * data, int stride) : Base(data, stride) {}
+
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ //auto dr = (const ggml_half *)Base::lblock(l1);
+ auto dr = Base::lblock(l1);
+ v1 = F16::load(dr, i + 0);
+ v2 = F16::load(dr, i + 1);
+ }
+};
+
+template <int D> struct block_q8_KV {
+ float d;
+ int s;
+ int8_t qs[D];
+};
+
+template <int D>
+struct HelperQ8KV final : public BaseHelper {
+ using Base = BaseHelper;
+ using block_q8 = block_q8_KV<D>;
+ constexpr static ggml_type type = GGML_TYPE_Q8_KV;
+ constexpr static int block_size_q = D;
+ HelperQ8KV(const char * data, int stride) : Base(data, stride) {}
+
+ // Needed for v * softmax(k * q)
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ auto q8 = (const block_q8_KV<D> *)Base::lblock(l1);
+#ifdef __aarch64__
+ auto vd = F16::set1(q8->d);
+ auto qs = vld1_s8_x2(q8->qs + 8*i);
+ v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
+ v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
+#else
+ auto vd = F16::set1(q8->d);
+#ifdef __AVX512F__
+ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+0))));
+ v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+1))));
+#else
+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+0)))));
+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+8)))));
+#endif
+#endif
+ }
+};
+
+struct HelperQ80 final : public BaseHelper {
+ using Base = BaseHelper;
+ constexpr static ggml_type type = GGML_TYPE_Q8_0;
+#ifdef HAVE_FANCY_SIMD
+ using block_q8 = block_q8_2;
+ constexpr static int block_size_q = QK8_2;
+#else
+ using block_q8 = block_q8_0;
+ constexpr static int block_size_q = QK8_0;
+#endif
+ HelperQ80(const char * data, int stride) : Base(data, stride) {}
+
+ // Needed for v * softmax(k * q)
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ int j = F16::block_size*i;
+ auto dl = (const block_q8_0 *)Base::lblock(l1) + j/QK8_0;
+#ifdef __aarch64__
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
+ int ii = j%QK8_0;
+ auto qs = vld1_s8_x2(dl->qs + ii);
+ v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
+ v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
+#else
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
+#ifdef __AVX512F__
+ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+0))));
+ v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+1))));
+#else
+ int ii = j%QK8_0;
+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii+0)))));
+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii+8)))));
+#endif
+#endif
+ }
+
+ template <int D>
+ static inline void convert(int nq, int stride_q, const float * q, block_q8_0 * y) {
+ for (int i = 0; i < nq; ++i) {
+ quantize_row_q8_0_x4(q, y, D);
+ q += stride_q;
+ y += D/QK8_0;
+ }
+ }
+
+ template <int D>
+ static inline void convert(int nq, int stride_q, const float * q, block_q8_1 * y) {
+ for (int i = 0; i < nq; ++i) {
+ quantize_row_q8_1_x4(q, y, D);
+ q += stride_q;
+ y += D/QK8_1;
+ }
+ }
+
+ template <int D>
+ static inline void convert(int nq, int stride_q, const float * q, block_q8_2 * y) {
+ for (int i = 0; i < nq; ++i) {
+ quantize_row_q8_2_x4(q, y, D);
+ q += stride_q;
+ y += D/QK8_2;
+ }
+ }
+
+ template <int D>
+ static inline void convert(int nq, int stride_q, const float * q, block_q8_KV<D> * y) {
+ for (int i = 0; i < nq; ++i) {
+ quantize_row_q8_KV(q, y, D);
+ q += stride_q;
+ ++y;
+ }
+ }
+};
+
+template <int D>
+struct HelperQ80R8 : public BaseHelper {
+ using Base = BaseHelper;
+ constexpr static ggml_type type = GGML_TYPE_Q8_0_R8;
+#ifdef __AVX2__
+ constexpr static int block_size_q = QK8_2;
+ using block_q8 = block_q8_2;
+#else
+ constexpr static int block_size_q = QK8_0;
+ using block_q8 = block_q8_0;
+#endif
+ HelperQ80R8(const char * data, int stride) : Base(data, stride) {}
+ HelperQ80R8(int nk, const HelperQ80& q8) : Base(q8.data, q8.stride) {
+ r4 = repack(nk, q8);
+ Base::data = (const char *)r4.data();
+ Base::stride = (D/QK8_0)*sizeof(block_q8_0);
+ }
+
+ static void repack(int nk, const char * q8_data, int q8_stride, block_q8_0_r8 * y) {
+ constexpr int nblock = D/QK8_0;
+ const block_q8_0 * x8[8];
+#ifdef __ARM_NEON
+ int8x16x2_t m0, m1, m2, m3;
+#endif
+ for (int row = 0; row < nk; row += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8_data + (row + k)*q8_stride);
+ for (int ib = 0; ib < nblock; ++ib) {
+ for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d;
+#ifdef __AVX2__
+ auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs), _mm_loadu_si128((const __m128i *)x8[0][ib].qs));
+ auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs), _mm_loadu_si128((const __m128i *)x8[1][ib].qs));
+ auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs), _mm_loadu_si128((const __m128i *)x8[2][ib].qs));
+ auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs), _mm_loadu_si128((const __m128i *)x8[3][ib].qs));
+ auto t0 = _mm256_unpacklo_epi32(m0, m1);
+ auto t1 = _mm256_unpacklo_epi32(m2, m3);
+ auto t2 = _mm256_unpackhi_epi32(m0, m1);
+ auto t3 = _mm256_unpackhi_epi32(m2, m3);
+ m0 = _mm256_unpacklo_epi64(t0, t1);
+ m1 = _mm256_unpackhi_epi64(t0, t1);
+ m2 = _mm256_unpacklo_epi64(t2, t3);
+ m3 = _mm256_unpackhi_epi64(t2, t3);
+//#ifdef HAVE_FANCY_SIMD
+// m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
+// m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
+// m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
+// m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
+//#endif
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 3, m3);
+ m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[0][ib].qs+1));
+ m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[1][ib].qs+1));
+ m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[2][ib].qs+1));
+ m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[3][ib].qs+1));
+ t0 = _mm256_unpacklo_epi32(m0, m1);
+ t1 = _mm256_unpacklo_epi32(m2, m3);
+ t2 = _mm256_unpackhi_epi32(m0, m1);
+ t3 = _mm256_unpackhi_epi32(m2, m3);
+ m0 = _mm256_unpacklo_epi64(t0, t1);
+ m1 = _mm256_unpackhi_epi64(t0, t1);
+ m2 = _mm256_unpacklo_epi64(t2, t3);
+ m3 = _mm256_unpackhi_epi64(t2, t3);
+//#ifdef HAVE_FANCY_SIMD
+// m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
+// m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
+// m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
+// m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
+//#endif
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 7, m3);
+#elif defined __ARM_NEON
+ for (int l = 0; l < 2; ++l) {
+ m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l);
+ m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l);
+ m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l);
+ m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l);
+ auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
+ auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
+ m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
+ row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
+ m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0);
+ vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1);
+ vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2);
+ vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3);
+ }
+#else
+ for (int l = 0; l < 4; ++l) {
+ for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
+ y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
+ y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
+ }
+ }
+#endif
+ }
+ y += nblock;
+ }
+ }
+
+ static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80& q8) {
+ static_assert(D%QK8_0 == 0);
+ GGML_ASSERT(nk%8 == 0);
+ constexpr int nblock = D/QK8_0;
+ std::vector<block_q8_0_r8> result(nblock * nk/8);
+ auto y = result.data();
+ repack(nk, q8.data, q8.stride, y);
+ return result;
+ }
+
+ std::vector<block_q8_0_r8> r4;
+};
+
+// TODO: unite this with the above
+template <int D>
+struct HelperQ8KVR8 : public BaseHelper {
+ using Base = BaseHelper;
+ constexpr static ggml_type type = GGML_TYPE_Q8_KV_R8;
+ constexpr static int block_size_q = D;
+ using block_q8 = block_q8_KV<D>;
+
+ struct block_q8_KV_r8 {
+ float d[8];
+ int8_t qs[8*D];
+ };
+
+ HelperQ8KVR8(int nk, const HelperQ8KV<D>& q8) : Base(q8.data, q8.stride) {
+ r4 = repack(nk, q8);
+ Base::data = (const char *)r4.data();
+ Base::stride = sizeof(block_q8_KV_r8)/8;
+ }
+
+ static std::vector<block_q8_KV_r8> repack(int nk, const HelperQ8KV<D>& q8) {
+ static_assert(D%32 == 0);
+ GGML_ASSERT(nk%8 == 0);
+ std::vector<block_q8_KV_r8> result(nk/8);
+ auto y = result.data();
+#ifdef __ARM_NEON
+ int8x16x2_t m0, m1, m2, m3;
+#endif
+ const int8_t * x8[8];
+ for (int ix = 0; ix < nk/8; ++ix) {
+ for (int k = 0; k < 8; ++k) {
+ auto dptr = (const float *)(q8.data + (8*ix + k)*q8.stride);
+ y[ix].d[k] = dptr[0];
+ x8[k] = (const int8_t *)(dptr + 2);
+ }
+ for (int ib = 0; ib < D/16; ++ib) {
+#ifdef __AVX2__
+ auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib));
+ auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib));
+ auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib));
+ auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib));
+ auto t0 = _mm256_unpacklo_epi32(m0, m1);
+ auto t1 = _mm256_unpacklo_epi32(m2, m3);
+ auto t2 = _mm256_unpackhi_epi32(m0, m1);
+ auto t3 = _mm256_unpackhi_epi32(m2, m3);
+ m0 = _mm256_unpacklo_epi64(t0, t1);
+ m1 = _mm256_unpackhi_epi64(t0, t1);
+ m2 = _mm256_unpacklo_epi64(t2, t3);
+ m3 = _mm256_unpackhi_epi64(t2, t3);
+//#ifdef HAVE_FANCY_SIMD
+// m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
+// m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
+// m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
+// m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
+//#endif
+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+0, m0);
+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+1, m1);
+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+2, m2);
+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+3, m3);
+#elif defined __ARM_NEON
+ // TODO
+ m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib);
+ m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib);
+ m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib);
+ m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib);
+ auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
+ auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
+ m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
+ row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
+ m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ vst1q_s8_x2(y[ix].qs + 0 + 128*ib, m0);
+ vst1q_s8_x2(y[ix].qs + 32 + 128*ib, m1);
+ vst1q_s8_x2(y[ix].qs + 64 + 128*ib, m2);
+ vst1q_s8_x2(y[ix].qs + 96 + 128*ib, m3);
+#else
+ // TODO
+ for (int l = 0; l < 4; ++l) {
+ for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
+ y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
+ y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
+ }
+ }
+#endif
+ }
+ }
+ return result;
+ }
+
+ std::vector<block_q8_KV_r8> r4;
+};
+
+struct HelperQ40 final : public BaseHelper {
+ using Base = BaseHelper;
+ constexpr static ggml_type type = GGML_TYPE_Q4_0;
+#if defined __AVX2__
+ using block_q8 = block_q8_2;
+ constexpr static int block_size_q = QK8_2;
+#else
+ using block_q8 = block_q8_0;
+ constexpr static int block_size_q = QK8_0;
+#endif
+ HelperQ40(const char * data, int stride) : Base(data, stride) {}
+
+ // Needed for v * softmax(k * q)
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ int j = F16::block_size*i;
+ auto dl = (const block_q4_0 *)Base::lblock(l1) + j/QK4_0;
+#ifdef __aarch64__
+ auto vd = F16::set1(*(const float16_t *)&dl->d);
+ auto q = vld1q_u8(dl->qs);
+ q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
+ q = vaddq_s8(q, m8);
+ v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q))));
+ v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q))));
+#else
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
+ auto q = _mm_loadu_si128((const __m128i *)dl->qs);
+#ifdef __AVX512F__
+ auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
+ auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
+ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
+ v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
+#else
+ if (j%QK4_0) q = _mm_srli_epi16(q, 4);
+ auto q16 = _mm256_cvtepi8_epi16(_mm_add_epi8(_mm_and_si128(q, mask), m8));
+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))));
+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))));
+#endif
+#endif
+ }
+
+#ifdef __AVX2__
+ const __m128i mask = _mm_set1_epi8(0xf);
+ const __m128i m8 = _mm_set1_epi8(-8);
+#else
+ const uint8x16_t mask = vdupq_n_u8(0xf);
+ const int8x16_t m8 = vdupq_n_s8(-8);
+#endif
+};
+
+struct HelperQ41 final : public BaseHelper {
+ using Base = BaseHelper;
+ using block_q8 = block_q8_2;
+ constexpr static ggml_type type = GGML_TYPE_Q4_1;
+ constexpr static int block_size_q = QK8_2;
+ HelperQ41(const char * data, int stride) : Base(data, stride) {}
+
+ // Needed for v * softmax(k * q)
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ int j = F16::block_size*i;
+ auto dl = (const block_q4_1 *)Base::lblock(l1) + j/QK4_1;
+#ifdef __aarch64__
+ auto vd = F16::set1(*(const float16_t *)&dl->d);
+ auto vm = F16::set1(*(const float16_t *)&dl->m);
+ auto q = vld1q_u8(dl->qs);
+ q = (j%QK4_1) ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
+ v1 = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_low_u8(q))));
+ v2 = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_high_u8(q))));
+#else
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
+ auto vm = F16::set1(GGML_FP16_TO_FP32(dl->m));
+ auto q = _mm_loadu_si128((const __m128i *)dl->qs);
+#ifdef __AVX512F__
+ auto ql = _mm_and_si128(q, mask);
+ auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask);
+ v1 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm);
+ v2 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)), vm);
+#else
+ if (j%QK4_1) q = _mm_srli_epi16(q, 4);
+ auto q16 = _mm256_cvtepi8_epi16(_mm_and_si128(q, mask));
+ v1 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))), vm);
+ v2 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))), vm);
+#endif
+#endif
+ }
+
+#ifdef __aarch64__
+ const uint8x16_t mask = vdupq_n_u8(0xf);
+#else
+ const __m128i mask = _mm_set1_epi8(0xf);
+#endif
+};
+
+struct HelperIQ4nl final : public BaseHelper {
+ using Base = BaseHelper;
+ constexpr static ggml_type type = GGML_TYPE_IQ4_NL;
+#ifdef __aarch64__
+ using block_q8 = block_q8_0;
+ HelperIQ4nl(const char * data, int stride) : Base(data, stride), values(vld1q_s8(iq4k_values)) {}
+ constexpr static int block_size_q = QK8_0;
+#else
+ HelperIQ4nl(const char * data, int stride) : Base(data, stride) {}
+#ifdef HAVE_FANCY_SIMD
+ using block_q8 = block_q8_2;
+ constexpr static int block_size_q = QK8_2;
+#else
+ using block_q8 = block_q8_0;
+ constexpr static int block_size_q = QK8_0;
+#endif
+#endif
+
+ // Needed for v * softmax(k * q)
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ int j = F16::block_size*i;
+ auto dl = (const block_iq4_nl *)Base::lblock(l1) + j/QK4_0;
+#ifdef __aarch64__
+ auto vd = F16::set1(*(const float16_t *)&dl->d);
+ auto q = vld1q_u8(dl->qs);
+ q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
+ q = vqtbl1q_s8(values, q);
+ v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q))));
+ v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q))));
+#else
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
+ auto q = _mm_loadu_si128((const __m128i *)dl->qs);
+#ifdef __AVX512F__
+ auto ql = _mm_shuffle_epi8(values, _mm_and_si128(q, mask));
+ auto qh = _mm_shuffle_epi8(values, _mm_and_si128(_mm_srli_epi16(q, 4), mask));
+ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
+ v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
+#else
+ if (j%QK4_0) q = _mm_srli_epi16(q, 4);
+ auto q16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(values, _mm_and_si128(q, mask)));
+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))));
+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))));
+#endif
+#endif
+ }
+
+#ifdef __aarch64__
+ const uint8x16_t mask = vdupq_n_u8(0xf);
+ const int8x16_t values;
+#else
+ const __m128i mask = _mm_set1_epi8(0xf);
+ const __m128i values = _mm_loadu_si128((const __m128i *)iq4k_values);
+#endif
+};
+
+struct HelperQ60 final : public BaseHelper {
+ constexpr static ggml_type type = GGML_TYPE_Q6_0;
+#ifdef __aarch64__
+ using block_q8 = block_q8_0;
+ constexpr static int block_size_q = QK8_0;
+#else
+ using block_q8 = block_q8_2;
+ constexpr static int block_size_q = QK8_2;
+#endif
+ using Base = BaseHelper;
+ HelperQ60(const char * data, int stride) : Base(data, stride) {}
+
+ // Needed for v * softmax(k * q)
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ int j = F16::block_size*i;
+ auto dl = (const block_q6_0 *)Base::lblock(l1) + j/QK6_0;
+#ifdef __aarch64__
+ // TODO
+ const float16_t * d16 = (const float16_t *)&dl->d;
+ auto vd = F16::set1(d16[0]);
+ //auto vd = F16::set1(*(const float16_t *)&dl->d);
+ auto qh8 = vld1_u8(dl->qh);
+ auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8);
+ auto qs = vld1q_u8(dl->qs);
+ qs = j%QK4_0 ? vshrq_n_u8(qs, 4) : vandq_u8(qs, mask_l);
+ qs = vorrq_u8(qs, vandq_u8(mask_h, j%QK4_0 ? vshrq_n_u8(qh, 2) : qh));
+ qs = vaddq_s8(qs, m32);
+ v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(qs))));
+ v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(qs))));
+#else
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
+ auto bl = _mm_loadu_si128((const __m128i *)dl->qs);
+ uint64_t aux64; std::memcpy(&aux64, dl->qh, 8);
+ auto bh = _mm_set_epi64x(aux64, aux64 << 4);
+#ifdef __AVX512F__
+ auto ql = _mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32);
+ auto qh = _mm_add_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(bl, 4), mask_l), _mm_and_si128(_mm_srli_epi16(bh, 2), mask_h)), m32);
+ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
+ v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
+#else
+ if (j%QK4_0) {
+ bl = _mm_srli_epi16(bl, 4);
+ bh = _mm_srli_epi16(bh, 2);
+ }
+ auto q16 = _mm256_cvtepi8_epi16(_mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32));
+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))));
+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))));
+#endif
+#endif
+ }
+
+#ifdef __AVX2__
+ const __m128i mask_l = _mm_set1_epi8(0x0f);
+ const __m128i mask_h = _mm_set1_epi8(0x30);
+ const __m128i m32 = _mm_set1_epi8(-32);
+#else
+ const uint8x16_t mask_l = vdupq_n_u8(0x0f);
+ const uint8x16_t mask_h = vdupq_n_u8(0x30);
+ const int8x16_t m32 = vdupq_n_s8(-32);
+#endif
+};
+
+template <int q_step_in, int k_step_in>
+struct FlashMS {
+ constexpr static int q_step = q_step_in;
+ constexpr static int k_step = k_step_in;
+// Something goes wrong when storing and manipulating K*Q as fp16.
+// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
+// As I wasn't able to find where we lose precision, let's comment this out
+// for now and do the K*Q part in fp32.
+//#ifdef __aarch64__
+// using cache_t = float16_t;
+//#else
+// using cache_t = float;
+//#endif
+ using cache_t = float;
+
+ FlashMS(float scale, float softcap) : vscale(F16::set1(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
+
+ inline void init_qstep() {
+ for (int j = 0; j < q_step; ++j) {
+ S[j] = 0; M[j] = -INFINITY;
+ }
+ }
+
+ inline void update_M(int j, float smax) {
+ if (smax == -INFINITY) {
+ std::memset(cache + k_step*j, 0, k_step*sizeof(float));
+ need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
+ return;
+ }
+ need_scaling[j] = 0;
+ if (smax > M[j]) {
+ if (M[j] > -INFINITY) {
+ float m = expf(M[j] - smax);
+ vms[j] = m;
+ need_scaling[j] = 1;
+ S[j] *= m;
+ } else {
+ need_scaling[j] = 2;
+ S[j] = 0;
+ }
+ M[j] = smax;
+ }
+ }
+
+#ifdef __aarch64__
+ inline void update_S(int j, float32x4_t * vk) {
+ auto vm = vdupq_n_f32(M[j]);
+ auto vsum = vdupq_n_f32(0);
+ for (int l = 0; l < k_step/4; ++l) {
+ vk[l] = v_expf(vsubq_f32(vk[l], vm));
+ vsum = vaddq_f32(vsum, vk[l]);
+ F16::store(cache + k_step*j + 4*l, vk[l]);
+ }
+ S[j] += vaddvq_f32(vsum);
+ }
+#else
+ inline void update_S(int j, F16::Data * vk) {
+ auto vm = F16::set1(M[j]);
+ for (int l = 0; l < k_step/F16::block_size; ++l) {
+ vk[l] = v_expf(F16::sub(vk[l], vm));
+ F16::store(cache + k_step*j + F16::block_size*l, vk[l]);
+ }
+ S[j] += F16::reduce_add<k_step>(vk);
+ }
+#endif
+
+#ifdef __aarch64__
+ inline float load_and_scale(int j, float32x4_t * vk) {
+ float32x4_t vmax = vdupq_n_f32(-INFINITY);
+ // Something goes wrong when storing and manipulating K*Q as fp16.
+ // It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
+ // As I wasn't able to find where we lose precision, let's comment this out
+ // for now and do the K*Q part in fp32.
+ //if (softcap <= 0.0f) {
+ // for (int l = 0; l < k_step/F16::block_size; ++l) {
+ // auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
+ // vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val));
+ // vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val));
+ // vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1]));
+ // }
+ //} else {
+ // auto v_softcap = vdupq_n_f32(softcap);
+ // for (int l = 0; l < k_step/F16::block_size; ++l) {
+ // auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
+ // vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val));
+ // vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val));
+ // vk[2*l+0] = vmulq_f32(v_softcap, v_tanh(vk[2*l+0]));
+ // vk[2*l+1] = vmulq_f32(v_softcap, v_tanh(vk[2*l+1]));
+ // vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1]));
+ // }
+ //}
+ auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale));
+ if (softcap <= 0.0f) {
+ for (int l = 0; l < k_step/4; ++l) {
+ vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l));
+ vmax = vmaxq_f32(vmax, vk[l]);
+ }
+ } else {
+ auto v_softcap = vdupq_n_f32(softcap);
+ for (int l = 0; l < k_step/4; ++l) {
+ vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l));
+ vk[l] = vmulq_f32(v_softcap, v_tanh(vk[l]));
+ vmax = vmaxq_f32(vmax, vk[l]);
+ }
+ }
+ return vmaxvq_f32(vmax);
+ }
+ inline float load_apply_mask_and_scale(int j, float32x4_t * vk, const char * mask) {
+ auto vzero = vdupq_n_f16(0);
+ auto vinf = vdupq_n_f32(-INFINITY);
+ for (int l = 0; l < k_step/8; ++l) {
+ auto vm = vceqq_f16(vzero, vld1q_f16((const float16_t *)mask + 8*l));
+ auto vm1 = vzip1q_u16(vm, vm);
+ auto vm2 = vzip2q_u16(vm, vm);
+ auto kq = vld1q_f32_x2(cache + k_step*j + 8*l);
+ vk[2*l+0] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[0]), vm1),
+ vbicq_u32(vreinterpretq_u32_f32(vinf), vm1)));
+ vk[2*l+1] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[1]), vm2),
+ vbicq_u32(vreinterpretq_u32_f32(vinf), vm2)));
+ }
+ float32x4_t vmax = vdupq_n_f32(-INFINITY);
+ auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale));
+ if (softcap <= 0.0f) {
+ for (int l = 0; l < k_step/4; ++l) {
+ vk[l] = vmulq_f32(vscale32, vk[l]);
+ vmax = vmaxq_f32(vmax, vk[l]);
+ }
+ } else {
+ auto v_softcap = vdupq_n_f32(softcap);
+ for (int l = 0; l < k_step/4; ++l) {
+ vk[l] = vmulq_f32(vscale32, vk[l]);
+ vk[l] = vmulq_f32(v_softcap, v_tanh(vk[l]));
+ vmax = vmaxq_f32(vmax, vk[l]);
+ }
+ }
+ return vmaxvq_f32(vmax);
+ }
+#else
+ inline float load_and_scale(int j, F16::Data * vk) {
+ if (softcap <= 0.0f) {
+ for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
+ } else {
+ auto v_softcap = F16::set1(softcap);
+ for (int l = 0; l < k_step/F16::block_size; ++l) {
+ auto val = F16::load(cache + k_step*j + F16::block_size*l);
+ vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, val)));
+ }
+ }
+ return F16::reduce_max<k_step>(vk);
+ }
+ static inline __m256 apply_mask(int l, const char * mask, __m256 val, [[maybe_unused]] __m256 vinf) {
+ return _mm256_add_ps(val, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)mask+l)));
+ //auto m128 = _mm_loadu_si128((const __m128i *)mask+l);
+ //m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128());
+ //auto m256 = _mm256_cvtepi16_epi32(m128);
+ //auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16)));
+ //return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf));
+ }
+#ifdef __AVX512F__
+ static inline __m512 apply_mask(int l, const char * mask, __m512 val, __m512 vinf) {
+ auto m256 = _mm256_loadu_si256((const __m256i *)mask+l);
+ m256 = _mm256_cmpeq_epi16(m256, _mm256_setzero_si256());
+ auto m512 = _mm512_cvtepi16_epi32(m256);
+ auto mf = _mm512_castsi512_ps(_mm512_or_si512(m512, _mm512_slli_epi32(m512, 16)));
+ return _mm512_or_ps(_mm512_and_ps(mf, val), _mm512_andnot_ps(mf, vinf));
+ }
+#endif
+ inline float load_apply_mask_and_scale(int j, F16::Data * vk, const char * mask) {
+#ifdef HAVE_FANCY_SIMD
+ auto vzero = _mm256_set1_epi16(0);
+ auto vinf = _mm512_set1_ps(-INFINITY);
+ if (softcap <= 0) {
+ for (int l = 0; l < k_step/F16::block_size; ++l) {
+ auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero);
+ vk[l] = _mm512_mask_mul_ps(vinf, m16, vscale, F16::load(cache + k_step*j + F16::block_size*l));
+ }
+ } else {
+ auto v_softcap = F16::set1(softcap);
+ for (int l = 0; l < k_step/F16::block_size; ++l) {
+ auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero);
+ vk[l] = _mm512_mask_mul_ps(vinf, m16, v_softcap, v_tanh(F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l))));
+ }
+ }
+#else
+ auto vinf = F16::set1(-INFINITY);
+ for (int l = 0; l < k_step/F16::block_size; ++l) {
+ vk[l] = apply_mask(l, mask, F16::load(cache + k_step*j + F16::block_size*l), vinf);
+ }
+ if (softcap <= 0) {
+ for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, vk[l]);
+ } else {
+ auto v_softcap = F16::set1(softcap);
+ for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, vk[l])));
+ }
+#endif
+ return F16::reduce_max<k_step>(vk);
+ }
+#endif
+
+#ifdef __aarch64__
+ inline void update_M_S(int j, float32x4_t * vk) {
+ float smax = load_and_scale(j, vk);
+ update_M(j, smax);
+ if (M[j] > -INFINITY) update_S(j, vk);
+ }
+ inline void update_M_S(int j, float32x4_t * vk, const char * mask) {
+ float smax = load_apply_mask_and_scale(j, vk, mask);
+ update_M(j, smax);
+ if (M[j] > -INFINITY) update_S(j, vk);
+ }
+#else
+ inline void update_M_S(int j, F16::Data * vk) {
+ float smax = load_and_scale(j, vk);
+ update_M(j, smax);
+ if (M[j] > -INFINITY) update_S(j, vk);
+ }
+ inline void update_M_S(int j, F16::Data * vk, const char * mask) {
+ float smax = load_apply_mask_and_scale(j, vk, mask);
+ update_M(j, smax);
+ if (M[j] > -INFINITY) update_S(j, vk);
+ }
+#endif
+
+ cache_t cache[q_step*k_step];
+ float S[q_step], M[q_step];
+ int need_scaling[q_step];
+ float vms[q_step];
+ const F16::Data vscale;
+ const float softcap;
+ const ggml_half h_inf;
+
+};
+
+template <int D, int q_step, int k_step>
+struct FlashQKV {
+
+#ifdef __aarch64__
+ using qkv_cache_t = float16_t;
+#else
+ using qkv_cache_t = float;
+#endif
+
+ template <typename VHelper, typename FMS>
+ inline void accumulate_qkv_1(const VHelper& vh, const FMS& fms) {
+ static_assert(q_step == FMS::q_step);
+ F16::Data vq[D/F16::block_size];
+ if (fms.need_scaling[0] == 2) {
+ for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::zero();
+ } else {
+ for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::load(qkv_cache + F16::block_size*i);
+ if (fms.need_scaling[0] == 1) {
+ auto vms = F16::set1(fms.vms[0]);
+ for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::mul(vms, vq[i]);
+ }
+ }
+ F16::Data v0, v1;
+ for (int l = 0; l < k_step; l += 4) {
+ auto vs0 = F16::set1(fms.cache[l + 0]);
+ auto vs1 = F16::set1(fms.cache[l + 1]);
+ auto vs2 = F16::set1(fms.cache[l + 2]);
+ auto vs3 = F16::set1(fms.cache[l + 3]);
+ for (int i = 0; i < D/F16::block_size; i += 2) {
+ vh.load(l+0, i, v0, v1);
+ vq[i+0] = F16::fmadd(vq[i+0], v0, vs0);
+ vq[i+1] = F16::fmadd(vq[i+1], v1, vs0);
+ vh.load(l+1, i, v0, v1);
+ vq[i+0] = F16::fmadd(vq[i+0], v0, vs1);
+ vq[i+1] = F16::fmadd(vq[i+1], v1, vs1);
+ vh.load(l+2, i, v0, v1);
+ vq[i+0] = F16::fmadd(vq[i+0], v0, vs2);
+ vq[i+1] = F16::fmadd(vq[i+1], v1, vs2);
+ vh.load(l+3, i, v0, v1);
+ vq[i+0] = F16::fmadd(vq[i+0], v0, vs3);
+ vq[i+1] = F16::fmadd(vq[i+1], v1, vs3);
+ }
+ }
+ for (int i = 0; i < D/F16::block_size; ++i) F16::store(qkv_cache + F16::block_size*i, vq[i]);
+ }
+
+ // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
+ // Hence, for now, we will not handle head sizes of 80 and 112
+ template <typename VHelper, typename FMS>
+ inline void accumulate_qkv(const VHelper& vh, const FMS& fms) {
+ static_assert(q_step == FMS::q_step);
+ if constexpr (q_step == 1) {
+ accumulate_qkv_1(vh, fms);
+ return;
+ }
+ for (int j = 0; j < q_step; ++j) {
+ auto R = qkv_cache + D*j;
+ if (fms.need_scaling[j] == 2) {
+ std::memset(R, 0, D*sizeof(qkv_cache_t));
+ }
+ else if (fms.need_scaling[j] == 1) {
+ auto vms = F16::set1(fms.vms[j]);
+ for (int i = 0; i < D/F16::block_size; ++i) {
+ F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i)));
+ }
+ }
+ }
+#ifdef __AVX512F__
+ if constexpr ((D/F16::block_size)%4 == 0) {
+ F16::Data v[16];
+ F16::Data vs[4];
+ for (int i = 0; i < D/F16::block_size; i += 4) {
+ for (int l = 0; l < k_step; l += 4) {
+ for (int k = 0; k < 4; ++k) {
+ vh.load(l+k, i+0, v[4*k+0], v[4*k+1]);
+ vh.load(l+k, i+2, v[4*k+2], v[4*k+3]);
+ }
+ for (int j = 0; j < q_step; ++j) {
+ auto R = qkv_cache + D*j;
+ auto s1 = F16::load(R + F16::block_size*(i+0));
+ auto s2 = F16::load(R + F16::block_size*(i+1));
+ auto s3 = F16::load(R + F16::block_size*(i+2));
+ auto s4 = F16::load(R + F16::block_size*(i+3));
+ F16::set4(fms.cache + k_step*j + l, vs);
+ for (int k = 0; k < 4; ++k) {
+ s1 = F16::fmadd(s1, v[4*k+0], vs[k]);
+ s2 = F16::fmadd(s2, v[4*k+1], vs[k]);
+ s3 = F16::fmadd(s3, v[4*k+2], vs[k]);
+ s4 = F16::fmadd(s4, v[4*k+3], vs[k]);
+ }
+ F16::store(R + F16::block_size*(i+0), s1);
+ F16::store(R + F16::block_size*(i+1), s2);
+ F16::store(R + F16::block_size*(i+2), s3);
+ F16::store(R + F16::block_size*(i+3), s4);
+ }
+ }
+ }
+ return;
+ }
+#endif
+ F16::Data v[8];
+#ifdef __AVX2__
+ F16::Data vs[4];
+#endif
+ for (int i = 0; i < D/F16::block_size; i += 2) {
+ for (int l = 0; l < k_step; l += 4) {
+ vh.load(l+0, i, v[0], v[4]);
+ vh.load(l+1, i, v[1], v[5]);
+ vh.load(l+2, i, v[2], v[6]);
+ vh.load(l+3, i, v[3], v[7]);
+ for (int j = 0; j < q_step; ++j) {
+ auto R = qkv_cache + D*j;
+ auto s1 = F16::load(R + F16::block_size*(i+0));
+ auto s2 = F16::load(R + F16::block_size*(i+1));
+#ifdef __AVX2__
+ F16::set4(fms.cache + k_step*j + l, vs);
+ for (int k = 0; k < 4; ++k) {
+ s1 = F16::fmadd(s1, v[k+0], vs[k]);
+ s2 = F16::fmadd(s2, v[k+4], vs[k]);
+ }
+#else
+ auto vs = F16::set4(fms.cache + k_step*j + l);
+ s1 = F16::fmadd_lane0(s1, v[0], vs);
+ s2 = F16::fmadd_lane0(s2, v[4], vs);
+ s1 = F16::fmadd_lane1(s1, v[1], vs);
+ s2 = F16::fmadd_lane1(s2, v[5], vs);
+ s1 = F16::fmadd_lane2(s1, v[2], vs);
+ s2 = F16::fmadd_lane2(s2, v[6], vs);
+ s1 = F16::fmadd_lane3(s1, v[3], vs);
+ s2 = F16::fmadd_lane3(s2, v[7], vs);
+#endif
+ F16::store(R + F16::block_size*(i+0), s1);
+ F16::store(R + F16::block_size*(i+1), s2);
+ }
+ }
+ }
+ }
+
+ template <typename VHelper, typename FMS>
+ inline void accumulate_qkv(int nq1, const VHelper& vh, const FMS& fms) {
+ static_assert(q_step == FMS::q_step);
+ if (nq1 == 1) {
+ accumulate_qkv_1(vh, fms);
+ return;
+ }
+ F16::Data v[8];
+ for (int j = 0; j < nq1; ++j) {
+ auto R = qkv_cache + D*j;
+ if (fms.need_scaling[j] == 2) {
+ std::memset(R, 0, D*sizeof(qkv_cache_t));
+ }
+ else if (fms.need_scaling[j] == 1) {
+ auto vms = F16::set1(fms.vms[j]);
+ for (int i = 0; i < D/F16::block_size; ++i) {
+ F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i)));
+ }
+ }
+ }
+ for (int i = 0; i < D/F16::block_size; i += 2) {
+ for (int l = 0; l < k_step; l += 4) {
+ vh.load(l+0, i, v[0], v[4]);
+ vh.load(l+1, i, v[1], v[5]);
+ vh.load(l+2, i, v[2], v[6]);
+ vh.load(l+3, i, v[3], v[7]);
+ for (int j = 0; j < nq1; ++j) {
+ auto R = qkv_cache + D*j;
+ auto s1 = F16::load(R + F16::block_size*(i+0));
+ auto s2 = F16::load(R + F16::block_size*(i+1));
+ auto vs = F16::set4(fms.cache + k_step*j + l);
+ s1 = F16::fmadd_lane0(s1, v[0], vs);
+ s2 = F16::fmadd_lane0(s2, v[4], vs);
+ s1 = F16::fmadd_lane1(s1, v[1], vs);
+ s2 = F16::fmadd_lane1(s2, v[5], vs);
+ s1 = F16::fmadd_lane2(s1, v[2], vs);
+ s2 = F16::fmadd_lane2(s2, v[6], vs);
+ s1 = F16::fmadd_lane3(s1, v[3], vs);
+ s2 = F16::fmadd_lane3(s2, v[7], vs);
+ F16::store(R + F16::block_size*(i+0), s1);
+ F16::store(R + F16::block_size*(i+1), s2);
+ }
+ }
+ }
+ }
+
+ template <typename FMS>
+ inline void normalize_and_store_1row(const FMS& fms, int j, const qkv_cache_t * R, float * qkv) const {
+ static_assert(q_step == FMS::q_step);
+ GGML_ASSERT(fms.S[j] > 0);
+ auto norm = F16::set1(1/fms.S[j]);
+ //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
+ for (int i = 0; i < D/F16::block_size; ++i) {
+ auto r = F16::load(R + F16::block_size*i);
+ F16::store(qkv + F16::block_size*i, F16::mul(norm, r));
+ }
+ }
+
+ template <typename FMS>
+ inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const {
+ static_assert(q_step == FMS::q_step);
+ if (M && S) {
+ std::memcpy(M, fms.M, nq1*sizeof(float));
+ std::memcpy(S, fms.S, nq1*sizeof(float));
+ auto R = qkv_cache;
+ for (int j = 0; j < nq1; ++j) {
+#ifdef __aarch64__
+ for (int i = 0; i < D/F16::block_size; ++i) {
+ F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i));
+ }
+#else
+ std::memcpy(qkv, R, D*sizeof(float));
+#endif
+ qkv += stride_qkv;
+ R += D;
+ }
+ } else {
+ auto R = qkv_cache;
+ for (int j = 0; j < nq1; ++j) {
+ normalize_and_store_1row(fms, j, R, qkv);
+ qkv += stride_qkv;
+ R += D;
+ }
+ }
+ }
+
+ template <typename FMS>
+ inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, float * M, float * S) const {
+ static_assert(q_step == FMS::q_step);
+ if (M && S) {
+ std::memcpy(M, fms.M, q_step*sizeof(float));
+ std::memcpy(S, fms.S, q_step*sizeof(float));
+ auto R = qkv_cache;
+ for (int j = 0; j < q_step; ++j) {
+#ifdef __aarch64__
+ for (int i = 0; i < D/F16::block_size; ++i) {
+ F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i));
+ }
+#else
+ std::memcpy(qkv, R, D*sizeof(float));
+#endif
+ qkv += stride_qkv;
+ R += D;
+ }
+ } else {
+ auto R = qkv_cache;
+ for (int j = 0; j < q_step; ++j) {
+ normalize_and_store_1row(fms, j, R, qkv);
+ qkv += stride_qkv;
+ R += D;
+ }
+ }
+ }
+
+ // qkv_cache_t qkv_cache[D*q_step];
+ // The initializer is not actually required. But the compiler cannot figure out that when qkv_cache is
+ // first used for q_step rows, fms.need_scaling[j] is always 2, which zeroes the content of qkv_cache.
+ // As a result, we get an infinite stream of warnings about uninitialized variable use (one for each
+ // combination of D, q_step, k_step), which is extremely annoying. Hence, I succumb to the trend of
+ // constantly being saved by others (the compiler in this case), and add this 100% unnecessary initialization.
+ qkv_cache_t qkv_cache[D*q_step]; // = {};
+ //qkv_cache_t * qkv_cache;
+};
+
+template <int D, int q_step, int k_step>
+struct FlashQKfp32 {
+ static_assert(D%F16::block_size == 0 && D <= 576);
+ static_assert(k_step%F16::block_size == 0);
+ static_assert(q_step <= 4 || q_step%4 == 0);
+
+ template <typename KHelper, typename q_float>
+ static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
+ FlashMS<q_step, k_step>& fms) {
+#ifdef __AVX2__
+ constexpr int nrc_k = 8;
+ static_assert(k_step%nrc_k == 0);
+#endif
+ DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
+ iqk_gemm_default_floats(D, q_step, kh.block, kh.stride, info, k_step);
+#ifdef __AVX2__
+ F16::Data vk[k_step/F16::block_size];
+#else
+ float32x4_t vk[k_step/4];
+#endif
+ for (int j = 0; j < q_step; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
+ }
+ }
+
+ template <typename KHelper, typename q_float>
+ static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
+ FlashMS<q_step, k_step>& fms) {
+#ifdef __AVX2__
+ constexpr int nrc_k = 8;
+ static_assert(k_step%nrc_k == 0);
+#endif
+ DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
+ iqk_gemm_default_floats(D, nq, kh.block, kh.stride, info, k_step);
+#ifdef __AVX2__
+ F16::Data vk[k_step/F16::block_size];
+#else
+ float32x4_t vk[k_step/4];
+#endif
+ for (int j = 0; j < nq; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
+ }
+ }
+
+#ifdef __aarch64__
+ static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) {
+ for (int i = 0; i < nq; ++i) {
+ for (int j = 0; j < D; j += 8) {
+ auto val1_f32 = vld1q_f32(q + j + 0);
+ auto val2_f32 = vld1q_f32(q + j + 4);
+ auto val_f16 = vcombine_f16(vcvt_f16_f32(val1_f32), vcvt_f16_f32(val2_f32));
+ vst1q_f16(q_f16 + j, val_f16);
+ }
+ q += stride_q;
+ q_f16 += D;
+ }
+ }
+#endif
+
+ template <typename KHelper, typename block_q8>
+ static inline void mul_mask_kq(const KHelper& kh, int stride_m,
+ const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
+ constexpr int kMaxQ = 8;
+ static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
+ if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D>> ||
+ std::is_same_v<KHelper, HelperQ8KV<D>>) {
+ iqk_gemm_q8kv_fa(D, q_step, kh.type, kh.block, kh.stride, info, k_step);
+ } else {
+ iqk_gemm_legacy_fa(D, q_step, kh.type, kh.block, kh.stride, info, k_step);
+ }
+#ifdef __aarch64__
+ float32x4_t vk[k_step/4];
+ for (int j = 0; j < q_step; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
+ }
+#else
+ F16::Data vk[k_step/F16::block_size];
+ for (int j = 0; j < q_step; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
+ }
+#endif
+ }
+
+ template <typename KHelper, typename block_q8>
+ static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m,
+ const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
+ GGML_ASSERT(nq < q_step);
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
+ if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D>> ||
+ std::is_same_v<KHelper, HelperQ8KV<D>>) {
+ iqk_gemm_q8kv_fa(D, nq, kh.type, kh.block, kh.stride, info, k_step);
+ } else {
+ iqk_gemm_legacy_fa(D, nq, kh.type, kh.block, kh.stride, info, k_step);
+ }
+#ifdef __aarch64__
+ float32x4_t vk[k_step/4];
+ for (int j = 0; j < nq; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
+ }
+#else
+ F16::Data vk[k_step/F16::block_size];
+ for (int j = 0; j < nq; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
+ }
+#endif
+ }
+};
+
+template <int Dk, int Dv, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
+void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
+ FlashMS<q_step, k_step>& fms,
+ FlashQKV<Dv, q_step, k_step>& fqkv,
+ const float * q, const char * mask, float * qkv,
+ float * M, float * S) {
+#ifdef __aarch64__
+ float16_t q_f16[Dk*q_step];
+#endif
+
+ for (int i1 = 0; i1 < nq1/q_step; ++i1) {
+ fms.init_qstep();
+ kh.reset_block();
+ vh.reset_block();
+#ifdef __aarch64__
+ KQHelper::convert(q_step, stride_q, q, q_f16);
+#endif
+ auto mr = mask;
+ for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+#ifdef __aarch64__
+ KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms);
+#else
+ KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms);
+#endif
+ fqkv.accumulate_qkv(vh, fms);
+ kh.next_block(k_step);
+ vh.next_block(k_step);
+ mr += k_step*sizeof(ggml_half);
+ }
+ fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
+
+ q += q_step*stride_q;
+ mask += q_step*stride_m;
+ qkv += q_step*stride_qkv;
+ if (M && S) { M += q_step; S += q_step; }
+ }
+ int n_left = nq1 - q_step*(nq1/q_step);
+ if (n_left > 0) {
+ fms.init_qstep();
+ kh.reset_block();
+ vh.reset_block();
+#ifdef __aarch64__
+ KQHelper::convert(n_left, stride_q, q, q_f16);
+#endif
+ auto mr = mask;
+ for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+#ifdef __aarch64__
+ KQHelper::multiply_mask_kq(n_left, kh, Dk, stride_m, q_f16, mr, fms);
+#else
+ KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms);
+#endif
+ fqkv.accumulate_qkv(n_left, vh, fms);
+ kh.next_block(k_step);
+ vh.next_block(k_step);
+ mr += k_step*sizeof(ggml_half);
+ }
+ fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
+ }
+}
+
+template <int Dk, int Dv, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
+void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
+ FlashMS<q_step, k_step>& fms,
+ FlashQKV<Dv, q_step, k_step>& fqkv,
+ const float * q, const char * mask, float * qkv,
+ float * M, float * S, char * qptr) {
+ auto q8 = (typename KHelper::block_q8 *)qptr;
+ if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80>) {
+ if (nq1 == q_step) {
+ fms.init_qstep();
+ kh.reset_block();
+ vh.reset_block();
+ block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8];
+ HelperQ80R8<Dk> khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0));
+ auto q8r = (typename HelperQ80R8<Dk>::block_q8 *)qptr;
+ HelperQ80::convert<Dk>(q_step, stride_q, q, q8r);
+ auto mr = mask;
+ for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+ HelperQ80R8<Dk>::repack(k_step, kh.block, kh.stride, q8r8);
+ KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
+ fqkv.accumulate_qkv(vh, fms);
+ kh.next_block(k_step);
+ vh.next_block(k_step);
+ mr += k_step*sizeof(ggml_half);
+ }
+ fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
+ return;
+ }
+ }
+#if FA_TIMING
+ Perf perf(false);
+#endif
+ for (int i1 = 0; i1 < nq1/q_step; ++i1) {
+#if FA_TIMING
+ auto t1 = Perf::cur_time();
+#endif
+ fms.init_qstep();
+ kh.reset_block();
+ vh.reset_block();
+ HelperQ80::convert<Dk>(q_step, stride_q, q, q8);
+#if FA_TIMING
+ perf.accum_nolock(0, t1);
+#endif
+ auto mr = mask;
+ for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+#if FA_TIMING
+ t1 = Perf::cur_time();
+ KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
+ perf.accum_nolock(1, t1);
+ t1 = Perf::cur_time();
+ fqkv.accumulate_qkv(vh, fms);
+ perf.accum_nolock(2, t1);
+#else
+ KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
+ fqkv.accumulate_qkv(vh, fms);
+#endif
+ kh.next_block(k_step);
+ vh.next_block(k_step);
+ mr += k_step*sizeof(ggml_half);
+ }
+#if FA_TIMING
+ t1 = Perf::cur_time();
+ fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
+ perf.accum_nolock(3, t1);
+#else
+ fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
+#endif
+
+ q += q_step*stride_q;
+ mask += q_step*stride_m;
+ qkv += q_step*stride_qkv;
+ if (M && S) { M += q_step; S += q_step; }
+ }
+ int n_left = nq1 - q_step*(nq1/q_step);
+ if (n_left > 0) {
+ fms.init_qstep();
+ kh.reset_block();
+ vh.reset_block();
+ HelperQ80::convert<Dk>(n_left, stride_q, q, q8);
+ auto mr = mask;
+ for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+ KQHelper::mul_mask_kq(n_left, kh, stride_m, q8, mr, fms);
+ fqkv.accumulate_qkv(n_left, vh, fms);
+ kh.next_block(k_step);
+ vh.next_block(k_step);
+ mr += k_step*sizeof(ggml_half);
+ }
+ fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
+ }
+#if FA_TIMING
+ Perf::instance().add(perf);
+#endif
+}
+
+char * get_q_storage(size_t size) {
+ thread_local std::vector<char> q_storage;
+ if (q_storage.size() < size) q_storage.resize(size);
+ return q_storage.data();
+}
+
+// Some of the methods in FlashAttn have two identical implementations that only differ by
+// one version using a loop over the template parameter q_step, while the other using a loop
+// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot,
+// but performance drops signficantly if I remove the version with fixed q_step iterations.
+// We only instantiate FlashAttn with q_step = 1 and q_step = 4 or 8 (depending on head size D),
+// so when we have to process Nq rows, we process q_step*(Nq/q_step) using fixed q_step loops,
+// and use the variable nq version (with lower performance) only for the remaining i1...q_step-1
+// rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to
+// process template parameter of such functions, but this would result in the compiler generating
+// q_step-1 versions of these functions for us, which I though was too much with q_step = 8.
+template <int Dk, int Dv, int q_step, int k_step>
+struct FlashAttn {
+ static_assert(Dk%F16::block_size == 0 && Dk <= 576);
+ static_assert(Dv%F16::block_size == 0 && Dv <= 512);
+ static_assert(k_step%F16::block_size == 0);
+ static_assert(q_step <= 4 || q_step%4 == 0);
+
+ FlashAttn(float scale, float softcap) : fms(scale, softcap) {}
+
+ template <typename KHelper, typename VHelper>
+ void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
+ const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) {
+ if constexpr (std::is_same_v<KHelper, HelperQ40> ||
+ std::is_same_v<KHelper, HelperQ41> ||
+ std::is_same_v<KHelper, HelperIQ4nl> ||
+ std::is_same_v<KHelper, HelperQ60> ||
+ std::is_same_v<KHelper, HelperQ80R8<Dk>> ||
+ std::is_same_v<KHelper, HelperQ80> ||
+ std::is_same_v<KHelper, HelperQ8KV<Dk>> ||
+ std::is_same_v<KHelper, HelperQ8KVR8<Dk>>) {
+ constexpr size_t kMaxOnStackSize = 576;
+ //auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8);
+ auto q_size = q_step*(Dk/QK8_2*sizeof(block_q8_2));
+ q_size = GGML_PAD(q_size, 64);
+ if (q_size > kMaxOnStackSize) {
+ auto qptr = get_q_storage(q_size);
+ if (false && nq1 >= 8) {
+ if constexpr (std::is_same_v<KHelper, HelperQ80>) {
+#if FA_TIMING
+ auto t1 = Perf::cur_time();
+ HelperQ80R8<Dk, k_step> khr4(nk1, kh);
+ Perf::instance().accum(4, t1);
+#else
+ HelperQ80R8<Dk> khr4(nk1, kh);
+#endif
+ compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
+ khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
+ return;
+
+ }
+#if GGML_IQK_FA_ALL_QUANTS
+ if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk>>) {
+#if FA_TIMING
+ auto t1 = Perf::cur_time();
+ HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
+ Perf::instance().accum(4, t1);
+#else
+ HelperQ8KVR8<Dk> khr4(nk1, kh);
+#endif
+ compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
+ khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
+ return;
+ }
+#endif
+ }
+ compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
+
+ }
+ else {
+ typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
+ compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, (char *)q8);
+ }
+ }
+ else {
+ compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
+ }
+ }
+
+ FlashMS<q_step, k_step> fms;
+ FlashQKV<Dv, q_step, k_step> fqkv;
+
+};
+
+#ifdef __AVX512BF16__
+
+template <int D, int step>
+struct HelperBF16 final : public BaseHelper {
+ using Base = BaseHelper;
+ HelperBF16(const char * data, int stride) : Base(data, stride) {}
+ inline void load(int l1, __m512bh * vk) const {
+ auto dr = Base::lblock(l1);
+ for (int i = 0; i < D/32; ++i) vk[i] = __m512bh(_mm512_loadu_si512((const __m512i*)dr + i));
+ }
+
+ inline void load(int l1, int i, __m512& v1, __m512& v2) const {
+ auto dr = Base::lblock(l1);
+ v1 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)dr + i + 0)), 16));
+ v2 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)dr + i + 1)), 16));
+ }
+
+ inline void load_2(int l1, __m512bh * vk) const {
+ load(l1+0, vk+0);
+ load(l1+1, vk+D/32);
+ }
+
+ inline void load_4(int l1, __m512bh * vk) const {
+ load(l1+0, vk+0);
+ load(l1+1, vk+1*D/32);
+ load(l1+2, vk+2*D/32);
+ load(l1+3, vk+3*D/32);
+ }
+
+ inline void load_8(int l1, __m512bh * vk) const {
+ for (int k = 0; k < 8; ++k) load(l1 + k, vk + k*D/32);
+ }
+};
+
+template <int D, int q_step, int k_step>
+struct FlashQKbf16 {
+ //static_assert(D%32 == 0 && D <= 256);
+ static_assert(D%32 == 0 && D <= 576);
+ static_assert(k_step%32 == 0);
+ static_assert(q_step <= 4 || q_step%4 == 0);
+
+ static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
+ __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
+ // q index is q_step*i1 + m1
+ // k index is k_step*k1 + l1
+ const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
+ fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY;
+ if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) {
+ return;
+ }
+ auto qr = q + m1*stride_q;
+ for (int i = 0; i < D/32; ++i) {
+ auto val1 = _mm512_loadu_ps(qr + 32*i);
+ auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
+ qv[i] = _mm512_cvtne2ps_pbh(val2, val1);
+ }
+ if (mp[l1+0] != fms.h_inf) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]);
+ fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
+ }
+ if (mp[l1+1] != fms.h_inf) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]);
+ fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
+ }
+ }
+
+ static inline void mult_mask_kq_one(int l1, int m1, int stride_m, const ggml_bf16_t * q, const char * mask,
+ __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
+ // q index is q_step*i1 + m1
+ // k index is k_step*k1 + l1
+ const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
+ fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY;
+ if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) {
+ return;
+ }
+ auto qr = q + m1*D;
+ for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i*)qr + i));
+ if (mp[l1+0] != fms.h_inf) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]);
+ fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
+ }
+ if (mp[l1+1] != fms.h_inf) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]);
+ fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
+ }
+ }
+
+ static inline void mult_mask_kq_4(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
+ __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
+ // q index is q_step*i1 + m1
+ // k index is k_step*k1 + l1
+ const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
+ fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] =
+ fms.cache[k_step*m1 + l1 + 2] = fms.cache[k_step*m1 + l1 + 3] = -INFINITY;
+ if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf && mp[l1+2] == fms.h_inf && mp[l1+3] == fms.h_inf) {
+ return;
+ }
+ auto qr = q + m1*stride_q;
+ for (int i = 0; i < D/32; ++i) {
+ auto val1 = _mm512_loadu_ps(qr + 32*i);
+ auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
+ qv[i] = _mm512_cvtne2ps_pbh(val2, val1);
+ }
+ for (int k = 0; k < 4; ++k) {
+ if (mp[l1+k] == fms.h_inf) continue;
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
+ fms.cache[k_step*m1 + l1 + k] = _mm512_reduce_add_ps(vsum);
+ }
+ }
+
+ static inline void mult_mask_kq_4(int l1, int m1, int stride_m, const ggml_bf16_t * q, const char * mask,
+ __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
+ // q index is q_step*i1 + m1
+ // k index is k_step*k1 + l1
+ const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
+ fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] =
+ fms.cache[k_step*m1 + l1 + 2] = fms.cache[k_step*m1 + l1 + 3] = -INFINITY;
+ if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf && mp[l1+2] == fms.h_inf && mp[l1+3] == fms.h_inf) {
+ return;
+ }
+ auto qr = q + m1*D;
+ for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i));
+ for (int k = 0; k < 4; ++k) {
+ if (mp[l1+k] == fms.h_inf) continue;
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
+ fms.cache[k_step*m1 + l1 + k] = _mm512_reduce_add_ps(vsum);
+ }
+ }
+
+ static inline __m128 hsum_float_4x4(__m128 * a) {
+ for (int i = 0; i < 2; ++i) a[i] = _mm_add_ps(_mm_unpacklo_ps(a[i], a[i+2]), _mm_unpackhi_ps(a[i], a[i+2]));
+ return _mm_add_ps(_mm_unpacklo_ps(a[0], a[1]), _mm_unpackhi_ps(a[0], a[1]));
+ }
+
+ template <typename KHelper>
+ static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q,
+ const char * mask, FlashMS<q_step, k_step>& fms) {
+ {
+ __m512bh qv[D/32];
+ if constexpr (D <= 128) {
+ __m512bh vkh[D/8];
+ for (int l1 = 0; l1 < k_step; l1 += 4) {
+ kh.load_4(l1, vkh);
+ for (int j = 0; j < q_step; ++j) {
+ mult_mask_kq_4(l1, j, stride_q, stride_m, q, mask, qv, vkh, fms);
+ }
+ }
+ } else {
+ __m512bh vkh[D/16];
+ for (int l1 = 0; l1 < k_step; l1 += 2) {
+ kh.load_2(l1, vkh);
+ for (int j = 0; j < q_step; ++j) {
+ mult_mask_kq_one(l1, j, stride_q, stride_m, q, mask, qv, vkh, fms);
+ }
+ }
+ }
+ }
+ __m512 vk[k_step/16];
+ for (int j = 0; j < q_step; ++j) {
+ fms.update_M_S(j, vk);
+ }
+ }
+
+ static inline void mult_mask_kq_4(int l1, int m1, const ggml_bf16_t * q,
+ __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
+ auto qr = q + m1*D;
+ for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i));
+ __m128 sum[4];
+ for (int k = 0; k < 4; ++k) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
+ auto aux = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1));
+ sum[k] = _mm_add_ps(_mm256_castps256_ps128(aux), _mm256_extractf128_ps(aux, 1));
+ }
+ //auto sum4 = _mm_mask_blend_ps(m8, hsum_float_4x4(sum), _mm_set1_ps(-INFINITY));
+ //_mm_storeu_ps(fms.cache + k_step*m1 + l1, sum4);
+ _mm_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_4x4(sum));
+ }
+
+ static IQK_ALWAYS_INLINE __m256 hsum_float_8x8(__m256 * accm) {
+ for (int i = 0; i < 4; ++i) {
+ accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31));
+ //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
+ // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
+ }
+ for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
+ return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
+ }
+
+ static inline void mult_mask_kq_8(int l1, int m1, const ggml_bf16_t * q,
+ __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
+ auto qr = q + m1*D;
+ for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i));
+ __m256 sum[8];
+ for (int k = 0; k < 8; ++k) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
+ sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1));
+ }
+ _mm256_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_8x8(sum));
+ }
+
+ static inline void mult_mask_kq_one(int l1, int m1, const ggml_bf16_t * q,
+ __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
+ auto qr = q + m1*D;
+ for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i*)qr + i));
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]);
+ fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
+ vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]);
+ fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
+ }
+
+#if FA_TIMING
+ template <typename KHelper>
+ static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
+ const char * mask, FlashMS<q_step, k_step>& fms, Perf& perf) {
+ auto t1 = Perf::cur_time();
+#else
+ template <typename KHelper>
+ static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
+ const char * mask, FlashMS<q_step, k_step>& fms) {
+#endif
+ if constexpr (q_step == 1) {
+ __m512bh vq[D/32];
+ __m512bh vk[D/32];
+ __m256 sum[8];
+ for (int i = 0; i < D/32; ++i) vq[i] = __m512bh(_mm512_loadu_si512((const __m512i *)q + i));
+ for (int l = 0; l < k_step; l += 8) {
+ for (int k = 0; k < 8; ++k) {
+ kh.load(l+k, vk);
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vk[i], vq[i]);
+ sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1));
+ }
+ _mm256_storeu_ps(fms.cache + l, hsum_float_8x8(sum));
+ }
+ }
+ else {
+ __m512bh qv[D/32];
+ if constexpr (D <= 128) {
+ __m512bh vkh[D/4];
+ for (int l1 = 0; l1 < k_step; l1 += 8) {
+ kh.load_8(l1, vkh);
+ for (int j = 0; j < q_step; ++j) mult_mask_kq_8(l1, j, q, qv, vkh, fms);
+ }
+ } else {
+ __m512bh vkh[D/16];
+ for (int l1 = 0; l1 < k_step; l1 += 2) {
+ kh.load_2(l1, vkh);
+ for (int j = 0; j < q_step; ++j) mult_mask_kq_one(l1, j, q, qv, vkh, fms);
+ }
+ }
+ }
+#if FA_TIMING
+ perf.accum_nolock(1, t1);
+ t1 = Perf::cur_time();
+#endif
+ F16::Data vk[k_step/16];
+ for (int j = 0; j < q_step; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
+ }
+#if FA_TIMING
+ perf.accum_nolock(2, t1);
+#endif
+ }
+
+ template <typename KHelper>
+ static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_m, const ggml_bf16_t * q,
+ const char * mask, FlashMS<q_step, k_step>& fms) {
+ {
+ __m512bh qv[D/32];
+ if constexpr (D <= 128) {
+ __m512bh vkh[D/8];
+ for (int l1 = 0; l1 < k_step; l1 += 4) {
+ kh.load_4(l1, vkh);
+ for (int j = 0; j < nq; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms);
+ }
+ } else {
+ __m512bh vkh[D/16];
+ for (int l1 = 0; l1 < k_step; l1 += 2) {
+ kh.load_2(l1, vkh);
+ for (int j = 0; j < nq; ++j) mult_mask_kq_one(l1, j, q, qv, vkh, fms);
+ }
+ }
+ }
+ F16::Data vk[k_step/16];
+ for (int j = 0; j < nq; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
+ }
+ }
+
+ template <typename KHelper>
+ static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q,
+ const char * mask, FlashMS<q_step, k_step>& fms) {
+ {
+ __m512bh qv[D/32];
+ __m512bh vkh[D/16];
+ for (int l1 = 0; l1 < k_step; l1 += 2) {
+ kh.load_2(l1, vkh);
+ for (int m1 = 0; m1 < nq; ++m1) {
+ mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vkh, fms);
+ }
+ }
+ }
+ __m512 vk[k_step/16];
+ for (int j = 0; j < nq; ++j) {
+ fms.update_M_S(j, vk);
+ }
+ }
+
+ static inline void convert(int stride_q, const float * q, ggml_bf16_t * bf16) {
+ auto qr = q;
+ for (int j = 0; j < q_step; ++j) {
+ for (int i = 0; i < D/32; ++i) {
+ auto val1 = _mm512_loadu_ps(qr + 32*i);
+ auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
+ _mm512_storeu_si512((__m512i *)bf16 + i, (__m512i)_mm512_cvtne2ps_pbh(val2, val1));
+ }
+ qr += stride_q;
+ bf16 += D;
+ }
+ }
+
+ static inline void convert(int nq, int stride_q, const float * q, ggml_bf16_t * bf16) {
+ auto qr = q;
+ for (int j = 0; j < nq; ++j) {
+ for (int i = 0; i < D/32; ++i) {
+ auto val1 = _mm512_loadu_ps(qr + 32*i);
+ auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
+ _mm512_storeu_si512((__m512i *)bf16 + i, (__m512i)_mm512_cvtne2ps_pbh(val2, val1));
+ }
+ qr += stride_q;
+ bf16 += D;
+ }
+ }
+};
+
+template <int Dk, int Dv, int q_step, int k_step>
+struct FlashAttnBF16 {
+ //static_assert(Dk%32 == 0 && Dk <= 256);
+ //static_assert(Dv%32 == 0 && Dv <= 256);
+ static_assert(Dk%32 == 0 && Dk <= 576);
+ static_assert(Dv%32 == 0 && Dv <= 512);
+ static_assert(k_step%32 == 0);
+ static_assert(q_step <= 4 || q_step%4 == 0);
+
+ FlashAttnBF16(float scale, float softcap) : fms(scale, softcap) {}
+
+ template <typename KHelper, typename VHelper>
+ void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
+ const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) {
+ ggml_bf16_t q_bf16[q_step*Dk];
+#if FA_TIMING
+ Perf perf(false);
+#endif
+ for (int i1 = 0; i1 < nq1/q_step; ++i1) {
+#if FA_TIMING
+ auto t1 = Perf::cur_time();
+#endif
+ fms.init_qstep();
+ kh.reset_block();
+ vh.reset_block();
+ FlashQKbf16<Dk, q_step, k_step>::convert(stride_q, q, q_bf16);
+#if FA_TIMING
+ perf.accum_nolock(0, t1);
+#endif
+ auto mr = mask;
+ for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+#if FA_TIMING
+ //t1 = Perf::cur_time();
+ FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);
+ //perf.accum_nolock(1, t1);
+ t1 = Perf::cur_time();
+ fqkv.accumulate_qkv(vh, fms);
+ perf.accum_nolock(3, t1);
+#else
+ FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms);
+ fqkv.accumulate_qkv(vh, fms);
+#endif
+ kh.next_block(k_step);
+ vh.next_block(k_step);
+ mr += k_step*sizeof(ggml_half);
+ }
+#if FA_TIMING
+ t1 = Perf::cur_time();
+#endif
+ fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
+#if FA_TIMING
+ perf.accum_nolock(4, t1);
+#endif
+
+ q += q_step*stride_q;
+ mask += q_step*stride_m;
+ qkv += q_step*stride_qkv;
+ }
+ int n_left = nq1 - q_step*(nq1/q_step);
+ if (n_left > 0) {
+ fms.init_qstep();
+ kh.reset_block();
+ vh.reset_block();
+ FlashQKbf16<Dk, q_step, k_step>::convert(n_left, stride_q, q, q_bf16);
+ auto mr = mask;
+ for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+ FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms);
+ fqkv.accumulate_qkv(n_left, vh, fms);
+ kh.next_block(k_step);
+ vh.next_block(k_step);
+ mr += k_step*sizeof(ggml_half);
+ }
+ fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
+ }
+#if FA_TIMING
+ Perf::instance().add(perf);
+#endif
+ }
+
+ FlashMS<q_step, k_step> fms;
+ FlashQKV<Dv, q_step, k_step> fqkv;
+};
+#endif
+
+template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
+inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
+ const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
+
+ auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
+ nq1 -= n;
+ if (nq1 == 0) return true;
+ q += n*stride_q;
+ mask += n*stride_m;
+ qkv += n*stride_qkv;
+ if (M && S) { M += n; S += n; }
+ return false;
+ };
+ if (nk1 >= 512) {
+ if (nq1 >= 128) {
+ int n_step = nq1/128;
+ FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ if (update(128*n_step)) return;
+ }
+ if (nq1 >= 64) {
+ int n_step = nq1/64;
+ FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ if (update(64*n_step)) return;
+ }
+ if (nq1 >= 32) {
+ int n_step = nq1/32;
+ FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ if (update(32*n_step)) return;
+ }
+ if (nq1 >= 16) {
+ int n_step = nq1/16;
+ FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ if (update(16*n_step)) return;
+ }
+ }
+ if (nq1 >= 8) {
+ int n_step = nq1/8;
+ FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ if (update(8*n_step)) return;
+ }
+ else if (nq1 >= 4) {
+ int n_step = nq1/4;
+ FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ if (update(4*n_step)) return;
+ }
+ else if (nq1 >= 2) {
+ int n_step = nq1/2;
+ FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ if (update(2*n_step)) return;
+ }
+ FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+}
+
+#ifdef __AVX512BF16__
+template <int Dk, int Dv, int k_step>
+inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
+ const float * q, const char * k, const char * v, const char * mask,
+ float scale, float softcap, float * qkv, float * M, float * S) {
+ HelperBF16<Dk, k_step> kh(k, stride_k);
+ HelperBF16<Dv, k_step> vh(v, stride_v);
+ if (nk1 >= 4096) {
+ if (nq1 >= 64) {
+ FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ return;
+ }
+ else if (nq1 >= 16) {
+ FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ return;
+ }
+ }
+ if (nq1 >= 8) {
+ FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ } else {
+ FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ }
+}
+#endif
+
+template <int Dk, int Dv, int k_step, typename KHelper>
+inline bool iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
+ int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv,
+ const float * q, const char * v, const char * mask,
+ float scale, float softcap, float * qkv, float * M, float * S) {
+
+ switch (type_v) {
+ case GGML_TYPE_F16: {
+ HelperF16 vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
+ } break;
+#ifdef __AVX512BF16__
+ case GGML_TYPE_BF16: {
+ HelperBF16<Dv, k_step> vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
+ } break;
+#endif
+ case GGML_TYPE_Q8_0: {
+ HelperQ80 vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
+ } break;
+ case GGML_TYPE_Q8_KV: {
+ HelperQ8KV<Dv> vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
+ } break;
+ case GGML_TYPE_Q6_0: {
+ HelperQ60 vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
+ } break;
+#if GGML_IQK_FA_ALL_QUANTS
+ case GGML_TYPE_Q4_0: {
+ HelperQ40 vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
+ } break;
+ case GGML_TYPE_Q4_1: {
+ HelperQ41 vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
+ } break;
+ case GGML_TYPE_IQ4_NL: {
+ HelperIQ4nl vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
+ } break;
+#endif
+ default: return false;
+ }
+ return true;
+}
+
+template <int Dk, int Dv, int k_step>
+inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
+ int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
+ const float * q, const char * k, const char * v, const char * mask,
+ float scale, float softcap, float * qkv, float * M, float * S) {
+
+ bool result = false;
+ switch (type_k) {
+ case GGML_TYPE_F16: {
+ HelperF16 kh(k, stride_k);
+ result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
+ } break;
+ case GGML_TYPE_Q8_0: {
+ HelperQ80 kh(k, stride_k);
+ result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
+ } break;
+ case GGML_TYPE_Q8_0_R8: {
+ HelperQ80R8<Dk> kh(k, stride_k);
+ result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
+ } break;
+ case GGML_TYPE_Q6_0: {
+ HelperQ60 kh(k, stride_k);
+ result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
+ } break;
+#if GGML_IQK_FA_ALL_QUANTS
+ case GGML_TYPE_Q8_KV: {
+ HelperQ8KV<Dk> kh(k, stride_k);
+ result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
+ } break;
+ case GGML_TYPE_Q4_0: {
+ HelperQ40 kh(k, stride_k);
+ result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
+ } break;
+ case GGML_TYPE_Q4_1: {
+ HelperQ41 kh(k, stride_k);
+ result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
+ } break;
+ case GGML_TYPE_IQ4_NL: {
+ HelperIQ4nl kh(k, stride_k);
+ result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
+ } break;
+#endif
+ default: break;
+ }
+
+ return result;
+}
+
+}
+
+#define IQK_FA_CASE(name) bool name(int int_type_k, int int_type_v,int nq,int nk,\
+ int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,\
+ const float * q, const void * k, const void * v, const void * mask,\
+ float scale, float softcap,\
+ float * qkv, float * M, float * S)
+
+IQK_FA_CASE(iqk_fa_576_512);
+IQK_FA_CASE(iqk_fa_192_128);
+IQK_FA_CASE(iqk_fa_256_256);
+IQK_FA_CASE(iqk_fa_128_128);
+IQK_FA_CASE(iqk_fa_96_96);
+IQK_FA_CASE(iqk_fa_64_64);
+
+#endif
+
diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h
index dc3e369f..6feeff1a 100644
--- a/ggml/src/iqk/iqk_common.h
+++ b/ggml/src/iqk/iqk_common.h
@@ -7,6 +7,8 @@
// SPDX-License-Identifier: MIT
//
+#pragma once
+
#include "iqk_config.h"
#if defined IQK_IMPLEMENT
@@ -14,6 +16,7 @@
#include <cstring>
#include <type_traits>
#include <vector>
+#include <cstdint>
#include "ggml-impl.h"
#include "ggml-quants.h"
@@ -79,8 +82,6 @@ struct Perf {
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
#endif
-namespace {
-
typedef struct {
int32_t i1;
int32_t i2;
@@ -135,4 +136,694 @@ struct DataInfo {
typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x);
+#define IQK_MAX_NY 8
+
+#define IQK_SET_MUL_MAT_FUNCTIONS_T(kernel, Dequantizer, funcs) \
+ funcs[0] = kernel<Dequantizer, 1>;\
+ funcs[1] = kernel<Dequantizer, 2>;\
+ funcs[2] = kernel<Dequantizer, 3>;\
+ funcs[3] = kernel<Dequantizer, 4>;\
+ funcs[4] = kernel<Dequantizer, 5>;\
+ funcs[5] = kernel<Dequantizer, 6>;\
+ funcs[6] = kernel<Dequantizer, 7>;\
+ funcs[7] = kernel<Dequantizer, 8>;\
+
+#define IQK_SET_MUL_MAT_FUNCTIONS(kernel, funcs) \
+ funcs[0] = kernel<1>;\
+ funcs[1] = kernel<2>;\
+ funcs[2] = kernel<3>;\
+ funcs[3] = kernel<4>;\
+ funcs[4] = kernel<5>;\
+ funcs[5] = kernel<6>;\
+ funcs[6] = kernel<7>;\
+ funcs[7] = kernel<8>;\
+
+
+// ==================================================================================================
+
+static inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {
+ const uint16_t * scales = (const uint16_t *)scales8;
+ const uint32_t a0 = scales[0] | (scales[1] << 16);
+ const uint32_t a1 = scales[2] | (scales[3] << 16);
+ const uint32_t a2 = scales[4] | (scales[5] << 16);
+ aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030);
+ aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030);
+ aux32[2] = a1 & 0x3f3f3f3f;
+ aux32[0] = a0 & 0x3f3f3f3f;
+}
+
+#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__)
+const uint64_t keven_signs[128] = {
+ 0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,
+ 0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,
+ 0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff,
+ 0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff,
+ 0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff,
+ 0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff,
+ 0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff,
+ 0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff,
+ 0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff,
+ 0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff,
+ 0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff,
+ 0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff,
+ 0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff,
+ 0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff,
+ 0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff,
+ 0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff,
+ 0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff,
+ 0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff,
+ 0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff,
+ 0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff,
+ 0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff,
+ 0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff,
+ 0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff,
+ 0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff,
+ 0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff,
+ 0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff,
+ 0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff,
+ 0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff,
+ 0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff,
+ 0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff,
+ 0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,
+ 0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,
+};
+#endif
+
+#ifdef __AVX2__
+
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
+static inline float hsum_float_4(__m128 x) {
+ x = _mm_add_ps(x, _mm_movehl_ps(x, x));
+ x = _mm_add_ss(x, _mm_movehdup_ps(x));
+ return _mm_cvtss_f32(x);
+}
+static inline float hsum_float_8(__m256 x) {
+ return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
+}
+static inline int hsum_i32_8(const __m256i a) {
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+static inline float hmax_float_8(__m256 x) {
+ __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
+ max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4));
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
+ return _mm_cvtss_f32(max4);
+}
+
+static inline __m256 hsum_float_8x8(__m256 * accm) {
+ for (int i = 0; i < 4; ++i) {
+ accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31));
+ //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
+ // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
+ }
+ for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
+ return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
+}
+
+static inline __m128i load_iq4nl_values_128() {
+ static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
+ return _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
+}
+
+static inline __m256i load_iq4nl_values_256() {
+ auto val128 = load_iq4nl_values_128();
+ return MM256_SET_M128I(val128, val128);
+}
+
+#ifdef HAVE_FANCY_SIMD
+static inline __m512i load_iq4nl_values_512() {
+ auto val256 = load_iq4nl_values_256();
+ return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
+}
+#endif
+
+static inline __m128i load_iq4k_values_128() {
+ return _mm_loadu_si128((const __m128i *)iq4k_values);
+}
+
+static inline __m256i load_iq4k_values_256() {
+ auto val128 = load_iq4k_values_128();
+ return MM256_SET_M128I(val128, val128);
+}
+
+template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);
+ }
+
+#ifdef HAVE_FANCY_SIMD
+ inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); }
+#endif
+ inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); }
+ inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); }
+ inline float scale(int iy, int i) const { return y[iy][i].d; }
+
+ const block_q8 * y[nrc_y];
+};
+
+template <int nrc> struct Q8_16 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8_16(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto ptr = (const float *)info.src1_row(iy);
+ std::memcpy(d + 5*iy, ptr, 5*sizeof(float));
+ y[iy] = (const int8_t *)(ptr + 5);
+ }
+ }
+
+#ifdef HAVE_FANCY_SIMD
+ inline __m512i load_quants64(int iy, int i) const { return _mm512_loadu_si512((const __m512i*)y[iy] + i); }
+#endif
+ inline __m256i load_quants(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy] + i); }
+ inline float scale(int iy, int k) const { return d[5*iy+k]; }
+ inline float sum_row(int iy) const { return d[5*iy + 4]; }
+ inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 5*iy); }
+
+ float d[5*nrc_y];
+ const int8_t * y[nrc_y];
+};
+
+struct Scales8KBase {
+ template <typename Q8>
+ inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
+ const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i q8s = q8.load_bsums(iy, i);
+ const __m256i prod = _mm256_madd_epi16(mins, q8s);
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
+ }
+ }
+ inline __m256i shuffle(__m128i mins) const {
+ return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0]));
+ }
+ const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),
+ _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};
+};
+
+template <typename Block, bool per_row_scale = false, bool is_f16 = false>
+struct BaseDequantizer {
+ BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {}
+ inline void new_row(int ix) {
+ if constexpr (per_row_scale) {
+ if constexpr (is_f16) {
+ const ggml_half * dptr = (const ggml_half *)((const char *)vx + bx*ix);
+ d = GGML_FP16_TO_FP32(*dptr);
+ x = (const Block *)(dptr + 1);
+ } else {
+ const float * dptr = (const float *)((const char *)vx + bx*ix);
+ d = *dptr;
+ x = (const Block *)(dptr + 1);
+ }
+ } else {
+ x = (const Block *)((const char *)vx + bx*ix);
+ }
+ }
+
+ const void * vx;
+ const size_t bx;
+ const Block * x;
+
+ float d;
+};
+
+template <typename Q8, typename Bits>
+static inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
+ if (j == 0) {
+#ifdef HAVE_FANCY_SIMD
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
+ }
+#else
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
+ const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
+ const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
+ const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
+ sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));
+ }
+#endif
+ } else {
+#ifdef HAVE_FANCY_SIMD
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
+ }
+#else
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
+ const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
+ const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
+ const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
+ sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));
+ sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));
+ }
+#endif
+ }
+}
+
+template <typename Q8, typename Bits>
+static inline void multiply_add_avx2(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
+ __m256i p[4];
+ if (j == 0) {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ for (int k = 0; k < 4; ++k) {
+ auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
+ p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, k), bits.values[k])));
+ }
+ sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p[0], p[1]), _mm256_add_epi32(p[2], p[3]));
+ }
+ } else {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ for (int k = 0; k < 4; ++k) {
+ auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
+ p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, 4+k), bits.values[k])));
+ }
+ sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[0], p[2]));
+ sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[1], p[3]));
+ }
+ }
+}
+
+#ifdef HAVE_FANCY_SIMD
+
+struct BlockPermuter {
+ const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
+ const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
+};
+
+struct Q4Bits {
+ inline void prepare(const uint8_t * q4) {
+ auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
+ auto tmp1 = _mm512_and_si512(q4bits, ml);
+ auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
+ values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
+ values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
+ q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
+ tmp1 = _mm512_and_si512(q4bits, ml);
+ tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
+ values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
+ values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
+ }
+ inline void prepare64(const uint8_t * q4) {
+ auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
+ values[0] = _mm512_and_si512(q4bits, ml);
+ values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
+ q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
+ values[2] = _mm512_and_si512(q4bits, ml);
+ values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
+ }
+ inline void prepare64a(const uint8_t * q4) {
+ for (int k = 0; k < 4; ++k) {
+ auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + k);
+ values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(q4bits), _mm256_srli_epi16(q4bits, 4), 1);
+ values[k] = _mm512_and_si512(values[k], ml);
+ }
+ }
+ __m512i values[4];
+ const __m512i ml = _mm512_set1_epi8(0xf);
+ const BlockPermuter perm;
+};
+
+struct Q2Bits {
+ inline void prepare(const uint8_t * q2) {
+
+ auto q2bits = _mm512_loadu_si512((const __m512i*)q2);
+ auto tmp = _mm512_srli_epi16(q2bits, 2);
+
+ values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);
+ values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);
+ values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);
+ values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);
+ values[0] = _mm512_and_si512(values[0], ml);
+ values[2] = _mm512_and_si512(values[2], ml);
+ }
+ __m512i values[4];
+ const __m512i ml = _mm512_set1_epi8(0x03);
+ BlockPermuter perm;
+};
+
+#else
+
+struct Q2Bits {
+ inline void prepare(const uint8_t * q2, int j) {
+ auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);
+ values[0] = _mm256_and_si256(q2bits, ml);
+ values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
+ values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
+ values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
+ }
+ __m256i values[4];
+ const __m256i ml = _mm256_set1_epi8(0x03);
+};
+
+struct Q4Bits {
+ inline void prepare(const uint8_t * q4, int j) {
+ auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
+ values[0] = _mm256_and_si256(q4bits, ml);
+ values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
+ q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
+ values[2] = _mm256_and_si256(q4bits, ml);
+ values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
+ }
+ inline void prepare64(const uint8_t * q4, int j) {
+ auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
+ values[0] = _mm256_and_si256(q4bits, ml);
+ values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
+ q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
+ values[1] = _mm256_and_si256(q4bits, ml);
+ values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
+ }
+ inline void prepare16(const uint8_t * q4, int j) {
+ values[0] = dequant16(q4 + 64*j + 0);
+ values[1] = dequant16(q4 + 64*j + 16);
+ values[2] = dequant16(q4 + 64*j + 32);
+ values[3] = dequant16(q4 + 64*j + 48);
+ }
+ inline __m256i dequant16(const uint8_t * qs) const {
+ const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);
+ const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);
+ return _mm256_and_si256(ml, aux256);
+ }
+ __m256i values[4];
+ const __m256i ml = _mm256_set1_epi8(0xf);
+};
+
+#endif
+
+#else
+// ------------------------------------ __aarch64__ --------------------------------------------------
+
+template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);
+ }
+
+ inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }
+ inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); }
+ inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); }
+ inline int16x8_t load_bsums8(int iy, int i) const {
+ auto q8s = vld1q_s16_x2(y[iy][i].bsums);
+ return vpaddq_s16(q8s.val[0], q8s.val[1]);
+ }
+ inline float scale(int iy, int i) const { return y[iy][i].d; }
+
+ const block_q8 * y[nrc_y];
+};
+
+template <typename block_q, bool has_row_scale = false, bool scale_is_f16 = false>
+struct BaseDequantizer {
+ BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}
+ inline void new_row(int ix) {
+ if constexpr (has_row_scale) {
+ if constexpr (scale_is_f16) {
+ const ggml_half * dptr = (const ggml_half *)((const char *)vx + ix*bx);
+ d = GGML_FP16_TO_FP32(*dptr);
+ x = (const block_q *)(dptr + 1);
+ } else {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ d = *dptr;
+ x = (const block_q *)(dptr + 1);
+ }
+ } else {
+ x = (const block_q *)((const char *)vx + ix*bx);
+ }
+ }
+ const void * vx;
+ const block_q * x;
+ const size_t bx;
+ const int nrc;
+ float d;
+};
+
+struct Q4bits {
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
+ uint8x16x4_t b1, b2;
+ inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const {
+ b.val[0] = vandq_u8(val[0], m4b);
+ b.val[2] = vshrq_n_u8(val[0], 4);
+ b.val[1] = vandq_u8(val[1], m4b);
+ b.val[3] = vshrq_n_u8(val[1], 4);
+ }
+ inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const {
+ b.val[0] = vandq_u8(val[0], m4b);
+ b.val[1] = vshrq_n_u8(val[0], 4);
+ b.val[2] = vandq_u8(val[1], m4b);
+ b.val[3] = vshrq_n_u8(val[1], 4);
+ }
+ inline void prepare(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x2(qs);
+ prepare4(b1, q4bits.val);
+ q4bits = vld1q_u8_x2(qs+32);
+ prepare4(b2, q4bits.val);
+ }
+ inline void prepare_v2(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x4(qs);
+ prepare4(b1, q4bits.val+0);
+ prepare4(b2, q4bits.val+2);
+ }
+ inline void prepare64(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x4(qs);
+ b1.val[0] = vandq_u8(q4bits.val[0], m4b);
+ b1.val[1] = vandq_u8(q4bits.val[1], m4b);
+ b1.val[2] = vandq_u8(q4bits.val[2], m4b);
+ b1.val[3] = vandq_u8(q4bits.val[3], m4b);
+ b2.val[0] = vshrq_n_u8(q4bits.val[0], 4);
+ b2.val[1] = vshrq_n_u8(q4bits.val[1], 4);
+ b2.val[2] = vshrq_n_u8(q4bits.val[2], 4);
+ b2.val[3] = vshrq_n_u8(q4bits.val[3], 4);
+ }
+ inline void prepare16(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x2(qs);
+ prepare4_16(b1, q4bits.val);
+ q4bits = vld1q_u8_x2(qs+32);
+ prepare4_16(b2, q4bits.val);
+ }
+ inline void prepare16_v2(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x4(qs);
+ prepare4_16(b1, q4bits.val+0);
+ prepare4_16(b2, q4bits.val+2);
+ }
+};
+
+struct Q2bits {
+ const uint8x16_t m4b = vdupq_n_u8(0x03);
+ uint8x16x4_t b1, b2;
+ inline void prepare(const uint8_t * qs) {
+ auto q2bits = vld1q_u8_x2(qs);
+ b1.val[0] = vandq_u8(q2bits.val[0], m4b);
+ b1.val[1] = vandq_u8(q2bits.val[1], m4b);
+
+ q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
+ q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
+ b1.val[2] = vandq_u8(q2bits.val[0], m4b);
+ b1.val[3] = vandq_u8(q2bits.val[1], m4b);
+
+ q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
+ q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
+ b2.val[0] = vandq_u8(q2bits.val[0], m4b);
+ b2.val[1] = vandq_u8(q2bits.val[1], m4b);
+
+ q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
+ q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
+ b2.val[2] = vandq_u8(q2bits.val[0], m4b);
+ b2.val[3] = vandq_u8(q2bits.val[1], m4b);
+ }
+};
+
+template <typename Q8>
+static inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
+ const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) {
+ auto mzero = vdupq_n_s32(0);
+ auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
+ auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
+ vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1]); // block 1
+ auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
+ auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
+ vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1]); // block 2
+ auto p12 = vpaddq_s32(p1, p2);
+
+ auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
+ auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
+ vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1]); // block 1
+ auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
+ auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
+ vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1]); // block 2
+ auto p34 = vpaddq_s32(p3, p4);
+
+ auto pall = vpaddq_s32(p12, p34);
+ sumi = vmlaq_s32(sumi, scales.val[j], pall);
+}
+
+template <typename Q8>
+static inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
+ const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {
+
+ auto mzero = vdupq_n_s32(0);
+ auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
+ auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
+ ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1,
+ auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
+ auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
+ ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4,
+ auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3
+ sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12);
+
+ auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
+ auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
+ ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5,
+ auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
+ auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
+ ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7,
+ auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7
+ sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34);
+}
+
+struct SignHelper {
+
+ inline void init() { shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); }
+
+ inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16) {
+ auto aux = vqtbl1q_u8(signs16, shuffle);
+ auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1));
+ b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s));
+ shuffle = vaddq_u8(shuffle, step);
+ }
+
+ const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
+ const uint8x16_t m1 = vdupq_n_u8(1);
+ const uint8x16_t step = vdupq_n_u8(2);
+ uint8x16_t shuffle;
+};
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y, block_q8_K> q8(info);
+
+ Dequantizer deq(vx, bx, nrc_y);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq.new_row(ix);
+
+ float32x4_t acc[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
+
+ for (int i = 0; i < nb; ++i) {
+
+ int32x4_t sumi[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);
+
+ if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) {
+ deq.process_scales(i, q8, acc);
+ deq.prepare(i, 0);
+ deq.compute(q8, i, 0, sumi);
+ deq.prepare(i, 1);
+ deq.compute(q8, i, 1, sumi);
+ } else {
+ if constexpr (Dequantizer::num_blocks() == 8) {
+ auto scales = deq.new_block(i, q8, acc);
+ deq.prepare(i, 0);
+ for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
+ deq.prepare(i, 1);
+ for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
+ }
+ else if constexpr (Dequantizer::num_blocks() == 16) {
+ auto scales = deq.new_block(i, q8, acc);
+ deq.prepare(i, 0);
+ for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
+ deq.prepare(i, 1);
+ for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
+ }
+ else {
+ GGML_ASSERT(false);
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vaddvq_f32(acc[iy]));
+ }
+ }
+}
+
+static IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) {
+ auto sumi = vdupq_n_s32(0);
+ sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
+ sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
+ return sumi;
+}
+
+static IQK_ALWAYS_INLINE int32x4x2_t interleaved_dotq_b16(const int8x16_t * qx, const int8x16x2_t& y) {
+ int32x4x2_t sumi = { vdupq_n_s32(0), vdupq_n_s32(0) };
+ sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[0], y.val[0], 0);
+ sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[1], y.val[1], 0);
+ sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[2], y.val[0], 1);
+ sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[3], y.val[1], 1);
+ sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[4], y.val[0], 2);
+ sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[5], y.val[1], 2);
+ sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[6], y.val[0], 3);
+ sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[7], y.val[1], 3);
+ return sumi;
+}
+
+static IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16_t& y) {
+ auto sumi = vdupq_n_s32(0);
+ sumi = vdotq_laneq_s32(sumi, qx[0], y, 0);
+ sumi = vdotq_laneq_s32(sumi, qx[1], y, 1);
+ sumi = vdotq_laneq_s32(sumi, qx[2], y, 2);
+ sumi = vdotq_laneq_s32(sumi, qx[3], y, 3);
+ return sumi;
+}
+
+static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) {
+ qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
+ qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
+ qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
+ qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
+ qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
+ qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
+ qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
+}
+
+static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x2_t& bits, int8x16_t * qx) {
+ qx[0] = vqtbl1q_s8(values, vandq_u8( bits.val[0], m4));
+ qx[1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4));
+ qx[2] = vqtbl1q_s8(values, vandq_u8( bits.val[1], m4));
+ qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4));
+}
+
+#endif
+
#endif
diff --git a/ggml/src/iqk/iqk_gemm_1bit.cpp b/ggml/src/iqk/iqk_gemm_1bit.cpp
new file mode 100644
index 00000000..728604f9
--- /dev/null
+++ b/ggml/src/iqk/iqk_gemm_1bit.cpp
@@ -0,0 +1,2282 @@
+#include "iqk_gemm_1bit.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include "ggml-impl.h"
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+namespace {
+
+#ifdef __AVX2__
+static const uint64_t iq1s_grid_us[2048] = {
+ 0x0000000000000000, 0x0000000000000002, 0x0000000000000101, 0x0000000000000200,
+ 0x0000000000000202, 0x0000000000010001, 0x0000000000010101, 0x0000000000020000,
+ 0x0000000000020002, 0x0000000000020200, 0x0000000000020202, 0x0000000001000101,
+ 0x0000000001010001, 0x0000000001010100, 0x0000000001010102, 0x0000000001020101,
+ 0x0000000002000000, 0x0000000002000002, 0x0000000002000200, 0x0000000002000202,
+ 0x0000000002010101, 0x0000000002020000, 0x0000000002020002, 0x0000000002020200,
+ 0x0000000002020202, 0x0000000100000100, 0x0000000100000101, 0x0000000100010001,
+ 0x0000000100010100, 0x0000000100010102, 0x0000000100010201, 0x0000000100010202,
+ 0x0000000100020101, 0x0000000101000001, 0x0000000101000102, 0x0000000101000201,
+ 0x0000000101010002, 0x0000000101010101, 0x0000000101010202, 0x0000000101020001,
+ 0x0000000101020100, 0x0000000101020102, 0x0000000101020200, 0x0000000102000101,
+ 0x0000000102010001, 0x0000000102010100, 0x0000000102010102, 0x0000000102020101,
+ 0x0000000200000000, 0x0000000200000002, 0x0000000200000200, 0x0000000200000202,
+ 0x0000000200010101, 0x0000000200020000, 0x0000000200020002, 0x0000000200020200,
+ 0x0000000200020202, 0x0000000201000101, 0x0000000201010001, 0x0000000201010201,
+ 0x0000000201020100, 0x0000000201020201, 0x0000000202000000, 0x0000000202000002,
+ 0x0000000202000200, 0x0000000202000202, 0x0000000202010001, 0x0000000202010101,
+ 0x0000000202010201, 0x0000000202020000, 0x0000000202020002, 0x0000000202020200,
+ 0x0000000202020202, 0x0000010000010001, 0x0000010000010100, 0x0000010000010102,
+ 0x0000010000020101, 0x0000010001000001, 0x0000010001000201, 0x0000010001010101,
+ 0x0000010001010202, 0x0000010001020100, 0x0000010001020101, 0x0000010002010001,
+ 0x0000010002010201, 0x0000010002020101, 0x0000010100000001, 0x0000010100000100,
+ 0x0000010100000101, 0x0000010100000102, 0x0000010100010101, 0x0000010100010200,
+ 0x0000010100010202, 0x0000010100020201, 0x0000010101000000, 0x0000010101000101,
+ 0x0000010101000202, 0x0000010101010000, 0x0000010101010001, 0x0000010101010100,
+ 0x0000010101010101, 0x0000010101010102, 0x0000010101010201, 0x0000010101020000,
+ 0x0000010101020002, 0x0000010101020101, 0x0000010101020200, 0x0000010101020202,
+ 0x0000010102000001, 0x0000010102010001, 0x0000010102010101, 0x0000010102010200,
+ 0x0000010102010202, 0x0000010102020001, 0x0000010102020100, 0x0000010102020101,
+ 0x0000010102020102, 0x0000010102020201, 0x0000010200010100, 0x0000010200010201,
+ 0x0000010201000001, 0x0000010201000100, 0x0000010201010000, 0x0000010201010002,
+ 0x0000010201010101, 0x0000010201010200, 0x0000010201020000, 0x0000010201020001,
+ 0x0000010201020102, 0x0000010201020201, 0x0000010202000101, 0x0000010202010001,
+ 0x0000010202010100, 0x0000010202010201, 0x0000020000000000, 0x0000020000000002,
+ 0x0000020000000200, 0x0000020000000202, 0x0000020000010101, 0x0000020000020000,
+ 0x0000020000020002, 0x0000020000020200, 0x0000020000020202, 0x0000020001000101,
+ 0x0000020001010001, 0x0000020001010102, 0x0000020001020101, 0x0000020002000000,
+ 0x0000020002000002, 0x0000020002000200, 0x0000020002000202, 0x0000020002010101,
+ 0x0000020002020000, 0x0000020002020002, 0x0000020002020200, 0x0000020002020202,
+ 0x0000020100000101, 0x0000020100010001, 0x0000020100010100, 0x0000020100010201,
+ 0x0000020100020100, 0x0000020100020101, 0x0000020101000001, 0x0000020101010000,
+ 0x0000020101010001, 0x0000020101010101, 0x0000020101020001, 0x0000020101020100,
+ 0x0000020101020201, 0x0000020102010001, 0x0000020102010100, 0x0000020102010102,
+ 0x0000020102010201, 0x0000020102020101, 0x0000020200000000, 0x0000020200000002,
+ 0x0000020200000200, 0x0000020200000202, 0x0000020200010101, 0x0000020200020000,
+ 0x0000020200020002, 0x0000020200020200, 0x0000020200020202, 0x0000020201000101,
+ 0x0000020201010001, 0x0000020201010201, 0x0000020201020001, 0x0000020201020101,
+ 0x0000020202000000, 0x0000020202000002, 0x0000020202000101, 0x0000020202000200,
+ 0x0000020202000202, 0x0000020202010101, 0x0000020202020000, 0x0000020202020002,
+ 0x0000020202020200, 0x0000020202020202, 0x0001000000010000, 0x0001000000010001,
+ 0x0001000000010100, 0x0001000000010201, 0x0001000000020100, 0x0001000000020101,
+ 0x0001000001000001, 0x0001000001000100, 0x0001000001010000, 0x0001000001010101,
+ 0x0001000001010200, 0x0001000001020001, 0x0001000001020100, 0x0001000001020101,
+ 0x0001000001020201, 0x0001000002010001, 0x0001000002010100, 0x0001000002010102,
+ 0x0001000002020001, 0x0001000002020101, 0x0001000100000001, 0x0001000100000100,
+ 0x0001000100000102, 0x0001000100000201, 0x0001000100010000, 0x0001000100010002,
+ 0x0001000100010101, 0x0001000100010200, 0x0001000100020001, 0x0001000100020100,
+ 0x0001000100020201, 0x0001000101000101, 0x0001000101000202, 0x0001000101010000,
+ 0x0001000101010001, 0x0001000101010002, 0x0001000101010100, 0x0001000101010101,
+ 0x0001000101010102, 0x0001000101010201, 0x0001000101020000, 0x0001000101020101,
+ 0x0001000102000100, 0x0001000102010002, 0x0001000102010101, 0x0001000102020001,
+ 0x0001000102020100, 0x0001000200010001, 0x0001000200010100, 0x0001000200010102,
+ 0x0001000200020101, 0x0001000201000000, 0x0001000201000102, 0x0001000201000201,
+ 0x0001000201010002, 0x0001000201010101, 0x0001000201010200, 0x0001000201010202,
+ 0x0001000201020100, 0x0001000201020102, 0x0001000202000101, 0x0001000202010001,
+ 0x0001000202010100, 0x0001000202010102, 0x0001000202020101, 0x0001010000000001,
+ 0x0001010000000102, 0x0001010000000201, 0x0001010000010100, 0x0001010000010101,
+ 0x0001010000010200, 0x0001010000010201, 0x0001010000020001, 0x0001010000020102,
+ 0x0001010001000001, 0x0001010001000101, 0x0001010001000102, 0x0001010001000200,
+ 0x0001010001000202, 0x0001010001010001, 0x0001010001010100, 0x0001010001010101,
+ 0x0001010001010102, 0x0001010001010201, 0x0001010001020002, 0x0001010001020101,
+ 0x0001010001020200, 0x0001010002000100, 0x0001010002000201, 0x0001010002010000,
+ 0x0001010002010100, 0x0001010002010101, 0x0001010002010200, 0x0001010002010201,
+ 0x0001010002010202, 0x0001010002020001, 0x0001010002020100, 0x0001010002020101,
+ 0x0001010002020201, 0x0001010100000002, 0x0001010100000101, 0x0001010100000202,
+ 0x0001010100010001, 0x0001010100010100, 0x0001010100010101, 0x0001010100010102,
+ 0x0001010100010201, 0x0001010100020000, 0x0001010100020002, 0x0001010100020101,
+ 0x0001010100020200, 0x0001010100020202, 0x0001010101000001, 0x0001010101000100,
+ 0x0001010101000101, 0x0001010101000102, 0x0001010101010001, 0x0001010101010002,
+ 0x0001010101010100, 0x0001010101010101, 0x0001010101010102, 0x0001010101010201,
+ 0x0001010101010202, 0x0001010101020001, 0x0001010101020100, 0x0001010101020101,
+ 0x0001010101020102, 0x0001010101020201, 0x0001010102000000, 0x0001010102000002,
+ 0x0001010102000100, 0x0001010102000101, 0x0001010102000200, 0x0001010102000202,
+ 0x0001010102010000, 0x0001010102010001, 0x0001010102010100, 0x0001010102010101,
+ 0x0001010102010102, 0x0001010102010201, 0x0001010102010202, 0x0001010102020000,
+ 0x0001010102020002, 0x0001010102020101, 0x0001010200000001, 0x0001010200000100,
+ 0x0001010200000101, 0x0001010200000102, 0x0001010200010101, 0x0001010200010102,
+ 0x0001010200010200, 0x0001010200010202, 0x0001010200020001, 0x0001010200020102,
+ 0x0001010201000000, 0x0001010201000002, 0x0001010201000100, 0x0001010201000101,
+ 0x0001010201000200, 0x0001010201000202, 0x0001010201010001, 0x0001010201010101,
+ 0x0001010201010102, 0x0001010201010200, 0x0001010201010201, 0x0001010201020001,
+ 0x0001010201020100, 0x0001010201020101, 0x0001010201020200, 0x0001010201020201,
+ 0x0001010201020202, 0x0001010202000102, 0x0001010202000202, 0x0001010202010002,
+ 0x0001010202010101, 0x0001010202020100, 0x0001010202020201, 0x0001020000010001,
+ 0x0001020000010102, 0x0001020000020101, 0x0001020001000001, 0x0001020001000100,
+ 0x0001020001000102, 0x0001020001000201, 0x0001020001010000, 0x0001020001010101,
+ 0x0001020001010200, 0x0001020001010202, 0x0001020001020000, 0x0001020001020001,
+ 0x0001020001020100, 0x0001020001020102, 0x0001020001020201, 0x0001020002000101,
+ 0x0001020002010001, 0x0001020002010100, 0x0001020002020101, 0x0001020100010000,
+ 0x0001020100010002, 0x0001020100010101, 0x0001020100010202, 0x0001020100020001,
+ 0x0001020100020101, 0x0001020101000002, 0x0001020101000100, 0x0001020101000101,
+ 0x0001020101000200, 0x0001020101010001, 0x0001020101010100, 0x0001020101010101,
+ 0x0001020101010102, 0x0001020101010201, 0x0001020101010202, 0x0001020101020000,
+ 0x0001020101020101, 0x0001020101020202, 0x0001020102000201, 0x0001020102010001,
+ 0x0001020102010002, 0x0001020102010101, 0x0001020102010200, 0x0001020102020001,
+ 0x0001020102020102, 0x0001020102020201, 0x0001020200000201, 0x0001020200010102,
+ 0x0001020200020100, 0x0001020200020102, 0x0001020201000100, 0x0001020201000102,
+ 0x0001020201000201, 0x0001020201010000, 0x0001020201010002, 0x0001020201010101,
+ 0x0001020201010200, 0x0001020201020001, 0x0001020201020102, 0x0001020201020201,
+ 0x0001020202000101, 0x0001020202010001, 0x0001020202010102, 0x0001020202010202,
+ 0x0002000000000000, 0x0002000000000002, 0x0002000000000200, 0x0002000000000202,
+ 0x0002000000010101, 0x0002000000020000, 0x0002000000020002, 0x0002000000020101,
+ 0x0002000000020200, 0x0002000000020202, 0x0002000001000101, 0x0002000001010001,
+ 0x0002000001010201, 0x0002000001020001, 0x0002000001020101, 0x0002000002000000,
+ 0x0002000002000002, 0x0002000002000200, 0x0002000002000202, 0x0002000002010101,
+ 0x0002000002020000, 0x0002000002020002, 0x0002000002020101, 0x0002000002020200,
+ 0x0002000002020202, 0x0002000100000101, 0x0002000100010001, 0x0002000100010100,
+ 0x0002000100010201, 0x0002000100020101, 0x0002000101000002, 0x0002000101000100,
+ 0x0002000101000201, 0x0002000101010101, 0x0002000101010200, 0x0002000101010202,
+ 0x0002000101020001, 0x0002000101020100, 0x0002000101020101, 0x0002000101020102,
+ 0x0002000102000101, 0x0002000102010000, 0x0002000102010102, 0x0002000102010201,
+ 0x0002000102020101, 0x0002000200000001, 0x0002000200000200, 0x0002000200000202,
+ 0x0002000200010001, 0x0002000200010101, 0x0002000200020000, 0x0002000200020002,
+ 0x0002000200020200, 0x0002000200020202, 0x0002000201000101, 0x0002000201010001,
+ 0x0002000201010102, 0x0002000201010201, 0x0002000201020101, 0x0002000202000001,
+ 0x0002000202000200, 0x0002000202000202, 0x0002000202010001, 0x0002000202010101,
+ 0x0002000202020000, 0x0002000202020002, 0x0002000202020200, 0x0002000202020202,
+ 0x0002010000000101, 0x0002010000010100, 0x0002010000010102, 0x0002010000010201,
+ 0x0002010000020101, 0x0002010001000100, 0x0002010001000101, 0x0002010001000102,
+ 0x0002010001000201, 0x0002010001010002, 0x0002010001010101, 0x0002010001010200,
+ 0x0002010001010202, 0x0002010001020102, 0x0002010002000101, 0x0002010002010001,
+ 0x0002010002010100, 0x0002010002010201, 0x0002010002020001, 0x0002010002020101,
+ 0x0002010100000201, 0x0002010100010101, 0x0002010100020001, 0x0002010100020201,
+ 0x0002010101000000, 0x0002010101000101, 0x0002010101000200, 0x0002010101010001,
+ 0x0002010101010100, 0x0002010101010101, 0x0002010101010201, 0x0002010101020002,
+ 0x0002010101020101, 0x0002010101020200, 0x0002010102000201, 0x0002010102010000,
+ 0x0002010102010100, 0x0002010102010101, 0x0002010102010200, 0x0002010102010202,
+ 0x0002010102020001, 0x0002010102020100, 0x0002010102020102, 0x0002010102020201,
+ 0x0002010200000101, 0x0002010200010000, 0x0002010200010002, 0x0002010200010201,
+ 0x0002010200020101, 0x0002010201000001, 0x0002010201000201, 0x0002010201010101,
+ 0x0002010201020000, 0x0002010201020001, 0x0002010201020201, 0x0002010202000100,
+ 0x0002010202000102, 0x0002010202010000, 0x0002010202010202, 0x0002020000000000,
+ 0x0002020000000002, 0x0002020000000200, 0x0002020000000202, 0x0002020000010101,
+ 0x0002020000020000, 0x0002020000020002, 0x0002020000020200, 0x0002020000020202,
+ 0x0002020001000101, 0x0002020001010001, 0x0002020001010100, 0x0002020001020101,
+ 0x0002020002000000, 0x0002020002000002, 0x0002020002000200, 0x0002020002000202,
+ 0x0002020002020000, 0x0002020002020002, 0x0002020002020200, 0x0002020002020202,
+ 0x0002020100000201, 0x0002020100010001, 0x0002020100010100, 0x0002020100010201,
+ 0x0002020100020101, 0x0002020101000102, 0x0002020101000201, 0x0002020101010002,
+ 0x0002020101010101, 0x0002020101020001, 0x0002020101020100, 0x0002020101020102,
+ 0x0002020101020201, 0x0002020102000101, 0x0002020102010000, 0x0002020102010102,
+ 0x0002020102010201, 0x0002020102020100, 0x0002020102020101, 0x0002020200000000,
+ 0x0002020200000002, 0x0002020200000200, 0x0002020200000202, 0x0002020200020000,
+ 0x0002020200020002, 0x0002020200020200, 0x0002020200020202, 0x0002020201000101,
+ 0x0002020201010001, 0x0002020201010102, 0x0002020201010201, 0x0002020201020101,
+ 0x0002020202000000, 0x0002020202000002, 0x0002020202000200, 0x0002020202000202,
+ 0x0002020202010101, 0x0002020202020000, 0x0002020202020002, 0x0002020202020200,
+ 0x0002020202020202, 0x0100000000000101, 0x0100000000010001, 0x0100000000010102,
+ 0x0100000000020101, 0x0100000001000201, 0x0100000001010002, 0x0100000001010101,
+ 0x0100000001010200, 0x0100000001010202, 0x0100000001020001, 0x0100000001020100,
+ 0x0100000001020102, 0x0100000002010100, 0x0100000002010201, 0x0100000002020001,
+ 0x0100000002020102, 0x0100000100000000, 0x0100000100000001, 0x0100000100000100,
+ 0x0100000100000102, 0x0100000100000201, 0x0100000100010002, 0x0100000100010101,
+ 0x0100000100010102, 0x0100000100010200, 0x0100000100010202, 0x0100000100020001,
+ 0x0100000100020102, 0x0100000100020201, 0x0100000101000101, 0x0100000101000200,
+ 0x0100000101000202, 0x0100000101010001, 0x0100000101010100, 0x0100000101010101,
+ 0x0100000101010102, 0x0100000101010201, 0x0100000101010202, 0x0100000101020101,
+ 0x0100000101020200, 0x0100000101020202, 0x0100000102000001, 0x0100000102000100,
+ 0x0100000102000102, 0x0100000102010000, 0x0100000102010002, 0x0100000102010101,
+ 0x0100000102020000, 0x0100000102020001, 0x0100000102020002, 0x0100000200000101,
+ 0x0100000200010001, 0x0100000200010100, 0x0100000200010102, 0x0100000200020101,
+ 0x0100000201000001, 0x0100000201010002, 0x0100000201010101, 0x0100000201010202,
+ 0x0100000201020100, 0x0100000201020201, 0x0100000202000201, 0x0100000202010100,
+ 0x0100000202020101, 0x0100010000000001, 0x0100010000010101, 0x0100010000010201,
+ 0x0100010000020201, 0x0100010001000101, 0x0100010001000200, 0x0100010001000202,
+ 0x0100010001010001, 0x0100010001010100, 0x0100010001010101, 0x0100010001010102,
+ 0x0100010001020001, 0x0100010001020002, 0x0100010001020101, 0x0100010001020200,
+ 0x0100010001020202, 0x0100010002000001, 0x0100010002000102, 0x0100010002000201,
+ 0x0100010002010000, 0x0100010002010002, 0x0100010002010101, 0x0100010002020000,
+ 0x0100010002020001, 0x0100010002020201, 0x0100010100000001, 0x0100010100000002,
+ 0x0100010100000101, 0x0100010100000202, 0x0100010100010001, 0x0100010100010100,
+ 0x0100010100010101, 0x0100010100010102, 0x0100010100010201, 0x0100010100020000,
+ 0x0100010100020101, 0x0100010100020202, 0x0100010101000001, 0x0100010101000100,
+ 0x0100010101000101, 0x0100010101000102, 0x0100010101000201, 0x0100010101010000,
+ 0x0100010101010001, 0x0100010101010100, 0x0100010101010101, 0x0100010101010102,
+ 0x0100010101010200, 0x0100010101010201, 0x0100010101020001, 0x0100010101020100,
+ 0x0100010101020101, 0x0100010101020102, 0x0100010101020201, 0x0100010102000002,
+ 0x0100010102000100, 0x0100010102000101, 0x0100010102000200, 0x0100010102010001,
+ 0x0100010102010100, 0x0100010102010101, 0x0100010102010102, 0x0100010102010201,
+ 0x0100010102010202, 0x0100010102020101, 0x0100010102020200, 0x0100010102020202,
+ 0x0100010200000001, 0x0100010200000101, 0x0100010200000201, 0x0100010200010100,
+ 0x0100010200010101, 0x0100010200010200, 0x0100010200010202, 0x0100010200020001,
+ 0x0100010200020100, 0x0100010200020201, 0x0100010201000000, 0x0100010201000002,
+ 0x0100010201000101, 0x0100010201000200, 0x0100010201010000, 0x0100010201010001,
+ 0x0100010201010002, 0x0100010201010101, 0x0100010201010102, 0x0100010201010201,
+ 0x0100010201020002, 0x0100010201020101, 0x0100010201020200, 0x0100010202000001,
+ 0x0100010202000101, 0x0100010202000202, 0x0100010202010100, 0x0100010202010101,
+ 0x0100010202020001, 0x0100010202020100, 0x0100010202020102, 0x0100020000000101,
+ 0x0100020000010001, 0x0100020000010101, 0x0100020000010202, 0x0100020000020101,
+ 0x0100020001000002, 0x0100020001000201, 0x0100020001010000, 0x0100020001010101,
+ 0x0100020001010200, 0x0100020001020001, 0x0100020001020100, 0x0100020001020102,
+ 0x0100020001020201, 0x0100020002000101, 0x0100020002010001, 0x0100020002010100,
+ 0x0100020002010102, 0x0100020002010201, 0x0100020002020101, 0x0100020100000001,
+ 0x0100020100000101, 0x0100020100000102, 0x0100020100000202, 0x0100020100010000,
+ 0x0100020100010100, 0x0100020100010101, 0x0100020100010200, 0x0100020100020001,
+ 0x0100020100020100, 0x0100020100020102, 0x0100020101000000, 0x0100020101000101,
+ 0x0100020101000202, 0x0100020101010001, 0x0100020101010002, 0x0100020101010100,
+ 0x0100020101010101, 0x0100020101010102, 0x0100020101010201, 0x0100020101020000,
+ 0x0100020101020002, 0x0100020101020101, 0x0100020101020102, 0x0100020101020202,
+ 0x0100020102000102, 0x0100020102000201, 0x0100020102010002, 0x0100020102010101,
+ 0x0100020102010102, 0x0100020102010200, 0x0100020102020001, 0x0100020102020100,
+ 0x0100020102020102, 0x0100020102020201, 0x0100020200010102, 0x0100020201000100,
+ 0x0100020201000102, 0x0100020201000201, 0x0100020201010101, 0x0100020201010200,
+ 0x0100020201010202, 0x0100020201020100, 0x0100020201020201, 0x0100020202010100,
+ 0x0100020202020101, 0x0101000000000001, 0x0101000000000100, 0x0101000000000101,
+ 0x0101000000000102, 0x0101000000000201, 0x0101000000010002, 0x0101000000010101,
+ 0x0101000000010202, 0x0101000000020001, 0x0101000000020100, 0x0101000000020201,
+ 0x0101000001000000, 0x0101000001000101, 0x0101000001000200, 0x0101000001010001,
+ 0x0101000001010100, 0x0101000001010101, 0x0101000001010102, 0x0101000001010201,
+ 0x0101000001020101, 0x0101000001020200, 0x0101000002000102, 0x0101000002000201,
+ 0x0101000002010101, 0x0101000002010200, 0x0101000002020000, 0x0101000002020001,
+ 0x0101000002020102, 0x0101000002020201, 0x0101000100000101, 0x0101000100000200,
+ 0x0101000100000201, 0x0101000100000202, 0x0101000100010001, 0x0101000100010100,
+ 0x0101000100010101, 0x0101000100010102, 0x0101000100010200, 0x0101000100010201,
+ 0x0101000100020000, 0x0101000100020101, 0x0101000100020102, 0x0101000100020200,
+ 0x0101000100020202, 0x0101000101000001, 0x0101000101000100, 0x0101000101000101,
+ 0x0101000101000102, 0x0101000101000201, 0x0101000101010000, 0x0101000101010001,
+ 0x0101000101010002, 0x0101000101010100, 0x0101000101010101, 0x0101000101010102,
+ 0x0101000101010200, 0x0101000101010201, 0x0101000101010202, 0x0101000101020001,
+ 0x0101000101020100, 0x0101000101020101, 0x0101000101020102, 0x0101000101020201,
+ 0x0101000102000002, 0x0101000102000101, 0x0101000102010001, 0x0101000102010100,
+ 0x0101000102010101, 0x0101000102010102, 0x0101000102010201, 0x0101000102020000,
+ 0x0101000102020101, 0x0101000102020202, 0x0101000200000001, 0x0101000200000102,
+ 0x0101000200010002, 0x0101000200010101, 0x0101000200010202, 0x0101000200020001,
+ 0x0101000200020100, 0x0101000201000002, 0x0101000201000101, 0x0101000201000202,
+ 0x0101000201010001, 0x0101000201010100, 0x0101000201010101, 0x0101000201010102,
+ 0x0101000201010201, 0x0101000201020002, 0x0101000201020101, 0x0101000202000101,
+ 0x0101000202010000, 0x0101000202010002, 0x0101000202010101, 0x0101000202010201,
+ 0x0101000202010202, 0x0101000202020100, 0x0101010000000100, 0x0101010000000101,
+ 0x0101010000010001, 0x0101010000010100, 0x0101010000010101, 0x0101010000010102,
+ 0x0101010000010200, 0x0101010000010201, 0x0101010000020001, 0x0101010000020101,
+ 0x0101010000020200, 0x0101010000020202, 0x0101010001000001, 0x0101010001000100,
+ 0x0101010001000101, 0x0101010001000102, 0x0101010001000201, 0x0101010001000202,
+ 0x0101010001010000, 0x0101010001010001, 0x0101010001010100, 0x0101010001010101,
+ 0x0101010001010102, 0x0101010001010200, 0x0101010001010201, 0x0101010001010202,
+ 0x0101010001020001, 0x0101010001020002, 0x0101010001020100, 0x0101010001020101,
+ 0x0101010001020102, 0x0101010001020201, 0x0101010002000000, 0x0101010002000200,
+ 0x0101010002000202, 0x0101010002010001, 0x0101010002010100, 0x0101010002010101,
+ 0x0101010002010102, 0x0101010002010201, 0x0101010002020001, 0x0101010002020100,
+ 0x0101010002020101, 0x0101010002020202, 0x0101010100000001, 0x0101010100000002,
+ 0x0101010100000100, 0x0101010100000101, 0x0101010100000102, 0x0101010100000201,
+ 0x0101010100010000, 0x0101010100010001, 0x0101010100010002, 0x0101010100010100,
+ 0x0101010100010101, 0x0101010100010102, 0x0101010100010201, 0x0101010100010202,
+ 0x0101010100020001, 0x0101010100020100, 0x0101010100020101, 0x0101010100020102,
+ 0x0101010100020201, 0x0101010101000000, 0x0101010101000001, 0x0101010101000002,
+ 0x0101010101000100, 0x0101010101000101, 0x0101010101000102, 0x0101010101000200,
+ 0x0101010101000201, 0x0101010101010000, 0x0101010101010001, 0x0101010101010002,
+ 0x0101010101010100, 0x0101010101010101, 0x0101010101010102, 0x0101010101010200,
+ 0x0101010101010201, 0x0101010101010202, 0x0101010101020000, 0x0101010101020001,
+ 0x0101010101020100, 0x0101010101020101, 0x0101010101020102, 0x0101010101020200,
+ 0x0101010101020201, 0x0101010101020202, 0x0101010102000001, 0x0101010102000100,
+ 0x0101010102000101, 0x0101010102000201, 0x0101010102000202, 0x0101010102010000,
+ 0x0101010102010001, 0x0101010102010100, 0x0101010102010101, 0x0101010102010102,
+ 0x0101010102010200, 0x0101010102010201, 0x0101010102020001, 0x0101010102020100,
+ 0x0101010102020101, 0x0101010102020102, 0x0101010102020201, 0x0101010200000000,
+ 0x0101010200000001, 0x0101010200000002, 0x0101010200000100, 0x0101010200000102,
+ 0x0101010200000200, 0x0101010200000201, 0x0101010200010001, 0x0101010200010100,
+ 0x0101010200010101, 0x0101010200010200, 0x0101010200010201, 0x0101010200020000,
+ 0x0101010200020001, 0x0101010200020002, 0x0101010200020100, 0x0101010200020101,
+ 0x0101010200020102, 0x0101010200020200, 0x0101010200020201, 0x0101010201000001,
+ 0x0101010201000101, 0x0101010201000102, 0x0101010201000200, 0x0101010201000201,
+ 0x0101010201000202, 0x0101010201010000, 0x0101010201010001, 0x0101010201010002,
+ 0x0101010201010100, 0x0101010201010101, 0x0101010201010102, 0x0101010201010200,
+ 0x0101010201010201, 0x0101010201010202, 0x0101010201020001, 0x0101010201020100,
+ 0x0101010201020101, 0x0101010201020201, 0x0101010202000002, 0x0101010202000101,
+ 0x0101010202000102, 0x0101010202000200, 0x0101010202000201, 0x0101010202000202,
+ 0x0101010202010001, 0x0101010202010101, 0x0101010202010202, 0x0101010202020002,
+ 0x0101010202020101, 0x0101010202020102, 0x0101010202020200, 0x0101010202020201,
+ 0x0101020000000100, 0x0101020000000101, 0x0101020000000102, 0x0101020000000201,
+ 0x0101020000010000, 0x0101020000010101, 0x0101020000010200, 0x0101020000020001,
+ 0x0101020000020202, 0x0101020001000101, 0x0101020001000200, 0x0101020001000202,
+ 0x0101020001010001, 0x0101020001010100, 0x0101020001010101, 0x0101020001010102,
+ 0x0101020001010200, 0x0101020001010201, 0x0101020001020000, 0x0101020001020002,
+ 0x0101020001020100, 0x0101020001020101, 0x0101020002000002, 0x0101020002000201,
+ 0x0101020002010000, 0x0101020002010002, 0x0101020002010101, 0x0101020002010200,
+ 0x0101020002020001, 0x0101020002020201, 0x0101020100000001, 0x0101020100000002,
+ 0x0101020100000101, 0x0101020100000202, 0x0101020100010001, 0x0101020100010100,
+ 0x0101020100010101, 0x0101020100010102, 0x0101020100010201, 0x0101020100020101,
+ 0x0101020101000001, 0x0101020101000100, 0x0101020101000101, 0x0101020101000102,
+ 0x0101020101000201, 0x0101020101010000, 0x0101020101010001, 0x0101020101010002,
+ 0x0101020101010100, 0x0101020101010101, 0x0101020101010102, 0x0101020101010200,
+ 0x0101020101010201, 0x0101020101010202, 0x0101020101020001, 0x0101020101020100,
+ 0x0101020101020101, 0x0101020101020102, 0x0101020101020201, 0x0101020102000001,
+ 0x0101020102000101, 0x0101020102000201, 0x0101020102010001, 0x0101020102010100,
+ 0x0101020102010101, 0x0101020102010102, 0x0101020102010200, 0x0101020102010201,
+ 0x0101020102020101, 0x0101020200000100, 0x0101020200000200, 0x0101020200010101,
+ 0x0101020200010202, 0x0101020200020000, 0x0101020200020101, 0x0101020200020102,
+ 0x0101020200020201, 0x0101020201000101, 0x0101020201000200, 0x0101020201000201,
+ 0x0101020201010001, 0x0101020201010101, 0x0101020201010102, 0x0101020201010200,
+ 0x0101020201010201, 0x0101020201020002, 0x0101020201020101, 0x0101020201020200,
+ 0x0101020201020202, 0x0101020202000001, 0x0101020202000202, 0x0101020202010002,
+ 0x0101020202010101, 0x0101020202010102, 0x0101020202010200, 0x0101020202010202,
+ 0x0101020202020001, 0x0102000000000101, 0x0102000000010100, 0x0102000000010102,
+ 0x0102000000010201, 0x0102000000020101, 0x0102000001000100, 0x0102000001010000,
+ 0x0102000001010101, 0x0102000001010102, 0x0102000001010200, 0x0102000001010202,
+ 0x0102000001020001, 0x0102000001020100, 0x0102000001020102, 0x0102000001020201,
+ 0x0102000002000001, 0x0102000002010102, 0x0102000002020101, 0x0102000100000001,
+ 0x0102000100000100, 0x0102000100000102, 0x0102000100000201, 0x0102000100010002,
+ 0x0102000100010101, 0x0102000100020001, 0x0102000100020002, 0x0102000100020102,
+ 0x0102000100020201, 0x0102000101000101, 0x0102000101000201, 0x0102000101010001,
+ 0x0102000101010101, 0x0102000101010102, 0x0102000101010201, 0x0102000101020101,
+ 0x0102000101020102, 0x0102000101020202, 0x0102000102000100, 0x0102000102000202,
+ 0x0102000102010002, 0x0102000102010101, 0x0102000102020001, 0x0102000102020102,
+ 0x0102000102020201, 0x0102000200010001, 0x0102000200010102, 0x0102000200010201,
+ 0x0102000201000000, 0x0102000201000001, 0x0102000201000102, 0x0102000201010101,
+ 0x0102000201010102, 0x0102000201010200, 0x0102000201020000, 0x0102000202000101,
+ 0x0102000202010001, 0x0102000202010102, 0x0102000202020101, 0x0102010000010001,
+ 0x0102010000010002, 0x0102010000010101, 0x0102010000010102, 0x0102010000010202,
+ 0x0102010000020001, 0x0102010000020102, 0x0102010000020201, 0x0102010001000000,
+ 0x0102010001000002, 0x0102010001000101, 0x0102010001000200, 0x0102010001000202,
+ 0x0102010001010001, 0x0102010001010100, 0x0102010001010101, 0x0102010001010102,
+ 0x0102010001010201, 0x0102010001010202, 0x0102010001020000, 0x0102010001020002,
+ 0x0102010001020101, 0x0102010002000100, 0x0102010002000101, 0x0102010002000201,
+ 0x0102010002010000, 0x0102010002010002, 0x0102010002010100, 0x0102010002010101,
+ 0x0102010002010102, 0x0102010002010200, 0x0102010002010202, 0x0102010002020001,
+ 0x0102010002020100, 0x0102010002020201, 0x0102010100000101, 0x0102010100000200,
+ 0x0102010100000202, 0x0102010100010001, 0x0102010100010101, 0x0102010100010102,
+ 0x0102010100010201, 0x0102010101000100, 0x0102010101000101, 0x0102010101000102,
+ 0x0102010101000201, 0x0102010101010000, 0x0102010101010001, 0x0102010101010100,
+ 0x0102010101010101, 0x0102010101010102, 0x0102010101010201, 0x0102010101020001,
+ 0x0102010101020100, 0x0102010101020101, 0x0102010101020102, 0x0102010101020201,
+ 0x0102010102000102, 0x0102010102000201, 0x0102010102000202, 0x0102010102010001,
+ 0x0102010102010101, 0x0102010102010102, 0x0102010102010201, 0x0102010102010202,
+ 0x0102010102020002, 0x0102010102020101, 0x0102010102020102, 0x0102010102020200,
+ 0x0102010200000002, 0x0102010200000201, 0x0102010200010101, 0x0102010200020000,
+ 0x0102010200020102, 0x0102010200020200, 0x0102010200020201, 0x0102010201000000,
+ 0x0102010201000101, 0x0102010201000200, 0x0102010201000202, 0x0102010201010001,
+ 0x0102010201010100, 0x0102010201010101, 0x0102010201010102, 0x0102010201010200,
+ 0x0102010201010202, 0x0102010201020000, 0x0102010201020101, 0x0102010201020200,
+ 0x0102010202000000, 0x0102010202000002, 0x0102010202000101, 0x0102010202000202,
+ 0x0102010202010100, 0x0102010202010102, 0x0102010202010200, 0x0102010202010201,
+ 0x0102010202020000, 0x0102010202020100, 0x0102010202020102, 0x0102010202020202,
+ 0x0102020000010102, 0x0102020000010201, 0x0102020000020101, 0x0102020001000001,
+ 0x0102020001010002, 0x0102020001010101, 0x0102020001010202, 0x0102020001020001,
+ 0x0102020001020201, 0x0102020002000101, 0x0102020002010001, 0x0102020002010200,
+ 0x0102020002020102, 0x0102020100000001, 0x0102020100000100, 0x0102020100010000,
+ 0x0102020100010101, 0x0102020100020001, 0x0102020100020100, 0x0102020100020102,
+ 0x0102020100020201, 0x0102020101000000, 0x0102020101000001, 0x0102020101000101,
+ 0x0102020101000102, 0x0102020101000200, 0x0102020101010001, 0x0102020101010100,
+ 0x0102020101010101, 0x0102020101010102, 0x0102020101010201, 0x0102020101020000,
+ 0x0102020101020101, 0x0102020101020202, 0x0102020102000002, 0x0102020102000100,
+ 0x0102020102000202, 0x0102020102010101, 0x0102020102020001, 0x0102020102020100,
+ 0x0102020102020101, 0x0102020102020201, 0x0102020200010001, 0x0102020200010102,
+ 0x0102020200010200, 0x0102020201000001, 0x0102020201000100, 0x0102020201000201,
+ 0x0102020201010000, 0x0102020201010101, 0x0102020201010200, 0x0102020201010202,
+ 0x0102020201020100, 0x0102020201020101, 0x0102020201020201, 0x0102020202000102,
+ 0x0102020202010100, 0x0102020202010200, 0x0102020202010202, 0x0102020202020102,
+ 0x0200000000000000, 0x0200000000000002, 0x0200000000000200, 0x0200000000000202,
+ 0x0200000000020000, 0x0200000000020002, 0x0200000000020200, 0x0200000000020202,
+ 0x0200000001000101, 0x0200000001010000, 0x0200000001010001, 0x0200000001010100,
+ 0x0200000001010102, 0x0200000001010201, 0x0200000001020101, 0x0200000002000000,
+ 0x0200000002000002, 0x0200000002000200, 0x0200000002000202, 0x0200000002010101,
+ 0x0200000002020000, 0x0200000002020002, 0x0200000002020200, 0x0200000002020202,
+ 0x0200000100000101, 0x0200000100010001, 0x0200000100010100, 0x0200000100010102,
+ 0x0200000100010201, 0x0200000100020101, 0x0200000101000001, 0x0200000101000100,
+ 0x0200000101000201, 0x0200000101010000, 0x0200000101010002, 0x0200000101010101,
+ 0x0200000101010102, 0x0200000101010200, 0x0200000101010201, 0x0200000101020100,
+ 0x0200000101020102, 0x0200000101020201, 0x0200000102000101, 0x0200000102000201,
+ 0x0200000102010100, 0x0200000102010102, 0x0200000102010201, 0x0200000102020101,
+ 0x0200000200000000, 0x0200000200000002, 0x0200000200000200, 0x0200000200000202,
+ 0x0200000200010101, 0x0200000200020000, 0x0200000200020002, 0x0200000200020200,
+ 0x0200000200020202, 0x0200000201010001, 0x0200000201010100, 0x0200000201010201,
+ 0x0200000201020101, 0x0200000202000000, 0x0200000202000002, 0x0200000202000200,
+ 0x0200000202000202, 0x0200000202010101, 0x0200000202020000, 0x0200000202020002,
+ 0x0200000202020200, 0x0200000202020202, 0x0200010000010100, 0x0200010000010201,
+ 0x0200010001000001, 0x0200010001000100, 0x0200010001010001, 0x0200010001010101,
+ 0x0200010001010202, 0x0200010001020001, 0x0200010001020100, 0x0200010001020201,
+ 0x0200010002010100, 0x0200010002010201, 0x0200010100000001, 0x0200010100000201,
+ 0x0200010100010002, 0x0200010100010101, 0x0200010100010202, 0x0200010100020102,
+ 0x0200010100020201, 0x0200010101000000, 0x0200010101000001, 0x0200010101000101,
+ 0x0200010101000200, 0x0200010101010001, 0x0200010101010100, 0x0200010101010101,
+ 0x0200010101010102, 0x0200010101010201, 0x0200010101010202, 0x0200010101020101,
+ 0x0200010101020102, 0x0200010101020200, 0x0200010101020202, 0x0200010102000001,
+ 0x0200010102000100, 0x0200010102000102, 0x0200010102000201, 0x0200010102010000,
+ 0x0200010102010002, 0x0200010102010101, 0x0200010102010200, 0x0200010102020102,
+ 0x0200010200010001, 0x0200010200010102, 0x0200010200010201, 0x0200010200020101,
+ 0x0200010201000001, 0x0200010201000100, 0x0200010201000201, 0x0200010201000202,
+ 0x0200010201010000, 0x0200010201010101, 0x0200010201010201, 0x0200010201010202,
+ 0x0200010201020001, 0x0200010201020102, 0x0200010201020202, 0x0200010202000101,
+ 0x0200010202010001, 0x0200010202010202, 0x0200010202020100, 0x0200020000000000,
+ 0x0200020000000002, 0x0200020000000200, 0x0200020000000202, 0x0200020000010101,
+ 0x0200020000020000, 0x0200020000020002, 0x0200020000020200, 0x0200020000020202,
+ 0x0200020001000001, 0x0200020001000101, 0x0200020001010001, 0x0200020001010100,
+ 0x0200020001010201, 0x0200020001020101, 0x0200020001020201, 0x0200020002000000,
+ 0x0200020002000002, 0x0200020002000200, 0x0200020002000202, 0x0200020002010101,
+ 0x0200020002020000, 0x0200020002020002, 0x0200020002020200, 0x0200020002020202,
+ 0x0200020100000101, 0x0200020100000102, 0x0200020100010001, 0x0200020100010100,
+ 0x0200020100010102, 0x0200020100020101, 0x0200020101000001, 0x0200020101000100,
+ 0x0200020101000102, 0x0200020101000201, 0x0200020101010000, 0x0200020101010002,
+ 0x0200020101010101, 0x0200020101010202, 0x0200020101020001, 0x0200020101020100,
+ 0x0200020102000101, 0x0200020102010102, 0x0200020102010201, 0x0200020102020101,
+ 0x0200020200000000, 0x0200020200000002, 0x0200020200000200, 0x0200020200000202,
+ 0x0200020200010101, 0x0200020200020000, 0x0200020200020002, 0x0200020200020200,
+ 0x0200020200020202, 0x0200020201000101, 0x0200020201010001, 0x0200020201010100,
+ 0x0200020201010102, 0x0200020202000000, 0x0200020202000002, 0x0200020202000200,
+ 0x0200020202000202, 0x0200020202010101, 0x0200020202020000, 0x0200020202020002,
+ 0x0200020202020200, 0x0200020202020202, 0x0201000000000101, 0x0201000000010001,
+ 0x0201000000010102, 0x0201000000010200, 0x0201000000010201, 0x0201000000020101,
+ 0x0201000001000001, 0x0201000001000102, 0x0201000001000201, 0x0201000001010101,
+ 0x0201000001010200, 0x0201000001010202, 0x0201000001020201, 0x0201000001020202,
+ 0x0201000002000101, 0x0201000002010001, 0x0201000002010100, 0x0201000002010102,
+ 0x0201000002010201, 0x0201000002020101, 0x0201000100000001, 0x0201000100000100,
+ 0x0201000100000102, 0x0201000100000201, 0x0201000100010000, 0x0201000100010101,
+ 0x0201000100010200, 0x0201000100010202, 0x0201000100020001, 0x0201000100020100,
+ 0x0201000100020102, 0x0201000100020201, 0x0201000101000000, 0x0201000101000101,
+ 0x0201000101010000, 0x0201000101010001, 0x0201000101010100, 0x0201000101010101,
+ 0x0201000101010102, 0x0201000101010201, 0x0201000101020002, 0x0201000101020101,
+ 0x0201000102000100, 0x0201000102000102, 0x0201000102010002, 0x0201000102010101,
+ 0x0201000102010200, 0x0201000102020001, 0x0201000102020100, 0x0201000102020102,
+ 0x0201000102020201, 0x0201000200000101, 0x0201000200010001, 0x0201000200010100,
+ 0x0201000200010201, 0x0201000200020101, 0x0201000201000100, 0x0201000201000102,
+ 0x0201000201000201, 0x0201000201010000, 0x0201000201010002, 0x0201000201010101,
+ 0x0201000201010200, 0x0201000201020102, 0x0201000201020201, 0x0201000202000101,
+ 0x0201000202010100, 0x0201000202010102, 0x0201000202020201, 0x0201010000000001,
+ 0x0201010000000100, 0x0201010000000102, 0x0201010000010000, 0x0201010000010101,
+ 0x0201010000010200, 0x0201010000020102, 0x0201010001000000, 0x0201010001000202,
+ 0x0201010001010001, 0x0201010001010100, 0x0201010001010101, 0x0201010001010102,
+ 0x0201010001010200, 0x0201010001010201, 0x0201010001020000, 0x0201010001020001,
+ 0x0201010001020002, 0x0201010001020101, 0x0201010002000100, 0x0201010002000102,
+ 0x0201010002010002, 0x0201010002010100, 0x0201010002010101, 0x0201010002010200,
+ 0x0201010002020001, 0x0201010002020201, 0x0201010100000000, 0x0201010100000101,
+ 0x0201010100000200, 0x0201010100000202, 0x0201010100010000, 0x0201010100010001,
+ 0x0201010100010100, 0x0201010100010101, 0x0201010100010102, 0x0201010100010201,
+ 0x0201010100020001, 0x0201010100020101, 0x0201010100020201, 0x0201010100020202,
+ 0x0201010101000001, 0x0201010101000100, 0x0201010101000101, 0x0201010101000102,
+ 0x0201010101000201, 0x0201010101010000, 0x0201010101010001, 0x0201010101010002,
+ 0x0201010101010100, 0x0201010101010101, 0x0201010101010102, 0x0201010101010200,
+ 0x0201010101010201, 0x0201010101010202, 0x0201010101020001, 0x0201010101020100,
+ 0x0201010101020101, 0x0201010101020102, 0x0201010101020201, 0x0201010102000001,
+ 0x0201010102000101, 0x0201010102000200, 0x0201010102010001, 0x0201010102010002,
+ 0x0201010102010100, 0x0201010102010101, 0x0201010102010102, 0x0201010102010201,
+ 0x0201010102010202, 0x0201010102020000, 0x0201010102020002, 0x0201010102020101,
+ 0x0201010102020200, 0x0201010102020202, 0x0201010200000001, 0x0201010200000100,
+ 0x0201010200010000, 0x0201010200010101, 0x0201010200010201, 0x0201010200020000,
+ 0x0201010200020102, 0x0201010200020201, 0x0201010201000101, 0x0201010201000200,
+ 0x0201010201000201, 0x0201010201010001, 0x0201010201010002, 0x0201010201010101,
+ 0x0201010201010102, 0x0201010201010201, 0x0201010201020101, 0x0201010201020200,
+ 0x0201010202000002, 0x0201010202000100, 0x0201010202000201, 0x0201010202000202,
+ 0x0201010202010002, 0x0201010202010100, 0x0201010202010101, 0x0201010202020100,
+ 0x0201010202020102, 0x0201010202020201, 0x0201020000000101, 0x0201020000010102,
+ 0x0201020000010201, 0x0201020000020101, 0x0201020001000001, 0x0201020001000102,
+ 0x0201020001010000, 0x0201020001010002, 0x0201020001010101, 0x0201020001010102,
+ 0x0201020001010202, 0x0201020001020100, 0x0201020001020101, 0x0201020002000101,
+ 0x0201020002010001, 0x0201020002010102, 0x0201020002010201, 0x0201020002020101,
+ 0x0201020100000100, 0x0201020100000102, 0x0201020100000201, 0x0201020100010000,
+ 0x0201020100010002, 0x0201020100010101, 0x0201020100010200, 0x0201020100010202,
+ 0x0201020100020000, 0x0201020100020001, 0x0201020100020100, 0x0201020100020102,
+ 0x0201020101000000, 0x0201020101000002, 0x0201020101000101, 0x0201020101000200,
+ 0x0201020101000202, 0x0201020101010001, 0x0201020101010100, 0x0201020101010101,
+ 0x0201020101010102, 0x0201020101010201, 0x0201020101020002, 0x0201020101020101,
+ 0x0201020101020102, 0x0201020101020202, 0x0201020102000001, 0x0201020102000100,
+ 0x0201020102010000, 0x0201020102010002, 0x0201020102010101, 0x0201020102010202,
+ 0x0201020102020001, 0x0201020102020102, 0x0201020200000101, 0x0201020200010101,
+ 0x0201020200020101, 0x0201020201000100, 0x0201020201000102, 0x0201020201000201,
+ 0x0201020201010000, 0x0201020201010101, 0x0201020201010200, 0x0201020201020001,
+ 0x0201020202000101, 0x0201020202010001, 0x0201020202010100, 0x0201020202010101,
+ 0x0201020202010102, 0x0202000000000000, 0x0202000000000002, 0x0202000000000200,
+ 0x0202000000000202, 0x0202000000010101, 0x0202000000020000, 0x0202000000020002,
+ 0x0202000000020200, 0x0202000000020202, 0x0202000001000101, 0x0202000001010001,
+ 0x0202000001010100, 0x0202000001010102, 0x0202000001010201, 0x0202000002000000,
+ 0x0202000002000002, 0x0202000002000200, 0x0202000002000202, 0x0202000002010101,
+ 0x0202000002020000, 0x0202000002020002, 0x0202000002020200, 0x0202000002020202,
+ 0x0202000100000101, 0x0202000100000201, 0x0202000100010001, 0x0202000100010100,
+ 0x0202000100010102, 0x0202000100010201, 0x0202000100010202, 0x0202000101000102,
+ 0x0202000101000201, 0x0202000101010001, 0x0202000101010101, 0x0202000101010200,
+ 0x0202000101010202, 0x0202000101020001, 0x0202000101020100, 0x0202000102000101,
+ 0x0202000102010000, 0x0202000102010002, 0x0202000102010102, 0x0202000102010201,
+ 0x0202000200000002, 0x0202000200000200, 0x0202000200000202, 0x0202000200010000,
+ 0x0202000200010201, 0x0202000200020002, 0x0202000200020200, 0x0202000200020202,
+ 0x0202000201000101, 0x0202000201010001, 0x0202000201010102, 0x0202000201010201,
+ 0x0202000201020101, 0x0202000202000000, 0x0202000202000002, 0x0202000202000200,
+ 0x0202000202000202, 0x0202000202010101, 0x0202000202020000, 0x0202000202020002,
+ 0x0202000202020200, 0x0202000202020202, 0x0202010000010201, 0x0202010000020101,
+ 0x0202010001000001, 0x0202010001000100, 0x0202010001010000, 0x0202010001010100,
+ 0x0202010001010101, 0x0202010001010200, 0x0202010001010202, 0x0202010001020001,
+ 0x0202010001020101, 0x0202010001020102, 0x0202010001020200, 0x0202010001020201,
+ 0x0202010002000101, 0x0202010100000102, 0x0202010100000201, 0x0202010100010000,
+ 0x0202010100010002, 0x0202010100010101, 0x0202010100010200, 0x0202010100020102,
+ 0x0202010100020201, 0x0202010101000002, 0x0202010101000101, 0x0202010101010001,
+ 0x0202010101010100, 0x0202010101010101, 0x0202010101010102, 0x0202010101010201,
+ 0x0202010101020101, 0x0202010101020202, 0x0202010102000001, 0x0202010102000100,
+ 0x0202010102000101, 0x0202010102000102, 0x0202010102000201, 0x0202010102010002,
+ 0x0202010102010101, 0x0202010102010200, 0x0202010200000101, 0x0202010200010001,
+ 0x0202010200010102, 0x0202010200010202, 0x0202010200020001, 0x0202010200020101,
+ 0x0202010201000100, 0x0202010201000102, 0x0202010201000202, 0x0202010201010002,
+ 0x0202010201010101, 0x0202010201010102, 0x0202010201010200, 0x0202010201020000,
+ 0x0202010201020002, 0x0202010202000102, 0x0202010202010000, 0x0202010202010101,
+ 0x0202010202010102, 0x0202010202010201, 0x0202010202020001, 0x0202010202020100,
+ 0x0202010202020102, 0x0202020000000000, 0x0202020000000002, 0x0202020000000200,
+ 0x0202020000000202, 0x0202020000020000, 0x0202020000020002, 0x0202020000020200,
+ 0x0202020000020202, 0x0202020001010001, 0x0202020001010100, 0x0202020001010102,
+ 0x0202020001010201, 0x0202020002000000, 0x0202020002000002, 0x0202020002000200,
+ 0x0202020002000202, 0x0202020002010101, 0x0202020002020000, 0x0202020002020002,
+ 0x0202020002020200, 0x0202020002020202, 0x0202020100000101, 0x0202020100010100,
+ 0x0202020100010201, 0x0202020100020001, 0x0202020100020101, 0x0202020101000001,
+ 0x0202020101010000, 0x0202020101010101, 0x0202020101010202, 0x0202020101020001,
+ 0x0202020101020102, 0x0202020101020201, 0x0202020102010000, 0x0202020102010102,
+ 0x0202020200000000, 0x0202020200000002, 0x0202020200000200, 0x0202020200000202,
+ 0x0202020200020000, 0x0202020200020002, 0x0202020200020200, 0x0202020200020202,
+ 0x0202020201010001, 0x0202020201010100, 0x0202020201010102, 0x0202020202000000,
+ 0x0202020202000002, 0x0202020202000200, 0x0202020202000202, 0x0202020202010101,
+ 0x0202020202020000, 0x0202020202020002, 0x0202020202020200, 0x0202020202020202,
+};
+#else
+static const uint32_t iq1s_grid_us[2048] = {
+ 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,
+ 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,
+ 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,
+ 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212,
+ 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011,
+ 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111,
+ 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220,
+ 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022,
+ 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,
+ 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101,
+ 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110,
+ 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111,
+ 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010,
+ 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210,
+ 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221,
+ 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021,
+ 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002,
+ 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,
+ 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101,
+ 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211,
+ 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110,
+ 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022,
+ 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121,
+ 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220,
+ 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001,
+ 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101,
+ 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,
+ 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012,
+ 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010,
+ 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111,
+ 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122,
+ 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222,
+ 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001,
+ 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102,
+ 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101,
+ 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,
+ 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101,
+ 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112,
+ 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110,
+ 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211,
+ 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012,
+ 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111,
+ 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120,
+ 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122,
+ 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,
+ 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221,
+ 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001,
+ 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101,
+ 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101,
+ 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011,
+ 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111,
+ 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011,
+ 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122,
+ 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,
+ 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222,
+ 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101,
+ 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000,
+ 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200,
+ 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110,
+ 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112,
+ 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222,
+ 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021,
+ 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,
+ 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201,
+ 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200,
+ 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101,
+ 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011,
+ 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010,
+ 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211,
+ 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121,
+ 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000,
+ 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,
+ 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202,
+ 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211,
+ 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112,
+ 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020,
+ 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121,
+ 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222,
+ 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102,
+ 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100,
+ 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,
+ 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011,
+ 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111,
+ 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110,
+ 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121,
+ 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222,
+ 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201,
+ 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102,
+ 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201,
+ 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,
+ 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010,
+ 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010,
+ 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110,
+ 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011,
+ 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212,
+ 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021,
+ 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021,
+ 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021,
+ 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,
+ 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101,
+ 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100,
+ 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010,
+ 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111,
+ 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010,
+ 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111,
+ 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120,
+ 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120,
+ 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,
+ 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001,
+ 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201,
+ 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210,
+ 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211,
+ 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111,
+ 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112,
+ 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211,
+ 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010,
+ 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,
+ 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122,
+ 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221,
+ 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102,
+ 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100,
+ 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101,
+ 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101,
+ 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101,
+ 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012,
+ 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,
+ 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112,
+ 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210,
+ 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210,
+ 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210,
+ 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010,
+ 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110,
+ 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122,
+ 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020,
+ 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,
+ 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022,
+ 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120,
+ 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222,
+ 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221,
+ 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001,
+ 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102,
+ 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201,
+ 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012,
+ 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,
+ 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012,
+ 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110,
+ 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110,
+ 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121,
+ 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221,
+ 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220,
+ 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222,
+ 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000,
+ 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,
+ 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012,
+ 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011,
+ 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212,
+ 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221,
+ 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121,
+ 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202,
+ 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202,
+ 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002,
+ 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,
+ 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210,
+ 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112,
+ 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011,
+ 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011,
+ 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210,
+ 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020,
+ 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220,
+ 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222,
+ 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,
+ 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001,
+ 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010,
+ 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111,
+ 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010,
+ 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110,
+ 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221,
+ 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122,
+ 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202,
+ 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,
+ 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101,
+ 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112,
+ 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111,
+ 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211,
+ 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222,
+ 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221,
+ 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022,
+ 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101,
+ 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,
+ 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111,
+ 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111,
+ 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010,
+ 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121,
+ 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222,
+ 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000,
+ 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202,
+ 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000,
+ 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,
+ 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110,
+ 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110,
+ 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222,
+ 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120,
+ 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022,
+ 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101,
+ 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202,
+ 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110,
+ 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,
+ 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111,
+ 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111,
+ 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120,
+ 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121,
+ 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001,
+ 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202,
+ 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001,
+ 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200,
+ 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,
+ 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212,
+ 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012,
+ 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110,
+ 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012,
+ 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111,
+ 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020,
+ 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121,
+ 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222,
+ 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,
+ 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102,
+ 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101,
+ 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212,
+ 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210,
+ 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111,
+ 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212,
+ 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221,
+ 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121,
+ 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,
+ 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000,
+ 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202,
+ 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112,
+ 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111,
+ 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020,
+ 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221,
+ 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022,
+ 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100,
+ 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,
+ 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112,
+ 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211,
+ 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012,
+ 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121,
+ 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020,
+ 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120,
+ 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200,
+ 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200,
+ 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,
+ 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011,
+ 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222,
+ 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,
+ 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,
+};
+#endif
+
+}
+
+#ifdef __x86_64__
+
+namespace {
+template <int nrc_y>
+void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ __m256i qx[8];
+ __m256i scales[4];
+ __m256 acc[nrc_y] = {};
+ auto delta_mask = _mm_set1_epi16(-32768); // to avoid stupid overflow warnings when using 0x8000
+ __m256i shuffle0 = _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < n/QK_K; ++ibl) {
+ float d = GGML_FP16_TO_FP32(iq1s[ibl].d);
+ auto qhb = _mm_loadu_si128((const __m128i *)iq1s[ibl].qh);
+ auto scales128 = _mm_and_si128(_mm_srli_epi16(qhb, 12), _mm_set1_epi16(7));
+ scales128 = _mm_add_epi16(_mm_slli_epi16(scales128, 1), _mm_set1_epi16(1));
+#ifdef HAVE_FANCY_SIMD
+ auto mask = _mm_cmpeq_epi16_mask(_mm_and_si128(qhb, delta_mask), delta_mask);
+ auto deltas128 = _mm_mask_blend_epi16(mask, _mm_set1_epi16(-7), _mm_set1_epi16(-9));
+#else
+ auto mask = _mm_cmpeq_epi16(_mm_and_si128(qhb, delta_mask), delta_mask);
+ auto deltas128 = _mm_or_si128(_mm_and_si128(mask, _mm_set1_epi16(-9)), _mm_andnot_si128(mask, _mm_set1_epi16(-7)));
+#endif
+ deltas128 = _mm_mullo_epi16(scales128, deltas128);
+ scales128 = _mm_slli_epi16(scales128, 3);
+ auto deltas_l = _mm_unpacklo_epi16(deltas128, deltas128);
+ auto deltas_h = _mm_unpackhi_epi16(deltas128, deltas128);
+ auto deltas = MM256_SET_M128I(deltas_h, deltas_l); // blocks 0,0, 1,1, 2,2, ..., 7,7
+ auto all_scales = MM256_SET_M128I(scales128, scales128);
+ auto shuffle = shuffle0;
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+ scales[ib64] = _mm256_shuffle_epi8(all_scales, shuffle);
+ shuffle = _mm256_add_epi8(shuffle, _mm256_set1_epi8(4));
+ }
+ const uint8_t * qs = iq1s[ibl].qs;
+ const uint16_t * qh = iq1s[ibl].qh;
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
+ qx[ib+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid_us[qs[2] | ((qh[ib+0] << 2) & 0x700)],
+ iq1s_grid_us[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid_us[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
+ qx[ib+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid_us[qs[6] | ((qh[ib+1] << 2) & 0x700)],
+ iq1s_grid_us[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid_us[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
+ qs += 8;
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, ibl);
+ auto sumi = _mm256_setzero_si256();
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+ auto qy1 = q8.load_quants(iy, ibl, 2*ib64+0);
+ auto qy2 = q8.load_quants(iy, ibl, 2*ib64+1);
+#ifdef HAVE_FANCY_SIMD
+ auto dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+0], qy1);
+ auto dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+1], qy2);
+ sumi = _mm256_dpwssd_epi32(sumi, scales[ib64], _mm256_packs_epi32(dot1, dot2));
+#else
+ auto dot1 = _mm256_maddubs_epi16(qx[2*ib64+0], qy1);
+ auto dot2 = _mm256_maddubs_epi16(qx[2*ib64+1], qy2);
+ auto dot = _mm256_add_epi16(_mm256_unpacklo_epi64(dot1, dot2), _mm256_unpackhi_epi64(dot1, dot2));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(scales[ib64], dot));
+#endif
+ }
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm256_dpwssd_epi32(sumi, bsums, deltas);
+#else
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(bsums, deltas));
+#endif
+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, 0.125f*hsum_float_8(acc[iy]));
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K128> q8(info);
+ int nb = n / 32;
+ GGML_ASSERT(nb%4 == 0);
+ __m256i qx[4];
+ __m256 acc[nrc_y] = {};
+ auto m1 = _mm256_set1_epi16(1);
+ auto ms = _mm_set1_epi16(-32768);
+ float d8[4*nrc_y];
+ union { __m256i vec; uint16_t val[16]; } helper;
+ struct aux_iq1_s_r4 {
+ uint8_t qs[16];
+ uint64_t qh;
+ };
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
+ auto d1 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr));
+ auto x = (const aux_iq1_s_r4 *)(dptr + 4);
+ for (int ib = 0; ib < nb/4; ++ib) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = _mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib].bsums));
+ _mm_storeu_ps(d8 + 4*iy, _mm_mul_ps(_mm_set1_ps(q8.y[iy][ib].d), _mm_cvtepi32_ps(bsums)));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto idxh = _mm256_set1_epi64x(x[4*ib+k].qh);
+ auto sas = _mm256_castsi256_si128(idxh);
+ auto scales4 = _mm_and_si128(_mm_srli_epi16(sas, 12), _mm_set1_epi16(7));
+ scales4 = _mm_or_si128(_mm_slli_epi16(scales4, 1), _mm_set1_epi16(1));
+ auto signs = _mm_or_si128(_mm_cmpeq_epi16(_mm_and_si128(sas, ms), ms), _mm256_castsi256_si128(m1));
+ signs = _mm_add_epi16(_mm_set1_epi16(-8), signs);
+ signs = _mm_mullo_epi16(signs, scales4);
+ auto delta4 = _mm_mul_ps(_mm_set1_ps(0.0625f), _mm_cvtepi32_ps(_mm_cvtepi16_epi32(signs)));
+ auto delta = _mm256_set_m128(delta4, delta4);
+ scales4 = _mm_unpacklo_epi16(scales4, scales4); // 0,0, 1,1, 2,2, 3,3
+ auto scales = MM256_SET_M128I(scales4, scales4);
+ auto idxl = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)x[4*ib+k].qs));
+ idxh = _mm256_sllv_epi64(idxh, _mm256_set_epi64x(0, 2, 5, 8));
+ idxh = _mm256_srlv_epi64(idxh, _mm256_set_epi64x(1, 0, 0, 0));
+ helper.vec = _mm256_or_si256(idxl, _mm256_and_si256(_mm256_set1_epi16(0x0700), idxh));
+ qx[0] = _mm256_set_epi64x(iq1s_grid_us[helper.val[ 9]], iq1s_grid_us[helper.val[ 8]],
+ iq1s_grid_us[helper.val[ 1]], iq1s_grid_us[helper.val[ 0]]);
+ qx[1] = _mm256_set_epi64x(iq1s_grid_us[helper.val[13]], iq1s_grid_us[helper.val[12]],
+ iq1s_grid_us[helper.val[ 5]], iq1s_grid_us[helper.val[ 4]]);
+ qx[2] = _mm256_set_epi64x(iq1s_grid_us[helper.val[11]], iq1s_grid_us[helper.val[10]],
+ iq1s_grid_us[helper.val[ 3]], iq1s_grid_us[helper.val[ 2]]);
+ qx[3] = _mm256_set_epi64x(iq1s_grid_us[helper.val[15]], iq1s_grid_us[helper.val[14]],
+ iq1s_grid_us[helper.val[ 7]], iq1s_grid_us[helper.val[ 6]]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ib].qs + k);
+#ifdef HAVE_FANCY_SIMD
+ // 0,0, 1,1, 0,0, 1,1 as int32_t
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(),
+ qx[0], _mm256_shuffle_epi32(y, 0x44)), qx[1], _mm256_shuffle_epi32(y, 0xee));
+ // 2,2, 3,3, 2,2, 3,3 as int32_t
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(),
+ qx[2], _mm256_shuffle_epi32(y, 0x44)), qx[3], _mm256_shuffle_epi32(y, 0xee));
+ auto sumi = _mm256_packs_epi32(sumi1, sumi2);
+#else
+ // 4 x row 0, 4 x row 1, 4 x row 0, 4 x row 1
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x44)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0xee)));
+ // 4 x row 2, 4 x row 3, 4 x row 2, 4 x row 3
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0x44)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xee)));
+ // 0,0, 1,1, 0,0, 1,1 as int32_t
+ sumi1 = _mm256_madd_epi16(m1, sumi1);
+ // 2,2, 3,3, 2,2, 3,3 as int32_t
+ sumi2 = _mm256_madd_epi16(m1, sumi2);
+ // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 as int16_t
+ auto sumi = _mm256_packs_epi32(sumi1, sumi2);
+#endif
+ sumi = _mm256_madd_epi16(scales, sumi);
+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.y[iy][ib].d), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[4*iy+k]), delta, acc[iy]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumf = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, _mm_mul_ps(d1, sumf));
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K128> q8(info);
+ int nb = n / 32;
+ GGML_ASSERT(nb%4 == 0);
+ auto shuffle0 = _mm256_set_epi64x(0x0909090909090909, 0x0808080808080808, 0x0101010101010101, 0x0000000000000000);
+ auto step = _mm256_set1_epi8(2);
+#ifndef HAVE_FANCY_SIMD
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ __m256i qx[4];
+ __m256 acc[nrc_y] = {};
+ __m256i isum[nrc_y] = {};
+ auto ms = _mm_set1_epi8(0x08);
+ union { __m256i vec; uint16_t val[16]; } helper;
+ for (int ix= 0; ix < nrc_x; ix += 4) {
+ auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
+ auto d1 = _mm_mul_ps(_mm_set1_ps(0.125f), _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr)));
+ auto x = (const block_iq1_m_r4 *)(dptr + 4);
+ for (int ib = 0; ib < nb/4; ++ib) {
+ for (int k = 0; k < 4; ++k) {
+ auto qh = (const uint32_t *)x[4*ib+k].qh;
+ auto idxh = _mm_set_epi32(qh[1] >> 4, qh[1], qh[0] >> 4, qh[0]);
+ auto scales4 = _mm_set1_epi32(((const uint32_t *)x[4*ib+k].scales)[0]);
+ scales4 = _mm_and_si128(_mm_srlv_epi32(scales4, _mm_set_epi32(4, 0, 4, 0)), _mm_set1_epi8(0xf));
+ scales4 = _mm_cvtepu8_epi16(scales4);
+ auto scales = MM256_SET_M128I(_mm_unpackhi_epi16(scales4, scales4), _mm_unpacklo_epi16(scales4, scales4));
+
+ auto signs128 = _mm_or_si128(_mm_cmpeq_epi8(_mm_and_si128(idxh, ms), ms), _mm_set1_epi8(1));
+ signs128 = _mm_add_epi8(_mm_set1_epi8(-8), signs128);
+ auto signs = MM256_SET_M128I(signs128, signs128);
+ auto idxl = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)x[4*ib+k].qs));
+ idxh = _mm_and_si128(idxh, _mm_set1_epi8(0x07));
+ helper.vec = _mm256_or_si256(idxl, _mm256_slli_epi16(_mm256_cvtepu8_epi16(idxh), 8));
+ qx[0] = _mm256_set_epi64x(iq1s_grid_us[helper.val[ 9]], iq1s_grid_us[helper.val[ 8]],
+ iq1s_grid_us[helper.val[ 1]], iq1s_grid_us[helper.val[ 0]]);
+ qx[1] = _mm256_set_epi64x(iq1s_grid_us[helper.val[13]], iq1s_grid_us[helper.val[12]],
+ iq1s_grid_us[helper.val[ 5]], iq1s_grid_us[helper.val[ 4]]);
+ qx[2] = _mm256_set_epi64x(iq1s_grid_us[helper.val[11]], iq1s_grid_us[helper.val[10]],
+ iq1s_grid_us[helper.val[ 3]], iq1s_grid_us[helper.val[ 2]]);
+ qx[3] = _mm256_set_epi64x(iq1s_grid_us[helper.val[15]], iq1s_grid_us[helper.val[14]],
+ iq1s_grid_us[helper.val[ 7]], iq1s_grid_us[helper.val[ 6]]);
+ qx[0] = _mm256_add_epi8(_mm256_slli_epi16(qx[0], 3), _mm256_shuffle_epi8(signs, shuffle0));
+ auto shuffle = _mm256_add_epi8(shuffle0, step);
+ qx[2] = _mm256_add_epi8(_mm256_slli_epi16(qx[2], 3), _mm256_shuffle_epi8(signs, shuffle));
+ shuffle = _mm256_add_epi8(shuffle, step);
+ qx[1] = _mm256_add_epi8(_mm256_slli_epi16(qx[1], 3), _mm256_shuffle_epi8(signs, shuffle));
+ shuffle = _mm256_add_epi8(shuffle, step);
+ qx[3] = _mm256_add_epi8(_mm256_slli_epi16(qx[3], 3), _mm256_shuffle_epi8(signs, shuffle));
+ auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
+ auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
+ auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
+ auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ib].qs + k);
+ auto y1 = _mm256_shuffle_epi32(y, 0x44);
+ auto y2 = _mm256_shuffle_epi32(y, 0xee);
+#ifdef HAVE_FANCY_SIMD
+ // 0,0, 1,1, 0,0, 1,1 as int32_t
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(),
+ s0, _mm256_sign_epi8(y1, qx[0])), s1, _mm256_sign_epi8(y2, qx[1]));
+ // 2,2, 3,3, 2,2, 3,3 as int32_t
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(),
+ s2, _mm256_sign_epi8(y1, qx[2])), s3, _mm256_sign_epi8(y2, qx[3]));
+ auto sumi = _mm256_packs_epi32(sumi1, sumi2);
+#else
+ // 4 x row 0, 4 x row 1, 4 x row 0, 4 x row 1
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(s0, _mm256_sign_epi8(y1, qx[0])),
+ _mm256_maddubs_epi16(s1, _mm256_sign_epi8(y2, qx[1])));
+ // 4 x row 2, 4 x row 3, 4 x row 2, 4 x row 3
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(s2, _mm256_sign_epi8(y1, qx[2])),
+ _mm256_maddubs_epi16(s3, _mm256_sign_epi8(y2, qx[3])));
+ // 0,0, 1,1, 0,0, 1,1 as int32_t
+ sumi1 = _mm256_madd_epi16(m1, sumi1);
+ // 2,2, 3,3, 2,2, 3,3 as int32_t
+ sumi2 = _mm256_madd_epi16(m1, sumi2);
+ // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 as int16_t
+ auto sumi = _mm256_packs_epi32(sumi1, sumi2);
+#endif
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales, sumi));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.y[iy][ib].d), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumf = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, _mm_mul_ps(d1, sumf));
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+template <int nrc> struct Q8_K64 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8_K64(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const float * dptr = (const float *)info.src1_row(iy);
+ std::memcpy(d + 8*iy, dptr, 8*sizeof(float));
+ y[iy] = (const int8_t *)(dptr + 8);
+ }
+ }
+
+ inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); }
+ inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 8*iy); }
+ inline __m128 minus(int iy) const { return _mm_loadu_ps(d + 8*iy + 4); }
+
+ float d[8*nrc_y];
+ const int8_t * y[nrc_y];
+};
+
+struct DequantizerIQ1BN {
+ const __m256i m1_8 = _mm256_set1_epi8(1);
+ static __m256i load_shuffle(int i) {
+ static const uint8_t data[128] = {
+ 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 1, 255, 1, 255, 1, 255, 1, 255, 1, 255, 2, 255, 2, 255, 2, 255, 2, 255, 2, 255, 12, 255,
+ 3, 255, 3, 255, 3, 255, 3, 255, 3, 255, 4, 255, 4, 255, 4, 255, 4, 255, 4, 255, 5, 255, 5, 255, 5, 255, 5, 255, 5, 255, 12, 255,
+ 6, 255, 6, 255, 6, 255, 6, 255, 6, 255, 7, 255, 7, 255, 7, 255, 7, 255, 7, 255, 8, 255, 8, 255, 8, 255, 8, 255, 8, 255, 12, 255,
+ 9, 255, 9, 255, 9, 255, 9, 255, 9, 255, 10, 255, 10, 255, 10, 255, 10, 255, 10, 255, 11, 255, 11, 255, 11, 255, 11, 255, 11, 255, 12, 255,
+ };
+ return _mm256_loadu_si256((const __m256i*)data + i);
+ }
+ const __m256i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) };
+ const __m256i mult[4] = {
+ _mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
+ _mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
+ _mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
+ _mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
+ };
+ const __m256i m3 = _mm256_set1_epi16(3);
+#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__
+ const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
+#endif
+
+ IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const {
+ auto data128 = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes!
+ auto data = MM256_SET_M128I(data128, data128);
+ auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[0]), mult[0]), m3);
+ auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[1]), mult[1]), m3);
+ auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3);
+ auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3);
+#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__
+ v1 = _mm256_permutex2var_epi8(val1, bmask, val2);
+ v2 = _mm256_permutex2var_epi8(val3, bmask, val4);
+#else
+ v1 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216);
+ v2 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216);
+#endif
+ }
+
+};
+
+template <int nrc_y>
+IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK_IQ1BN;
+ Q8_K64<nrc_y> q8(info);
+ DequantizerIQ1BN deq;
+ __m256i accd[nrc_y];
+ __m256i val[4];
+
+#ifndef HAVE_FANCY_SIMD
+ const auto m1_16 = _mm256_set1_epi16(1);
+#endif
+
+ const block_iq1_bn * x;
+ const char * cx0 = (const char *)vx;
+ float scale;
+ ggml_half d16;
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ const char * cx = cx0 + ix*bx;
+ std::memcpy(&d16, cx, sizeof(d16));
+ scale = GGML_FP16_TO_FP32(d16);
+ cx += sizeof(d16);
+ x = (const block_iq1_bn *)cx;
+
+ if constexpr (nrc_y == 1) {
+ __m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256();
+ for (int i = 0; i < nb/2; ++i) {
+ deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
+ deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
+#ifdef HAVE_FANCY_SIMD
+ acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1));
+ acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3));
+#else
+ auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
+ _mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
+ auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
+ _mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
+ acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1));
+ acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2));
+#endif
+ }
+ accd[0] = _mm256_add_epi32(acc1, acc2);
+ }
+ else {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
+
+ for (int i = 0; i < nb/2; ++i) {
+
+ deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
+ deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+#ifdef HAVE_FANCY_SIMD
+ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
+ val[0], q8.load_quants(iy, i, 0)),
+ val[1], q8.load_quants(iy, i, 1)),
+ val[2], q8.load_quants(iy, i, 2)),
+ val[3], q8.load_quants(iy, i, 3));
+#else
+ auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
+ _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1)));
+ auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
+ _mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)));
+ dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
+ accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
+#endif
+ }
+ }
+ }
+ int i = 2*(nb/2);
+ if (i < nb) {
+ deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+#ifdef HAVE_FANCY_SIMD
+ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
+ val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1));
+#else
+ auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
+ _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 1))));
+ accd[iy] = _mm256_add_epi32(dot, accd[iy]);
+#endif
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto vd = q8.scale(iy);
+ auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
+ auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
+ info.store(ix, iy, scale*hsum_float_4(sumf));
+ }
+
+ }
+}
+
+struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn, true> {
+ DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ IQK_ALWAYS_INLINE void prepare4(int i, __m256i * val) const {
+ auto q2bits_1 = _mm256_loadu_si256((const __m256i *)x[2*i].qs);
+ auto q2bits_2 = _mm256_srli_epi16(q2bits_1, 2);
+ make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x20), val+0);
+ make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2);
+ }
+ IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const {
+ val[0] = _mm256_and_si256(q2_1, mask2);
+ val[1] = _mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2);
+ }
+ IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const {
+ auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
+ make2(MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1), val);
+ }
+ const __m256i m1_8 = _mm256_set1_epi8(1);
+ const __m256i mf_8 = _mm256_set1_epi8(16);
+ const __m256i mask2 = _mm256_set1_epi8(0x03);
+ const __m256i mask3 = _mm256_set1_epi8(0x30);
+};
+
+template <int nrc_y>
+IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK_IQ1BN;
+ Q8_K64<nrc_y> q8(info);
+ DequantizeIQ2BN deq(vx, bx);
+ __m256i accd[nrc_y];
+ __m256i val[4];
+
+#ifndef HAVE_FANCY_SIMD
+ const auto m1_16 = _mm256_set1_epi16(1);
+#endif
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq.new_row(ix);
+
+ if constexpr (nrc_y == 1) {
+ __m256i acc[2] = {};
+ for (int i = 0; i < nb/2; ++i) {
+ deq.prepare4(i, val);
+#ifdef HAVE_FANCY_SIMD
+ acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)),
+ val[1], q8.load_quants(0, i, 1));
+ acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)),
+ val[3], q8.load_quants(0, i, 3));
+#else
+ auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
+ _mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
+ auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
+ _mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
+ acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1));
+ acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2));
+#endif
+ }
+ accd[0] = _mm256_add_epi32(acc[0], acc[1]);
+ }
+ else {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
+
+ for (int i = 0; i < nb/2; ++i) {
+ deq.prepare4(i, val);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+#ifdef HAVE_FANCY_SIMD
+ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
+ val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)),
+ val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3));
+#else
+ auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
+ _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
+ _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))),
+ _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
+ _mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)))));
+ accd[iy] = _mm256_add_epi32(dot, accd[iy]);
+#endif
+ }
+ }
+ }
+ int i = 2*(nb/2);
+ if (i < nb) {
+ deq.prepare2(i, val);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+#ifdef HAVE_FANCY_SIMD
+ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)),
+ val[1], q8.load_quants(iy, i/2, 1));
+#else
+ auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
+ _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 0))));
+ accd[iy] = _mm256_add_epi32(dot, accd[iy]);
+#endif
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto vd = q8.scale(iy);
+ auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
+ auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
+ info.store(ix, iy, deq.d*hsum_float_4(sumf));
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if (nrc_x%4) {
+ printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
+ GGML_ABORT("fatal error");
+ }
+ Q8_16<nrc_y> q8(info);
+ auto m3 = _mm256_set1_epi8(0x3);
+ auto m1 = _mm256_set1_epi16(1);
+ int nb = n / QK_IQ1BN;
+ __m256i qx[4];
+ if constexpr (nrc_y > 4) {
+ __m256i acc[nrc_y] = {};
+ __m128 sum4[nrc_y];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = _mm_loadu_ps(dptr);
+ const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
+ qx[0] = _mm256_and_si256(bits, m3);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants(iy, 2*ib+0);
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
+ auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
+ s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4);
+ sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4);
+ acc[iy] = _mm256_setzero_si256();
+ }
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
+ qx[0] = _mm256_and_si256(bits, m3);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants(iy, 2*ib+1);
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
+ auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]);
+ s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4);
+ info.store(ix, iy, s4);
+ acc[iy] = _mm256_setzero_si256();
+ }
+ }
+ } else {
+ __m256i acc[2*nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = _mm_loadu_ps(dptr);
+ const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
+ qx[0] = _mm256_and_si256(bits, m3);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants(iy, 2*ib+0);
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ acc[2*iy+0] = _mm256_add_epi32(acc[2*iy+0], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
+ }
+ bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
+ qx[0] = _mm256_and_si256(bits, m3);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants(iy, 2*ib+1);
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ acc[2*iy+1] = _mm256_add_epi32(acc[2*iy+1], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]);
+ auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]);
+ auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
+ sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
+ sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
+ sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
+ sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
+ info.store(ix, iy, sum4);
+ acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256();
+ }
+ }
+ }
+}
+
+
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if (nrc_x%4) {
+ printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
+ GGML_ABORT("fatal error");
+ }
+ if constexpr (nrc_y == 1) {
+ mul_mat_iq2_bn_r4_q8_k16_avx2<1>(n, vx, bx, info, nrc_x);
+ } else {
+ Q8_16<nrc_y> q8(info);
+ auto m3 = _mm512_set1_epi8(0x3);
+ int nb = n / QK_IQ1BN;
+ __m512i acc[2*nrc_y] = {};
+ __m512i qx[8];
+ for (int ix = 0; ix < nrc_x/8; ++ix) {
+ const float * dptr1 = (const float *)((const char *)vx + (8*ix+0)*bx);
+ const float * dptr2 = (const float *)((const char *)vx + (8*ix+4)*bx);
+ auto dl = _mm_loadu_ps(dptr1);
+ auto dh = _mm_loadu_ps(dptr2);
+ const uint8_t * iq2l = (const uint8_t *)(dptr1 + 4);
+ const uint8_t * iq2h = (const uint8_t *)(dptr2 + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
+ auto bits_h = _mm512_loadu_si512((const __m512i *)iq2h + ib);
+ qx[0] = _mm512_and_si512(bits_l, m3);
+ qx[1] = _mm512_and_si512(bits_h, m3);
+ qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
+ qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 2), m3);
+ qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
+ qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 4), m3);
+ qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
+ qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants64(iy, ib);
+ auto sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00));
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], sy);
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], sy);
+ sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55));
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], sy);
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], sy);
+ sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa));
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[4], sy);
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[5], sy);
+ sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff));
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[6], sy);
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[7], sy);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ __m128 sum4;
+ for (int k = 0; k < 2; ++k) {
+ const auto& dx = k == 0 ? dl : dh;
+ auto sumf = _mm512_cvtepi32_ps(acc[2*iy+k]);
+ sum4 = _mm_mul_ps (_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x00)));
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
+ sum4 = _mm_fmadd_ps(dx, _mm_set1_ps(-q8.sum_row(iy)), sum4);
+ info.store(8*ix + 4*k, iy, sum4);
+ }
+ acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
+ }
+ }
+ if (int ix = 8*(nrc_x/8); ix < nrc_x) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = _mm_loadu_ps(dptr);
+ const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
+ qx[0] = _mm512_and_si512(bits_l, m3);
+ qx[1] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
+ qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
+ qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants64(iy, ib);
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ auto sumf = _mm512_cvtepi32_ps(acc[iy]);
+ auto sum4 = _mm_mul_ps(_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
+ sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
+ info.store(ix, iy, sum4);
+ }
+ }
+ }
+}
+#else
+template <int nrc_y>
+static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if (nrc_x%4) {
+ printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
+ GGML_ABORT("fatal error");
+ }
+ mul_mat_iq2_bn_r4_q8_k16_avx2<nrc_y>(n, vx, bx, info, nrc_x);
+}
+#endif
+
+
+} // namespace
+
+bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& funcs, mul_mat_t& func16) {
+
+ auto expected_typeB = GGML_TYPE_Q8_K128;
+
+ func16 = nullptr;
+
+ switch (typeA) {
+ case GGML_TYPE_IQ1_S:
+ if (ne00%QK_K != 0) return false;
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_s_q8_K, funcs);
+#ifdef HAVE_FANCY_SIMD
+ func16 = mul_mat_iq1_s_q8_K<16>;
+#endif
+ expected_typeB = GGML_TYPE_Q8_K;
+ break;
+ case GGML_TYPE_IQ1_S_R4:
+ if (ne00%128 != 0) return false;
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_s_r4_q8_1, funcs);
+#ifdef HAVE_FANCY_SIMD
+ func16 = mul_mat_iq1_s_r4_q8_1<16>;
+#endif
+ break;
+ case GGML_TYPE_IQ1_M_R4:
+ if (ne00%128 != 0) return false;
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_r4_q8_0, funcs);
+#ifdef HAVE_FANCY_SIMD
+ func16 = mul_mat_iq1_m_r4_q8_0<16>;
+#endif
+ break;
+ case GGML_TYPE_IQ1_BN:
+ if (ne00 % QK_IQ1BN != 0) return false;
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1bn_q8_K64, funcs);
+ expected_typeB = GGML_TYPE_Q8_K64;
+ break;
+ case GGML_TYPE_IQ2_BN:
+ if (ne00 % QK_IQ1BN != 0) return false;
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2bn_q8_K64, funcs);
+ expected_typeB = GGML_TYPE_Q8_K64;
+ break;
+ case GGML_TYPE_IQ2_BN_R4:
+ if (ne00 % QK_IQ1BN != 0) return false;
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_bn_r4_q8_k16, funcs);
+ expected_typeB = GGML_TYPE_Q8_K16;
+ break;
+
+ default:
+ return false;
+ }
+
+ return ggml_type(typeB) == expected_typeB;
+
+}
+
+#else
+// -------------------------------- __aarch64__
+
+namespace {
+
+template <int nrc> struct Q8_K64 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8_K64(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ std::memcpy(d + 8*iy, dptr, 8*sizeof(float));
+ y[iy] = (const int8_t *)(dptr + 8);
+ }
+ }
+
+ inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy] + 128*i + 64*j); }
+ inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy] + 128*i + 32*j); }
+ inline float32x4_t scale(int iy) const { return vld1q_f32(d + 8*iy); }
+ inline float32x4_t minus(int iy) const { return vld1q_f32(d + 8*iy + 4); }
+
+ float d[8*nrc_y];
+ const int8_t * y[nrc_y];
+};
+
+struct DequantizerIQ1BN {
+ const uint8x16_t m1 = vdupq_n_u8(1);
+
+ static inline uint8x16x4_t load_shuffles() {
+ static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12,
+ 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12,
+ 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12,
+ 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12};
+ return vld1q_u8_x4(data);
+ }
+ static inline uint8x16x4_t load_mult() {
+ static const uint8_t data[64] = {81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81,
+ 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 27,
+ 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 9,
+ 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 3};
+ return vld1q_u8_x4(data);
+ }
+ const uint8x16x4_t shuff = load_shuffles();
+ const uint8x16x4_t mult = load_mult();
+
+ IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
+ auto data = vld1q_u8((const uint8_t *)x);
+ for (int k = 0; k < 4; ++k) {
+ auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
+ val = vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6);
+ v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1);
+ }
+ }
+
+ IQK_ALWAYS_INLINE void prepare_iq1bn_quants_nosub(const block_iq1_bn * x, int8x16x4_t& v) const {
+ auto data = vld1q_u8((const uint8_t *)x);
+ for (int k = 0; k < 4; ++k) {
+ auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
+ v.val[k] = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6));
+ }
+ }
+};
+
+template <int nrc_y>
+static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK_IQ1BN;
+
+ Q8_K64<nrc_y> q8(info);
+ DequantizerIQ1BN deq;
+
+ int32x4_t accd[nrc_y];
+ int8x16x4_t v1, v2;
+
+ float scale;
+ ggml_half d16;
+ char * c16 = (char *)&d16;
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ const char * cx = ((const char *)vx + ix*bx);
+ c16[0] = cx[0]; c16[1] = cx[1];
+ //std::memcpy(&d16, cx, sizeof(d16));
+ cx += sizeof(d16);
+ scale = GGML_FP16_TO_FP32(d16);
+
+ const block_iq1_bn * x = (const block_iq1_bn *)cx;
+
+ if constexpr (nrc_y == 1) {
+ int32x4_t acc[4] = {};
+ for (int i = 0; i < nb/2; ++i) {
+ deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1);
+ auto q = q8.load_quants64(0, i, 0);
+ for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
+ deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2);
+ q = q8.load_quants64(0, i, 1);
+ for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v2.val[j]);
+ }
+ accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3]));
+ }
+ else {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0);
+
+ for (int i = 0; i < nb/2; ++i) {
+
+ deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1);
+ deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2);
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto q = q8.load_quants(iy, i, 0);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
+ q = q8.load_quants(iy, i, 1);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
+ q = q8.load_quants(iy, i, 2);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]);
+ q = q8.load_quants(iy, i, 3);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]);
+ }
+ }
+ }
+ int i = 2*(nb/2);
+ if (i < nb) {
+ deq.prepare_iq1bn_quants_nosub(x+i, v1);
+ if constexpr (nrc_y == 1) {
+ auto q = q8.load_quants(0, i/2, 0);
+ for (int j = 0; j < 4; ++j) {
+ accd[0] = ggml_vdotq_s32(accd[0], q.val[j], v1.val[j]);
+ }
+ } else {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto q = q8.load_quants(iy, i/2, 0);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
+ q = q8.load_quants(iy, i/2, 1);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
+ }
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, -scale * vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
+ }
+
+ }
+}
+
+template <int nrc> struct Q8_16 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8_16(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto ptr = (const float *)info.src1_row(iy);
+ std::memcpy(d + 5*iy, ptr, 5*sizeof(float));
+ y[iy] = (const int8_t *)(ptr + 5);
+ }
+ }
+
+ inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy] + 64*i); }
+ inline int8x16x2_t load_quants_32(int iy, int i) const { return vld1q_s8_x2(y[iy] + 32*i); }
+ inline float scale(int iy, int k) const { return d[5*iy+k]; }
+ inline float sum_row(int iy) const { return d[5*iy + 4]; }
+ inline float32x4_t scale(int iy) const { return vld1q_f32(d + 5*iy); }
+
+ float d[5*nrc_y];
+ const int8_t * y[nrc_y];
+};
+
+template <int nrc_y>
+static IQK_NOINLINE void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if (nrc_x%4) {
+ printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
+ GGML_ABORT("fatal error");
+ }
+ Q8_16<nrc_y> q8(info);
+ auto m3 = vdupq_n_u8(0x3);
+ int nb = n / QK_IQ1BN;
+ if constexpr (nrc_y == 1) {
+ auto mc = vdupq_n_u8(0xc);
+ int32x4_t acc[8];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ for (int k = 0; k < 8; ++k) acc[k] = vdupq_n_s32(0);
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = vld1q_f32(dptr);
+ const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto y = q8.load_quants(0, ib);
+ for (int j = 0; j < 4; ++j) {
+ auto bits1 = vld1q_u8(iq2 + 64*ib + 16*j);
+ auto bits2 = vshrq_n_u8(bits1, 4);
+ acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits1, m3), y.val[j], 0);
+ acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits1, mc), y.val[j], 1);
+ acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits2, m3), y.val[j], 2);
+ acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits2, mc), y.val[j], 3);
+ }
+ }
+ auto dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 0)));
+ auto sumf1 = vmulq_f32( vcvtq_f32_s32(acc[0]), dy);
+ auto sumf2 = vmulq_f32( vcvtq_f32_s32(acc[1]), dy);
+ dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 1)));
+ sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[2]), dy);
+ sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[3]), dy);
+ dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 2)));
+ sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[4]), dy);
+ sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[5]), dy);
+ dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 3)));
+ sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[6]), dy);
+ sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[7]), dy);
+ auto sumf = vfmaq_f32(sumf1, vdupq_n_f32(0.25f), sumf2);
+ sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(0)));
+ info.store(ix, 0, sumf);
+ }
+ } else {
+ int32x4_t acc[4*nrc_y] = {};
+ uint8x16_t qx[8];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = vld1q_f32(dptr);
+ const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits = vld1q_u8_x2(iq2 + 64*ib);
+ qx[0] = vandq_u8(bits.val[0], m3);
+ qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
+ qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
+ qx[3] = vshrq_n_u8(bits.val[0], 6);
+ qx[4] = vandq_u8(bits.val[1], m3);
+ qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
+ qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
+ qx[7] = vshrq_n_u8(bits.val[1], 6);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants_32(iy, 2*ib+0);
+ acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[0], y.val[0], 0);
+ acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[1], y.val[0], 1);
+ acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[2], y.val[0], 2);
+ acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[3], y.val[0], 3);
+ acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[4], y.val[1], 0);
+ acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[5], y.val[1], 1);
+ acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[6], y.val[1], 2);
+ acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[7], y.val[1], 3);
+ }
+ bits = vld1q_u8_x2(iq2 + 64*ib + 32);
+ qx[0] = vandq_u8(bits.val[0], m3);
+ qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
+ qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
+ qx[3] = vshrq_n_u8(bits.val[0], 6);
+ qx[4] = vandq_u8(bits.val[1], m3);
+ qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
+ qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
+ qx[7] = vshrq_n_u8(bits.val[1], 6);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants_32(iy, 2*ib+1);
+ acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[0], y.val[0], 0);
+ acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[1], y.val[0], 1);
+ acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[2], y.val[0], 2);
+ acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[3], y.val[0], 3);
+ acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[4], y.val[1], 0);
+ acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[5], y.val[1], 1);
+ acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[6], y.val[1], 2);
+ acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[7], y.val[1], 3);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ float32x4_t sumf = vmulq_f32(vcvtq_f32_s32(acc[4*iy+0]), vmulq_laneq_f32(dl, dy, 0));
+ sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+1]), vmulq_laneq_f32(dl, dy, 1));
+ sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+2]), vmulq_laneq_f32(dl, dy, 2));
+ sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+3]), vmulq_laneq_f32(dl, dy, 3));
+ sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(iy)));
+ info.store(ix, iy, sumf);
+ acc[4*iy+0] = acc[4*iy+1] = acc[4*iy+2] = acc[4*iy+3] = vdupq_n_s32(0);
+ }
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK_IQ1BN;
+
+ Q8_K64<nrc_y> q8(info);
+
+ int32x4_t accd[nrc_y];
+
+ const auto mask2 = vdupq_n_s8(3);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ const float d = *dptr;
+ const block_iq2_bn * x = (const block_iq2_bn *)(dptr + 1);
+
+ if constexpr (nrc_y == 1) {
+ int8x16x4_t v1;
+ int32x4_t acc[4] = {};
+ for (int i = 0; i < nb/2; ++i) {
+ for (int j = 0; j < 2; ++j) {
+ auto q = q8.load_quants64(0, i, j);
+ auto q2bits = vld1q_u8(x[2*i+j].qs);
+ v1.val[0] = vandq_s8(q2bits, mask2);
+ v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
+ v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
+ v1.val[3] = vshrq_n_u8(q2bits, 6);
+ acc[0] = ggml_vdotq_s32(acc[0], q.val[0], v1.val[0]);
+ acc[1] = ggml_vdotq_s32(acc[1], q.val[1], v1.val[1]);
+ acc[2] = ggml_vdotq_s32(acc[2], q.val[2], v1.val[2]);
+ acc[3] = ggml_vdotq_s32(acc[3], q.val[3], v1.val[3]);
+ }
+ }
+ accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3]));
+ } else {
+ int8x16x4_t v1, v2;
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0);
+ for (int i = 0; i < nb/2; ++i) {
+ auto q2bits = vld1q_u8(x[2*i+0].qs);
+ v1.val[0] = vandq_s8(q2bits, mask2);
+ v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
+ v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
+ v1.val[3] = vshrq_n_u8(q2bits, 6);
+ q2bits = vld1q_u8(x[2*i+1].qs);
+ v2.val[0] = vandq_s8(q2bits, mask2);
+ v2.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
+ v2.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
+ v2.val[3] = vshrq_n_u8(q2bits, 6);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto q = q8.load_quants(iy, i, 0);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
+ q = q8.load_quants(iy, i, 1);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
+ q = q8.load_quants(iy, i, 2);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]);
+ q = q8.load_quants(iy, i, 3);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]);
+ }
+ }
+ }
+ int i = 2*(nb/2);
+ if (i < nb) {
+ auto q2bits = vld1q_u8(x[i].qs);
+ int8x16x4_t v1;
+ v1.val[0] = vandq_s8(q2bits, mask2);
+ v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
+ v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
+ v1.val[3] = vshrq_n_u8(q2bits, 6);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto q = q8.load_quants(iy, i/2, 0);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
+ q = q8.load_quants(iy, i/2, 1);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, -d*vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K128> q8(info);
+ int nb = n / 32;
+ GGML_ASSERT(nb%4 == 0);
+ uint8x16_t qx[8];
+ float32x4_t acc[nrc_y] = {};
+ auto ms = vdup_n_u16(0x8000);
+ auto mask = vdupq_n_s8(0x03);
+ float d8[4*nrc_y];
+ for (int ix= 0; ix < nrc_x; ix += 4) {
+ auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
+ auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr));
+ auto x = (const block_iq1_s_r4 *)(dptr + 4);
+ for (int ib = 0; ib < nb/4; ++ib) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scales = vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[iy][ib].bsums)));
+ vst1q_f32(d8+4*iy, vmulq_f32(vdupq_n_f32(q8.y[iy][ib].d), scales));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto sas = vld1_u16(x[4*ib+k].qh);
+ auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7));
+ scales4 = vorr_u16(vshl_n_u16(scales4, 1), vdup_n_u16(1));
+ auto signs = vreinterpret_s16_u16(vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1)));
+ signs = vadd_s16(vdup_n_s16(-8), signs);
+ auto delta4 = vmulq_f32(vdupq_n_f32(0.125f), vcvtq_f32_s32(vmull_s16(signs, scales4)));
+ qx[0] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
+ qx[2] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]});
+ qx[4] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]});
+ qx[6] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]});
+ qx[1] = vandq_u8(vshrq_n_u8(qx[0], 4), mask); qx[0] = vandq_u8(qx[0], mask);
+ qx[3] = vandq_u8(vshrq_n_u8(qx[2], 4), mask); qx[2] = vandq_u8(qx[2], mask);
+ qx[5] = vandq_u8(vshrq_n_u8(qx[4], 4), mask); qx[4] = vandq_u8(qx[4], mask);
+ qx[7] = vandq_u8(vshrq_n_u8(qx[6], 4), mask); qx[6] = vandq_u8(qx[6], mask);
+ auto scales = vmovl_u16(scales4);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ib].qs + 32*k);
+ auto sumi = vdupq_n_s32(0);
+ sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[0]), y.val[0], 0);
+ sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[1]), y.val[0], 1);
+ sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[2]), y.val[0], 2);
+ sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[3]), y.val[0], 3);
+ sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[4]), y.val[1], 0);
+ sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[5]), y.val[1], 1);
+ sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[6]), y.val[1], 2);
+ sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[7]), y.val[1], 3);
+ sumi = vmulq_s32(scales, sumi);
+ acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(sumi));
+ acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4*iy+k]), delta4);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vmulq_f32(d1, acc[iy]));
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K128> q8(info);
+ int nb = n / 32;
+ GGML_ASSERT(nb%4 == 0);
+ int8x16_t qx[8];
+ float32x4_t acc[nrc_y] = {};
+ int32x4_t isum[nrc_y] = {};
+ auto shuffle0 = uint32x4_t{0x00000000, 0x01010101, 0x02020202, 0x03030303};
+ auto step = vdupq_n_u8(4);
+ auto ms = vdupq_n_u8(0x08);
+ auto mask = vdupq_n_s8(0x18);
+ for (int ix= 0; ix < nrc_x; ix += 4) {
+ auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
+ auto d1 = vmulq_f32(vdupq_n_f32(0.125f), vcvt_f32_f16(vld1_f16((const float16_t *)dptr)));
+ auto x = (const block_iq1_m_r4 *)(dptr + 4);
+ for (int ib = 0; ib < nb/4; ++ib) {
+ for (int k = 0; k < 4; ++k) {
+ auto scales4 = vdup_n_u32(((const uint32_t *)x[4*ib+k].scales)[0]);
+ scales4 = vand_u8(vshl_u32(scales4, int32x2_t{0, -4}), vdup_n_u8(0xf));
+ auto scales16 = vmovl_u8(scales4);
+ auto scales1 = vmovl_u16(vget_low_u16(scales16));
+ auto scales2 = vmovl_u16(vget_high_u16(scales16));
+ auto qh = (const uint32_t *)x[4*ib+k].qh;
+ auto idxh = uint32x4_t{qh[0], qh[0] >> 4, qh[1], qh[1] >> 4};
+ auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(idxh, ms), ms), vdupq_n_u8(1)));
+ signs = vaddq_s8(signs, vdupq_n_s8(-8));
+ qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
+ qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]});
+ qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]});
+ qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]});
+ auto shuffle = shuffle0;
+ for (int j = 0; j < 4; ++j) {
+ auto s = vqtbl1q_s8(signs, shuffle);
+ qx[2*j+1] = vaddq_s8(s, vandq_s8(vshrq_n_s8(qx[2*j+0], 1), mask));
+ qx[2*j+0] = vaddq_s8(s, vandq_s8(vshlq_n_s8(qx[2*j+0], 3), mask));
+ shuffle = vaddq_u8(shuffle, step);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ib].qs + 32*k);
+ auto sumi1 = vdupq_n_s32(0);
+ auto sumi2 = vdupq_n_s32(0);
+ sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[0]), y.val[0], 0);
+ sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[1]), y.val[0], 1);
+ sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[2]), y.val[0], 2);
+ sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[3]), y.val[0], 3);
+ sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[4]), y.val[1], 0);
+ sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[5]), y.val[1], 1);
+ sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[6]), y.val[1], 2);
+ sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[7]), y.val[1], 3);
+ isum[iy] = vmlaq_s32(vmlaq_s32(isum[iy], sumi1, scales1), sumi2, scales2);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(isum[iy]));
+ isum[iy] = vdupq_n_s32(0);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vmulq_f32(d1, acc[iy]));
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<1, block_q8_K128> q8(info);
+ int nb = n / 32;
+ GGML_ASSERT(nb%4 == 0);
+ int8x16_t qx[8];
+ float32x4_t acc[2] = {};
+ int32x4_t isum[8];
+ auto ms = vdup_n_u16(0x8000);
+ for (int ix= 0; ix < nrc_x; ix += 4) {
+ auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
+ auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr));
+ auto x = (const block_iq1_s_r4 *)(dptr + 4);
+ for (int ib = 0; ib < nb/4; ++ib) {
+ auto scale_yd = vdupq_n_f32(q8.y[0][ib].d);
+ auto scale_ym = vmulq_f32(scale_yd, vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[0][ib].bsums))));
+ for (int k = 0; k < 4; ++k) {
+ auto sas = vld1_u16(x[4*ib+k].qh);
+ auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7));
+ scales4 = vorr_u16(vshl_n_u16(scales4, 1), vdup_n_u16(1));
+ auto signs = vreinterpret_s16_u16(vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1)));
+ isum[k+4] = vmull_s16(signs, scales4);
+ qx[0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
+ iq1s_grid[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)]});
+ qx[1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)],
+ iq1s_grid[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)]});
+ qx[2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
+ iq1s_grid[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)]});
+ qx[3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)],
+ iq1s_grid[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)]});
+ qx[4] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
+ iq1s_grid[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)]});
+ qx[5] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)],
+ iq1s_grid[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)]});
+ qx[6] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)],
+ iq1s_grid[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]});
+ qx[7] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)],
+ iq1s_grid[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]});
+ auto scales = vmovl_u16(scales4);
+ auto y = vld1q_s8_x2(q8.y[0][ib].qs + 32*k);
+ auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]);
+ auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]);
+ auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]);
+ auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]);
+ sumi1 = vpaddq_s32(sumi1, sumi2);
+ sumi3 = vpaddq_s32(sumi3, sumi4);
+ isum[k] = vmulq_s32(scales, vpaddq_s32(sumi1, sumi3));
+ }
+ acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[0]), scale_yd, 0);
+ acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[1]), scale_yd, 1);
+ acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[2]), scale_yd, 2);
+ acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[3]), scale_yd, 3);
+ acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[4]), scale_ym, 0);
+ acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[5]), scale_ym, 1);
+ acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[6]), scale_ym, 2);
+ acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[7]), scale_ym, 3);
+ }
+ info.store(ix, 0, vmulq_f32(d1, vfmaq_f32(acc[0], acc[1], vdupq_n_f32(IQ1S_DELTA))));
+ acc[0] = acc[1] = vdupq_n_f32(0.f);
+ }
+}
+
+template <int nrc_y>
+void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ int8x16_t qx[16];
+ int32x4_t scales[2];
+ int16x4_t deltas[2];
+ float32x4_t acc[nrc_y] = {};
+ auto delta_mask = vdupq_n_u16(0x8000);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < n/QK_K; ++ibl) {
+ float d = GGML_FP16_TO_FP32(iq1s[ibl].d);
+ auto qhb = vld1q_u16(iq1s[ibl].qh);
+ auto scales128 = vandq_u16(vshrq_n_u16(qhb, 12), vdupq_n_u16(7));
+ scales128 = vaddq_u16(vshlq_n_u16(scales128, 1), vdupq_n_u16(1));
+ auto mask = vceqq_u16(vandq_u16(qhb, delta_mask), delta_mask);
+ // Note: we explicitely assume IQ1S_DELTA = 0.125
+ auto deltas128 = vsubq_s16(vbicq_s16(scales128, mask), vandq_s16(scales128, mask));
+ //auto deltas128 = vorrq_s16(vandq_s16(vdupq_n_s16(-1), mask), vbicq_s16(vdupq_n_s16(1), mask));
+ //deltas128 = vmulq_s16(scales128, deltas128);
+ scales128 = vshlq_n_u16(scales128, 3);
+ auto qs = iq1s[ibl].qs;
+ auto qh = iq1s[ibl].qh;
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+ qx[4*ib64+0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[0] | ((qh[2*ib64+0] << 8) & 0x700)], iq1s_grid[qs[1] | ((qh[2*ib64+0] << 5) & 0x700)]});
+ qx[4*ib64+1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[2] | ((qh[2*ib64+0] << 2) & 0x700)], iq1s_grid[qs[3] | ((qh[2*ib64+0] >> 1) & 0x700)]});
+ qx[4*ib64+2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[4] | ((qh[2*ib64+1] << 8) & 0x700)], iq1s_grid[qs[5] | ((qh[2*ib64+1] << 5) & 0x700)]});
+ qx[4*ib64+3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[6] | ((qh[2*ib64+1] << 2) & 0x700)], iq1s_grid[qs[7] | ((qh[2*ib64+1] >> 1) & 0x700)]});
+ qs += 8;
+ }
+ scales[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales128)));
+ scales[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales128)));
+ deltas[0] = vget_low_s16 (deltas128);
+ deltas[1] = vget_high_s16(deltas128);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums8(iy, ibl);
+ auto sumi = vdupq_n_s32(0);
+ sumi = vmlal_s16(sumi, deltas[0], vget_low_s16 (bsums));
+ sumi = vmlal_s16(sumi, deltas[1], vget_high_s16(bsums));
+ for (int k = 0; k < QK_K/128; ++k) {
+ auto qy = q8.load_quants_64(iy, ibl, 2*k+0);
+ auto dot1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+0], qy.val[0]), qx[8*k+1], qy.val[1]);
+ auto dot2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+2], qy.val[2]), qx[8*k+3], qy.val[3]);
+ auto dot12 = vpaddq_s32(dot1, dot2);
+ qy = q8.load_quants_64(iy, ibl, 2*k+1);
+ auto dot3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+4], qy.val[0]), qx[8*k+5], qy.val[1]);
+ auto dot4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+6], qy.val[2]), qx[8*k+7], qy.val[3]);
+ auto dot34 = vpaddq_s32(dot3, dot4);
+ auto dot = vpaddq_s32(dot12, dot34);
+ sumi = vmlaq_s32(sumi, dot, scales[k]);
+ }
+ acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, ibl)), vcvtq_f32_s32(sumi));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, 0.125f*vaddvq_f32(acc[iy]));
+ acc[iy] = vdupq_n_f32(0);
+ }
+ }
+}
+}
+
+bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& funcs, mul_mat_t& func16) {
+
+ auto expected_Btype = GGML_TYPE_Q8_K128;
+
+ func16 = nullptr;
+
+ switch (typeA) {
+ case GGML_TYPE_IQ1_BN:
+ if (ne00 % QK_IQ1BN != 0) return false;
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1bn_q8_K64, funcs);
+ expected_Btype = GGML_TYPE_Q8_K64;
+ break;
+ case GGML_TYPE_IQ2_BN:
+ if (ne00 % QK_IQ1BN != 0) return false;
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2bn_q8_K64, funcs);
+ expected_Btype = GGML_TYPE_Q8_K64;
+ break;
+ case GGML_TYPE_IQ2_BN_R4:
+ if (ne00 % QK_IQ1BN != 0) return false;
+ funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
+ funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
+ funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
+ funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
+ funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
+ //funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
+ //funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
+ //funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
+ expected_Btype = GGML_TYPE_Q8_K16;
+ break;
+ case GGML_TYPE_IQ1_S:
+ if (ne00%QK_K != 0) return false;
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_s_q8_K, funcs);
+ func16 = mul_mat_iq1_s_q8_K<16>;
+ expected_Btype = GGML_TYPE_Q8_K;
+ break;
+ case GGML_TYPE_IQ1_S_R4:
+ if (ne00%128 != 0) return false;
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_s_r4_q8_1, funcs);
+ funcs[0] = mul_mat_iq1_s_r4_q8_1_1;
+ func16 = mul_mat_iq1_s_r4_q8_1<16>;
+ expected_Btype = GGML_TYPE_Q8_K128;
+ break;
+ case GGML_TYPE_IQ1_M_R4:
+ if (ne00%128 != 0) return false;
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_r4_q8_0, funcs);
+ func16 = mul_mat_iq1_m_r4_q8_0<16>;
+ expected_Btype = GGML_TYPE_Q8_K128;
+ break;
+ default:
+ return false;
+ }
+
+ return ggml_type(typeB) == expected_Btype;
+
+}
+
+#endif
+
+#endif
diff --git a/ggml/src/iqk/iqk_gemm_1bit.h b/ggml/src/iqk/iqk_gemm_1bit.h
new file mode 100644
index 00000000..80309187
--- /dev/null
+++ b/ggml/src/iqk/iqk_gemm_1bit.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include "iqk_common.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include <array>
+
+bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
+
+#endif
diff --git a/ggml/src/iqk/iqk_gemm_floats.cpp b/ggml/src/iqk/iqk_gemm_floats.cpp
new file mode 100644
index 00000000..5165eb98
--- /dev/null
+++ b/ggml/src/iqk/iqk_gemm_floats.cpp
@@ -0,0 +1,1048 @@
+#include "iqk_gemm_floats.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include "ggml-impl.h"
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#ifdef __x86_64__
+
+namespace {
+
+// float matrices - we handle f16, bf16 (if native bf16 support is available) and f32, but only to f32 result
+
+struct QFBase {
+#ifdef __AVX512F__
+ constexpr static int k_step = 16;
+ using Data = __m512;
+ using Acc = __m512;
+ static inline Data load(const ggml_half * x) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x)); }
+ static inline Data load(const float * x) { return _mm512_loadu_ps(x); }
+ static inline Data load(const ggml_bf16_t * x) {
+ return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)x)), 16));
+ }
+ static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ return _mm512_fmadd_ps(y, x, prev);
+ }
+ static inline Acc acc_first(const Data& y, const Data& x) {
+ return _mm512_mul_ps(y, x);
+ }
+ static inline Acc add(Acc x, Acc y) { return _mm512_add_ps(x, y); }
+ static inline float hsum(Acc acc) {
+ return _mm512_reduce_add_ps(acc);
+ }
+ template <typename Float>
+ static inline Data load4Floats(const Float * x) {
+ return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0);
+ }
+ static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
+ acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc);
+ acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);
+ acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);
+ acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);
+ return acc;
+ }
+ static inline Acc acc_r4_first(const Data * xv, const Data& yv) {
+ auto acc = _mm512_mul_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00));
+ acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);
+ acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);
+ acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);
+ return acc;
+ }
+ static inline __m128 hsum_r4(Acc acc) {
+ auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 0), _mm512_extractf32x4_ps(acc, 1));
+ auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 2), _mm512_extractf32x4_ps(acc, 3));
+ return _mm_add_ps(sum1, sum2);
+ }
+#else
+ constexpr static int k_step = 8;
+ using Data = __m256;
+ using Acc = __m256;
+ static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); }
+ static inline Data load(const float * x) { return _mm256_loadu_ps(x); }
+ static inline Data load(const ggml_bf16_t * x) {
+ return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)x)), 16));
+ }
+ static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ return _mm256_fmadd_ps(y, x, prev);
+ }
+ static inline Acc add(Acc x, Acc y) { return _mm256_add_ps(x, y); }
+ static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
+ acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc);
+ acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
+ acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);
+ acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);
+ return acc;
+ }
+ static inline Acc acc_r4_first(const Data * xv, const Data& yv) {
+ auto acc = _mm256_mul_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00));
+ acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
+ acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);
+ acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);
+ return acc;
+ }
+ static inline Acc acc_first(const Data& y, const Data& x) {
+ return _mm256_mul_ps(y, x);
+ }
+ static inline float hsum(Acc acc) {
+ return hsum_float_8(acc);
+ }
+ static inline __m128 hsum_r4(Acc acc) {
+ return _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
+ }
+ template <typename Float>
+ static inline Data load4Floats(const Float * x) {
+ return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0);
+ }
+#endif
+ static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); }
+ static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); }
+ static inline __m128 load128(const ggml_bf16_t * x) {
+ return _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*)x)), 16));
+ }
+};
+
+template <typename Float, int nrc_in> struct QFT final : public QFBase {
+ constexpr static int nrc = nrc_in;
+ QFT(const DataInfo& info) {
+ for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)info.src1_row(iy);
+ }
+ QFT(const char * cx, size_t bx) {
+ for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)(cx + iy*bx);
+ }
+ IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
+ IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); }
+ IQK_ALWAYS_INLINE void load_r4(int ix, int i, Data * xv) const {
+ xv[0] = load1(ix+0, i);
+ xv[1] = load1(ix+1, i);
+ xv[2] = load1(ix+2, i);
+ xv[3] = load1(ix+3, i);
+#ifdef __AVX512F__
+ auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]);
+ auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]);
+ auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]);
+ auto t3 = _mm512_unpackhi_ps(xv[2], xv[3]);
+ xv[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));
+ xv[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));
+ xv[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));
+ xv[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));
+#else
+ auto t0 = _mm256_unpacklo_ps(xv[0], xv[1]);
+ auto t1 = _mm256_unpacklo_ps(xv[2], xv[3]);
+ auto t2 = _mm256_unpackhi_ps(xv[0], xv[1]);
+ auto t3 = _mm256_unpackhi_ps(xv[2], xv[3]);
+ xv[0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
+ xv[1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
+ xv[2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
+ xv[3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
+#endif
+ }
+ const Float * y[nrc];
+};
+
+// TBD if we want this
+//template <typename Qy, typename Qx>
+//IQK_NOINLINE void mul_mat_Qx_Qy_Mx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+// static_assert(Qy::nrc == 1);
+// int nb = n/QFBase::k_step;
+// int nb4 = n/4;
+// Qy y(info);
+// Qx x(cx + ix0*bx, bx);
+// QFBase::Data xv[2*Qx::nrc];
+// QFBase::Acc acc[2*Qx::nrc];
+// auto yv1 = y.load1(0, 0);
+// auto yv2 = y.load1(0, 1);
+// for (int ix = 0; ix < Qx::nrc; ++ix) {
+// xv[2*ix+0] = x.load1(ix, 0);
+// xv[2*ix+1] = x.load1(ix, 1);
+// acc[2*ix+0] = QFBase::acc_first(yv1, xv[2*ix+0]);
+// acc[2*ix+1] = QFBase::acc_first(yv2, xv[2*ix+1]);
+// }
+// for (int i = 1; i < nb/2; ++i) {
+// yv1 = y.load1(0, 2*i+0);
+// yv2 = y.load1(0, 2*i+1);
+// for (int ix = 0; ix < Qx::nrc; ++ix) {
+// xv[2*ix+0] = x.load1(ix, 2*i+0);
+// xv[2*ix+1] = x.load1(ix, 2*i+1);
+// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[2*ix+0]);
+// acc[2*ix+1] = QFBase::acc(acc[2*ix+1], yv2, xv[2*ix+1]);
+// }
+// }
+// for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
+// yv1 = y.load_tail(0, i);
+// for (int ix = 0; ix < Qx::nrc; ++ix) {
+// xv[ix] = x.load_tail(ix, i);
+// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[ix]);
+// }
+// }
+// for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, 0, QFBase::hsum(QFBase::add(acc[2*ix+0], acc[2*ix+1])));
+//}
+
+template <typename Qy, typename Qx>
+IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ int nb = n/QFBase::k_step;
+ int nb4 = n/4;
+ Qy y(info);
+ Qx x(cx + ix0*bx, bx);
+ QFBase::Data xv[Qx::nrc];
+ QFBase::Acc acc[Qx::nrc*Qy::nrc];
+ auto yv = y.load1(0, 0);
+ for (int ix = 0; ix < Qx::nrc; ++ix) {
+ xv[ix] = x.load1(ix, 0);
+ acc[ix] = QFBase::acc_first(yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc; ++iy) {
+ yv = y.load1(iy, 0);
+ for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]);
+ }
+ for (int i = 1; i < nb; ++i) {
+ yv = y.load1(0, i);
+ for (int ix = 0; ix < Qx::nrc; ++ix) {
+ xv[ix] = x.load1(ix, i);
+ acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc; ++iy) {
+ yv = y.load1(iy, i);
+ for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
+ }
+ }
+ for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
+ yv = y.load_tail(0, i);
+ for (int ix = 0; ix < Qx::nrc; ++ix) {
+ xv[ix] = x.load_tail(ix, i);
+ acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc; ++iy) {
+ yv = y.load_tail(iy, i);
+ for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
+ }
+ }
+ for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
+}
+
+template <typename Qy, typename Qx>
+inline void mul_mat_Qx_Qy_MxN_fa(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ int nb = n/QFBase::k_step;
+ Qy y(info);
+ Qx x(cx + ix0*bx, bx);
+ QFBase::Data xv[Qx::nrc];
+ QFBase::Acc acc[Qx::nrc*Qy::nrc];
+ auto yv = y.load1(0, 0);
+ for (int ix = 0; ix < Qx::nrc; ++ix) {
+ xv[ix] = x.load1(ix, 0);
+ acc[ix] = QFBase::acc_first(yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc; ++iy) {
+ yv = y.load1(iy, 0);
+ for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]);
+ }
+ for (int i = 1; i < nb; ++i) {
+ yv = y.load1(0, i);
+ for (int ix = 0; ix < Qx::nrc; ++ix) {
+ xv[ix] = x.load1(ix, i);
+ acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc; ++iy) {
+ yv = y.load1(iy, i);
+ for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
+ }
+ }
+ for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
+}
+
+template <typename Qy, typename Qx>
+inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ static_assert(Qx::nrc%4 == 0);
+ int nb = D/QFBase::k_step;
+ Qy y(info);
+ Qx x(cx + ix0*bx, bx);
+ QFBase::Data xv[Qx::nrc];
+ QFBase::Acc acc[Qx::nrc*Qy::nrc/4] = {};
+ for (int i = 0; i < nb; ++i) {
+ for (int ix = 0; ix < Qx::nrc/4; ++ix) x.load_r4(4*ix, i, xv + 4*ix);
+ for (int iy = 0; iy < Qy::nrc; ++iy) {
+ auto yv = y.load1(iy, i);
+ for (int ix = 0; ix < Qx::nrc/4; ++ix) acc[ix*Qy::nrc + iy] = QFBase::acc_r4(acc[ix*Qy::nrc + iy], xv + 4*ix, yv);
+ }
+ }
+ for (int iy = 0; iy < Qy::nrc; ++iy) {
+ for (int ix = 0; ix < Qx::nrc/4; ++ix) info.store(ix0+4*ix, iy, QFBase::hsum_r4(acc[ix*Qy::nrc + iy]));
+ }
+}
+
+// This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done
+// in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in
+// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.
+template <int nrc_y, typename FloatX, typename FloatY>
+void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const char * cx = (const char *)vx;
+ // TBD if we want this
+ //if constexpr (nrc_y == 1) {
+ // constexpr int k_nx = 2;
+ // for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ // mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
+ // }
+ // if (int lastx = k_nx*(nrc_x/k_nx); lastx < nrc_x) {
+ // int nx = nrc_x - lastx;
+ // switch (nx) {
+ // case 1: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info); break;
+ // case 2: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, lastx, info); break;
+ // case 3: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, lastx, info); break;
+ // }
+ // //mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info);
+ // }
+ // return;
+ //}
+#ifdef __AVX512F__
+ constexpr int k_nx = 5;
+#else
+ constexpr int k_nx = nrc_y == 1 ? 4 : 2;
+#endif
+ for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
+ }
+ int last_x = k_nx*(nrc_x/k_nx);
+ if (last_x == nrc_x) return;
+ int nx = nrc_x - last_x;
+#ifdef __AVX512F__
+ switch (nx) {
+ case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
+ case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
+ case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
+ case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;
+ }
+#else
+ if constexpr (nrc_y == 1) {
+ switch (nx) {
+ case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
+ case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
+ case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
+ }
+ } else {
+ switch (nx) {
+ case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
+ }
+ }
+#endif
+}
+
+#ifdef __AVX512BF16__
+template <int nrc_y>
+static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%16 == 0);
+ const ggml_bf16_t * y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
+ for (int ix = 0; ix < nrc_x/32; ++ix) {
+ __m512 acc[2*nrc_y] = {};
+ __m512bh qx[8];
+ const ggml_bf16_t * b8_1 = (const ggml_bf16_t *)((const char *)vx + (32*ix+ 0)*bx);
+ const ggml_bf16_t * b8_2 = (const ggml_bf16_t *)((const char *)vx + (32*ix+16)*bx);
+ for (int ib = 0; ib < n/8; ++ib) {
+ qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+0);
+ qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+1);
+ qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+2);
+ qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+3);
+ qx[4] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+0);
+ qx[5] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+1);
+ qx[6] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+2);
+ qx[7] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
+ //auto y = _mm512_broadcast_i32x4(y128);
+ auto y256 = MM256_SET_M128I(y128, y128);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[4], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[5], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[6], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[7], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(32*ix+ 0, iy, acc[2*iy+0]);
+ info.store(32*ix+16, iy, acc[2*iy+1]);
+ }
+ }
+ for (int ix = 32*(nrc_x/32); ix < nrc_x; ix += 16) {
+ __m512 acc[nrc_y] = {};
+ __m512bh qx[4];
+ const ggml_bf16_t * b8 = (const ggml_bf16_t *)((const char *)vx + (ix+0)*bx);
+ for (int ib = 0; ib < n/8; ++ib) {
+ qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+0);
+ qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1);
+ qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2);
+ qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
+ auto y256 = MM256_SET_M128I(y128, y128);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ }
+ }
+}
+
+struct QFBaseBF16 {
+ constexpr static int k_step = 32;
+ using Data = __m512bh;
+ using Acc = __m512;
+ static inline Data load(const ggml_bf16_t * x) { return __m512bh(_mm512_loadu_si512((const __m512i *)x)); }
+ static inline Acc acc(Acc prev, Data y, Data x) {
+ return _mm512_dpbf16_ps(prev, y, x);
+ }
+ static inline Acc acc_first(const Data& y, const Data& x) {
+ return _mm512_dpbf16_ps(_mm512_setzero_ps(), y, x);
+ }
+ static inline float hsum(Acc acc) {
+ return _mm512_reduce_add_ps(acc);
+ }
+};
+template <int nrc_in> struct QFTBF16 final : public QFBaseBF16 {
+ constexpr static int nrc = nrc_in;
+ QFTBF16(const DataInfo& info) {
+ for (int iy = 0; iy < nrc; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
+ }
+ QFTBF16(const char * cx, size_t bx) {
+ for (int iy = 0; iy < nrc; ++iy) y[iy] = (const ggml_bf16_t *)(cx + iy*bx);
+ }
+ IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
+ const ggml_bf16_t * y[nrc];
+};
+
+template <int nrc_y, int nrc_x>
+IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ int nb = n/QFBaseBF16::k_step;
+ QFTBF16<nrc_y> y(info);
+ QFTBF16<nrc_x> x(cx + ix0*bx, bx);
+ QFBaseBF16::Data xv[nrc_x];
+ QFBaseBF16::Acc acc[nrc_x*nrc_y];
+ auto yv = y.load1(0, 0);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ xv[ix] = x.load1(ix, 0);
+ acc[ix] = QFBaseBF16::acc_first(yv, xv[ix]);
+ }
+ for (int iy = 1; iy < nrc_y; ++iy) {
+ yv = y.load1(iy, 0);
+ for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QFBaseBF16::acc_first(yv, xv[ix]);
+ }
+ for (int i = 1; i < nb; ++i) {
+ yv = y.load1(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ xv[ix] = x.load1(ix, i);
+ acc[ix] = QFBaseBF16::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < nrc_y; ++iy) {
+ yv = y.load1(iy, i);
+ for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QFBaseBF16::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QFBaseBF16::hsum(acc[nrc_x*iy+ix]));
+}
+
+template <int nrc_y>
+void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ constexpr int k_nx = nrc_y <= 2 ? 8 : 5;
+ const char * cx = (const char *)vx;
+ for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ mul_mat_Qx_Qy_MxN<nrc_y, k_nx>(n, cx, bx, ix*k_nx, info);
+ }
+ int last_x = k_nx*(nrc_x/k_nx);
+ if (last_x == nrc_x) return;
+ int nx = nrc_x - last_x;
+ if constexpr (nrc_y <= 2) {
+ if (nx >= 4) {
+ mul_mat_Qx_Qy_MxN<nrc_y, 4>(n, cx, bx, last_x, info);
+ last_x += 4;
+ if (last_x == nrc_x) return;
+ nx = nrc_x - last_x;
+ }
+ }
+ switch (nx) {
+ case 1: mul_mat_Qx_Qy_MxN<nrc_y, 1>(n, cx, bx, last_x, info); break;
+ case 2: mul_mat_Qx_Qy_MxN<nrc_y, 2>(n, cx, bx, last_x, info); break;
+ case 3: mul_mat_Qx_Qy_MxN<nrc_y, 3>(n, cx, bx, last_x, info); break;
+ case 4: mul_mat_Qx_Qy_MxN<nrc_y, 4>(n, cx, bx, last_x, info); break;
+ }
+}
+#endif
+
+
+template <typename FloatX, typename FloatY>
+void set_mul_mat_f(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
+ for (auto& f : funcs) f = nullptr;
+ funcs[0] = mul_mat_fX_fY_T<1, FloatX, FloatY>;
+ funcs[1] = mul_mat_fX_fY_T<2, FloatX, FloatY>;
+ funcs[2] = mul_mat_fX_fY_T<3, FloatX, FloatY>;
+ funcs[3] = mul_mat_fX_fY_T<4, FloatX, FloatY>;
+ funcs[4] = mul_mat_fX_fY_T<5, FloatX, FloatY>;
+#ifndef __AVX512F__
+ funcs[5] = mul_mat_fX_fY_T<6, FloatX, FloatY>;
+#endif
+}
+
+#ifdef __AVX512BF16__
+void set_mul_mat_bf16(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
+ for (auto& f : funcs) f = nullptr;
+ funcs[0] = mul_mat_fX_fY_T<1>;
+ funcs[1] = mul_mat_fX_fY_T<2>;
+ funcs[2] = mul_mat_fX_fY_T<3>;
+ funcs[3] = mul_mat_fX_fY_T<4>;
+ funcs[4] = mul_mat_fX_fY_T<5>;
+}
+void set_mul_mat_bf16_r16(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
+ for (auto& f : funcs) f = nullptr;
+ funcs[0] = mul_mat_bf16_r16_bf16<1>;
+ funcs[1] = mul_mat_bf16_r16_bf16<2>;
+ funcs[2] = mul_mat_bf16_r16_bf16<3>;
+ funcs[3] = mul_mat_bf16_r16_bf16<4>;
+ funcs[4] = mul_mat_bf16_r16_bf16<5>;
+ funcs[5] = mul_mat_bf16_r16_bf16<6>;
+ funcs[6] = mul_mat_bf16_r16_bf16<7>;
+ funcs[7] = mul_mat_bf16_r16_bf16<8>;
+}
+#endif
+
+} // namespace
+
+bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels) {
+
+ if (typeA == GGML_TYPE_BF16) {
+ if (ne00 % 32) return false;
+ switch (typeB) {
+#ifdef __AVX512BF16__
+ case GGML_TYPE_BF16: set_mul_mat_bf16(kernels); break;
+#else
+ case GGML_TYPE_BF16: set_mul_mat_f<ggml_bf16_t, ggml_bf16_t>(kernels); break;
+ case GGML_TYPE_F32: set_mul_mat_f<ggml_bf16_t, float>(kernels); break;
+#endif
+ default: return false;
+ }
+ return true;
+ }
+
+ if (typeA == GGML_TYPE_BF16_R16) {
+ if (ne00 % 16) return false;
+ switch (typeB) {
+#ifdef __AVX512BF16__
+ case GGML_TYPE_BF16: set_mul_mat_bf16_r16(kernels); break;
+#endif
+ default: return false;
+ }
+ return true;
+ }
+
+ if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) {
+ if (ne00 % 4) return false;
+ }
+ if (typeA == GGML_TYPE_F16) {
+ switch (typeB) {
+ case GGML_TYPE_F16: set_mul_mat_f<ggml_half, ggml_half>(kernels); break;
+ case GGML_TYPE_F32: set_mul_mat_f<ggml_half, float>(kernels); break;
+ default: return false;
+ }
+ return true;
+ }
+ if (typeA == GGML_TYPE_F32) {
+ switch (typeB) {
+ case GGML_TYPE_F16: set_mul_mat_f<float, ggml_half>(kernels); break;
+ case GGML_TYPE_F32: set_mul_mat_f<float, float>(kernels); break;
+ default: return false;
+ }
+ return true;
+ }
+
+ return false;
+
+}
+
+void iqk_gemm_default_floats(int D, int nq, const char * cx, size_t bx, DataInfo& info, int k_step) {
+ using q_float = float;
+#ifdef HAVE_FANCY_SIMD
+ constexpr int nrc_q = 8;
+ constexpr int nrc_k = 8;
+#else
+ // somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4
+ constexpr int nrc_q = 4;
+ constexpr int nrc_k = 8;
+#endif
+ GGML_ASSERT(k_step%nrc_k == 0);
+ int qrem = nq - nrc_q*(nq/nrc_q);
+ for (int iq = 0; iq < nq/nrc_q; ++iq) {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<float, nrc_q>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
+ }
+ info.cur_y += nrc_q;
+ }
+ if (qrem > 0) {
+ switch (qrem) {
+ case 1: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 1>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
+ }
+ } break;
+ case 2: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 2>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
+ }
+ } break;
+ case 3: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 3>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
+ }
+ } break;
+#ifdef HAVE_FANCY_SIMD
+ case 4: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 4>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
+ }
+ } break;
+ case 5: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 5>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
+ }
+ } break;
+ case 6: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 6>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
+ }
+ } break;
+ case 7: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 7>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
+ }
+ } break;
+#endif
+ }
+ }
+}
+
+#else
+// ----------------------------------- __aarch64__ -----------------------------------------------
+
+namespace {
+
+struct QF16Base {
+ constexpr static int k_step = 8;
+ using Data = float16x8_t;
+ using Acc = float16x8_t;
+ static inline Data load(const __fp16 * x) { return vld1q_f16(x); }
+ static inline Data load4(const __fp16 * x) { return vcombine_f16(vld1_f16(x), vdup_n_f16(0)); }
+ static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ return vfmaq_f16(prev, y, x);
+ }
+ static inline Acc acc_first(const Data& y, const Data& x) {
+ return vmulq_f16(y, x);
+ }
+ //constexpr static int k_step = 16;
+ //using Data = float16x8x2_t;
+ //static inline Data load(const __fp16 * x) { return vld1q_f16_x2(x); }
+ //static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ // return vfmaq_f16(vfmaq_f16(prev, y.val[0], x.val[0]), y.val[1], x.val[1]);
+ //}
+ //static inline Acc acc_first(const Data& y, const Data& x) {
+ // return vfmaq_f16(vmulq_f16(y.val[0], x.val[0]), y.val[1], x.val[1]);
+ //}
+ static inline float hsum(Acc acc) {
+ float32x4_t sum = vcvt_f32_f16(vadd_f16(vget_low_f16(acc), vget_high_f16(acc)));
+ return vaddvq_f32(sum);
+ }
+};
+template <int nrc> struct QF16 final : public QF16Base {
+ using Base = QF16Base;
+ constexpr static int nrc_y = nrc;
+ QF16(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)info.src1_row(iy);
+ }
+ QF16(const char * cx, size_t bx) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)(cx + iy*bx);
+ }
+ IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
+ IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4(y[iy] + 4*i); }
+ IQK_ALWAYS_INLINE float16x8x4_t loadx(int iy, int i) const { return vld1q_f16_x4(y[iy] + 4*k_step*i); }
+ const __fp16 * y[nrc_y];
+};
+
+struct QBF16Base {
+ constexpr static int k_step = 4;
+ using Data = float32x4_t;
+ using Acc = float32x4_t;
+ static inline Data load(const uint16_t * x) { return vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vld1_u16(x)), 16)); }
+ static inline Data load4(const uint16_t * x) { return load(x); }
+ static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ return vfmaq_f32(prev, y, x);
+ }
+ static inline Acc acc_first(const Data& y, const Data& x) {
+ return vmulq_f32(y, x);
+ }
+ static inline float hsum(Acc acc) { return vaddvq_f32(acc); }
+};
+template <int nrc> struct QBF16 final : public QBF16Base {
+ using Base = QBF16Base;
+ constexpr static int nrc_y = nrc;
+ QBF16(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const uint16_t *)info.src1_row(iy);
+ }
+ QBF16(const char * cx, size_t bx) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const uint16_t *)(cx + iy*bx);
+ }
+ IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
+ IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load(y[iy] + 4*i); }
+ const uint16_t * y[nrc_y];
+};
+
+struct QF32Base {
+ constexpr static int k_step = 4;
+ using Data = float32x4_t;
+ using Acc = float32x4_t;
+ static inline Data load(const float * x) { return vld1q_f32(x); }
+ static inline Data load4(const float * x) { return load(x); }
+ static inline Acc acc(Acc prev, const Data& y, const Data& x) { return vfmaq_f32(prev, y, x); }
+ static inline Acc acc_first(const Data& y, const Data& x) { return vmulq_f32(y, x); }
+ static inline float hsum(Acc acc) { return vaddvq_f32(acc); }
+};
+template <int nrc> struct QF32 final : public QF32Base {
+ using Base = QF32Base;
+ constexpr static int nrc_y = nrc;
+ QF32(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
+ }
+ QF32(const char * cx, size_t bx) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)(cx + iy*bx);
+ }
+ IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
+ IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load(y[iy] + 4*i); }
+ const float * y[nrc_y];
+};
+
+template <typename Qy, typename Qx, bool is_multiple_of_k_step>
+IQK_NOINLINE void mul_mat_Qx_Qy_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ GGML_ASSERT(Qx::Base::k_step == Qy::Base::k_step);
+ int nb = n/Qx::Base::k_step;
+ Qy y(info);
+ Qx x(cx + ix0*bx, bx);
+ typename Qx::Base::Data xv[Qx::nrc_y];
+ typename Qx::Base::Acc acc[Qx::nrc_y*Qy::nrc_y];
+ auto yv = y.load1(0, 0);
+ for (int ix = 0; ix < Qx::nrc_y; ++ix) {
+ xv[ix] = x.load1(ix, 0);
+ acc[ix] = Qx::Base::acc_first(yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc_y; ++iy) {
+ yv = y.load1(iy, 0);
+ for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc_first(yv, xv[ix]);
+ }
+ for (int i = 1; i < nb; ++i) {
+ yv = y.load1(0, i);
+ for (int ix = 0; ix < Qx::nrc_y; ++ix) {
+ xv[ix] = x.load1(ix, i);
+ acc[ix] = Qx::Base::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc_y; ++iy) {
+ yv = y.load1(iy, i);
+ for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc(acc[Qx::nrc_y*iy + ix], yv, xv[ix]);
+ }
+ }
+ if constexpr (Qx::Base::k_step > 4 && !is_multiple_of_k_step) {
+ int nb4 = n/4;
+ for (int i = (Qx::Base::k_step/4)*nb; i < nb4; ++i) {
+ yv = y.load_tail(0, i);
+ for (int ix = 0; ix < Qx::nrc_y; ++ix) {
+ xv[ix] = x.load_tail(ix, i);
+ acc[ix] = Qx::Base::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc_y; ++iy) {
+ yv = y.load_tail(iy, i);
+ for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc(acc[Qx::nrc_y*iy + ix], yv, xv[ix]);
+ }
+ }
+ }
+ for (int iy = 0; iy < Qy::nrc_y; ++iy) for (int ix = 0; ix < Qx::nrc_y; ++ix) info.store(ix0+ix, iy, Qx::Base::hsum(acc[Qx::nrc_y*iy+ix]));
+}
+
+template <int nrc_y, int nrc_x, bool is_multiple_of_k_step>
+IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ assert(n%QF16Base::k_step == 0);
+ int nb = n/QF16Base::k_step;
+ QF16<nrc_y> y(info);
+ QF16<nrc_x> x(cx + ix0*bx, bx);
+ QF16Base::Data xv[nrc_x];
+ QF16Base::Acc acc[nrc_x*nrc_y];
+ auto yv = y.load1(0, 0);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ xv[ix] = x.load1(ix, 0);
+ acc[ix] = QF16Base::acc_first(yv, xv[ix]);
+ }
+ for (int iy = 1; iy < nrc_y; ++iy) {
+ yv = y.load1(iy, 0);
+ for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc_first(yv, xv[ix]);
+ }
+ for (int i = 1; i < nb; ++i) {
+ yv = y.load1(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ xv[ix] = x.load1(ix, i);
+ acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < nrc_y; ++iy) {
+ yv = y.load1(iy, i);
+ for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
+ }
+ }
+ if constexpr (!is_multiple_of_k_step) {
+ int nb4 = n/4;
+ for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) {
+ yv = y.load_tail(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ xv[ix] = x.load_tail(ix, i);
+ acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < nrc_y; ++iy) {
+ yv = y.load_tail(iy, i);
+ for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QF16Base::hsum(acc[nrc_x*iy+ix]));
+}
+
+template <typename Qy, template<int> typename Qx>
+void mul_mat_Qx_Qy_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%4 == 0);
+ constexpr int k_nx = 5;
+ const char * cx = (const char *)vx;
+ if (n%Qx<k_nx>::Base::k_step == 0) {
+ for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ mul_mat_Qx_Qy_NxN<Qy, Qx<k_nx>, true>(n, cx, bx, ix*k_nx, info);
+ }
+ int last_x = k_nx*(nrc_x/k_nx);
+ if (last_x == nrc_x) return;
+ int nx = nrc_x - last_x;
+ switch (nx) {
+ case 1: mul_mat_Qx_Qy_NxN<Qy, Qx<1>, true>(n, cx, bx, last_x, info); break;
+ case 2: mul_mat_Qx_Qy_NxN<Qy, Qx<2>, true>(n, cx, bx, last_x, info); break;
+ case 3: mul_mat_Qx_Qy_NxN<Qy, Qx<3>, true>(n, cx, bx, last_x, info); break;
+ case 4: mul_mat_Qx_Qy_NxN<Qy, Qx<4>, true>(n, cx, bx, last_x, info); break;
+ }
+ } else {
+ for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ mul_mat_Qx_Qy_NxN<Qy, Qx<k_nx>, false>(n, cx, bx, ix*k_nx, info);
+ }
+ int last_x = k_nx*(nrc_x/k_nx);
+ if (last_x == nrc_x) return;
+ int nx = nrc_x - last_x;
+ switch (nx) {
+ case 1: mul_mat_Qx_Qy_NxN<Qy, Qx<1>, false>(n, cx, bx, last_x, info); break;
+ case 2: mul_mat_Qx_Qy_NxN<Qy, Qx<2>, false>(n, cx, bx, last_x, info); break;
+ case 3: mul_mat_Qx_Qy_NxN<Qy, Qx<3>, false>(n, cx, bx, last_x, info); break;
+ case 4: mul_mat_Qx_Qy_NxN<Qy, Qx<4>, false>(n, cx, bx, last_x, info); break;
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%4 == 0);
+ constexpr int k_nx = 5;
+ const char * cx = (const char *)vx;
+ if (n%QF16Base::k_step == 0) {
+ for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ mul_mat_f16_f16_NxN<nrc_y, k_nx, true>(n, cx, bx, ix*k_nx, info);
+ }
+ int last_x = k_nx*(nrc_x/k_nx);
+ if (last_x == nrc_x) return;
+ int nx = nrc_x - last_x;
+ switch (nx) {
+ case 1: mul_mat_f16_f16_NxN<nrc_y, 1, true>(n, cx, bx, last_x, info); break;
+ case 2: mul_mat_f16_f16_NxN<nrc_y, 2, true>(n, cx, bx, last_x, info); break;
+ case 3: mul_mat_f16_f16_NxN<nrc_y, 3, true>(n, cx, bx, last_x, info); break;
+ case 4: mul_mat_f16_f16_NxN<nrc_y, 4, true>(n, cx, bx, last_x, info); break;
+ }
+ } else {
+ for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ mul_mat_f16_f16_NxN<nrc_y, k_nx, false>(n, cx, bx, ix*k_nx, info);
+ }
+ int last_x = k_nx*(nrc_x/k_nx);
+ if (last_x == nrc_x) return;
+ int nx = nrc_x - last_x;
+ switch (nx) {
+ case 1: mul_mat_f16_f16_NxN<nrc_y, 1, false>(n, cx, bx, last_x, info); break;
+ case 2: mul_mat_f16_f16_NxN<nrc_y, 2, false>(n, cx, bx, last_x, info); break;
+ case 3: mul_mat_f16_f16_NxN<nrc_y, 3, false>(n, cx, bx, last_x, info); break;
+ case 4: mul_mat_f16_f16_NxN<nrc_y, 4, false>(n, cx, bx, last_x, info); break;
+ }
+ }
+}
+
+template <int nrc_x, bool is_multiple_of_k_step>
+IQK_NOINLINE void mul_mat_f16_f16_Nx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ assert(n%QF16Base::k_step == 0);
+ int nb = n/QF16Base::k_step;
+ QF16<1> y(info);
+ QF16<nrc_x> x(cx + ix0*bx, bx);
+ QF16Base::Acc acc[4*nrc_x];
+ auto yv = y.loadx(0, 0);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ for (int k = 0; k < 4; ++k) {
+ auto xv = x.load1(ix, k);
+ acc[4*ix+k] = QF16Base::acc_first(yv.val[k], xv);
+ }
+ }
+ for (int i = 1; i < nb/4; ++i) {
+ yv = y.loadx(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ for (int k = 0; k < 4; ++k) {
+ auto xv = x.load1(ix, 4*i+k);
+ acc[4*ix+k] = QF16Base::acc(acc[4*ix+k], yv.val[k], xv);
+ }
+ }
+ }
+ for (int i = 4*(nb/4); i < nb; ++i) {
+ auto yv1 = y.load1(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto xv1 = x.load1(ix, i);
+ acc[4*ix] = QF16Base::acc(acc[4*ix], yv1, xv1);
+ }
+ }
+ if constexpr (!is_multiple_of_k_step) {
+ int nb4 = n/4;
+ for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) {
+ auto yv1 = y.load_tail(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto xv1 = x.load_tail(ix, i);
+ acc[4*ix] = QF16Base::acc(acc[4*ix], yv1, xv1);
+ }
+ }
+ }
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto v1 = vaddq_f16(acc[4*ix+0], acc[4*ix+1]);
+ auto v2 = vaddq_f16(acc[4*ix+2], acc[4*ix+3]);
+ info.store(ix0+ix, 0, QF16Base::hsum(vaddq_f16(v1, v2)));
+ }
+}
+
+// At least on my M2-Max the version below, which does the multiplication row-by-row, is faster.
+// But let's keep this version commented out for now.
+//void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+// GGML_ASSERT(n%4 == 0);
+// constexpr int k_nx = 2;
+// const char * cx = (const char *)vx;
+// if (n%QF16Base::k_step == 0) {
+// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+// mul_mat_f16_f16_Nx1<k_nx, true>(n, cx, bx, ix*k_nx, info);
+// }
+// int last_x = k_nx*(nrc_x/k_nx);
+// if (last_x == nrc_x) return;
+// int nx = nrc_x - last_x;
+// switch (nx) {
+// case 1: mul_mat_f16_f16_Nx1<1, true>(n, cx, bx, last_x, info); break;
+// //case 2: mul_mat_f16_f16_Nx1<2, true>(n, cx, bx, last_x, info); break;
+// //case 3: mul_mat_f16_f16_Nx1<3, true>(n, cx, bx, last_x, info); break;
+// }
+// } else {
+// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+// mul_mat_f16_f16_Nx1<k_nx, false>(n, cx, bx, ix*k_nx, info);
+// }
+// int last_x = k_nx*(nrc_x/k_nx);
+// if (last_x == nrc_x) return;
+// int nx = nrc_x - last_x;
+// switch (nx) {
+// case 1: mul_mat_f16_f16_Nx1<1, false>(n, cx, bx, last_x, info); break;
+// //case 2: mul_mat_f16_f16_Nx1<2, false>(n, cx, bx, last_x, info); break;
+// //case 3: mul_mat_f16_f16_Nx1<3, false>(n, cx, bx, last_x, info); break;
+// }
+// }
+//}
+
+void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%4 == 0);
+ const char * cx = (const char *)vx;
+ if (n%QF16Base::k_step == 0) {
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ mul_mat_f16_f16_Nx1<1, true>(n, cx, bx, ix, info);
+ }
+ } else {
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ mul_mat_f16_f16_Nx1<1, false>(n, cx, bx, ix, info);
+ }
+ }
+}
+
+}
+
+bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels) {
+
+ if (ne00%4 == 0) {
+
+ if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F16) {
+ for (auto& f : kernels) f = nullptr;
+ kernels[0] = mul_mat_f16_f16_1;
+ kernels[1] = mul_mat_f16_f16_T<2>;
+ kernels[2] = mul_mat_f16_f16_T<3>;
+ kernels[3] = mul_mat_f16_f16_T<4>;
+ kernels[4] = mul_mat_f16_f16_T<5>;
+ return true;
+ }
+ else if (typeA == GGML_TYPE_BF16 && typeB == GGML_TYPE_F32) {
+ for (auto& f : kernels) f = nullptr;
+ kernels[0] = mul_mat_Qx_Qy_T<QF32<1>, QBF16>;
+ kernels[1] = mul_mat_Qx_Qy_T<QF32<2>, QBF16>;
+ kernels[2] = mul_mat_Qx_Qy_T<QF32<3>, QBF16>;
+ kernels[3] = mul_mat_Qx_Qy_T<QF32<4>, QBF16>;
+ kernels[4] = mul_mat_Qx_Qy_T<QF32<5>, QBF16>;
+ return true;
+ }
+
+ }
+
+ return false;
+
+}
+
+namespace {
+template <int nrc_q>
+inline void mm_helper(int D, int nq, const char * cx, size_t bx, DataInfo& info, int k_step) {
+ constexpr int nrc_k = 6;
+ int krem = k_step - nrc_k*(k_step/nrc_k);
+ for (int iq = 0; iq < nq/nrc_q; ++iq) {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_f16_f16_NxN<nrc_q, nrc_k, true>(D, cx, bx, ik*nrc_k, info);
+ }
+ if (krem > 0) {
+ switch (krem) {
+ case 1: mul_mat_f16_f16_NxN<nrc_q, 1, true>(D, cx, bx, k_step - krem, info); break;
+ case 2: mul_mat_f16_f16_NxN<nrc_q, 2, true>(D, cx, bx, k_step - krem, info); break;
+ case 3: mul_mat_f16_f16_NxN<nrc_q, 3, true>(D, cx, bx, k_step - krem, info); break;
+ case 4: mul_mat_f16_f16_NxN<nrc_q, 4, true>(D, cx, bx, k_step - krem, info); break;
+ default: mul_mat_f16_f16_NxN<nrc_q, 5, true>(D, cx, bx, k_step - krem, info); break;
+ }
+ }
+ info.cur_y += nrc_q;
+ }
+}
+}
+
+void iqk_gemm_default_floats(int D, int nq, const char * cx, size_t bx, DataInfo& info, int k_step) {
+ constexpr int nrc_q = 4;
+ mm_helper<nrc_q>(D, nq, cx, bx, info, k_step);
+ if (int qrem = nq - nrc_q*(nq/nrc_q); qrem > 0) {
+ switch (qrem) {
+ case 1: mm_helper<1>(D, nq, cx, bx, info, k_step);
+ case 2: mm_helper<2>(D, nq, cx, bx, info, k_step);
+ default: mm_helper<3>(D, nq, cx, bx, info, k_step);
+ }
+ }
+}
+
+#endif
+
+#endif
diff --git a/ggml/src/iqk/iqk_gemm_floats.h b/ggml/src/iqk/iqk_gemm_floats.h
new file mode 100644
index 00000000..aba514f6
--- /dev/null
+++ b/ggml/src/iqk/iqk_gemm_floats.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include "iqk_common.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include <array>
+
+bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels);
+
+void iqk_gemm_default_floats(int D, int nq, const char * vx, size_t bx, DataInfo& info, int k_step);
+
+#endif
diff --git a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp
new file mode 100644
index 00000000..15c963ca
--- /dev/null
+++ b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp
@@ -0,0 +1,3289 @@
+#include "iqk_gemm_iqk_quants.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include "ggml-impl.h"
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#ifdef __x86_64__
+
+namespace {
+
+#ifdef HAVE_FANCY_SIMD
+
+struct IQXKScales {
+ IQXKScales(uint8_t shift, int8_t min_val) : eshift(_mm256_set1_epi16(shift)), min(_mm256_set1_epi16(min_val)) {}
+ template <typename Q8>
+ inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m512i * scales) const {
+ auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle));
+ scales16 = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, extra, min, eshift));
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i));
+ accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
+ }
+ scales16 = MM256_SET_M128I(scales8, scales8);
+ scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1));
+ scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2));
+ }
+ const __m256i eshift;
+ const __m256i min;
+ const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
+ const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
+ const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
+ const __m256i shuffle1 = _mm256_set_epi64x(0x0b0b0b0b09090909, 0x0303030301010101, 0x0a0a0a0a08080808, 0x0202020200000000);
+ const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0d0d0d0d, 0x0707070705050505, 0x0e0e0e0e0c0c0c0c, 0x0606060604040404);
+};
+
+struct IQXKScales2 {
+ IQXKScales2(uint8_t shift, int8_t min_val) : eshift(_mm256_set1_epi16(shift)), min(_mm256_set1_epi16(min_val)) {}
+ template <typename Q8>
+ inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m512i * scales) const {
+ process(i, d, extra, _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)), q8, accm, scales);
+ }
+ template <typename Q8>
+ inline void process(int i, float d, uint16_t extra, __m256i scales16, const Q8& q8, __m256 * accm, __m512i * scales) const {
+ auto scales_s = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, extra, min, eshift));
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i prod = _mm256_madd_epi16(scales_s, q8.load_bsums(iy, i));
+ accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
+ }
+ auto aux_1 = MM256_SET_M128I(_mm256_castsi256_si128(scales16), _mm256_castsi256_si128(scales16));
+ auto aux_2 = MM256_SET_M128I(_mm256_extracti128_si256(scales16, 1), _mm256_extracti128_si256(scales16, 1));
+ auto scales16_1 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_1), aux_1, 1);
+ auto scales16_2 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_2), aux_2, 1);
+ scales[0] = _mm512_shuffle_epi8(scales16_1, shuffles[0]);
+ scales[1] = _mm512_shuffle_epi8(scales16_1, shuffles[1]);
+ scales[2] = _mm512_shuffle_epi8(scales16_2, shuffles[0]);
+ scales[3] = _mm512_shuffle_epi8(scales16_2, shuffles[1]);
+ }
+ const __m256i eshift;
+ const __m256i min;
+ const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
+ const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
+ const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
+ const __m512i shuffles[2] = {
+ _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(),
+ _mm_set1_epi16(0x0100), 0), _mm_set1_epi16(0x0302), 1), _mm_set1_epi16(0x0504), 2), _mm_set1_epi16(0x0706), 3),
+ _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(),
+ _mm_set1_epi16(0x0908), 0), _mm_set1_epi16(0x0b0a), 1), _mm_set1_epi16(0x0d0c), 2), _mm_set1_epi16(0x0f0e), 3)
+ };
+};
+
+struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
+ DequantizerIQ2KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}
+ template <typename Q8>
+ inline void compute_block(int i, const Q8& q8, __m512 * acc) {
+ prepare(x[i].qs);
+ auto scales128 = make_scales(x[i].scales, x[i].extra >> 8);
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi8(x[i].extra), hmask), hmask), m5);
+ auto mins128 = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts)));
+ auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0]));
+ auto scales256 = MM256_SET_M128I(scales128, scales128);
+ auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
+ __m512i scales[4];
+ for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]);
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto q8s = q8.load_bsums(iy, i);
+ auto prod = _mm256_madd_epi16(mins, q8s);
+ auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0);
+ for (int k = 0; k < 4; ++k) {
+ auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k));
+ sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]);
+ }
+ acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ inline void prepare(const uint8_t * q2) {
+ bits.prepare(q2);
+ bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]);
+ bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]);
+ bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]);
+ bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]);
+ }
+ static inline __m512i load_values() {
+ static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0};
+ auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl);
+ auto val256 = MM256_SET_M128I(val128, val128);
+ return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
+ }
+ inline __m128i make_scales(const uint8_t * scales_l, uint8_t scales_h) const {
+ const uint16_t * scales = (const uint16_t *)scales_l;
+ uint32_t aux32 = scales[0] | (uint32_t(scales[1]) << 16);
+ auto scl = _mm_srlv_epi32(_mm_set1_epi32(aux32), shift);
+ scl = _mm_and_si128(_mm_shuffle_epi8(scl, shuffle), _mm_set1_epi8(0xf));
+ auto sch = _mm_set1_epi8(scales_h);
+ sch = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(sch, hmask), _mm_setzero_si128()), m16);
+ return _mm_cvtepi8_epi16(_mm_add_epi8(scl, sch));
+ }
+ Q2Bits bits;
+ Scales8KBase s8k;
+
+ const __m512i values;
+ const __m128i m16 = _mm_set1_epi8(-16);
+ const __m128i m5 = _mm_set1_epi8(5);
+ const __m128i m32 = _mm_set1_epi8(-32);
+ const __m128i hmask = _mm_set1_epi64x(0x8040201008040201);
+ const __m128i shuffle = _mm_set1_epi64x(0x0703060205010400);
+ const __m128i shift = _mm_set_epi32(0, 0, 4, 0);
+ const __m512i shuffles[4] = {
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
+ };
+};
+
+struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
+ DequantizerIQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(IQXKScales(5, -32)), values(load_values()) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ prepare(x[i].qs);
+ iqxk.process(i, d, x[i].extra, make_scales(x[i].scales), q8, accm, scales);
+ }
+ inline void prepare(const uint8_t * q2) {
+ bits.prepare(q2);
+ bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]);
+ bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]);
+ bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]);
+ bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]);
+ }
+ static inline __m512i load_values() {
+ static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0};
+ auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl);
+ auto val256 = MM256_SET_M128I(val128, val128);
+ return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
+ }
+ inline __m128i make_scales(const uint8_t * scales_l) const {
+ uint64_t aux64; std::memcpy(&aux64, scales_l, 8);
+ auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf));
+ return _mm_add_epi8(scl, m8);
+ }
+ Q2Bits bits;
+ const IQXKScales iqxk;
+
+ const __m512i values;
+ const __m128i m8 = _mm_set1_epi8(-8);
+};
+
+struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
+ DequantizerIQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -64), values(load_values()) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ prepare(x[i].qs, x[i].qh);
+ iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_h, x[i].scales_l), q8, accm, scales);
+ }
+ inline void prepare(const uint8_t * q2, const uint8_t * qh) {
+ bits.prepare(q2);
+ auto h256 = _mm256_loadu_si256((const __m256i *)qh);
+ auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 1), 1);
+ bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), hmask));
+ bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, hmask));
+ bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), hmask));
+ bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), hmask));
+ bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]);
+ bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]);
+ bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]);
+ bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]);
+ }
+ static inline __m512i load_values() {
+ static const uint8_t kvalues_iq3nl[16] = {1, 24, 41, 54, 65, 77, 92, 111, 5, 28, 45, 58, 69, 81, 96, 115};
+ auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq3nl);
+ auto val256 = MM256_SET_M128I(val128, val128);
+ return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
+ }
+ inline __m128i make_scales(uint16_t signs, const uint8_t * scales_l) const {
+ uint64_t aux64; std::memcpy(&aux64, scales_l, 8);
+ auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf));
+ scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), m1);
+ const __m128i sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(signs), sign_mask), sign_mask);
+ const __m128i sch = _mm_shuffle_epi8(_mm_or_si128(sc_signs, _mm_set1_epi8(1)), hshuff);
+ return _mm_sign_epi8(scl, sch);
+ }
+ Q2Bits bits;
+ const IQXKScales2 iqxk;
+
+ const __m512i values;
+ const __m512i hmask = _mm512_set1_epi8(4);
+ const __m128i m1 = _mm_set1_epi8(1);
+ const __m128i sign_mask = _mm_set_epi64x(0x8080404020201010, 0x0808040402020101);
+ const __m128i hshuff = _mm_loadu_si128((const __m128i*)k_shuff);
+ constexpr static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
+};
+
+struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
+ DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ uint32_t aux32[2];
+ auto b1 = _mm512_loadu_si512((const __m512i *)x[i].qs + 0);
+ auto b2 = _mm512_loadu_si512((const __m512i *)x[i].qs + 1);
+ auto bs1 = _mm512_and_si512(b1, mask15);
+ bs1 = _mm512_xor_si512(bs1, _mm512_srli_epi16(bs1, 1));
+ auto bs2 = _mm512_and_si512(b2, mask15);
+ bs2 = _mm512_xor_si512(bs2, _mm512_srli_epi16(bs2, 1));
+ bits.values[0] = _mm512_and_si512(bs1, bits.ml);
+ bits.values[1] = _mm512_and_si512(_mm512_srli_epi16(bs1, 4), bits.ml);
+ bits.values[2] = _mm512_and_si512(bs2, bits.ml);
+ bits.values[3] = _mm512_and_si512(_mm512_srli_epi16(bs2, 4), bits.ml);
+ auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
+ bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
+ bits.values[0] = _mm512_shuffle_epi8(values, tmp);
+ tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
+ bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
+ bits.values[2] = _mm512_shuffle_epi8(values, tmp);
+ //
+ // Now the more difficult part - prepare the scales
+ //
+ aux32[0] = _mm512_cmpeq_epi16_mask(_mm512_and_si512(b1, mask1), mask1);
+ aux32[1] = _mm512_cmpeq_epi16_mask(_mm512_and_si512(b2, mask1), mask1);
+
+ auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)aux32));
+ auto m1 = _mm512_castsi512_si128(mask1);
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4);
+ scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
+ auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
+ s8k.accum_mins(scales_s, q8, i, d, accm);
+ auto scales256 = MM256_SET_M128I(scales128, scales128);
+ auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
+ scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
+ scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
+ scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
+ scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
+ }
+
+ Q4Bits bits;
+ Scales8KBase s8k;
+ const __m512i values;
+ const __m512i mask15 = _mm512_set1_epi16(-2); // value is 0xfffe, but to shut up the stupid compiler warning we use the signed value
+ const __m512i mask1 = _mm512_set1_epi16(1);
+ const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
+ const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
+ const __m128i mask = _mm_set1_epi16(254);
+ const __m128i m127 = _mm_set1_epi16(-127);
+ const __m128i m128 = _mm_set1_epi16(-128);
+ const __m128i m4 = _mm_set1_epi16(4);
+ const __m512i shuffles[4] = {
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
+ };
+};
+
+struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
+ DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4);
+ scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
+ auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
+ s8k.accum_mins(scales_s, q8, i, d, accm);
+ auto scales256 = MM256_SET_M128I(scales128, scales128);
+ auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
+ scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
+ scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
+ scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
+ scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
+ prepare(x[i].qs);
+ }
+ template <typename Q8>
+ inline void compute_block(int i, const Q8& q8, __m512 * acc) {
+ auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4);
+ scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
+ auto mins128 = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
+ auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0]));
+ auto scales256 = MM256_SET_M128I(scales128, scales128);
+ auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
+ __m512i scales[4];
+ for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]);
+ prepare(x[i].qs);
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto q8s = q8.load_bsums(iy, i);
+ auto prod = _mm256_madd_epi16(mins, q8s);
+ auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0);
+ for (int k = 0; k < 4; ++k) {
+ auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k));
+ sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]);
+ }
+ acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ inline void prepare(const uint8_t * q4) {
+ bits.prepare64(q4);
+ // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
+ // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127
+ // etc.
+ auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
+ bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
+ bits.values[0] = _mm512_shuffle_epi8(values, tmp);
+ tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
+ bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
+ bits.values[2] = _mm512_shuffle_epi8(values, tmp);
+ }
+
+ Q4Bits bits;
+ Scales8KBase s8k;
+ const __m512i values;
+ const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
+ const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
+ const __m128i mask = _mm_set1_epi16(254);
+ const __m128i m127 = _mm_set1_epi16(-127);
+ const __m128i m128 = _mm_set1_epi16(-128);
+ const __m128i m1 = _mm_set1_epi16(1);
+ const __m128i m4 = _mm_set1_epi16(4);
+ const __m512i shuffles[4] = {
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
+ };
+};
+
+struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
+ DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -128), values(load_iq4nl_values_512()) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ prepare(x[i].qs);
+ iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales);
+ }
+ inline void prepare(const uint8_t * q4) {
+ bits.prepare64(q4);
+ // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
+ // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127
+ // etc.
+ auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
+ bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
+ bits.values[0] = _mm512_shuffle_epi8(values, tmp);
+ tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
+ bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
+ bits.values[2] = _mm512_shuffle_epi8(values, tmp);
+ }
+ __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const {
+ uint64_t aux64;
+ memcpy(&aux64, scales_l, 8);
+ auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
+ const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16);
+ auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh);
+ auto sch = _mm_shuffle_epi8(aux, iqxk.scale_shuffle);
+ return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
+ }
+
+ Q4Bits bits;
+ const IQXKScales2 iqxk;
+ const __m512i values;
+ const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
+ const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
+ const __m128i maskl = _mm_set1_epi8(0xf);
+ const __m128i maskh = _mm_set1_epi8(0x30);
+ const __m128i m32 = _mm_set1_epi8(-32);
+};
+
+struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
+ DequantizerIQ5KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(values); }
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2);
+ scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
+ auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
+ s8k.accum_mins(scales_s, q8, i, d, accm);
+ auto scales256 = MM256_SET_M128I(scales128, scales128);
+ auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
+ scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
+ scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
+ scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
+ scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
+ prepare(x[i].qs, x[i].qh);
+ }
+ template <typename Q8>
+ inline void compute_block(int i, const Q8& q8, __m512 * acc) {
+ auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2);
+ scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
+ auto mins128 = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
+ auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0]));
+ auto scales256 = MM256_SET_M128I(scales128, scales128);
+ auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
+ __m512i scales[4];
+ for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]);
+ prepare(x[i].qs, x[i].qh);
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto q8s = q8.load_bsums(iy, i);
+ auto prod = _mm256_madd_epi16(mins, q8s);
+ auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0);
+ for (int k = 0; k < 4; ++k) {
+ auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k));
+ sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]);
+ }
+ acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ inline void prepare(const uint8_t * q4, const uint8_t * qh) {
+ bits.prepare64a(q4);
+ auto h256 = _mm256_loadu_si256((const __m256i *)qh);
+ auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 1), 1);
+ auto m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1);
+ auto m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
+ bits.values[0] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[0]), m1, values[1], bits.values[0]);
+ bits.values[1] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[1]), m2, values[1], bits.values[1]);
+ hbits = _mm512_srli_epi16(hbits, 4);
+ m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1);
+ m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
+ bits.values[2] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[2]), m1, values[1], bits.values[2]);
+ bits.values[3] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[3]), m2, values[1], bits.values[3]);
+ }
+ static void load_values(__m512i * values) {
+ static const uint8_t kvalues_iq5nl[32] = {
+ 2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127,
+ 133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249,
+ };
+ auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0);
+ auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1);
+ auto values256_1 = MM256_SET_M128I(values128_1, values128_1);
+ auto values256_2 = MM256_SET_M128I(values128_2, values128_2);
+ values[0] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_1), values256_1, 1);
+ values[1] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_2), values256_2, 1);
+ }
+
+ Q4Bits bits;
+ Scales8KBase s8k;
+ __m512i values[2];
+ const __m512i hmask1 = _mm512_set1_epi8(1);
+ const __m512i hmask2 = _mm512_set1_epi8(4);
+ const __m128i m127 = _mm_set1_epi16(-127);
+ const __m128i m128 = _mm_set1_epi16(-128);
+ const __m128i mask = _mm_set1_epi16(254);
+ const __m128i m1 = _mm_set1_epi16(1);
+ const __m128i m2 = _mm_set1_epi16(2);
+ const __m512i shuffles[4] = {
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
+ };
+};
+
+struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
+ DequantizerIQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(2, -128) { load_values(values); }
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ prepare(x[i].qs, x[i].qh);
+ iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales);
+ }
+ inline void prepare(const uint8_t * q4, const uint8_t * qh) {
+ bits.prepare64(q4);
+ auto h256 = _mm256_loadu_si256((const __m256i *)qh);
+ auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 2), 1);
+ auto m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1);
+ auto m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
+ bits.values[0] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[0]), m1, values[1], bits.values[0]);
+ bits.values[1] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[1]), m2, values[1], bits.values[1]);
+ hbits = _mm512_srli_epi16(hbits, 4);
+ m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1);
+ m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
+ bits.values[2] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[2]), m1, values[1], bits.values[2]);
+ bits.values[3] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[3]), m2, values[1], bits.values[3]);
+ // We now have in bits.valuse[0]: 0...31, 64...95
+ // bits.valuse[1]: 32..63, 96..127
+ // etc.
+ auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
+ bits.values[1] = _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]);
+ bits.values[0] = tmp;
+ tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
+ bits.values[3] = _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]);
+ bits.values[2] = tmp;
+ }
+ __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const {
+ uint64_t aux64;
+ memcpy(&aux64, scales_l, 8);
+ auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
+ const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16);
+ auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh);
+ auto sch = _mm_shuffle_epi8(aux, iqxk.scale_shuffle);
+ return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
+ }
+ static void load_values(__m512i * values) {
+ static const uint8_t kvalues_iq5nl[32] = {
+ 2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127,
+ 133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249,
+ };
+ auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0);
+ auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1);
+ auto values256_1 = MM256_SET_M128I(values128_1, values128_1);
+ auto values256_2 = MM256_SET_M128I(values128_2, values128_2);
+ values[0] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_1), values256_1, 1);
+ values[1] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_2), values256_2, 1);
+ }
+
+ Q4Bits bits;
+ const IQXKScales2 iqxk;
+ __m512i values[2];
+ const __m512i hmask1 = _mm512_set1_epi8(1);
+ const __m512i hmask2 = _mm512_set1_epi8(2);
+ const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
+ const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
+ const __m128i maskl = _mm_set1_epi8(0xf);
+ const __m128i maskh = _mm_set1_epi8(0x30);
+ const __m128i m32 = _mm_set1_epi8(-32);
+};
+
+struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
+ DequantizerIQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(1, -128) { load_values(values); }
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ prepare(x[i].qs, x[i].qh);
+ auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales);
+ iqxk.process(i, d, x[i].extra, _mm256_cvtepi8_epi16(scales8), q8, accm, scales);
+ }
+ inline __m512i make_one(__m512i l, __m512i h) const {
+ auto p = _mm512_shuffle_epi8(values[0], l);
+ p = _mm512_mask_shuffle_epi8(p, _mm512_cmpeq_epi8_mask(_mm512_and_si512(h, masks[0]), masks[0]), values[1], l);
+ p = _mm512_mask_shuffle_epi8(p, _mm512_cmpeq_epi8_mask(_mm512_and_si512(h, masks[1]), masks[1]), values[2], l);
+ p = _mm512_mask_shuffle_epi8(p, _mm512_cmpeq_epi8_mask(_mm512_and_si512(h, masks[2]), masks[2]), values[3], l);
+ return p;
+ }
+ inline void prepare(const uint8_t * q4, const uint8_t * qh) {
+ bits.prepare64(q4);
+ auto h256_1 = _mm256_loadu_si256((const __m256i *)qh + 0);
+ auto h256_2 = _mm256_loadu_si256((const __m256i *)qh + 1);
+ auto h1 = _mm512_inserti32x8(_mm512_castsi256_si512(h256_1), _mm256_srli_epi16(h256_1, 4), 1);
+ auto h2 = _mm512_inserti32x8(_mm512_castsi256_si512(h256_2), _mm256_srli_epi16(h256_2, 4), 1);
+ bits.values[0] = make_one(bits.values[0], h1);
+ bits.values[1] = make_one(bits.values[1], _mm512_srli_epi16(h1, 2));
+ bits.values[2] = make_one(bits.values[2], h2);
+ bits.values[3] = make_one(bits.values[3], _mm512_srli_epi16(h2, 2));
+ // We now have in bits.valuse[0]: 0...31, 64...95
+ // bits.valuse[1]: 32..63, 96..127
+ // etc.
+ auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
+ bits.values[1] = _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]);
+ bits.values[0] = tmp;
+ tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
+ bits.values[3] = _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]);
+ bits.values[2] = tmp;
+ }
+ static void load_values(__m512i * values) {
+ static const uint8_t kvalues_iq6nl[64] = {
+ 1, 7, 13, 19, 24, 30, 35, 40, 44, 49, 54, 58, 62, 66, 70, 74,
+ 77, 81, 84, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 117, 120, 123,
+ 126, 128, 131, 134, 137, 140, 142, 145, 148, 151, 155, 158, 161, 164, 168, 172,
+ 175, 179, 183, 187, 191, 196, 200, 205, 210, 215, 220, 226, 231, 237, 243, 249,
+ };
+ for (int k = 0; k < 4; ++k) {
+ auto values128 = _mm_loadu_si128((const __m128i *)kvalues_iq6nl + k);
+ auto values256 = MM256_SET_M128I(values128, values128);
+ values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(values256), values256, 1);
+ }
+ }
+
+ Q4Bits bits;
+ IQXKScales2 iqxk;
+ __m512i values[4];
+ __m512i masks[3] = { _mm512_set1_epi8(0x01), _mm512_set1_epi8(0x02), _mm512_set1_epi8(0x03) };
+ const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
+ const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
+};
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y> q8(info);
+
+ Dequantizer deq(vx, bx);
+
+ __m256 accm[nrc_y];
+ __m512 accd[nrc_y];
+ __m512i scales[4];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
+ for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ deq.new_block(i, q8, accm, scales);
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const __m512i p1 = _mm512_maddubs_epi16(deq.bits.values[0], q8.load_quants64(iy, i, 0));
+ const __m512i p2 = _mm512_maddubs_epi16(deq.bits.values[1], q8.load_quants64(iy, i, 1));
+ const __m512i p3 = _mm512_maddubs_epi16(deq.bits.values[2], q8.load_quants64(iy, i, 2));
+ const __m512i p4 = _mm512_maddubs_epi16(deq.bits.values[3], q8.load_quants64(iy, i, 3));
+ auto sumi = _mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_setzero_si512(),
+ p1, scales[0]), p2, scales[1]), p3, scales[2]), p4, scales[3]);
+ accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
+ info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
+ }
+
+ }
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_iqX_k_q8_K_AVX512_new(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y> q8(info);
+
+ Dequantizer deq(vx, bx);
+
+ __m512 accd[nrc_y];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+ deq.compute_block(i, q8, accd);
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, _mm512_reduce_add_ps(accd[iy]));
+ }
+
+ }
+}
+
+template <typename Q8>
+inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
+ const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
+ const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1));
+ const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2));
+ const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3));
+ auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
+ sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
+ accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
+}
+
+template <typename Dequantizer>
+static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ constexpr int k_nx = 2;
+
+ Q8<1> q8(info);
+
+ Dequantizer deq1(vx, bx);
+ Dequantizer deq2(vx, bx);
+
+ Dequantizer * deq[k_nx];
+ deq[0] = &deq1;
+ deq[1] = &deq2;
+
+ __m512i scales[2*k_nx];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ auto accd = _mm512_setzero_ps();
+ auto accm = _mm256_setzero_ps();
+
+ for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix);
+
+ for (int i = 0; i < nb/k_nx; ++i) {
+
+ for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);
+
+ for (int kx = 0; kx < k_nx; ++kx) {
+ compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
+ }
+
+ }
+ if (2*(nb/2) < nb) {
+ int i0 = 2*(nb/2);
+ deq[0]->new_block(i0, q8, &accm, scales);
+ compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
+ }
+
+ auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
+ info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
+ }
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y> q8(info);
+
+ Dequantizer deq(vx, bx);
+
+ __m256 accm[nrc_y];
+ __m512 accd[nrc_y];
+ __m512i scales[2];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
+ for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ deq.new_block(i, q8, accm, scales);
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
+ const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
+ const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
+ const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
+ auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
+ sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
+ accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
+ info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
+ }
+
+ }
+}
+
+#else
+
+inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {
+ const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
+ const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
+ scales[0] = MM256_SET_M128I(l_scales, l_scales);
+ scales[1] = MM256_SET_M128I(h_scales, h_scales);
+}
+
+struct IQXKScales {
+ IQXKScales(int8_t shift, int8_t min_val) : min(_mm256_set1_epi16(min_val)), eshift(_mm_set1_epi8(shift)) {}
+ template <typename Q8>
+ inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m256i * scales) const {
+ auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, hshuff));
+ process(i, d, extra, scales16, q8, accm, scales);
+ }
+ template <typename Q8>
+ inline void process(int i, float d, uint16_t extra, __m256i scales16, const Q8& q8, __m256 * accm, __m256i * scales) const {
+ auto extra128 = _mm_set1_epi16(extra);
+ extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask);
+ extra128 = _mm_and_si128(extra128, eshift);
+ extra128 = _mm_shuffle_epi8(extra128, eshuffle);
+ auto scales_s = _mm256_mullo_epi16(scales16, _mm256_add_epi16(min, _mm256_cvtepi8_epi16(extra128)));
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i prod = _mm256_madd_epi16(scales_s, q8.load_bsums(iy, i));
+ accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
+ }
+ prepare_scales_16(scales16, scales);
+ }
+
+ const __m256i min;
+ const __m128i eshift;
+ const __m128i hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
+ const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
+ const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
+};
+
+struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
+ DequantizerIQ2KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}
+ template <typename Q8>
+ inline __m256i new_block(int i, const Q8& q8, __m256 * accm) {
+ auto scales128 = make_scales(x[i].scales, x[i].extra >> 8);
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi8(x[i].extra), hmask), hmask), m5);
+ auto scales_s = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts)));
+ s8k.accum_mins(scales_s, q8, i, d, accm);
+ return MM256_SET_M128I(scales128, scales128);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);
+ bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);
+ bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);
+ bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
+ }
+ static inline __m256i load_values() {
+ static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0};
+ auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl);
+ return MM256_SET_M128I(val128, val128);
+ }
+ inline __m128i make_scales(const uint8_t * scales_l, uint8_t scales_h) const {
+ const uint16_t * scales = (const uint16_t *)scales_l;
+ uint32_t aux32 = scales[0] | (uint32_t(scales[1]) << 16);
+ auto scl = _mm_srlv_epi32(_mm_set1_epi32(aux32), shift);
+ scl = _mm_and_si128(_mm_shuffle_epi8(scl, shuffle), _mm_set1_epi8(0xf));
+ auto sch = _mm_set1_epi8(scales_h);
+ sch = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(sch, hmask), _mm_setzero_si128()), m16);
+ return _mm_cvtepi8_epi16(_mm_add_epi8(scl, sch));
+ }
+ Q2Bits bits;
+ Scales8KBase s8k;
+
+ const __m256i values;
+ const __m128i m16 = _mm_set1_epi8(-16);
+ const __m128i m5 = _mm_set1_epi8(5);
+ const __m128i m32 = _mm_set1_epi8(-32);
+ const __m128i hmask = _mm_set1_epi64x(0x8040201008040201);
+ const __m128i shuffle = _mm_set1_epi64x(0x0703060205010400);
+ const __m128i shift = _mm_set_epi32(0, 0, 4, 0);
+};
+
+struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
+ DequantizerIQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(5, -32), values(load_values()) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ iqxk.process(i, d, x[i].extra, make_scales(x[i].scales), q8, accm, scales);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);
+ bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);
+ bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);
+ bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
+ }
+ static inline __m256i load_values() {
+ static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0};
+ auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl);
+ return MM256_SET_M128I(val128, val128);
+ }
+ inline __m128i make_scales(const uint8_t * scales_l) const {
+ uint64_t aux64; std::memcpy(&aux64, scales_l, 8);
+ auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
+ return _mm_add_epi8(scl, m8);
+ }
+
+ Q2Bits bits;
+ const IQXKScales iqxk;
+ const __m256i values;
+ const __m128i m8 = _mm_set1_epi8(-8);
+ const __m128i maskl = _mm_set1_epi8(0xf);
+};
+
+struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
+ DequantizerIQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -64), values(load_values()) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_h, x[i].scales_l), q8, accm, scales);
+ hbits = _mm256_loadu_si256((const __m256i *)x[i].qh);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ auto h256 = j == 0 ? hbits : _mm256_srli_epi16(hbits, 4);
+ bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(h256, 2), hmask));
+ bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(h256, 1), hmask));
+ bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(h256, hmask));
+ bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(h256, 1), hmask));
+ bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);
+ bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);
+ bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);
+ bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
+ }
+ static inline __m256i load_values() {
+ static const uint8_t kvalues_iq3nl[16] = {1, 24, 41, 54, 65, 77, 92, 111, 5, 28, 45, 58, 69, 81, 96, 115};
+ auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq3nl);
+ return MM256_SET_M128I(val128, val128);
+ }
+ inline __m128i make_scales(uint16_t signs, const uint8_t * scales_l) const {
+ uint64_t aux64; std::memcpy(&aux64, scales_l, 8);
+ auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf));
+ scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), m1);
+ const __m128i sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(signs), sign_mask), sign_mask);
+ const __m128i sch = _mm_shuffle_epi8(_mm_or_si128(sc_signs, _mm_set1_epi8(1)), hshuff);
+ return _mm_sign_epi8(scl, sch);
+ }
+
+ Q2Bits bits;
+ const IQXKScales iqxk;
+ const __m256i values;
+ __m256i hbits;
+ const __m256i hmask = _mm256_set1_epi8(4);
+ const __m128i m1 = _mm_set1_epi8(1);
+ const __m128i sign_mask = _mm_set_epi64x(0x8080404020201010, 0x0808040402020101);
+ const __m128i hshuff = _mm_loadu_si128((const __m128i*)k_shuff);
+ constexpr static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
+};
+
+struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
+ DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {}
+ template <typename Q8>
+ inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
+ union { __m256i vec; uint16_t val[16]; } helper;
+ for (int k = 0; k < 4; ++k) {
+ data[k] = _mm256_loadu_si256((const __m256i *)x[i].qs + k);
+ auto p = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(data[k], m1), m1), smask);
+ p = _mm256_add_epi32(_mm256_unpackhi_epi64(p, p), p);
+ p = _mm256_add_epi32(_mm256_shuffle_epi32(p, _MM_SHUFFLE(2, 3, 0, 1)), p);
+ helper.vec = _mm256_hadd_epi16(p, p);
+ aux[2*k+0] = helper.val[0];
+ aux[2*k+1] = helper.val[8];
+ data[k] = _mm256_and_si256(data[k], bmask);
+ data[k] = _mm256_xor_si256(data[k], _mm256_srli_epi16(data[k], 1));
+ }
+ auto scales128 = _mm_loadu_si128((const __m128i *)aux);
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, _mm256_castsi256_si128(m1)), _mm256_castsi256_si128(m1)), m4);
+ scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
+ auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
+ s8k.accum_mins(scales_s, q8, i, d, accd);
+ return MM256_SET_M128I(scales128, scales128);
+ }
+ inline void prepare(int, int j) {
+ for (int k = 0; k < 2; ++k) {
+ auto p1 = _mm256_castsi256_si128(data[2*j+k]);
+ auto p2 = _mm256_extractf128_si256(data[2*j+k], 1);
+ bits.values[2*k+0] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(p1, 4), p1), bits.ml);
+ bits.values[2*k+0] = _mm256_shuffle_epi8(values, bits.values[2*k+0]);
+ bits.values[2*k+1] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(p2, 4), p2), bits.ml);
+ bits.values[2*k+1] = _mm256_shuffle_epi8(values, bits.values[2*k+1]);
+ }
+ }
+
+ Q4Bits bits;
+ Scales8KBase s8k;
+ const __m256i values;
+ __m256i data[4];
+ const __m256i smask = _mm256_set_epi64x(0x0080004000200010, 0x0008000400020001, 0x0080004000200010, 0x0008000400020001);
+ const __m256i bmask = _mm256_set1_epi16(-2); // 0xfffe;
+ const __m128i mask = _mm_set1_epi16(254);
+ const __m128i m127 = _mm_set1_epi16(-127);
+ const __m128i m128 = _mm_set1_epi16(-128);
+ const __m256i m1 = _mm256_set1_epi16(1);
+ const __m128i m4 = _mm_set1_epi16(4);
+ uint16_t aux[8];
+};
+
+struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
+ DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); }
+ template <typename Q8>
+ inline __m256i new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accd) {
+ auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
+ scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
+ return MM256_SET_M128I(scales128, scales128);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare16(x[i].qs, j);
+ bits.values[0] = _mm256_shuffle_epi8(values[x[i].scales[4*j+0] & 1], bits.values[0]);
+ bits.values[1] = _mm256_shuffle_epi8(values[x[i].scales[4*j+1] & 1], bits.values[1]);
+ bits.values[2] = _mm256_shuffle_epi8(values[x[i].scales[4*j+2] & 1], bits.values[2]);
+ bits.values[3] = _mm256_shuffle_epi8(values[x[i].scales[4*j+3] & 1], bits.values[3]);
+ }
+ void load_values() {
+ auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0);
+ auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1);
+ values[0] = MM256_SET_M128I(v1, v1);
+ values[1] = MM256_SET_M128I(v2, v2);
+ }
+
+
+ Q4Bits bits;
+ __m256i values[2];
+ const __m128i mask = _mm_set1_epi16(254);
+ const __m128i m127 = _mm_set1_epi16(-127);
+};
+
+struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
+ DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); }
+ template <typename Q8>
+ inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, __m256i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ auto scales8 = make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h);
+ auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, hshuff));
+ prepare_scales_16(scales16, scales);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare16(x[i].qs, j);
+ auto extra = x[i].extra >> 8*j;
+ bits.values[0] = _mm256_shuffle_epi8(values[extra & 3], bits.values[0]); extra >>= 2;
+ bits.values[1] = _mm256_shuffle_epi8(values[extra & 3], bits.values[1]); extra >>= 2;
+ bits.values[2] = _mm256_shuffle_epi8(values[extra & 3], bits.values[2]); extra >>= 2;
+ bits.values[3] = _mm256_shuffle_epi8(values[extra & 3], bits.values[3]);
+ }
+ __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const {
+ uint64_t aux64;
+ memcpy(&aux64, scales_l, 8);
+ auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
+ const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16);
+ auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh);
+ auto sch = _mm_shuffle_epi8(aux, hshuff);
+ return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
+ }
+ void load_values() {
+ auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0);
+ auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1);
+ values[0] = MM256_SET_M128I(v1, v1);
+ values[1] = MM256_SET_M128I(v1, v2);
+ values[2] = MM256_SET_M128I(v2, v1);
+ values[3] = MM256_SET_M128I(v2, v2);
+ }
+
+ Q4Bits bits;
+ const __m128i maskl = _mm_set1_epi8(0xf);
+ const __m128i maskh = _mm_set1_epi8(0x30);
+ const __m128i m32 = _mm_set1_epi8(-32);
+ const __m128i hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
+ __m256i values[4];
+};
+
+struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
+ DequantizerIQ5KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(values); }
+ template <typename Q8>
+ inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
+ hbits = _mm256_loadu_si256((const __m256i *)x[i].qh);
+ auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
+ auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2);
+ scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
+ auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
+ s8k.accum_mins(scales_s, q8, i, d, accd);
+ return MM256_SET_M128I(scales128, scales128);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ auto h = j == 0 ? hbits : _mm256_srli_epi16(hbits, 4);
+ for (int k = 0; k < 4; ++k) {
+ auto qh = _mm256_and_si256(_mm256_slli_epi16(h, 7-k), mh);
+ auto q5vl = _mm256_or_si256(bits.values[k], qh);
+ auto q5vh = _mm256_or_si256(bits.values[k], _mm256_xor_si256(qh, mh));
+ bits.values[k] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
+ }
+ }
+ static void load_values(__m256i * values) {
+ static const uint8_t kvalues_iq5nl[32] = {
+ 2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127,
+ 133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249,
+ };
+ auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0);
+ auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1);
+ values[0] = MM256_SET_M128I(values128_1, values128_1);
+ values[1] = MM256_SET_M128I(values128_2, values128_2);
+ }
+
+ Q4Bits bits;
+ Scales8KBase s8k;
+ __m256i hbits;
+ __m256i values[2];
+ const __m128i maskl = _mm_set1_epi8(0xf);
+ const __m128i maskh = _mm_set1_epi8(0x30);
+ const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing
+ const __m128i mask = _mm_set1_epi16(254);
+ const __m128i m127 = _mm_set1_epi16(-127);
+ const __m128i m128 = _mm_set1_epi16(-128);
+ const __m128i m1 = _mm_set1_epi16(1);
+ const __m128i m2 = _mm_set1_epi16(2);
+};
+
+struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
+ DequantizerIQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(2, 0) { load_values(values); }
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales);
+ hbits = _mm256_loadu_si256((const __m256i *)x[i].qh);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ auto h = j == 0 ? hbits : _mm256_srli_epi16(hbits, 4);
+ for (int k = 0; k < 4; ++k) {
+ auto qh = _mm256_and_si256(_mm256_slli_epi16(h, 7-k), mh);
+ auto q5vl = _mm256_or_si256(bits.values[k], qh);
+ auto q5vh = _mm256_or_si256(bits.values[k], _mm256_xor_si256(qh, mh));
+ bits.values[k] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
+ }
+ }
+ __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const {
+ uint64_t aux64;
+ memcpy(&aux64, scales_l, 8);
+ auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
+ const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16);
+ auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh);
+ auto sch = _mm_shuffle_epi8(aux, iqxk.hshuff);
+ return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
+ }
+ static void load_values(__m256i * values) {
+ auto values128_1 = _mm_loadu_si128((const __m128i *)iq5nl_values + 0);
+ auto values128_2 = _mm_loadu_si128((const __m128i *)iq5nl_values + 1);
+ values[0] = MM256_SET_M128I(values128_1, values128_1);
+ values[1] = MM256_SET_M128I(values128_2, values128_2);
+ }
+
+ Q4Bits bits;
+ const IQXKScales iqxk;
+ __m256i hbits;
+ __m256i values[2];
+ const __m128i maskl = _mm_set1_epi8(0xf);
+ const __m128i maskh = _mm_set1_epi8(0x30);
+ const __m128i m32 = _mm_set1_epi8(-32);
+ const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing
+};
+
+struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
+ DequantizerIQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(1, 0) { load_values(values); }
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales);
+ auto scales16 = _mm256_cvtepi8_epi16(scales8);
+ iqxk.process(i, d, x[i].extra, scales16, q8, accm, scales);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j);
+ for (int k = 0; k < 4; ++k) {
+ bits.values[k] = make_one(bits.values[k], hbits);
+ hbits = _mm256_srli_epi16(hbits, 2);
+ }
+ }
+ inline __m256i make_one(__m256i l, __m256i hbits) const {
+ auto mask4 = _mm256_cmpeq_epi8(_mm256_and_si256(hbits, mh3), mh3);
+ auto h1 = _mm256_andnot_si256(mask4, hbits);
+ auto mask2 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh1), mh1);
+ auto mask3 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh2), mh2);
+ auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(-1)); // 0xff;
+ return _mm256_or_si256(_mm256_or_si256(_mm256_and_si256(mask1, _mm256_shuffle_epi8(values[0], l)),
+ _mm256_and_si256(mask2, _mm256_shuffle_epi8(values[1], l))),
+ _mm256_or_si256(_mm256_and_si256(mask3, _mm256_shuffle_epi8(values[2], l)),
+ _mm256_and_si256(mask4, _mm256_shuffle_epi8(values[3], l))));
+ }
+ static void load_values(__m256i * values) {
+ for (int k = 0; k < 4; ++k) {
+ auto values128 = _mm_loadu_si128((const __m128i *)iq6nl_values + k);
+ values[k] = MM256_SET_M128I(values128, values128);
+ }
+ }
+
+ Q4Bits bits;
+ const IQXKScales iqxk;
+ __m256i values[4];
+ const __m256i mh1 = _mm256_set1_epi8(1);
+ const __m256i mh2 = _mm256_set1_epi8(2);
+ const __m256i mh3 = _mm256_set1_epi8(3);
+ const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing
+};
+
+inline __m256i get_scale_shuffle_16(int i) {
+ static const uint8_t k_shuffle[128] = {
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
+ };
+ return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+
+inline void set_scales_16(const __m256i& all_scales, __m256i * scales) {
+ scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));
+ scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));
+ scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));
+ scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));
+}
+
+inline __m256i get_scale_shuffle_8(int i) {
+ return _mm256_set1_epi16((2*i) | ((2*i+1) << 8));
+}
+
+inline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) {
+ scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0));
+ scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1));
+ scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2));
+ scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3));
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK_K == 0);
+ const int nb = n/QK_K;
+
+ Q8<nrc_y> q8(info);
+
+ __m256i all_scales[2];
+ __m256i scales[4];
+ __m256 accd[nrc_y];
+
+ Dequantizer deq(vx, bx);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq.new_row(ix);
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+
+ deq.new_block(i, q8, accd, all_scales);
+
+ __m256i sumi[nrc_y];
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ deq.prepare(i, j);
+ set_scales_16(all_scales[j], scales);
+ if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4K> ||
+ std::is_same_v<Dequantizer, DequantizerIQ5K> ||
+ std::is_same_v<Dequantizer, DequantizerIQ6K>) {
+ multiply_add_avx2(deq.bits, scales, j, i, q8, sumi);
+ } else {
+ multiply_add(deq.bits, scales, j, i, q8, sumi);
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+
+ }
+
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y> q8(info);
+
+ Dequantizer deq(vx, bx);
+
+ __m256 accd[nrc_y];
+ __m256i scales[4];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ auto all_scales = deq.new_block(i, q8, accd);
+
+ __m256i sumi[nrc_y];
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ deq.prepare(i, j);
+
+ set_scales_8(all_scales, j, scales);
+
+ if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS>) {
+ multiply_add_avx2(deq.bits, scales, j, i, q8, sumi);
+ } else {
+ multiply_add(deq.bits, scales, j, i, q8, sumi);
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
+ accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+
+ }
+}
+
+#endif
+
+template <int nrc_y>
+//IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
+inline void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
+ __m256i * isum, int16_t min) {
+ auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
+ auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
+ auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
+ auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
+ if constexpr (nrc_y == 1) {
+ auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
+ auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
+ auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
+ auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
+ auto sumi = _mm256_setzero_si256();
+ auto bsums = q8.load_bsums(0, ibl);
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
+ sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
+ sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
+ sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
+#else
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
+#endif
+ isum[0] = _mm256_mullo_epi32(sumi, _mm256_set1_epi32(min));
+
+ } else {
+ auto s1 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9
+ auto s2 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11
+ auto s3 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13
+ auto s4 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, ibl);
+#ifdef HAVE_FANCY_SIMD
+ isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00));
+ isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55));
+ isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa));
+ isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff));
+#else
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
+#endif
+ }
+ }
+}
+
+template <int nrc_y>
+inline void iq2345_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
+ __m256i extra, __m256i * isum, int8_t min, int8_t delta) {
+ auto mask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101);
+ auto vdelta = _mm256_set1_epi8(delta);
+ auto vmin = _mm256_set1_epi8(min);
+ auto min1 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(extra, mask), mask)));
+ auto min2 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(extra, 4), mask), mask)));
+ auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
+ auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
+ auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
+ auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
+ auto m1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
+ auto m2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
+ auto m3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
+ auto m4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
+ auto s1 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 0), _mm256_extracti128_si256(m1, 0)),
+ MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9
+ auto s2 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 1), _mm256_extracti128_si256(m1, 1)),
+ MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11
+ auto s3 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 0), _mm256_extracti128_si256(m2, 0)),
+ MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13
+ auto s4 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 1), _mm256_extracti128_si256(m2, 1)),
+ MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, ibl);
+#ifdef HAVE_FANCY_SIMD
+ isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00));
+ isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55));
+ isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa));
+ isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff));
+#else
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
+#endif
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ auto ms = _mm256_set1_epi8(4);
+ auto m03 = _mm256_set1_epi8(0x03);
+ auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
+ static const uint8_t kvalues_iq2nl[32] = {1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54};
+ auto values = _mm256_loadu_si256((const __m256i*)kvalues_iq2nl);
+ static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
+ auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
+#ifndef HAVE_FANCY_SIMD
+ auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
+#endif
+ int nbl = n / QK_K;
+ __m256 acc[nrc_y] = {};
+ __m256i qx[4];
+ uint64_t stored_scales[8];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq2_k_r4 * iq2 = (const block_iq2_k_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+ auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq2[ibl].extra);
+ auto slbits = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales);
+ auto i8scales1 = _mm256_add_epi8(_mm256_and_si256(slbits, m4), _mm256_set1_epi8(-8));
+ auto i8scales2 = _mm256_add_epi8(_mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4), _mm256_set1_epi8(-8));
+ _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
+ _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
+ __m256i isum[nrc_y] = {};
+ iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -32);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+#ifdef HAVE_FANCY_SIMD
+ auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
+#else
+ auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
+#endif
+ auto lb = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs+ib);
+ auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 2)); extra = _mm256_srli_epi16(extra, 1);
+ shift = _mm256_shuffle_epi8(shift, shift_shuffle);
+ qx[0] = _mm256_and_si256(lb, m03);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(lb, 2), m03);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(lb, 4), m03);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(lb, 6), m03);
+ qx[0] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[0], shift));
+ qx[1] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[1], shift));
+ qx[2] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[2], shift));
+ qx[3] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[3], shift));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, sum);
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ auto ms = _mm256_set1_epi8(8);
+ auto m03 = _mm256_set1_epi8(0x03);
+ auto m04 = _mm256_set1_epi8(0x04);
+ auto smask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101);
+ auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
+ auto values128 = _mm_loadu_si128((const __m128i *)iq3nl_values);
+ auto values = MM256_SET_M128I(values128, values128);
+ values = _mm256_add_epi8(values, _mm256_set1_epi8(64));
+ static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
+ auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
+#ifndef HAVE_FANCY_SIMD
+ auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
+#endif
+ int nbl = n / QK_K;
+ __m256 acc[nrc_y] = {};
+ __m256i qx[4];
+ uint64_t stored_scales[8];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+ auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq3[ibl].extra);
+ auto slbits = _mm256_loadu_si256((const __m256i *)iq3[ibl].scales_l);
+ auto sl1 = _mm256_add_epi8(_mm256_slli_epi16(_mm256_and_si256(slbits, m4), 1), _mm256_set1_epi8(1));
+ auto sl2 = _mm256_add_epi8(_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4), 1), _mm256_set1_epi8(1));
+ auto sh = _mm256_set1_epi64x(((const uint64_t *)iq3[ibl].scales_h)[0]);
+ auto sh1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sh, smask), smask), _mm256_set1_epi8(1));
+ auto sh2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(sh, 4), smask), smask), _mm256_set1_epi8(1));
+ auto i8scales1 = _mm256_sign_epi8(sl1, sh1);
+ auto i8scales2 = _mm256_sign_epi8(sl2, sh2);
+ _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
+ _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
+ __m256i isum[nrc_y] = {};
+ iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -64);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+#ifdef HAVE_FANCY_SIMD
+ auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
+#else
+ auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
+#endif
+ auto lb = _mm256_loadu_si256((const __m256i *)iq3[ibl].qs+ib);
+ auto hbits = _mm_loadu_si128((const __m128i *)iq3[ibl].qh+ib);
+ auto hb = MM256_SET_M128I(hbits, _mm_slli_epi16(hbits, 4));
+ auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 3)); extra = _mm256_srli_epi16(extra, 1);
+ shift = _mm256_shuffle_epi8(shift, shift_shuffle);
+ qx[0] = _mm256_or_si256(_mm256_and_si256(lb, m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 2)));
+ qx[1] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 2), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 3)));
+ qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 4), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 4)));
+ qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 6), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 5)));
+ qx[0] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[0], shift));
+ qx[1] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[1], shift));
+ qx[2] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[2], shift));
+ qx[3] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[3], shift));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00));
+ auto sumi2 = _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55));
+ auto sumi3 = _mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ auto sumi4 = _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4)));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, sum);
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ auto m30 = _mm256_set1_epi8(0x30);
+ auto m32 = _mm256_set1_epi8(32);
+ auto ms = _mm256_set1_epi8(4);
+ auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
+#ifdef HAVE_FANCY_SIMD
+ auto values = load_iq4nl_values_256();
+ static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
+ auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
+#else
+ auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
+ auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
+ auto values = MM256_SET_M128I(values128, values128);
+#endif
+ int nbl = n / QK_K;
+ __m256 acc[nrc_y] = {};
+ __m256i qx[4];
+ uint64_t stored_scales[8];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq4_k_r4 * iq4 = (const block_iq4_k_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+ auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq4[ibl].extra);
+ auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
+ auto sl1 = _mm256_and_si256(slbits, m4);
+ auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4);
+ auto shbits = _mm_loadu_si128((const __m128i*)iq4[ibl].scales_h);
+ auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
+ auto i8scales1 = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(m30, _mm256_slli_epi16(sh, 4))), m32);
+ auto i8scales2 = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(m30, sh)), m32);
+ _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
+ _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
+ __m256i isum[nrc_y] = {};
+#ifdef HAVE_FANCY_SIMD
+ iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128);
+#endif
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+#ifdef HAVE_FANCY_SIMD
+ auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
+#else
+ auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
+#endif
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
+ auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 2)); extra = _mm256_srli_epi16(extra, 1);
+ shift = _mm256_shuffle_epi8(shift, shift_shuffle);
+ qx[0] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)));
+ qx[1] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)));
+ qx[2] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)));
+ qx[3] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)));
+#ifndef HAVE_FANCY_SIMD
+ auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
+ auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
+ auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
+ auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
+ auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
+ auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
+ auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4)));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, sum);
+ }
+ }
+}
+
+static inline __m256i prepare_5bit_quants(const __m256i * values, __m256i ql, __m256i qh, __m256i mask) {
+ auto q5vl = _mm256_shuffle_epi8(values[0], ql);
+ auto q5vh = _mm256_shuffle_epi8(values[1], ql);
+#ifdef HAVE_FANCY_SIMD
+ return _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(qh, mask), mask), q5vl, q5vh);
+#else
+ return _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(qh, mask), mask));
+#endif
+}
+
+template <int nrc_y>
+static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ auto m30 = _mm256_set1_epi8(0x30);
+ auto m32 = _mm256_set1_epi8(32);
+ auto ms = _mm256_set1_epi8(2);
+ auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
+ __m256i values[2];
+ {
+ auto val1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0);
+ auto val2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1);
+ values[0] = MM256_SET_M128I(val1, val1);
+ values[1] = MM256_SET_M128I(val2, val2);
+#ifdef HAVE_FANCY_SIMD
+ values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128));
+ values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128));
+#endif
+ }
+#ifdef HAVE_FANCY_SIMD
+ static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
+ auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
+#else
+ auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
+#endif
+ int nbl = n / QK_K;
+ __m256 acc[nrc_y] = {};
+ __m256i qx[4];
+ uint64_t stored_scales[8];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+ auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq5[ibl].extra);
+ auto slbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l);
+ auto sl1 = _mm256_and_si256(slbits, m4);
+ auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4);
+ auto shbits = _mm_loadu_si128((const __m128i*)iq5[ibl].scales_h);
+ auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
+ auto i8scales1 = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(m30, _mm256_slli_epi16(sh, 4))), m32);
+ auto i8scales2 = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(m30, sh)), m32);
+ _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
+ _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
+ __m256i isum[nrc_y] = {};
+#ifdef HAVE_FANCY_SIMD
+ if constexpr (nrc_y == 1) {
+ iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128);
+ } else {
+ iq2345_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, extra, isum, -128, 2);
+ }
+#endif
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+#ifdef HAVE_FANCY_SIMD
+ auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
+#else
+ auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
+#endif
+ auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
+ auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
+ auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib);
+ auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits);
+ qx[0] = _mm256_and_si256(lbits1, m4);
+ qx[1] = _mm256_and_si256(lbits2, m4);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4);
+
+ qx[0] = prepare_5bit_quants(values, qx[0], hb, _mm256_set1_epi8(0x01));
+ qx[1] = prepare_5bit_quants(values, qx[1], hb, _mm256_set1_epi8(0x10));
+ qx[2] = prepare_5bit_quants(values, qx[2], hb, _mm256_set1_epi8(0x02));
+ qx[3] = prepare_5bit_quants(values, qx[3], hb, _mm256_set1_epi8(0x20));
+#ifdef HAVE_FANCY_SIMD
+ if constexpr (nrc_y == 1) {
+ auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
+ shift = _mm256_shuffle_epi8(shift, shift_shuffle);
+ qx[0] = _mm256_add_epi8(qx[0], shift);
+ qx[1] = _mm256_add_epi8(qx[1], shift);
+ qx[2] = _mm256_add_epi8(qx[2], shift);
+ qx[3] = _mm256_add_epi8(qx[3], shift);
+ }
+#else
+ auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
+ shift = _mm256_shuffle_epi8(shift, shift_shuffle);
+ qx[0] = _mm256_add_epi8(qx[0], shift);
+ qx[1] = _mm256_add_epi8(qx[1], shift);
+ qx[2] = _mm256_add_epi8(qx[2], shift);
+ qx[3] = _mm256_add_epi8(qx[3], shift);
+ auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
+ auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
+ auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
+ auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
+ auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
+ auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
+ auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4)));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, sum);
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+#ifndef HAVE_FANCY_SIMD
+ auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
+ auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
+ auto values = MM256_SET_M128I(values128, values128);
+#else
+ auto values = load_iq4nl_values_256();
+#endif
+ int nbl = n / QK_K;
+ using helper_t = union { __m256i vec; uint32_t val[8]; };
+#ifndef HAVE_FANCY_SIMD
+ helper_t h, h_shift;
+#else
+ using helper512_t = union { __m512i vec; uint64_t val[8]; };
+ helper_t h;
+ helper512_t h_shift;
+#endif
+ __m256 acc[nrc_y] = {};
+ __m256i isum[nrc_y] = {};
+ __m256i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto dptr = (const float *)((const char *)vx + (ix+0)*bx);
+ const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4);
+ auto d4 = _mm_loadu_ps(dptr);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto scales = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales);
+ h.vec = _mm256_sub_epi8(_mm256_and_si256(scales, _mm256_set1_epi8(-2)), _mm256_set1_epi8(127));
+#ifndef HAVE_FANCY_SIMD
+ h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 2);
+ {
+ __m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[0])))),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[0])))));
+ __m256 v2 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[1])))),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[1])))));
+ __m256 v3 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[2])))),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[2])))));
+ __m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[3])))),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[3])))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto m8 = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
+ acc[iy] = _mm256_fmadd_ps(v1, _mm256_shuffle_ps(m8, m8, 0x00), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(v2, _mm256_shuffle_ps(m8, m8, 0x55), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(v3, _mm256_shuffle_ps(m8, m8, 0xaa), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(v4, _mm256_shuffle_ps(m8, m8, 0xff), acc[iy]);
+ }
+ }
+#else
+ auto shift = _mm256_add_epi8(_mm256_set1_epi8(-64), _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1));
+ h_shift.vec = _mm512_mullo_epi16(_mm512_cvtepi8_epi16(shift), _mm512_cvtepi8_epi16(h.vec));
+#endif
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+#ifdef HAVE_FANCY_SIMD
+ auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib]));
+ auto ishifts = _mm256_cvtepi16_epi32(_mm_set1_epi64x(h_shift.val[ib]));
+ auto scales_m = _mm256_cvtepi32_ps(ishifts);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
+ acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
+ }
+#endif
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
+ qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
+ qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
+ qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
+ qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
+#ifndef HAVE_FANCY_SIMD
+ auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle);
+ auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
+ auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
+ auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
+ auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
+ auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
+ auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
+ auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, ibl)), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, _mm_mul_ps(d4, sum));
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ __m256i values[2];
+ {
+ auto val1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0);
+ auto val2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1);
+ values[0] = MM256_SET_M128I(val1, val1);
+ values[1] = MM256_SET_M128I(val2, val2);
+#ifdef HAVE_FANCY_SIMD
+ values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128));
+ values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128));
+#endif
+ }
+ int nbl = n / QK_K;
+ using helper_t = union { __m256i vec; uint32_t val[8]; };
+#ifndef HAVE_FANCY_SIMD
+ helper_t h, h_shift;
+ auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
+#else
+ using helper512_t = union { __m512i vec; uint64_t val[8]; };
+ helper_t h;
+ helper512_t h_shift;
+#endif
+ __m256 acc[nrc_y] = {};
+ __m256i isum[nrc_y] = {};
+ __m256i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto dptr = (const float *)((const char *)vx + (ix+0)*bx);
+ const block_iq5_ks_r4 * iq5 = (const block_iq5_ks_r4 *)(dptr + 4);
+ auto d4 = _mm_loadu_ps(dptr);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto scales = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales);
+ h.vec = _mm256_sub_epi8(_mm256_and_si256(scales, _mm256_set1_epi8(-2)), _mm256_set1_epi8(127));
+#ifndef HAVE_FANCY_SIMD
+ h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1);
+ {
+ __m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[0])))),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[0])))));
+ __m256 v2 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[1])))),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[1])))));
+ __m256 v3 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[2])))),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[2])))));
+ __m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[3])))),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[3])))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto m8 = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
+ acc[iy] = _mm256_fmadd_ps(v1, _mm256_shuffle_ps(m8, m8, 0x00), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(v2, _mm256_shuffle_ps(m8, m8, 0x55), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(v3, _mm256_shuffle_ps(m8, m8, 0xaa), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(v4, _mm256_shuffle_ps(m8, m8, 0xff), acc[iy]);
+ }
+ }
+#else
+ auto shift = _mm256_add_epi8(_mm256_set1_epi8(-64), _mm256_and_si256(scales, _mm256_set1_epi8(1)));
+ h_shift.vec = _mm512_mullo_epi16(_mm512_cvtepi8_epi16(shift), _mm512_cvtepi8_epi16(h.vec));
+#endif
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+#ifdef HAVE_FANCY_SIMD
+ auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib]));
+ auto ishifts = _mm256_cvtepi16_epi32(_mm_set1_epi64x(h_shift.val[ib]));
+ auto scales_m = _mm256_cvtepi32_ps(ishifts);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
+ acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
+ }
+#endif
+ auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
+ auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
+ auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib);
+ auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits);
+ qx[0] = _mm256_and_si256(lbits1, m4);
+ qx[1] = _mm256_and_si256(lbits2, m4);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4);
+
+ qx[0] = prepare_5bit_quants(values, qx[0], hb, _mm256_set1_epi8(0x01));
+ qx[1] = prepare_5bit_quants(values, qx[1], hb, _mm256_set1_epi8(0x10));
+ qx[2] = prepare_5bit_quants(values, qx[2], hb, _mm256_set1_epi8(0x02));
+ qx[3] = prepare_5bit_quants(values, qx[3], hb, _mm256_set1_epi8(0x20));
+
+#ifndef HAVE_FANCY_SIMD
+ auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle);
+ auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
+ auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
+ auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
+ auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
+ auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
+ auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
+ auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, ibl)), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, _mm_mul_ps(d4, sum));
+ }
+ }
+}
+
+
+template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
+#ifdef HAVE_FANCY_SIMD
+ if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2KS> ||
+ std::is_same_v<Dequantizer, DequantizerIQ4KS> ||
+ std::is_same_v<Dequantizer, DequantizerIQ5KS>) {
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_iqX_k_q8_K_AVX512_new, Dequantizer, funcs)
+ } else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2K>) {
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_AVX512, Dequantizer, funcs);
+ funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
+ } else {
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_iqX_k_q8_K_AVX512, Dequantizer, funcs);
+ }
+#else
+ if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2K>||
+ std::is_same_v<Dequantizer, DequantizerIQ3K>||
+ std::is_same_v<Dequantizer, DequantizerIQ4K>||
+ std::is_same_v<Dequantizer, DequantizerIQ5K>||
+ std::is_same_v<Dequantizer, DequantizerIQ6K>) {
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_K_T, Dequantizer, funcs);
+ } else {
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, Dequantizer, funcs);
+ }
+
+#endif
+}
+
+} // namespace
+
+bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
+
+ auto etypeA = ggml_type(typeA);
+ auto expected_type_B = etypeA == GGML_TYPE_IQ4_KS_R4 || etypeA == GGML_TYPE_IQ5_KS_R4 ? GGML_TYPE_Q8_K32 : GGML_TYPE_Q8_K;
+ if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
+ return false;
+ }
+
+ func16 = nullptr;
+
+ switch (typeA) {
+ case GGML_TYPE_IQ2_KS:
+ set_functions<DequantizerIQ2KS>(kernels);
+ break;
+ case GGML_TYPE_IQ2_K:
+ set_functions<DequantizerIQ2K>(kernels);
+ break;
+ case GGML_TYPE_IQ3_K:
+ set_functions<DequantizerIQ3K>(kernels);
+ break;
+ case GGML_TYPE_IQ4_KSS:
+ set_functions<DequantizerIQ4KSS>(kernels);
+ break;
+ case GGML_TYPE_IQ4_KS:
+ set_functions<DequantizerIQ4KS>(kernels);
+ break;
+ case GGML_TYPE_IQ4_K:
+ set_functions<DequantizerIQ4K>(kernels);
+ break;
+ case GGML_TYPE_IQ5_KS:
+ set_functions<DequantizerIQ5KS>(kernels);
+ break;
+ case GGML_TYPE_IQ5_K:
+ set_functions<DequantizerIQ5K>(kernels);
+ break;
+ case GGML_TYPE_IQ6_K:
+ set_functions<DequantizerIQ6K>(kernels);
+ break;
+ case GGML_TYPE_IQ2_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_k_r4_q8_k, kernels);
+ break;
+ case GGML_TYPE_IQ3_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_k_r4_q8_k, kernels);
+#ifdef HAVE_FANCY_SIMD
+ func16 = mul_mat_iq3_k_r4_q8_k<16>;
+#endif
+ break;
+ case GGML_TYPE_IQ4_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_k_r4_q8_k, kernels);
+ func16 = mul_mat_iq4_k_r4_q8_k<16>;
+ break;
+ case GGML_TYPE_IQ4_KS_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_ks_r4_q8_k, kernels);
+#ifndef HAVE_FANCY_SIMD
+ // For some reason Zen4 does not like this particular function
+ func16 = mul_mat_iq4_ks_r4_q8_k<16>;
+#endif
+ break;
+ case GGML_TYPE_IQ5_KS_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_ks_r4_q8_k, kernels);
+#ifndef HAVE_FANCY_SIMD
+ // For some reason Zen4 does not like this particular function
+ func16 = mul_mat_iq5_ks_r4_q8_k<16>;
+#endif
+ break;
+ case GGML_TYPE_IQ5_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_k_r4_q8_k, kernels);
+ func16 = mul_mat_iq5_k_r4_q8_k<16>;
+ break;
+ default:
+ return false;
+ }
+
+ return true;
+
+}
+
+#else
+// ----------------------------------------- __aarch64__ ---------------------------------------------
+
+namespace {
+
+inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
+ int32x4x4_t scales = {
+ vmovl_s16(vget_low_s16 (scales16.val[0])),
+ vmovl_s16(vget_high_s16(scales16.val[0])),
+ vmovl_s16(vget_low_s16 (scales16.val[1])),
+ vmovl_s16(vget_high_s16(scales16.val[1])),
+ };
+ return scales;
+}
+
+inline int32x4x4_t make_wider_8(const int8x16_t& scales8) {
+ int16x8x2_t scales16{vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8))};
+ return make_wider(scales16);
+}
+
+template <typename Q8>
+inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto q8s = q8.load_bsums(iy, i);
+ int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0]));
+ int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0]));
+ int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1]));
+ int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1]));
+ float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4)));
+ acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
+ }
+}
+
+struct Scale16Extra {
+ template <typename Q8>
+ static inline int32x4x4_t new_block(int i, float d, uint16_t extra, uint8_t val,
+ const int8x16_t& scales8, const Q8& q8, float32x4_t * acc) {
+ uint8x16_t e8 = vreinterpretq_u8_u16(vdupq_n_u16(extra));
+ e8 = vceqq_u8(vandq_u8(e8, emask), emask);
+ e8 = vqtbl1q_u8(vandq_u8(e8, vdupq_n_u8(val)), eshuff);
+ int16x8x2_t extra16 = {vmull_s8(vget_low_s8 (e8), vget_low_s8 (scales8)),
+ vmull_s8(vget_high_s8(e8), vget_high_s8(scales8))};
+ accum_mins_16(extra16, q8, acc, i, d);
+ return make_wider_8(scales8);
+ }
+
+ constexpr static uint32x4_t emask = {0x02020101, 0x08080404, 0x20201010, 0x80804040};
+ constexpr static uint32x4_t eshuff = {0x06040200, 0x0e0c0a08, 0x07050301, 0x0f0d0b09};
+};
+
+// Note: on ARM_NEON we cannot use the values shifted into the uint8_t range because
+// the ARM_NEON only has vdotq_s32 or vdotq_u32, where both operands need to
+// be signed or unsigned. As the Q8_K quants are signed, we need to have the
+// iq4_s quants also signed. We can only use unsigned values in k-quants
+// because they are all within the valid int8_t range.
+struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
+ DequantizerIQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8(iq4k_values)) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ return Scale16Extra::new_block(i, d, x[i].extra, 4, make_scales(x[i].scales_l, x[i].scales_h), q8, acc);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare16(x[i].qs+64*j);
+ for (int k = 0; k < 4; ++k) {
+ bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]);
+ bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]);
+ }
+ }
+ inline int8x16_t make_scales(const uint8_t * scales_l, const uint8_t * scales_h) const {
+ uint8x8_t aux = vld1_u8(scales_l);
+ uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
+ const uint32_t * aux32 = (const uint32_t *)scales_h;
+ uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2};
+ uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30));
+ int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, hshuff));
+ return vaddq_s8(vqtbl1q_s8(scales8, hshuff), vdupq_n_s8(-32));
+ }
+
+ Q4bits bits;
+ const int8x16_t values;
+ const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
+
+};
+
+struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
+ DequantizerIQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq5nl_values)) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ hbits = vld1q_u8_x2(x[i].qh); // hbits.val[0] holds 0....15, 32...47, 64...79, 96...111, 128...143, 160...175, 192...207, 224...239
+ // hbits.val[1] holds 16...31, 48...63, 80...95, 112..127, 144...159, 176...191, 208...223, 240...255
+ return Scale16Extra::new_block(i, d, x[i].extra, 2, make_scales(x[i].scales_l, x[i].scales_h), q8, acc);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs+64*j);
+ if (j == 1) {
+ for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4);
+ }
+ bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm));
+ bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm));
+ bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm));
+ bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm));
+ bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm));
+ bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm));
+ bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm));
+ bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm));
+ for (int k = 0; k < 4; ++k) {
+ bits.b1.val[k] = vqtbl2q_s8(values, bits.b1.val[k]);
+ bits.b2.val[k] = vqtbl2q_s8(values, bits.b2.val[k]);
+ }
+ }
+ inline int8x16_t make_scales(const uint8_t * scales_l, const uint8_t * scales_h) const {
+ uint8x8_t aux = vld1_u8(scales_l);
+ uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
+ const uint32_t * aux32 = (const uint32_t *)scales_h;
+ uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2};
+ uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30));
+ int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, hshuff));
+ return vaddq_s8(vqtbl1q_s8(scales8, hshuff), vdupq_n_s8(-32));
+ }
+
+ Q4bits bits;
+ const int8x16x2_t values;
+ const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
+ const uint8x16_t hm = vdupq_n_u8(0x10);
+ uint8x16x2_t hbits;
+
+};
+
+struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
+ DequantizerIQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x4(iq6nl_values)) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ return Scale16Extra::new_block(i, d, x[i].extra, 1, vld1q_s8(x[i].scales), q8, acc);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs+64*j);
+ auto hbits = vld1q_u8_x2(x[i].qh + 32*j);
+ bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm));
+ bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm));
+ bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm));
+ bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm));
+ bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hm));
+ bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hm));
+ bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), hm));
+ bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), hm));
+ for (int k = 0; k < 4; ++k) {
+ bits.b1.val[k] = vqtbl4q_s8(values, bits.b1.val[k]);
+ bits.b2.val[k] = vqtbl4q_s8(values, bits.b2.val[k]);
+ }
+ }
+
+ Q4bits bits;
+ const int8x16x4_t values;
+ const uint8x16_t hm = vdupq_n_u8(0x30);
+
+};
+
+struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
+ DequantizerIQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ return Scale16Extra::new_block(i, d, x[i].extra, 5, make_scales(x[i].scales), q8, acc);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs+32*j);
+ for (int k = 0; k < 4; ++k) {
+ bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]);
+ bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]);
+ }
+ }
+ inline int8x16_t make_scales(const uint8_t * scales_l) const {
+ uint8x8_t aux = vld1_u8(scales_l);
+ uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
+ int8x16_t scales = vaddq_s8(vreinterpretq_s8_u8(scl8), vdupq_n_s8(-8));
+ return vqtbl1q_s8(scales, hshuff);
+ }
+
+ Q2bits bits;
+ const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x000000001101f3e1));
+ const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
+
+};
+
+struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
+ DequantizerIQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ return Scale16Extra::new_block(i, d, x[i].extra, 4, make_scales(x[i].scales_h, x[i].scales_l), q8, acc);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs+32*j);
+ if (j == 0) {
+ hbits = vld1q_u8_x2(x[i].qh);
+ }
+ else {
+ hbits.val[0] = vshrq_n_u8(hbits.val[0], 4);
+ hbits.val[1] = vshrq_n_u8(hbits.val[1], 4);
+ }
+ bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hmask));
+ bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hmask));
+ bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hmask));
+ bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hmask));
+ bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hmask));
+ bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hmask));
+ bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 1), hmask));
+ bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 1), hmask));
+ for (int k = 0; k < 4; ++k) {
+ bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]);
+ bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]);
+ }
+ }
+ inline int8x16_t make_scales(uint16_t sign_bits, const uint8_t * scales_l) const {
+ uint8x8_t aux = vld1_u8(scales_l);
+ uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
+ int8x16_t scales = vaddq_s8(vreinterpretq_s8_u8(vshlq_n_u8(scl8, 1)), vdupq_n_s8(1));
+ uint8x16_t signs = vceqq_u8(vandq_u8(vreinterpretq_u8_u16(vdupq_n_u16(sign_bits)), sign_mask), sign_mask);
+ signs = vorrq_u8(signs, vdupq_n_u8(1));
+ // scales are 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15
+ // signs are 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15
+ scales = vmulq_s8(scales, vreinterpretq_s8_u8(vqtbl1q_u8(signs, sign_shuffle)));
+ return vqtbl1q_s8(scales, hshuff);
+ }
+ inline static uint8x16_t load_sign_shuffle() {
+ static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
+ return vld1q_u8(k_shuff);
+ }
+
+ Q2bits bits;
+ uint8x16x2_t hbits;
+ const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x2f1c0d01f6e9d8c1));
+ const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
+ const uint8x16_t hmask = vdupq_n_u8(4);
+ const uint8x16_t sign_mask = vreinterpretq_u8_u64(uint64x2_t{0x0808040402020101, 0x8080404020201010});
+ const uint8x16_t sign_shuffle = load_sign_shuffle();
+
+};
+
+struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
+
+ DequantizerIQ4KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ (void)q8;
+ (void)acc;
+ auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(vld1_u8(x[i].scales)), mask)), m127);
+ int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
+ return scales;
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare16(x[i].qs+64*j);
+ const uint32_t * scales32 = (const uint32_t *)x[i].scales;
+ uint32_t aux32 = scales32[j] & 0x01010101;
+ const uint8_t * aux8 = (const uint8_t *)&aux32;
+ for (int k = 0; k < 4; ++k) {
+ bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values.val[aux8[k/2+0]], bits.b1.val[k]));
+ bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values.val[aux8[k/2+2]], bits.b2.val[k]));
+ }
+ }
+
+ Q4bits bits;
+ const int8x16x2_t values;
+ const uint16x8_t mask = vdupq_n_u16(254);
+ const int16x8_t m127 = vdupq_n_s16(-127);
+};
+
+struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
+ DequantizerIQ5KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc),
+ values(vld1q_s8_x4(iq5nl_values)) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ (void)q8;
+ (void)acc;
+ auto sas8 = vld1_u8(x[i].scales);
+ auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(sas8), mask)), m127);
+ hbits = vld1q_u8_x2(x[i].qh);
+ sas = vcombine_u8(sas8, sas8);
+ sas = vshlq_n_u8(vandq_u8(sas, vdupq_n_u8(1)), 5);
+ int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
+ return scales;
+ }
+
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs+64*j);
+ if (j == 1) {
+ for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4);
+ }
+ auto shift = vdupq_n_u8((x[i].scales[4*j+0] & 1) << 5);
+ bits.b1.val[0] = vaddq_u8(shift, vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm)));
+ bits.b1.val[1] = vaddq_u8(shift, vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm)));
+ shift = vdupq_n_u8((x[i].scales[4*j+1] & 1) << 5);
+ bits.b1.val[2] = vaddq_u8(shift, vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm)));
+ bits.b1.val[3] = vaddq_u8(shift, vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm)));
+ for (int k = 0; k < 4; ++k) bits.b1.val[k] = vqtbl4q_s8(values, bits.b1.val[k]);
+ shift = vdupq_n_u8((x[i].scales[4*j+2] & 1) << 5);
+ bits.b2.val[0] = vaddq_u8(shift, vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm)));
+ bits.b2.val[1] = vaddq_u8(shift, vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm)));
+ shift = vdupq_n_u8((x[i].scales[4*j+3] & 1) << 5);
+ bits.b2.val[2] = vaddq_u8(shift, vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm)));
+ bits.b2.val[3] = vaddq_u8(shift, vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm)));
+ for (int k = 0; k < 4; ++k) bits.b2.val[k] = vqtbl4q_s8(values, bits.b2.val[k]);
+ }
+
+ Q4bits bits;
+ const int8x16x4_t values;
+ const uint8x16_t hm = vdupq_n_u8(0x10);
+ const uint16x8_t mask = vdupq_n_u16(254);
+ const int16x8_t m127 = vdupq_n_s16(-127);
+ uint8x16x2_t hbits;
+ uint8x16_t sas;
+
+};
+
+struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
+
+ DequantizerIQ4KSS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ (void)q8;
+ (void)acc;
+ auto q4bits_1 = vld1q_u16_x4((const uint16_t *)x[i].qs);
+ q4bits_2 = vld1q_u16_x4((const uint16_t *)x[i].qs + 32);
+ for (int k = 0; k < 4; ++k) {
+ aux[k+0] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_1.val[k], m1), shift));
+ aux[k+4] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_2.val[k], m1), shift));
+ q4bits_1.val[k] = vandq_u16(q4bits_1.val[k], bmask);
+ q4bits_1.val[k] = veorq_u16(q4bits_1.val[k], vshrq_n_u16(q4bits_1.val[k], 1));
+ q4bits_2.val[k] = vandq_u16(q4bits_2.val[k], bmask);
+ q4bits_2.val[k] = veorq_u16(q4bits_2.val[k], vshrq_n_u16(q4bits_2.val[k], 1));
+ }
+ make_quants(q4bits_1, bits, aux);
+ auto scales16 = vld1q_s16(aux);
+ scales16 = vaddq_s16(vandq_s16(scales16, mask), m127);
+ int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
+ return scales;
+ }
+ inline void make_quants(uint16x8x4_t& q4bits, Q4bits& bits, const int16_t * aux) const {
+ bits.b1.val[0] = vqtbl1q_s8(values.val[aux[0] & 1], vandq_u8(q4bits.val[0], bits.m4b));
+ bits.b1.val[1] = vqtbl1q_s8(values.val[aux[0] & 1], vshrq_n_u8(q4bits.val[0], 4));
+ bits.b1.val[2] = vqtbl1q_s8(values.val[aux[1] & 1], vandq_u8(q4bits.val[1], bits.m4b));
+ bits.b1.val[3] = vqtbl1q_s8(values.val[aux[1] & 1], vshrq_n_u8(q4bits.val[1], 4));
+ bits.b2.val[0] = vqtbl1q_s8(values.val[aux[2] & 1], vandq_u8(q4bits.val[2], bits.m4b));
+ bits.b2.val[1] = vqtbl1q_s8(values.val[aux[2] & 1], vshrq_n_u8(q4bits.val[2], 4));
+ bits.b2.val[2] = vqtbl1q_s8(values.val[aux[3] & 1], vandq_u8(q4bits.val[3], bits.m4b));
+ bits.b2.val[3] = vqtbl1q_s8(values.val[aux[3] & 1], vshrq_n_u8(q4bits.val[3], 4));
+ }
+ inline void prepare([[maybe_unused]] int i, int j) {
+ if (j == 0) return;
+ make_quants(q4bits_2, bits, aux+4);
+ }
+ static int16x8_t load_shift() {
+ static const int16_t k_shift[8] = {0, 1, 2, 3, 4, 5, 6, 7};
+ return vld1q_s16(k_shift);
+ }
+
+ Q4bits bits;
+ const int8x16x2_t values;
+ const uint16x8_t mask = vdupq_n_s16(254);
+ const uint16x8_t bmask = vdupq_n_u16(0xfffe);
+ const uint16x8_t m1 = vdupq_n_u16(1);
+ const int16x8_t shift = load_shift();
+ const int16x8_t m127 = vdupq_n_s16(-127);
+ uint16x8x4_t q4bits_2;
+ int16_t aux[8];
+};
+
+struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
+ DequantizerIQ2KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) {
+ const uint16_t * sc16 = (const uint16_t *)x[i].scales;
+ uint32_t aux32 = sc16[0] | (sc16[1] << 16);
+ uint8x8_t scales8 = vreinterpret_u8_u32(vdup_n_u32(aux32));
+ scales8 = vand_u8(vzip1_u8(scales8, vshr_n_u8(scales8, 4)), vdup_n_u8(0xf));
+ uint8x8_t sh = vand_u8(vceq_u8(vand_u8(vdup_n_u8(x[i].extra >> 8), hmask), vdup_n_u8(0)), vdup_n_u8(16));
+ int16x8_t scales16 = vmovl_s8(vsub_s8(vreinterpret_s8_u8(scales8), vreinterpret_s8_u8(sh)));
+ int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
+ return scales;
+ }
+ inline void prepare(int i, int j) {
+ uint8_t extra = x[i].extra >> 4*j;
+ bits.prepare(x[i].qs+32*j);
+ bits.b1.val[0] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[0]);
+ bits.b1.val[1] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[1]); extra >>= 1;
+ bits.b1.val[2] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[2]);
+ bits.b1.val[3] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[3]); extra >>= 1;
+ bits.b2.val[0] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[0]);
+ bits.b2.val[1] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[1]); extra >>= 1;
+ bits.b2.val[2] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[2]);
+ bits.b2.val[3] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[3]);
+ }
+
+ Q2bits bits;
+ const uint8x8_t hmask = vreinterpret_u8_u64(vdup_n_u64(0x8040201008040201));
+ const int8x16x2_t values = { vreinterpretq_s8_u64(vdupq_n_u64(0x1101f3e1)), vreinterpretq_s8_u64(vdupq_n_u64(0x1606f8e6)) };
+
+};
+
+template <int nrc_y>
+void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = vdupq_n_u8(0xf);
+ auto values = vld1q_s8(iq4k_values);
+ int nbl = n / QK_K;
+ int8x16_t qx[8];
+ int16x8x4_t iscales;
+ int32x4x4_t scales;
+ float32x4_t acc[nrc_y] = {};
+ int32x4_t isum[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto dptr = (const float *)((const char *)vx + ix*bx);
+ auto d4 = vld1q_f32(dptr);
+ const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto sas = vld1q_u8_x2(iq4[ibl].scales);
+ auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254));
+ iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
+ iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
+ scale = vandq_u8(sas.val[1], vdupq_n_u8(254));
+ iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
+ iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
+ // Adding the block shifts costs us ~9% in performance drop.
+ // Is there a better way?
+ sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 2);
+ sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 2);
+ {
+ auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0])));
+ auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0])));
+ auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1])));
+ auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1])));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums);
+ auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]);
+ auto b8 = vget_low_s16(bs);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
+ b8 = vget_high_s16(bs);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
+ }
+ }
+ for (int is = 0; is < 2; ++is) {
+ scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0]));
+ scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0]));
+ scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1]));
+ scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1]));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
+ prepare_iq4_nl_quants(values, m4, bits, qx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy]));
+ isum[iy] = vdupq_n_s32(0);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vmulq_f32(d4, acc[iy]));
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = vdupq_n_u8(0xf);
+ auto m10 = vdupq_n_u8(0x10);
+ auto values = vld1q_s8_x2(iq5nl_values);
+ int nbl = n / QK_K;
+ int8x16_t qx[8];
+ int16x8x4_t iscales;
+ int32x4x4_t scales;
+ float32x4_t acc[nrc_y] = {};
+ int32x4_t isum[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto dptr = (const float *)((const char *)vx + ix*bx);
+ auto d4 = vld1q_f32(dptr);
+ const block_iq5_ks_r4 * iq5 = (const block_iq5_ks_r4 *)(dptr + 4);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto sas = vld1q_u8_x2(iq5[ibl].scales);
+ auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254));
+ iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
+ iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
+ scale = vandq_u8(sas.val[1], vdupq_n_u8(254));
+ iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
+ iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
+ // Adding the block shifts costs us ~9% in performance drop.
+ // Is there a better way?
+ sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 1);
+ sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 1);
+ {
+ auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0])));
+ auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0])));
+ auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1])));
+ auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1])));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums);
+ auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]);
+ auto b8 = vget_low_s16(bs);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
+ b8 = vget_high_s16(bs);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
+ }
+ }
+ for (int is = 0; is < 2; ++is) {
+ scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0]));
+ scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0]));
+ scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1]));
+ scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1]));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
+ auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
+ qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4)));
+ qx[1] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2)));
+ qx[2] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits));
+ qx[3] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2)));
+ qx[4] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3)));
+ qx[5] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1)));
+ qx[6] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1)));
+ qx[7] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3)));
+ for (int l = 0; l < 8; ++l) qx[l] = vqtbl2q_s8(values, qx[l]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy]));
+ isum[iy] = vdupq_n_s32(0);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vmulq_f32(d4, acc[iy]));
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y, int k_shift>
+inline void iq3_4_add_shift(int ibl, const Q8<nrc_y, block_q8_K>& q8, const int8x16x4_t& i8scales, uint8x16_t extra,
+ int32x4_t * isum) {
+ auto ms = vdupq_n_s8(k_shift);
+ int8x16_t s8_1, s8_2;
+ if constexpr (k_shift == 5) {
+ auto m1 = vdupq_n_u8(1);
+ s8_1 = vmulq_s8(i8scales.val[0], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
+ s8_2 = vmulq_s8(i8scales.val[1], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
+ } else {
+ if constexpr (k_shift == 4) {
+ s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 2)));
+ s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, extra));
+ } else {
+ s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 1)));
+ s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, vshrq_n_u8(extra, 1)));
+ }
+ }
+ auto s16_1 = vmovl_s8(vget_low_s8 (s8_1));
+ auto s16_2 = vmovl_s8(vget_high_s8(s8_1));
+ auto s16_3 = vmovl_s8(vget_low_s8 (s8_2));
+ auto s16_4 = vmovl_s8(vget_high_s8(s8_2));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto b8 = vld1_s16(q8.y[iy][ibl].bsums);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
+ b8 = vld1_s16(q8.y[iy][ibl].bsums+4);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
+ }
+ if constexpr (k_shift == 5) {
+ auto m1 = vdupq_n_u8(1);
+ s8_1 = vmulq_s8(i8scales.val[2], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
+ s8_2 = vmulq_s8(i8scales.val[3], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
+ } else {
+ if constexpr (k_shift == 4) {
+ s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 2)));
+ s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 4)));
+ } else {
+ s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 3)));
+ s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 5)));
+ }
+ }
+ s16_1 = vmovl_s8(vget_low_s8 (s8_1));
+ s16_2 = vmovl_s8(vget_high_s8(s8_1));
+ s16_3 = vmovl_s8(vget_low_s8 (s8_2));
+ s16_4 = vmovl_s8(vget_high_s8(s8_2));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto b8 = vld1_s16(q8.y[iy][ibl].bsums+8);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
+ b8 = vld1_s16(q8.y[iy][ibl].bsums+12);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
+ }
+}
+
+template <int nrc_y>
+void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = vdupq_n_u8(0xf);
+ auto m03 = vdupq_n_u8(0x03);
+ auto ms = vdupq_n_u8(4);
+ uint8x16x2_t shift_shuffle = {
+ vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
+ vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
+ };
+ auto values8 = vld1_s8(iq2nl_values);
+ auto values = vcombine_s8(values8, values8);
+ int nbl = n / QK_K;
+ int8x16_t qx[4];
+ int8x16x4_t i8scales;
+ int16x8x4_t i16scales;
+ float32x4_t acc[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq2_k_r4 * iq2 = (const block_iq2_k_r4 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
+ auto extra8 = vld1_u8(iq2[ibl].extra);
+ uint8x16_t extra;
+ if constexpr (nrc_y == 1) {
+ extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
+ } else {
+ extra = vcombine_u8(extra8, extra8);
+ }
+ auto sl = vld1q_u8_x2(iq2[ibl].scales);
+ i8scales.val[0] = vaddq_s8(vandq_u8(sl.val[0], m4), vdupq_n_s8(-8));
+ i8scales.val[1] = vaddq_s8(vandq_u8(sl.val[1], m4), vdupq_n_s8(-8));
+ i8scales.val[2] = vaddq_s8(vshrq_n_u8(sl.val[0], 4), vdupq_n_s8(-8));
+ i8scales.val[3] = vaddq_s8(vshrq_n_u8(sl.val[1], 4), vdupq_n_s8(-8));
+ int32x4_t isum[nrc_y] = {};
+ if constexpr (nrc_y == 1) {
+ iq3_4_add_shift<nrc_y, 5>(ibl, q8, i8scales, extra, isum);
+ }
+ for (int is = 0; is < 2; ++is) {
+ i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
+ i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
+ i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
+ i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
+ auto bits = vld1q_u8_x2(iq2[ibl].qs + 128*is + 32*ib);
+ qx[0] = vandq_u8( bits.val[0], m03);
+ qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m03);
+ qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m03);
+ qx[3] = vandq_u8(vshrq_n_u8(bits.val[0], 6), m03);
+ uint8x16_t shifts;
+ if constexpr (nrc_y == 1) {
+ qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
+ qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
+ qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
+ } else {
+ shifts = vandq_u8(ms, vshlq_n_u8(extra, 2));
+ auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
+ extra = vshrq_n_u8(extra, 1);
+ qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
+ qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
+ qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ qx[0] = vandq_u8( bits.val[1], m03);
+ qx[1] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m03);
+ qx[2] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m03);
+ qx[3] = vandq_u8(vshrq_n_u8(bits.val[1], 6), m03);
+ if constexpr (nrc_y == 1) {
+ qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
+ qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
+ qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
+ } else {
+ auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
+ qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
+ qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
+ qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
+ }
+ scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = vdupq_n_u8(0xf);
+ auto ms = nrc_y == 1 ? vdupq_n_u8(4) : vdupq_n_u8(8);
+ auto m03 = vdupq_n_u8(0x03);
+ auto m04 = vdupq_n_u8(0x04);
+ uint8x16x2_t shift_shuffle = {
+ vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
+ vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
+ };
+ uint8x16x2_t smask = { vcombine_u8(vdup_n_u8(1), vdup_n_u8(2)), vcombine_u8(vdup_n_u8(4), vdup_n_u8(8)) };
+ auto values = vld1q_s8(iq3nl_values);
+ int nbl = n / QK_K;
+ int8x16_t qx[4];
+ int8x16x4_t i8scales;
+ int16x8x4_t i16scales;
+ float32x4_t acc[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
+ auto extra8 = vld1_u8(iq3[ibl].extra);
+ uint8x16_t extra;
+ if constexpr (nrc_y == 1) {
+ extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
+ } else {
+ extra = vcombine_u8(extra8, extra8);
+ }
+ auto sl = vld1q_u8_x2(iq3[ibl].scales_l);
+ auto sh8 = vld1_u8(iq3[ibl].scales_h);
+ auto sh = vcombine_u8(sh8, sh8);
+ i8scales.val[0] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[0], m4), 1), vdupq_n_s8(1));
+ i8scales.val[1] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[1], m4), 1), vdupq_n_s8(1));
+ i8scales.val[2] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[0], 4), 1), vdupq_n_s8(1));
+ i8scales.val[3] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[1], 4), 1), vdupq_n_s8(1));
+ i8scales.val[0] = vmulq_s8(i8scales.val[0], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1)));
+ i8scales.val[1] = vmulq_s8(i8scales.val[1], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1)));
+ sh = vshrq_n_u8(sh, 4);
+ i8scales.val[2] = vmulq_s8(i8scales.val[2], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1)));
+ i8scales.val[3] = vmulq_s8(i8scales.val[3], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1)));
+ int32x4_t isum[nrc_y] = {};
+ if constexpr (nrc_y == 1) {
+ iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum);
+ }
+ for (int is = 0; is < 2; ++is) {
+ i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
+ i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
+ i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
+ i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
+ auto lbits = vld1q_u8_x2(iq3[ibl].qs + 128*is + 32*ib);
+ auto hbits = vld1q_u8(iq3[ibl].qh + 64*is + 16*ib);
+ qx[0] = vorrq_u8(vandq_u8( lbits.val[0], m03), vandq_u8(m04, vshlq_n_u8(hbits, 2)));
+ qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 2), m03), vandq_u8(m04, vshlq_n_u8(hbits, 1)));
+ qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 4), m03), vandq_u8(m04, hbits));
+ qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 1)));
+ uint8x16_t shifts;
+ if constexpr (nrc_y == 1) {
+ qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
+ qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
+ qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
+ } else {
+ shifts = vandq_u8(ms, vshlq_n_u8(extra, 3));
+ auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
+ extra = vshrq_n_u8(extra, 1);
+ qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
+ qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
+ qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ qx[0] = vorrq_u8(vandq_u8( lbits.val[1], m03), vandq_u8(m04, vshrq_n_u8(hbits, 2)));
+ qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 2), m03), vandq_u8(m04, vshrq_n_u8(hbits, 3)));
+ qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 4), m03), vandq_u8(m04, vshrq_n_u8(hbits, 4)));
+ qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 5)));
+ if constexpr (nrc_y == 1) {
+ qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
+ qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
+ qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
+ } else {
+ auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
+ qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
+ qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
+ qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
+ }
+ scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = vdupq_n_u8(0xf);
+ auto m3 = vdupq_n_u8(0x30);
+ auto ms = vdupq_n_u8(4);
+ auto m32 = vdupq_n_s8(-32);
+ uint8x16x2_t shift_shuffle = {
+ vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
+ vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
+ };
+ auto values = vld1q_s8(iq4k_values);
+ int nbl = n / QK_K;
+ int8x16_t qx[4];
+ int8x16x4_t i8scales;
+ int16x8x4_t i16scales;
+ float32x4_t acc[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq4_k_r4 * iq4 = (const block_iq4_k_r4 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d));
+ auto extra8 = vld1_u8(iq4[ibl].extra);
+ uint8x16_t extra;
+ if constexpr (nrc_y == 1) {
+ extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
+ } else {
+ extra = vcombine_u8(extra8, extra8);
+ }
+ auto sl = vld1q_u8_x2(iq4[ibl].scales_l);
+ auto sh = vld1q_u8(iq4[ibl].scales_h);
+ i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
+ i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
+ i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
+ i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
+ int32x4_t isum[nrc_y] = {};
+ if constexpr (nrc_y == 1) {
+ iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum);
+ }
+ for (int is = 0; is < 2; ++is) {
+ i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
+ i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
+ i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
+ i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
+ uint8x16_t shifts;
+ if constexpr (nrc_y == 1) {
+ qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
+ qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
+ qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
+ } else {
+ shifts = vandq_u8(ms, vshlq_n_u8(extra, 2));
+ auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
+ extra = vshrq_n_u8(extra, 1);
+ qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[0], m4))); // 0...3 from the 4 rows
+ qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[2], m4))); // 4...7
+ qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4))); // 8..11
+ qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4))); // 12..15
+ }
+ auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ if constexpr (nrc_y == 1) {
+ qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
+ qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
+ qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
+ qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
+ } else {
+ auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
+ qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[1], m4))); // 16..19
+ qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[3], m4))); // 20..23
+ qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4))); // 24..27
+ qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4))); // 28..31
+ }
+ scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = vdupq_n_u8(0xf);
+ auto m3 = vdupq_n_u8(0x30);
+ auto ms = vdupq_n_u8(2);
+ auto m32 = vdupq_n_s8(-32);
+ auto m10 = vdupq_n_u8(0x10);
+ uint8x16x2_t shift_shuffle = {
+ vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
+ vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
+ };
+ auto values = vld1q_s8_x2(iq5nl_values);
+ int nbl = n / QK_K;
+ int8x16_t qx[4];
+ int8x16x4_t i8scales;
+ int16x8x4_t i16scales;
+ float32x4_t acc[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d));
+ auto extra8 = vld1_u8(iq5[ibl].extra);
+ uint8x16_t extra;
+ if constexpr (nrc_y == 1) {
+ extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
+ } else {
+ extra = vcombine_u8(extra8, extra8);
+ }
+ auto sl = vld1q_u8_x2(iq5[ibl].scales_l);
+ auto sh = vld1q_u8(iq5[ibl].scales_h);
+ i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
+ i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
+ i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
+ i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
+ int32x4_t isum[nrc_y] = {};
+ if constexpr (nrc_y == 1) {
+ iq3_4_add_shift<nrc_y, 2>(ibl, q8, i8scales, extra, isum);
+ }
+ for (int is = 0; is < 2; ++is) {
+ i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
+ i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
+ i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
+ i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
+ auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
+ qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4))); // aligns with 1st half of qx[0] in AVX2
+ qx[1] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits)); // aligns with 1st half of qx[1] in AVX2
+ qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3))); // aligns with 1st half of qx[2] in AVX2
+ qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1))); // aligns with 1st half of qx[3] in AVX2
+ uint8x16_t shifts;
+ if constexpr (nrc_y == 1) {
+ qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows
+ qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7
+ qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11
+ qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15
+ } else {
+ shifts = vandq_u8(ms, vshlq_n_u8(extra, 1));
+ auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
+ extra = vshrq_n_u8(extra, 1);
+ qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows
+ qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7
+ qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11
+ qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15
+ }
+ auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ qx[0] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2))); // aligns with 2nd half of qx[0] in AVX2
+ qx[1] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2))); // aligns with 2nd half of qx[1] in AVX2
+ qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1))); // aligns with 2nd half of qx[2] in AVX2
+ qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3))); // aligns with 2nd half of qx[3] in AVX2
+ if constexpr (nrc_y == 1) {
+ qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows
+ qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7
+ qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11
+ qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15
+ } else {
+ auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
+ qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows
+ qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7
+ qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11
+ qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15
+ }
+ scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+}
+
+bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, [[maybe_unused]] mul_mat_t& func16) {
+
+ if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
+ return false;
+ }
+
+ func16 = nullptr;
+
+ switch (typeA) {
+ case GGML_TYPE_IQ2_KS:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ2KS, kernels);
+ break;
+ case GGML_TYPE_IQ2_K:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ2K, kernels);
+ break;
+ case GGML_TYPE_IQ3_K:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ3K, kernels);
+ break;
+ case GGML_TYPE_IQ4_KSS:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ4KSS, kernels);
+ break;
+ case GGML_TYPE_IQ4_KS:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ4KS, kernels);
+ break;
+ case GGML_TYPE_IQ4_K:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ4K, kernels);
+ break;
+ case GGML_TYPE_IQ5_KS:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ5KS, kernels);
+ break;
+ case GGML_TYPE_IQ5_K:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ5K, kernels);
+ break;
+ case GGML_TYPE_IQ6_K:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ6K, kernels);
+ break;
+ case GGML_TYPE_IQ2_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_k_r4_q8_k, kernels);
+ break;
+ case GGML_TYPE_IQ3_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_k_r4_q8_k, kernels);
+ break;
+ case GGML_TYPE_IQ4_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_k_r4_q8_k, kernels);
+ break;
+ case GGML_TYPE_IQ4_KS_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_ks_r4_q8_k, kernels);
+ break;
+ case GGML_TYPE_IQ5_KS_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_ks_r4_q8_k, kernels);
+ break;
+ case GGML_TYPE_IQ5_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_k_r4_q8_k, kernels);
+ break;
+ default:
+ return false;
+ }
+
+ return true;
+
+}
+
+#endif
+
+#endif
diff --git a/ggml/src/iqk/iqk_gemm_iqk_quants.h b/ggml/src/iqk/iqk_gemm_iqk_quants.h
new file mode 100644
index 00000000..cd076ff7
--- /dev/null
+++ b/ggml/src/iqk/iqk_gemm_iqk_quants.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include "iqk_common.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include <array>
+
+bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
+
+#endif
diff --git a/ggml/src/iqk/iqk_gemm_iquants.cpp b/ggml/src/iqk/iqk_gemm_iquants.cpp
new file mode 100644
index 00000000..782e48d8
--- /dev/null
+++ b/ggml/src/iqk/iqk_gemm_iquants.cpp
@@ -0,0 +1,2252 @@
+#include "iqk_gemm_iquants.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include "ggml-impl.h"
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#ifdef __x86_64__
+
+namespace {
+
+inline __m256i get_scale_shuffle_8(int i) {
+ return _mm256_set1_epi16((2*i) | ((2*i+1) << 8));
+}
+
+inline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) {
+ scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0));
+ scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1));
+ scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2));
+ scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3));
+}
+
+inline __m256i get_scale_shuffle_16(int i) {
+ static const uint8_t k_shuffle[128] = {
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
+ };
+ return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+
+inline void set_scales_16(const __m256i& all_scales, __m256i * scales) {
+ scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));
+ scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));
+ scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));
+ scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));
+}
+
+// TODO: find the bug that causes this to be called without HAVE_FANCY_SIMD, which triggers
+// writing 4 vvalues into scales, which is of size 2.
+inline void set_scales_8_iq(int j, const __m256i& all_scales, __m256i * scales) {
+//#ifdef HAVE_FANCY_SIMD
+ auto shuffle = j == 0 ? _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100)
+ : _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908);
+ scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);
+ scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(4)));
+//#else
+// set_scales_8(all_scales, j, scales);
+//#endif
+}
+
+inline void set_scales_16_iq(const __m256i& all_scales, __m256i * scales) {
+#ifdef HAVE_FANCY_SIMD
+ auto shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100);
+ scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);
+ scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(8)));
+#else
+ set_scales_16(all_scales, scales);
+#endif
+}
+
+struct SimpleBits {
+ __m256i values[4];
+};
+
+struct EvenSignHelper {
+#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
+ union sbits_t {
+ __m128i vec;
+ __mmask32 mask[4];
+ };
+ IQK_ALWAYS_INLINE void sign_2_values(__m256i aux, __m256i * values) const {
+ aux = _mm256_and_si256(_mm256_srlv_epi32(aux, shifts), mask);
+ auto pcnt = _mm256_popcnt_epi32(aux);
+ sbits_t sbits;
+ sbits.vec = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7)));
+ values[0] = _mm256_mask_sub_epi8(values[0], sbits.mask[0], _mm256_setzero_si256(), values[0]);
+ values[1] = _mm256_mask_sub_epi8(values[1], sbits.mask[1], _mm256_setzero_si256(), values[1]);
+ //auto sign_bits = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7)));
+ //const __mmask32 * m32 = (const __mmask32 *)&sign_bits;
+ //values[0] = _mm256_mask_sub_epi8(values[0], m32[0], _mm256_setzero_si256(), values[0]);
+ //values[1] = _mm256_mask_sub_epi8(values[1], m32[1], _mm256_setzero_si256(), values[1]);
+ }
+ const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0);
+ const __m256i mask = _mm256_set1_epi32(127);
+ const __m256i mone = _mm256_set1_epi32(1);
+#else
+ inline void sign_value(uint32_t aux32, __m256i& value) const {
+ auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127],
+ keven_signs[(aux32 >> 7) & 127], keven_signs[(aux32 >> 0) & 127]);
+ value = _mm256_sign_epi8(value, signs);
+ }
+#endif
+};
+
+struct SignHelper {
+ inline __m256i make_signs(uint32_t sign_bits) const {
+ auto aux256 = _mm256_set1_epi32(sign_bits);
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256, mask1), mask2);
+ return _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone);
+ }
+// inline __m256i make_signs(const uint16_t * sign_bits) const {
+//#ifdef HAVE_FANCY_SIMD
+//#else
+// return make_signs(sign_bits[0] | (sign_bits[1] << 16));
+//#endif
+// }
+ inline __m256i sign_value(const uint16_t * sign_bits, const __m256i& value) const {
+#ifdef HAVE_FANCY_SIMD
+ const __mmask32 * mask = (const __mmask32 *)sign_bits;
+ return _mm256_mask_sub_epi8(value, mask[0], _mm256_setzero_si256(), value);
+#else
+ return _mm256_sign_epi8(value, make_signs(sign_bits[0] | (sign_bits[1] << 16)));
+#endif
+ }
+ inline void sign_4_values(const uint16_t * sign_bits, __m256i * values) const {
+#ifdef HAVE_FANCY_SIMD
+ const __mmask32 * mask = (const __mmask32 *)sign_bits;
+ values[0] = _mm256_mask_sub_epi8(values[0], mask[0], _mm256_setzero_si256(), values[0]);
+ values[1] = _mm256_mask_sub_epi8(values[1], mask[1], _mm256_setzero_si256(), values[1]);
+ values[2] = _mm256_mask_sub_epi8(values[2], mask[2], _mm256_setzero_si256(), values[2]);
+ values[3] = _mm256_mask_sub_epi8(values[3], mask[3], _mm256_setzero_si256(), values[3]);
+#else
+ auto s128 = _mm_loadu_si128((const __m128i *)sign_bits);
+ auto s256 = MM256_SET_M128I(s128, s128);
+ __m256i aux256;
+ auto shuffle = mask1;
+ auto step = _mm256_set1_epi8(4);
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
+ values[0] = _mm256_sign_epi8(values[0], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
+ values[1] = _mm256_sign_epi8(values[1], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
+ values[2] = _mm256_sign_epi8(values[2], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
+ values[3] = _mm256_sign_epi8(values[3], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
+#endif
+ }
+ const __m256i mask1 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ const __m256i mask2 = _mm256_set1_epi64x(0x8040201008040201ull);
+ const __m256i mone = _mm256_set1_epi8(1);
+};
+
+struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
+ DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ constexpr static int num_blocks = 8;
+
+ union Data {
+ __m256i vec;
+ uint32_t val[8];
+ };
+
+ inline __m128i load_scales(int i) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ const uint16_t * a16 = (const uint16_t *)x[i].qs;
+ auto scales = _mm_srli_epi16(_mm_set_epi16(a16[31], a16[27], a16[23], a16[19], a16[15], a16[11], a16[7], a16[3]), 12);
+ return _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi16(1));
+ }
+
+ inline void new_block(int i, __m256i * scales) {
+ auto sc16 = load_scales(i);
+ scales[0] = MM256_SET_M128I(sc16, sc16);
+ }
+ inline float new_block(int i, __m256i * scales, __m256i& mins) {
+ auto sc16 = load_scales(i);
+ mins = scb.shuffle(sc16);
+ scales[0] = MM256_SET_M128I(sc16, sc16);
+ return -d*minv;
+ }
+
+ inline static void make4(const uint32_t * aux32, __m256i * values) {
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+ values[0] = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[ 1]], iq2xxs_grid[aux8[ 0]]);
+ values[1] = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[ 9]], iq2xxs_grid[aux8[ 8]]);
+ values[2] = _mm256_set_epi64x(iq2xxs_grid[aux8[19]], iq2xxs_grid[aux8[18]], iq2xxs_grid[aux8[17]], iq2xxs_grid[aux8[16]]);
+ values[3] = _mm256_set_epi64x(iq2xxs_grid[aux8[27]], iq2xxs_grid[aux8[26]], iq2xxs_grid[aux8[25]], iq2xxs_grid[aux8[24]]);
+ }
+
+ IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const {
+#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
+ esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0);
+ esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2);
+#else
+ esh.sign_value(aux32[1], values[0]);
+ esh.sign_value(aux32[3], values[1]);
+ esh.sign_value(aux32[5], values[2]);
+ esh.sign_value(aux32[7], values[3]);
+#endif
+ }
+ inline void make4_signed(const uint32_t * aux32, const __m256i& min_value, __m256i * values) const {
+ make4(aux32, values);
+ sign_values(aux32, values);
+ for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);
+ }
+ inline void make4(const uint32_t * aux32, __m256i * values, __m256i * q8) const {
+ make4(aux32, values);
+ sign_values(aux32, q8);
+ }
+ inline void prepare(int i, int j) {
+ Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
+ make4_signed(data.val, min_value, bits.values);
+ }
+ inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
+ for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
+ Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
+ make4(data.val, bits.values, q8_quants);
+ }
+
+ constexpr static int minv = 43;
+ SimpleBits bits;
+ Scales8KBase scb;
+ EvenSignHelper esh;
+ const __m256i min_value = _mm256_set1_epi8(minv);
+ const __m256i shuffle = _mm256_set_epi32(7, 5, 3, 1, 7, 5, 3, 1);
+};
+
+struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
+ DequantizerIQ2XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ constexpr static int num_blocks = 16;
+
+ inline __m256i load_scales(int i) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
+ auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf));
+ auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
+ return _mm256_cvtepi8_epi16(scales8);
+ }
+ inline static void prepare_scales(const __m256i& all, __m256i * scales) {
+ auto scales_l = _mm256_castsi256_si128(all);
+ auto scales_h = _mm256_extractf128_si256(all, 1);
+ scales[0] = MM256_SET_M128I(scales_l, scales_l);
+ scales[1] = MM256_SET_M128I(scales_h, scales_h);
+ }
+
+ inline void new_block(int i, __m256i * scales) {
+ prepare_scales(load_scales(i), scales);
+ }
+ inline float new_block(int i, __m256i * scales, __m256i& mins) {
+ mins = load_scales(i);
+ prepare_scales(mins, scales);
+ return -d*minv;
+ }
+
+ struct Helper {
+ const __m256i mone = _mm256_set1_epi8(1);
+ const __m256i mask = _mm256_set1_epi64x(0x8040201008040201);
+ //const __m256i bhelper = _mm256_set_epi64x(0x8000008000808000, 0x0080800080000080, 0x8000008000808000, 0x0080800080000080);
+ const __m256i bhelper = load_bhelper();
+ const __m256i shuff1 = _mm256_set_epi64x(0x0606060606060606, 0x0404040404040404, 0x0202020202020202, 0x0000000000000000);
+ const __m256i shuff2 = _mm256_set_epi64x(0x0e0e0e0e0e0e0e0e, 0x0c0c0c0c0c0c0c0c, 0x0a0a0a0a0a0a0a0a, 0x0808080808080808);
+ static __m256i load_bhelper() {
+ static const uint8_t k_bit_helper[32] = {
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ };
+ return _mm256_loadu_si256((const __m256i*)k_bit_helper);
+ }
+ };
+
+ union index_t {
+ __m256i vec;
+ uint16_t val[8];
+ };
+
+ inline static void make4(const __m256i& data, const __m256i& mask, __m256i * values) {
+ index_t idx;
+ idx.vec = _mm256_and_si256(data, mask);
+ values[0] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 3]], iq2xs_grid[idx.val[ 2]], iq2xs_grid[idx.val[ 1]], iq2xs_grid[idx.val[ 0]]);
+ values[1] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 7]], iq2xs_grid[idx.val[ 6]], iq2xs_grid[idx.val[ 5]], iq2xs_grid[idx.val[ 4]]);
+ values[2] = _mm256_set_epi64x(iq2xs_grid[idx.val[11]], iq2xs_grid[idx.val[10]], iq2xs_grid[idx.val[ 9]], iq2xs_grid[idx.val[ 8]]);
+ values[3] = _mm256_set_epi64x(iq2xs_grid[idx.val[15]], iq2xs_grid[idx.val[14]], iq2xs_grid[idx.val[13]], iq2xs_grid[idx.val[12]]);
+ }
+ inline static void sign_value(const __m256i& sign_bits, const __m256i& shuffle, const __m256i& mask,
+ const __m256i& mone, __m256i& value) {
+ auto signs = _mm256_shuffle_epi8(sign_bits, shuffle);
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, mask), mask);
+ value = _mm256_sign_epi8(value, _mm256_or_si256(signs, mone));
+ }
+ inline void sign_values(const __m256i& data, __m256i * values) const {
+#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
+ auto partial_bits = _mm256_cvtepi16_epi8(_mm256_srli_epi16(data, 9));
+ auto pcnt = _mm_popcnt_epi8(partial_bits);
+ auto full_bits = _mm_or_si128(partial_bits, _mm_slli_epi16(_mm_and_si128(pcnt, _mm_set1_epi8(1)), 7));
+ const __mmask32 * m32 = (const __mmask32 *)&full_bits;
+ auto zero = _mm256_setzero_si256();
+ values[0] = _mm256_mask_sub_epi8(values[0], m32[0], zero, values[0]);
+ values[1] = _mm256_mask_sub_epi8(values[1], m32[1], zero, values[1]);
+ values[2] = _mm256_mask_sub_epi8(values[2], m32[2], zero, values[2]);
+ values[3] = _mm256_mask_sub_epi8(values[3], m32[3], zero, values[3]);
+#else
+ auto psb1 = _mm256_srli_epi16(data, 9);
+ auto psb2 = _mm256_srli_epi16(data, 13);
+ auto psbc = _mm256_xor_si256(psb1, psb2);
+ auto oddb = _mm256_shuffle_epi8(helper.bhelper, psbc);
+ auto full = _mm256_or_si256(psb1, oddb);
+ auto full_l = _mm256_castsi256_si128(full);
+ auto full_h = _mm256_extractf128_si256(full, 1);
+ auto full_1 = MM256_SET_M128I(full_l, full_l);
+ auto full_2 = MM256_SET_M128I(full_h, full_h);
+ sign_value(full_1, helper.shuff1, helper.mask, helper.mone, values[0]);
+ sign_value(full_1, helper.shuff2, helper.mask, helper.mone, values[1]);
+ sign_value(full_2, helper.shuff1, helper.mask, helper.mone, values[2]);
+ sign_value(full_2, helper.shuff2, helper.mask, helper.mone, values[3]);
+#endif
+ }
+ inline void make4_signed(const uint16_t * qs, const __m256i& m511,
+ const __m256i& min_value, __m256i * values) const {
+ auto q2 = _mm256_loadu_si256((const __m256i *)qs);
+ make4(q2, m511, values);
+ sign_values(q2, values);
+ for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);
+ }
+ inline void make4(const uint16_t * qs, const __m256i& m511, __m256i * values, __m256i * q8) const {
+ auto q2 = _mm256_loadu_si256((const __m256i *)qs);
+ make4(q2, m511, values);
+ sign_values(q2, q8);
+ }
+
+ inline void prepare(int i, int j) {
+ make4_signed(x[i].qs + 16*j, idx_mask, min_value, bits.values);
+ }
+ inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
+ for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
+ make4(x[i].qs + 16*j, idx_mask, bits.values, q8_quants);
+ }
+
+ constexpr static int minv = 43;
+
+ SimpleBits bits;
+#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__)
+ Helper helper;
+#endif
+ const __m256i idx_mask = _mm256_set1_epi16(511);
+ const __m256i min_value = _mm256_set1_epi8(minv);
+
+};
+
+struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
+ DequantizerIQ2S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ constexpr static int num_blocks = 16;
+
+ inline __m256i load_scales(int i) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
+ auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf));
+ auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
+ return _mm256_cvtepi8_epi16(scales8);
+ }
+ inline static void prepare_scales(const __m256i& all, __m256i * scales) {
+ auto scales_l = _mm256_castsi256_si128(all);
+ auto scales_h = _mm256_extractf128_si256(all, 1);
+ scales[0] = MM256_SET_M128I(scales_l, scales_l);
+ scales[1] = MM256_SET_M128I(scales_h, scales_h);
+ }
+
+ inline void new_block(int i, __m256i * scales) {
+ prepare_scales(load_scales(i), scales);
+ }
+ inline float new_block(int i, __m256i * scales, __m256i& mins) {
+ mins = load_scales(i);
+ prepare_scales(mins, scales);
+ return -d*minv;
+ }
+
+ union index_t {
+ __m256i vec;
+ uint32_t val[8];
+ };
+
+ inline static void make2(const uint8_t * qs, const uint8_t * qh, const __m256i& idx_shift, const __m256i& idx_mask, __m256i * values) {
+ auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs));
+ auto idx_h = MM256_SET_M128I(_mm_set1_epi32(qh[1]), _mm_set1_epi32(qh[0]));
+ index_t idx;
+ idx.vec = _mm256_or_si256(idx_l, _mm256_and_si256(_mm256_sllv_epi32(idx_h, idx_shift), idx_mask));
+ values[0] = _mm256_set_epi64x(iq2s_grid[idx.val[3]], iq2s_grid[idx.val[2]], iq2s_grid[idx.val[1]], iq2s_grid[idx.val[0]]);
+ values[1] = _mm256_set_epi64x(iq2s_grid[idx.val[7]], iq2s_grid[idx.val[6]], iq2s_grid[idx.val[5]], iq2s_grid[idx.val[4]]);
+ }
+ inline static void make2_signed(const SignHelper& sh, const uint8_t * qs, const uint8_t * qh, const uint16_t * sidx,
+ const __m256i& idx_shift, const __m256i& idx_mask, const __m256i& min_value, __m256i * values) {
+ make2(qs, qh, idx_shift, idx_mask, values);
+ values[0] = _mm256_add_epi8(sh.sign_value(sidx+0, values[0]), min_value);
+ values[1] = _mm256_add_epi8(sh.sign_value(sidx+2, values[1]), min_value);
+ }
+
+ inline void prepare(int i, int j) {
+ auto qs = x[i].qs + 16*j;
+ auto qh = x[i].qh + 4*j;
+ const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/8) + 8*j;
+ make2_signed(sh, qs+0, qh+0, signs+0, idx_shift, idx_mask, min_value, bits.values+0);
+ make2_signed(sh, qs+8, qh+2, signs+4, idx_shift, idx_mask, min_value, bits.values+2);
+ }
+ inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
+ auto qs = x[i].qs + 16*j;
+ auto qh = x[i].qh + 4*j;
+ const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/8) + 8*j;
+ make2(qs+0, qh+0, idx_shift, idx_mask, bits.values+0);
+ make2(qs+8, qh+2, idx_shift, idx_mask, bits.values+2);
+ q8_quants[0] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+0), sh.make_signs(signs[0] | (signs[1] << 16)));
+ q8_quants[1] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+1), sh.make_signs(signs[2] | (signs[3] << 16)));
+ q8_quants[2] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+2), sh.make_signs(signs[4] | (signs[5] << 16)));
+ q8_quants[3] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+3), sh.make_signs(signs[6] | (signs[7] << 16)));
+ }
+
+ constexpr static int minv = 43;
+
+ SimpleBits bits;
+ SignHelper sh;
+ const __m256i idx_shift = _mm256_set_epi32(2, 4, 6, 8, 2, 4, 6, 8);
+ const __m256i idx_mask = _mm256_set1_epi32(0x300);
+ const __m256i min_value = _mm256_set1_epi8(minv);
+
+};
+
+struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
+ DequantizerIQ3XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ constexpr static int num_blocks = 8;
+
+ inline __m128i prepare_scales(int i) {
+ d = 0.25f * GGML_FP16_TO_FP32(x[i].d);
+ auto tmp = _mm256_loadu_si256((const __m256i *)(x[i].qs + QK_K/4));
+ auto scales32 = _mm256_srli_epi32(tmp, 28);
+ scales32 = _mm256_or_si256(_mm256_slli_epi32(scales32, 1), _mm256_set1_epi32(1));
+ return _mm_packs_epi32(_mm256_castsi256_si128(scales32), _mm256_extractf128_si256(scales32, 1));
+ }
+
+ inline void new_block(int i, __m256i * scales) {
+ auto scales16 = prepare_scales(i);
+ scales[0] = MM256_SET_M128I(scales16, scales16);
+ }
+ inline float new_block(int i, __m256i * scales, __m256i& mins) {
+ auto scales16 = prepare_scales(i);
+ mins = scb.shuffle(scales16);
+ scales[0] = MM256_SET_M128I(scales16, scales16);
+ return -d*minv;
+ }
+
+ inline static __m256i make_quants(const uint8_t * qs) {
+ return _mm256_set_epi32(iq3xxs_grid[qs[7]], iq3xxs_grid[qs[6]], iq3xxs_grid[qs[5]], iq3xxs_grid[qs[4]],
+ iq3xxs_grid[qs[3]], iq3xxs_grid[qs[2]], iq3xxs_grid[qs[1]], iq3xxs_grid[qs[0]]);
+ }
+ inline static void make4_unsigned(const uint8_t * qs, __m256i * values) {
+ values[0] = make_quants(qs+ 0);
+ values[1] = make_quants(qs+ 8);
+ values[2] = make_quants(qs+16);
+ values[3] = make_quants(qs+24);
+ }
+
+ IQK_ALWAYS_INLINE void sign_2_values(const uint16_t * signs, __m256i * values) const {
+#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
+ esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(signs[2] | (signs[3] << 16)), _mm_set1_epi32(signs[0] | (signs[1] << 16))), values);
+#else
+ esh.sign_value(signs[0] | (signs[1] << 16), values[0]);
+ esh.sign_value(signs[2] | (signs[3] << 16), values[1]);
+#endif
+ }
+
+ inline void prepare(int i, int j) {
+ auto qs = x[i].qs + 32*j;
+ const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j;
+ make4_unsigned(qs, bits.values);
+ sign_2_values(signs+0, bits.values+0);
+ sign_2_values(signs+4, bits.values+2);
+ for (int k = 0; k < 4; ++k) bits.values[k] = _mm256_add_epi32(bits.values[k], min_value);
+ }
+ inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
+ for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
+ auto qs = x[i].qs + 32*j;
+ const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j;
+ make4_unsigned(qs, bits.values);
+ sign_2_values(signs+0, q8_quants+0);
+ sign_2_values(signs+4, q8_quants+2);
+ }
+
+ constexpr static int minv = 64;
+
+ SimpleBits bits;
+ Scales8KBase scb;
+ EvenSignHelper esh;
+ const __m256i min_value = _mm256_set1_epi8(minv);
+
+};
+
+#ifdef HAVE_FANCY_SIMD
+// Strangely enough, the following implementation makes PP ~6% slower and TG ~6% faster
+// compared to the vanilla AVX2 version below.
+struct IndexHelperIQ3S {
+ union index_t {
+ __m256i vec;
+ uint16_t val[16];
+ };
+ inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const {
+ auto idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs));
+ const __mmask16 * m16 = (const __mmask16 *)qh;
+ index_t idx;
+ idx.vec = _mm256_mask_add_epi16(idx_l, m16[0], idx_l, offset);
+ values[0] = _mm256_set_epi32(iq3s_grid[idx.val[ 7]], iq3s_grid[idx.val[ 6]], iq3s_grid[idx.val[ 5]], iq3s_grid[idx.val[ 4]],
+ iq3s_grid[idx.val[ 3]], iq3s_grid[idx.val[ 2]], iq3s_grid[idx.val[ 1]], iq3s_grid[idx.val[ 0]]);
+ values[1] = _mm256_set_epi32(iq3s_grid[idx.val[15]], iq3s_grid[idx.val[14]], iq3s_grid[idx.val[13]], iq3s_grid[idx.val[12]],
+ iq3s_grid[idx.val[11]], iq3s_grid[idx.val[10]], iq3s_grid[idx.val[ 9]], iq3s_grid[idx.val[ 8]]);
+ }
+ const __m256i offset = _mm256_set1_epi16(256);
+};
+#else
+struct IndexHelperIQ3S {
+ union index_t {
+ __m256i vec;
+ uint32_t val[8];
+ };
+ inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const {
+ index_t idx;
+ auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs));
+ auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask);
+ idx.vec = _mm256_or_si256(idx_h, idx_l);
+ values[0] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]],
+ iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]);
+ idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs+8)));
+ idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask);
+ idx.vec = _mm256_or_si256(idx_h, idx_l);
+ values[1] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]],
+ iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]);
+ }
+ const __m256i idx_mask = _mm256_set1_epi32(256);
+ const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
+};
+#endif
+
+struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
+ DequantizerIQ3S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ constexpr static int num_blocks = 8;
+
+ inline __m128i make_scales(int i, float& dd) const {
+ dd = GGML_FP16_TO_FP32(x[i].d);
+ uint32_t aux32[2];
+ std::memcpy(aux32, x[i].scales, 4);
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+ aux32[0] &= 0x0f0f0f0f;
+ auto scales8 = _mm_shuffle_epi8(_mm_loadl_epi64((const __m128i *)aux32), _mm_set1_epi64x(0x0703060205010400));
+ auto scales16 = _mm256_castsi256_si128(_mm256_cvtepi8_epi16(scales8));
+ return _mm_or_si128(_mm_slli_epi16(scales16, 1), _mm_set1_epi16(1));
+ }
+ inline void new_block(int i, __m256i * scales) {
+ auto scales16 = make_scales(i, d);
+ scales[0] = MM256_SET_M128I(scales16, scales16);
+ }
+ inline float new_block(int i, __m256i * scales, __m256i& mins) {
+ auto scales16 = make_scales(i, d);
+ mins = scb.shuffle(scales16);
+ scales[0] = MM256_SET_M128I(scales16, scales16);
+ return -minv*d;
+ }
+
+ inline void prepare(int i, int j) {
+ prepare_unsigned(i, j);
+ sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, bits.values);
+ for (int k = 0; k < 4; ++k) bits.values[k] = _mm256_add_epi8(bits.values[k], min_value);
+ }
+ inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
+ prepare_unsigned(i, j);
+ for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
+ sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants);
+ }
+
+ inline void prepare_unsigned(int i, int j) {
+ auto qs = x[i].qs + 32*j;
+ auto qh = x[i].qh + 4*j;
+ helper.make2(qs+ 0, qh+0, bits.values+0);
+ helper.make2(qs+16, qh+2, bits.values+2);
+ }
+
+ constexpr static int minv = 16;
+
+ SimpleBits bits;
+ SignHelper sh;
+ Scales8KBase scb;
+ IndexHelperIQ3S helper;
+ const __m256i min_value = _mm256_set1_epi8(minv);
+
+};
+
+template <typename Bits>
+inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) {
+ if (j == 0) {
+#ifdef HAVE_FANCY_SIMD
+ auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);
+ auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);
+ auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);
+ auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);
+ sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_packs_epi32(p1, p2));
+ sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_packs_epi32(p3, p4));
+#else
+ const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));
+ const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));
+ const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));
+ const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));
+ sumi[0] = _mm256_add_epi32(p1, p3);
+ sumi[1] = _mm256_add_epi32(p2, p4);
+#endif
+ } else {
+#ifdef HAVE_FANCY_SIMD
+ auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);
+ auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);
+ auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);
+ auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);
+ sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_packs_epi32(p1, p2));
+ sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_packs_epi32(p3, p4));
+#else
+ const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));
+ const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));
+ const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));
+ const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));
+ sumi[0] = _mm256_add_epi32(sumi[0], _mm256_add_epi32(p1, p3));
+ sumi[1] = _mm256_add_epi32(sumi[1], _mm256_add_epi32(p2, p4));
+#endif
+ }
+}
+
+template <typename Dequantizer>
+static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK_K;
+ Q8<1> q8(info);
+ Dequantizer deq(vx, bx);
+ __m256i scales[2];
+ __m256i q8_quants[4];
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ __m256 accd = _mm256_setzero_ps();
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ __m256i sumi[2], all_scales[Dequantizer::num_blocks/8];
+ deq.new_block(i, all_scales);
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ deq.prepare(i, j, q8, q8_quants);
+ if constexpr (Dequantizer::num_blocks == 8) {
+ set_scales_8_iq(j, all_scales[0], scales);
+ } else {
+ set_scales_16_iq(all_scales[j], scales);
+ }
+ multiply_add_1(j, deq.bits, scales, q8_quants, sumi);
+ }
+ accd = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(0, i)), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi[0], sumi[1])), accd);
+ }
+
+ info.store(ix, 0, hsum_float_8(accd));
+ }
+}
+
+// So, if I uncomment this function and the call to it in mul_mat_qX_K_q8_K_IQ_N() below,
+// PP performance improves by ~2-3% (when we have __AVX512VNNI__ and __AVX512VL__).
+// But TG performance for iq3_xs drops by 35%. Seriously? I mean, c'mon,
+// what does the compilation of mul_mat_qX_K_q8_K_IQ_1 (which gets invoked during TG)
+// have to do with the compilation of mul_mat_qX_K_q8_K_IQ_N (invoked during PP)?
+//template <typename Q8, typename Bits>
+//inline void multiply_add_iq(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
+//#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
+// for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4*j+0)));
+// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 4*j+1)));
+// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 4*j+2)));
+// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 4*j+3)));
+// }
+//#else
+// for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+// const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4*j+0)));
+// const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 4*j+1)));
+// const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 4*j+2)));
+// const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 4*j+3)));
+// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));
+// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));
+// }
+//#endif
+//}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK_K;
+ Q8<nrc_y> q8(info);
+ Dequantizer deq(vx, bx);
+ __m256i scales[4];
+ __m256 accd[nrc_y];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ __m256i sumi[nrc_y], all_scales[Dequantizer::num_blocks/8];
+ //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();
+ __m256i mins;
+ float dmin = deq.new_block(i, all_scales, mins);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, i);
+ auto prod = _mm256_madd_epi16(mins, bsums);
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
+ }
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ deq.prepare(i, j);
+ if constexpr (Dequantizer::num_blocks == 8) {
+ set_scales_8(all_scales[0], j, scales);
+ } else {
+ set_scales_16(all_scales[j], scales);
+ }
+ //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);
+ multiply_add(deq.bits, scales, j, i, q8, sumi);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
+ accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+ }
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+#ifdef HAVE_FANCY_SIMD
+ if constexpr (nrc_y == 1) {
+ mul_mat_qX_K_q8_K_IQ_1<Dequantizer>(n, vx, bx, info, nrc_x);
+ } else {
+ mul_mat_qX_K_q8_K_IQ_N<Dequantizer, nrc_y>(n, vx, bx, info, nrc_x);
+ }
+#else
+ mul_mat_qX_K_q8_K_IQ_N<Dequantizer, nrc_y>(n, vx, bx, info, nrc_x);
+#endif
+}
+
+template <int nrc_y>
+static void mul_mat_iq2_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+#ifndef HAVE_FANCY_SIMD
+ auto smask = _mm256_set1_epi64x(0x8040201008040201);
+ auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ auto m4 = _mm256_set1_epi8(4);
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ __m256 acc[nrc_y] = {};
+ __m256i isum[nrc_y] = {};
+ __m256i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto iq2 = (const block_iq2_xxs_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+ auto qs = iq2[ibl].qs;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ qx[0] = _mm256_set_epi64x(iq2xxs_grid[qs[ 3]], iq2xxs_grid[qs[ 2]], iq2xxs_grid[qs[ 1]], iq2xxs_grid[qs[ 0]]);
+ qx[1] = _mm256_set_epi64x(iq2xxs_grid[qs[ 7]], iq2xxs_grid[qs[ 6]], iq2xxs_grid[qs[ 5]], iq2xxs_grid[qs[ 4]]);
+ qx[2] = _mm256_set_epi64x(iq2xxs_grid[qs[11]], iq2xxs_grid[qs[10]], iq2xxs_grid[qs[ 9]], iq2xxs_grid[qs[ 8]]);
+ qx[3] = _mm256_set_epi64x(iq2xxs_grid[qs[15]], iq2xxs_grid[qs[14]], iq2xxs_grid[qs[13]], iq2xxs_grid[qs[12]]);
+ qs += 16;
+ auto sas = _mm_loadu_si128((const __m128i *)iq2[ibl].sas + ib);
+ auto scales = _mm_and_si128(sas, _mm_set1_epi8(1));
+#ifdef HAVE_FANCY_SIMD
+ scales = _mm_dpbusd_epi32(_mm_set1_epi32(1), scales, _mm_set1_epi32(0x10080402));
+#else
+ scales = _mm_maddubs_epi16(scales, _mm_set1_epi32(0x10080402));
+ scales = _mm_add_epi32(_mm_madd_epi16(_mm_set1_epi16(1), scales), _mm_set1_epi32(1));
+#endif
+ auto scales32 = MM256_SET_M128I(scales, scales);
+ auto signs128 = _mm_and_si128(sas, _mm_set1_epi8(-2)); // 0xfe = -2 as signed. Needed to shutup compiler warning.
+ signs128 = _mm_xor_si128(signs128, _mm_srli_epi16(signs128, 1));
+#ifdef HAVE_FANCY_SIMD
+ auto mask = (const __mmask32 *)&signs128;
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y));
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_mask_sub_epi8(y, mask[1], _mm256_setzero_si256(), y));
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_mask_sub_epi8(y, mask[2], _mm256_setzero_si256(), y));
+ auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_mask_sub_epi8(y, mask[3], _mm256_setzero_si256(), y));
+ auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
+ auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
+ auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi));
+ }
+#else
+ auto signs = MM256_SET_M128I(signs128, signs128);
+ auto shuffle = sign_shuffle;
+ auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1)));
+ auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2)));
+ auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3)));
+ auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4)));
+ auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
+ auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
+ auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi));
+ }
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum));
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+#ifndef HAVE_FANCY_SIMD
+ auto smask = _mm256_set1_epi64x(0x8040201008040201);
+ auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ auto m4 = _mm256_set1_epi8(4);
+#endif
+ __m256 acc[nrc_y] = {};
+#ifdef HAVE_FANCY_SIMD
+ __m256i shuffles[2] = {
+ _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100),
+ _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
+ };
+ __m256i isum[2*nrc_y] = {};
+#else
+ __m256i shuffles[4] = {
+ MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)),
+ };
+ __m256i isum[nrc_y == 1 ? 4 : nrc_y] = {};
+#endif
+ auto s_shuffle = _mm_set_epi64x(0x0f0d0b0907050301, 0x0e0c0a0806040200);
+ __m256i qx[4];
+ union { __m256i vec; uint16_t val[16]; } helper;
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto iq2 = (const block_iq2_xs_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+ auto s32 = (const uint32_t *)iq2[ibl].scales;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto val = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs + ib);
+ helper.vec = _mm256_and_si256(val, _mm256_set1_epi16(511));
+ qx[0] = _mm256_set_epi64x(iq2xs_grid[helper.val[ 3]], iq2xs_grid[helper.val[ 2]], iq2xs_grid[helper.val[ 1]], iq2xs_grid[helper.val[ 0]]);
+ qx[1] = _mm256_set_epi64x(iq2xs_grid[helper.val[ 7]], iq2xs_grid[helper.val[ 6]], iq2xs_grid[helper.val[ 5]], iq2xs_grid[helper.val[ 4]]);
+ qx[2] = _mm256_set_epi64x(iq2xs_grid[helper.val[11]], iq2xs_grid[helper.val[10]], iq2xs_grid[helper.val[ 9]], iq2xs_grid[helper.val[ 8]]);
+ qx[3] = _mm256_set_epi64x(iq2xs_grid[helper.val[15]], iq2xs_grid[helper.val[14]], iq2xs_grid[helper.val[13]], iq2xs_grid[helper.val[12]]);
+ auto signs16 = _mm256_srli_epi16(val, 9);
+ signs16 = _mm256_xor_si256(signs16, _mm256_slli_epi16(signs16, 1));
+ auto signs128 = _mm_or_si128(_mm256_castsi256_si128(signs16), _mm_slli_epi16(_mm256_extracti128_si256(signs16, 1), 8));
+ signs128 = _mm_shuffle_epi8(signs128, s_shuffle);
+ auto scales = _mm_set1_epi32(s32[ib]);
+ scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf));
+ scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1));
+ auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7
+#ifdef HAVE_FANCY_SIMD
+ __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) };
+ auto mask = (const __mmask32 *)&signs128;
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y)); // blocks: 0,0,0,0, 1,1,1,1, row 0
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_mask_sub_epi8(y, mask[1], _mm256_setzero_si256(), y)); // blocks: 2,2,2,2, 3,3,3,3, row 1
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_mask_sub_epi8(y, mask[2], _mm256_setzero_si256(), y)); // blocks: 4,4,4,4, 5,5,5,5, row 2
+ auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_mask_sub_epi8(y, mask[3], _mm256_setzero_si256(), y)); // blocks: 6,6,6,6, 7,7,7,7, row 3
+ auto s12 = _mm256_packs_epi32(sumi1, sumi2); // 0,0,0,0, 2,2,2,2, 1,1,1,1, 3,3,3,3
+ auto s34 = _mm256_packs_epi32(sumi3, sumi4); // 4,4,4,4, 6,6,6,6, 5,5,5,5, 7,7,7,7
+ isum[2*iy+0] = _mm256_add_epi32(isum[2*iy+0], _mm256_madd_epi16(scs[0], s12));
+ isum[2*iy+1] = _mm256_add_epi32(isum[2*iy+1], _mm256_madd_epi16(scs[1], s34));
+ }
+#else
+ auto signs = MM256_SET_M128I(signs128, signs128);
+ auto shuffle = sign_shuffle;
+ auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ __m256i scs[4] = {
+ _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]),
+ _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]),
+ };
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ if constexpr (nrc_y == 1) {
+ isum[0] = _mm256_add_epi32(isum[0], _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))));
+ isum[1] = _mm256_add_epi32(isum[1], _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))));
+ isum[2] = _mm256_add_epi32(isum[2], _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))));
+ isum[3] = _mm256_add_epi32(isum[3], _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))));
+ } else {
+ auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))); // blocks 4x0, 4x1, row 0
+ auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))); // blocks 4x2, 4x3, row 1
+ auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))); // blocks 4x4, 4x5, row 2
+ auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))); // blocks 4x6, 4x7, row 3
+ auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
+ auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
+ auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
+ isum[iy] = _mm256_add_epi32(isum[iy], sumi);
+ }
+ }
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]);
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256();
+#else
+ if constexpr (nrc_y == 1) {
+ auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[0], isum[1]), _mm256_unpackhi_epi32(isum[0], isum[1]));
+ auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[2], isum[3]), _mm256_unpackhi_epi32(isum[2], isum[3]));
+ auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34));
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ isum[0] = isum[1] = isum[2] = isum[3] = _mm256_setzero_si256();
+ } else {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum));
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+static void mul_mat_iq2_xs_r4_q8_k_16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ constexpr int nrc_y = 16;
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+#ifndef HAVE_FANCY_SIMD
+ auto smask = _mm256_set1_epi64x(0x8040201008040201);
+ auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ auto m4 = _mm256_set1_epi8(4);
+#endif
+ __m256 acc[nrc_y] = {};
+#ifdef HAVE_FANCY_SIMD
+ __m256i shuffles[2] = {
+ _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100),
+ _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
+ };
+ __m256i isum[2*nrc_y] = {};
+#else
+ __m256i shuffles[4] = {
+ MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)),
+ };
+ __m256i isum[nrc_y == 1 ? 4 : nrc_y] = {};
+#endif
+ auto s_shuffle = _mm_set_epi64x(0x0f0d0b0907050301, 0x0e0c0a0806040200);
+ __m256i qx[4];
+ union { __m256i vec; uint16_t val[16]; } helper;
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto iq2 = (const block_iq2_xs_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+ auto s32 = (const uint32_t *)iq2[ibl].scales;
+ {
+ auto scale_bits = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales);
+ auto scales1 = _mm256_and_si256(scale_bits, _mm256_set1_epi8(0xf));
+ auto scales2 = _mm256_and_si256(_mm256_srli_epi16(scale_bits, 4), _mm256_set1_epi8(0xf));
+ scales1 = _mm256_or_si256(_mm256_slli_epi16(scales1, 1), _mm256_set1_epi8(1));
+ scales2 = _mm256_or_si256(_mm256_slli_epi16(scales2, 1), _mm256_set1_epi8(1));
+ auto s1_8 = _mm256_unpacklo_epi8(scales1, scales2); // blocks 0...15, 32...47 (0...3, 8...11 from each row)
+ auto s2_8 = _mm256_unpackhi_epi8(scales1, scales2); // blocks 16..31, 48...63 (4...7, 12..15 from each row)
+ auto s1_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s1_8)); // 0...15 (0...3 from each row)
+ auto s2_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s1_8, 1)); // 32...47 (8..11 from each row)
+ auto s3_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s2_8)); // 16...31 (4...7 from each row)
+ auto s4_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s2_8, 1)); // 48...63 (12.15 from each row)
+ auto t1 = MM256_SET_M128I(_mm256_castsi256_si128(s2_16), _mm256_castsi256_si128(s1_16)); // 0,1 and 8,9 from each row
+ auto t2 = MM256_SET_M128I(_mm256_extracti128_si256(s2_16, 1), _mm256_extracti128_si256(s1_16, 1)); // 2,3 and 10,11 from each row
+ auto t3 = MM256_SET_M128I(_mm256_castsi256_si128(s4_16), _mm256_castsi256_si128(s3_16)); // 4,5 and 12,13 from each row
+ auto t4 = MM256_SET_M128I(_mm256_extracti128_si256(s4_16, 1), _mm256_extracti128_si256(s3_16, 1)); // 6,7 and 14,15 from each row
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, ibl);
+ auto sumi = _mm256_setzero_si256();
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm256_dpwssd_epi32(sumi, t1, _mm256_shuffle_epi32(bsums, 0x00));
+ sumi = _mm256_dpwssd_epi32(sumi, t2, _mm256_shuffle_epi32(bsums, 0x55));
+ sumi = _mm256_dpwssd_epi32(sumi, t3, _mm256_shuffle_epi32(bsums, 0xaa));
+ sumi = _mm256_dpwssd_epi32(sumi, t4, _mm256_shuffle_epi32(bsums, 0xff));
+#else
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t1, _mm256_shuffle_epi32(bsums, 0x00)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t2, _mm256_shuffle_epi32(bsums, 0x55)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t3, _mm256_shuffle_epi32(bsums, 0xaa)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t4, _mm256_shuffle_epi32(bsums, 0xff)));
+#endif
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(-64.f*q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto val = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs + ib);
+ helper.vec = _mm256_and_si256(val, _mm256_set1_epi16(511));
+ qx[0] = _mm256_set_epi64x(iq2xs_grid[helper.val[ 3]], iq2xs_grid[helper.val[ 2]], iq2xs_grid[helper.val[ 1]], iq2xs_grid[helper.val[ 0]]);
+ qx[1] = _mm256_set_epi64x(iq2xs_grid[helper.val[ 7]], iq2xs_grid[helper.val[ 6]], iq2xs_grid[helper.val[ 5]], iq2xs_grid[helper.val[ 4]]);
+ qx[2] = _mm256_set_epi64x(iq2xs_grid[helper.val[11]], iq2xs_grid[helper.val[10]], iq2xs_grid[helper.val[ 9]], iq2xs_grid[helper.val[ 8]]);
+ qx[3] = _mm256_set_epi64x(iq2xs_grid[helper.val[15]], iq2xs_grid[helper.val[14]], iq2xs_grid[helper.val[13]], iq2xs_grid[helper.val[12]]);
+ auto signs16 = _mm256_srli_epi16(val, 9);
+ signs16 = _mm256_xor_si256(signs16, _mm256_slli_epi16(signs16, 1));
+ auto signs128 = _mm_or_si128(_mm256_castsi256_si128(signs16), _mm_slli_epi16(_mm256_extracti128_si256(signs16, 1), 8));
+ signs128 = _mm_shuffle_epi8(signs128, s_shuffle);
+ auto scales = _mm_set1_epi32(s32[ib]);
+ scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf));
+ scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1));
+ auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7
+#ifdef HAVE_FANCY_SIMD
+ __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) };
+ auto mask = (const __mmask32 *)&signs128;
+ qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[0], mask[0], _mm256_setzero_si256(), qx[0]));
+ qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[1], mask[1], _mm256_setzero_si256(), qx[1]));
+ qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[2], mask[2], _mm256_setzero_si256(), qx[2]));
+ qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[3], mask[3], _mm256_setzero_si256(), qx[3]));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], y); // blocks: 0,0,0,0, 1,1,1,1, row 0
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], y); // blocks: 2,2,2,2, 3,3,3,3, row 1
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], y); // blocks: 4,4,4,4, 5,5,5,5, row 2
+ auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], y); // blocks: 6,6,6,6, 7,7,7,7, row 3
+ auto s12 = _mm256_packs_epi32(sumi1, sumi2); // 0,0,0,0, 2,2,2,2, 1,1,1,1, 3,3,3,3
+ auto s34 = _mm256_packs_epi32(sumi3, sumi4); // 4,4,4,4, 6,6,6,6, 5,5,5,5, 7,7,7,7
+ isum[2*iy+0] = _mm256_add_epi32(isum[2*iy+0], _mm256_madd_epi16(scs[0], s12));
+ isum[2*iy+1] = _mm256_add_epi32(isum[2*iy+1], _mm256_madd_epi16(scs[1], s34));
+ }
+#else
+ auto signs = MM256_SET_M128I(signs128, signs128);
+ auto shuffle = sign_shuffle;
+ auto s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[0], s));
+ s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[1], s));
+ s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[2], s));
+ s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[3], s));
+ __m256i scs[4] = {
+ _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]),
+ _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]),
+ };
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], y)); // blocks 4x0, 4x1, row 0
+ auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], y)); // blocks 4x2, 4x3, row 1
+ auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], y)); // blocks 4x4, 4x5, row 2
+ auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], y)); // blocks 4x6, 4x7, row 3
+ auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
+ auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
+ auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
+ isum[iy] = _mm256_add_epi32(isum[iy], sumi);
+ }
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]);
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256();
+#else
+ if constexpr (nrc_y == 1) {
+ auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[0], isum[1]), _mm256_unpackhi_epi32(isum[0], isum[1]));
+ auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[2], isum[3]), _mm256_unpackhi_epi32(isum[2], isum[3]));
+ auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34));
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ isum[0] = isum[1] = isum[2] = isum[3] = _mm256_setzero_si256();
+ } else {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum));
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+#ifndef HAVE_FANCY_SIMD
+ auto smask = _mm256_set1_epi64x(0x8040201008040201);
+ auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ auto m4 = _mm256_set1_epi8(4);
+#endif
+ __m256 acc[nrc_y] = {};
+#ifdef HAVE_FANCY_SIMD
+ __m256i shuffles[2] = {
+ _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100),
+ _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
+ };
+ __m256i isum[2*nrc_y] = {};
+#else
+ __m256i shuffles[4] = {
+ MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)),
+ };
+ __m256i isum[nrc_y == 1 ? 4 : nrc_y] = {};
+#endif
+ __m256i qx[4];
+ auto grid = iq2s_grid;
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+ auto s32 = (const uint32_t *)iq2[ibl].scales;
+ auto ql = iq2[ibl].qs;
+ auto qh = iq2[ibl].qh;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ qx[0] = _mm256_set_epi64x(grid[ql[ 3] | ((qh[0] << 2) & 0x300)], grid[ql[ 2] | ((qh[0] << 4) & 0x300)], grid[ql[ 1] | ((qh[0] << 6) & 0x300)], grid[ql[ 0] | ((qh[0] << 8) & 0x300)]);
+ qx[1] = _mm256_set_epi64x(grid[ql[ 7] | ((qh[1] << 2) & 0x300)], grid[ql[ 6] | ((qh[1] << 4) & 0x300)], grid[ql[ 5] | ((qh[1] << 6) & 0x300)], grid[ql[ 4] | ((qh[1] << 8) & 0x300)]);
+ qx[2] = _mm256_set_epi64x(grid[ql[11] | ((qh[2] << 2) & 0x300)], grid[ql[10] | ((qh[2] << 4) & 0x300)], grid[ql[ 9] | ((qh[2] << 6) & 0x300)], grid[ql[ 8] | ((qh[2] << 8) & 0x300)]);
+ qx[3] = _mm256_set_epi64x(grid[ql[15] | ((qh[3] << 2) & 0x300)], grid[ql[14] | ((qh[3] << 4) & 0x300)], grid[ql[13] | ((qh[3] << 6) & 0x300)], grid[ql[12] | ((qh[3] << 8) & 0x300)]);
+ ql += 16; qh += 4;
+ auto signs128 = _mm_loadu_si128((const __m128i*)iq2[ibl].signs + ib);
+ auto scales = _mm_set1_epi32(s32[ib]);
+ scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf));
+ scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1));
+ auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7
+#ifdef HAVE_FANCY_SIMD
+ __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) };
+ auto mask = (const __mmask32 *)&signs128;
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y)); // blocks: 0,0,0,0, 1,1,1,1, row 0
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_mask_sub_epi8(y, mask[1], _mm256_setzero_si256(), y)); // blocks: 2,2,2,2, 3,3,3,3, row 1
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_mask_sub_epi8(y, mask[2], _mm256_setzero_si256(), y)); // blocks: 4,4,4,4, 5,5,5,5, row 2
+ auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_mask_sub_epi8(y, mask[3], _mm256_setzero_si256(), y)); // blocks: 6,6,6,6, 7,7,7,7, row 3
+ auto s12 = _mm256_packs_epi32(sumi1, sumi2); // 0,0,0,0, 2,2,2,2, 1,1,1,1, 3,3,3,3
+ auto s34 = _mm256_packs_epi32(sumi3, sumi4); // 4,4,4,4, 6,6,6,6, 5,5,5,5, 7,7,7,7
+ isum[2*iy+0] = _mm256_add_epi32(isum[2*iy+0], _mm256_madd_epi16(scs[0], s12));
+ isum[2*iy+1] = _mm256_add_epi32(isum[2*iy+1], _mm256_madd_epi16(scs[1], s34));
+ }
+#else
+ auto signs = MM256_SET_M128I(signs128, signs128);
+ auto shuffle = sign_shuffle;
+ auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ __m256i scs[4] = {
+ _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]),
+ _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]),
+ };
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ if constexpr (nrc_y == 1) {
+ isum[0] = _mm256_add_epi32(isum[0], _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))));
+ isum[1] = _mm256_add_epi32(isum[1], _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))));
+ isum[2] = _mm256_add_epi32(isum[2], _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))));
+ isum[3] = _mm256_add_epi32(isum[3], _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))));
+ } else {
+ auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))); // blocks 4x0, 4x1, row 0
+ auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))); // blocks 4x2, 4x3, row 1
+ auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))); // blocks 4x4, 4x5, row 2
+ auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))); // blocks 4x6, 4x7, row 3
+ auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
+ auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
+ auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
+ isum[iy] = _mm256_add_epi32(isum[iy], sumi);
+ }
+ }
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]);
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256();
+#else
+ if constexpr (nrc_y == 1) {
+ auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[0], isum[1]), _mm256_unpackhi_epi32(isum[0], isum[1]));
+ auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[2], isum[3]), _mm256_unpackhi_epi32(isum[2], isum[3]));
+ auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34));
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ isum[0] = isum[1] = isum[2] = isum[3] = _mm256_setzero_si256();
+ } else {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum));
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+static void mul_mat_iq2_s_r4_q8_k_16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ constexpr int nrc_y = 16;
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+#ifndef HAVE_FANCY_SIMD
+ auto smask = _mm256_set1_epi64x(0x8040201008040201);
+ auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ auto m4 = _mm256_set1_epi8(4);
+#endif
+ __m256 acc[nrc_y] = {};
+#ifdef HAVE_FANCY_SIMD
+ __m256i shuffles[2] = {
+ _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100),
+ _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
+ };
+ __m256i isum[2*nrc_y] = {};
+#else
+ __m256i shuffles[4] = {
+ MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)),
+ MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)),
+ };
+ __m256i isum[nrc_y == 1 ? 4 : nrc_y] = {};
+#endif
+ __m256i qx[4];
+ auto grid = iq2s_grid;
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+ auto s32 = (const uint32_t *)iq2[ibl].scales;
+ auto ql = iq2[ibl].qs;
+ auto qh = iq2[ibl].qh;
+ {
+ auto scale_bits = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales);
+ auto scales1 = _mm256_and_si256(scale_bits, _mm256_set1_epi8(0xf));
+ auto scales2 = _mm256_and_si256(_mm256_srli_epi16(scale_bits, 4), _mm256_set1_epi8(0xf));
+ scales1 = _mm256_or_si256(_mm256_slli_epi16(scales1, 1), _mm256_set1_epi8(1));
+ scales2 = _mm256_or_si256(_mm256_slli_epi16(scales2, 1), _mm256_set1_epi8(1));
+ auto s1_8 = _mm256_unpacklo_epi8(scales1, scales2); // blocks 0...15, 32...47 (0...3, 8...11 from each row)
+ auto s2_8 = _mm256_unpackhi_epi8(scales1, scales2); // blocks 16..31, 48...63 (4...7, 12..15 from each row)
+ auto s1_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s1_8)); // 0...15 (0...3 from each row)
+ auto s2_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s1_8, 1)); // 32...47 (8..11 from each row)
+ auto s3_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s2_8)); // 16...31 (4...7 from each row)
+ auto s4_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s2_8, 1)); // 48...63 (12.15 from each row)
+ auto t1 = MM256_SET_M128I(_mm256_castsi256_si128(s2_16), _mm256_castsi256_si128(s1_16)); // 0,1 and 8,9 from each row
+ auto t2 = MM256_SET_M128I(_mm256_extracti128_si256(s2_16, 1), _mm256_extracti128_si256(s1_16, 1)); // 2,3 and 10,11 from each row
+ auto t3 = MM256_SET_M128I(_mm256_castsi256_si128(s4_16), _mm256_castsi256_si128(s3_16)); // 4,5 and 12,13 from each row
+ auto t4 = MM256_SET_M128I(_mm256_extracti128_si256(s4_16, 1), _mm256_extracti128_si256(s3_16, 1)); // 6,7 and 14,15 from each row
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, ibl);
+ auto sumi = _mm256_setzero_si256();
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm256_dpwssd_epi32(sumi, t1, _mm256_shuffle_epi32(bsums, 0x00));
+ sumi = _mm256_dpwssd_epi32(sumi, t2, _mm256_shuffle_epi32(bsums, 0x55));
+ sumi = _mm256_dpwssd_epi32(sumi, t3, _mm256_shuffle_epi32(bsums, 0xaa));
+ sumi = _mm256_dpwssd_epi32(sumi, t4, _mm256_shuffle_epi32(bsums, 0xff));
+#else
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t1, _mm256_shuffle_epi32(bsums, 0x00)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t2, _mm256_shuffle_epi32(bsums, 0x55)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t3, _mm256_shuffle_epi32(bsums, 0xaa)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t4, _mm256_shuffle_epi32(bsums, 0xff)));
+#endif
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(-64.f*q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ qx[0] = _mm256_set_epi64x(grid[ql[ 3] | ((qh[0] << 2) & 0x300)], grid[ql[ 2] | ((qh[0] << 4) & 0x300)], grid[ql[ 1] | ((qh[0] << 6) & 0x300)], grid[ql[ 0] | ((qh[0] << 8) & 0x300)]);
+ qx[1] = _mm256_set_epi64x(grid[ql[ 7] | ((qh[1] << 2) & 0x300)], grid[ql[ 6] | ((qh[1] << 4) & 0x300)], grid[ql[ 5] | ((qh[1] << 6) & 0x300)], grid[ql[ 4] | ((qh[1] << 8) & 0x300)]);
+ qx[2] = _mm256_set_epi64x(grid[ql[11] | ((qh[2] << 2) & 0x300)], grid[ql[10] | ((qh[2] << 4) & 0x300)], grid[ql[ 9] | ((qh[2] << 6) & 0x300)], grid[ql[ 8] | ((qh[2] << 8) & 0x300)]);
+ qx[3] = _mm256_set_epi64x(grid[ql[15] | ((qh[3] << 2) & 0x300)], grid[ql[14] | ((qh[3] << 4) & 0x300)], grid[ql[13] | ((qh[3] << 6) & 0x300)], grid[ql[12] | ((qh[3] << 8) & 0x300)]);
+ ql += 16; qh += 4;
+ auto signs128 = _mm_loadu_si128((const __m128i*)iq2[ibl].signs + ib);
+ auto scales = _mm_set1_epi32(s32[ib]);
+ scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf));
+ scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1));
+ auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7
+#ifdef HAVE_FANCY_SIMD
+ __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) };
+ auto mask = (const __mmask32 *)&signs128;
+ qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[0], mask[0], _mm256_setzero_si256(), qx[0]));
+ qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[1], mask[1], _mm256_setzero_si256(), qx[1]));
+ qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[2], mask[2], _mm256_setzero_si256(), qx[2]));
+ qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[3], mask[3], _mm256_setzero_si256(), qx[3]));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], y); // blocks: 0,0,0,0, 1,1,1,1, row 0
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], y); // blocks: 2,2,2,2, 3,3,3,3, row 1
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], y); // blocks: 4,4,4,4, 5,5,5,5, row 2
+ auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], y); // blocks: 6,6,6,6, 7,7,7,7, row 3
+ auto s12 = _mm256_packs_epi32(sumi1, sumi2); // 0,0,0,0, 2,2,2,2, 1,1,1,1, 3,3,3,3
+ auto s34 = _mm256_packs_epi32(sumi3, sumi4); // 4,4,4,4, 6,6,6,6, 5,5,5,5, 7,7,7,7
+ isum[2*iy+0] = _mm256_add_epi32(isum[2*iy+0], _mm256_madd_epi16(scs[0], s12));
+ isum[2*iy+1] = _mm256_add_epi32(isum[2*iy+1], _mm256_madd_epi16(scs[1], s34));
+ }
+#else
+ auto signs = MM256_SET_M128I(signs128, signs128);
+ auto shuffle = sign_shuffle;
+ auto s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[0], s));
+ s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[1], s));
+ s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[2], s));
+ s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[3], s));
+ __m256i scs[4] = {
+ _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]),
+ _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]),
+ };
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], y)); // blocks 4x0, 4x1, row 0
+ auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], y)); // blocks 4x2, 4x3, row 1
+ auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], y)); // blocks 4x4, 4x5, row 2
+ auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], y)); // blocks 4x6, 4x7, row 3
+ auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
+ auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
+ auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
+ isum[iy] = _mm256_add_epi32(isum[iy], sumi);
+ }
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]);
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256();
+#else
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum));
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+#ifndef HAVE_FANCY_SIMD
+ auto smask = _mm256_set1_epi64x(0x8040201008040201);
+ auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ auto m4 = _mm256_set1_epi8(4);
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ __m256 acc[nrc_y] = {};
+ __m256i isum[nrc_y] = {};
+ __m256i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto iq3 = (const block_iq3_xxs_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_mul_ps(_mm_set1_ps(0.25f), _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d))); // TODO: absorb the 0.25 factor into d when quantizing/repacking
+ auto d4 = _mm256_set_m128(dl, dl);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ qx[0] = _mm256_set_epi32(iq3xxs_grid[iq3[ibl].qs[32*ib+ 7]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 6]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 5]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 4]],
+ iq3xxs_grid[iq3[ibl].qs[32*ib+ 3]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 2]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 1]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 0]]);
+ qx[1] = _mm256_set_epi32(iq3xxs_grid[iq3[ibl].qs[32*ib+15]], iq3xxs_grid[iq3[ibl].qs[32*ib+14]], iq3xxs_grid[iq3[ibl].qs[32*ib+13]], iq3xxs_grid[iq3[ibl].qs[32*ib+12]],
+ iq3xxs_grid[iq3[ibl].qs[32*ib+11]], iq3xxs_grid[iq3[ibl].qs[32*ib+10]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 9]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 8]]);
+ qx[2] = _mm256_set_epi32(iq3xxs_grid[iq3[ibl].qs[32*ib+23]], iq3xxs_grid[iq3[ibl].qs[32*ib+22]], iq3xxs_grid[iq3[ibl].qs[32*ib+21]], iq3xxs_grid[iq3[ibl].qs[32*ib+20]],
+ iq3xxs_grid[iq3[ibl].qs[32*ib+19]], iq3xxs_grid[iq3[ibl].qs[32*ib+18]], iq3xxs_grid[iq3[ibl].qs[32*ib+17]], iq3xxs_grid[iq3[ibl].qs[32*ib+16]]);
+ qx[3] = _mm256_set_epi32(iq3xxs_grid[iq3[ibl].qs[32*ib+31]], iq3xxs_grid[iq3[ibl].qs[32*ib+30]], iq3xxs_grid[iq3[ibl].qs[32*ib+29]], iq3xxs_grid[iq3[ibl].qs[32*ib+28]],
+ iq3xxs_grid[iq3[ibl].qs[32*ib+27]], iq3xxs_grid[iq3[ibl].qs[32*ib+26]], iq3xxs_grid[iq3[ibl].qs[32*ib+25]], iq3xxs_grid[iq3[ibl].qs[32*ib+24]]);
+ auto sas = _mm_loadu_si128((const __m128i *)iq3[ibl].sas + ib);
+ auto scales = _mm_and_si128(sas, _mm_set1_epi8(1));
+#ifdef HAVE_FANCY_SIMD
+ scales = _mm_dpbusd_epi32(_mm_set1_epi32(1), scales, _mm_set1_epi32(0x10080402));
+#else
+ scales = _mm_maddubs_epi16(scales, _mm_set1_epi32(0x10080402));
+ scales = _mm_add_epi32(_mm_madd_epi16(_mm_set1_epi16(1), scales), _mm_set1_epi32(1));
+ //auto t1 = _mm_or_si128(_mm_and_si128(scales, _mm_set1_epi32(0x00000001)), _mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x00000100)), 7));
+ //auto t2 = _mm_or_si128(_mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x00010000)), 14), _mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x01000000)), 21));
+ //scales = _mm_or_si128(_mm_slli_epi32(_mm_or_si128(t1, t2), 1), _mm_set1_epi32(1));
+#endif
+ auto scales32 = MM256_SET_M128I(scales, scales);
+ auto signs128 = _mm_and_si128(sas, _mm_set1_epi8(-2)); // 0xfe = -2 as signed. Needed to shutup compiler warning.
+ signs128 = _mm_xor_si128(signs128, _mm_srli_epi16(signs128, 1));
+#ifdef HAVE_FANCY_SIMD
+ auto mask = (const __mmask32 *)&signs128;
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y));
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_mask_sub_epi8(y, mask[1], _mm256_setzero_si256(), y));
+ auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_mask_sub_epi8(y, mask[2], _mm256_setzero_si256(), y));
+ auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_mask_sub_epi8(y, mask[3], _mm256_setzero_si256(), y));
+ auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
+ auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
+ auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi));
+ }
+#else
+ auto signs = MM256_SET_M128I(signs128, signs128);
+ auto shuffle = sign_shuffle;
+ auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ shuffle = _mm256_add_epi8(shuffle, m4);
+ auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1)));
+ auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2)));
+ auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3)));
+ auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4)));
+ auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
+ auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
+ auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi));
+ }
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, sum);
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+ auto smask = _mm256_set1_epi8(1);
+ union { __m256i vec; uint32_t val[8]; } helper;
+ union { __m128i vec; uint16_t val[8]; } hidx;
+ __m256 acc[nrc_y] = {};
+ __m256i isum[nrc_y] = {};
+ __m256i qx[4];
+#ifdef HAVE_FANCY_SIMD
+ __mmask32 mask[4];
+#endif
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+ auto qs = iq3[ibl].qs;
+ auto qh = iq3[ibl].qh;
+ auto scale_bits = _mm_loadu_si128((const __m128i *)iq3[ibl].scales);
+ auto scales8 = MM256_SET_M128I(_mm_srli_epi16(scale_bits, 4), scale_bits);
+ helper.vec = _mm256_or_si256(_mm256_slli_epi16(_mm256_and_si256(scales8, _mm256_set1_epi8(0xf)), 1), _mm256_set1_epi8(1));
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto qh32 = (const uint32_t *)qh;
+ auto idx_h = _mm_sllv_epi64(_mm_cvtepu8_epi16(_mm_set1_epi32(qh32[0])), _mm_set_epi64x(4, 8));
+ for (int i = 0; i < 4; ++i) {
+ auto idx_l = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)(qs + 8*i)));
+ hidx.vec = _mm_or_si128(idx_l, _mm_and_si128(idx_h, _mm_set1_epi16(0x100))); idx_h = _mm_srli_epi16(idx_h, 1);
+ qx[i] = _mm256_set_epi32(iq3s_grid[hidx.val[7]], iq3s_grid[hidx.val[6]], iq3s_grid[hidx.val[5]], iq3s_grid[hidx.val[4]],
+ iq3s_grid[hidx.val[3]], iq3s_grid[hidx.val[2]], iq3s_grid[hidx.val[1]], iq3s_grid[hidx.val[0]]);
+ }
+ qs += 32; qh += 4;
+ auto signs128 = _mm_loadu_si128((const __m128i*)iq3[ibl].signs + ib);
+ auto signs = MM256_SET_M128I(_mm_srli_epi16(signs128, 4), signs128);
+#ifdef HAVE_FANCY_SIMD
+ auto scales = _mm256_cvtepi8_epi32(_mm_set1_epi32(helper.val[ib]));
+ mask[0] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1);
+ mask[1] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1);
+ mask[2] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1);
+ mask[3] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi = _mm256_setzero_si256();
+ auto ys = _mm256_shuffle_epi32(y, 0x00);
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_mask_sub_epi8(ys, mask[0], _mm256_setzero_si256(), ys));
+ ys = _mm256_shuffle_epi32(y, 0x55);
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_mask_sub_epi8(ys, mask[1], _mm256_setzero_si256(), ys));
+ ys = _mm256_shuffle_epi32(y, 0xaa);
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_mask_sub_epi8(ys, mask[2], _mm256_setzero_si256(), ys));
+ ys = _mm256_shuffle_epi32(y, 0xff);
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_mask_sub_epi8(ys, mask[3], _mm256_setzero_si256(), ys));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(sumi, scales));
+ }
+#else
+ auto scales16 = _mm256_cvtepi8_epi16(_mm_set1_epi32(helper.val[ib]));
+ auto scales = _mm256_unpacklo_epi16(scales16, scales16);
+ auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1);
+ auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1);
+ auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1);
+ auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), s1)));
+ sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), s2)));
+ sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), s3)));
+ sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), s4)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales, sumi));
+ }
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, sum);
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
+ funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
+ funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
+ funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;
+ funcs[3] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 4>;
+ funcs[4] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 5>;
+ funcs[5] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 6>;
+ funcs[6] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 7>;
+ funcs[7] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 8>;
+}
+
+} // namespace
+
+bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
+
+ if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
+ return false;
+ }
+
+ func16 = nullptr;
+
+ switch (typeA) {
+ case GGML_TYPE_IQ2_XXS:
+ set_functions<DequantizerIQ2XXS>(kernels);
+ break;
+ case GGML_TYPE_IQ2_XS:
+ set_functions<DequantizerIQ2XS>(kernels);
+ break;
+ case GGML_TYPE_IQ2_S:
+ set_functions<DequantizerIQ2S>(kernels);
+ break;
+ case GGML_TYPE_IQ3_XXS:
+ set_functions<DequantizerIQ3XXS>(kernels);
+ break;
+ case GGML_TYPE_IQ3_S:
+ set_functions<DequantizerIQ3S>(kernels);
+ break;
+ case GGML_TYPE_IQ2_XXS_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xxs_r4_q8_k, kernels);
+ func16 = mul_mat_iq2_xxs_r4_q8_k<16>;
+ break;
+ case GGML_TYPE_IQ2_XS_R4:
+ assert (ne00 % QK_K == 0);
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xs_r4_q8_k, kernels);
+#ifndef HAVE_FANCY_SIMD
+ // For some reason Zen4 does not like this particular function
+ func16 = mul_mat_iq2_xs_r4_q8_k_16;
+#endif
+ break;
+ case GGML_TYPE_IQ2_S_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_s_r4_q8_k, kernels);
+ func16 = mul_mat_iq2_s_r4_q8_k_16;
+ break;
+ case GGML_TYPE_IQ3_XXS_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_xxs_r4_q8_k, kernels);
+ func16 = mul_mat_iq3_xxs_r4_q8_k<16>;
+ break;
+ case GGML_TYPE_IQ3_S_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_s_r4_q8_k, kernels);
+ func16 = mul_mat_iq3_s_r4_q8_k<16>;
+ break;
+ default:
+ return false;
+ }
+
+ return true;
+
+}
+
+#else
+// --------------------------------------- __aarch64__ ---------------------------------------------
+
+namespace {
+
+inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
+ int32x4x4_t scales = {
+ vmovl_s16(vget_low_s16 (scales16.val[0])),
+ vmovl_s16(vget_high_s16(scales16.val[0])),
+ vmovl_s16(vget_low_s16 (scales16.val[1])),
+ vmovl_s16(vget_high_s16(scales16.val[1])),
+ };
+ return scales;
+}
+
+struct SimpleBits {
+ uint8x16x4_t b1;
+ uint8x16x4_t b2;
+};
+
+inline int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) {
+ int32x4x2_t scales;
+ scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v1, 28), 1), vdupq_n_u32(1)));
+ scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v2, 28), 1), vdupq_n_u32(1)));
+ return scales;
+}
+
+inline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) {
+ auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127))));
+ auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127))));
+ b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1));
+ b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2));
+}
+
+struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
+ DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+
+ auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs);
+ data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3
+ data.val[1] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3
+ data.val[2] = vuzp1q_u32(tmp.val[2], tmp.val[3]); // codebook indices for blocks 4...7
+ data.val[3] = vuzp2q_u32(tmp.val[2], tmp.val[3]); // scales and signs for blocks 4...7
+
+ return prepare_scales_8(data.val[1], data.val[3]);
+ }
+
+ static inline void prepare2(uint8x16_t * b, const uint8_t * idx, const uint64_t * signs, uint32_t sidx) {
+ b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]});
+ b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]});
+ apply_signs_2(b, signs, sidx);
+ }
+
+ inline void prepare(int /*i*/, int j) {
+ const uint8_t * idx = (const uint8_t *)(data.val + 2*j);
+ const uint32_t * sidx = (const uint32_t *)(data.val + 2*j+1);
+ prepare2(bits.b1.val + 0, idx, keven_signs, sidx[0]); idx += 4;
+ prepare2(bits.b1.val + 2, idx, keven_signs, sidx[1]); idx += 4;
+ prepare2(bits.b2.val + 0, idx, keven_signs, sidx[2]); idx += 4;
+ prepare2(bits.b2.val + 2, idx, keven_signs, sidx[3]);
+ }
+
+ uint32x4x4_t data;
+ SimpleBits bits;
+
+};
+
+inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) {
+ auto aux = vld1_u8(sc);
+ auto scales_l = vand_u8(aux, vdup_n_u8(0xf));
+ auto scales_h = vshr_n_u8(aux, 4);
+ auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
+
+ auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1)));
+ int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) };
+ return make_wider(scales16);
+}
+
+struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
+ DequantizerIQ2XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ return prepare_4bit_scales16(x[i].scales);
+ }
+
+ inline static uint8x16_t make1(const uint16_t * qs) {
+ auto b = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_u8((const uint8_t *)(iq2xs_grid + (qs[1] & 511))));
+ auto s = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9))));
+ return vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b), s));
+ }
+
+ inline static void make4(const uint16_t * qs, uint8x16_t * b) {
+ b[0] = make1(qs + 0);
+ b[1] = make1(qs + 2);
+ b[2] = make1(qs + 4);
+ b[3] = make1(qs + 6);
+ }
+
+ inline void prepare(int i, int j) {
+ make4(x[i].qs + 16*j + 0, bits.b1.val);
+ make4(x[i].qs + 16*j + 8, bits.b2.val);
+ }
+
+ SimpleBits bits;
+
+
+};
+
+struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
+ DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ return prepare_4bit_scales16(x[i].scales);
+ }
+
+ static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) {
+ uint32_t aux32[2];
+ const uint16_t * aux16 = (const uint16_t *)aux32;
+ for (int k = 0; k < 2; ++k) {
+ aux32[1] = (qh[k] << 4) | (qh[k] << 18);
+ aux32[0] = (aux32[1] << 4) & 0x03000300;
+ aux32[1] &= 0x03000300;
+ b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))),
+ vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1]))));
+ sh.apply_signs_1(b+2*k+0, signs16);
+
+ b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))),
+ vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3]))));
+ sh.apply_signs_1(b+2*k+1, signs16);
+ }
+ }
+
+ inline void prepare(int i, int j) {
+
+ const auto * qs = x[i].qs + 16*j;
+ const auto * qh = x[i].qh + 4*j;
+ const auto signs16 = vld1q_u8(qs + QK_K/8);
+
+ sh.init();
+ make4(sh, signs16, qs+0, qh+0, bits.b1.val);
+ make4(sh, signs16, qs+8, qh+2, bits.b2.val);
+ }
+
+ SimpleBits bits;
+ SignHelper sh;
+
+
+};
+
+struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
+ DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
+ d = 0.25f * GGML_FP16_TO_FP32(x[i].d);
+ gas = vld1q_u32_x2((const uint32_t *)(x[i].qs + QK_K/4));
+ return prepare_scales_8(gas.val[0], gas.val[1]);
+ }
+
+ inline static void make2(const uint8_t * q3, uint32_t sidx, uint8x16_t * b) {
+ b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]});
+ b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]});
+ apply_signs_2(b, keven_signs, sidx);
+ }
+ inline void prepare(int i, int j) {
+ const auto * q3 = x[i].qs + 32*j;
+ const auto * signs = (const uint32_t *)(gas.val + j);
+ make2(q3, signs[0], bits.b1.val + 0); q3 += 8;
+ make2(q3, signs[1], bits.b1.val + 2); q3 += 8;
+ make2(q3, signs[2], bits.b2.val + 0); q3 += 8;
+ make2(q3, signs[3], bits.b2.val + 2);
+ }
+
+ SimpleBits bits;
+ uint32x4x2_t gas;
+
+};
+
+struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
+ DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ uint32_t scales32[2];
+ std::memcpy(scales32, x[i].scales, 4);
+ scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
+ scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
+ auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7
+ scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400)));
+ auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8));
+ int32x4x2_t scales;
+ scales.val[0] = vmovl_s16(vget_low_s16(scales16));
+ scales.val[1] = vmovl_s16(vget_high_s16(scales16));
+ return scales;
+ }
+
+ static inline void make2(SignHelper& sh, const uint8x16_t& signs16, const uint16x8_t& idx_l, uint8_t qh,
+ const int8x16_t& hshift, uint8x16_t * b) {
+ auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256)));
+ const uint16_t * idx = (const uint16_t *)&vindex;
+ b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});
+ b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});
+ sh.apply_signs_1(b+0, signs16);
+ sh.apply_signs_1(b+1, signs16);
+ }
+ static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh,
+ const int8x16_t& hshift, uint8x16_t * b) {
+ auto idx_l = vld1q_u8(qs);
+ make2(sh, signs16, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0);
+ make2(sh, signs16, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2);
+ }
+
+ inline void prepare(int i, int j) {
+
+ static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
+ const auto hshift = vld1q_s16(k_shift);
+
+ const auto * qs = x[i].qs + 32*j;
+ const auto * qh = x[i].qh + 4*j;
+ const auto signs16 = vld1q_u8(x[i].signs + 16*j);
+
+ sh.init();
+ make4(sh, signs16, qs+ 0, qh+0, hshift, bits.b1.val);
+ make4(sh, signs16, qs+16, qh+2, hshift, bits.b2.val);
+ }
+
+ SimpleBits bits;
+ SignHelper sh;
+ uint32x4x2_t gas;
+
+};
+
+template <int nrc_y>
+static void mul_mat_iq2_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+ float32x4_t acc[nrc_y] = {};
+ int32x4_t isum[nrc_y] = {};
+ int8x16_t qx[8];
+ SignHelper sh;
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto iq2 = (const block_iq2_xxs_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
+ auto qs = iq2[ibl].qs;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto sas = vld1q_u8(iq2[ibl].sas + 16*ib);
+ auto scale_bits = vandq_u8(sas, vdupq_n_u8(1));
+ auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402)));
+ auto signs128 = vandq_u8(sas, vdupq_n_u8(254));
+ signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1));
+ sh.init();
+ for (int i = 0; i < 8; ++i) {
+ qx[i] = vreinterpretq_s8_u64(uint64x2_t{iq2xxs_grid[qs[2*i+0]], iq2xxs_grid[qs[2*i+1]]});
+ sh.apply_signs_1((uint8x16_t *)qx+i, signs128);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib);
+ auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]);
+ auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]);
+ auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]);
+ auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]);
+ auto sumi12 = vpaddq_s32(sumi1, sumi2);
+ auto sumi34 = vpaddq_s32(sumi3, sumi4);
+ auto sumi = vpaddq_s32(sumi12, sumi34);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ qs += 16;
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ isum[iy] = vdupq_n_s32(0);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy]));
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+ static const uint8_t k_shuff[16] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31};
+ auto shuff = vld1q_u8(k_shuff);
+ float32x4_t acc[nrc_y] = {};
+ int32x4_t isum[2*nrc_y] = {};
+ int8x16_t qx[8];
+ uint16x8x4_t scales16;
+ SignHelper sh;
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto iq2 = (const block_iq2_xs_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
+ auto qs = iq2[ibl].qs;
+ for (int is = 0; is < 2; ++is) {
+ auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is);
+ auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf));
+ auto scales2 = vshrq_n_u8(scale_bits, 4);
+ scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1));
+ scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1));
+ auto s1 = vzip1q_u8(scales1, scales2);
+ auto s2 = vzip2q_u8(scales1, scales2);
+ scales16.val[0] = vmovl_u8(vget_low_u8 (s1));
+ scales16.val[1] = vmovl_u8(vget_high_u8(s1));
+ scales16.val[2] = vmovl_u8(vget_low_u8 (s2));
+ scales16.val[3] = vmovl_u8(vget_high_u8(s2));
+ for (int ib = 0; ib < QK_K/64; ++ib) {
+ auto v = vld1q_u8_x2((const uint8_t *)qs);
+ auto signs128 = vandq_u8(vqtbl2q_u8(v, shuff), vdupq_n_u8(254));
+ signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1));
+ sh.init();
+ for (int i = 0; i < 8; ++i) {
+ qx[i] = vreinterpretq_s8_u64(uint64x2_t{iq2xs_grid[qs[2*i+0] & 511], iq2xs_grid[qs[2*i+1] & 511]});
+ sh.apply_signs_1((uint8x16_t *)qx+i, signs128);
+ }
+ auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib]));
+ auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib]));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib);
+ auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1]));
+ auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1]));
+ auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1]));
+ auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1]));
+ auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1
+ auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3
+ isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12);
+ isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34);
+ }
+ qs += 16;
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]);
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi));
+ isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy]));
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+ float32x4_t acc[nrc_y] = {};
+ int32x4_t isum[2*nrc_y] = {};
+ int8x16_t qx[8];
+ uint16x8x4_t scales16;
+ SignHelper sh;
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
+ auto qs = iq2[ibl].qs;
+ auto qh = iq2[ibl].qh;
+ for (int is = 0; is < 2; ++is) {
+ auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is);
+ auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf));
+ auto scales2 = vshrq_n_u8(scale_bits, 4);
+ scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1));
+ scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1));
+ auto s1 = vzip1q_u8(scales1, scales2);
+ auto s2 = vzip2q_u8(scales1, scales2);
+ scales16.val[0] = vmovl_u8(vget_low_u8 (s1));
+ scales16.val[1] = vmovl_u8(vget_high_u8(s1));
+ scales16.val[2] = vmovl_u8(vget_low_u8 (s2));
+ scales16.val[3] = vmovl_u8(vget_high_u8(s2));
+ for (int ib = 0; ib < QK_K/64; ++ib) {
+ auto signs128 = vld1q_u8(iq2[ibl].signs + 64*is + 16*ib);
+ sh.init();
+ for (int i = 0; i < 4; ++i) {
+ qx[2*i+0] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+0] | ((qh[i] << 8) & 0x300)], iq2s_grid[qs[4*i+1] | ((qh[i] << 6) & 0x300)]});
+ sh.apply_signs_1((uint8x16_t *)qx+2*i+0, signs128);
+ qx[2*i+1] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+2] | ((qh[i] << 4) & 0x300)], iq2s_grid[qs[4*i+3] | ((qh[i] << 2) & 0x300)]});
+ sh.apply_signs_1((uint8x16_t *)qx+2*i+1, signs128);
+ }
+ qs += 16; qh += 4;
+ auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib]));
+ auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib]));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib);
+ auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1]));
+ auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1]));
+ auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1]));
+ auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1]));
+ auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1
+ auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3
+ isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12);
+ isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]);
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi));
+ isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy]));
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+ float32x4_t acc[nrc_y] = {};
+ int32x4_t isum[nrc_y] = {};
+ int8x16_t qx[8];
+ SignHelper sh;
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto iq3 = (const block_iq3_xxs_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto d4 = vmulq_f32(vdupq_n_f32(0.25f), vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d)));
+ auto qs = iq3[ibl].qs;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto sas = vld1q_u8(iq3[ibl].sas + 16*ib);
+ auto scale_bits = vandq_u8(sas, vdupq_n_u8(1));
+ auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402)));
+ auto signs128 = vandq_u8(sas, vdupq_n_u8(254));
+ signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1));
+ sh.init();
+ for (int i = 0; i < 8; ++i) {
+ qx[i] = vreinterpretq_s8_u32(uint32x4_t{iq3xxs_grid[qs[4*i+0]], iq3xxs_grid[qs[4*i+1]], iq3xxs_grid[qs[4*i+2]], iq3xxs_grid[qs[4*i+3]]});
+ sh.apply_signs_1((uint8x16_t *)qx+i, signs128);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib);
+ auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]);
+ auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]);
+ auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]);
+ auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]);
+ auto sumi12 = vpaddq_s32(sumi1, sumi2);
+ auto sumi34 = vpaddq_s32(sumi3, sumi4);
+ auto sumi = vpaddq_s32(sumi12, sumi34);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ qs += 32;
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ isum[iy] = vdupq_n_s32(0);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+ float32x4_t acc[nrc_y] = {};
+ int32x4_t isum[nrc_y] = {};
+ int8x16_t qx[8];
+ auto m1 = vdupq_n_u8(1);
+ auto shuff = vreinterpretq_u8_u32(uint32x4_t{0xffffff00, 0xffffff01, 0xffffff02, 0xffffff03});
+ uint32_t stored_scales[8];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
+ auto qs = iq3[ibl].qs;
+ auto qh = iq3[ibl].qh;
+ auto scale_bits = vld1q_u8(iq3[ibl].scales);
+ uint8x16x2_t scales8 = { vandq_u8(scale_bits, vdupq_n_u8(0xf)), vshrq_n_u8(scale_bits, 4) };
+ scales8.val[0] = vorrq_u8(vshlq_n_u8(scales8.val[0], 1), m1);
+ scales8.val[1] = vorrq_u8(vshlq_n_u8(scales8.val[1], 1), m1);
+ vst1q_u8_x2((uint8_t *)stored_scales, scales8);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto signs128 = vld1q_u8(iq3[ibl].signs+16*ib);
+ if constexpr (nrc_y == 1) {
+ auto qh32 = (const uint32_t *)qh;
+ auto idx_h = vreinterpretq_u16_u64(vshlq_u64(vreinterpretq_u64_u16(vmovl_u8(vreinterpret_u8_u32(vdup_n_u32(qh32[0])))), int64x2_t{8, 4}));
+ union { uint16x8_t vec; uint16_t val[8]; } hidx;
+ for (int i = 0; i < 4; ++i) {
+ auto idx_l = vmovl_u8(vld1_u8(qs));
+ hidx.vec = vorrq_u16(idx_l, vandq_u16(idx_h, vdupq_n_u16(0x100))); idx_h = vshrq_n_u16(idx_h, 1);
+ qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[0]], iq3s_grid[hidx.val[1]], iq3s_grid[hidx.val[2]], iq3s_grid[hidx.val[3]]});
+ auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1));
+ qx[2*i+0] = vmulq_s8(qx[2*i+0], signs);
+ qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[4]], iq3s_grid[hidx.val[5]], iq3s_grid[hidx.val[6]], iq3s_grid[hidx.val[7]]});
+ signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1));
+ qx[2*i+1] = vmulq_s8(qx[2*i+1], signs);
+ signs128 = vshrq_n_u8(signs128, 1);
+ qs += 8;
+ }
+ } else {
+ for (int i = 0; i < 4; ++i) {
+ qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[0] | ((qh[0] << (8-i)) & 0x100)], iq3s_grid[qs[1] | ((qh[1] << (8-i)) & 0x100)],
+ iq3s_grid[qs[2] | ((qh[2] << (8-i)) & 0x100)], iq3s_grid[qs[3] | ((qh[3] << (8-i)) & 0x100)]});
+ auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1));
+ qx[2*i+0] = vmulq_s8(qx[2*i+0], signs);
+
+ qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[4] | ((qh[0] << (4-i)) & 0x100)], iq3s_grid[qs[5] | ((qh[1] << (4-i)) & 0x100)],
+ iq3s_grid[qs[6] | ((qh[2] << (4-i)) & 0x100)], iq3s_grid[qs[7] | ((qh[3] << (4-i)) & 0x100)]});
+ signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1));
+ qx[2*i+1] = vmulq_s8(qx[2*i+1], signs);
+
+ qs += 8;
+ signs128 = vshrq_n_u8(signs128, 1);
+ }
+ }
+ auto scales = vreinterpretq_s32_u8(vqtbl1q_u8(vreinterpretq_u8_u32(vdupq_n_u32(stored_scales[ib])), shuff));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ qh += 4;
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ isum[iy] = vdupq_n_s32(0);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+}
+
+bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
+
+ if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
+ return false;
+ }
+
+ func16 = nullptr;
+
+ switch (typeA) {
+ case GGML_TYPE_IQ2_XXS:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ2XXS, kernels);
+ break;
+ case GGML_TYPE_IQ2_XS:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ2XS, kernels);
+ break;
+ case GGML_TYPE_IQ2_S:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ2S, kernels);
+ break;
+ case GGML_TYPE_IQ3_XXS:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ3XXS, kernels);
+ break;
+ case GGML_TYPE_IQ3_S:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ3S, kernels);
+ break;
+ case GGML_TYPE_IQ2_XXS_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xxs_r4_q8_k, kernels);
+ func16 = mul_mat_iq2_xxs_r4_q8_k<16>;
+ break;
+ case GGML_TYPE_IQ2_XS_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xs_r4_q8_k, kernels);
+ func16 = mul_mat_iq2_xs_r4_q8_k<16>;
+ break;
+ case GGML_TYPE_IQ2_S_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_s_r4_q8_k, kernels);
+ func16 = mul_mat_iq2_s_r4_q8_k<16>;
+ break;
+ case GGML_TYPE_IQ3_XXS_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_xxs_r4_q8_k, kernels);
+ func16 = mul_mat_iq3_xxs_r4_q8_k<16>;
+ break;
+ case GGML_TYPE_IQ3_S_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_s_r4_q8_k, kernels);
+ func16 = mul_mat_iq3_s_r4_q8_k<16>;
+ break;
+ default:
+ return false;
+ }
+
+ return true;
+
+}
+
+#endif
+
+#endif
diff --git a/ggml/src/iqk/iqk_gemm_iquants.h b/ggml/src/iqk/iqk_gemm_iquants.h
new file mode 100644
index 00000000..4182526a
--- /dev/null
+++ b/ggml/src/iqk/iqk_gemm_iquants.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include "iqk_common.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include <array>
+
+bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
+
+#endif
diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp
new file mode 100644
index 00000000..dfbff710
--- /dev/null
+++ b/ggml/src/iqk/iqk_gemm_kquants.cpp
@@ -0,0 +1,3121 @@
+#include "iqk_gemm_kquants.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include "ggml-impl.h"
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#ifdef __x86_64__
+
+namespace {
+
+// Handles q4_K and q5_K scales/mins
+struct Scales8K {
+ template <typename Q8>
+ inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
+ make_q4_scales(data, utmp);
+ const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+ const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1);
+ accum_mins(mins128, q8, i, c, accd);
+ const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
+ return MM256_SET_M128I(sc128, sc128);
+ }
+#ifdef HAVE_FANCY_SIMD
+ template <typename Q8>
+ inline __m512i process_mins_and_scales_64(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
+ auto scales = process_mins_and_scales(data, c, i, q8, accd);
+ return _mm512_inserti32x8(_mm512_castsi256_si512(scales), scales, 1);
+ }
+#endif
+ template <typename Q8>
+ inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
+ base.accum_mins(mins128, q8, i, c, accd);
+ }
+#ifdef HAVE_FANCY_SIMD
+ const __m512i shuffles512[2] = {
+ _mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302,
+ 0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100),
+ _mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a,
+ 0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
+ };
+#endif
+ Scales8KBase base;
+
+ uint32_t utmp[4];
+};
+
+template <typename Q8>
+inline void process_mins_16(const __m256i& all_scales, const Q8& q8, int i, float d, __m256 * accm) {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i prod = _mm256_madd_epi16(all_scales, q8.load_bsums(iy, i));
+ accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
+ }
+}
+inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {
+ const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
+ const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
+ scales[0] = MM256_SET_M128I(l_scales, l_scales);
+ scales[1] = MM256_SET_M128I(h_scales, h_scales);
+}
+
+// Handles q3_K scales
+struct ScaleQ3 {
+ inline __m128i make_scales(const uint16_t * s8) const {
+ const uint16_t * scales16 = (const uint16_t *)s8;
+ uint32_t aux0 = scales16[0] | (scales16[1] << 16);
+ uint32_t aux1 = scales16[2] | (scales16[3] << 16);
+ uint32_t aux2 = scales16[4] | (scales16[5] << 16);
+ __m128i scales128 = _mm_set_epi32(
+ ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030),
+ ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030),
+ (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030),
+ (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030));
+ return _mm_add_epi8(scales128, m32);
+ }
+ const __m128i m32 = _mm_set1_epi8(-32);
+};
+
+struct Scale16 {
+ inline void make_scales(const __m128i& scales8, __m512i * scales) const {
+ auto all_scales8 = MM256_SET_M128I(scales8, scales8);
+ auto scales1 = _mm256_shuffle_epi8(all_scales8, shuffle1);
+ auto scales2 = _mm256_shuffle_epi8(all_scales8, shuffle2);
+ scales[0] = _mm512_cvtepi8_epi16(scales1);
+ scales[1] = _mm512_cvtepi8_epi16(scales2);
+ }
+ template <typename Q8>
+ inline void process_mins_and_scales(int i, float c, const __m128i& mins8, const __m128i& scales8,
+ const Q8& q8, __m256 * accm, __m512i * scales) const {
+ process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, c, accm);
+ make_scales(scales8, scales);
+ }
+ const __m256i shuffle1 = _mm256_set_epi32(0x07070707, 0x03030303, 0x06060606, 0x02020202,
+ 0x05050505, 0x01010101, 0x04040404, 0x00000000);
+ const __m256i shuffle2 = _mm256_set_epi32(0x0f0f0f0f, 0x0b0b0b0b, 0x0e0e0e0e, 0x0a0a0a0a,
+ 0x0d0d0d0d, 0x09090909, 0x0c0c0c0c, 0x08080808);
+};
+
+template <typename Q8>
+inline void process_mins_and_scales_16(const __m128i& scales128, const Q8& q8, int i, float d,
+ __m256 * accm, __m256i * scales) {
+ const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
+ process_mins_16(all_scales, q8, i, d, accm);
+ prepare_scales_16(all_scales, scales);
+}
+
+inline __m256i get_scale_shuffle_8(int i) {
+ return _mm256_set1_epi16((2*i) | ((2*i+1) << 8));
+}
+
+inline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) {
+ scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0));
+ scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1));
+ scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2));
+ scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3));
+}
+
+inline __m256i get_scale_shuffle_16(int i) {
+ static const uint8_t k_shuffle[128] = {
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
+ };
+ return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+
+inline void set_scales_16(const __m256i& all_scales, __m256i * scales) {
+ scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));
+ scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));
+ scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));
+ scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));
+}
+
+struct ScaleIQ4XS {
+ inline __m128i make_scales(const uint32_t scales_l, const uint16_t scales_h) {
+ uint32_t tmp32 = scales_h | (scales_h << 14);
+ const __m128i sh = _mm_slli_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(tmp32), hshift), hmask), 4);
+ const __m128i sl = _mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(scales_l), lshift), lmask);
+ return _mm_add_epi16(_mm_or_si128(sh, _mm_cvtepi8_epi16(_mm_shuffle_epi8(sl, lshuffle))), m32);
+ }
+ const __m128i hshift = _mm_set_epi32(12, 8, 4, 0);
+ const __m128i lshift = _mm_set_epi32(4, 0, 4, 0);
+ const __m128i hmask = _mm_set1_epi16(0x03);
+ const __m128i lmask = _mm_set1_epi8(0xf);
+ const __m128i lshuffle = _mm_set_epi32(0x07030602, 0x05010400, 0x07030602, 0x05010400);
+ const __m128i m32 = _mm_set1_epi16(-32);
+};
+
+#ifdef HAVE_FANCY_SIMD
+//====================================== Zen4 ==================================================
+
+struct HighBit5 {
+ inline void apply(const uint8_t * h, Q4Bits& bits) {
+ auto hbits256 = _mm256_loadu_si256((const __m256i *)h);
+ auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);
+ bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh));
+ bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));
+ bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(hbits, mh));
+ bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));
+ }
+ const __m512i mh = _mm512_set1_epi8(0x10);
+};
+
+struct HighBit3 {
+ inline void apply(const uint8_t * h, Q2Bits& bits) {
+ auto hbits256 = _mm256_loadu_si256((const __m256i *)h);
+ auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);
+ bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));
+ bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, mh));
+ bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));
+ bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), mh));
+ }
+ const __m512i mh = _mm512_set1_epi8(0x04);
+};
+
+
+template <typename Q8>
+inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
+ const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
+ const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1));
+ const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2));
+ const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3));
+ auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
+ sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
+ accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
+}
+
+struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
+ DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ bits.prepare(x[i].qs);
+ const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+ const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
+ const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+ sc16.process_mins_and_scales(i, -GGML_FP16_TO_FP32(x[i].dmin), mins8, scales8, q8, accm, scales);
+ }
+
+ Q2Bits bits;
+ Scale16 sc16;
+ const __m128i m4 = _mm_set1_epi8(0xf);
+
+};
+
+struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
+ DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ bits.prepare(x[i].qs);
+ hbits.apply(x[i].hmask, bits);
+ auto scales128 = sc3.make_scales((const uint16_t *)x[i].scales);
+ sc16.process_mins_and_scales(i, -4.f*d, scales128, scales128, q8, accm, scales);
+ }
+
+ Q2Bits bits;
+ HighBit3 hbits;
+ ScaleQ3 sc3;
+ Scale16 sc16;
+ const __m128i m4 = _mm_set1_epi8(0xf);
+ const __m128i m32 = _mm_set1_epi8(-32);
+};
+
+struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
+ DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ bits.prepare(x[i].qs);
+ auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
+ scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
+ scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
+ }
+
+ Q4Bits bits;
+ Scales8K s8k;
+};
+
+struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
+ DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ bits.prepare(x[i].qs);
+ hbits.apply(x[i].qh, bits);
+ auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
+ scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
+ scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
+ }
+
+ Q4Bits bits;
+ HighBit5 hbits;
+ Scales8K s8k;
+};
+
+struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
+ DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ bits.prepare64(x[i].ql);
+ add_high_bits(x[i].qh, bits);
+ auto scales128 = _mm_loadu_si128((const __m128i *)x[i].scales);
+ sc16.process_mins_and_scales(i, -32.f*d, scales128, scales128, q8, accm, scales);
+ }
+
+ inline void add_high_bits(const uint8_t * qh, Q4Bits& bits) const {
+ auto hbits = _mm512_loadu_si512((const __m512i *)qh);
+ auto tmp1 = _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh);
+ auto tmp2 = _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh);
+ bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));
+ bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));
+ tmp1 = _mm512_and_si512(hbits, mh);
+ tmp2 = _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh);
+ bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));
+ bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));
+ }
+
+ Q4Bits bits;
+ HighBit3 hbits;
+ Scale16 sc16;
+
+ const __m512i mh = _mm512_set1_epi8(0x30);
+
+};
+
+struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
+ DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ prepare(x[i].qs);
+ auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);
+ s8k.accum_mins(scales128, q8, i, -128.f*d, accd);
+ auto scales256 = MM256_SET_M128I(scales128, scales128);
+ auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
+ scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
+ scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
+ scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
+ scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
+ }
+ inline void prepare(const uint8_t * q4) {
+ bits.prepare64(q4);
+ // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
+ // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127
+ // etc.
+ auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
+ bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
+ bits.values[0] = _mm512_shuffle_epi8(values, tmp);
+ tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
+ bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
+ bits.values[2] = _mm512_shuffle_epi8(values, tmp);
+ }
+
+ Q4Bits bits;
+ Scales8KBase s8k;
+ ScaleIQ4XS siq4;
+ const __m512i values;
+ const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
+ const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
+ const __m512i shuffles[4] = {
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
+ _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
+ };
+};
+
+template <typename Dequantizer>
+static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ constexpr int k_nx = 2;
+
+ Q8<1> q8(info);
+
+ Dequantizer deq1(vx, bx);
+ Dequantizer deq2(vx, bx);
+
+ Dequantizer * deq[k_nx];
+ deq[0] = &deq1;
+ deq[1] = &deq2;
+
+ __m512i scales[2*k_nx];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ auto accd = _mm512_setzero_ps();
+ auto accm = _mm256_setzero_ps();
+
+ for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix);
+
+ for (int i = 0; i < nb/k_nx; ++i) {
+
+ for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);
+
+ for (int kx = 0; kx < k_nx; ++kx) {
+ compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
+ }
+
+ }
+ if (2*(nb/2) < nb) {
+ int i0 = 2*(nb/2);
+ deq[0]->new_block(i0, q8, &accm, scales);
+ compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
+ }
+
+ auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
+ info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
+ }
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y> q8(info);
+
+ Dequantizer deq(vx, bx);
+
+ __m256 accm[nrc_y];
+ __m512 accd[nrc_y];
+ __m512i scales[2];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
+ for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ deq.new_block(i, q8, accm, scales);
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
+ const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
+ const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
+ const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
+ auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
+ sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
+ accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
+ info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
+ }
+
+ }
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y> q8(info);
+
+ Dequantizer deq(vx, bx);
+
+ __m256 accm[nrc_y];
+ __m512 accd[nrc_y];
+ __m512i scales[4];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
+ for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ deq.new_block(i, q8, accm, scales);
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const __m512i p1 = _mm512_maddubs_epi16(deq.bits.values[0], q8.load_quants64(iy, i, 0));
+ const __m512i p2 = _mm512_maddubs_epi16(deq.bits.values[1], q8.load_quants64(iy, i, 1));
+ const __m512i p3 = _mm512_maddubs_epi16(deq.bits.values[2], q8.load_quants64(iy, i, 2));
+ const __m512i p4 = _mm512_maddubs_epi16(deq.bits.values[3], q8.load_quants64(iy, i, 3));
+ auto sumi = _mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_setzero_si512(),
+ p1, scales[0]), p2, scales[1]), p3, scales[2]), p4, scales[3]);
+ accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
+ info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
+ }
+
+ }
+}
+
+#else
+//====================================== AVX2 ==================================================
+
+struct HighBit5 {
+ inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }
+ inline void apply(Q4Bits& bits, bool do_shift) {
+ bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
+ bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh));
+ bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
+ bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
+ if (do_shift) {
+ hbits = _mm256_srli_epi16(hbits, 4);
+ }
+ }
+ const __m256i mh = _mm256_set1_epi8(0x10);
+ __m256i hbits;
+};
+
+struct HighBit3 {
+ inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }
+ inline void apply(Q2Bits& bits, bool do_shift) {
+ bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
+ bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
+ bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));
+ bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));
+ if (do_shift) {
+ hbits = _mm256_srli_epi16(hbits, 4);
+ }
+ }
+ const __m256i mh = _mm256_set1_epi8(0x04);
+ __m256i hbits;
+};
+
+template <typename Q8>
+inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
+ const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
+ const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1));
+ const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2));
+ const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3));
+ auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
+ sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
+ accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
+}
+
+struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
+ DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+ const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
+ const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+ process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, -GGML_FP16_TO_FP32(x[i].dmin), accm);
+ prepare_scales_16(_mm256_cvtepi8_epi16(scales8), scales);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ }
+
+ Q2Bits bits;
+
+ const __m128i m4 = _mm_set1_epi8(0xf);
+};
+
+struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
+ DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ hbits.load(x[i].hmask);
+ process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ hbits.apply(bits, j == 0);
+ }
+
+ Q2Bits bits;
+ HighBit3 hbits;
+ ScaleQ3 sc3;
+
+ const __m128i m32 = _mm_set1_epi8(-32);
+};
+
+struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
+ DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ }
+
+ Q4Bits bits;
+ Scales8K s8k;
+};
+
+struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
+ DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ hbits.load(x[i].qh);
+ return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ hbits.apply(bits, j == 0);
+ }
+
+ Q4Bits bits;
+ HighBit5 hbits;
+ Scales8K s8k;
+};
+
+struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
+ DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ process_mins_and_scales_16(_mm_loadu_si128((const __m128i *)x[i].scales), q8, i, -32.f*d, accm, scales);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare64(x[i].ql, j);
+ auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j);
+ bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
+ bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
+ bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));
+ bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));
+ }
+
+ Q4Bits bits;
+ const __m256i mh = _mm256_set1_epi8(0x30);
+};
+
+struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
+ DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {}
+ template <typename Q8>
+ inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);
+ s8k.accum_mins(scales128, q8, i, -128.f*d, accd);
+ return MM256_SET_M128I(scales128, scales128);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare16(x[i].qs, j);
+ bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);
+ bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);
+ bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);
+ bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
+ }
+
+ Q4Bits bits;
+ Scales8K s8k;
+ ScaleIQ4XS siq4;
+ const __m256i values;
+};
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y> q8(info);
+
+ Dequantizer deq(vx, bx);
+
+ __m256 accd[nrc_y];
+ __m256i scales[4];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ auto all_scales = deq.new_block(i, q8, accd);
+
+ __m256i sumi[nrc_y];
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ deq.prepare(i, j);
+
+ set_scales_8(all_scales, j, scales);
+
+ multiply_add(deq.bits, scales, j, i, q8, sumi);
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
+ accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+
+ }
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK_K == 0);
+ const int nb = n/QK_K;
+
+ Q8<nrc_y> q8(info);
+
+ __m256i all_scales[2];
+ __m256i scales[4];
+ __m256 accd[nrc_y];
+
+ Dequantizer deq(vx, bx);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq.new_row(ix);
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+
+ deq.new_block(i, q8, accd, all_scales);
+
+ __m256i sumi[nrc_y];
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ deq.prepare(i, j);
+ set_scales_16(all_scales[j], scales);
+ multiply_add(deq.bits, scales, j, i, q8, sumi);
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+
+ }
+
+}
+
+#endif
+
+template <int nrc_y>
+static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ auto m30 = _mm256_set1_epi8(0x30);
+ auto m32 = _mm256_set1_epi8(32);
+#ifndef HAVE_FANCY_SIMD
+ auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
+ auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
+ auto values = MM256_SET_M128I(values128, values128);
+#else
+ auto values = load_iq4nl_values_256();
+#endif
+ int nbl = n / QK_K;
+ using helper_t = union { __m256i vec[2]; uint64_t val[8]; };
+ helper_t h;
+ __m256 acc[nrc_y] = {};
+ __m256i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_iq4_xs_r8 * iq4 = (const block_iq4_xs_r8 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
+ auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
+ auto sl1 = _mm256_and_si256(slbits, m4);
+ auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4);
+ auto shbits = _mm_loadu_si128((const __m128i*)iq4[ibl].scales_h);
+ auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
+ h.vec[0] = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(_mm256_slli_epi16(sh, 4), m30)), m32);
+ h.vec[1] = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(sh, m30)), m32);
+ __m256i isum[nrc_y] = {};
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+#ifdef HAVE_FANCY_SIMD
+ auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi64x(h.val[ib]));
+ auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
+ auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-128.f));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
+ acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
+ }
+#else
+ auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(h.val[ib])), s_shuffle);
+#endif
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+1);
+ qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1));
+ qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4)));
+ qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2));
+ qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4)));
+#ifndef HAVE_FANCY_SIMD
+ auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
+ auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
+ auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
+ auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+0);
+ auto y = MM256_SET_M128I(y128, y128);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
+ auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
+ auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
+ auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
+ auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)),
+ _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
+ isum[iy] = _mm256_add_epi32(isum[iy], sumi);
+#endif
+ }
+ bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+2);
+ bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+3);
+ qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1));
+ qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4)));
+ qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2));
+ qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4)));
+#ifndef HAVE_FANCY_SIMD
+ s1 = _mm256_sign_epi8(qx[0], qx[0]);
+ s2 = _mm256_sign_epi8(qx[1], qx[1]);
+ s3 = _mm256_sign_epi8(qx[2], qx[2]);
+ s4 = _mm256_sign_epi8(qx[3], qx[3]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+1);
+ auto y = MM256_SET_M128I(y128, y128);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
+ auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
+ auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
+ auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
+ auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)),
+ _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
+ isum[iy] = _mm256_add_epi32(isum[iy], sumi);
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ mul_mat_iq4_xs_r8_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
+ return;
+ if constexpr (nrc_y == 1){
+ mul_mat_iq4_xs_r8_q8_k_avx2<1>(n, vx, bx, info, nrc_x);
+ } else {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm512_set1_epi8(0xf);
+ auto values = load_iq4nl_values_512();
+ int nbl = n / QK_K;
+ using helper_t = union { __m512i vec; uint32_t val[16]; };
+ helper_t h;
+ __m512 acc[nrc_y] = {};
+ __m512i isum[nrc_y] = {};
+ __m512i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_iq4_xs_r8 * iq4l = (const block_iq4_xs_r8 *)((const char *)vx + (ix+0)*bx);
+ const block_iq4_xs_r8 * iq4h = (const block_iq4_xs_r8 *)((const char *)vx + (ix+4)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[ibl].d));
+ auto dh = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[ibl].d));
+ auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1);
+ auto d4x64 = _mm512_mul_ps(d4, _mm512_set1_ps(-64.f));
+ auto slbits_l = _mm_loadu_si128((const __m128i *)iq4l[ibl].scales_l);
+ auto shbits_l = _mm_loadu_si128((const __m128i *)iq4h[ibl].scales_l);
+ auto sl_l = MM256_SET_M128I(_mm_srli_epi16(slbits_l, 4), slbits_l);
+ auto sh_l = MM256_SET_M128I(_mm_srli_epi16(shbits_l, 4), shbits_l);
+ auto slb = _mm512_and_si512(_mm512_inserti32x8(_mm512_castsi256_si512(sl_l), sh_l, 1), m4);
+ auto aux64 = (const uint64_t *)iq4l[ibl].scales_h;
+ auto slbits_h = _mm_set_epi64x(aux64[0] >> 2, aux64[0]);
+ aux64 = (const uint64_t *)iq4h[ibl].scales_h;
+ auto shbits_h = _mm_set_epi64x(aux64[0] >> 2, aux64[0]);
+ auto sl_h = MM256_SET_M128I(slbits_h, _mm_slli_epi16(slbits_h, 4));
+ auto sh_h = MM256_SET_M128I(shbits_h, _mm_slli_epi16(shbits_h, 4));
+ auto shb = _mm512_and_si512(_mm512_inserti32x8(_mm512_castsi256_si512(sl_h), sh_h, 1), _mm512_set1_epi8(0x30));
+ h.vec = _mm512_sub_epi8(_mm512_or_si512(slb, shb), _mm512_set1_epi8(32));
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto iscales = _mm512_cvtepi8_epi32(_mm_blend_epi32(_mm_set1_epi32(h.val[ib+0]), _mm_set1_epi32(h.val[ib+8]), 0x0c));
+ auto scales = _mm512_cvtepi32_ps(iscales);
+ auto scales_m = _mm512_mul_ps(scales, d4x64);
+ auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+0)),
+ _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+0), 1);
+ auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+1)),
+ _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+1), 1);
+ qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
+ qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
+ qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
+ qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi));
+ float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
+ acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm512_setzero_si512();
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1));
+ auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3));
+ info.store(ix+0, iy, sum1);
+ info.store(ix+4, iy, sum2);
+ acc[iy] = _mm512_setzero_ps();
+ }
+ }
+ }
+}
+#else
+template <int nrc_y>
+static void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ mul_mat_iq4_xs_r8_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
+}
+#endif
+
+template <int nrc_y>
+static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto mxf = _mm256_set1_epi8(0xf);
+ auto m03 = _mm256_set1_epi8(0x03);
+ static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
+ auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
+#ifdef HAVE_FANCY_SIMD
+ __m256i isum[nrc_y] = {};
+#else
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ int nbl = n / QK_K;
+ __m256 acc[nrc_y] = {};
+ __m256i qx[4];
+ int8_t scales[64];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q2_k_r4 * iq2 = (const block_q2_k_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq2[ibl].d));
+ auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dm), _mm256_castps256_ps128(dm));
+ auto m4 = _mm256_set_m128(_mm256_extractf128_ps(dm, 1), _mm256_extractf128_ps(dm, 1));
+ m4 = _mm256_mul_ps(m4, _mm256_set1_ps(-1.f));
+ auto all_scales1 = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales+0);
+ auto all_scales2 = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales+1);
+ auto scales1 = _mm256_and_si256(_mm256_srli_epi16(all_scales1, 4), mxf);
+ auto scales2 = _mm256_and_si256(_mm256_srli_epi16(all_scales2, 4), mxf);
+ {
+ auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
+ auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
+ auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
+ auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
+ auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
+ auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
+ auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
+ auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, ibl);
+ auto sumi = _mm256_setzero_si256();
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
+ sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
+ sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
+ sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
+ auto d8 = _mm256_set1_ps(q8.scale(iy, ibl));
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]);
+#else
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
+ auto d8 = _mm256_set1_ps(q8.scale(iy, ibl));
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ if constexpr (nrc_y == 1) {
+ d4 = _mm256_mul_ps(d4, d8);
+ }
+#endif
+ }
+ }
+ all_scales1 = _mm256_and_si256(all_scales1, mxf);
+ all_scales2 = _mm256_and_si256(all_scales2, mxf);
+ _mm256_storeu_si256((__m256i *)scales+0, all_scales1);
+ _mm256_storeu_si256((__m256i *)scales+1, all_scales2);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib)));
+#ifndef HAVE_FANCY_SIMD
+ auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
+#endif
+ auto lb = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs+ib);
+ qx[0] = _mm256_and_si256(lb, m03);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(lb, 2), m03);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(lb, 4), m03);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(lb, 6), m03);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ // Quants are in 0...3, so we can add add up all of them as int16_t without overflowing
+ auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
+ if constexpr (nrc_y == 1) {
+ acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ } else {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+#endif
+ }
+ }
+#ifdef HAVE_FANCY_SIMD
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
+ acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, sum);
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ auto m30 = _mm256_set1_epi8(0x30);
+ auto m32 = _mm256_set1_epi8(32);
+ auto m03 = _mm256_set1_epi8(0x03);
+ auto m04 = _mm256_set1_epi8(0x04);
+ static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
+ auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
+#ifdef HAVE_FANCY_SIMD
+ __m256i isum[nrc_y];
+#else
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ int nbl = n / QK_K;
+ __m256 acc[nrc_y] = {};
+ __m256i qx[4];
+ int8_t scales[64];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q3_k_r4 * iq3 = (const block_q3_k_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+#ifndef HAVE_FANCY_SIMD
+ if constexpr (nrc_y == 1) {
+ d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
+ }
+#endif
+ auto slb = _mm256_loadu_si256((const __m256i *)iq3[ibl].scales_l);
+ auto shbits = _mm_loadu_si128((const __m128i *)iq3[ibl].scales_h);
+ auto shb = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
+ auto scales1 = _mm256_sub_epi8(_mm256_or_si256(_mm256_and_si256(slb, m4), _mm256_and_si256(_mm256_slli_epi16(shb, 4), m30)), m32);
+ auto scales2 = _mm256_sub_epi8(_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(slb, 4), m4), _mm256_and_si256(shb, m30)), m32);
+ _mm256_storeu_si256((__m256i *)scales+0, scales1);
+ _mm256_storeu_si256((__m256i *)scales+1, scales2);
+ {
+#ifndef HAVE_FANCY_SIMD
+ auto min = _mm256_mul_ps(d4, _mm256_set1_ps(-4.f));
+#endif
+ auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
+ auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
+ auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
+ auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
+ auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
+ auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
+ auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
+ auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
+#ifdef HAVE_FANCY_SIMD
+ s1 = _mm256_mullo_epi16(s1, _mm256_set1_epi16(-4));
+ s2 = _mm256_mullo_epi16(s2, _mm256_set1_epi16(-4));
+ s3 = _mm256_mullo_epi16(s3, _mm256_set1_epi16(-4));
+ s4 = _mm256_mullo_epi16(s4, _mm256_set1_epi16(-4));
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, ibl);
+ auto sumi = _mm256_setzero_si256();
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
+ sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
+ sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
+ sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
+ isum[iy] = sumi;
+#else
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
+ if constexpr (nrc_y == 1) {
+ acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ } else {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(min, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+#endif
+ }
+ }
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib)));
+#ifndef HAVE_FANCY_SIMD
+ auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
+#endif
+ auto lb = _mm256_loadu_si256((const __m256i *)iq3[ibl].qs+ib);
+ auto hbits = _mm_loadu_si128((const __m128i *)iq3[ibl].qh+ib);
+ auto hb = MM256_SET_M128I(hbits, _mm_slli_epi16(hbits, 4));
+ qx[0] = _mm256_or_si256(_mm256_and_si256(lb, m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 2)));
+ qx[1] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 2), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 3)));
+ qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 4), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 4)));
+ qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 6), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 5)));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ // Quants are in 0...8, so we can add add up all of them as int16_t without overflowing
+ auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
+ if constexpr (nrc_y == 1) {
+ acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ } else {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+#endif
+
+ }
+ }
+#ifdef HAVE_FANCY_SIMD
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
+ acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ }
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, sum);
+ }
+ }
+}
+
+template <int nrc_y>
+inline void process_min_r4_b32(int ibl, __m256 m4, __m256i mins, const Q8<nrc_y, block_q8_K>& q8, __m256 * acc) {
+ auto mins_l = _mm256_castsi256_si128(mins);
+ auto mins_h = _mm256_extracti128_si256(mins, 1);
+ auto aux1 = _mm_unpacklo_epi32(mins_l, mins_h);
+ auto aux2 = _mm_unpackhi_epi32(mins_l, mins_h);
+ auto ic1 = _mm256_cvtepi8_epi32(aux1);
+ auto ic2 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux1, 0xee));
+ auto ic3 = _mm256_cvtepi8_epi32(aux2);
+ auto ic4 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux2, 0xee));
+ if constexpr (nrc_y == 1) {
+ auto bs = _mm256_loadu_ps((const float *)q8.y[0][ibl].bsums);
+ auto sumf = _mm256_mul_ps(_mm256_cvtepi32_ps(ic1), _mm256_shuffle_ps(bs, bs, 0x00));
+ sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic2), _mm256_shuffle_ps(bs, bs, 0x55), sumf);
+ sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic3), _mm256_shuffle_ps(bs, bs, 0xaa), sumf);
+ sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic4), _mm256_shuffle_ps(bs, bs, 0xff), sumf);
+ acc[0] = _mm256_fmadd_ps(m4, sumf, acc[0]);
+ } else {
+ auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic1));
+ auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic2));
+ auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic3));
+ auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic4));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
+ acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto mf = _mm256_set1_epi8(0xf);
+ auto m3 = _mm256_set1_epi8(0x30);
+ int nbl = n / QK_K;
+ union { __m256i vec; uint32_t val[8]; } hd;
+ __m256 acc[nrc_y] = {};
+ __m256i isum[nrc_y] = {};
+ __m256i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
+ auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
+ auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
+ auto lbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
+ auto hbits128 = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_h);
+ auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
+ hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m3));
+ auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3));
+ process_min_r4_b32(ibl, m4, mins, q8, acc);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+#ifdef HAVE_FANCY_SIMD
+ auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
+#else
+ auto aux = _mm_set1_epi32(hd.val[ib]);
+ aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
+ auto scales_d = MM256_SET_M128I(aux, aux);
+#endif
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
+ qx[0] = _mm256_and_si256(bits1, mf);
+ qx[1] = _mm256_and_si256(bits2, mf);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), mf);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), mf);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales_d, _mm256_add_epi16(sumi1, sumi2)));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, sum);
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto mf = _mm256_set1_epi8(0xf);
+ auto m10 = _mm256_set1_epi8(0x10);
+ auto m30 = _mm256_set1_epi8(0x30);
+ int nbl = n / QK_K;
+ union { __m256i vec; uint32_t val[8]; } hd;
+ __m256 acc[nrc_y] = {};
+ __m256i isum[nrc_y] = {};
+ __m256i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5[ibl].d));
+ auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
+ auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
+ auto lbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l);
+ auto hbits128 = _mm_loadu_si128((const __m128i *)iq5[ibl].scales_h);
+ auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
+ hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m30));
+ auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m30));
+ process_min_r4_b32(ibl, m4, mins, q8, acc);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+#ifdef HAVE_FANCY_SIMD
+ auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
+#else
+ auto aux = _mm_set1_epi32(hd.val[ib]);
+ aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
+ auto scales_d = MM256_SET_M128I(aux, aux);
+#endif
+ auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
+ auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
+ auto hbits128 = _mm_loadu_si128((const __m128i*)iq5[ibl].qh + ib);
+ auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
+ qx[0] = _mm256_or_si256(_mm256_and_si256(lbits1, mf), _mm256_and_si256(m10, hbits));
+ qx[1] = _mm256_or_si256(_mm256_and_si256(lbits2, mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 2)));
+ qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 1)));
+ qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 3)));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ // To avoid overflow, we can only add up to 4 q5 x q8 products.
+ auto sumi = _mm256_add_epi32(_mm256_madd_epi16(scales_d, sumi1), _mm256_madd_epi16(scales_d, sumi2));
+ isum[iy] = _mm256_add_epi32(isum[iy], sumi);
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, sum);
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ auto m3 = _mm256_set1_epi8(0x30);
+ static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
+ auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
+#ifdef HAVE_FANCY_SIMD
+ __m256i isum[nrc_y];
+#else
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ int nbl = n / QK_K;
+ __m256 acc[nrc_y] = {};
+ __m256i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q6_k_r4 * iq6 = (const block_q6_k_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+#ifndef HAVE_FANCY_SIMD
+ if constexpr (nrc_y == 1) {
+ d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
+ }
+#endif
+ {
+#ifndef HAVE_FANCY_SIMD
+ auto min = _mm256_mul_ps(d4, _mm256_set1_ps(-32.f));
+#endif
+ auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+0)), shuff); // blocks 0, 1, 2, 3 for each row
+ auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+1)), shuff); // blocks 4, 5, 6, 7 for each row
+ auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+2)), shuff); // blocks 8, 9, 10, 11 for each row
+ auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+3)), shuff); // blocks 12, 13, 14, 15 for each row
+ auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
+ auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
+ auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
+ auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
+#ifdef HAVE_FANCY_SIMD
+ s1 = _mm256_mullo_epi16(s1, _mm256_set1_epi16(-32));
+ s2 = _mm256_mullo_epi16(s2, _mm256_set1_epi16(-32));
+ s3 = _mm256_mullo_epi16(s3, _mm256_set1_epi16(-32));
+ s4 = _mm256_mullo_epi16(s4, _mm256_set1_epi16(-32));
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, ibl);
+ auto sumi = _mm256_setzero_si256();
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
+ sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
+ sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
+ sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
+ isum[iy] = sumi;
+#else
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
+ if constexpr (nrc_y == 1) {
+ acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ } else {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(min, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+#endif
+ }
+ }
+ const uint32_t * scales = (const uint32_t *)iq6[ibl].scales;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 2*ib)));
+#ifndef HAVE_FANCY_SIMD
+ auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
+#endif
+ auto lbits1 = _mm256_loadu_si256((const __m256i *)iq6[ibl].ql+2*ib+0);
+ auto lbits2 = _mm256_loadu_si256((const __m256i *)iq6[ibl].ql+2*ib+1);
+ auto hbits = _mm256_loadu_si256((const __m256i *)iq6[ibl].qh+ib);
+ qx[0] = _mm256_or_si256(_mm256_and_si256(lbits1, m4), _mm256_and_si256(m3, _mm256_slli_epi16(hbits, 4)));
+ qx[1] = _mm256_or_si256(_mm256_and_si256(lbits2, m4), _mm256_and_si256(m3, hbits));
+ qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4), _mm256_and_si256(m3, _mm256_slli_epi16(hbits, 2)));
+ qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4), _mm256_and_si256(m3, _mm256_srli_epi16(hbits, 2)));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ // Quants are in 0...63, so we can add at most 4 as int16_t to be sure of no int16_t overflow
+ auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
+ if constexpr (nrc_y == 1) {
+ acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ } else {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+#endif
+ }
+ }
+#ifdef HAVE_FANCY_SIMD
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
+ acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ }
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, sum);
+ }
+ }
+}
+
+template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
+#ifdef HAVE_FANCY_SIMD
+ if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4XS>) {
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_iqX_k_q8_K_AVX512, Dequantizer, funcs)
+ } else {
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_AVX512, Dequantizer, funcs)
+ funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
+ }
+#else
+ if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> ||
+ std::is_same_v<Dequantizer, DequantizerQ3K> ||
+ std::is_same_v<Dequantizer, DequantizerQ6K>) {
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_K_T, Dequantizer, funcs)
+ } else {
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, Dequantizer, funcs)
+ }
+#endif
+}
+
+// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
+template <int nrc_y>
+static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+#ifndef HAVE_FANCY_SIMD
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ int nbl = n / QK_K;
+ __m256 acc[nrc_y] = {};
+ __m256i isum[nrc_y] = {};
+ __m256i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ibl].d));
+ for (int ib = 0; ib < QK_K/16; ++ib) {
+ qx[0] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+0);
+ qx[1] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1);
+ qx[2] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2);
+ qx[3] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3);
+#ifndef HAVE_FANCY_SIMD
+ auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
+ auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
+ auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
+ auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
+#else
+ qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127));
+ qx[1] = _mm256_add_epi8(qx[1], _mm256_set1_epi8(127));
+ qx[2] = _mm256_add_epi8(qx[2], _mm256_set1_epi8(127));
+ qx[3] = _mm256_add_epi8(qx[3], _mm256_set1_epi8(127));
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+ib);
+ auto y = MM256_SET_M128I(y128, y128);
+#ifdef HAVE_FANCY_SIMD
+ isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
+ isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
+ isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
+#else
+ auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])));
+ auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])));
+ auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])));
+ auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi1, sumi2));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi3, sumi4));
+#endif
+ }
+ }
+#ifdef HAVE_FANCY_SIMD
+ auto m4 = _mm256_mul_ps(d4, _mm256_set1_ps(-128.f));
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
+ acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+#ifdef HAVE_FANCY_SIMD
+ auto bsums = (const float *)q8.y[iy][ibl].bsums;
+ acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(bsums[0]), acc[iy]);
+#endif
+ isum[iy] = _mm256_setzero_si256();
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ GGML_ASSERT(n%32 == 0);
+ __m256i qx[4];
+#ifndef HAVE_FANCY_SIMD
+ __m256i sx[4];
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ __m256i acc[nrc_y] = {};
+ float dy[nrc_y];
+#ifdef HAVE_FANCY_SIMD
+ int32_t sy[nrc_y];
+#endif
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+#ifdef HAVE_FANCY_SIMD
+ auto iptr = (const int32_t *)(dptr + 1);
+ sy[iy] = -127*iptr[0];
+#endif
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ const int8_t * q8x[4];
+ float dx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ for (int kx = 0; kx < 4; ++kx) {
+ auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
+ dx[kx] = dptr[0];
+ q8x[kx] = (const int8_t *)(dptr + 2);
+ }
+ for (int i = 0; i < n/32; ++i) {
+ for (int kx = 0; kx < 4; ++kx) qx[kx] = _mm256_loadu_si256((const __m256i *)q8x[kx] + i);
+ auto t0 = _mm256_unpacklo_epi32(qx[0], qx[1]);
+ auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
+ auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
+ auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
+#ifdef HAVE_FANCY_SIMD
+ qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127));
+ qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127));
+ qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127));
+ qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127));
+#else
+ qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
+ qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
+ qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
+ qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
+#ifdef HAVE_FANCY_SIMD
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
+#else
+ auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
+ auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
+ auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
+ auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
+ auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2));
+ auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4));
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34));
+#endif
+ }
+ }
+ auto scales_x = _mm_loadu_ps(dx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy]));
+#endif
+ auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
+ info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
+ acc[iy] = _mm256_setzero_si256();
+ }
+ }
+}
+
+// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
+template <int nrc_y>
+static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%32 == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+#ifndef HAVE_FANCY_SIMD
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ int nb = n / 16;
+ __m256i acc[nrc_y] = {};
+ __m256i qx[4];
+ float dy[nrc_y];
+#ifdef HAVE_FANCY_SIMD
+ float sy[nrc_y];
+#endif
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+#ifdef HAVE_FANCY_SIMD
+ auto iptr = (const int32_t *)(dptr + 1);
+ sy[iy] = -127*iptr[0];
+#endif
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ auto dptr = (const float *)((const char *)vx + ix*bx);
+ auto dx = _mm256_loadu_ps(dptr);
+ auto q8x = (const int8_t *)(dptr + 8);
+ for (int ib = 0; ib < nb; ++ib) { // Blocks of 16 for 8 interleaved rows
+ qx[0] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+0);
+ qx[1] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+1);
+ qx[2] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+2);
+ qx[3] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+3);
+#ifndef HAVE_FANCY_SIMD
+ auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
+ auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
+ auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
+ auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
+#else
+ qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127));
+ qx[1] = _mm256_add_epi8(qx[1], _mm256_set1_epi8(127));
+ qx[2] = _mm256_add_epi8(qx[2], _mm256_set1_epi8(127));
+ qx[3] = _mm256_add_epi8(qx[3], _mm256_set1_epi8(127));
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)q8y[iy]+ib);
+ auto y = MM256_SET_M128I(y128, y128);
+#ifdef HAVE_FANCY_SIMD
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
+ auto sumi2 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
+ auto sumi3 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
+ auto sumi4 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
+ auto sumi12 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
+ auto sumi34 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi3), _mm256_madd_epi16(m1, sumi4));
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(sumi12, sumi34));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy]));
+#ifdef HAVE_FANCY_SIMD
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy]));
+#endif
+ info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy])));
+ acc[iy] = _mm256_setzero_si256();
+ }
+ }
+}
+
+} // namespace
+
+bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
+
+ auto etypeA = ggml_type(typeA);
+ auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
+ : etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
+ : etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV
+ : GGML_TYPE_Q8_K;
+
+ if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
+ return false;
+ }
+
+ func16 = nullptr;
+
+ switch (typeA) {
+ case GGML_TYPE_Q2_K:
+ set_functions<DequantizerQ2K>(kernels);
+ break;
+ case GGML_TYPE_Q3_K:
+ set_functions<DequantizerQ3K>(kernels);
+ break;
+ case GGML_TYPE_Q4_K:
+ set_functions<DequantizerQ4K>(kernels);
+ break;
+ case GGML_TYPE_Q5_K:
+ set_functions<DequantizerQ5K>(kernels);
+ break;
+ case GGML_TYPE_Q6_K:
+ set_functions<DequantizerQ6K>(kernels);
+ break;
+ case GGML_TYPE_IQ4_XS:
+ set_functions<DequantizerIQ4XS>(kernels);
+ break;
+ case GGML_TYPE_Q2_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q2_k_r4_q8_k, kernels)
+ break;
+ case GGML_TYPE_Q3_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q3_k_r4_q8_k, kernels)
+ break;
+ case GGML_TYPE_Q4_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_k_r4_q8_k, kernels)
+ break;
+ case GGML_TYPE_Q5_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q5_k_r4_q8_k, kernels)
+ break;
+ case GGML_TYPE_Q6_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q6_k_r4_q8_k, kernels)
+ break;
+ case GGML_TYPE_IQ4_XS_R8:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_xs_r8_q8_k_avx2, kernels)
+ break;
+ case GGML_TYPE_Q8_K_R8:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_k_r8_q8_k, kernels)
+#ifdef HAVE_FANCY_SIMD
+ func16 = mul_mat_q8_k_r8_q8_k<16>;
+#endif
+ break;
+ case GGML_TYPE_Q8_KV:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_q8_KV, kernels)
+#ifdef HAVE_FANCY_SIMD
+ func16 = mul_mat_q8_KV_q8_KV<16>;
+#endif
+ break;
+ case GGML_TYPE_Q8_KV_R8:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_r8_q8_KV, kernels);
+ break;
+ default:
+ return false;
+ }
+
+ return true;
+
+}
+
+#else
+// --------------------------------- __aarch64__ --------------------------------------
+
+namespace {
+
+template <typename Q8>
+inline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto q8s = q8.load_bsums8(iy, i);
+ int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s));
+ int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s));
+ float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2));
+ acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
+ }
+}
+template <typename Q8>
+inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto q8s = q8.load_bsums(iy, i);
+ int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0]));
+ int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0]));
+ int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1]));
+ int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1]));
+ float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4)));
+ acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
+ }
+}
+
+struct Scales8 {
+ uint32_t utmp[4];
+ const uint8_t * sc8 = (const uint8_t *)utmp;
+ template <typename Q8, typename Qx>
+ inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) {
+ make_q4_scales(x.scales, utmp);
+ int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8));
+ accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin));
+
+ uint8x8_t scales8 = vld1_u8(sc8);
+ uint16x8_t scales16 = vmovl_u8(scales8);
+ int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))),
+ vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))};
+ return scales;
+ }
+};
+
+struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
+ DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ return s8.process_scales_mins(x[i], q8, i, acc);
+ }
+ inline void prepare(int i, int j) {
+ if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
+ else bits.prepare(x[i].qs+64*j);
+ }
+
+ Q4bits bits;
+ Scales8 s8;
+
+};
+
+struct HighBit5 {
+ const uint8x16_t mhb = vdupq_n_u8(0x10);
+ uint8x16x2_t bits;
+ inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
+ b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb));
+ b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb));
+ b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb));
+ b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb));
+
+ b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
+ b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
+ b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
+ b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
+
+ if (do_shift) {
+ bits.val[0] = vshrq_n_u8(bits.val[0], 4);
+ bits.val[1] = vshrq_n_u8(bits.val[1], 4);
+ }
+ }
+};
+
+struct HighBit3 {
+ const uint8x16_t mhb = vdupq_n_u8(0x04);
+ uint8x16x2_t bits;
+ inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
+ b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
+ b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
+ b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
+ b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
+
+ b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb));
+ b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb));
+ b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb));
+ b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb));
+
+ if (do_shift) {
+ bits.val[0] = vshrq_n_u8(bits.val[0], 4);
+ bits.val[1] = vshrq_n_u8(bits.val[1], 4);
+ }
+ }
+};
+
+struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
+ DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ h.bits = vld1q_u8_x2(x[i].qh);
+ return s8.process_scales_mins(x[i], q8, i, acc);
+ }
+ inline void prepare(int i, int j) {
+ if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
+ else bits.prepare(x[i].qs+64*j);
+ h.apply(bits.b1, bits.b2, j == 0);
+ }
+
+ Q4bits bits;
+ HighBit5 h;
+ Scales8 s8;
+
+ uint8x16x2_t hbits;
+
+};
+
+inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
+ int32x4x4_t scales = {
+ vmovl_s16(vget_low_s16 (scales16.val[0])),
+ vmovl_s16(vget_high_s16(scales16.val[0])),
+ vmovl_s16(vget_low_s16 (scales16.val[1])),
+ vmovl_s16(vget_high_s16(scales16.val[1])),
+ };
+ return scales;
+}
+
+template <typename Q8>
+inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) {
+ int16x8x2_t scales16;
+ scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
+ scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
+ accum_mins_16(scales16, q8, acc, i, c);
+ return make_wider(scales16);
+}
+
+struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
+ DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d);
+ }
+ inline void prepare(int i, int j) {
+
+ auto hbits = vld1q_u8_x2(x[i].qh + 32*j);
+
+ bits.prepare64(x[i].ql+64*j);
+ bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb));
+ bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb));
+ bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb));
+ bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb));
+
+ bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb));
+ bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb));
+ bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb));
+ bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb));
+
+ }
+
+ Q4bits bits;
+
+ const uint8x16_t mhb = vdupq_n_u8(0x30);
+
+};
+
+struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
+ DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ h.bits = vld1q_u8_x2(x[i].hmask);
+ mask = vdupq_n_u8(0x01);
+ const uint16_t * sc16 = (const uint16_t *)x[i].scales;
+ uint32_t aux0 = sc16[0] | (sc16[1] << 16);
+ uint32_t aux1 = sc16[2] | (sc16[3] << 16);
+ uint32_t aux2 = sc16[4] | (sc16[5] << 16);
+ aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030);
+ aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);
+ aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);
+ aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);
+ auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32));
+ if (nrc > 1) {
+ return process_scales_mins_16(scales8, q8, acc, i, -4.f*d);
+ }
+ int16x8x2_t scales16;
+ scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
+ scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
+ return make_wider(scales16);
+ }
+
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs+32*j);
+ if (nrc > 1) {
+ h.apply(bits.b1, bits.b2, j == 0);
+ } else {
+ auto minus4 = vdupq_n_u8(0xfc);
+ auto zero = vdupq_n_u8(0);
+ bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
+ bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
+ mask = vshlq_n_u8(mask, 1);
+ bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
+ bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
+ mask = vshlq_n_u8(mask, 1);
+ bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
+ bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
+ mask = vshlq_n_u8(mask, 1);
+ bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
+ bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
+ mask = vshlq_n_u8(mask, 1);
+ }
+ }
+
+ uint32_t aux32[4];
+
+ Q2bits bits;
+
+ uint8x16_t mask;
+ HighBit3 h;
+
+};
+
+struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
+ DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return true; }
+
+ template <typename Q8>
+ inline void process_scales(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ auto scales_and_mins = vld1q_u8(x[i].scales);
+ auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4));
+ int16x8x2_t scales16;
+ scales16.val[0] = vmovl_s8(vget_low_s8(mins8));
+ scales16.val[1] = vmovl_s8(vget_high_s8(mins8));
+ accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin));
+
+ scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf));
+ }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ process_scales(i, q8, acc);
+ int16x8x2_t scales16;
+ scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8)));
+ scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8)));
+ return make_wider(scales16);
+ }
+
+ template <typename Q8>
+ inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {
+ auto m1 = vdupq_n_u8(1);
+ auto shuffle = vdupq_n_u8(8*j);
+ bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
+ sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
+ vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
+
+ auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
+ sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
+ vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
+
+ auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
+ sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),
+ vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);
+
+ auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
+ sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),
+ vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);
+ }
+ }
+
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs+32*j);
+ }
+
+ uint32_t aux32[4];
+
+ uint8x16_t scales8;
+
+ Q2bits bits;
+
+};
+
+struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
+
+ static int8x16_t load_values() {
+ static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+ return vld1q_s8(iq4nl_values);
+ }
+
+ DequantizerIQ4XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ inline void new_row(int ix) { x = (const block_iq4_xs *)((const char *)vx + bx*ix); }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ (void)q8;
+ (void)acc;
+ d = GGML_FP16_TO_FP32(x[i].d);
+ const uint16_t scales_h = x[i].scales_h;
+ const uint16_t * scales_l = (const uint16_t *)x[i].scales_l;
+ aux32[0] = scales_l[0] | (scales_l[1] << 16);
+ aux32[1] = aux32[0] >> 4;
+ // scl is ordered as 0, 2, 4, 6, 1, 3, 5, 7
+ uint8x8_t scl8 = vand_u8(vld1_u8((const uint8_t *)aux32), vdup_n_u8(0xf));
+ uint16_t * aux16 = (uint16_t *)aux32;
+ aux16[0] = scales_h << 4; aux16[1] = scales_h << 2; aux16[2] = scales_h; aux16[3] = scales_h >> 2;
+ // sch is ordered as 0, 4, 1, 5, 2, 6, 3, 7
+ uint8x8_t sch8 = vand_u8(vld1_u8((const uint8_t *)aux16), vdup_n_u8(0x30));
+ int8x8_t scales8 = vadd_s8(vreinterpret_s8_u8(vorr_u8(scl8, vtbl1_u8(sch8, vreinterpret_u8_u32(hshuff)))), vdup_n_s8(-32));
+ // shuffle 0, 2, 4, 6, 1, 3, 5, 7 -> 0, 1, 2, 3, 4, 5, 6, 7
+ scales8 = vtbl1_s8(scales8, vreinterpret_s8_u32(hshuff));
+ int16x8_t scales16 = vmovl_s8(scales8);
+ int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
+ return scales;
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare16(x[i].qs+64*j);
+ //if (nrc == 1) {
+ // bits.prepare16_v2(x[i].qs+64*j);
+ //} else {
+ // bits.prepare16(x[i].qs+64*j);
+ //}
+ for (int k = 0; k < 4; ++k) {
+ bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k]));
+ bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k]));
+ }
+ }
+
+ Q4bits bits;
+ const int8x16_t values;
+ uint32_t aux32[2];
+
+ constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602};
+
+};
+
+IQK_ALWAYS_INLINE void prepare_q4_k_quants(const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) {
+ qx[0] = vandq_u8(bits.val[0], m4); // 0...3 from the 4 rows
+ qx[1] = vandq_u8(bits.val[1], m4); // 16..19
+ qx[2] = vandq_u8(bits.val[2], m4); // 4...7
+ qx[3] = vandq_u8(bits.val[3], m4); // 20..23
+ qx[4] = vshrq_n_u8(bits.val[0], 4); // 8..11
+ qx[5] = vshrq_n_u8(bits.val[1], 4); // 24..27
+ qx[6] = vshrq_n_u8(bits.val[2], 4); // 12..15
+ qx[7] = vshrq_n_u8(bits.val[3], 4); // 28..31
+}
+
+template <int nrc_y>
+void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto mf = vdupq_n_u8(0x0f);
+ auto m03 = vdupq_n_u8(0x03);
+ int nbl = n / QK_K;
+ int8x16_t qx[4];
+ float32x4_t acc[nrc_y] = {};
+ int16x8x4_t i16scales;
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q2_k_r4 * iq2 = (const block_q2_k_r4 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ int32x4_t isum[nrc_y] = {};
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
+ auto m4 = vmulq_f32(vdupq_n_f32(-1.f), vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d+4)));
+ for (int is = 0; is < 2; ++is) {
+ auto sl = vld1q_u8_x2(iq2[ibl].scales + 32*is);
+ auto m = vshrq_n_u8(sl.val[0], 4);
+ i16scales.val[0] = vmovl_u8(vget_low_u8 (m));
+ i16scales.val[1] = vmovl_u8(vget_high_u8(m));
+ m = vshrq_n_u8(sl.val[1], 4);
+ i16scales.val[2] = vmovl_u8(vget_low_u8 (m));
+ i16scales.val[3] = vmovl_u8(vget_high_u8(m));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = vdupq_n_s32(0);
+ auto bsums = vld1q_s16(q8.y[iy][ibl].bsums + 8*is);
+ auto b8 = vget_low_s16(bsums);
+ //auto bsums = q8.load_bsums(iy, ibl);
+ //auto b8 = vget_low_s16(bsums.val[0]);
+ sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[0]), b8, 0);
+ sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[0]), b8, 1);
+ sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[1]), b8, 2);
+ sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[1]), b8, 3);
+ b8 = vget_high_s16(bsums);
+ sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[2]), b8, 0);
+ sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[2]), b8, 1);
+ sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[3]), b8, 2);
+ sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[3]), b8, 3);
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(m4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi));
+ }
+ m = vandq_u8(sl.val[0], mf);
+ i16scales.val[0] = vmovl_u8(vget_low_u8 (m));
+ i16scales.val[1] = vmovl_u8(vget_high_u8(m));
+ m = vandq_u8(sl.val[1], mf);
+ i16scales.val[2] = vmovl_u8(vget_low_u8 (m));
+ i16scales.val[3] = vmovl_u8(vget_high_u8(m));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto bits = vld1q_u8_x2(iq2[ibl].qs + 128*is + 32*ib);
+ auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
+ qx[0] = vreinterpretq_s8_u8(vandq_u8( bits.val[0], m03));
+ qx[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 2), m03));
+ qx[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 4), m03));
+ qx[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 6), m03));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
+ qx[0] = vreinterpretq_s8_u8(vandq_u8( bits.val[1], m03));
+ qx[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 2), m03));
+ qx[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 4), m03));
+ qx[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 6), m03));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto mf = vdupq_n_u8(0x0f);
+ auto m30 = vdupq_n_u8(0x30);
+ auto m32 = vdupq_n_s8(-32);
+ auto m03 = vdupq_n_u8(0x03);
+ auto m04 = vdupq_n_u8(0x04);
+ int nbl = n / QK_K;
+ int8x16_t qx[4];
+ float32x4_t acc[nrc_y] = {};
+ int8x16x4_t i8scales;
+ int16x8x4_t i16scales;
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q3_k_r4 * iq3 = (const block_q3_k_r4 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ int32x4_t isum[nrc_y] = {};
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
+ auto sl = vld1q_u8_x2(iq3[ibl].scales_l);
+ auto sh = vld1q_u8(iq3[ibl].scales_h);
+ i8scales.val[0] = vaddq_s8(m32, vorrq_u8(vandq_u8(sl.val[0], mf), vandq_u8(vshlq_n_u8(sh, 4), m30)));
+ i8scales.val[1] = vaddq_s8(m32, vorrq_u8(vandq_u8(sl.val[1], mf), vandq_u8(vshlq_n_u8(sh, 2), m30)));
+ i8scales.val[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m30)));
+ i8scales.val[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m30)));
+ for (int is = 0; is < 2; ++is) {
+ i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
+ i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
+ i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
+ i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto lbits = vld1q_u8_x2(iq3[ibl].qs + 128*is + 32*ib);
+ auto hbits = vld1q_u8(iq3[ibl].qh + 64*is + 16*ib);
+ hbits = veorq_u8(hbits, vdupq_n_u8(0xff));
+ auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
+ qx[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8( lbits.val[0], m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshlq_n_u8(hbits, 2))));
+ qx[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 2), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshlq_n_u8(hbits, 1))));
+ qx[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 4), m03)), vreinterpretq_s8_u8(vandq_u8(m04, hbits)));
+ qx[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 6), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 1))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
+ qx[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8( lbits.val[1], m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 2))));
+ qx[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 2), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 3))));
+ qx[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 4), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 4))));
+ qx[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 6), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 5))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto mf = vdupq_n_u8(0xf);
+ auto m3 = vdupq_n_u8(0x30);
+ int nbl = n / QK_K;
+ int8x16_t qx[8];
+ int8x16x2_t iscales;
+ int32x4x4_t scales;
+ float32x4_t acc[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d));
+ auto m4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d+4));
+ m4 = vmulq_f32(m4, vdupq_n_f32(-1.f));
+ auto sl = vld1q_u8_x2(iq4[ibl].scales_l);
+ auto sh = vld1q_u8(iq4[ibl].scales_h);
+ iscales.val[0] = vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(vshlq_n_u8(sh, 2), m3));
+ iscales.val[1] = vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3));
+ for (int is = 0; is < 2; ++is) {
+ auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
+ auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
+ float32x4x4_t fscales;
+ fscales.val[0] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_1))));
+ fscales.val[1] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_1))));
+ fscales.val[2] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_2))));
+ fscales.val[3] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_2))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto m8 = vld1q_f32((const float *)q8.y[iy][ibl].bsums + 4*is);
+ acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[0], m8, 0);
+ acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[1], m8, 1);
+ acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[2], m8, 2);
+ acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[3], m8, 3);
+ }
+ }
+ iscales.val[0] = vorrq_u8(vandq_u8(sl.val[0], mf), vandq_u8(vshlq_n_u8(sh, 4), m3));
+ iscales.val[1] = vorrq_u8(vandq_u8(sl.val[1], mf), vandq_u8(sh, m3));
+ int32x4_t isum[nrc_y] = {};
+ for (int is = 0; is < 2; ++is) {
+ auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
+ auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
+ scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1));
+ scales.val[1] = vmovl_s16(vget_high_s16(iscales16_1));
+ scales.val[2] = vmovl_s16(vget_low_s16(iscales16_2));
+ scales.val[3] = vmovl_s16(vget_high_s16(iscales16_2));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
+ prepare_q4_k_quants(mf, bits, qx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto mf = vdupq_n_u8(0xf);
+ auto m30 = vdupq_n_u8(0x30);
+ auto m10 = vdupq_n_u8(0x10);
+ int nbl = n / QK_K;
+ int8x16_t qx[8];
+ int8x16x2_t iscales;
+ int32x4x4_t scales;
+ float32x4_t acc[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d));
+ auto m4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d+4));
+ m4 = vmulq_f32(m4, vdupq_n_f32(-1.f));
+ auto sl = vld1q_u8_x2(iq5[ibl].scales_l);
+ auto sh = vld1q_u8(iq5[ibl].scales_h);
+ iscales.val[0] = vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(vshlq_n_u8(sh, 2), m30));
+ iscales.val[1] = vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m30));
+ for (int is = 0; is < 2; ++is) {
+ auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
+ auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
+ float32x4x4_t fscales;
+ fscales.val[0] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_1))));
+ fscales.val[1] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_1))));
+ fscales.val[2] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_2))));
+ fscales.val[3] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_2))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto m8 = vld1q_f32((const float *)q8.y[iy][ibl].bsums + 4*is);
+ acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[0], m8, 0);
+ acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[1], m8, 1);
+ acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[2], m8, 2);
+ acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[3], m8, 3);
+ }
+ }
+ iscales.val[0] = vorrq_u8(vandq_u8(sl.val[0], mf), vandq_u8(vshlq_n_u8(sh, 4), m30));
+ iscales.val[1] = vorrq_u8(vandq_u8(sl.val[1], mf), vandq_u8(sh, m30));
+ int32x4_t isum[nrc_y] = {};
+ for (int is = 0; is < 2; ++is) {
+ auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
+ auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
+ scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1));
+ scales.val[1] = vmovl_s16(vget_high_s16(iscales16_1));
+ scales.val[2] = vmovl_s16(vget_low_s16(iscales16_2));
+ scales.val[3] = vmovl_s16(vget_high_s16(iscales16_2));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
+ auto hbits2 = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
+ auto hbits1 = vshlq_n_u8(hbits2, 4);
+ prepare_q4_k_quants(mf, lbits, qx);
+ qx[0] = vorrq_u8(qx[0], vandq_u8(m10, hbits1));
+ qx[1] = vorrq_u8(qx[1], vandq_u8(m10, hbits2));
+ qx[2] = vorrq_u8(qx[2], vandq_u8(m10, vshrq_n_u8(hbits1, 2)));
+ qx[3] = vorrq_u8(qx[3], vandq_u8(m10, vshrq_n_u8(hbits2, 2)));
+ qx[4] = vorrq_u8(qx[4], vandq_u8(m10, vshrq_n_u8(hbits1, 1)));
+ qx[5] = vorrq_u8(qx[5], vandq_u8(m10, vshrq_n_u8(hbits2, 1)));
+ qx[6] = vorrq_u8(qx[6], vandq_u8(m10, vshrq_n_u8(hbits1, 3)));
+ qx[7] = vorrq_u8(qx[7], vandq_u8(m10, vshrq_n_u8(hbits2, 3)));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto mf = vdupq_n_u8(0x0f);
+ auto m3 = vdupq_n_u8(0x30);
+ auto m32 = vdupq_n_s8(-32);
+ int nbl = n / QK_K;
+ int8x16_t qx[4];
+ float32x4_t acc[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q6_k_r4 * iq6 = (const block_q6_k_r4 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ibl].d));
+ int32x4_t isum[nrc_y] = {};
+ for (int is = 0; is < 2; ++is) {
+ for (int ib = 0; ib < 4; ++ib) {
+ auto lbits = vld1q_u8_x4(iq6[ibl].ql + 256*is + 64*ib);
+ auto hbits = vld1q_u8(iq6[ibl].qh + 128*is + 32*ib);
+ auto iscales = vmovl_s8(vld1_s8(iq6[ibl].scales + 32*is + 8*ib));
+ auto scales = vmovl_s16(vget_low_s16(iscales));
+ qx[0] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[0], mf), vandq_u8(m3, vshlq_n_u8(hbits, 4))));
+ qx[1] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[2], mf), vandq_u8(m3, hbits)));
+ qx[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m3, vshlq_n_u8(hbits, 2))));
+ qx[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m3, vshrq_n_u8(hbits, 2))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ scales = vmovl_s16(vget_high_s16(iscales));
+ hbits = vld1q_u8(iq6[ibl].qh + 128*is + 32*ib + 16);
+ qx[0] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[1], mf), vandq_u8(m3, vshlq_n_u8(hbits, 4))));
+ qx[1] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[3], mf), vandq_u8(m3, hbits)));
+ qx[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m3, vshlq_n_u8(hbits, 2))));
+ qx[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m3, vshrq_n_u8(hbits, 2))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ int nbl = n / QK_K;
+ float32x4_t acc[2*nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto d4l = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+0));
+ auto d4h = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+4));
+ int32x4_t isum[2*nrc_y] = {};
+ for (int ib = 0; ib < QK_K/16; ++ib) {
+ auto q1 = vld1q_s8_x4(iq8[ibl].qs + 128*ib + 0);
+ auto q2 = vld1q_s8_x4(iq8[ibl].qs + 128*ib + 64);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+16*ib);
+ isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[0], y, 0);
+ isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[1], y, 0);
+ isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[2], y, 1);
+ isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[3], y, 1);
+ isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[0], y, 2);
+ isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[1], y, 2);
+ isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[2], y, 3);
+ isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[3], y, 3);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
+ const float * bsum = (const float *)q8.y[iy][ibl].bsums;
+ auto m8 = vdupq_n_f32(-128.f*bsum[0]);
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[2*iy+0]));
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[2*iy+1]));
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], d4l, m8);
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], d4l, m8);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix+0, iy, acc[2*iy+0]);
+ info.store(ix+4, iy, acc[2*iy+1]);
+ acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = vdupq_n_u8(0xf);
+ auto m3 = vdupq_n_u8(0x30);
+ auto m32 = vdupq_n_s8(-32);
+ auto values = vld1q_s8(iq4k_values);
+ int nbl = n / QK_K;
+ int8x16_t qx[8];
+ int8x16x4_t iscales;
+ int32x4x2_t scales;
+ float32x4_t acc[2*nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_iq4_xs_r8 * iq4 = (const block_iq4_xs_r8 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto d4_f16 = vld1q_f16((const float16_t *)iq4[ibl].d);
+ auto d4l = vcvt_f32_f16(vget_low_f16 (d4_f16));
+ auto d4h = vcvt_f32_f16(vget_high_f16(d4_f16));
+ auto sl = vld1q_u8_x2(iq4[ibl].scales_l);
+ auto sh = vld1q_u8(iq4[ibl].scales_h);
+ iscales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
+ iscales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
+ iscales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
+ iscales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
+ int32x4_t isum[nrc_y] = {};
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+ auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[ib64]));
+ auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[ib64]));
+ scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1));
+ scales.val[1] = vmovl_s16(vget_low_s16(iscales16_2));
+ for (int l = 0; l < 2; ++l) {
+ uint8x16x2_t bits;
+ bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l);
+ bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 32);
+ prepare_iq4_nl_quants_r8(values, m4, bits, qx+0);
+ bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 64);
+ bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 96);
+ prepare_iq4_nl_quants_r8(values, m4, bits, qx+4);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+64*ib64+32*l);
+ auto sumi = vdupq_n_s32(0);
+ sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[1], y.val[0], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[3], y.val[0], 3);
+ sumi = vdotq_laneq_s32(sumi, qx[4], y.val[1], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[6], y.val[1], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
+ isum[iy] = vmlaq_s32(isum[iy], sumi, scales.val[l]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[iy]));
+ isum[iy] = vdupq_n_s32(0);
+ }
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+ auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[ib64]));
+ auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[ib64]));
+ scales.val[0] = vmovl_s16(vget_high_s16(iscales16_1));
+ scales.val[1] = vmovl_s16(vget_high_s16(iscales16_2));
+ for (int l = 0; l < 2; ++l) {
+ uint8x16x2_t bits;
+ bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 16);
+ bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 48);
+ prepare_iq4_nl_quants_r8(values, m4, bits, qx+0);
+ bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 80);
+ bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l +112);
+ prepare_iq4_nl_quants_r8(values, m4, bits, qx+4);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+64*ib64+32*l);
+ auto sumi = vdupq_n_s32(0);
+ sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[1], y.val[0], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[3], y.val[0], 3);
+ sumi = vdotq_laneq_s32(sumi, qx[4], y.val[1], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[6], y.val[1], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
+ isum[iy] = vmlaq_s32(isum[iy], sumi, scales.val[l]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[iy]));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix+0, iy, acc[2*iy+0]);
+ info.store(ix+4, iy, acc[2*iy+1]);
+ acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%32 == 0);
+ int32x4_t acc[4] = {};
+ auto dptr = (const float *)info.src1_row(0);
+ const float dy = dptr[0];
+ auto q8y = (const int8_t *)(dptr + 2);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto dx = (const float *)((const char *)vx + ix*bx);
+ auto q8x = (const int8_t *)(dx + 2);
+ for (int i = 0; i < n/64; ++i) {
+ auto qx = vld1q_s8_x4(q8x + 64*i);
+ for (int j = 0; j < 4; ++j) {
+ acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 64*i + 16*j));
+ }
+ }
+ if (int i = 2*(n/64); i < n/32) {
+ auto qx = vld1q_s8_x2(q8x + 32*i);
+ for (int j = 0; j < 2; ++j) {
+ acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 32*i + 16*j));
+ }
+ }
+ acc[0] = vaddq_s32(acc[0], acc[1]);
+ acc[2] = vaddq_s32(acc[2], acc[3]);
+ acc[0] = vaddq_s32(acc[0], acc[2]);
+ info.store(ix, 0, dx[0]*dy*vaddvq_s32(acc[0]));
+ acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0);
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ GGML_ASSERT(n%16 == 0);
+ int8x16_t qx[4];
+ int32x4_t acc[nrc_y] = {};
+ float dy[nrc_y];
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ const int8_t * q8x[4];
+ float dx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ for (int kx = 0; kx < 4; ++kx) {
+ auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
+ dx[kx] = dptr[0];
+ q8x[kx] = (const int8_t *)(dptr + 2);
+ }
+ for (int i = 0; i < n/16; ++i) {
+ for (int kx = 0; kx < 4; ++kx) qx[kx] = vld1q_s8(q8x[kx] + 16*i);
+ auto row01 = vtrnq_s32(qx[0], qx[1]);
+ auto row23 = vtrnq_s32(qx[2], qx[3]);
+ qx[0] = vtrn1q_s64(row01.val[0], row23.val[0]);
+ qx[1] = vtrn1q_s64(row01.val[1], row23.val[1]);
+ qx[2] = vtrn2q_s64(row01.val[0], row23.val[0]);
+ qx[3] = vtrn2q_s64(row01.val[1], row23.val[1]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8y[iy] + 16*i);
+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[0], y, 0);
+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[1], y, 1);
+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[2], y, 2);
+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[3], y, 3);
+ }
+ }
+ auto scales_x = vld1q_f32(dx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scale = vmulq_f32(scales_x, vdupq_n_f32(dy[iy]));
+ info.store(ix, iy, vmulq_f32(scale, vcvtq_f32_s32(acc[iy])));
+ acc[iy] = vdupq_n_s32(0);
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ int32x4_t acc[2*nrc_y] = {};
+ float dy[nrc_y];
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto q8x = (const int8_t *)(dptr + 8);
+ for (int ib = 0; ib < n/16; ++ib) {
+ auto q1 = vld1q_s8_x4(q8x + 128*ib + 0);
+ auto q2 = vld1q_s8_x4(q8x + 128*ib + 64);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8y[iy]+16*ib);
+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[0], y, 0);
+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[1], y, 0);
+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[2], y, 1);
+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[3], y, 1);
+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[0], y, 2);
+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[1], y, 2);
+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[2], y, 3);
+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[3], y, 3);
+ }
+ }
+ auto scale1_x = vld1q_f32(dptr+0);
+ auto scale2_x = vld1q_f32(dptr+4);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scale_y = vdupq_n_f32(dy[iy]);
+ auto scale1 = vmulq_f32(scale1_x, scale_y);
+ auto scale2 = vmulq_f32(scale2_x, scale_y);
+ info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0])));
+ info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1])));
+ acc[2*iy+0] = acc[2*iy+1] = vdupq_n_s32(0.f);
+ }
+ }
+}
+
+}
+
+bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, [[maybe_unused]] mul_mat_t& func16) {
+
+ auto etypeA = ggml_type(typeA);
+ auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
+ : etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
+ : etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV
+ : GGML_TYPE_Q8_K;
+
+ if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
+ return false;
+ }
+
+ func16 = nullptr;
+
+ switch (typeA) {
+ case GGML_TYPE_Q2_K:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ2K, kernels)
+ break;
+ case GGML_TYPE_Q3_K:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ3K, kernels)
+ break;
+ case GGML_TYPE_Q4_K:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ4K, kernels)
+ break;
+ case GGML_TYPE_Q5_K:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ5K, kernels)
+ break;
+ case GGML_TYPE_Q6_K:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ6K, kernels)
+ break;
+ case GGML_TYPE_IQ4_XS:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ4XS, kernels)
+ break;
+ case GGML_TYPE_Q2_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q2_k_r4_q8_k, kernels)
+ break;
+ case GGML_TYPE_Q3_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q3_k_r4_q8_k, kernels)
+ break;
+ case GGML_TYPE_Q4_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_k_r4_q8_k, kernels)
+ break;
+ case GGML_TYPE_Q5_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q5_k_r4_q8_k, kernels)
+ break;
+ case GGML_TYPE_Q6_K_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q6_k_r4_q8_k, kernels)
+ break;
+ case GGML_TYPE_IQ4_XS_R8:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_xs_r8_q8_k, kernels)
+ break;
+ case GGML_TYPE_Q8_K_R8:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_k_r8_q8_k, kernels)
+ break;
+ case GGML_TYPE_Q8_KV:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_q8_KV, kernels)
+ kernels[0] = mul_mat_q8_KV_q8_KV_1;
+ func16 = mul_mat_q8_KV_q8_KV<16>;
+ break;
+ case GGML_TYPE_Q8_KV_R8:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_r8_q8_KV, kernels);
+ break;
+ default:
+ return false;
+ }
+
+ return true;
+
+}
+
+#endif
+
+namespace {
+
+#ifdef __AVX2__
+template <int nrc_y>
+static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%32 == 0);
+ if (nrc_y == 1 && nrc_x == 1) {
+ auto dx = (const float *)vx;
+ auto dy = (const float *)info.src1_row(0);
+#ifdef HAVE_FANCY_SIMD
+ auto sy = (const int32_t *)(dy + 1);
+ auto x = (const int8_t *)(dx + 2);
+ auto y = (const int8_t *)(dy + 2);
+ auto isum = _mm512_setzero_si512();
+ for (int i = 0; i < n/64; ++i) {
+ auto qx = _mm512_loadu_si512((const __m512i *)x + i);
+ auto qy = _mm512_loadu_si512((const __m512i *)y + i);
+ isum = _mm512_dpbusd_epi32(isum, _mm512_add_epi8(qx, _mm512_set1_epi8(127)), qy);
+ }
+ auto isum256 = _mm256_add_epi32(_mm512_castsi512_si256(isum), _mm512_extracti32x8_epi32(isum, 1));
+ for (int i = 2*(n/64); i < n/32; ++i) {
+ auto qx = _mm256_loadu_si256((const __m256i *)x + i);
+ auto qy = _mm256_loadu_si256((const __m256i *)y + i);
+ isum256 = _mm256_dpbusd_epi32(isum256, _mm256_add_epi8(qx, _mm256_set1_epi8(127)), qy);
+ }
+ info.store(0, 0, dx[0]*dy[0]*(hsum_i32_8(isum256) - 127*sy[0]));
+#else
+ auto x = (const int8_t *)(dx + 2);
+ auto y = (const int8_t *)(dy + 2);
+ auto isum = _mm256_setzero_si256();
+ for (int i = 0; i < n/32; ++i) {
+ auto qx = _mm256_loadu_si256((const __m256i *)x + i);
+ auto qy = _mm256_loadu_si256((const __m256i *)y + i);
+ auto dot = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(qy, qx));
+ isum = _mm256_add_epi32(isum, _mm256_madd_epi16(_mm256_set1_epi16(1), dot));
+ }
+ info.store(0, 0, dx[0]*dy[0]*hsum_i32_8(isum));
+#endif
+ return;
+ }
+ __m256i qx[2];
+ __m256i acc[2*nrc_y] = {};
+ float dy[nrc_y];
+#ifdef HAVE_FANCY_SIMD
+ int32_t sy[nrc_y];
+#else
+ __m256i sx[2];
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+#ifdef HAVE_FANCY_SIMD
+ auto iptr = (const int32_t *)(dptr+1);
+ sy[iy] = -127*iptr[0];
+#endif
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto dx = (const float *)((const char *)vx + ix*bx);
+ auto q8x = (const int8_t *)(dx + 2);
+ for (int i = 0; i < n/64; ++i) {
+ for (int j = 0; j < 2; ++j) {
+#ifdef HAVE_FANCY_SIMD
+ qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127));
+#else
+ qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ for (int j = 0; j < 2; ++j) {
+#ifdef HAVE_FANCY_SIMD
+ acc[2*iy+j] = _mm256_dpbusd_epi32(acc[2*iy+j], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j));
+#else
+ auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
+ acc[2*iy+j] = _mm256_add_epi32(acc[2*iy+j], _mm256_madd_epi16(m1, dot));
+#endif
+ }
+ }
+ }
+ if (int i = 2*(n/64); i < n/32) {
+#ifdef HAVE_FANCY_SIMD
+ qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127));
+#else
+ qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
+ sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+#ifdef HAVE_FANCY_SIMD
+ acc[2*iy] = _mm256_dpbusd_epi32(acc[2*iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i));
+#else
+ auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
+ acc[2*iy] = _mm256_add_epi32(acc[2*iy], _mm256_madd_epi16(m1, dot));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = hsum_i32_8(_mm256_add_epi32(acc[2*iy], acc[2*iy+1]));
+#ifdef HAVE_FANCY_SIMD
+ info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy]));
+#else
+ info.store(ix, iy, dx[0]*dy[iy]*sumi);
+#endif
+ acc[2*iy] = acc[2*iy+1] = _mm256_setzero_si256();
+ }
+ }
+}
+
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ GGML_ASSERT(n%32 == 0);
+ __m512i qx[4];
+ __m512i acc[nrc_y <= 4 ? 2*nrc_y : nrc_y] = {};
+ float dy[nrc_y];
+ int32_t sy[nrc_y];
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+ auto iptr = (const int32_t *)(dptr + 1);
+ sy[iy] = -64*iptr[0];
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ const int8_t * q8x[8];
+ float dx[8];
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int kx = 0; kx < 8; ++kx) {
+ auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
+ dx[kx] = dptr[0];
+ q8x[kx] = (const int8_t *)(dptr + 2);
+ }
+ for (int i = 0; i < n/32; ++i) {
+ for (int kx = 0; kx < 4; ++kx) {
+ qx[kx] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8x[kx+0] + i)),
+ _mm256_loadu_si256((const __m256i *)q8x[kx+4] + i), 1);
+ }
+ auto t0 = _mm512_unpacklo_epi32(qx[0], qx[1]);
+ auto t1 = _mm512_unpacklo_epi32(qx[2], qx[3]);
+ auto t2 = _mm512_unpackhi_epi32(qx[0], qx[1]);
+ auto t3 = _mm512_unpackhi_epi32(qx[2], qx[3]);
+ qx[0] = _mm512_xor_si512(_mm512_unpacklo_epi64(t0, t1), _mm512_set1_epi8(-128));
+ qx[1] = _mm512_xor_si512(_mm512_unpackhi_epi64(t0, t1), _mm512_set1_epi8(-128));
+ qx[2] = _mm512_xor_si512(_mm512_unpacklo_epi64(t2, t3), _mm512_set1_epi8(-128));
+ qx[3] = _mm512_xor_si512(_mm512_unpackhi_epi64(t2, t3), _mm512_set1_epi8(-128));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y256 = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
+ if constexpr (nrc_y <= 4) {
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ } else {
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ }
+ }
+ }
+ auto scales_x = _mm256_loadu_ps(dx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ if constexpr (nrc_y <= 4) {
+ auto ss = _mm512_add_epi32(_mm512_add_epi32(acc[2*iy+0], acc[2*iy+1]), _mm512_set1_epi32(sy[iy]));
+ auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 0), _mm512_extracti32x4_epi32(ss, 1));
+ auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 2), _mm512_extracti32x4_epi32(ss, 3));
+ auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy]));
+ info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1)));
+ info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2)));
+ acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
+ } else {
+ acc[iy] = _mm512_add_epi32(acc[iy], _mm512_set1_epi32(sy[iy]));
+ auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 0), _mm512_extracti32x4_epi32(acc[iy], 1));
+ auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 2), _mm512_extracti32x4_epi32(acc[iy], 3));
+ auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy]));
+ info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1)));
+ info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2)));
+ acc[iy] = _mm512_setzero_si512();
+ }
+ }
+ }
+}
+#endif
+#endif
+
+template <int k_step>
+inline std::pair<mul_mat_t, int> mul_mat_kernel([[maybe_unused]] int D, int int_typeA, int nq) {
+ auto typeA = ggml_type(int_typeA);
+ constexpr int kMaxQ = 8;
+#define MAKE_FUNCS(mul_mat, n) \
+ if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ>, kMaxQ);\
+ else {\
+ switch (n) {\
+ case 1: return std::make_pair(mul_mat, 1>, 1);\
+ case 2: return std::make_pair(mul_mat, 2>, 2);\
+ case 3: return std::make_pair(mul_mat, 3>, 3);\
+ case 4: return std::make_pair(mul_mat, 4>, 4);\
+ case 5: return std::make_pair(mul_mat, 5>, 5);\
+ case 6: return std::make_pair(mul_mat, 6>, 6);\
+ case 7: return std::make_pair(mul_mat, 7>, 7);\
+ }\
+ }
+#define MAKE_FUNCS_ONLY_NRC(mul_mat, n) \
+ if (n >= kMaxQ) return std::make_pair(mul_mat<kMaxQ>, kMaxQ);\
+ else {\
+ switch (n) {\
+ case 1: return std::make_pair(mul_mat<1>, 1);\
+ case 2: return std::make_pair(mul_mat<2>, 2);\
+ case 3: return std::make_pair(mul_mat<3>, 3);\
+ case 4: return std::make_pair(mul_mat<4>, 4);\
+ case 5: return std::make_pair(mul_mat<5>, 5);\
+ case 6: return std::make_pair(mul_mat<6>, 6);\
+ case 7: return std::make_pair(mul_mat<7>, 7);\
+ }\
+ }
+ if (typeA == GGML_TYPE_Q8_KV) {
+#ifdef __aarch64__
+ if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
+ if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1);
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
+#else
+ if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1);
+#ifdef HAVE_FANCY_SIMD
+ if (D%32 == 0 && k_step%8 == 0) {
+ if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV_8<16>, 16);
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV_8, nq);
+ } else {
+ if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
+ }
+#endif
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
+#endif
+ }
+ else if (typeA == GGML_TYPE_Q8_KV_R8) {
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq);
+ }
+ GGML_ABORT("Fatal error");
+}
+
+inline std::pair<mul_mat_t, int> mul_mat_kernel(int D, int int_typeA, int nq, int k_step) {
+ switch (k_step) {
+ case 32: return mul_mat_kernel< 32>(D, int_typeA, nq);
+ case 64: return mul_mat_kernel< 64>(D, int_typeA, nq);
+ case 128: return mul_mat_kernel<128>(D, int_typeA, nq);
+ default: GGML_ABORT("Fatal error");
+ }
+}
+
+}
+
+void iqk_gemm_q8kv_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step) {
+ auto [mul_mat, nrc_q] = mul_mat_kernel(D, type_k, nq, k_step);
+ for (int iq = 0; iq < nq/nrc_q; ++iq) {
+ mul_mat(D, k, stride_k, info, k_step);
+ info.cur_y += nrc_q;
+ }
+ int iq = nrc_q*(nq/nrc_q);
+ if (iq < nq) {
+ auto [mul_mat1, nrc_q1] = mul_mat_kernel(D, type_k, nq - iq, k_step);
+ GGML_ASSERT(nrc_q1 == nq - iq);
+ mul_mat1(D, k, stride_k, info, k_step);
+ }
+}
+
+#endif
diff --git a/ggml/src/iqk/iqk_gemm_kquants.h b/ggml/src/iqk/iqk_gemm_kquants.h
new file mode 100644
index 00000000..071d2e50
--- /dev/null
+++ b/ggml/src/iqk/iqk_gemm_kquants.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include "iqk_common.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include <array>
+
+bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
+
+void iqk_gemm_q8kv_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step);
+
+#endif
diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
new file mode 100644
index 00000000..6e262aab
--- /dev/null
+++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
@@ -0,0 +1,2763 @@
+#include "iqk_gemm_legacy_quants.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include "ggml-impl.h"
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+//
+// ============================== Legacy quants
+//
+
+#ifdef __x86_64__
+
+namespace {
+
+struct DotHelper {
+ const __m256i m1 = _mm256_set1_epi16(1);
+#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
+ inline __m256i dot(__m256i x, __m256i y) const {
+ return _mm256_dpbusd_epi32(_mm256_setzero_si256(), x, y);
+ }
+#else
+ inline __m256i dot(__m256i x, __m256i y) const {
+ return _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x, y));
+ }
+#endif
+};
+
+struct SignedDot {
+ DotHelper helper;
+ inline __m256i compute(__m256i x, __m256i y) const {
+ return helper.dot(_mm256_sign_epi8(x, x), _mm256_sign_epi8(y, x));
+ }
+};
+struct UnsignedDot {
+ DotHelper helper;
+ inline __m256i compute(__m256i x, __m256i y) const {
+ return helper.dot(x, y);
+ }
+};
+
+template <typename Q8, typename Q8x4, typename Dot, bool can_pack = true> struct Sum4 {
+ Dot dot;
+ inline __m256i compute(const __m256i * qx, const Q8 * y) const {
+ const Q8x4 * y4 = (const Q8x4 *)y;
+ const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 8x block 0
+ const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 8x block 1
+ const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 8x block 2
+ const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 8x block 3
+ if constexpr (can_pack) {
+ const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1
+ const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3
+ return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3
+ } else {
+ // Note to myself: this is much faster than using _mm256_hadd_epi32()
+ auto p01 = _mm256_add_epi32(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,1, 0,1, 0,1, 0,1
+ auto p23 = _mm256_add_epi32(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,3, 2,3, 2,3, 2,3
+ return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3
+ }
+ }
+ inline __m256i compute(__m256i x, __m256i y) const { return dot.compute(x, y); }
+};
+
+template <typename Q8, typename Q8x4> struct Sum4q4 {
+ inline __m256i compute(const __m256i * qx, const Q8 * y) const {
+ const Q8x4 * y4 = (const Q8x4 *)y;
+ auto p0 = _mm256_maddubs_epi16(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 16x block 0
+ auto p1 = _mm256_maddubs_epi16(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 16x block 1
+ auto p2 = _mm256_maddubs_epi16(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 16x block 2
+ auto p3 = _mm256_maddubs_epi16(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 16x block 3
+ auto p01 = _mm256_add_epi16(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1, 0,0, 1,1, 0,0, 1,1
+ auto p23 = _mm256_add_epi16(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3, 2,2, 3,3, 2,2, 3,3
+ auto p0123 = _mm256_add_epi16(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3
+ return _mm256_madd_epi16(_mm256_set1_epi16(1), p0123);
+ }
+ inline __m256i compute(__m256i x, __m256i y) const { return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(x, y)); }
+};
+
+struct ScaleHelperQ8_0 {
+ inline __m128 prepare4(const block_q8_0 * y) {
+ const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y;
+ return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4->d));
+ }
+ inline __m128 prepare4(__m128 other_scales, const block_q8_0 * y) {
+ return _mm_mul_ps(other_scales, prepare4(y));
+ }
+ template <typename Q> inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); }
+ template <typename Q> inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }
+};
+
+struct ScaleHelperQ_0 {
+ ggml_half scales8[4];
+ template <typename Q>
+ inline __m128 prepare4(const Q * y) {
+ for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;
+ return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8));
+ }
+ template <typename Q>
+ inline __m128 prepare4(__m128 other_scales, const Q * y) {
+ return _mm_mul_ps(other_scales, prepare4<Q>(y));
+ }
+ template <typename Q> inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); }
+ template <typename Q> inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }
+};
+
+template <int min_value>
+struct ScaleHelperQ_0_1 {
+ ggml_half scales8[4];
+ template <typename Q>
+ inline __m256 prepare4(const Q * y) {
+ for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;
+ auto s4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8));
+ return _mm256_set_m128(_mm_mul_ps(s4, min), s4);
+ }
+ template <typename Q>
+ inline __m256 prepare4(__m256 other_scales, const Q * y) {
+ return _mm_mul256_ps(other_scales, prepare4<Q>(y));
+ }
+ template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
+ float d = GGML_FP16_TO_FP32(y->d);
+ return std::make_pair(d, -d*float(min_value));
+ }
+ std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {
+ return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
+ }
+ const __m128 min = _mm_set1_ps(float(-min_value));
+};
+
+//template <int min_value>
+//struct ScaleHelperQ_0_2 {
+// ggml_bf16_t scales8[4];
+// template <typename Q>
+// inline __m256 prepare4(const Q * y) {
+// for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;
+// auto s4 = _mm_castsi128_ps(_mm_slli_epi16(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)scales8)), 16));
+// return _mm256_set_m128(_mm_mul_ps(s4, min), s4);
+// }
+// template <typename Q>
+// inline __m256 prepare4(__m256 other_scales, const Q * y) {
+// return _mm_mul256_ps(other_scales, prepare4<Q>(y));
+// }
+// template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
+// float d = GGML_BF16_TO_FP32(y->d);
+// return std::make_pair(d, -d*float(min_value));
+// }
+// std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {
+// return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
+// }
+// const __m128 min = _mm_set1_ps(float(-min_value));
+//};
+
+struct ScaleHelperQ8_1 {
+ template <typename Q>
+ inline __m256 prepare4(const Q * y) {
+ const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y;
+ return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)y4->d));
+ }
+ template <typename Q>
+ inline __m256 prepare4(__m256 other_scales, const Q * y) {
+ return _mm256_mul_ps(other_scales, prepare4<Q>(y));
+ }
+ template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
+ return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m));
+ }
+ template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const {
+ return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m));
+ }
+ std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {
+ return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
+ }
+};
+
+struct ScaleHelperQ8_2 {
+ template <typename Q>
+ inline __m256 prepare4(const Q * y) {
+ const block_q8_2_x4 * y4 = (const block_q8_2_x4 *)y;
+ auto aux = _mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)y4->d));
+ return _mm256_castsi256_ps(_mm256_slli_epi32(aux, 16));
+ }
+ template <typename Q>
+ inline __m256 prepare4(__m256 other_scales, const Q * y) {
+ return _mm256_mul_ps(other_scales, prepare4<Q>(y));
+ }
+ template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
+ return std::make_pair(GGML_BF16_TO_FP32(y->d), GGML_BF16_TO_FP32(y->m));
+ }
+ template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const {
+ ggml_bf16_t d, s; d.bits = y->d; s.bits = y->s;
+ return std::make_pair(dm.first*GGML_BF16_TO_FP32(d), dm.second*GGML_BF16_TO_FP32(s));
+ }
+ std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_2 * y) const {
+ ggml_bf16_t d, s; d.bits = y->d; s.bits = y->s;
+ return std::make_pair(dm.first*GGML_BF16_TO_FP32(d), dm.second*GGML_BF16_TO_FP32(s));
+ }
+};
+
+struct ScaleHelperQ_1 {
+ uint32_t scales8[4];
+ const __m128i shuffle = _mm_set_epi16(0x0f0e, 0x0b0a, 0x0706, 0x0302, 0x0d0c, 0x0908, 0x0504, 0x0100);
+
+ template <typename Q>
+ inline __m256 prepare4(const Q * y) {
+ for (int j = 0; j < 4; ++j) {
+ // it is slightly faster to directly dereference (const uint32 *)&y[j].d, but some compilers
+ // complain that this breaks strict-aliasing rules.
+ memcpy(scales8 + j, &y[j].d, sizeof(uint32_t));
+ }
+ return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *)scales8), shuffle));
+ }
+
+ template <typename Q>
+ inline __m256 prepare4(__m256 other_scales, const Q * y) {
+ return _mm256_mul_ps(other_scales, prepare4<Q>(y));
+ }
+
+ template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
+ return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m));
+ }
+ template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const {
+ return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m));
+ }
+ std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {
+ return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
+ }
+};
+
+struct MinusType0 {
+ inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); }
+ inline float compute(float d, int) const { return d; }
+ inline float result(__m256 acc, int) const { return hsum_float_8(acc); }
+ inline __m256 vresult(__m256 acc, int) const { return acc; }
+};
+
+template <int nrc_y> struct MinusType1 {
+ __m128 accm[nrc_y];
+ MinusType1() { for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm_setzero_ps(); }
+ inline __m256 compute(__m256 dm, int iy) {
+ const __m128 d = _mm256_castps256_ps128(dm);
+ const __m128 m = _mm256_extractf128_ps(dm, 1);
+ accm[iy] = _mm_add_ps(accm[iy], m);
+ return _mm256_set_m128(d, d);
+ }
+ inline float compute(const std::pair<float, float>& dm, int iy) {
+ accm[iy] = _mm_add_ps(accm[iy], _mm_set1_ps(dm.second*0.25f));
+ return dm.first;
+ }
+ inline float result(__m256 acc, int iy) const {
+ const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
+ return hsum_float_4(_mm_add_ps(sum, accm[iy]));
+ }
+ inline __m256 vresult(__m256 acc, int iy) const {
+ return _mm256_add_ps(acc, _mm256_insertf128_ps(_mm256_setzero_ps(), accm[iy], 0));
+ }
+};
+
+template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
+ __m256 acc[nrc_y];
+ Minus accm;
+ AccumT() { for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps(); }
+ template <typename Unpacker, typename Scales, typename Sum, typename Q8>
+ inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, const DataInfo& info, int ix) {
+ auto qx = unp.quants();
+ __m256 dall[nrc_y];
+ for (int i = 0; i < nb/4; ++i) {
+ auto other_scales = unp.set_block_4(i);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto s12 = scales.prepare4(other_scales, y[iy] + 4*i);
+ dall[iy] = accm.compute(s12, iy);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto pall = sum.compute(qx, y[iy] + 4*i);
+ acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]);
+ }
+ }
+ if (!is_multiple_of_4) {
+ for (int i = 4*(nb/4); i < nb; ++i) {
+ auto other_scales = unp.set_block(i);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto s12 = scales.prepare1(other_scales, y[iy] + i);
+ auto d = accm.compute(s12, iy);
+ const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));
+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, accm.result(acc[iy], iy));
+ }
+ }
+ template <typename Unpacker, typename Scales, typename Sum, typename Q8>
+ inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, __m256 * result) {
+ auto qx = unp.quants();
+ __m256 dall[nrc_y];
+ for (int i = 0; i < nb/4; ++i) {
+ auto other_scales = unp.set_block_4(i);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto s12 = scales.prepare4(other_scales, y[iy] + 4*i);
+ dall[iy] = accm.compute(s12, iy);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto pall = sum.compute(qx, y[iy] + 4*i);
+ acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]);
+ }
+ }
+ if (!is_multiple_of_4) {
+ for (int i = 4*(nb/4); i < nb; ++i) {
+ auto other_scales = unp.set_block(i);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto s12 = scales.prepare1(other_scales, y[iy] + i);
+ auto d = accm.compute(s12, iy);
+ const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));
+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ result[iy] = accm.vresult(acc[iy], iy);
+ }
+ }
+};
+
+template <int nrc_y, bool is_multiple_of_4>
+using AccumType0 = AccumT<MinusType0, nrc_y, is_multiple_of_4>;
+
+template <int nrc_y, bool is_multiple_of_4>
+using AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>;
+
+using Sum4TypeQ80 = Sum4<block_q8_0, block_q8_0_x4, SignedDot, false>;
+using Sum4TypeQ82 = Sum4<block_q8_2, block_q8_2_x4, UnsignedDot, false>;
+
+template <typename Unpacker, typename AccumType, typename Scales, typename Q8, int nrc_y>
+void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) {
+ Unpacker unp(vx, bx);
+ typename Unpacker::Sum4T sum4;
+ Scales scales;
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ unp.set_row(ix);
+ AccumType accum;
+ accum.compute(nb, unp, scales, sum4, y, info, ix);
+ }
+}
+
+template <typename Unpacker, typename AccumType, typename Scales, typename Q8, int nrc_y>
+void mul_mat_qX_q8_Helper_x2(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) {
+ GGML_ASSERT(nrc_x%2 == 0);
+ Unpacker unp(vx, bx);
+ typename Unpacker::Sum4T sum4;
+ Scales scales;
+ for (int ix = 0; ix < nrc_x; ix += 2) {
+ unp.set_row(ix);
+ AccumType accum;
+ accum.compute(nb, unp, scales, sum4, y, info, ix);
+ }
+}
+
+template <typename Unpacker, int nrc_y>
+void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%Unpacker::block_size() == 0);
+ Q8<nrc_y, block_q8_0> q8(info);
+ int nb = n/Unpacker::block_size();
+ if (nb%4 == 0) {
+ mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, true>, ScaleHelperQ8_0, block_q8_0, nrc_y>(
+ nb, vx, bx, info, q8.y, nrc_x
+ );
+ } else {
+ mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, false>, ScaleHelperQ8_0, block_q8_0, nrc_y>(
+ nb, vx, bx, info, q8.y, nrc_x
+ );
+ }
+}
+
+template <typename Unpacker, int nrc_y, int nrc_x>
+void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) {
+ static_assert(8%nrc_y == 0);
+ Q8<nrc_y, block_q8_0> q8(info);
+ int nb = n/Unpacker::block_size();
+ Unpacker unp(vx, bx);
+ typename Unpacker::Sum4T sum4;
+ ScaleHelperQ8_0 scales;
+ __m256 result[8];
+ auto store = [&info, &result] (int ix0) {
+ if constexpr (nrc_y == 1) {
+ info.store(ix0, 0, hsum_float_8x8(result));
+ }
+ else if constexpr (nrc_y == 2) {
+ auto value = hsum_float_8x8(result);
+ auto value1 = _mm256_extractf128_ps(value, 1);
+ info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88));
+ info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd));
+ }
+ else {
+ float val[8];
+ _mm256_storeu_ps(val, hsum_float_8x8(result));
+ for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
+ }
+ };
+ if (nb%4 == 0) {
+ for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
+ for (int ix = 0; ix < 8/nrc_y; ++ix) {
+ unp.set_row(ix0 + ix);
+ AccumType0<nrc_y, true> accum;
+ accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
+ }
+ store(ix0);
+ }
+ } else {
+ for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
+ for (int ix = 0; ix < 8/nrc_y; ++ix) {
+ unp.set_row(ix0 + ix);
+ AccumType0<nrc_y, false> accum;
+ accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
+ }
+ store(ix0);
+ }
+ }
+}
+
+template <typename Unpacker, int nrc_y>
+void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%Unpacker::block_size() == 0);
+ Q8<nrc_y, block_q8_1> q8(info);
+ int nb = n/Unpacker::block_size();
+ if (nb%4 == 0) {
+ mul_mat_qX_q8_Helper<Unpacker, AccumType1<nrc_y, true>, ScaleHelperQ8_1, block_q8_1, nrc_y>(
+ nb, vx, bx, info, q8.y, nrc_x
+ );
+ } else {
+ mul_mat_qX_q8_Helper<Unpacker, AccumType1<nrc_y, false>, ScaleHelperQ8_1, block_q8_1, nrc_y>(
+ nb, vx, bx, info, q8.y, nrc_x
+ );
+ }
+}
+
+template <typename Unpacker, int nrc_y>
+void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%Unpacker::block_size() == 0);
+ Q8<nrc_y, block_q8_2> q8(info);
+ int nb = n/Unpacker::block_size();
+ if (nb%4 == 0) {
+ mul_mat_qX_q8_Helper<Unpacker, AccumType1<nrc_y, true>, ScaleHelperQ8_2, block_q8_2, nrc_y>(
+ nb, vx, bx, info, q8.y, nrc_x
+ );
+ } else {
+ mul_mat_qX_q8_Helper<Unpacker, AccumType1<nrc_y, false>, ScaleHelperQ8_2, block_q8_2, nrc_y>(
+ nb, vx, bx, info, q8.y, nrc_x
+ );
+ }
+}
+
+template <typename Unpacker, int nrc_y, int nrc_x>
+void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) {
+ static_assert(8%nrc_y == 0);
+ Q8<nrc_y, block_q8_2> q8(info);
+ int nb = n/Unpacker::block_size();
+ Unpacker unp(vx, bx);
+ typename Unpacker::Sum4T sum4;
+ ScaleHelperQ8_2 scales;
+ __m256 result[8];
+ auto store = [&info, &result] (int ix0) {
+ if constexpr (nrc_y == 1) {
+ info.store(ix0, 0, hsum_float_8x8(result));
+ }
+ else if constexpr (nrc_y == 2) {
+ auto value = hsum_float_8x8(result);
+ auto value1 = _mm256_extractf128_ps(value, 1);
+ info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88));
+ info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd));
+ }
+ else {
+ float val[8];
+ _mm256_storeu_ps(val, hsum_float_8x8(result));
+ for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
+ }
+ };
+ if (nb%4 == 0) {
+ for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
+ for (int ix = 0; ix < 8/nrc_y; ++ix) {
+ unp.set_row(ix0 + ix);
+ AccumType1<nrc_y, true> accum;
+ accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
+ }
+ store(ix0);
+ }
+ } else {
+ for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
+ for (int ix = 0; ix < 8/nrc_y; ++ix) {
+ unp.set_row(ix0 + ix);
+ AccumType1<nrc_y, false> accum;
+ accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
+ }
+ store(ix0);
+ }
+ }
+}
+
+struct Dequantizer4bit {
+ const __m256i m4 = _mm256_set1_epi8(0xf);
+ inline __m256i dequant(const uint8_t * qs) const {
+ const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);
+ return _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), m4);
+ }
+};
+
+struct Q8_0_Dequantizer {
+ inline __m256i dequant(const block_q8_0 * x) const {
+ return _mm256_loadu_si256((const __m256i *)x->qs);
+ }
+};
+
+struct Q8_0_1_Dequantizer {
+ inline __m256i dequant(const block_q8_0 * x) const {
+ return _mm256_add_epi8(_mm256_set1_epi8(127), _mm256_loadu_si256((const __m256i *)x->qs));
+ }
+};
+
+struct Q4_0_Dequantizer {
+ Dequantizer4bit b4;
+ const __m256i m8 = _mm256_set1_epi8(-8);
+ inline __m256i dequant(const block_q4_0 * x) const {
+ return _mm256_add_epi8(b4.dequant(x->qs), m8);
+ }
+};
+
+struct Q4_0_1_Dequantizer {
+ Dequantizer4bit b4;
+ inline __m256i dequant(const block_q4_0 * x) const {
+ return b4.dequant(x->qs);
+ }
+};
+
+struct IQ4_NL_Dequantizer {
+ Dequantizer4bit b4;
+#ifdef HAVE_FANCY_SIMD
+ const __m256i values = load_iq4nl_values_256();
+#else
+ const __m256i values = load_iq4k_values_256();
+#endif
+ inline __m256i dequant(const block_iq4_nl * x) const {
+ return _mm256_shuffle_epi8(values, b4.dequant(x->qs));
+ }
+};
+
+struct Q4_1_Dequantizer {
+ Dequantizer4bit b4;
+ inline __m256i dequant(const block_q4_1 * x) const {
+ return b4.dequant(x->qs);
+ }
+};
+
+struct HBitDequantizer {
+ const __m256i shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ const __m256i mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
+ const __m256i minus1 = _mm256_set1_epi64x(-1);
+ inline __m256i to_bytes(const uint8_t * bits) const {
+ // Note: Data in all ggml quants is at least 2-byte aligned.
+ // => we can cast to uint16_t and use or on two consecutive entries
+ // which is faster than memcpy
+ const uint16_t * aux16 = (const uint16_t *)bits;
+ const uint32_t aux32 = aux16[0] | (aux16[1] << 16);
+ //uint32_t aux32; memcpy(&aux32, bits, sizeof(uint32_t));
+ __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(aux32), shuffle);
+ bytes = _mm256_or_si256(bytes, mask);
+ return _mm256_cmpeq_epi8(bytes, minus1);
+ }
+};
+
+struct Q5_0_Dequantizer {
+ Dequantizer4bit b4;
+ HBitDequantizer hbit;
+ const __m256i mh = _mm256_set1_epi8((char)0xF0);
+ inline __m256i dequant(const block_q5_0 * x) const {
+ const __m256i vqh = _mm256_andnot_si256(hbit.to_bytes(x->qh), mh);
+ return _mm256_or_si256(b4.dequant(x->qs), vqh);
+ }
+};
+
+template <typename Q5>
+struct Q5_1_Dequantizer {
+ Dequantizer4bit b4;
+ HBitDequantizer hbit;
+ const __m256i mh = _mm256_set1_epi8(0x10);
+ inline __m256i dequant(const Q5 * x) const {
+ const __m256i vqh = _mm256_and_si256(hbit.to_bytes(x->qh), mh);
+ return _mm256_or_si256(b4.dequant(x->qs), vqh);
+ }
+};
+struct Q6_0_1_Dequantizer {
+ Dequantizer4bit b4;
+ const __m256i mh = _mm256_set1_epi8(0x30);
+ const __m256i shift1 = _mm256_set_epi64x(0, 2, 0, 4);
+ const __m256i shift2 = _mm256_set_epi64x(2, 0, 0, 0);
+ inline __m256i dequant(const block_q6_0 * x) const {
+ uint64_t aux64; std::memcpy(&aux64, x->qh, 8);
+ auto h256 = _mm256_sllv_epi64(_mm256_set1_epi64x(aux64), shift1);
+ return _mm256_or_si256(b4.dequant(x->qs), _mm256_and_si256(_mm256_srlv_epi64(h256, shift2), mh));
+ }
+};
+
+template <typename Q, typename Scales, typename Dequantizer>
+struct Q_Unpacker {
+ Q_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const Q*)cx_0), bx(bx) {}
+
+ const char * cx_0;
+ const Q * x;
+ size_t bx;
+
+ Scales scales;
+ Dequantizer deq;
+
+ __m256i qx[4];
+
+ inline const __m256i* quants() const { return qx; }
+
+ inline void set_row(int ix) { x = (const Q*)(cx_0 + ix*bx); }
+
+ inline auto set_block_4(int i) {
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = deq.dequant(x + 4*i + j);
+ }
+ return scales.prepare4(x + 4*i);
+ }
+ inline auto set_block(int i) {
+ qx[0] = deq.dequant(x + i);
+ return scales.prepare1(x + i);
+ }
+};
+
+struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {
+ Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ80;
+ inline static int block_size() { return QK8_0; }
+};
+struct Q8_0_1_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0_1<127>, Q8_0_1_Dequantizer> {
+ Q8_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ82;
+ inline static int block_size() { return QK8_0; }
+};
+struct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_Dequantizer> {
+ Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ80;
+ inline static int block_size() { return QK4_0; }
+};
+struct Q4_0_1_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0_1<8>, Q4_0_1_Dequantizer> {
+ Q4_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ //using Sum4T = Sum4TypeQ82;
+ using Sum4T = Sum4q4<block_q8_2, block_q8_2_x4>;
+ inline static int block_size() { return QK4_0; }
+};
+#ifdef HAVE_FANCY_SIMD
+struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0_1<128>, IQ4_NL_Dequantizer> {
+ IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ82;
+ inline static int block_size() { return QK4_NL; }
+};
+#else
+struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0, IQ4_NL_Dequantizer> {
+ IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ80;
+ inline static int block_size() { return QK4_NL; }
+};
+#endif
+struct Q5_0_Unpacker final : public Q_Unpacker<block_q5_0, ScaleHelperQ_0, Q5_0_Dequantizer> {
+ Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ80;
+ inline static int block_size() { return QK5_0; }
+};
+struct Q5_0_1_Unpacker final : public Q_Unpacker<block_q5_0, ScaleHelperQ_0_1<16>, Q5_1_Dequantizer<block_q5_0>> {
+ Q5_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ82;
+ inline static int block_size() { return QK5_0; }
+};
+struct Q4_1_Unpacker final : public Q_Unpacker<block_q4_1, ScaleHelperQ_1, Q4_1_Dequantizer> {
+ Q4_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ82;
+ inline static int block_size() { return QK4_1; }
+};
+struct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_Dequantizer<block_q5_1>> {
+ Q5_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ82;
+ inline static int block_size() { return QK5_1; }
+};
+struct Q6_0_1_Unpacker final : public Q_Unpacker<block_q6_0, ScaleHelperQ_0_1<32>, Q6_0_1_Dequantizer> {
+ Q6_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ82;
+ inline static int block_size() { return QK6_0; }
+};
+
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+ auto m4 = _mm512_set1_epi8(0xf);
+ auto values = load_iq4nl_values_512();
+ int nb = n / QK4_NL;
+ __m512 acc[2*nrc_y] = {};
+ __m512i qx[4];
+ float d8[8*nrc_y];
+ auto prepare = [&qx, &m4, &values] (const block_iq4_nl_r4& iq4l, const block_iq4_nl_r4& iq4h) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l.d));
+ auto scales1 = _mm256_set_m128(scales128, scales128);
+ scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h.d));
+ auto scales2 = _mm256_set_m128(scales128, scales128);
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+0)),
+ _mm256_loadu_si256((const __m256i *)iq4h.qs+0), 1);
+ auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+1)),
+ _mm256_loadu_si256((const __m256i *)iq4h.qs+1), 1);
+ qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
+ qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
+ qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
+ qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
+ return scales;
+ };
+ auto dot = [&qx] (__m256i y8) {
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ return sumi;
+ };
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_iq4_nl_r4 * iq4l = (const block_iq4_nl_r4 *)((const char *)vx + (ix+0)*bx);
+ const block_iq4_nl_r4 * iq4h = (const block_iq4_nl_r4 *)((const char *)vx + (ix+4)*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16);
+ _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
+ auto dy = _mm512_set1_ps(d8[8*iy+k]);
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
+ }
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq4l[ib], iq4h[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
+ ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s;
+ auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-64.f), acc[2*iy+1], acc[2*iy+0]);
+ acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
+ auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
+ auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
+ info.store(ix+0, iy, sum1);
+ info.store(ix+4, iy, sum2);
+ }
+ }
+}
+#else
+template <int nrc_y>
+static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ auto m1 = _mm256_set1_epi16(1);
+ auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
+ auto values = MM256_SET_M128I(values128, values128);
+ int nb = n / QK4_NL;
+ __m256 acc[nrc_y] = {};
+ __m256i qs[4];
+ float d8[4*nrc_y];
+ auto prepare = [&qs, &values, &m4] (const block_iq4_nl_r4& iq4) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4.d));
+ auto scales = _mm256_set_m128(scales128, scales128);
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq4.qs+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq4.qs+1);
+ qs[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
+ qs[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
+ qs[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
+ qs[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
+ return scales;
+ };
+ auto dot = [&qs, &m1] (__m256i y) {
+ auto u1 = _mm256_sign_epi8(qs[0], qs[0]);
+ auto u2 = _mm256_sign_epi8(qs[1], qs[1]);
+ auto sumi1 = _mm256_add_epi32(
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qs[0]))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qs[1]))));
+ u1 = _mm256_sign_epi8(qs[2], qs[2]);
+ u2 = _mm256_sign_epi8(qs[3], qs[3]);
+ auto sumi2 = _mm256_add_epi32(
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qs[2]))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qs[3]))));
+ return _mm256_add_epi32(sumi1, sumi2);
+ };
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto aux = _mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)), 16);
+ _mm_storeu_ps(d8+4*iy, _mm_castsi128_ps(aux));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales = prepare(iq4[4*ib4+k]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq4[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
+ ggml_bf16_t d{qy[ib].d};
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, sum);
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+#endif
+
+inline void prepare_q4_0_quants_avx2(const uint8_t * qs, __m256i * v, const __m256i& m4) {
+ auto bits1 = _mm256_loadu_si256((const __m256i *)qs+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)qs+1);
+ auto bits3 = _mm256_loadu_si256((const __m256i *)qs+2);
+ auto bits4 = _mm256_loadu_si256((const __m256i *)qs+3);
+ v[0] = _mm256_and_si256(bits1, m4);
+ v[1] = _mm256_and_si256(bits2, m4);
+ v[2] = _mm256_and_si256(bits3, m4);
+ v[3] = _mm256_and_si256(bits4, m4);
+ v[4] = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4);
+ v[5] = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4);
+ v[6] = _mm256_and_si256(_mm256_srli_epi16(bits3, 4), m4);
+ v[7] = _mm256_and_si256(_mm256_srli_epi16(bits4, 4), m4);
+}
+
+inline __m256i accum_q4_0_quants(const __m256i * v, const int8_t * qs) {
+ auto y4l = _mm_loadu_si128((const __m128i*)qs+0);
+ auto y4h = _mm_loadu_si128((const __m128i*)qs+1);
+ auto yl = MM256_SET_M128I(y4l, y4l);
+ auto yh = MM256_SET_M128I(y4h, y4h);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, v[0], _mm256_shuffle_epi32(yl, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, v[1], _mm256_shuffle_epi32(yl, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, v[2], _mm256_shuffle_epi32(yl, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, v[3], _mm256_shuffle_epi32(yl, 0xff));
+ sumi = _mm256_dpbusd_epi32(sumi, v[4], _mm256_shuffle_epi32(yh, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, v[5], _mm256_shuffle_epi32(yh, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, v[6], _mm256_shuffle_epi32(yh, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, v[7], _mm256_shuffle_epi32(yh, 0xff));
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(v[0], _mm256_shuffle_epi32(yl, 0x00)),
+ _mm256_maddubs_epi16(v[1], _mm256_shuffle_epi32(yl, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(v[2], _mm256_shuffle_epi32(yl, 0xaa)),
+ _mm256_maddubs_epi16(v[3], _mm256_shuffle_epi32(yl, 0xff)));
+ auto sumi3 = _mm256_add_epi16(_mm256_maddubs_epi16(v[4], _mm256_shuffle_epi32(yh, 0x00)),
+ _mm256_maddubs_epi16(v[5], _mm256_shuffle_epi32(yh, 0x55)));
+ auto sumi4 = _mm256_add_epi16(_mm256_maddubs_epi16(v[6], _mm256_shuffle_epi32(yh, 0xaa)),
+ _mm256_maddubs_epi16(v[7], _mm256_shuffle_epi32(yh, 0xff)));
+ auto sumi = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(_mm256_add_epi16(sumi1, sumi2), _mm256_add_epi16(sumi3, sumi4)));
+#endif
+ return sumi;
+}
+
+template <int nrc_y>
+static void mul_mat_q4_0_r8_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_1_x4> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ int nb = n / QK4_NL;
+ __m256i v[8];
+ GGML_ASSERT(nb%4 == 0);
+ if constexpr (nrc_y == 1) {
+ union { __m256 vec; float val[8]; } helper;
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_iq4_nl_r8 * iq4 = (const block_iq4_nl_r8 *)((const char *)vx + ix*bx);
+ auto acc1 = _mm256_setzero_ps();
+ auto acc2 = _mm256_setzero_ps();
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ helper.vec = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)), 16));
+ for (int k = 0; k < 4; ++k) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d));
+ prepare_q4_0_quants_avx2(iq4[4*ib4+k].qs, v, m4);
+ auto sumi = accum_q4_0_quants(v, q8.y[0][ib4].qs+32*k);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(helper.val[k]));
+ acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1);
+ acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(helper.val[k+4]), acc2);
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto qy = (const block_q8_1 *)q8.y[0];
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d));
+ prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4);
+ auto sumi = accum_q4_0_quants(v, qy[ib].qs);
+ ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
+ acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1);
+ acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc2);
+ }
+ acc1 = _mm256_fmadd_ps(acc2, _mm256_set1_ps(-8.f), acc1);
+ info.store(ix, 0, acc1);
+ }
+ }
+ else {
+ __m256 acc[nrc_y] = {};
+ float d8[8*nrc_y];
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_iq4_nl_r8 * iq4 = (const block_iq4_nl_r8 *)((const char *)vx + ix*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ {
+ __m256 d4[4];
+ for (int k = 0; k < 4; ++k) {
+ d4[k] = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d));
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16));
+ _mm256_storeu_ps(d8 + 8*iy, scales);
+ auto m4 = _mm256_extractf128_ps(scales, 1);
+ auto m8 = _mm256_set_m128(m4, m4);
+ auto sumf = _mm256_mul_ps(d4[0], _mm256_shuffle_ps(m8, m8, 0x00));
+ sumf = _mm256_fmadd_ps(d4[1], _mm256_shuffle_ps(m8, m8, 0x55), sumf);
+ sumf = _mm256_fmadd_ps(d4[2], _mm256_shuffle_ps(m8, m8, 0xaa), sumf);
+ sumf = _mm256_fmadd_ps(d4[3], _mm256_shuffle_ps(m8, m8, 0xff), sumf);
+ acc[iy] = _mm256_fmadd_ps(sumf, _mm256_set1_ps(-8.f), acc[iy]);
+ }
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d));
+ prepare_q4_0_quants_avx2(iq4[4*ib4+k].qs, v, m4);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = accum_q4_0_quants(v, q8.y[iy][ib4].qs+32*k);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k]));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d));
+ auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f));
+ prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = accum_q4_0_quants(v, qy[ib].qs);
+ ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[iy]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+ }
+}
+
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if constexpr (nrc_y == 1) {
+ mul_mat_q4_0_r8_q8_2_avx2<1>(n, vx, bx, info, nrc_x);
+ return;
+ }
+ GGML_ASSERT(nrc_x%16 == 0);
+ Q8<nrc_y, block_q8_1_x4> q8(info);
+ auto m4 = _mm512_set1_epi8(0xf);
+ int nb = n / QK4_NL;
+ __m512 acc[2*nrc_y] = {};
+ __m512i qx[8];
+ auto prepare = [&qx, &m4] (const block_iq4_nl_r8& iq4l, const block_iq4_nl_r8& iq4h) {
+ auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l.d));
+ auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h.d));
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ for (int j = 0; j < 4; ++j) {
+ auto bits = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+j)),
+ _mm256_loadu_si256((const __m256i *)iq4h.qs+j), 1);
+ qx[j+0] = _mm512_and_si512(bits, m4);
+ qx[j+4] = _mm512_and_si512(_mm512_srli_epi16(bits, 4), m4);
+ }
+ return scales;
+ };
+ auto dot = [&qx] (const int8_t * qy) {
+ auto y4l = _mm_loadu_si128((const __m128i*)qy+0);
+ auto y4h = _mm_loadu_si128((const __m128i*)qy+1);
+ auto y8l = MM256_SET_M128I(y4l, y4l);
+ auto y8h = MM256_SET_M128I(y4h, y4h);
+ auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
+ auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
+ return sumi;
+ };
+ float d8[8*nrc_y];
+ for (int ix = 0; ix < nrc_x; ix += 16) {
+ const block_iq4_nl_r8 * iq4l = (const block_iq4_nl_r8 *)((const char *)vx + (ix+0)*bx);
+ const block_iq4_nl_r8 * iq4h = (const block_iq4_nl_r8 *)((const char *)vx + (ix+8)*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16);
+ _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = dot(q8.y[iy][ib4].qs+32*k);
+ auto dy = _mm512_set1_ps(d8[8*iy+k]);
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
+ }
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq4l[ib], iq4h[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(qy[ib].qs);
+ ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
+ auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]);
+ acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
+ info.store(ix, iy, sum);
+ }
+ }
+}
+#else
+template <int nrc_y>
+static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ mul_mat_q4_0_r8_q8_2_avx2<nrc_y>(n, vx, bx, info, nrc_x);
+}
+#endif
+
+template <int nrc_y>
+static void mul_mat_q5_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ auto m5 = _mm256_set1_epi8(0x10);
+#ifndef HAVE_FANCY_SIMD
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ auto mscale = _mm256_set_m128(_mm_set1_ps(-8.f), _mm_set1_ps(1.f));
+ int nb = n / QK5_0;
+ __m256 acc[nrc_y] = {};
+ __m256i qx[4];
+ float d8[8*nrc_y];
+ auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5.d));
+ auto scales = _mm256_set_m128(scales128, scales128);
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq5.qs+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq5.qs+1);
+ auto hbits = _mm_loadu_si128((const __m128i *)iq5.qh);
+ auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 1), hbits);
+ qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 4), m5));
+ qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 2), m5));
+ qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hb, m5));
+ qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hb, 2), m5));;
+ return scales;
+ };
+#ifdef HAVE_FANCY_SIMD
+ auto dot = [&qx] (__m256i y) {
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ return sumi;
+ };
+#else
+ auto dot = [&qx, &m1] (__m256i y) {
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
+ return sumi;
+ };
+#endif
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16));
+ _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(mscale, scales));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales = prepare(iq5[4*ib4+k]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k]));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]);
+ }
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq5[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
+ ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-8.f*GGML_BF16_TO_FP32(s)), acc[iy]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, sum);
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_q5_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if constexpr (nrc_y == 1) {
+ mul_mat_q5_0_r4_q8_2_avx2<1>(n, vx, bx, info, nrc_x);
+ } else {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+ auto m4 = _mm512_set1_epi8(0xf);
+ auto m5 = _mm512_set1_epi8(0x10);
+ int nb = n / QK5_0;
+ __m512 acc[2*nrc_y] = {};
+ __m512i qx[4];
+ float d8[8*nrc_y];
+ auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5l, const block_q5_0_r4& iq5h) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5l.d));
+ auto scales1 = _mm256_set_m128(scales128, scales128);
+ scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5h.d));
+ auto scales2 = _mm256_set_m128(scales128, scales128);
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+0)),
+ _mm256_loadu_si256((const __m256i *)iq5h.qs+0), 1);
+ auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+1)),
+ _mm256_loadu_si256((const __m256i *)iq5h.qs+1), 1);
+ auto hbits1 = _mm_loadu_si128((const __m128i *)iq5l.qh);
+ auto hbits2 = _mm_loadu_si128((const __m128i *)iq5h.qh);
+ auto hb1 = MM256_SET_M128I(_mm_srli_epi16(hbits1, 1), hbits1);
+ auto hb2 = MM256_SET_M128I(_mm_srli_epi16(hbits2, 1), hbits2);
+ auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hb1), hb2, 1);
+ qx[0] = _mm512_or_si512(_mm512_and_si512(bits1, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 4), m5));
+ qx[1] = _mm512_or_si512(_mm512_and_si512(bits2, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5));
+ qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4), _mm512_and_si512(hb, m5));
+ qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4), _mm512_and_si512(_mm512_srli_epi16(hb, 2), m5));
+ return scales;
+ };
+ auto dot = [&qx] (__m256i y8) {
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ return sumi;
+ };
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q5_0_r4 * iq5l = (const block_q5_0_r4 *)((const char *)vx + (ix+0)*bx);
+ const block_q5_0_r4 * iq5h = (const block_q5_0_r4 *)((const char *)vx + (ix+4)*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales = prepare(iq5l[4*ib4+k], iq5h[4*ib4+k]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
+ auto dy = _mm512_set1_ps(d8[8*iy+k]);
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
+ }
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq5l[ib], iq5h[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
+ ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
+ auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]);
+ acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
+ auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
+ auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
+ info.store(ix+0, iy, sum1);
+ info.store(ix+4, iy, sum2);
+ }
+ }
+ }
+}
+#else
+template <int nrc_y>
+static void mul_mat_q5_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ mul_mat_q5_0_r4_q8_2_avx2<nrc_y>(n, vx, bx, info, nrc_x);
+}
+#endif
+
+template <int nrc_y>
+static void mul_mat_q6_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ auto m6 = _mm256_set1_epi8(0x30);
+ auto mscale = _mm256_set_m128(_mm_set1_ps(-16.f), _mm_set1_ps(1.f));
+#ifndef HAVE_FANCY_SIMD
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ int nb = n / QK6_0;
+ __m256 acc[nrc_y] = {};
+ float d8[8*nrc_y];
+ __m256i qx[4];
+ auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6.d));
+ auto scales = _mm256_set_m128(scales128, scales128);
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq6.qs+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq6.qs+1);
+ auto hbits = _mm256_loadu_si256((const __m256i *)iq6.qh);
+ qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), m6));
+ qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), m6));
+ qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hbits, m6));
+ qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m6));
+ return scales;
+ };
+#ifdef HAVE_FANCY_SIMD
+ auto dot = [&qx] (__m256i y) {
+ auto sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ return sumi;
+ };
+#else
+ auto dot = [&qx, &m1] (__m256i y) {
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
+ return sumi;
+ };
+#endif
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16));
+ _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(scales, mscale));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales = prepare(iq6[4*ib4+k]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k]));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]);
+ }
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq6[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
+ ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-16.f*GGML_BF16_TO_FP32(s)), acc[iy]);
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, sum);
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_q6_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if constexpr (nrc_y == 1) {
+ mul_mat_q6_0_r4_q8_2_avx2<1>(n, vx, bx, info, nrc_x);
+ } else {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+ auto m4 = _mm512_set1_epi8(0xf);
+ auto m6 = _mm512_set1_epi8(0x30);
+ int nb = n / QK6_0;
+ __m512 acc[2*nrc_y] = {};
+ __m512i qx[4];
+ float d8[8*nrc_y];
+ auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6l, const block_q6_0_r4& iq6h) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6l.d));
+ auto scales1 = _mm256_set_m128(scales128, scales128);
+ scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6h.d));
+ auto scales2 = _mm256_set_m128(scales128, scales128);
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+0)),
+ _mm256_loadu_si256((const __m256i *)iq6h.qs+0), 1);
+ auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+1)),
+ _mm256_loadu_si256((const __m256i *)iq6h.qs+1), 1);
+ auto hbits1 = _mm256_loadu_si256((const __m256i *)iq6l.qh);
+ auto hbits2 = _mm256_loadu_si256((const __m256i *)iq6h.qh);
+ auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hbits1), hbits2, 1);
+ qx[0] = _mm512_and_si512(bits1, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 4), m6);
+ qx[1] = _mm512_and_si512(bits2, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m6);;
+ qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(hb, m6);
+ qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4) | _mm512_and_si512(_mm512_srli_epi16(hb, 2), m6);
+ return scales;
+ };
+ auto dot = [&qx] (__m256i y8) {
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ return sumi;
+ };
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q6_0_r4 * iq6l = (const block_q6_0_r4 *)((const char *)vx + (ix+0)*bx);
+ const block_q6_0_r4 * iq6h = (const block_q6_0_r4 *)((const char *)vx + (ix+4)*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16));
+ _mm256_storeu_ps(d8 + 8*iy, scales);
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales = prepare(iq6l[4*ib4+k], iq6h[4*ib4+k]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
+ auto dy = _mm512_set1_ps(d8[8*iy+k]);
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
+ }
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = prepare(iq6l[ib], iq6h[ib]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
+ ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
+ auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-16.f), acc[2*iy+1], acc[2*iy+0]);
+ acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
+ auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
+ auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
+ info.store(ix+0, iy, sum1);
+ info.store(ix+4, iy, sum2);
+ }
+ }
+ }
+}
+#else
+template <int nrc_y>
+static void mul_mat_q6_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ mul_mat_q6_0_r4_q8_2_avx2<nrc_y>(n, vx, bx, info, nrc_x);
+}
+#endif
+
+#ifdef HAVE_FANCY_SIMD
+inline __m512i qx_r8_q8_dot_product(const __m512i * qx, const int8_t * y) {
+ auto y4l = _mm_loadu_si128((const __m128i*)y+0);
+ auto y4h = _mm_loadu_si128((const __m128i*)y+1);
+ auto y8l = MM256_SET_M128I(y4l, y4l);
+ auto y8h = MM256_SET_M128I(y4h, y4h);
+ auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
+ auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
+ return sumi;
+}
+inline __m256i qx_r8_q8_dot_product(const __m256i * qx, const int8_t * y) {
+ auto y4l = _mm_loadu_si128((const __m128i*)y+0);
+ auto y4h = _mm_loadu_si128((const __m128i*)y+1);
+ auto yl = MM256_SET_M128I(y4l, y4l);
+ auto yh = MM256_SET_M128I(y4h, y4h);
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff));
+ return sumi;
+}
+inline __m256i q8_0_r8_dot_product(const uint8_t * x, const int8_t * y, __m256i * qx) {
+ for (int i = 0; i < 8; ++i) {
+ qx[i] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)x+i), _mm256_set1_epi8(127));
+ }
+ return qx_r8_q8_dot_product(qx, y);
+}
+template <int nrc_y>
+static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%16 == 0);
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+ int nb = n / QK8_0;
+ if constexpr (nrc_y == 1) {
+ __m256 acc[2] = {};
+ __m256i qx[8];
+ float d8[8];
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)), 16);
+ _mm256_storeu_ps(d8, _mm256_castsi256_ps(aux));
+ for (int k = 0; k < 4; ++k) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d));
+ auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[4*ib4+k].qs, q8.y[0][ib4].qs+32*k, qx);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[k]));
+ acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]);
+ acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[k+4]), acc[1]);
+ }
+ }
+ if (4*(nb/4) < nb) {
+ auto qy = (const block_q8_1 *)q8.y[0];
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
+ auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx);
+ ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s;
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
+ acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]);
+ acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[1]);
+ }
+ }
+ info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0]));
+ acc[0] = acc[1] = _mm256_setzero_ps();
+ }
+ } else {
+ __m512 acc[2*nrc_y] = {};
+ __m512i qx[8];
+ float d8[8*nrc_y];
+ for (int ix = 0; ix < nrc_x; ix += 16) {
+ const block_q8_0_r8 * q8l = (const block_q8_0_r8 *)((const char *)vx + (ix+0)*bx);
+ const block_q8_0_r8 * q8h = (const block_q8_0_r8 *)((const char *)vx + (ix+8)*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16);
+ _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[4*ib4+k].d));
+ auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[4*ib4+k].d));
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ for (int j = 0; j < 8; ++j) {
+ qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+j)),
+ _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+j), 1);
+ qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127));
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = qx_r8_q8_dot_product(qx, q8.y[iy][ib4].qs+32*k);
+ auto dy = _mm512_set1_ps(d8[8*iy+k]);
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
+ }
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[ib].d));
+ auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[ib].d));
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ for (int j = 0; j < 8; ++j) {
+ qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[ib].qs+j)),
+ _mm256_loadu_si256((const __m256i *)q8h[ib].qs+j), 1);
+ qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127));
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs);
+ ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s;
+ auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-127.f), acc[2*iy+1], acc[2*iy+0]);
+ info.store(ix, iy, sum512);
+ acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
+ }
+ }
+ }
+}
+#else
+template <int nrc_y>
+static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+ auto m1 = _mm256_set1_epi16(1);
+ int nb = n / QK8_0;
+ __m256 acc[nrc_y] = {};
+ float d8[4*nrc_y];
+ __m256i qx[4], sx[4];
+ auto dot = [&qx, &sx, &m1] (const int8_t * qy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)qy);
+ auto y = MM256_SET_M128I(y128, y128);
+ auto sumi1 = _mm256_add_epi32(
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])))
+ );
+ auto sumi2 = _mm256_add_epi32(
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]))),
+ _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])))
+ );
+ return _mm256_add_epi32(sumi1, sumi2);
+ };
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scales = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)), 16));
+ _mm_storeu_ps(d8 + 4*iy, scales);
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d));
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = dot(q8.y[iy][ib4].qs+32*k);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4+j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = dot(q8.y[iy][ib4].qs+32*k+16);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_2 *)q8.y[iy];
+ auto sumi = dot(qy[ib].qs);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(ggml_bf16_t{qy[ib].d})));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+4+j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_2 *)q8.y[iy];
+ auto sumi = dot(qy[ib].qs+16);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(ggml_bf16_t{qy[ib].d})));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+#endif
+
+template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
+ if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||
+ std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0_T, Dequantizer, funcs)
+ }
+ else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker>) {
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_2_T, Dequantizer, funcs)
+ }
+ else if constexpr (std::is_same_v<Dequantizer, IQ4_NL_Unpacker>) {
+#ifdef HAVE_FANCY_SIMD
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_2_T, Dequantizer, funcs)
+#else
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0_T, Dequantizer, funcs)
+#endif
+ }
+ else if constexpr (std::is_same_v<Dequantizer, Q8_0_1_Unpacker> || std::is_same_v<Dequantizer, Q4_0_1_Unpacker> ||
+ std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, Q6_0_1_Unpacker>) {
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_2_T, Dequantizer, funcs)
+ }
+}
+
+} // namespace
+
+bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
+
+ if (ne00%QK8_0 != 0) return false;
+
+ auto expected_typeB = GGML_TYPE_Q8_2_X4;
+
+ func16 = nullptr;
+
+ switch (typeA) {
+ case GGML_TYPE_Q4_0:
+ set_functions<Q4_0_1_Unpacker>(kernels);
+ break;
+ case GGML_TYPE_Q4_1:
+ set_functions<Q4_1_Unpacker>(kernels);
+ break;
+ case GGML_TYPE_Q5_0:
+ set_functions<Q5_0_1_Unpacker>(kernels);
+ break;
+ case GGML_TYPE_Q5_1:
+ set_functions<Q5_1_Unpacker>(kernels);
+ break;
+ case GGML_TYPE_Q6_0:
+ set_functions<Q6_0_1_Unpacker>(kernels);
+ break;
+ case GGML_TYPE_Q8_0:
+#ifdef HAVE_FANCY_SIMD
+ set_functions<Q8_0_1_Unpacker>(kernels);
+#else
+ set_functions<Q8_0_Unpacker>(kernels);
+ expected_typeB = GGML_TYPE_Q8_0_X4;
+#endif
+ break;
+ case GGML_TYPE_IQ4_NL:
+ set_functions<IQ4_NL_Unpacker>(kernels);
+#ifndef HAVE_FANCY_SIMD
+ expected_typeB = GGML_TYPE_Q8_0_X4;
+#endif
+ break;
+ case GGML_TYPE_Q4_0_R8:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_0_r8_q8_2, kernels)
+#ifdef HAVE_FANCY_SIMD
+ func16 = mul_mat_q4_0_r8_q8_2<16>;
+#endif
+ break;
+ case GGML_TYPE_Q5_0_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q5_0_r4_q8_2, kernels)
+ break;
+ case GGML_TYPE_Q6_0_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q6_0_r4_q8_2, kernels)
+ break;
+ case GGML_TYPE_Q8_0_R8:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_0_r8_q8_2, kernels)
+ break;
+ case GGML_TYPE_IQ4_NL_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_nl_r4_q8_2, kernels)
+ break;
+ default:
+ return false;
+ }
+
+ return ggml_type(typeB) == expected_typeB;
+}
+
+#else
+// ---------------------------- __aarch64__ ----------------------------------------------
+
+namespace {
+
+template <typename Block>
+inline float16x4_t load_scales_q0(const Block * x, ggml_half * aux) {
+ for (int k = 0; k < 4; ++k) aux[k] = x[k].d;
+ return vld1_f16((const float16_t *)aux);
+}
+
+template <typename Block>
+inline float16x8_t load_scales_q1(const Block * x, ggml_half * aux) {
+ if constexpr (std::is_same_v<Block, block_q8_1>) {
+ for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].s; }
+ } else {
+ for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].m; }
+ }
+ return vld1q_f16((const float16_t *)aux);
+}
+
+struct Q4LegacyBits {
+ template <typename Block>
+ inline void prepare(const Block * x) {
+ for (int i = 0; i < 4; ++i) {
+ auto q4bits = vld1q_u8(x[i].qs);
+ b[2*i+0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b));
+ b[2*i+1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4));
+ }
+ }
+ inline void prepare1(const uint8_t * qs, int8x16_t * q) const {
+ auto q4bits = vld1q_u8(qs);
+ q[0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b));
+ q[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4));
+ }
+ inline void prepare1(const uint8_t * qs) {
+ prepare1(qs, b);
+ }
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
+ int8x16_t b[8];
+};
+
+// One would think this commented out version would do better than the one below
+// because it offers more opportunities to execute instructions in parallel.
+// Instead, it runs significantly slower. Why? If the compiler is running out of vector registers
+// cannot it just do the sequential version below on its own?
+//inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {
+// const auto q8b_1 = vld1q_s8_x2(qs + 0);
+// auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b_1.val[0]), b[1], q8b_1.val[1]);
+// const auto q8b_2 = vld1q_s8_x2(qs + 32);
+// auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b_2.val[0]), b[3], q8b_2.val[1]);
+// auto p1234 = vpaddq_s32(p12, p34);
+// const auto q8b_3 = vld1q_s8_x2(qs + 64);
+// auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b_3.val[0]), b[5], q8b_3.val[1]);
+// const auto q8b_4 = vld1q_s8_x2(qs + 96);
+// auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b_4.val[0]), b[7], q8b_4.val[1]);
+// return vpaddq_s32(p1234, vpaddq_s32(p56, p78));
+//}
+
+inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {
+ auto q8b = vld1q_s8_x2(qs + 0);
+ auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b.val[0]), b[1], q8b.val[1]);
+ q8b = vld1q_s8_x2(qs + 32);
+ auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b.val[0]), b[3], q8b.val[1]);
+ auto p1234 = vpaddq_s32(p12, p34);
+ q8b = vld1q_s8_x2(qs + 64);
+ auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b.val[0]), b[5], q8b.val[1]);
+ q8b = vld1q_s8_x2(qs + 96);
+ auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b.val[0]), b[7], q8b.val[1]);
+ return vpaddq_s32(p1234, vpaddq_s32(p56, p78));
+}
+
+inline int32x4x2_t sum_4_blocks(const int8x16_t * b1, const int8x16_t * b2, const int8_t * qs) {
+ auto q8b = vld1q_s8_x2(qs + 0);
+ auto p12_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q8b.val[0]), b1[1], q8b.val[1]);
+ auto p12_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q8b.val[0]), b2[1], q8b.val[1]);
+ q8b = vld1q_s8_x2(qs + 32);
+ auto p34_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q8b.val[0]), b1[3], q8b.val[1]);
+ auto p34_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q8b.val[0]), b2[3], q8b.val[1]);
+ auto p1234_1 = vpaddq_s32(p12_1, p34_1);
+ auto p1234_2 = vpaddq_s32(p12_2, p34_2);
+ q8b = vld1q_s8_x2(qs + 64);
+ auto p56_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[4], q8b.val[0]), b1[5], q8b.val[1]);
+ auto p56_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[4], q8b.val[0]), b2[5], q8b.val[1]);
+ q8b = vld1q_s8_x2(qs + 96);
+ auto p78_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[6], q8b.val[0]), b1[7], q8b.val[1]);
+ auto p78_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[6], q8b.val[0]), b2[7], q8b.val[1]);
+ auto p5678_1 = vpaddq_s32(p56_1, p78_1);
+ auto p5678_2 = vpaddq_s32(p56_2, p78_2);
+ return { vpaddq_s32(p1234_1, p5678_1), vpaddq_s32(p1234_2, p5678_2)};
+}
+
+template <int nrc> struct Q80 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q80(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy);
+ }
+
+ inline const int8_t * quant_data(int iy, int i) const {
+ const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i;
+ return y4->qs;
+ }
+
+ inline float16x4_t load_scales(int iy, int i) const {
+ const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i;
+ return vld1_f16((const float16_t *)y4->d);
+ }
+
+ template <typename Dequantizer>
+ inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * /*acc*/) const {
+ auto qx_scales = deq.new_block(i);
+ for (int iy = 0; iy < nrc; ++iy) {
+ auto q8_scales = load_scales(iy, i);
+ sc16[iy] = vmul_f16(qx_scales, q8_scales);
+ }
+ }
+
+ template <typename Dequantizer>
+ inline void process_scales(int i, Dequantizer& deq1, Dequantizer& deq2, float16x4_t * sc16, float32x4_t * /*acc*/) const {
+ auto qx_scales_1 = deq1.new_block(i);
+ auto qx_scales_2 = deq2.new_block(i);
+ for (int iy = 0; iy < nrc; ++iy) {
+ auto q8_scales = load_scales(iy, i);
+ sc16[iy ] = vmul_f16(qx_scales_1, q8_scales);
+ sc16[iy+nrc_y] = vmul_f16(qx_scales_2, q8_scales);
+ }
+ }
+
+ template <typename Dequantizer>
+ inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {
+ deq.prepare1(i);
+ float d = GGML_FP16_TO_FP32(deq.x[i].d);
+ for (int iy = 0; iy < nrc; ++iy) {
+ auto q8b = vld1q_s8_x2(y[iy][i].qs);
+ auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);
+ acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p));
+ }
+ }
+
+ const block_q8_0 * y[nrc_y];
+};
+
+template <int nrc> struct Q81 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q81(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_1 *)info.src1_row(iy);
+ }
+
+ inline const int8_t * quant_data(int iy, int i) const {
+ const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i;
+ return y4->qs;
+ }
+
+ inline float16x8_t load_scales(int iy, int i) const {
+ const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i;
+ return vld1q_f16((const float16_t *)y4->d);
+ }
+
+ template <typename Dequantizer>
+ inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * acc) const {
+ auto qx_scales = deq.new_block(i);
+ for (int iy = 0; iy < nrc; ++iy) {
+ auto q8_scales = load_scales(iy, i);
+ auto m = vmul_f16(vget_high_f16(qx_scales), vget_high_f16(q8_scales));
+ acc[iy] = vaddq_f32(acc[iy], vcvt_f32_f16(m));
+ sc16[iy] = vmul_f16(vget_low_f16(qx_scales), vget_low_f16(q8_scales));
+ }
+ }
+
+ template <typename Dequantizer>
+ inline void process_scales(int i, Dequantizer& deq1, Dequantizer& deq2, float16x4_t * sc16, float32x4_t * acc) const {
+ auto qx_scales_1 = deq1.new_block(i);
+ auto qx_scales_2 = deq2.new_block(i);
+ for (int iy = 0; iy < nrc; ++iy) {
+ auto q8_scales = load_scales(iy, i);
+ auto q8_scales_l = vget_low_f16(q8_scales);
+ auto q8_scales_h = vget_high_f16(q8_scales);
+ auto m1 = vmul_f16(vget_high_f16(qx_scales_1), q8_scales_h);
+ auto m2 = vmul_f16(vget_high_f16(qx_scales_2), q8_scales_h);
+ acc[iy ] = vaddq_f32(acc[iy ], vcvt_f32_f16(m1));
+ acc[iy+nrc_y ] = vaddq_f32(acc[iy+nrc_y], vcvt_f32_f16(m2));
+ sc16[iy ] = vmul_f16(vget_low_f16(qx_scales_1), q8_scales_l);
+ sc16[iy+nrc_y] = vmul_f16(vget_low_f16(qx_scales_2), q8_scales_l);
+ }
+ }
+
+ template <typename Dequantizer>
+ inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {
+ deq.prepare1(i);
+ float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m);
+ for (int iy = 0; iy < nrc; ++iy) {
+ auto q8b = vld1q_s8_x2(y[iy][i].qs);
+ auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);
+ acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p));
+ acc[iy] = vaddq_f32(acc[iy], vdupq_n_f32(m*GGML_FP16_TO_FP32(y[iy][i].s)));
+ }
+ }
+
+ const block_q8_1 * y[nrc_y];
+};
+
+template <typename block_q>
+struct BaseLegacyDequantizer {
+
+ BaseLegacyDequantizer(const void * vx, size_t bx) : vx(vx), x(nullptr), bx(bx) {}
+
+ inline void new_row(int ix) { x = (const block_q *)((const char *)vx + bx*ix); }
+
+ Q4LegacyBits bits;
+
+ const void * vx;
+ const block_q * x;
+ size_t bx;
+};
+
+struct DequantizerQ40 final : public BaseLegacyDequantizer<block_q4_0> {
+
+ DequantizerQ40(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i, int8x16_t * q) const {
+ bits.prepare1(x[i].qs, q);
+ q[0] = vaddq_s8(q[0], m8);
+ q[1] = vaddq_s8(q[1], m8);
+ }
+ inline void prepare1(int i) {
+ prepare1(i, bits.b);
+ }
+
+ inline float16x4_t new_block(int i) {
+ ggml_half aux[4];
+ for (int k = 0; k < 4; ++k) {
+ aux[k] = x[4*i+k].d;
+ prepare1(4*i+k, bits.b + 2*k);
+ }
+ return vld1_f16((const float16_t *)aux);
+ }
+
+ const int8x16_t m8 = vdupq_n_s8(-8);
+ //ggml_half aux[4];
+};
+
+struct DequantizerQ60 final : public BaseLegacyDequantizer<block_q6_0> {
+
+ DequantizerQ60(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i, int8x16_t * q) const {
+ bits.prepare1(x[i].qs, q);
+ auto qh8 = vld1_u8(x[i].qh);
+ auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8);
+ q[0] = vaddq_s8(vorrq_u8(q[0], vandq_u8(qh, hmask)), m32);
+ q[1] = vaddq_s8(vorrq_u8(q[1], vandq_u8(vshrq_n_u8(qh, 2), hmask)), m32);
+ }
+ inline void prepare1(int i) {
+ prepare1(i, bits.b);
+ }
+
+ inline float16x4_t new_block(int i) {
+ ggml_half aux[4];
+ for (int k = 0; k < 4; ++k) {
+ aux[k] = x[4*i+k].d;
+ prepare1(4*i+k, bits.b + 2*k);
+ }
+ return vld1_f16((const float16_t *)aux);
+ }
+
+ const int8x16_t m32 = vdupq_n_s8(-32);
+ const uint8x16_t hmask = vdupq_n_u8(0x30);
+};
+
+struct DequantizerIQ4NL final : public BaseLegacyDequantizer<block_iq4_nl> {
+
+ DequantizerIQ4NL(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i, int8x16_t * q) const {
+ bits.prepare1(x[i].qs, q);
+ q[0] = vqtbl1q_s8(values, q[0]);
+ q[1] = vqtbl1q_s8(values, q[1]);
+ }
+ inline void prepare1(int i) {
+ prepare1(i, bits.b);
+ }
+
+ inline float16x4_t new_block(int i) {
+ ggml_half aux[4];
+ for (int k = 0; k < 4; ++k) {
+ aux[k] = x[4*i+k].d;
+ prepare1(4*i+k, bits.b + 2*k);
+ }
+ return vld1_f16((const float16_t *)aux);
+ }
+ static int8x16_t load_values() {
+ static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+ return vld1q_s8(iq4nl_values);
+ }
+
+ const int8x16_t values = load_values();
+};
+
+struct DequantizerQ41 : public BaseLegacyDequantizer<block_q4_1> {
+
+ DequantizerQ41(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i) {
+ bits.prepare1(x[i].qs);
+ }
+
+ inline float16x8_t new_block(int i) {
+ uint32_t aux32[4];
+ const uint32_t * s32 = (const uint32_t *)&x[4*i].d;
+ for (int k = 0; k < 4; ++k) {
+ aux32[k] = *s32; s32 += sizeof(block_q4_1)/4;
+ bits.prepare1(x[4*i+k].qs, bits.b + 2*k);
+ }
+ return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle)));
+ }
+ // Leaving this commented out attempt to be reminded that I already tried this.
+ // It has basically the same performance as the version above.
+ //inline float16x8_t new_block(int i) {
+ // uint32x4_t scales = {};
+ // const block_q4_1 * xi = x + 4*i;
+ // const uint32_t * s32 = (const uint32_t *)&xi->d;
+ // scales = vsetq_lane_u32(*s32, scales, 0); s32 += sizeof(block_q4_1)/4;
+ // bits.prepare1(xi[0].qs, bits.b + 0);
+ // scales = vsetq_lane_u32(*s32, scales, 1); s32 += sizeof(block_q4_1)/4;
+ // bits.prepare1(xi[1].qs, bits.b + 2);
+ // scales = vsetq_lane_u32(*s32, scales, 2); s32 += sizeof(block_q4_1)/4;
+ // bits.prepare1(xi[2].qs, bits.b + 4);
+ // scales = vsetq_lane_u32(*s32, scales, 3);
+ // bits.prepare1(xi[3].qs, bits.b + 6);
+ // return vreinterpretq_f16_u8(vqtbl1q_u8(vreinterpretq_u8_u32(scales), vreinterpretq_u8_u64(shuffle)));
+ //}
+
+ const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302};
+};
+
+struct HighBit5Legacy {
+ inline uint8x16_t to_bytes(const uint8_t * qh) const {
+ uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle);
+ return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vreinterpretq_u8_u64(mask));
+ }
+ inline uint8x16_t to_negated_bytes(const uint8_t * qh) const {
+ uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle);
+ return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vdupq_n_u8(0));
+ }
+ const uint64x2_t mask = vdupq_n_u64(0x8040201008040201);
+ const uint8x16_t shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1));
+};
+
+struct DequantizerQ50 final : public BaseLegacyDequantizer<block_q5_0> {
+
+ DequantizerQ50(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i, int8x16_t * q) const {
+ bits.prepare1(x[i].qs, q);
+ auto qh = x[i].qh;
+ q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0))));
+ q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2))));
+ }
+ inline void prepare1(int i) {
+ prepare1(i, bits.b);
+ }
+
+ inline float16x4_t new_block(int i) {
+ ggml_half aux[4];
+ for (int k = 0; k < 4; ++k) {
+ aux[k] = x[4*i+k].d;
+ prepare1(4*i+k, bits.b + 2*k);
+ }
+ return vld1_f16((const float16_t *)aux);
+ }
+
+ HighBit5Legacy hbits;
+
+ const uint8x16_t mh = vdupq_n_u8(0xf0);
+
+};
+
+struct DequantizerQ80 final : public BaseLegacyDequantizer<block_q8_0> {
+
+ DequantizerQ80(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i) {
+ bits.b[0] = vld1q_s8(x[i].qs);
+ bits.b[1] = vld1q_s8(x[i].qs+16);
+ }
+
+ inline float16x4_t new_block(int i) {
+ ggml_half aux[4];
+ for (int k = 0; k < 4; ++k) {
+ aux[k] = x[4*i+k].d;
+ bits.b[2*k+0] = vld1q_s8(x[4*i+k].qs);
+ bits.b[2*k+1] = vld1q_s8(x[4*i+k].qs+16);
+ }
+ return vld1_f16((const float16_t *)aux);
+ }
+
+};
+
+// TODO: handle case where row size is not a multiple of 128
+struct DequantizerQ80_x4 final : public BaseLegacyDequantizer<block_q8_0_x4> {
+
+ DequantizerQ80_x4(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i) {
+ bits.b[0] = vld1q_s8(x[i].qs);
+ bits.b[1] = vld1q_s8(x[i].qs+16);
+ }
+
+ inline float16x4_t new_block(int i) {
+ auto scale = vld1_f16((const float16_t *)x[i].d);
+ for (int k = 0; k < 4; ++k) {
+ bits.b[2*k+0] = vld1q_s8(x[i].qs+32*k);
+ bits.b[2*k+1] = vld1q_s8(x[i].qs+32*k+16);
+ }
+ return scale;
+ }
+
+};
+
+struct DequantizerQ51 final : public BaseLegacyDequantizer<block_q5_1> {
+
+ DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i, int8x16_t * q) const {
+ bits.prepare1(x[i].qs, q);
+ auto qh = x[i].qh;
+ q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_bytes(qh+0))));
+ q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_bytes(qh+2))));
+ }
+ inline void prepare1(int i) {
+ bits.prepare1(x[i].qs, bits.b);
+ }
+
+ inline float16x8_t new_block(int i) {
+ uint32_t aux32[4];
+ const uint32_t * s32 = (const uint32_t *)&x[4*i].d;
+ for (int k = 0; k < 4; ++k) {
+ aux32[k] = *s32; s32 += sizeof(block_q5_1)/4;
+ prepare1(4*i+k, bits.b + 2*k);
+ }
+ return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle)));
+ }
+
+ HighBit5Legacy hbits;
+
+ const uint8x16_t mh = vdupq_n_u8(0x10);
+ const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302};
+
+};
+
+template <typename Dequantizer, typename Q8>
+inline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto pall = sum_4_blocks(deq.bits.b, q8.quant_data(iy, i));
+ auto scale = vcvt_f32_f16(sc16[iy]);
+ acc[iy] = vmlaq_f32(acc[iy], scale, vcvtq_f32_s32(pall));
+ }
+}
+
+template <typename Dequantizer, typename Q8>
+inline void sum_4(int i, Dequantizer& deq1, Dequantizer& deq2, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto pall = sum_4_blocks(deq1.bits.b, deq2.bits.b, q8.quant_data(iy, i));
+ auto scale1 = vcvt_f32_f16(sc16[iy]);
+ auto scale2 = vcvt_f32_f16(sc16[iy+Q8::nrc_y]);
+ acc[iy] = vmlaq_f32(acc[iy], scale1, vcvtq_f32_s32(pall.val[0]));
+ acc[iy+Q8::nrc_y] = vmlaq_f32(acc[iy+Q8::nrc_y], scale2, vcvtq_f32_s32(pall.val[1]));
+ }
+}
+
+template <typename Dequantizer, typename Q8>
+inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK4_1;
+
+ float16x4_t sc16[Q8::nrc_y];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq.new_row(ix);
+
+ float32x4_t acc[Q8::nrc_y];
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
+
+ for (int i = 0; i < nb/4; ++i) {
+ q8.process_scales(i, deq, sc16, acc);
+ sum_4(i, deq, q8, sc16, acc);
+ }
+ for (int i = 4*(nb/4); i < nb; ++i) {
+ q8.process_1_block(i, deq, acc);
+ }
+
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ info.store(ix, iy, vaddvq_f32(acc[iy]));
+ }
+ }
+}
+
+template <typename Dequantizer, typename Q8>
+inline void mul_mat_qX_Y_q8_Y_IK(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK4_1;
+
+ float16x4_t sc16[2*Q8::nrc_y];
+ float32x4_t acc[2*Q8::nrc_y];
+
+ for (int ix = 0; ix < nrc_x; ix += 2) {
+
+ deq1.new_row(ix+0);
+ deq2.new_row(ix+1);
+
+ for (int iy = 0; iy < 2*Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
+
+ for (int i = 0; i < nb/4; ++i) {
+ q8.process_scales(i, deq1, deq2, sc16, acc);
+ sum_4(i, deq1, deq2, q8, sc16, acc);
+ }
+ //for (int i = 4*(nb/4); i < nb; ++i) {
+ // q8.process_1_block(i, deq, acc);
+ //}
+
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ info.store(ix+0, iy, vaddvq_f32(acc[iy]));
+ info.store(ix+1, iy, vaddvq_f32(acc[iy+Q8::nrc_y]));
+ }
+ }
+}
+
+template <typename Dequantizer, typename Q8>
+inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK4_1;
+
+ float16x4_t sc16[2];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq1.new_row(ix);
+ deq2.new_row(ix);
+
+ float32x4_t acc[2] = { vdupq_n_f32(0.f), vdupq_n_f32(0.f) };
+
+ for (int i = 0; i < nb/8; ++i) {
+ q8.process_scales(2*i+0, deq1, sc16+0, acc+0);
+ q8.process_scales(2*i+1, deq2, sc16+1, acc+1);
+ sum_4(2*i+0, deq1, q8, sc16+0, acc+0);
+ sum_4(2*i+1, deq2, q8, sc16+1, acc+1);
+ }
+ for (int i = 2*(nb/8); i < nb/4; ++i) {
+ q8.process_scales(i, deq1, sc16, acc);
+ sum_4(i, deq1, q8, sc16, acc);
+ }
+ //for (int i = 4*(nb/4); i < nb; ++i) {
+ // q8.process_1_block(i, deq1, acc);
+ //}
+
+ info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1])));
+ }
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ Q81<nrc_y> q8(info);
+ if constexpr (nrc_y == 1) {
+ Dequantizer deq1(vx, bx), deq2(vx, bx);
+ mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);
+ } else {
+ if (nrc_x%2 == 0 && n%128 == 0) {
+ Dequantizer deq1(vx, bx), deq2(vx, bx);
+ mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x);
+ } else {
+ Dequantizer deq(vx, bx);
+ mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);
+ }
+ }
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ Q80<nrc_y> q8(info);
+ if constexpr (nrc_y == 1) {
+ Dequantizer deq1(vx, bx), deq2(vx, bx);
+ mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);
+ } else {
+ if (nrc_x%2 == 0 && n%128 == 0) {
+ Dequantizer deq1(vx, bx), deq2(vx, bx);
+ mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x);
+ } else {
+ Dequantizer deq(vx, bx);
+ mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);
+ }
+ }
+}
+
+template <typename Dequantizer>
+static void mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ Dequantizer deq1(vx, bx), deq2(vx, bx);
+ Q81<1> q8(info);
+ mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);
+}
+
+template <typename Dequantizer>
+static void mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ Dequantizer deq1(vx, bx), deq2(vx, bx);
+ Q80<1> q8(info);
+ mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x);
+}
+
+template <typename Dequantizer, int nrc_y>
+void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_0_x4> q8(info);
+ Dequantizer deq(vx, bx);
+ int nb = n / QK4_NL;
+ int8x16_t qx[8];
+ float d8[4*nrc_y];
+ float32x4_t acc[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ deq.new_row(ix);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales = deq.prepare(4*ib4+k, qx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
+ auto sumi = interleaved_dotq(qx, y);
+ auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k]));
+ acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
+ }
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = deq.prepare(ib, qx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_0 *)q8.y[iy];
+ auto y = vld1q_s8_x2(qy[ib].qs);
+ auto sumi = interleaved_dotq(qx, y);
+ auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)));
+ acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, deq.result(acc[iy]));
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <typename Dequantizer, int nrc_y>
+void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_0_x4> q8(info);
+ Dequantizer deq(vx, bx);
+ int nb = n / QK4_NL;
+ int8x16_t qx[16];
+ float d8[4*nrc_y];
+ float32x4_t acc[2*nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ deq.new_row(ix);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales = deq.prepare(ib4, k, qx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
+ auto sumi1 = interleaved_dotq(qx+0, y);
+ auto sumi2 = interleaved_dotq(qx+8, y);
+ auto dy = vdupq_n_f32(d8[4*iy+k]);
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1));
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2));
+ }
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales = deq.prepare(ib, 0, qx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_0 *)q8.y[iy];
+ auto y = vld1q_s8_x2(qy[ib].qs);
+ auto sumi1 = interleaved_dotq(qx+0, y);
+ auto sumi2 = interleaved_dotq(qx+8, y);
+ auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1));
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix+0, iy, deq.result(acc[2*iy+0]));
+ info.store(ix+4, iy, deq.result(acc[2*iy+1]));
+ acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+struct IQ4_NL_R4_Dequantizer {
+ IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {}
+ inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); }
+ inline float32x4_t prepare(int ib, int8x16_t * qx) const {
+ auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ib].d));
+ auto bits = vld1q_u8_x4(iq4[ib].qs);
+ prepare_iq4_nl_quants(values, m4, bits, qx);
+ return scales;
+ }
+ inline float32x4_t result(float32x4_t acc) const {
+ return acc;
+ }
+
+ const char * cx;
+ const size_t bx;
+ const block_iq4_nl_r4 * iq4;
+ const uint8x16_t m4 = vdupq_n_u8(0x0f);
+ const int8x16_t values;
+};
+
+struct Q4_0_R8_Dequantizer {
+ Q4_0_R8_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
+ inline void new_row(int ix) { iq4 = (const block_iq4_nl_r8 *)(cx + ix*bx); }
+ inline float32x4x2_t prepare(int ib4, int k, int8x16_t * qx) const {
+ auto scales16 = vld1q_f16((const float16_t *)iq4[4*ib4+k].d);
+ float32x4x2_t scales = { vcvt_f32_f16(vget_low_f16(scales16)), vcvt_f32_f16(vget_high_f16(scales16)) };
+ for (int j = 0; j < 4; ++j) {
+ auto bits = vld1q_u8_x2(iq4[4*ib4+k].qs + 32*j);
+ bits.val[0] = veorq_u8(m88, bits.val[0]);
+ bits.val[1] = veorq_u8(m88, bits.val[1]);
+ qx[2*j+0] = vshlq_n_u8(bits.val[0], 4);
+ qx[2*j+1] = vandq_u8(bits.val[0], m4);
+ qx[2*j+8] = vshlq_n_u8(bits.val[1], 4);
+ qx[2*j+9] = vandq_u8(bits.val[1], m4);
+ }
+ return scales;
+ }
+ inline float32x4_t result(float32x4_t acc) const {
+ return vmulq_f32(norm, acc);
+ }
+
+ const char * cx;
+ const size_t bx;
+ const block_iq4_nl_r8 * iq4;
+ const uint8x16_t m4 = vdupq_n_u8(0xf0);
+ const uint8x16_t m88 = vdupq_n_u8(0x88);
+ const float32x4_t norm = vdupq_n_f32(1.f/16);
+};
+
+struct Q5_0_R4_Dequantizer {
+ Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
+ inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); }
+ inline float32x4_t prepare(int ib, int8x16_t * qx) const {
+ auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ib].d));
+ auto lbits = vld1q_u8_x4(iq5[ib].qs);
+ auto hbits = vld1q_u8(iq5[ib].qh);
+ qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3
+ qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19
+ qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7
+ qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23
+ qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11
+ qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27
+ qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15
+ qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31
+ return scales;
+ }
+ inline float32x4_t result(float32x4_t acc) const {
+ return acc;
+ }
+
+ const char * cx;
+ const size_t bx;
+ const block_q5_0_r4 * iq5;
+ const uint8x16_t m4 = vdupq_n_u8(0x0f);
+ const uint8x16_t m5 = vdupq_n_u8(0x10);
+ const int8x16_t m16 = vdupq_n_s8(-16);
+};
+
+struct Q6_0_R4_Dequantizer {
+ Q6_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
+ inline void new_row(int ix) { iq6 = (const block_q6_0_r4 *)(cx + ix*bx); }
+ inline float32x4_t prepare(int ib, int8x16_t * qx) const {
+ auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ib].d));
+ auto lbits = vld1q_u8_x4(iq6[ib].qs);
+ auto hbits = vld1q_u8_x2(iq6[ib].qh);
+ qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3
+ qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19
+ qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7
+ qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 2), m6), m32); // 20..23
+ qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits.val[0], m6), m32); // 8..11
+ qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(hbits.val[1], m6), m32); // 24..27
+ qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits.val[0], 2), m6), m32); // 12..15
+ qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits.val[1], 2), m6), m32); // 28..31
+ return scales;
+ }
+ inline float32x4_t result(float32x4_t acc) const {
+ return acc;
+ }
+
+ const char * cx;
+ const size_t bx;
+ const block_q6_0_r4 * iq6;
+ const uint8x16_t m4 = vdupq_n_u8(0x0f);
+ const uint8x16_t m6 = vdupq_n_u8(0x30);
+ const int8x16_t m32 = vdupq_n_s8(-32);
+};
+
+inline void qx_0_q8_0_dot(const int8x16_t * qx, const int8_t * qy, int32x4_t& sumi1, int32x4_t& sumi2) {
+ auto y = vld1q_s8_x2(qy);
+ sumi1 = sumi2 = vdupq_n_s32(0);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
+ sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
+ sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
+}
+
+template <int nrc_y>
+void mul_mat_q8_0_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_0_x4> q8(info);
+ int nb = n / QK8_0;
+ float32x4_t acc[2*nrc_y] = {};
+ int8x16_t qx[16];
+ float d8[4*nrc_y];
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d);
+ auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
+ auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
+ for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j);
+ int32x4_t sumi1, sumi2;
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2);
+ auto dy = vdupq_n_f32(d8[4*iy+k]);
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
+ }
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d);
+ auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
+ auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
+ for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j);
+ int32x4_t sumi1, sumi2;
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_0 *)q8.y[iy];
+ qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2);
+ auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix+0, iy, acc[2*iy+0]);
+ info.store(ix+4, iy, acc[2*iy+1]);
+ acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+}
+
+bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
+
+ if (ne00%QK8_0 != 0) return false;
+
+ auto etypeA = ggml_type(typeA);
+ auto expected_typeB = etypeA == GGML_TYPE_Q4_1 || etypeA == GGML_TYPE_Q5_1 ? GGML_TYPE_Q8_1_X4 : GGML_TYPE_Q8_0_X4;
+ if (ggml_type(typeB) != expected_typeB) return false;
+
+ func16 = nullptr;
+
+ switch (typeA) {
+ case GGML_TYPE_Q4_0:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0, DequantizerQ40, kernels);
+ break;
+ case GGML_TYPE_Q4_1:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_1, DequantizerQ41, kernels);
+ break;
+ case GGML_TYPE_Q5_0:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0, DequantizerQ50, kernels);
+ break;
+ case GGML_TYPE_Q5_1:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_1, DequantizerQ51, kernels);
+ break;
+ case GGML_TYPE_Q6_0:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0, DequantizerQ60, kernels);
+ break;
+ case GGML_TYPE_Q8_0:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0, DequantizerQ80, kernels);
+ break;
+ case GGML_TYPE_IQ4_NL:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0, DequantizerIQ4NL, kernels);
+ break;
+ case GGML_TYPE_Q4_0_R8:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r8_q8_0, Q4_0_R8_Dequantizer, kernels);
+ break;
+ case GGML_TYPE_Q5_0_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r4_q8_0, Q5_0_R4_Dequantizer, kernels);
+ break;
+ case GGML_TYPE_Q6_0_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer, kernels);
+ break;
+ case GGML_TYPE_Q8_0_R8:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_0_r8_q8_0, kernels);
+ break;
+ case GGML_TYPE_IQ4_NL_R4:
+ IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer, kernels);
+ break;
+ default:
+ return false;
+ }
+
+ return true;
+}
+
+#endif
+
+namespace {
+template <int k_step>
+inline std::pair<mul_mat_t, int> mul_mat_kernel(int int_typeA, int nq) {
+ auto typeA = ggml_type(int_typeA);
+ constexpr int kMaxQ = 8;
+#define MAKE_FUNCS(mul_mat, n) \
+ if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ>, kMaxQ);\
+ else {\
+ switch (n) {\
+ case 1: return std::make_pair(mul_mat, 1>, 1);\
+ case 2: return std::make_pair(mul_mat, 2>, 2);\
+ case 3: return std::make_pair(mul_mat, 3>, 3);\
+ case 4: return std::make_pair(mul_mat, 4>, 4);\
+ case 5: return std::make_pair(mul_mat, 5>, 5);\
+ case 6: return std::make_pair(mul_mat, 6>, 6);\
+ case 7: return std::make_pair(mul_mat, 7>, 7);\
+ }\
+ }
+#define MAKE_FUNCS_ONLY_NRC(mul_mat, n) \
+ if (n >= kMaxQ) return std::make_pair(mul_mat<kMaxQ>, kMaxQ);\
+ else {\
+ switch (n) {\
+ case 1: return std::make_pair(mul_mat<1>, 1);\
+ case 2: return std::make_pair(mul_mat<2>, 2);\
+ case 3: return std::make_pair(mul_mat<3>, 3);\
+ case 4: return std::make_pair(mul_mat<4>, 4);\
+ case 5: return std::make_pair(mul_mat<5>, 5);\
+ case 6: return std::make_pair(mul_mat<6>, 6);\
+ case 7: return std::make_pair(mul_mat<7>, 7);\
+ }\
+ }
+ if (typeA == GGML_TYPE_Q8_0) {
+#ifdef __aarch64__
+ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
+#else
+#ifdef HAVE_FANCY_SIMD
+ if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 1, k_step>, 1);
+ if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 2, k_step>, 2);
+ if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 4, k_step>, 4);
+ MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q8_0_1_Unpacker, nq);
+#else
+ if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 1, k_step>, 1);
+ if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 2, k_step>, 2);
+ if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 4, k_step>, 4);
+ MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
+#endif
+#endif
+ }
+ else if (typeA == GGML_TYPE_Q8_0_R8) {
+#ifdef __aarch64__
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq);
+#else
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_2, nq);
+#endif
+ }
+ else if (typeA == GGML_TYPE_Q6_0) {
+#ifdef __aarch64__
+ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
+#else
+ if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 1, k_step>, 1);
+ if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 2, k_step>, 2);
+ if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 4, k_step>, 4);
+ MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q6_0_1_Unpacker, nq);
+#endif
+ }
+ else if (typeA == GGML_TYPE_Q4_0) {
+#ifdef __aarch64__
+ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
+#else
+ if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 1, k_step>, 1);
+ if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 2, k_step>, 2);
+ if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 4, k_step>, 4);
+ MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_0_1_Unpacker, nq);
+#endif
+ }
+#if GGML_IQK_FA_ALL_QUANTS
+ else if (typeA == GGML_TYPE_Q4_1) {
+#ifdef __aarch64__
+ MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq);
+#else
+ MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_1_Unpacker, nq);
+#endif
+ }
+ else if (typeA == GGML_TYPE_IQ4_NL) {
+#ifdef __aarch64__
+ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerIQ4NL, nq);
+#else
+#ifdef HAVE_FANCY_SIMD
+ MAKE_FUNCS(mul_mat_qX_1_q8_2_T<IQ4_NL_Unpacker, nq);
+#else
+ MAKE_FUNCS(mul_mat_qX_0_q8_0_T<IQ4_NL_Unpacker, nq);
+#endif
+#endif
+ }
+#endif
+ else {
+ GGML_ASSERT(false);
+ }
+ return std::make_pair<mul_mat_t, int>(nullptr, 0);
+}
+
+inline std::pair<mul_mat_t, int> mul_mat_kernel(int int_typeA, int nq, int k_step) {
+ switch (k_step) {
+ case 32: return mul_mat_kernel< 32>(int_typeA, nq);
+ case 64: return mul_mat_kernel< 64>(int_typeA, nq);
+ case 128: return mul_mat_kernel<128>(int_typeA, nq);
+ default: GGML_ABORT("Fatal error");
+ }
+}
+}
+
+void iqk_gemm_legacy_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step) {
+ auto [mul_mat, nrc_q] = mul_mat_kernel(type_k, nq, k_step);
+ for (int iq = 0; iq < nq/nrc_q; ++iq) {
+ mul_mat(D, k, stride_k, info, k_step);
+ info.cur_y += nrc_q;
+ }
+ int iq = nrc_q*(nq/nrc_q);
+ if (iq < nq) {
+ auto [mul_mat1, nrc_q1] = mul_mat_kernel(type_k, nq - iq, k_step);
+ GGML_ASSERT(nrc_q1 == nq - iq);
+ mul_mat1(D, k, stride_k, info, k_step);
+ }
+}
+
+#endif
diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.h b/ggml/src/iqk/iqk_gemm_legacy_quants.h
new file mode 100644
index 00000000..a472d9bb
--- /dev/null
+++ b/ggml/src/iqk/iqk_gemm_legacy_quants.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include "iqk_common.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include <array>
+#include <utility>
+
+bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
+
+void iqk_gemm_legacy_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step);
+
+#endif
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 311554f4..abf14ed0 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -20,6 +20,13 @@
#include "iqk_mul_mat.h"
#include "iqk_quantize.h"
#include "iqk_flash_impl.h"
+#include "iqk_gemm_floats.h"
+#include "iqk_gemm_kquants.h"
+#include "iqk_gemm_iquants.h"
+#include "iqk_gemm_iqk_quants.h"
+#include "iqk_gemm_1bit.h"
+#include "iqk_gemm_legacy_quants.h"
+#include "iqk_utils.h"
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
@@ -43,116 +50,10 @@
// For fp16/fp32 matri multiplications tiling is used to improve
// performance.
-#define FA_TIMING 0
-
-#include <utility>
-#include <array>
-#if FA_TIMING
-#include <chrono>
-#include <mutex>
-struct Perf {
- using TimePoint = std::chrono::time_point<std::chrono::high_resolution_clock>;
- std::array<double, 5> times = {};
- std::mutex mutex;
- bool report;
- static auto cur_time() { return std::chrono::high_resolution_clock::now(); }
- inline void accum(int what, const TimePoint& t1) {
- auto t2 = cur_time();
- auto dt = delta(t1, t2);
- std::lock_guard<std::mutex> lock(mutex);
- times[what] += dt;
- }
- inline void accum_nolock(int what, const TimePoint& t1) {
- auto t2 = cur_time();
- auto dt = delta(t1, t2);
- times[what] += dt;
- }
- inline void add(const Perf& other) {
- std::lock_guard<std::mutex> lock(mutex);
- for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i];
- }
- Perf(bool r) : report(r) {}
- ~Perf() {
- if (report) {
- double tot = 0;
- for (auto& t : times) tot += t;
- if (!tot) return;
- printf("======================= Timing: %g ms in total\n", tot);
- for (int i = 0; i < int(times.size()); ++i) {
- if (times[i]) {
- printf("%d: %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%');
- }
- }
- }
- }
- static Perf& instance() {
- static Perf p(true);
- return p;
- }
- static double delta(const TimePoint& t1, const TimePoint& t2) {
- return 1e-6*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
- }
-};
-#endif
-
namespace {
-typedef struct {
- int32_t i1;
- int32_t i2;
-} mmid_row_mapping;
-
-struct DataInfo {
- float * s;
- const char * cy;
- size_t bs;
- size_t by;
- int cur_y = 0;
- int ne11;
- const mmid_row_mapping * row_mapping = nullptr;
- size_t bs2 = 0;
-
- inline const char * src1_row(int iy) const {
- if (!row_mapping) return cy + (cur_y + iy)*by;
- int i11 = row_mapping[cur_y + iy].i1 % ne11;
- int i12 = row_mapping[cur_y + iy].i2;
- return cy + (i11 + i12*ne11)*by;
- }
-
- inline void store(int ix, int iy, float result) const {
- *(dst_row(iy) + ix) = result;
- }
-#ifdef __AVX__
- inline void store(int ix, int iy, __m128 result) const {
- _mm_storeu_ps(dst_row(iy) + ix, result);
- }
- inline void store(int ix, int iy, __m256 result) const {
- _mm256_storeu_ps(dst_row(iy) + ix, result);
- }
-#endif
-#ifdef __AVX512F__
- inline void store(int ix, int iy, __m512 result) const {
- _mm512_storeu_ps(dst_row(iy) + ix, result);
- }
-#endif
-#ifdef __ARM_NEON
- inline void store(int ix, int iy, float32x4_t result) const {
- vst1q_f32(dst_row(iy) + ix, result);
- }
-#endif
- inline float * dst_row(int iy) const {
- if (!row_mapping) return s + (cur_y + iy)*bs;
- int i12 = row_mapping[cur_y + iy].i2;
- int i1 = row_mapping[cur_y + iy].i1;
- int i2 = i12;
- return s + i1*bs + i2*bs2;
- }
-};
-
-typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x);
-
struct MulMat {
- std::array<mul_mat_t, 8> funcs = {};
+ std::array<mul_mat_t, IQK_MAX_NY> funcs = {};
mul_mat_t func16 = nullptr;
inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {
#ifdef __aarch64__
@@ -400,8 +301,6 @@ struct MulMat {
}
#endif
}
-private:
- template <typename Dequantizer> static void set_functions(MulMat& m);
};
}
@@ -457,7 +356,7 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
if (Nx >= 256 && Nx%32 == 0) {
int nx32 = Nx/32;
int nchunk = nx32*ne02;
- if (r2 <= 8) {
+ if (r2 <= IQK_MAX_NY) {
MulMat mm;
if (!MulMat::prepare(typeA, typeB, ne00, mm, r2)) return false;
int ny = mm.funcs.size();
@@ -585,9912 +484,89 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
return true;
}
-
-namespace {
-
-inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {
- const uint16_t * scales = (const uint16_t *)scales8;
- const uint32_t a0 = scales[0] | (scales[1] << 16);
- const uint32_t a1 = scales[2] | (scales[3] << 16);
- const uint32_t a2 = scales[4] | (scales[5] << 16);
- aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030);
- aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030);
- aux32[2] = a1 & 0x3f3f3f3f;
- aux32[0] = a0 & 0x3f3f3f3f;
-}
-
-#ifdef __AVX2__
-static const uint64_t iq1s_grid_us[2048] = {
- 0x0000000000000000, 0x0000000000000002, 0x0000000000000101, 0x0000000000000200,
- 0x0000000000000202, 0x0000000000010001, 0x0000000000010101, 0x0000000000020000,
- 0x0000000000020002, 0x0000000000020200, 0x0000000000020202, 0x0000000001000101,
- 0x0000000001010001, 0x0000000001010100, 0x0000000001010102, 0x0000000001020101,
- 0x0000000002000000, 0x0000000002000002, 0x0000000002000200, 0x0000000002000202,
- 0x0000000002010101, 0x0000000002020000, 0x0000000002020002, 0x0000000002020200,
- 0x0000000002020202, 0x0000000100000100, 0x0000000100000101, 0x0000000100010001,
- 0x0000000100010100, 0x0000000100010102, 0x0000000100010201, 0x0000000100010202,
- 0x0000000100020101, 0x0000000101000001, 0x0000000101000102, 0x0000000101000201,
- 0x0000000101010002, 0x0000000101010101, 0x0000000101010202, 0x0000000101020001,
- 0x0000000101020100, 0x0000000101020102, 0x0000000101020200, 0x0000000102000101,
- 0x0000000102010001, 0x0000000102010100, 0x0000000102010102, 0x0000000102020101,
- 0x0000000200000000, 0x0000000200000002, 0x0000000200000200, 0x0000000200000202,
- 0x0000000200010101, 0x0000000200020000, 0x0000000200020002, 0x0000000200020200,
- 0x0000000200020202, 0x0000000201000101, 0x0000000201010001, 0x0000000201010201,
- 0x0000000201020100, 0x0000000201020201, 0x0000000202000000, 0x0000000202000002,
- 0x0000000202000200, 0x0000000202000202, 0x0000000202010001, 0x0000000202010101,
- 0x0000000202010201, 0x0000000202020000, 0x0000000202020002, 0x0000000202020200,
- 0x0000000202020202, 0x0000010000010001, 0x0000010000010100, 0x0000010000010102,
- 0x0000010000020101, 0x0000010001000001, 0x0000010001000201, 0x0000010001010101,
- 0x0000010001010202, 0x0000010001020100, 0x0000010001020101, 0x0000010002010001,
- 0x0000010002010201, 0x0000010002020101, 0x0000010100000001, 0x0000010100000100,
- 0x0000010100000101, 0x0000010100000102, 0x0000010100010101, 0x0000010100010200,
- 0x0000010100010202, 0x0000010100020201, 0x0000010101000000, 0x0000010101000101,
- 0x0000010101000202, 0x0000010101010000, 0x0000010101010001, 0x0000010101010100,
- 0x0000010101010101, 0x0000010101010102, 0x0000010101010201, 0x0000010101020000,
- 0x0000010101020002, 0x0000010101020101, 0x0000010101020200, 0x0000010101020202,
- 0x0000010102000001, 0x0000010102010001, 0x0000010102010101, 0x0000010102010200,
- 0x0000010102010202, 0x0000010102020001, 0x0000010102020100, 0x0000010102020101,
- 0x0000010102020102, 0x0000010102020201, 0x0000010200010100, 0x0000010200010201,
- 0x0000010201000001, 0x0000010201000100, 0x0000010201010000, 0x0000010201010002,
- 0x0000010201010101, 0x0000010201010200, 0x0000010201020000, 0x0000010201020001,
- 0x0000010201020102, 0x0000010201020201, 0x0000010202000101, 0x0000010202010001,
- 0x0000010202010100, 0x0000010202010201, 0x0000020000000000, 0x0000020000000002,
- 0x0000020000000200, 0x0000020000000202, 0x0000020000010101, 0x0000020000020000,
- 0x0000020000020002, 0x0000020000020200, 0x0000020000020202, 0x0000020001000101,
- 0x0000020001010001, 0x0000020001010102, 0x0000020001020101, 0x0000020002000000,
- 0x0000020002000002, 0x0000020002000200, 0x0000020002000202, 0x0000020002010101,
- 0x0000020002020000, 0x0000020002020002, 0x0000020002020200, 0x0000020002020202,
- 0x0000020100000101, 0x0000020100010001, 0x0000020100010100, 0x0000020100010201,
- 0x0000020100020100, 0x0000020100020101, 0x0000020101000001, 0x0000020101010000,
- 0x0000020101010001, 0x0000020101010101, 0x0000020101020001, 0x0000020101020100,
- 0x0000020101020201, 0x0000020102010001, 0x0000020102010100, 0x0000020102010102,
- 0x0000020102010201, 0x0000020102020101, 0x0000020200000000, 0x0000020200000002,
- 0x0000020200000200, 0x0000020200000202, 0x0000020200010101, 0x0000020200020000,
- 0x0000020200020002, 0x0000020200020200, 0x0000020200020202, 0x0000020201000101,
- 0x0000020201010001, 0x0000020201010201, 0x0000020201020001, 0x0000020201020101,
- 0x0000020202000000, 0x0000020202000002, 0x0000020202000101, 0x0000020202000200,
- 0x0000020202000202, 0x0000020202010101, 0x0000020202020000, 0x0000020202020002,
- 0x0000020202020200, 0x0000020202020202, 0x0001000000010000, 0x0001000000010001,
- 0x0001000000010100, 0x0001000000010201, 0x0001000000020100, 0x0001000000020101,
- 0x0001000001000001, 0x0001000001000100, 0x0001000001010000, 0x0001000001010101,
- 0x0001000001010200, 0x0001000001020001, 0x0001000001020100, 0x0001000001020101,
- 0x0001000001020201, 0x0001000002010001, 0x0001000002010100, 0x0001000002010102,
- 0x0001000002020001, 0x0001000002020101, 0x0001000100000001, 0x0001000100000100,
- 0x0001000100000102, 0x0001000100000201, 0x0001000100010000, 0x0001000100010002,
- 0x0001000100010101, 0x0001000100010200, 0x0001000100020001, 0x0001000100020100,
- 0x0001000100020201, 0x0001000101000101, 0x0001000101000202, 0x0001000101010000,
- 0x0001000101010001, 0x0001000101010002, 0x0001000101010100, 0x0001000101010101,
- 0x0001000101010102, 0x0001000101010201, 0x0001000101020000, 0x0001000101020101,
- 0x0001000102000100, 0x0001000102010002, 0x0001000102010101, 0x0001000102020001,
- 0x0001000102020100, 0x0001000200010001, 0x0001000200010100, 0x0001000200010102,
- 0x0001000200020101, 0x0001000201000000, 0x0001000201000102, 0x0001000201000201,
- 0x0001000201010002, 0x0001000201010101, 0x0001000201010200, 0x0001000201010202,
- 0x0001000201020100, 0x0001000201020102, 0x0001000202000101, 0x0001000202010001,
- 0x0001000202010100, 0x0001000202010102, 0x0001000202020101, 0x0001010000000001,
- 0x0001010000000102, 0x0001010000000201, 0x0001010000010100, 0x0001010000010101,
- 0x0001010000010200, 0x0001010000010201, 0x0001010000020001, 0x0001010000020102,
- 0x0001010001000001, 0x0001010001000101, 0x0001010001000102, 0x0001010001000200,
- 0x0001010001000202, 0x0001010001010001, 0x0001010001010100, 0x0001010001010101,
- 0x0001010001010102, 0x0001010001010201, 0x0001010001020002, 0x0001010001020101,
- 0x0001010001020200, 0x0001010002000100, 0x0001010002000201, 0x0001010002010000,
- 0x0001010002010100, 0x0001010002010101, 0x0001010002010200, 0x0001010002010201,
- 0x0001010002010202, 0x0001010002020001, 0x0001010002020100, 0x0001010002020101,
- 0x0001010002020201, 0x0001010100000002, 0x0001010100000101, 0x0001010100000202,
- 0x0001010100010001, 0x0001010100010100, 0x0001010100010101, 0x0001010100010102,
- 0x0001010100010201, 0x0001010100020000, 0x0001010100020002, 0x0001010100020101,
- 0x0001010100020200, 0x0001010100020202, 0x0001010101000001, 0x0001010101000100,
- 0x0001010101000101, 0x0001010101000102, 0x0001010101010001, 0x0001010101010002,
- 0x0001010101010100, 0x0001010101010101, 0x0001010101010102, 0x0001010101010201,
- 0x0001010101010202, 0x0001010101020001, 0x0001010101020100, 0x0001010101020101,
- 0x0001010101020102, 0x0001010101020201, 0x0001010102000000, 0x0001010102000002,
- 0x0001010102000100, 0x0001010102000101, 0x0001010102000200, 0x0001010102000202,
- 0x0001010102010000, 0x0001010102010001, 0x0001010102010100, 0x0001010102010101,
- 0x0001010102010102, 0x0001010102010201, 0x0001010102010202, 0x0001010102020000,
- 0x0001010102020002, 0x0001010102020101, 0x0001010200000001, 0x0001010200000100,
- 0x0001010200000101, 0x0001010200000102, 0x0001010200010101, 0x0001010200010102,
- 0x0001010200010200, 0x0001010200010202, 0x0001010200020001, 0x0001010200020102,
- 0x0001010201000000, 0x0001010201000002, 0x0001010201000100, 0x0001010201000101,
- 0x0001010201000200, 0x0001010201000202, 0x0001010201010001, 0x0001010201010101,
- 0x0001010201010102, 0x0001010201010200, 0x0001010201010201, 0x0001010201020001,
- 0x0001010201020100, 0x0001010201020101, 0x0001010201020200, 0x0001010201020201,
- 0x0001010201020202, 0x0001010202000102, 0x0001010202000202, 0x0001010202010002,
- 0x0001010202010101, 0x0001010202020100, 0x0001010202020201, 0x0001020000010001,
- 0x0001020000010102, 0x0001020000020101, 0x0001020001000001, 0x0001020001000100,
- 0x0001020001000102, 0x0001020001000201, 0x0001020001010000, 0x0001020001010101,
- 0x0001020001010200, 0x0001020001010202, 0x0001020001020000, 0x0001020001020001,
- 0x0001020001020100, 0x0001020001020102, 0x0001020001020201, 0x0001020002000101,
- 0x0001020002010001, 0x0001020002010100, 0x0001020002020101, 0x0001020100010000,
- 0x0001020100010002, 0x0001020100010101, 0x0001020100010202, 0x0001020100020001,
- 0x0001020100020101, 0x0001020101000002, 0x0001020101000100, 0x0001020101000101,
- 0x0001020101000200, 0x0001020101010001, 0x0001020101010100, 0x0001020101010101,
- 0x0001020101010102, 0x0001020101010201, 0x0001020101010202, 0x0001020101020000,
- 0x0001020101020101, 0x0001020101020202, 0x0001020102000201, 0x0001020102010001,
- 0x0001020102010002, 0x0001020102010101, 0x0001020102010200, 0x0001020102020001,
- 0x0001020102020102, 0x0001020102020201, 0x0001020200000201, 0x0001020200010102,
- 0x0001020200020100, 0x0001020200020102, 0x0001020201000100, 0x0001020201000102,
- 0x0001020201000201, 0x0001020201010000, 0x0001020201010002, 0x0001020201010101,
- 0x0001020201010200, 0x0001020201020001, 0x0001020201020102, 0x0001020201020201,
- 0x0001020202000101, 0x0001020202010001, 0x0001020202010102, 0x0001020202010202,
- 0x0002000000000000, 0x0002000000000002, 0x0002000000000200, 0x0002000000000202,
- 0x0002000000010101, 0x0002000000020000, 0x0002000000020002, 0x0002000000020101,
- 0x0002000000020200, 0x0002000000020202, 0x0002000001000101, 0x0002000001010001,
- 0x0002000001010201, 0x0002000001020001, 0x0002000001020101, 0x0002000002000000,
- 0x0002000002000002, 0x0002000002000200, 0x0002000002000202, 0x0002000002010101,
- 0x0002000002020000, 0x0002000002020002, 0x0002000002020101, 0x0002000002020200,
- 0x0002000002020202, 0x0002000100000101, 0x0002000100010001, 0x0002000100010100,
- 0x0002000100010201, 0x0002000100020101, 0x0002000101000002, 0x0002000101000100,
- 0x0002000101000201, 0x0002000101010101, 0x0002000101010200, 0x0002000101010202,
- 0x0002000101020001, 0x0002000101020100, 0x0002000101020101, 0x0002000101020102,
- 0x0002000102000101, 0x0002000102010000, 0x0002000102010102, 0x0002000102010201,
- 0x0002000102020101, 0x0002000200000001, 0x0002000200000200, 0x0002000200000202,
- 0x0002000200010001, 0x0002000200010101, 0x0002000200020000, 0x0002000200020002,
- 0x0002000200020200, 0x0002000200020202, 0x0002000201000101, 0x0002000201010001,
- 0x0002000201010102, 0x0002000201010201, 0x0002000201020101, 0x0002000202000001,
- 0x0002000202000200, 0x0002000202000202, 0x0002000202010001, 0x0002000202010101,
- 0x0002000202020000, 0x0002000202020002, 0x0002000202020200, 0x0002000202020202,
- 0x0002010000000101, 0x0002010000010100, 0x0002010000010102, 0x0002010000010201,
- 0x0002010000020101, 0x0002010001000100, 0x0002010001000101, 0x0002010001000102,
- 0x0002010001000201, 0x0002010001010002, 0x0002010001010101, 0x0002010001010200,
- 0x0002010001010202, 0x0002010001020102, 0x0002010002000101, 0x0002010002010001,
- 0x0002010002010100, 0x0002010002010201, 0x0002010002020001, 0x0002010002020101,
- 0x0002010100000201, 0x0002010100010101, 0x0002010100020001, 0x0002010100020201,
- 0x0002010101000000, 0x0002010101000101, 0x0002010101000200, 0x0002010101010001,
- 0x0002010101010100, 0x0002010101010101, 0x0002010101010201, 0x0002010101020002,
- 0x0002010101020101, 0x0002010101020200, 0x0002010102000201, 0x0002010102010000,
- 0x0002010102010100, 0x0002010102010101, 0x0002010102010200, 0x0002010102010202,
- 0x0002010102020001, 0x0002010102020100, 0x0002010102020102, 0x0002010102020201,
- 0x0002010200000101, 0x0002010200010000, 0x0002010200010002, 0x0002010200010201,
- 0x0002010200020101, 0x0002010201000001, 0x0002010201000201, 0x0002010201010101,
- 0x0002010201020000, 0x0002010201020001, 0x0002010201020201, 0x0002010202000100,
- 0x0002010202000102, 0x0002010202010000, 0x0002010202010202, 0x0002020000000000,
- 0x0002020000000002, 0x0002020000000200, 0x0002020000000202, 0x0002020000010101,
- 0x0002020000020000, 0x0002020000020002, 0x0002020000020200, 0x0002020000020202,
- 0x0002020001000101, 0x0002020001010001, 0x0002020001010100, 0x0002020001020101,
- 0x0002020002000000, 0x0002020002000002, 0x0002020002000200, 0x0002020002000202,
- 0x0002020002020000, 0x0002020002020002, 0x0002020002020200, 0x0002020002020202,
- 0x0002020100000201, 0x0002020100010001, 0x0002020100010100, 0x0002020100010201,
- 0x0002020100020101, 0x0002020101000102, 0x0002020101000201, 0x0002020101010002,
- 0x0002020101010101, 0x0002020101020001, 0x0002020101020100, 0x0002020101020102,
- 0x0002020101020201, 0x0002020102000101, 0x0002020102010000, 0x0002020102010102,
- 0x0002020102010201, 0x0002020102020100, 0x0002020102020101, 0x0002020200000000,
- 0x0002020200000002, 0x0002020200000200, 0x0002020200000202, 0x0002020200020000,
- 0x0002020200020002, 0x0002020200020200, 0x0002020200020202, 0x0002020201000101,
- 0x0002020201010001, 0x0002020201010102, 0x0002020201010201, 0x0002020201020101,
- 0x0002020202000000, 0x0002020202000002, 0x0002020202000200, 0x0002020202000202,
- 0x0002020202010101, 0x0002020202020000, 0x0002020202020002, 0x0002020202020200,
- 0x0002020202020202, 0x0100000000000101, 0x0100000000010001, 0x0100000000010102,
- 0x0100000000020101, 0x0100000001000201, 0x0100000001010002, 0x0100000001010101,
- 0x0100000001010200, 0x0100000001010202, 0x0100000001020001, 0x0100000001020100,
- 0x0100000001020102, 0x0100000002010100, 0x0100000002010201, 0x0100000002020001,
- 0x0100000002020102, 0x0100000100000000, 0x0100000100000001, 0x0100000100000100,
- 0x0100000100000102, 0x0100000100000201, 0x0100000100010002, 0x0100000100010101,
- 0x0100000100010102, 0x0100000100010200, 0x0100000100010202, 0x0100000100020001,
- 0x0100000100020102, 0x0100000100020201, 0x0100000101000101, 0x0100000101000200,
- 0x0100000101000202, 0x0100000101010001, 0x0100000101010100, 0x0100000101010101,
- 0x0100000101010102, 0x0100000101010201, 0x0100000101010202, 0x0100000101020101,
- 0x0100000101020200, 0x0100000101020202, 0x0100000102000001, 0x0100000102000100,
- 0x0100000102000102, 0x0100000102010000, 0x0100000102010002, 0x0100000102010101,
- 0x0100000102020000, 0x0100000102020001, 0x0100000102020002, 0x0100000200000101,
- 0x0100000200010001, 0x0100000200010100, 0x0100000200010102, 0x0100000200020101,
- 0x0100000201000001, 0x0100000201010002, 0x0100000201010101, 0x0100000201010202,
- 0x0100000201020100, 0x0100000201020201, 0x0100000202000201, 0x0100000202010100,
- 0x0100000202020101, 0x0100010000000001, 0x0100010000010101, 0x0100010000010201,
- 0x0100010000020201, 0x0100010001000101, 0x0100010001000200, 0x0100010001000202,
- 0x0100010001010001, 0x0100010001010100, 0x0100010001010101, 0x0100010001010102,
- 0x0100010001020001, 0x0100010001020002, 0x0100010001020101, 0x0100010001020200,
- 0x0100010001020202, 0x0100010002000001, 0x0100010002000102, 0x0100010002000201,
- 0x0100010002010000, 0x0100010002010002, 0x0100010002010101, 0x0100010002020000,
- 0x0100010002020001, 0x0100010002020201, 0x0100010100000001, 0x0100010100000002,
- 0x0100010100000101, 0x0100010100000202, 0x0100010100010001, 0x0100010100010100,
- 0x0100010100010101, 0x0100010100010102, 0x0100010100010201, 0x0100010100020000,
- 0x0100010100020101, 0x0100010100020202, 0x0100010101000001, 0x0100010101000100,
- 0x0100010101000101, 0x0100010101000102, 0x0100010101000201, 0x0100010101010000,
- 0x0100010101010001, 0x0100010101010100, 0x0100010101010101, 0x0100010101010102,
- 0x0100010101010200, 0x0100010101010201, 0x0100010101020001, 0x0100010101020100,
- 0x0100010101020101, 0x0100010101020102, 0x0100010101020201, 0x0100010102000002,
- 0x0100010102000100, 0x0100010102000101, 0x0100010102000200, 0x0100010102010001,
- 0x0100010102010100, 0x0100010102010101, 0x0100010102010102, 0x0100010102010201,
- 0x0100010102010202, 0x0100010102020101, 0x0100010102020200, 0x0100010102020202,
- 0x0100010200000001, 0x0100010200000101, 0x0100010200000201, 0x0100010200010100,
- 0x0100010200010101, 0x0100010200010200, 0x0100010200010202, 0x0100010200020001,
- 0x0100010200020100, 0x0100010200020201, 0x0100010201000000, 0x0100010201000002,
- 0x0100010201000101, 0x0100010201000200, 0x0100010201010000, 0x0100010201010001,
- 0x0100010201010002, 0x0100010201010101, 0x0100010201010102, 0x0100010201010201,
- 0x0100010201020002, 0x0100010201020101, 0x0100010201020200, 0x0100010202000001,
- 0x0100010202000101, 0x0100010202000202, 0x0100010202010100, 0x0100010202010101,
- 0x0100010202020001, 0x0100010202020100, 0x0100010202020102, 0x0100020000000101,
- 0x0100020000010001, 0x0100020000010101, 0x0100020000010202, 0x0100020000020101,
- 0x0100020001000002, 0x0100020001000201, 0x0100020001010000, 0x0100020001010101,
- 0x0100020001010200, 0x0100020001020001, 0x0100020001020100, 0x0100020001020102,
- 0x0100020001020201, 0x0100020002000101, 0x0100020002010001, 0x0100020002010100,
- 0x0100020002010102, 0x0100020002010201, 0x0100020002020101, 0x0100020100000001,
- 0x0100020100000101, 0x0100020100000102, 0x0100020100000202, 0x0100020100010000,
- 0x0100020100010100, 0x0100020100010101, 0x0100020100010200, 0x0100020100020001,
- 0x0100020100020100, 0x0100020100020102, 0x0100020101000000, 0x0100020101000101,
- 0x0100020101000202, 0x0100020101010001, 0x0100020101010002, 0x0100020101010100,
- 0x0100020101010101, 0x0100020101010102, 0x0100020101010201, 0x0100020101020000,
- 0x0100020101020002, 0x0100020101020101, 0x0100020101020102, 0x0100020101020202,
- 0x0100020102000102, 0x0100020102000201, 0x0100020102010002, 0x0100020102010101,
- 0x0100020102010102, 0x0100020102010200, 0x0100020102020001, 0x0100020102020100,
- 0x0100020102020102, 0x0100020102020201, 0x0100020200010102, 0x0100020201000100,
- 0x0100020201000102, 0x0100020201000201, 0x0100020201010101, 0x0100020201010200,
- 0x0100020201010202, 0x0100020201020100, 0x0100020201020201, 0x0100020202010100,
- 0x0100020202020101, 0x0101000000000001, 0x0101000000000100, 0x0101000000000101,
- 0x0101000000000102, 0x0101000000000201, 0x0101000000010002, 0x0101000000010101,
- 0x0101000000010202, 0x0101000000020001, 0x0101000000020100, 0x0101000000020201,
- 0x0101000001000000, 0x0101000001000101, 0x0101000001000200, 0x0101000001010001,
- 0x0101000001010100, 0x0101000001010101, 0x0101000001010102, 0x0101000001010201,
- 0x0101000001020101, 0x0101000001020200, 0x0101000002000102, 0x0101000002000201,
- 0x0101000002010101, 0x0101000002010200, 0x0101000002020000, 0x0101000002020001,
- 0x0101000002020102, 0x0101000002020201, 0x0101000100000101, 0x0101000100000200,
- 0x0101000100000201, 0x0101000100000202, 0x0101000100010001, 0x0101000100010100,
- 0x0101000100010101, 0x0101000100010102, 0x0101000100010200, 0x0101000100010201,
- 0x0101000100020000, 0x0101000100020101, 0x0101000100020102, 0x0101000100020200,
- 0x0101000100020202, 0x0101000101000001, 0x0101000101000100, 0x0101000101000101,
- 0x0101000101000102, 0x0101000101000201, 0x0101000101010000, 0x0101000101010001,
- 0x0101000101010002, 0x0101000101010100, 0x0101000101010101, 0x0101000101010102,
- 0x0101000101010200, 0x0101000101010201, 0x0101000101010202, 0x0101000101020001,
- 0x0101000101020100, 0x0101000101020101, 0x0101000101020102, 0x0101000101020201,
- 0x0101000102000002, 0x0101000102000101, 0x0101000102010001, 0x0101000102010100,
- 0x0101000102010101, 0x0101000102010102, 0x0101000102010201, 0x0101000102020000,
- 0x0101000102020101, 0x0101000102020202, 0x0101000200000001, 0x0101000200000102,
- 0x0101000200010002, 0x0101000200010101, 0x0101000200010202, 0x0101000200020001,
- 0x0101000200020100, 0x0101000201000002, 0x0101000201000101, 0x0101000201000202,
- 0x0101000201010001, 0x0101000201010100, 0x0101000201010101, 0x0101000201010102,
- 0x0101000201010201, 0x0101000201020002, 0x0101000201020101, 0x0101000202000101,
- 0x0101000202010000, 0x0101000202010002, 0x0101000202010101, 0x0101000202010201,
- 0x0101000202010202, 0x0101000202020100, 0x0101010000000100, 0x0101010000000101,
- 0x0101010000010001, 0x0101010000010100, 0x0101010000010101, 0x0101010000010102,
- 0x0101010000010200, 0x0101010000010201, 0x0101010000020001, 0x0101010000020101,
- 0x0101010000020200, 0x0101010000020202, 0x0101010001000001, 0x0101010001000100,
- 0x0101010001000101, 0x0101010001000102, 0x0101010001000201, 0x0101010001000202,
- 0x0101010001010000, 0x0101010001010001, 0x0101010001010100, 0x0101010001010101,
- 0x0101010001010102, 0x0101010001010200, 0x0101010001010201, 0x0101010001010202,
- 0x0101010001020001, 0x0101010001020002, 0x0101010001020100, 0x0101010001020101,
- 0x0101010001020102, 0x0101010001020201, 0x0101010002000000, 0x0101010002000200,
- 0x0101010002000202, 0x0101010002010001, 0x0101010002010100, 0x0101010002010101,
- 0x0101010002010102, 0x0101010002010201, 0x0101010002020001, 0x0101010002020100,
- 0x0101010002020101, 0x0101010002020202, 0x0101010100000001, 0x0101010100000002,
- 0x0101010100000100, 0x0101010100000101, 0x0101010100000102, 0x0101010100000201,
- 0x0101010100010000, 0x0101010100010001, 0x0101010100010002, 0x0101010100010100,
- 0x0101010100010101, 0x0101010100010102, 0x0101010100010201, 0x0101010100010202,
- 0x0101010100020001, 0x0101010100020100, 0x0101010100020101, 0x0101010100020102,
- 0x0101010100020201, 0x0101010101000000, 0x0101010101000001, 0x0101010101000002,
- 0x0101010101000100, 0x0101010101000101, 0x0101010101000102, 0x0101010101000200,
- 0x0101010101000201, 0x0101010101010000, 0x0101010101010001, 0x0101010101010002,
- 0x0101010101010100, 0x0101010101010101, 0x0101010101010102, 0x0101010101010200,
- 0x0101010101010201, 0x0101010101010202, 0x0101010101020000, 0x0101010101020001,
- 0x0101010101020100, 0x0101010101020101, 0x0101010101020102, 0x0101010101020200,
- 0x0101010101020201, 0x0101010101020202, 0x0101010102000001, 0x0101010102000100,
- 0x0101010102000101, 0x0101010102000201, 0x0101010102000202, 0x0101010102010000,
- 0x0101010102010001, 0x0101010102010100, 0x0101010102010101, 0x0101010102010102,
- 0x0101010102010200, 0x0101010102010201, 0x0101010102020001, 0x0101010102020100,
- 0x0101010102020101, 0x0101010102020102, 0x0101010102020201, 0x0101010200000000,
- 0x0101010200000001, 0x0101010200000002, 0x0101010200000100, 0x0101010200000102,
- 0x0101010200000200, 0x0101010200000201, 0x0101010200010001, 0x0101010200010100,
- 0x0101010200010101, 0x0101010200010200, 0x0101010200010201, 0x0101010200020000,
- 0x0101010200020001, 0x0101010200020002, 0x0101010200020100, 0x0101010200020101,
- 0x0101010200020102, 0x0101010200020200, 0x0101010200020201, 0x0101010201000001,
- 0x0101010201000101, 0x0101010201000102, 0x0101010201000200, 0x0101010201000201,
- 0x0101010201000202, 0x0101010201010000, 0x0101010201010001, 0x0101010201010002,
- 0x0101010201010100, 0x0101010201010101, 0x0101010201010102, 0x0101010201010200,
- 0x0101010201010201, 0x0101010201010202, 0x0101010201020001, 0x0101010201020100,
- 0x0101010201020101, 0x0101010201020201, 0x0101010202000002, 0x0101010202000101,
- 0x0101010202000102, 0x0101010202000200, 0x0101010202000201, 0x0101010202000202,
- 0x0101010202010001, 0x0101010202010101, 0x0101010202010202, 0x0101010202020002,
- 0x0101010202020101, 0x0101010202020102, 0x0101010202020200, 0x0101010202020201,
- 0x0101020000000100, 0x0101020000000101, 0x0101020000000102, 0x0101020000000201,
- 0x0101020000010000, 0x0101020000010101, 0x0101020000010200, 0x0101020000020001,
- 0x0101020000020202, 0x0101020001000101, 0x0101020001000200, 0x0101020001000202,
- 0x0101020001010001, 0x0101020001010100, 0x0101020001010101, 0x0101020001010102,
- 0x0101020001010200, 0x0101020001010201, 0x0101020001020000, 0x0101020001020002,
- 0x0101020001020100, 0x0101020001020101, 0x0101020002000002, 0x0101020002000201,
- 0x0101020002010000, 0x0101020002010002, 0x0101020002010101, 0x0101020002010200,
- 0x0101020002020001, 0x0101020002020201, 0x0101020100000001, 0x0101020100000002,
- 0x0101020100000101, 0x0101020100000202, 0x0101020100010001, 0x0101020100010100,
- 0x0101020100010101, 0x0101020100010102, 0x0101020100010201, 0x0101020100020101,
- 0x0101020101000001, 0x0101020101000100, 0x0101020101000101, 0x0101020101000102,
- 0x0101020101000201, 0x0101020101010000, 0x0101020101010001, 0x0101020101010002,
- 0x0101020101010100, 0x0101020101010101, 0x0101020101010102, 0x0101020101010200,
- 0x0101020101010201, 0x0101020101010202, 0x0101020101020001, 0x0101020101020100,
- 0x0101020101020101, 0x0101020101020102, 0x0101020101020201, 0x0101020102000001,
- 0x0101020102000101, 0x0101020102000201, 0x0101020102010001, 0x0101020102010100,
- 0x0101020102010101, 0x0101020102010102, 0x0101020102010200, 0x0101020102010201,
- 0x0101020102020101, 0x0101020200000100, 0x0101020200000200, 0x0101020200010101,
- 0x0101020200010202, 0x0101020200020000, 0x0101020200020101, 0x0101020200020102,
- 0x0101020200020201, 0x0101020201000101, 0x0101020201000200, 0x0101020201000201,
- 0x0101020201010001, 0x0101020201010101, 0x0101020201010102, 0x0101020201010200,
- 0x0101020201010201, 0x0101020201020002, 0x0101020201020101, 0x0101020201020200,
- 0x0101020201020202, 0x0101020202000001, 0x0101020202000202, 0x0101020202010002,
- 0x0101020202010101, 0x0101020202010102, 0x0101020202010200, 0x0101020202010202,
- 0x0101020202020001, 0x0102000000000101, 0x0102000000010100, 0x0102000000010102,
- 0x0102000000010201, 0x0102000000020101, 0x0102000001000100, 0x0102000001010000,
- 0x0102000001010101, 0x0102000001010102, 0x0102000001010200, 0x0102000001010202,
- 0x0102000001020001, 0x0102000001020100, 0x0102000001020102, 0x0102000001020201,
- 0x0102000002000001, 0x0102000002010102, 0x0102000002020101, 0x0102000100000001,
- 0x0102000100000100, 0x0102000100000102, 0x0102000100000201, 0x0102000100010002,
- 0x0102000100010101, 0x0102000100020001, 0x0102000100020002, 0x0102000100020102,
- 0x0102000100020201, 0x0102000101000101, 0x0102000101000201, 0x0102000101010001,
- 0x0102000101010101, 0x0102000101010102, 0x0102000101010201, 0x0102000101020101,
- 0x0102000101020102, 0x0102000101020202, 0x0102000102000100, 0x0102000102000202,
- 0x0102000102010002, 0x0102000102010101, 0x0102000102020001, 0x0102000102020102,
- 0x0102000102020201, 0x0102000200010001, 0x0102000200010102, 0x0102000200010201,
- 0x0102000201000000, 0x0102000201000001, 0x0102000201000102, 0x0102000201010101,
- 0x0102000201010102, 0x0102000201010200, 0x0102000201020000, 0x0102000202000101,
- 0x0102000202010001, 0x0102000202010102, 0x0102000202020101, 0x0102010000010001,
- 0x0102010000010002, 0x0102010000010101, 0x0102010000010102, 0x0102010000010202,
- 0x0102010000020001, 0x0102010000020102, 0x0102010000020201, 0x0102010001000000,
- 0x0102010001000002, 0x0102010001000101, 0x0102010001000200, 0x0102010001000202,
- 0x0102010001010001, 0x0102010001010100, 0x0102010001010101, 0x0102010001010102,
- 0x0102010001010201, 0x0102010001010202, 0x0102010001020000, 0x0102010001020002,
- 0x0102010001020101, 0x0102010002000100, 0x0102010002000101, 0x0102010002000201,
- 0x0102010002010000, 0x0102010002010002, 0x0102010002010100, 0x0102010002010101,
- 0x0102010002010102, 0x0102010002010200, 0x0102010002010202, 0x0102010002020001,
- 0x0102010002020100, 0x0102010002020201, 0x0102010100000101, 0x0102010100000200,
- 0x0102010100000202, 0x0102010100010001, 0x0102010100010101, 0x0102010100010102,
- 0x0102010100010201, 0x0102010101000100, 0x0102010101000101, 0x0102010101000102,
- 0x0102010101000201, 0x0102010101010000, 0x0102010101010001, 0x0102010101010100,
- 0x0102010101010101, 0x0102010101010102, 0x0102010101010201, 0x0102010101020001,
- 0x0102010101020100, 0x0102010101020101, 0x0102010101020102, 0x0102010101020201,
- 0x0102010102000102, 0x0102010102000201, 0x0102010102000202, 0x0102010102010001,
- 0x0102010102010101, 0x0102010102010102, 0x0102010102010201, 0x0102010102010202,
- 0x0102010102020002, 0x0102010102020101, 0x0102010102020102, 0x0102010102020200,
- 0x0102010200000002, 0x0102010200000201, 0x0102010200010101, 0x0102010200020000,
- 0x0102010200020102, 0x0102010200020200, 0x0102010200020201, 0x0102010201000000,
- 0x0102010201000101, 0x0102010201000200, 0x0102010201000202, 0x0102010201010001,
- 0x0102010201010100, 0x0102010201010101, 0x0102010201010102, 0x0102010201010200,
- 0x0102010201010202, 0x0102010201020000, 0x0102010201020101, 0x0102010201020200,
- 0x0102010202000000, 0x0102010202000002, 0x0102010202000101, 0x0102010202000202,
- 0x0102010202010100, 0x0102010202010102, 0x0102010202010200, 0x0102010202010201,
- 0x0102010202020000, 0x0102010202020100, 0x0102010202020102, 0x0102010202020202,
- 0x0102020000010102, 0x0102020000010201, 0x0102020000020101, 0x0102020001000001,
- 0x0102020001010002, 0x0102020001010101, 0x0102020001010202, 0x0102020001020001,
- 0x0102020001020201, 0x0102020002000101, 0x0102020002010001, 0x0102020002010200,
- 0x0102020002020102, 0x0102020100000001, 0x0102020100000100, 0x0102020100010000,
- 0x0102020100010101, 0x0102020100020001, 0x0102020100020100, 0x0102020100020102,
- 0x0102020100020201, 0x0102020101000000, 0x0102020101000001, 0x0102020101000101,
- 0x0102020101000102, 0x0102020101000200, 0x0102020101010001, 0x0102020101010100,
- 0x0102020101010101, 0x0102020101010102, 0x0102020101010201, 0x0102020101020000,
- 0x0102020101020101, 0x0102020101020202, 0x0102020102000002, 0x0102020102000100,
- 0x0102020102000202, 0x0102020102010101, 0x0102020102020001, 0x0102020102020100,
- 0x0102020102020101, 0x0102020102020201, 0x0102020200010001, 0x0102020200010102,
- 0x0102020200010200, 0x0102020201000001, 0x0102020201000100, 0x0102020201000201,
- 0x0102020201010000, 0x0102020201010101, 0x0102020201010200, 0x0102020201010202,
- 0x0102020201020100, 0x0102020201020101, 0x0102020201020201, 0x0102020202000102,
- 0x0102020202010100, 0x0102020202010200, 0x0102020202010202, 0x0102020202020102,
- 0x0200000000000000, 0x0200000000000002, 0x0200000000000200, 0x0200000000000202,
- 0x0200000000020000, 0x0200000000020002, 0x0200000000020200, 0x0200000000020202,
- 0x0200000001000101, 0x0200000001010000, 0x0200000001010001, 0x0200000001010100,
- 0x0200000001010102, 0x0200000001010201, 0x0200000001020101, 0x0200000002000000,
- 0x0200000002000002, 0x0200000002000200, 0x0200000002000202, 0x0200000002010101,
- 0x0200000002020000, 0x0200000002020002, 0x0200000002020200, 0x0200000002020202,
- 0x0200000100000101, 0x0200000100010001, 0x0200000100010100, 0x0200000100010102,
- 0x0200000100010201, 0x0200000100020101, 0x0200000101000001, 0x0200000101000100,
- 0x0200000101000201, 0x0200000101010000, 0x0200000101010002, 0x0200000101010101,
- 0x0200000101010102, 0x0200000101010200, 0x0200000101010201, 0x0200000101020100,
- 0x0200000101020102, 0x0200000101020201, 0x0200000102000101, 0x0200000102000201,
- 0x0200000102010100, 0x0200000102010102, 0x0200000102010201, 0x0200000102020101,
- 0x0200000200000000, 0x0200000200000002, 0x0200000200000200, 0x0200000200000202,
- 0x0200000200010101, 0x0200000200020000, 0x0200000200020002, 0x0200000200020200,
- 0x0200000200020202, 0x0200000201010001, 0x0200000201010100, 0x0200000201010201,
- 0x0200000201020101, 0x0200000202000000, 0x0200000202000002, 0x0200000202000200,
- 0x0200000202000202, 0x0200000202010101, 0x0200000202020000, 0x0200000202020002,
- 0x0200000202020200, 0x0200000202020202, 0x0200010000010100, 0x0200010000010201,
- 0x0200010001000001, 0x0200010001000100, 0x0200010001010001, 0x0200010001010101,
- 0x0200010001010202, 0x0200010001020001, 0x0200010001020100, 0x0200010001020201,
- 0x0200010002010100, 0x0200010002010201, 0x0200010100000001, 0x0200010100000201,
- 0x0200010100010002, 0x0200010100010101, 0x0200010100010202, 0x0200010100020102,
- 0x0200010100020201, 0x0200010101000000, 0x0200010101000001, 0x0200010101000101,
- 0x0200010101000200, 0x0200010101010001, 0x0200010101010100, 0x0200010101010101,
- 0x0200010101010102, 0x0200010101010201, 0x0200010101010202, 0x0200010101020101,
- 0x0200010101020102, 0x0200010101020200, 0x0200010101020202, 0x0200010102000001,
- 0x0200010102000100, 0x0200010102000102, 0x0200010102000201, 0x0200010102010000,
- 0x0200010102010002, 0x0200010102010101, 0x0200010102010200, 0x0200010102020102,
- 0x0200010200010001, 0x0200010200010102, 0x0200010200010201, 0x0200010200020101,
- 0x0200010201000001, 0x0200010201000100, 0x0200010201000201, 0x0200010201000202,
- 0x0200010201010000, 0x0200010201010101, 0x0200010201010201, 0x0200010201010202,
- 0x0200010201020001, 0x0200010201020102, 0x0200010201020202, 0x0200010202000101,
- 0x0200010202010001, 0x0200010202010202, 0x0200010202020100, 0x0200020000000000,
- 0x0200020000000002, 0x0200020000000200, 0x0200020000000202, 0x0200020000010101,
- 0x0200020000020000, 0x0200020000020002, 0x0200020000020200, 0x0200020000020202,
- 0x0200020001000001, 0x0200020001000101, 0x0200020001010001, 0x0200020001010100,
- 0x0200020001010201, 0x0200020001020101, 0x0200020001020201, 0x0200020002000000,
- 0x0200020002000002, 0x0200020002000200, 0x0200020002000202, 0x0200020002010101,
- 0x0200020002020000, 0x0200020002020002, 0x0200020002020200, 0x0200020002020202,
- 0x0200020100000101, 0x0200020100000102, 0x0200020100010001, 0x0200020100010100,
- 0x0200020100010102, 0x0200020100020101, 0x0200020101000001, 0x0200020101000100,
- 0x0200020101000102, 0x0200020101000201, 0x0200020101010000, 0x0200020101010002,
- 0x0200020101010101, 0x0200020101010202, 0x0200020101020001, 0x0200020101020100,
- 0x0200020102000101, 0x0200020102010102, 0x0200020102010201, 0x0200020102020101,
- 0x0200020200000000, 0x0200020200000002, 0x0200020200000200, 0x0200020200000202,
- 0x0200020200010101, 0x0200020200020000, 0x0200020200020002, 0x0200020200020200,
- 0x0200020200020202, 0x0200020201000101, 0x0200020201010001, 0x0200020201010100,
- 0x0200020201010102, 0x0200020202000000, 0x0200020202000002, 0x0200020202000200,
- 0x0200020202000202, 0x0200020202010101, 0x0200020202020000, 0x0200020202020002,
- 0x0200020202020200, 0x0200020202020202, 0x0201000000000101, 0x0201000000010001,
- 0x0201000000010102, 0x0201000000010200, 0x0201000000010201, 0x0201000000020101,
- 0x0201000001000001, 0x0201000001000102, 0x0201000001000201, 0x0201000001010101,
- 0x0201000001010200, 0x0201000001010202, 0x0201000001020201, 0x0201000001020202,
- 0x0201000002000101, 0x0201000002010001, 0x0201000002010100, 0x0201000002010102,
- 0x0201000002010201, 0x0201000002020101, 0x0201000100000001, 0x0201000100000100,
- 0x0201000100000102, 0x0201000100000201, 0x0201000100010000, 0x0201000100010101,
- 0x0201000100010200, 0x0201000100010202, 0x0201000100020001, 0x0201000100020100,
- 0x0201000100020102, 0x0201000100020201, 0x0201000101000000, 0x0201000101000101,
- 0x0201000101010000, 0x0201000101010001, 0x0201000101010100, 0x0201000101010101,
- 0x0201000101010102, 0x0201000101010201, 0x0201000101020002, 0x0201000101020101,
- 0x0201000102000100, 0x0201000102000102, 0x0201000102010002, 0x0201000102010101,
- 0x0201000102010200, 0x0201000102020001, 0x0201000102020100, 0x0201000102020102,
- 0x0201000102020201, 0x0201000200000101, 0x0201000200010001, 0x0201000200010100,
- 0x0201000200010201, 0x0201000200020101, 0x0201000201000100, 0x0201000201000102,
- 0x0201000201000201, 0x0201000201010000, 0x0201000201010002, 0x0201000201010101,
- 0x0201000201010200, 0x0201000201020102, 0x0201000201020201, 0x0201000202000101,
- 0x0201000202010100, 0x0201000202010102, 0x0201000202020201, 0x0201010000000001,
- 0x0201010000000100, 0x0201010000000102, 0x0201010000010000, 0x0201010000010101,
- 0x0201010000010200, 0x0201010000020102, 0x0201010001000000, 0x0201010001000202,
- 0x0201010001010001, 0x0201010001010100, 0x0201010001010101, 0x0201010001010102,
- 0x0201010001010200, 0x0201010001010201, 0x0201010001020000, 0x0201010001020001,
- 0x0201010001020002, 0x0201010001020101, 0x0201010002000100, 0x0201010002000102,
- 0x0201010002010002, 0x0201010002010100, 0x0201010002010101, 0x0201010002010200,
- 0x0201010002020001, 0x0201010002020201, 0x0201010100000000, 0x0201010100000101,
- 0x0201010100000200, 0x0201010100000202, 0x0201010100010000, 0x0201010100010001,
- 0x0201010100010100, 0x0201010100010101, 0x0201010100010102, 0x0201010100010201,
- 0x0201010100020001, 0x0201010100020101, 0x0201010100020201, 0x0201010100020202,
- 0x0201010101000001, 0x0201010101000100, 0x0201010101000101, 0x0201010101000102,
- 0x0201010101000201, 0x0201010101010000, 0x0201010101010001, 0x0201010101010002,
- 0x0201010101010100, 0x0201010101010101, 0x0201010101010102, 0x0201010101010200,
- 0x0201010101010201, 0x0201010101010202, 0x0201010101020001, 0x0201010101020100,
- 0x0201010101020101, 0x0201010101020102, 0x0201010101020201, 0x0201010102000001,
- 0x0201010102000101, 0x0201010102000200, 0x0201010102010001, 0x0201010102010002,
- 0x0201010102010100, 0x0201010102010101, 0x0201010102010102, 0x0201010102010201,
- 0x0201010102010202, 0x0201010102020000, 0x0201010102020002, 0x0201010102020101,
- 0x0201010102020200, 0x0201010102020202, 0x0201010200000001, 0x0201010200000100,
- 0x0201010200010000, 0x0201010200010101, 0x0201010200010201, 0x0201010200020000,
- 0x0201010200020102, 0x0201010200020201, 0x0201010201000101, 0x0201010201000200,
- 0x0201010201000201, 0x0201010201010001, 0x0201010201010002, 0x0201010201010101,
- 0x0201010201010102, 0x0201010201010201, 0x0201010201020101, 0x0201010201020200,
- 0x0201010202000002, 0x0201010202000100, 0x0201010202000201, 0x0201010202000202,
- 0x0201010202010002, 0x0201010202010100, 0x0201010202010101, 0x0201010202020100,
- 0x0201010202020102, 0x0201010202020201, 0x0201020000000101, 0x0201020000010102,
- 0x0201020000010201, 0x0201020000020101, 0x0201020001000001, 0x0201020001000102,
- 0x0201020001010000, 0x0201020001010002, 0x0201020001010101, 0x0201020001010102,
- 0x0201020001010202, 0x0201020001020100, 0x0201020001020101, 0x0201020002000101,
- 0x0201020002010001, 0x0201020002010102, 0x0201020002010201, 0x0201020002020101,
- 0x0201020100000100, 0x0201020100000102, 0x0201020100000201, 0x0201020100010000,
- 0x0201020100010002, 0x0201020100010101, 0x0201020100010200, 0x0201020100010202,
- 0x0201020100020000, 0x0201020100020001, 0x0201020100020100, 0x0201020100020102,
- 0x0201020101000000, 0x0201020101000002, 0x0201020101000101, 0x0201020101000200,
- 0x0201020101000202, 0x0201020101010001, 0x0201020101010100, 0x0201020101010101,
- 0x0201020101010102, 0x0201020101010201, 0x0201020101020002, 0x0201020101020101,
- 0x0201020101020102, 0x0201020101020202, 0x0201020102000001, 0x0201020102000100,
- 0x0201020102010000, 0x0201020102010002, 0x0201020102010101, 0x0201020102010202,
- 0x0201020102020001, 0x0201020102020102, 0x0201020200000101, 0x0201020200010101,
- 0x0201020200020101, 0x0201020201000100, 0x0201020201000102, 0x0201020201000201,
- 0x0201020201010000, 0x0201020201010101, 0x0201020201010200, 0x0201020201020001,
- 0x0201020202000101, 0x0201020202010001, 0x0201020202010100, 0x0201020202010101,
- 0x0201020202010102, 0x0202000000000000, 0x0202000000000002, 0x0202000000000200,
- 0x0202000000000202, 0x0202000000010101, 0x0202000000020000, 0x0202000000020002,
- 0x0202000000020200, 0x0202000000020202, 0x0202000001000101, 0x0202000001010001,
- 0x0202000001010100, 0x0202000001010102, 0x0202000001010201, 0x0202000002000000,
- 0x0202000002000002, 0x0202000002000200, 0x0202000002000202, 0x0202000002010101,
- 0x0202000002020000, 0x0202000002020002, 0x0202000002020200, 0x0202000002020202,
- 0x0202000100000101, 0x0202000100000201, 0x0202000100010001, 0x0202000100010100,
- 0x0202000100010102, 0x0202000100010201, 0x0202000100010202, 0x0202000101000102,
- 0x0202000101000201, 0x0202000101010001, 0x0202000101010101, 0x0202000101010200,
- 0x0202000101010202, 0x0202000101020001, 0x0202000101020100, 0x0202000102000101,
- 0x0202000102010000, 0x0202000102010002, 0x0202000102010102, 0x0202000102010201,
- 0x0202000200000002, 0x0202000200000200, 0x0202000200000202, 0x0202000200010000,
- 0x0202000200010201, 0x0202000200020002, 0x0202000200020200, 0x0202000200020202,
- 0x0202000201000101, 0x0202000201010001, 0x0202000201010102, 0x0202000201010201,
- 0x0202000201020101, 0x0202000202000000, 0x0202000202000002, 0x0202000202000200,
- 0x0202000202000202, 0x0202000202010101, 0x0202000202020000, 0x0202000202020002,
- 0x0202000202020200, 0x0202000202020202, 0x0202010000010201, 0x0202010000020101,
- 0x0202010001000001, 0x0202010001000100, 0x0202010001010000, 0x0202010001010100,
- 0x0202010001010101, 0x0202010001010200, 0x0202010001010202, 0x0202010001020001,
- 0x0202010001020101, 0x0202010001020102, 0x0202010001020200, 0x0202010001020201,
- 0x0202010002000101, 0x0202010100000102, 0x0202010100000201, 0x0202010100010000,
- 0x0202010100010002, 0x0202010100010101, 0x0202010100010200, 0x0202010100020102,
- 0x0202010100020201, 0x0202010101000002, 0x0202010101000101, 0x0202010101010001,
- 0x0202010101010100, 0x0202010101010101, 0x0202010101010102, 0x0202010101010201,
- 0x0202010101020101, 0x0202010101020202, 0x0202010102000001, 0x0202010102000100,
- 0x0202010102000101, 0x0202010102000102, 0x0202010102000201, 0x0202010102010002,
- 0x0202010102010101, 0x0202010102010200, 0x0202010200000101, 0x0202010200010001,
- 0x0202010200010102, 0x0202010200010202, 0x0202010200020001, 0x0202010200020101,
- 0x0202010201000100, 0x0202010201000102, 0x0202010201000202, 0x0202010201010002,
- 0x0202010201010101, 0x0202010201010102, 0x0202010201010200, 0x0202010201020000,
- 0x0202010201020002, 0x0202010202000102, 0x0202010202010000, 0x0202010202010101,
- 0x0202010202010102, 0x0202010202010201, 0x0202010202020001, 0x0202010202020100,
- 0x0202010202020102, 0x0202020000000000, 0x0202020000000002, 0x0202020000000200,
- 0x0202020000000202, 0x0202020000020000, 0x0202020000020002, 0x0202020000020200,
- 0x0202020000020202, 0x0202020001010001, 0x0202020001010100, 0x0202020001010102,
- 0x0202020001010201, 0x0202020002000000, 0x0202020002000002, 0x0202020002000200,
- 0x0202020002000202, 0x0202020002010101, 0x0202020002020000, 0x0202020002020002,
- 0x0202020002020200, 0x0202020002020202, 0x0202020100000101, 0x0202020100010100,
- 0x0202020100010201, 0x0202020100020001, 0x0202020100020101, 0x0202020101000001,
- 0x0202020101010000, 0x0202020101010101, 0x0202020101010202, 0x0202020101020001,
- 0x0202020101020102, 0x0202020101020201, 0x0202020102010000, 0x0202020102010102,
- 0x0202020200000000, 0x0202020200000002, 0x0202020200000200, 0x0202020200000202,
- 0x0202020200020000, 0x0202020200020002, 0x0202020200020200, 0x0202020200020202,
- 0x0202020201010001, 0x0202020201010100, 0x0202020201010102, 0x0202020202000000,
- 0x0202020202000002, 0x0202020202000200, 0x0202020202000202, 0x0202020202010101,
- 0x0202020202020000, 0x0202020202020002, 0x0202020202020200, 0x0202020202020202,
-};
-#else
-static const uint32_t iq1s_grid_us[2048] = {
- 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,
- 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,
- 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,
- 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212,
- 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011,
- 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111,
- 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220,
- 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022,
- 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,
- 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101,
- 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110,
- 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111,
- 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010,
- 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210,
- 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221,
- 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021,
- 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002,
- 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,
- 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101,
- 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211,
- 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110,
- 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022,
- 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121,
- 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220,
- 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001,
- 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101,
- 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,
- 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012,
- 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010,
- 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111,
- 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122,
- 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222,
- 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001,
- 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102,
- 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101,
- 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,
- 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101,
- 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112,
- 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110,
- 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211,
- 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012,
- 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111,
- 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120,
- 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122,
- 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,
- 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221,
- 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001,
- 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101,
- 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101,
- 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011,
- 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111,
- 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011,
- 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122,
- 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,
- 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222,
- 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101,
- 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000,
- 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200,
- 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110,
- 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112,
- 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222,
- 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021,
- 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,
- 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201,
- 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200,
- 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101,
- 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011,
- 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010,
- 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211,
- 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121,
- 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000,
- 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,
- 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202,
- 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211,
- 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112,
- 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020,
- 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121,
- 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222,
- 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102,
- 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100,
- 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,
- 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011,
- 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111,
- 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110,
- 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121,
- 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222,
- 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201,
- 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102,
- 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201,
- 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,
- 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010,
- 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010,
- 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110,
- 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011,
- 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212,
- 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021,
- 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021,
- 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021,
- 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,
- 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101,
- 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100,
- 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010,
- 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111,
- 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010,
- 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111,
- 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120,
- 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120,
- 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,
- 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001,
- 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201,
- 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210,
- 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211,
- 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111,
- 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112,
- 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211,
- 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010,
- 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,
- 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122,
- 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221,
- 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102,
- 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100,
- 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101,
- 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101,
- 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101,
- 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012,
- 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,
- 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112,
- 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210,
- 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210,
- 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210,
- 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010,
- 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110,
- 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122,
- 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020,
- 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,
- 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022,
- 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120,
- 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222,
- 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221,
- 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001,
- 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102,
- 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201,
- 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012,
- 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,
- 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012,
- 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110,
- 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110,
- 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121,
- 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221,
- 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220,
- 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222,
- 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000,
- 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,
- 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012,
- 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011,
- 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212,
- 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221,
- 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121,
- 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202,
- 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202,
- 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002,
- 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,
- 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210,
- 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112,
- 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011,
- 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011,
- 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210,
- 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020,
- 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220,
- 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222,
- 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,
- 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001,
- 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010,
- 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111,
- 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010,
- 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110,
- 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221,
- 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122,
- 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202,
- 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,
- 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101,
- 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112,
- 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111,
- 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211,
- 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222,
- 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221,
- 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022,
- 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101,
- 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,
- 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111,
- 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111,
- 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010,
- 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121,
- 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222,
- 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000,
- 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202,
- 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000,
- 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,
- 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110,
- 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110,
- 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222,
- 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120,
- 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022,
- 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101,
- 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202,
- 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110,
- 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,
- 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111,
- 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111,
- 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120,
- 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121,
- 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001,
- 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202,
- 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001,
- 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200,
- 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,
- 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212,
- 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012,
- 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110,
- 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012,
- 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111,
- 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020,
- 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121,
- 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222,
- 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,
- 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102,
- 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101,
- 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212,
- 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210,
- 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111,
- 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212,
- 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221,
- 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121,
- 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,
- 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000,
- 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202,
- 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112,
- 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111,
- 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020,
- 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221,
- 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022,
- 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100,
- 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,
- 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112,
- 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211,
- 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012,
- 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121,
- 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020,
- 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120,
- 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200,
- 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200,
- 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,
- 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011,
- 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222,
- 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,
- 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,
-};
-#endif
-
-#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__)
-const uint64_t keven_signs[128] = {
- 0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,
- 0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,
- 0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff,
- 0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff,
- 0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff,
- 0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff,
- 0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff,
- 0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff,
- 0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff,
- 0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff,
- 0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff,
- 0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff,
- 0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff,
- 0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff,
- 0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff,
- 0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff,
- 0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff,
- 0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff,
- 0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff,
- 0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff,
- 0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff,
- 0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff,
- 0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff,
- 0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff,
- 0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff,
- 0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff,
- 0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff,
- 0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff,
- 0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff,
- 0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff,
- 0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,
- 0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,
-};
-#endif
-
-}
-
#if defined __x86_64__
namespace {
-inline float hsum_float_4(__m128 x) {
- x = _mm_add_ps(x, _mm_movehl_ps(x, x));
- x = _mm_add_ss(x, _mm_movehdup_ps(x));
- return _mm_cvtss_f32(x);
-}
-inline float hsum_float_8(__m256 x) {
- return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
-}
-inline int hsum_i32_8(const __m256i a) {
- const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
- const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
- const __m128i sum64 = _mm_add_epi32(hi64, sum128);
- const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
- return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
-}
-inline float hmax_float_8(__m256 x) {
- __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
- max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4));
- max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
- return _mm_cvtss_f32(max4);
-}
-
-#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
-
-template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
-
- constexpr static int nrc_y = nrc;
-
- Q8(const DataInfo& info) {
- for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);
- }
-
-#ifdef HAVE_FANCY_SIMD
- inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); }
-#endif
- inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); }
- inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); }
- inline float scale(int iy, int i) const { return y[iy][i].d; }
-
- const block_q8 * y[nrc_y];
-};
-
-template <int nrc> struct Q8_16 {
-
- constexpr static int nrc_y = nrc;
-
- Q8_16(const DataInfo& info) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto ptr = (const float *)info.src1_row(iy);
- std::memcpy(d + 5*iy, ptr, 5*sizeof(float));
- y[iy] = (const int8_t *)(ptr + 5);
- }
- }
-
-#ifdef HAVE_FANCY_SIMD
- inline __m512i load_quants64(int iy, int i) const { return _mm512_loadu_si512((const __m512i*)y[iy] + i); }
-#endif
- inline __m256i load_quants(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy] + i); }
- inline float scale(int iy, int k) const { return d[5*iy+k]; }
- inline float sum_row(int iy) const { return d[5*iy + 4]; }
- inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 5*iy); }
-
- float d[5*nrc_y];
- const int8_t * y[nrc_y];
-};
-
-struct Scales8KBase {
- template <typename Q8>
- inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
- const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- const __m256i q8s = q8.load_bsums(iy, i);
- const __m256i prod = _mm256_madd_epi16(mins, q8s);
- accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
- }
- }
- inline __m256i shuffle(__m128i mins) const {
- return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0]));
- }
- const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),
- _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};
-};
-
-// Handles q4_K and q5_K scales/mins
-struct Scales8K {
- template <typename Q8>
- inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
- make_q4_scales(data, utmp);
- const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
- const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1);
- accum_mins(mins128, q8, i, c, accd);
- const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
- return MM256_SET_M128I(sc128, sc128);
- }
-#ifdef HAVE_FANCY_SIMD
- template <typename Q8>
- inline __m512i process_mins_and_scales_64(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
- auto scales = process_mins_and_scales(data, c, i, q8, accd);
- return _mm512_inserti32x8(_mm512_castsi256_si512(scales), scales, 1);
- }
-#endif
- template <typename Q8>
- inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
- base.accum_mins(mins128, q8, i, c, accd);
- }
-#ifdef HAVE_FANCY_SIMD
- const __m512i shuffles512[2] = {
- _mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302,
- 0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100),
- _mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a,
- 0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
- };
-#endif
- Scales8KBase base;
-
- uint32_t utmp[4];
-};
-
-template <typename Q8>
-inline void process_mins_16(const __m256i& all_scales, const Q8& q8, int i, float d, __m256 * accm) {
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- const __m256i prod = _mm256_madd_epi16(all_scales, q8.load_bsums(iy, i));
- accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
- }
-}
-inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {
- const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
- const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
- scales[0] = MM256_SET_M128I(l_scales, l_scales);
- scales[1] = MM256_SET_M128I(h_scales, h_scales);
-}
-
-struct ScaleQ3 {
- inline __m128i make_scales(const uint16_t * s8) const {
- const uint16_t * scales16 = (const uint16_t *)s8;
- uint32_t aux0 = scales16[0] | (scales16[1] << 16);
- uint32_t aux1 = scales16[2] | (scales16[3] << 16);
- uint32_t aux2 = scales16[4] | (scales16[5] << 16);
- __m128i scales128 = _mm_set_epi32(
- ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030),
- ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030),
- (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030),
- (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030));
- return _mm_add_epi8(scales128, m32);
- }
- const __m128i m32 = _mm_set1_epi8(-32);
-};
-
-struct ScaleIQ4XS {
- inline __m128i make_scales(const uint32_t scales_l, const uint16_t scales_h) {
- uint32_t tmp32 = scales_h | (scales_h << 14);
- const __m128i sh = _mm_slli_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(tmp32), hshift), hmask), 4);
- const __m128i sl = _mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(scales_l), lshift), lmask);
- return _mm_add_epi16(_mm_or_si128(sh, _mm_cvtepi8_epi16(_mm_shuffle_epi8(sl, lshuffle))), m32);
- }
- const __m128i hshift = _mm_set_epi32(12, 8, 4, 0);
- const __m128i lshift = _mm_set_epi32(4, 0, 4, 0);
- const __m128i hmask = _mm_set1_epi16(0x03);
- const __m128i lmask = _mm_set1_epi8(0xf);
- const __m128i lshuffle = _mm_set_epi32(0x07030602, 0x05010400, 0x07030602, 0x05010400);
- const __m128i m32 = _mm_set1_epi16(-32);
-};
-
-template <typename Block, bool per_row_scale = false, bool is_f16 = false>
-struct BaseDequantizer {
- BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {}
- inline void new_row(int ix) {
- if constexpr (per_row_scale) {
- if constexpr (is_f16) {
- const ggml_half * dptr = (const ggml_half *)((const char *)vx + bx*ix);
- d = GGML_FP16_TO_FP32(*dptr);
- x = (const Block *)(dptr + 1);
- } else {
- const float * dptr = (const float *)((const char *)vx + bx*ix);
- d = *dptr;
- x = (const Block *)(dptr + 1);
- }
- } else {
- x = (const Block *)((const char *)vx + bx*ix);
- }
- }
-
- const void * vx;
- const size_t bx;
- const Block * x;
-
- float d;
-};
-
-inline __m256i get_scale_shuffle_8(int i) {
- return _mm256_set1_epi16((2*i) | ((2*i+1) << 8));
-}
-
-inline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) {
- scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0));
- scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1));
- scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2));
- scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3));
-}
-
-inline __m256i get_scale_shuffle_16(int i) {
- static const uint8_t k_shuffle[128] = {
- 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
- 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
- 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
- 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
- };
- return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
-}
-
-inline void set_scales_16(const __m256i& all_scales, __m256i * scales) {
- scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));
- scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));
- scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));
- scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));
-}
-
-template <typename Q8, typename Bits>
-inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
- if (j == 0) {
-#ifdef HAVE_FANCY_SIMD
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
- sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
- sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
- sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
- }
-#else
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
- const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
- const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
- const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
- sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));
- }
-#endif
- } else {
-#ifdef HAVE_FANCY_SIMD
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
- sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
- sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
- sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
- }
-#else
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
- const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
- const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
- const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
- sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));
- sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));
- }
-#endif
- }
-}
-
-template <typename Q8, typename Bits>
-inline void multiply_add_avx2(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
- __m256i p[4];
- if (j == 0) {
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- for (int k = 0; k < 4; ++k) {
- auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
- p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, k), bits.values[k])));
- }
- sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p[0], p[1]), _mm256_add_epi32(p[2], p[3]));
- }
- } else {
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- for (int k = 0; k < 4; ++k) {
- auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
- p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, 4+k), bits.values[k])));
- }
- sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[0], p[2]));
- sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[1], p[3]));
- }
- }
-}
-
-struct SignHelper {
- inline __m256i make_signs(uint32_t sign_bits) const {
- auto aux256 = _mm256_set1_epi32(sign_bits);
- aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256, mask1), mask2);
- return _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone);
- }
-// inline __m256i make_signs(const uint16_t * sign_bits) const {
-//#ifdef HAVE_FANCY_SIMD
-//#else
-// return make_signs(sign_bits[0] | (sign_bits[1] << 16));
-//#endif
-// }
- inline __m256i sign_value(const uint16_t * sign_bits, const __m256i& value) const {
-#ifdef HAVE_FANCY_SIMD
- const __mmask32 * mask = (const __mmask32 *)sign_bits;
- return _mm256_mask_sub_epi8(value, mask[0], _mm256_setzero_si256(), value);
-#else
- return _mm256_sign_epi8(value, make_signs(sign_bits[0] | (sign_bits[1] << 16)));
-#endif
- }
- inline void sign_4_values(const uint16_t * sign_bits, __m256i * values) const {
-#ifdef HAVE_FANCY_SIMD
- const __mmask32 * mask = (const __mmask32 *)sign_bits;
- values[0] = _mm256_mask_sub_epi8(values[0], mask[0], _mm256_setzero_si256(), values[0]);
- values[1] = _mm256_mask_sub_epi8(values[1], mask[1], _mm256_setzero_si256(), values[1]);
- values[2] = _mm256_mask_sub_epi8(values[2], mask[2], _mm256_setzero_si256(), values[2]);
- values[3] = _mm256_mask_sub_epi8(values[3], mask[3], _mm256_setzero_si256(), values[3]);
-#else
- auto s128 = _mm_loadu_si128((const __m128i *)sign_bits);
- auto s256 = MM256_SET_M128I(s128, s128);
- __m256i aux256;
- auto shuffle = mask1;
- auto step = _mm256_set1_epi8(4);
- aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
- values[0] = _mm256_sign_epi8(values[0], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
- aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
- values[1] = _mm256_sign_epi8(values[1], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
- aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
- values[2] = _mm256_sign_epi8(values[2], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
- aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
- values[3] = _mm256_sign_epi8(values[3], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
-#endif
- }
- const __m256i mask1 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
- const __m256i mask2 = _mm256_set1_epi64x(0x8040201008040201ull);
- const __m256i mone = _mm256_set1_epi8(1);
-};
-
-struct SimpleBits {
- __m256i values[4];
-};
-
-__m128i inline load_iq4nl_values_128() {
- static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
- return _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
-}
-
-__m256i inline load_iq4nl_values_256() {
- auto val128 = load_iq4nl_values_128();
- return MM256_SET_M128I(val128, val128);
-}
-
-__m128i inline load_iq4k_values_128() {
- return _mm_loadu_si128((const __m128i *)iq4k_values);
-}
-
-__m256i inline load_iq4k_values_256() {
- auto val128 = load_iq4k_values_128();
- return MM256_SET_M128I(val128, val128);
-}
-
-#ifdef HAVE_FANCY_SIMD
-//====================================== Zen4 ==================================================
-
-struct BlockPermuter {
- const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
- const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
-};
-
-struct Q4Bits {
- inline void prepare(const uint8_t * q4) {
- auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
- auto tmp1 = _mm512_and_si512(q4bits, ml);
- auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
- values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
- values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
- q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
- tmp1 = _mm512_and_si512(q4bits, ml);
- tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
- values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
- values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
- }
- inline void prepare64(const uint8_t * q4) {
- auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
- values[0] = _mm512_and_si512(q4bits, ml);
- values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
- q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
- values[2] = _mm512_and_si512(q4bits, ml);
- values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
- }
- inline void prepare64a(const uint8_t * q4) {
- for (int k = 0; k < 4; ++k) {
- auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + k);
- values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(q4bits), _mm256_srli_epi16(q4bits, 4), 1);
- values[k] = _mm512_and_si512(values[k], ml);
- }
- }
- __m512i values[4];
- const __m512i ml = _mm512_set1_epi8(0xf);
- BlockPermuter perm;
-};
-
-struct Q2Bits {
- inline void prepare(const uint8_t * q2) {
-
- auto q2bits = _mm512_loadu_si512((const __m512i*)q2);
- auto tmp = _mm512_srli_epi16(q2bits, 2);
-
- values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);
- values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);
- values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);
- values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);
- values[0] = _mm512_and_si512(values[0], ml);
- values[2] = _mm512_and_si512(values[2], ml);
- }
- __m512i values[4];
- const __m512i ml = _mm512_set1_epi8(0x03);
- BlockPermuter perm;
-};
-
-struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
- DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- bits.prepare(x[i].qs);
- auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
- scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
- scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
- }
-
- Q4Bits bits;
- Scales8K s8k;
-};
-
-__m512i inline load_iq4nl_values_512() {
- auto val256 = load_iq4nl_values_256();
- return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
-}
-
-
-struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
- DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- prepare(x[i].qs);
- auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);
- s8k.accum_mins(scales128, q8, i, -128.f*d, accd);
- auto scales256 = MM256_SET_M128I(scales128, scales128);
- auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
- scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
- scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
- scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
- scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
- }
- inline void prepare(const uint8_t * q4) {
- bits.prepare64(q4);
- // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
- // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127
- // etc.
- auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
- bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
- bits.values[0] = _mm512_shuffle_epi8(values, tmp);
- tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
- bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
- bits.values[2] = _mm512_shuffle_epi8(values, tmp);
- }
-
- Q4Bits bits;
- Scales8KBase s8k;
- ScaleIQ4XS siq4;
- const __m512i values;
- const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
- const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
- const __m512i shuffles[4] = {
- _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
- };
-};
-
-struct HighBit5 {
- inline void apply(const uint8_t * h, Q4Bits& bits) {
- auto hbits256 = _mm256_loadu_si256((const __m256i *)h);
- auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);
- bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh));
- bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));
- bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(hbits, mh));
- bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));
- }
- const __m512i mh = _mm512_set1_epi8(0x10);
-};
-
-struct HighBit3 {
- inline void apply(const uint8_t * h, Q2Bits& bits) {
- auto hbits256 = _mm256_loadu_si256((const __m256i *)h);
- auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);
- bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));
- bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, mh));
- bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));
- bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), mh));
- }
- const __m512i mh = _mm512_set1_epi8(0x04);
-};
-
-struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
- DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- bits.prepare(x[i].qs);
- hbits.apply(x[i].qh, bits);
- auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
- scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
- scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
- }
-
- Q4Bits bits;
- HighBit5 hbits;
- Scales8K s8k;
-};
-
-struct Scale16 {
- inline void make_scales(const __m128i& scales8, __m512i * scales) const {
- auto all_scales8 = MM256_SET_M128I(scales8, scales8);
- auto scales1 = _mm256_shuffle_epi8(all_scales8, shuffle1);
- auto scales2 = _mm256_shuffle_epi8(all_scales8, shuffle2);
- scales[0] = _mm512_cvtepi8_epi16(scales1);
- scales[1] = _mm512_cvtepi8_epi16(scales2);
- }
- template <typename Q8>
- inline void process_mins_and_scales(int i, float c, const __m128i& mins8, const __m128i& scales8,
- const Q8& q8, __m256 * accm, __m512i * scales) const {
- process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, c, accm);
- make_scales(scales8, scales);
- }
- const __m256i shuffle1 = _mm256_set_epi32(0x07070707, 0x03030303, 0x06060606, 0x02020202,
- 0x05050505, 0x01010101, 0x04040404, 0x00000000);
- const __m256i shuffle2 = _mm256_set_epi32(0x0f0f0f0f, 0x0b0b0b0b, 0x0e0e0e0e, 0x0a0a0a0a,
- 0x0d0d0d0d, 0x09090909, 0x0c0c0c0c, 0x08080808);
-};
-
-struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
- DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- bits.prepare(x[i].qs);
- const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
- const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
- const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
- sc16.process_mins_and_scales(i, -GGML_FP16_TO_FP32(x[i].dmin), mins8, scales8, q8, accm, scales);
- }
-
- Q2Bits bits;
- Scale16 sc16;
- const __m128i m4 = _mm_set1_epi8(0xf);
-
-};
-
-struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
- DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- bits.prepare(x[i].qs);
- hbits.apply(x[i].hmask, bits);
- auto scales128 = sc3.make_scales((const uint16_t *)x[i].scales);
- sc16.process_mins_and_scales(i, -4.f*d, scales128, scales128, q8, accm, scales);
- }
-
- Q2Bits bits;
- HighBit3 hbits;
- ScaleQ3 sc3;
- Scale16 sc16;
- const __m128i m4 = _mm_set1_epi8(0xf);
- const __m128i m32 = _mm_set1_epi8(-32);
-};
-
-struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
- DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- bits.prepare64(x[i].ql);
- add_high_bits(x[i].qh, bits);
- auto scales128 = _mm_loadu_si128((const __m128i *)x[i].scales);
- sc16.process_mins_and_scales(i, -32.f*d, scales128, scales128, q8, accm, scales);
- }
-
- inline void add_high_bits(const uint8_t * qh, Q4Bits& bits) const {
- auto hbits = _mm512_loadu_si512((const __m512i *)qh);
- auto tmp1 = _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh);
- auto tmp2 = _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh);
- bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));
- bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));
- tmp1 = _mm512_and_si512(hbits, mh);
- tmp2 = _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh);
- bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));
- bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));
- }
-
- Q4Bits bits;
- HighBit3 hbits;
- Scale16 sc16;
-
- const __m512i mh = _mm512_set1_epi8(0x30);
-
-};
-
-struct IQXKScales {
- IQXKScales(uint8_t shift, int8_t min_val) : eshift(_mm256_set1_epi16(shift)), min(_mm256_set1_epi16(min_val)) {}
- template <typename Q8>
- inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m512i * scales) const {
- auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle));
- scales16 = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, extra, min, eshift));
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- const __m256i prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i));
- accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
- }
- scales16 = MM256_SET_M128I(scales8, scales8);
- scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1));
- scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2));
- }
- const __m256i eshift;
- const __m256i min;
- const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
- const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
- const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
- const __m256i shuffle1 = _mm256_set_epi64x(0x0b0b0b0b09090909, 0x0303030301010101, 0x0a0a0a0a08080808, 0x0202020200000000);
- const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0d0d0d0d, 0x0707070705050505, 0x0e0e0e0e0c0c0c0c, 0x0606060604040404);
-};
-struct IQXKScales2 {
- IQXKScales2(uint8_t shift, int8_t min_val) : eshift(_mm256_set1_epi16(shift)), min(_mm256_set1_epi16(min_val)) {}
- template <typename Q8>
- inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m512i * scales) const {
- process(i, d, extra, _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)), q8, accm, scales);
- }
- template <typename Q8>
- inline void process(int i, float d, uint16_t extra, __m256i scales16, const Q8& q8, __m256 * accm, __m512i * scales) const {
- auto scales_s = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, extra, min, eshift));
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- const __m256i prod = _mm256_madd_epi16(scales_s, q8.load_bsums(iy, i));
- accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
- }
- auto aux_1 = MM256_SET_M128I(_mm256_castsi256_si128(scales16), _mm256_castsi256_si128(scales16));
- auto aux_2 = MM256_SET_M128I(_mm256_extracti128_si256(scales16, 1), _mm256_extracti128_si256(scales16, 1));
- auto scales16_1 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_1), aux_1, 1);
- auto scales16_2 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_2), aux_2, 1);
- scales[0] = _mm512_shuffle_epi8(scales16_1, shuffles[0]);
- scales[1] = _mm512_shuffle_epi8(scales16_1, shuffles[1]);
- scales[2] = _mm512_shuffle_epi8(scales16_2, shuffles[0]);
- scales[3] = _mm512_shuffle_epi8(scales16_2, shuffles[1]);
- }
- const __m256i eshift;
- const __m256i min;
- const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
- const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
- const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
- const __m512i shuffles[2] = {
- _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(),
- _mm_set1_epi16(0x0100), 0), _mm_set1_epi16(0x0302), 1), _mm_set1_epi16(0x0504), 2), _mm_set1_epi16(0x0706), 3),
- _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(),
- _mm_set1_epi16(0x0908), 0), _mm_set1_epi16(0x0b0a), 1), _mm_set1_epi16(0x0d0c), 2), _mm_set1_epi16(0x0f0e), 3)
- };
-};
-
-struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
- DequantizerIQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(IQXKScales(5, -32)), values(load_values()) {}
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- prepare(x[i].qs);
- iqxk.process(i, d, x[i].extra, make_scales(x[i].scales), q8, accm, scales);
- }
- inline void prepare(const uint8_t * q2) {
- bits.prepare(q2);
- bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]);
- bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]);
- bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]);
- bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]);
- }
- static inline __m512i load_values() {
- static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0};
- auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl);
- auto val256 = MM256_SET_M128I(val128, val128);
- return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
- }
- inline __m128i make_scales(const uint8_t * scales_l) const {
- uint64_t aux64; std::memcpy(&aux64, scales_l, 8);
- auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf));
- return _mm_add_epi8(scl, m8);
- }
- Q2Bits bits;
- const IQXKScales iqxk;
-
- const __m512i values;
- const __m128i m8 = _mm_set1_epi8(-8);
-};
-
-struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
- DequantizerIQ2KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}
- template <typename Q8>
- inline void compute_block(int i, const Q8& q8, __m512 * acc) {
- prepare(x[i].qs);
- auto scales128 = make_scales(x[i].scales, x[i].extra >> 8);
- auto shifts = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi8(x[i].extra), hmask), hmask), m5);
- auto mins128 = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts)));
- auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0]));
- auto scales256 = MM256_SET_M128I(scales128, scales128);
- auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
- __m512i scales[4];
- for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]);
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- auto q8s = q8.load_bsums(iy, i);
- auto prod = _mm256_madd_epi16(mins, q8s);
- auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0);
- for (int k = 0; k < 4; ++k) {
- auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k));
- sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]);
- }
- acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]);
- }
- }
- inline void prepare(const uint8_t * q2) {
- bits.prepare(q2);
- bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]);
- bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]);
- bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]);
- bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]);
- }
- static inline __m512i load_values() {
- static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0};
- auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl);
- auto val256 = MM256_SET_M128I(val128, val128);
- return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
- }
- inline __m128i make_scales(const uint8_t * scales_l, uint8_t scales_h) const {
- const uint16_t * scales = (const uint16_t *)scales_l;
- uint32_t aux32 = scales[0] | (uint32_t(scales[1]) << 16);
- auto scl = _mm_srlv_epi32(_mm_set1_epi32(aux32), shift);
- scl = _mm_and_si128(_mm_shuffle_epi8(scl, shuffle), _mm_set1_epi8(0xf));
- auto sch = _mm_set1_epi8(scales_h);
- sch = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(sch, hmask), _mm_setzero_si128()), m16);
- return _mm_cvtepi8_epi16(_mm_add_epi8(scl, sch));
- }
- Q2Bits bits;
- Scales8KBase s8k;
-
- const __m512i values;
- const __m128i m16 = _mm_set1_epi8(-16);
- const __m128i m5 = _mm_set1_epi8(5);
- const __m128i m32 = _mm_set1_epi8(-32);
- const __m128i hmask = _mm_set1_epi64x(0x8040201008040201);
- const __m128i shuffle = _mm_set1_epi64x(0x0703060205010400);
- const __m128i shift = _mm_set_epi32(0, 0, 4, 0);
- const __m512i shuffles[4] = {
- _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
- };
-};
-
-struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
- DequantizerIQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -64), values(load_values()) {}
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- prepare(x[i].qs, x[i].qh);
- iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_h, x[i].scales_l), q8, accm, scales);
- }
- inline void prepare(const uint8_t * q2, const uint8_t * qh) {
- bits.prepare(q2);
- auto h256 = _mm256_loadu_si256((const __m256i *)qh);
- auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 1), 1);
- bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), hmask));
- bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, hmask));
- bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), hmask));
- bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), hmask));
- bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]);
- bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]);
- bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]);
- bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]);
- }
- static inline __m512i load_values() {
- static const uint8_t kvalues_iq3nl[16] = {1, 24, 41, 54, 65, 77, 92, 111, 5, 28, 45, 58, 69, 81, 96, 115};
- auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq3nl);
- auto val256 = MM256_SET_M128I(val128, val128);
- return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
- }
- inline __m128i make_scales(uint16_t signs, const uint8_t * scales_l) const {
- uint64_t aux64; std::memcpy(&aux64, scales_l, 8);
- auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf));
- scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), m1);
- const __m128i sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(signs), sign_mask), sign_mask);
- const __m128i sch = _mm_shuffle_epi8(_mm_or_si128(sc_signs, _mm_set1_epi8(1)), hshuff);
- return _mm_sign_epi8(scl, sch);
- }
- Q2Bits bits;
- const IQXKScales2 iqxk;
-
- const __m512i values;
- const __m512i hmask = _mm512_set1_epi8(4);
- const __m128i m1 = _mm_set1_epi8(1);
- const __m128i sign_mask = _mm_set_epi64x(0x8080404020201010, 0x0808040402020101);
- const __m128i hshuff = _mm_loadu_si128((const __m128i*)k_shuff);
- constexpr static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
-};
-
-struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
- DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -128), values(load_iq4nl_values_512()) {}
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- prepare(x[i].qs);
- iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales);
- }
- inline void prepare(const uint8_t * q4) {
- bits.prepare64(q4);
- // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
- // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127
- // etc.
- auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
- bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
- bits.values[0] = _mm512_shuffle_epi8(values, tmp);
- tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
- bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
- bits.values[2] = _mm512_shuffle_epi8(values, tmp);
- }
- __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const {
- uint64_t aux64;
- memcpy(&aux64, scales_l, 8);
- auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
- const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16);
- auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh);
- auto sch = _mm_shuffle_epi8(aux, iqxk.scale_shuffle);
- return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
- }
-
- Q4Bits bits;
- const IQXKScales2 iqxk;
- const __m512i values;
- const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
- const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
- const __m128i maskl = _mm_set1_epi8(0xf);
- const __m128i maskh = _mm_set1_epi8(0x30);
- const __m128i m32 = _mm_set1_epi8(-32);
-};
-
-struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
- DequantizerIQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(2, -128) { load_values(values); }
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- prepare(x[i].qs, x[i].qh);
- iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales);
- }
- inline void prepare(const uint8_t * q4, const uint8_t * qh) {
- bits.prepare64(q4);
- auto h256 = _mm256_loadu_si256((const __m256i *)qh);
- auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 2), 1);
- auto m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1);
- auto m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
- bits.values[0] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[0]), m1, values[1], bits.values[0]);
- bits.values[1] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[1]), m2, values[1], bits.values[1]);
- hbits = _mm512_srli_epi16(hbits, 4);
- m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1);
- m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
- bits.values[2] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[2]), m1, values[1], bits.values[2]);
- bits.values[3] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[3]), m2, values[1], bits.values[3]);
- // We now have in bits.valuse[0]: 0...31, 64...95
- // bits.valuse[1]: 32..63, 96..127
- // etc.
- auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
- bits.values[1] = _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]);
- bits.values[0] = tmp;
- tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
- bits.values[3] = _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]);
- bits.values[2] = tmp;
- }
- __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const {
- uint64_t aux64;
- memcpy(&aux64, scales_l, 8);
- auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
- const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16);
- auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh);
- auto sch = _mm_shuffle_epi8(aux, iqxk.scale_shuffle);
- return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
- }
- static void load_values(__m512i * values) {
- static const uint8_t kvalues_iq5nl[32] = {
- 2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127,
- 133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249,
- };
- auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0);
- auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1);
- auto values256_1 = MM256_SET_M128I(values128_1, values128_1);
- auto values256_2 = MM256_SET_M128I(values128_2, values128_2);
- values[0] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_1), values256_1, 1);
- values[1] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_2), values256_2, 1);
- }
-
- Q4Bits bits;
- const IQXKScales2 iqxk;
- __m512i values[2];
- const __m512i hmask1 = _mm512_set1_epi8(1);
- const __m512i hmask2 = _mm512_set1_epi8(2);
- const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
- const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
- const __m128i maskl = _mm_set1_epi8(0xf);
- const __m128i maskh = _mm_set1_epi8(0x30);
- const __m128i m32 = _mm_set1_epi8(-32);
-};
-
-struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
- DequantizerIQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(1, -128) { load_values(values); }
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- prepare(x[i].qs, x[i].qh);
- auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales);
- iqxk.process(i, d, x[i].extra, _mm256_cvtepi8_epi16(scales8), q8, accm, scales);
- }
- inline __m512i make_one(__m512i l, __m512i h) const {
- auto p = _mm512_shuffle_epi8(values[0], l);
- p = _mm512_mask_shuffle_epi8(p, _mm512_cmpeq_epi8_mask(_mm512_and_si512(h, masks[0]), masks[0]), values[1], l);
- p = _mm512_mask_shuffle_epi8(p, _mm512_cmpeq_epi8_mask(_mm512_and_si512(h, masks[1]), masks[1]), values[2], l);
- p = _mm512_mask_shuffle_epi8(p, _mm512_cmpeq_epi8_mask(_mm512_and_si512(h, masks[2]), masks[2]), values[3], l);
- return p;
- }
- inline void prepare(const uint8_t * q4, const uint8_t * qh) {
- bits.prepare64(q4);
- auto h256_1 = _mm256_loadu_si256((const __m256i *)qh + 0);
- auto h256_2 = _mm256_loadu_si256((const __m256i *)qh + 1);
- auto h1 = _mm512_inserti32x8(_mm512_castsi256_si512(h256_1), _mm256_srli_epi16(h256_1, 4), 1);
- auto h2 = _mm512_inserti32x8(_mm512_castsi256_si512(h256_2), _mm256_srli_epi16(h256_2, 4), 1);
- bits.values[0] = make_one(bits.values[0], h1);
- bits.values[1] = make_one(bits.values[1], _mm512_srli_epi16(h1, 2));
- bits.values[2] = make_one(bits.values[2], h2);
- bits.values[3] = make_one(bits.values[3], _mm512_srli_epi16(h2, 2));
- // We now have in bits.valuse[0]: 0...31, 64...95
- // bits.valuse[1]: 32..63, 96..127
- // etc.
- auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
- bits.values[1] = _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]);
- bits.values[0] = tmp;
- tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
- bits.values[3] = _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]);
- bits.values[2] = tmp;
- }
- static void load_values(__m512i * values) {
- static const uint8_t kvalues_iq6nl[64] = {
- 1, 7, 13, 19, 24, 30, 35, 40, 44, 49, 54, 58, 62, 66, 70, 74,
- 77, 81, 84, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 117, 120, 123,
- 126, 128, 131, 134, 137, 140, 142, 145, 148, 151, 155, 158, 161, 164, 168, 172,
- 175, 179, 183, 187, 191, 196, 200, 205, 210, 215, 220, 226, 231, 237, 243, 249,
- };
- for (int k = 0; k < 4; ++k) {
- auto values128 = _mm_loadu_si128((const __m128i *)kvalues_iq6nl + k);
- auto values256 = MM256_SET_M128I(values128, values128);
- values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(values256), values256, 1);
- }
- }
-
- Q4Bits bits;
- IQXKScales2 iqxk;
- __m512i values[4];
- __m512i masks[3] = { _mm512_set1_epi8(0x01), _mm512_set1_epi8(0x02), _mm512_set1_epi8(0x03) };
- const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
- const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
-};
-
-struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
- DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
- auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
- auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4);
- scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
- auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
- s8k.accum_mins(scales_s, q8, i, d, accm);
- auto scales256 = MM256_SET_M128I(scales128, scales128);
- auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
- scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
- scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
- scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
- scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
- prepare(x[i].qs);
- }
- template <typename Q8>
- inline void compute_block(int i, const Q8& q8, __m512 * acc) {
- auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
- auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4);
- scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
- auto mins128 = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
- auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0]));
- auto scales256 = MM256_SET_M128I(scales128, scales128);
- auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
- __m512i scales[4];
- for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]);
- prepare(x[i].qs);
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- auto q8s = q8.load_bsums(iy, i);
- auto prod = _mm256_madd_epi16(mins, q8s);
- auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0);
- for (int k = 0; k < 4; ++k) {
- auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k));
- sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]);
- }
- acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]);
- }
- }
- inline void prepare(const uint8_t * q4) {
- bits.prepare64(q4);
- // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
- // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127
- // etc.
- auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
- bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
- bits.values[0] = _mm512_shuffle_epi8(values, tmp);
- tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
- bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
- bits.values[2] = _mm512_shuffle_epi8(values, tmp);
- }
-
- Q4Bits bits;
- Scales8KBase s8k;
- const __m512i values;
- const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
- const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
- const __m128i mask = _mm_set1_epi16(254);
- const __m128i m127 = _mm_set1_epi16(-127);
- const __m128i m128 = _mm_set1_epi16(-128);
- const __m128i m1 = _mm_set1_epi16(1);
- const __m128i m4 = _mm_set1_epi16(4);
- const __m512i shuffles[4] = {
- _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
- };
-};
-
-struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
- DequantizerIQ5KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(values); }
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
- auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
- auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2);
- scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
- auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
- s8k.accum_mins(scales_s, q8, i, d, accm);
- auto scales256 = MM256_SET_M128I(scales128, scales128);
- auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
- scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
- scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
- scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
- scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
- prepare(x[i].qs, x[i].qh);
- }
- template <typename Q8>
- inline void compute_block(int i, const Q8& q8, __m512 * acc) {
- auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
- auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2);
- scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
- auto mins128 = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
- auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0]));
- auto scales256 = MM256_SET_M128I(scales128, scales128);
- auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
- __m512i scales[4];
- for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]);
- prepare(x[i].qs, x[i].qh);
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- auto q8s = q8.load_bsums(iy, i);
- auto prod = _mm256_madd_epi16(mins, q8s);
- auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0);
- for (int k = 0; k < 4; ++k) {
- auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k));
- sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]);
- }
- acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]);
- }
- }
- inline void prepare(const uint8_t * q4, const uint8_t * qh) {
- bits.prepare64a(q4);
- auto h256 = _mm256_loadu_si256((const __m256i *)qh);
- auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 1), 1);
- auto m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1);
- auto m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
- bits.values[0] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[0]), m1, values[1], bits.values[0]);
- bits.values[1] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[1]), m2, values[1], bits.values[1]);
- hbits = _mm512_srli_epi16(hbits, 4);
- m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1);
- m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
- bits.values[2] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[2]), m1, values[1], bits.values[2]);
- bits.values[3] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[3]), m2, values[1], bits.values[3]);
- }
- static void load_values(__m512i * values) {
- static const uint8_t kvalues_iq5nl[32] = {
- 2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127,
- 133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249,
- };
- auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0);
- auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1);
- auto values256_1 = MM256_SET_M128I(values128_1, values128_1);
- auto values256_2 = MM256_SET_M128I(values128_2, values128_2);
- values[0] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_1), values256_1, 1);
- values[1] = _mm512_inserti32x8(_mm512_castsi256_si512(values256_2), values256_2, 1);
- }
-
- Q4Bits bits;
- Scales8KBase s8k;
- __m512i values[2];
- const __m512i hmask1 = _mm512_set1_epi8(1);
- const __m512i hmask2 = _mm512_set1_epi8(4);
- const __m128i m127 = _mm_set1_epi16(-127);
- const __m128i m128 = _mm_set1_epi16(-128);
- const __m128i mask = _mm_set1_epi16(254);
- const __m128i m1 = _mm_set1_epi16(1);
- const __m128i m2 = _mm_set1_epi16(2);
- const __m512i shuffles[4] = {
- _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
- };
-};
-
-struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
- DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
- uint32_t aux32[2];
- auto b1 = _mm512_loadu_si512((const __m512i *)x[i].qs + 0);
- auto b2 = _mm512_loadu_si512((const __m512i *)x[i].qs + 1);
- auto bs1 = _mm512_and_si512(b1, mask15);
- bs1 = _mm512_xor_si512(bs1, _mm512_srli_epi16(bs1, 1));
- auto bs2 = _mm512_and_si512(b2, mask15);
- bs2 = _mm512_xor_si512(bs2, _mm512_srli_epi16(bs2, 1));
- bits.values[0] = _mm512_and_si512(bs1, bits.ml);
- bits.values[1] = _mm512_and_si512(_mm512_srli_epi16(bs1, 4), bits.ml);
- bits.values[2] = _mm512_and_si512(bs2, bits.ml);
- bits.values[3] = _mm512_and_si512(_mm512_srli_epi16(bs2, 4), bits.ml);
- auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
- bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
- bits.values[0] = _mm512_shuffle_epi8(values, tmp);
- tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
- bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
- bits.values[2] = _mm512_shuffle_epi8(values, tmp);
- //
- // Now the more difficult part - prepare the scales
- //
- aux32[0] = _mm512_cmpeq_epi16_mask(_mm512_and_si512(b1, mask1), mask1);
- aux32[1] = _mm512_cmpeq_epi16_mask(_mm512_and_si512(b2, mask1), mask1);
-
- auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)aux32));
- auto m1 = _mm512_castsi512_si128(mask1);
- auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4);
- scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
- auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
- s8k.accum_mins(scales_s, q8, i, d, accm);
- auto scales256 = MM256_SET_M128I(scales128, scales128);
- auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
- scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
- scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
- scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
- scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
- }
-
- Q4Bits bits;
- Scales8KBase s8k;
- const __m512i values;
- const __m512i mask15 = _mm512_set1_epi16(-2); // value is 0xfffe, but to shut up the stupid compiler warning we use the signed value
- const __m512i mask1 = _mm512_set1_epi16(1);
- const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
- const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
- const __m128i mask = _mm_set1_epi16(254);
- const __m128i m127 = _mm_set1_epi16(-127);
- const __m128i m128 = _mm_set1_epi16(-128);
- const __m128i m4 = _mm_set1_epi16(4);
- const __m512i shuffles[4] = {
- _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
- _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
- };
-};
-
-
-template <typename Q8>
-inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
- const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
- const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1));
- const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2));
- const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3));
- auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
- sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
- accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
-}
-
-template <typename Q8>
-inline void compute_block_iq2tn(int iy, int i, float d, const Q8& q8, const __m512i * values, __m512 * accd) {
- auto sumi_scales = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i));
- auto sumi = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(
- _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales, 0),
- values[0], q8.load_quants64(iy, i, 0)), values[1], q8.load_quants64(iy, i, 1)),
- values[2], q8.load_quants64(iy, i, 2)), values[3], q8.load_quants64(iy, i, 3));
- accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
-}
-
-template <typename Dequantizer, int nrc_y>
-static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n % QK_K == 0);
- const int nb = n / QK_K;
-
- Q8<nrc_y> q8(info);
-
- Dequantizer deq(vx, bx);
-
- __m256 accm[nrc_y];
- __m512 accd[nrc_y];
- __m512i scales[2];
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
- for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();
-
- deq.new_row(ix);
-
- for (int i = 0; i < nb; ++i) {
-
- deq.new_block(i, q8, accm, scales);
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
- const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
- const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
- const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
- auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
- sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
- accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
- }
-
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
- info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
- }
-
- }
-}
-
-template <typename Dequantizer, int nrc_y>
-static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n % QK_K == 0);
- const int nb = n / QK_K;
-
- Q8<nrc_y> q8(info);
-
- Dequantizer deq(vx, bx);
-
- __m256 accm[nrc_y];
- __m512 accd[nrc_y];
- __m512i scales[4];
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
- for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();
-
- deq.new_row(ix);
-
- for (int i = 0; i < nb; ++i) {
-
- deq.new_block(i, q8, accm, scales);
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- const __m512i p1 = _mm512_maddubs_epi16(deq.bits.values[0], q8.load_quants64(iy, i, 0));
- const __m512i p2 = _mm512_maddubs_epi16(deq.bits.values[1], q8.load_quants64(iy, i, 1));
- const __m512i p3 = _mm512_maddubs_epi16(deq.bits.values[2], q8.load_quants64(iy, i, 2));
- const __m512i p4 = _mm512_maddubs_epi16(deq.bits.values[3], q8.load_quants64(iy, i, 3));
- auto sumi = _mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_setzero_si512(),
- p1, scales[0]), p2, scales[1]), p3, scales[2]), p4, scales[3]);
- accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
- }
-
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
- info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
- }
-
- }
-}
-
-template <typename Dequantizer, int nrc_y>
-static void mul_mat_iqX_k_q8_K_AVX512_new(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n % QK_K == 0);
- const int nb = n / QK_K;
-
- Q8<nrc_y> q8(info);
-
- Dequantizer deq(vx, bx);
-
- __m512 accd[nrc_y];
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
-
- deq.new_row(ix);
-
- for (int i = 0; i < nb; ++i) {
- deq.compute_block(i, q8, accd);
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, _mm512_reduce_add_ps(accd[iy]));
- }
-
- }
-}
-
-template <typename Dequantizer>
-static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n % QK_K == 0);
- const int nb = n / QK_K;
-
- constexpr int k_nx = 2;
-
- Q8<1> q8(info);
-
- Dequantizer deq1(vx, bx);
- Dequantizer deq2(vx, bx);
-
- Dequantizer * deq[k_nx];
- deq[0] = &deq1;
- deq[1] = &deq2;
-
- __m512i scales[2*k_nx];
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- auto accd = _mm512_setzero_ps();
- auto accm = _mm256_setzero_ps();
-
- for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix);
-
- for (int i = 0; i < nb/k_nx; ++i) {
-
- for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);
-
- for (int kx = 0; kx < k_nx; ++kx) {
- compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
- }
-
- }
- if (2*(nb/2) < nb) {
- int i0 = 2*(nb/2);
- deq[0]->new_block(i0, q8, &accm, scales);
- compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
- }
-
- auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
- info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
- }
-}
-
-#else
-// ===================================== Vanilla AVX2 =====================================
-
-struct Q4Bits {
- inline void prepare(const uint8_t * q4, int j) {
- auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
- values[0] = _mm256_and_si256(q4bits, ml);
- values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
- q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
- values[2] = _mm256_and_si256(q4bits, ml);
- values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
- }
- inline void prepare64(const uint8_t * q4, int j) {
- auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
- values[0] = _mm256_and_si256(q4bits, ml);
- values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
- q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
- values[1] = _mm256_and_si256(q4bits, ml);
- values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
- }
- inline void prepare16(const uint8_t * q4, int j) {
- values[0] = dequant16(q4 + 64*j + 0);
- values[1] = dequant16(q4 + 64*j + 16);
- values[2] = dequant16(q4 + 64*j + 32);
- values[3] = dequant16(q4 + 64*j + 48);
- }
- inline __m256i dequant16(const uint8_t * qs) const {
- const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);
- const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);
- return _mm256_and_si256(ml, aux256);
- }
- __m256i values[4];
- const __m256i ml = _mm256_set1_epi8(0xf);
-};
-
-struct Q2Bits {
- inline void prepare(const uint8_t * q2, int j) {
- auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);
- values[0] = _mm256_and_si256(q2bits, ml);
- values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
- values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
- values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
- }
- __m256i values[4];
- const __m256i ml = _mm256_set1_epi8(0x03);
-};
-
-struct HighBit5 {
- inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }
- inline void apply(Q4Bits& bits, bool do_shift) {
- bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
- bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh));
- bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
- bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
- if (do_shift) {
- hbits = _mm256_srli_epi16(hbits, 4);
- }
- }
- const __m256i mh = _mm256_set1_epi8(0x10);
- __m256i hbits;
-};
-
-struct HighBit3 {
- inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }
- inline void apply(Q2Bits& bits, bool do_shift) {
- bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
- bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
- bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));
- bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));
- if (do_shift) {
- hbits = _mm256_srli_epi16(hbits, 4);
- }
- }
- const __m256i mh = _mm256_set1_epi8(0x04);
- __m256i hbits;
-};
-
-struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
- DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
- template <typename Q8>
- inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
- d = GGML_FP16_TO_FP32(x[i].d);
- return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
- }
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs, j);
- }
-
- Q4Bits bits;
- Scales8K s8k;
-};
-
-struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
- DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {}
- template <typename Q8>
- inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
- d = GGML_FP16_TO_FP32(x[i].d);
- auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);
- s8k.accum_mins(scales128, q8, i, -128.f*d, accd);
- return MM256_SET_M128I(scales128, scales128);
- }
- inline void prepare(int i, int j) {
- bits.prepare16(x[i].qs, j);
- bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);
- bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);
- bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);
- bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
- }
-
- Q4Bits bits;
- Scales8K s8k;
- ScaleIQ4XS siq4;
- const __m256i values;
-};
-
-struct IQXKScales {
- IQXKScales(int8_t shift, int8_t min_val) : min(_mm256_set1_epi16(min_val)), eshift(_mm_set1_epi8(shift)) {}
- template <typename Q8>
- inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m256i * scales) const {
- auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, hshuff));
- process(i, d, extra, scales16, q8, accm, scales);
- }
- template <typename Q8>
- inline void process(int i, float d, uint16_t extra, __m256i scales16, const Q8& q8, __m256 * accm, __m256i * scales) const {
- auto extra128 = _mm_set1_epi16(extra);
- extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask);
- extra128 = _mm_and_si128(extra128, eshift);
- extra128 = _mm_shuffle_epi8(extra128, eshuffle);
- auto scales_s = _mm256_mullo_epi16(scales16, _mm256_add_epi16(min, _mm256_cvtepi8_epi16(extra128)));
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- const __m256i prod = _mm256_madd_epi16(scales_s, q8.load_bsums(iy, i));
- accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
- }
- prepare_scales_16(scales16, scales);
- }
-
- const __m256i min;
- const __m128i eshift;
- const __m128i hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
- const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
- const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
-};
-
-struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
- DequantizerIQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(5, -32), values(load_values()) {}
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- iqxk.process(i, d, x[i].extra, make_scales(x[i].scales), q8, accm, scales);
- }
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs, j);
- bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);
- bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);
- bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);
- bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
- }
- static inline __m256i load_values() {
- static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0};
- auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl);
- return MM256_SET_M128I(val128, val128);
- }
- inline __m128i make_scales(const uint8_t * scales_l) const {
- uint64_t aux64; std::memcpy(&aux64, scales_l, 8);
- auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
- return _mm_add_epi8(scl, m8);
- }
-
- Q2Bits bits;
- const IQXKScales iqxk;
- const __m256i values;
- const __m128i m8 = _mm_set1_epi8(-8);
- const __m128i maskl = _mm_set1_epi8(0xf);
-};
-
-struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
- DequantizerIQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -64), values(load_values()) {}
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_h, x[i].scales_l), q8, accm, scales);
- hbits = _mm256_loadu_si256((const __m256i *)x[i].qh);
- }
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs, j);
- auto h256 = j == 0 ? hbits : _mm256_srli_epi16(hbits, 4);
- bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(h256, 2), hmask));
- bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(h256, 1), hmask));
- bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(h256, hmask));
- bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(h256, 1), hmask));
- bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);
- bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);
- bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);
- bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
- }
- static inline __m256i load_values() {
- static const uint8_t kvalues_iq3nl[16] = {1, 24, 41, 54, 65, 77, 92, 111, 5, 28, 45, 58, 69, 81, 96, 115};
- auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq3nl);
- return MM256_SET_M128I(val128, val128);
- }
- inline __m128i make_scales(uint16_t signs, const uint8_t * scales_l) const {
- uint64_t aux64; std::memcpy(&aux64, scales_l, 8);
- auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf));
- scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), m1);
- const __m128i sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(signs), sign_mask), sign_mask);
- const __m128i sch = _mm_shuffle_epi8(_mm_or_si128(sc_signs, _mm_set1_epi8(1)), hshuff);
- return _mm_sign_epi8(scl, sch);
- }
-
- Q2Bits bits;
- const IQXKScales iqxk;
- const __m256i values;
- __m256i hbits;
- const __m256i hmask = _mm256_set1_epi8(4);
- const __m128i m1 = _mm_set1_epi8(1);
- const __m128i sign_mask = _mm_set_epi64x(0x8080404020201010, 0x0808040402020101);
- const __m128i hshuff = _mm_loadu_si128((const __m128i*)k_shuff);
- constexpr static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
-};
-
-struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
- DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); }
- template <typename Q8>
- inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, __m256i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- auto scales8 = make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h);
- auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, hshuff));
- prepare_scales_16(scales16, scales);
- }
- inline void prepare(int i, int j) {
- bits.prepare16(x[i].qs, j);
- auto extra = x[i].extra >> 8*j;
- bits.values[0] = _mm256_shuffle_epi8(values[extra & 3], bits.values[0]); extra >>= 2;
- bits.values[1] = _mm256_shuffle_epi8(values[extra & 3], bits.values[1]); extra >>= 2;
- bits.values[2] = _mm256_shuffle_epi8(values[extra & 3], bits.values[2]); extra >>= 2;
- bits.values[3] = _mm256_shuffle_epi8(values[extra & 3], bits.values[3]);
- }
- __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const {
- uint64_t aux64;
- memcpy(&aux64, scales_l, 8);
- auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
- const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16);
- auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh);
- auto sch = _mm_shuffle_epi8(aux, hshuff);
- return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
- }
- void load_values() {
- auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0);
- auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1);
- values[0] = MM256_SET_M128I(v1, v1);
- values[1] = MM256_SET_M128I(v1, v2);
- values[2] = MM256_SET_M128I(v2, v1);
- values[3] = MM256_SET_M128I(v2, v2);
- }
-
- Q4Bits bits;
- const __m128i maskl = _mm_set1_epi8(0xf);
- const __m128i maskh = _mm_set1_epi8(0x30);
- const __m128i m32 = _mm_set1_epi8(-32);
- const __m128i hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
- __m256i values[4];
-};
-
-struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
- DequantizerIQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(2, 0) { load_values(values); }
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales);
- hbits = _mm256_loadu_si256((const __m256i *)x[i].qh);
- }
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs, j);
- auto h = j == 0 ? hbits : _mm256_srli_epi16(hbits, 4);
- for (int k = 0; k < 4; ++k) {
- auto qh = _mm256_and_si256(_mm256_slli_epi16(h, 7-k), mh);
- auto q5vl = _mm256_or_si256(bits.values[k], qh);
- auto q5vh = _mm256_or_si256(bits.values[k], _mm256_xor_si256(qh, mh));
- bits.values[k] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
- }
- }
- __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const {
- uint64_t aux64;
- memcpy(&aux64, scales_l, 8);
- auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
- const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16);
- auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh);
- auto sch = _mm_shuffle_epi8(aux, iqxk.hshuff);
- return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
- }
- static void load_values(__m256i * values) {
- auto values128_1 = _mm_loadu_si128((const __m128i *)iq5nl_values + 0);
- auto values128_2 = _mm_loadu_si128((const __m128i *)iq5nl_values + 1);
- values[0] = MM256_SET_M128I(values128_1, values128_1);
- values[1] = MM256_SET_M128I(values128_2, values128_2);
- }
-
- Q4Bits bits;
- const IQXKScales iqxk;
- __m256i hbits;
- __m256i values[2];
- const __m128i maskl = _mm_set1_epi8(0xf);
- const __m128i maskh = _mm_set1_epi8(0x30);
- const __m128i m32 = _mm_set1_epi8(-32);
- const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing
-};
-
-struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
- DequantizerIQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(1, 0) { load_values(values); }
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales);
- auto scales16 = _mm256_cvtepi8_epi16(scales8);
- iqxk.process(i, d, x[i].extra, scales16, q8, accm, scales);
- }
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs, j);
- auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j);
- for (int k = 0; k < 4; ++k) {
- bits.values[k] = make_one(bits.values[k], hbits);
- hbits = _mm256_srli_epi16(hbits, 2);
- }
- }
- inline __m256i make_one(__m256i l, __m256i hbits) const {
- auto mask4 = _mm256_cmpeq_epi8(_mm256_and_si256(hbits, mh3), mh3);
- auto h1 = _mm256_andnot_si256(mask4, hbits);
- auto mask2 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh1), mh1);
- auto mask3 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh2), mh2);
- auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(-1)); // 0xff;
- return _mm256_or_si256(_mm256_or_si256(_mm256_and_si256(mask1, _mm256_shuffle_epi8(values[0], l)),
- _mm256_and_si256(mask2, _mm256_shuffle_epi8(values[1], l))),
- _mm256_or_si256(_mm256_and_si256(mask3, _mm256_shuffle_epi8(values[2], l)),
- _mm256_and_si256(mask4, _mm256_shuffle_epi8(values[3], l))));
- }
- static void load_values(__m256i * values) {
- for (int k = 0; k < 4; ++k) {
- auto values128 = _mm_loadu_si128((const __m128i *)iq6nl_values + k);
- values[k] = MM256_SET_M128I(values128, values128);
- }
- }
-
- Q4Bits bits;
- const IQXKScales iqxk;
- __m256i values[4];
- const __m256i mh1 = _mm256_set1_epi8(1);
- const __m256i mh2 = _mm256_set1_epi8(2);
- const __m256i mh3 = _mm256_set1_epi8(3);
- const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing
-};
-
-struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
- DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); }
- template <typename Q8>
- inline __m256i new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accd) {
- auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
- scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
- return MM256_SET_M128I(scales128, scales128);
- }
- inline void prepare(int i, int j) {
- bits.prepare16(x[i].qs, j);
- bits.values[0] = _mm256_shuffle_epi8(values[x[i].scales[4*j+0] & 1], bits.values[0]);
- bits.values[1] = _mm256_shuffle_epi8(values[x[i].scales[4*j+1] & 1], bits.values[1]);
- bits.values[2] = _mm256_shuffle_epi8(values[x[i].scales[4*j+2] & 1], bits.values[2]);
- bits.values[3] = _mm256_shuffle_epi8(values[x[i].scales[4*j+3] & 1], bits.values[3]);
- }
- void load_values() {
- auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0);
- auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1);
- values[0] = MM256_SET_M128I(v1, v1);
- values[1] = MM256_SET_M128I(v2, v2);
- }
-
-
- Q4Bits bits;
- __m256i values[2];
- const __m128i mask = _mm_set1_epi16(254);
- const __m128i m127 = _mm_set1_epi16(-127);
-};
-
-struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
- DequantizerIQ5KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(values); }
- template <typename Q8>
- inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
- hbits = _mm256_loadu_si256((const __m256i *)x[i].qh);
- auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
- auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2);
- scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
- auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
- s8k.accum_mins(scales_s, q8, i, d, accd);
- return MM256_SET_M128I(scales128, scales128);
- }
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs, j);
- auto h = j == 0 ? hbits : _mm256_srli_epi16(hbits, 4);
- for (int k = 0; k < 4; ++k) {
- auto qh = _mm256_and_si256(_mm256_slli_epi16(h, 7-k), mh);
- auto q5vl = _mm256_or_si256(bits.values[k], qh);
- auto q5vh = _mm256_or_si256(bits.values[k], _mm256_xor_si256(qh, mh));
- bits.values[k] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
- }
- }
- static void load_values(__m256i * values) {
- static const uint8_t kvalues_iq5nl[32] = {
- 2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127,
- 133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249,
- };
- auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0);
- auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1);
- values[0] = MM256_SET_M128I(values128_1, values128_1);
- values[1] = MM256_SET_M128I(values128_2, values128_2);
- }
-
- Q4Bits bits;
- Scales8KBase s8k;
- __m256i hbits;
- __m256i values[2];
- const __m128i maskl = _mm_set1_epi8(0xf);
- const __m128i maskh = _mm_set1_epi8(0x30);
- const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing
- const __m128i mask = _mm_set1_epi16(254);
- const __m128i m127 = _mm_set1_epi16(-127);
- const __m128i m128 = _mm_set1_epi16(-128);
- const __m128i m1 = _mm_set1_epi16(1);
- const __m128i m2 = _mm_set1_epi16(2);
-};
-
-struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
- DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {}
- template <typename Q8>
- inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
- union { __m256i vec; uint16_t val[16]; } helper;
- for (int k = 0; k < 4; ++k) {
- data[k] = _mm256_loadu_si256((const __m256i *)x[i].qs + k);
- auto p = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(data[k], m1), m1), smask);
- p = _mm256_add_epi32(_mm256_unpackhi_epi64(p, p), p);
- p = _mm256_add_epi32(_mm256_shuffle_epi32(p, _MM_SHUFFLE(2, 3, 0, 1)), p);
- helper.vec = _mm256_hadd_epi16(p, p);
- aux[2*k+0] = helper.val[0];
- aux[2*k+1] = helper.val[8];
- data[k] = _mm256_and_si256(data[k], bmask);
- data[k] = _mm256_xor_si256(data[k], _mm256_srli_epi16(data[k], 1));
- }
- auto scales128 = _mm_loadu_si128((const __m128i *)aux);
- auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, _mm256_castsi256_si128(m1)), _mm256_castsi256_si128(m1)), m4);
- scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
- auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
- s8k.accum_mins(scales_s, q8, i, d, accd);
- return MM256_SET_M128I(scales128, scales128);
- }
- inline void prepare(int, int j) {
- for (int k = 0; k < 2; ++k) {
- auto p1 = _mm256_castsi256_si128(data[2*j+k]);
- auto p2 = _mm256_extractf128_si256(data[2*j+k], 1);
- bits.values[2*k+0] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(p1, 4), p1), bits.ml);
- bits.values[2*k+0] = _mm256_shuffle_epi8(values, bits.values[2*k+0]);
- bits.values[2*k+1] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(p2, 4), p2), bits.ml);
- bits.values[2*k+1] = _mm256_shuffle_epi8(values, bits.values[2*k+1]);
- }
- }
-
- Q4Bits bits;
- Scales8KBase s8k;
- const __m256i values;
- __m256i data[4];
- const __m256i smask = _mm256_set_epi64x(0x0080004000200010, 0x0008000400020001, 0x0080004000200010, 0x0008000400020001);
- const __m256i bmask = _mm256_set1_epi16(-2); // 0xfffe;
- const __m128i mask = _mm_set1_epi16(254);
- const __m128i m127 = _mm_set1_epi16(-127);
- const __m128i m128 = _mm_set1_epi16(-128);
- const __m256i m1 = _mm256_set1_epi16(1);
- const __m128i m4 = _mm_set1_epi16(4);
- uint16_t aux[8];
-};
-
-struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
- DequantizerIQ2KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}
- template <typename Q8>
- inline __m256i new_block(int i, const Q8& q8, __m256 * accm) {
- auto scales128 = make_scales(x[i].scales, x[i].extra >> 8);
- auto shifts = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi8(x[i].extra), hmask), hmask), m5);
- auto scales_s = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts)));
- s8k.accum_mins(scales_s, q8, i, d, accm);
- return MM256_SET_M128I(scales128, scales128);
- }
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs, j);
- bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);
- bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);
- bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);
- bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
- }
- static inline __m256i load_values() {
- static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0};
- auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl);
- return MM256_SET_M128I(val128, val128);
- }
- inline __m128i make_scales(const uint8_t * scales_l, uint8_t scales_h) const {
- const uint16_t * scales = (const uint16_t *)scales_l;
- uint32_t aux32 = scales[0] | (uint32_t(scales[1]) << 16);
- auto scl = _mm_srlv_epi32(_mm_set1_epi32(aux32), shift);
- scl = _mm_and_si128(_mm_shuffle_epi8(scl, shuffle), _mm_set1_epi8(0xf));
- auto sch = _mm_set1_epi8(scales_h);
- sch = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(sch, hmask), _mm_setzero_si128()), m16);
- return _mm_cvtepi8_epi16(_mm_add_epi8(scl, sch));
- }
- Q2Bits bits;
- Scales8KBase s8k;
-
- const __m256i values;
- const __m128i m16 = _mm_set1_epi8(-16);
- const __m128i m5 = _mm_set1_epi8(5);
- const __m128i m32 = _mm_set1_epi8(-32);
- const __m128i hmask = _mm_set1_epi64x(0x8040201008040201);
- const __m128i shuffle = _mm_set1_epi64x(0x0703060205010400);
- const __m128i shift = _mm_set_epi32(0, 0, 4, 0);
-};
-
-struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
- DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
- template <typename Q8>
- inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
- d = GGML_FP16_TO_FP32(x[i].d);
- hbits.load(x[i].qh);
- return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
- }
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs, j);
- hbits.apply(bits, j == 0);
- }
-
- Q4Bits bits;
- HighBit5 hbits;
- Scales8K s8k;
-};
-
-template <typename Q8>
-inline void process_mins_and_scales_16(const __m128i& scales128, const Q8& q8, int i, float d,
- __m256 * accm, __m256i * scales) {
- const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
- process_mins_16(all_scales, q8, i, d, accm);
- prepare_scales_16(all_scales, scales);
-}
-
-struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
- DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
-
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- hbits.load(x[i].hmask);
- process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales);
- }
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs, j);
- hbits.apply(bits, j == 0);
- }
-
- Q2Bits bits;
- HighBit3 hbits;
- ScaleQ3 sc3;
-
- const __m128i m32 = _mm_set1_epi8(-32);
-};
-
-struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
- DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
-
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
- const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
- const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
- process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, -GGML_FP16_TO_FP32(x[i].dmin), accm);
- prepare_scales_16(_mm256_cvtepi8_epi16(scales8), scales);
- }
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs, j);
- }
-
- Q2Bits bits;
-
- const __m128i m4 = _mm_set1_epi8(0xf);
-};
-
-struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
- DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- process_mins_and_scales_16(_mm_loadu_si128((const __m128i *)x[i].scales), q8, i, -32.f*d, accm, scales);
- }
- inline void prepare(int i, int j) {
- bits.prepare64(x[i].ql, j);
- auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j);
- bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
- bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
- bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));
- bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));
- }
-
- Q4Bits bits;
- const __m256i mh = _mm256_set1_epi8(0x30);
-};
-
-template <typename Dequantizer, int nrc_y>
-static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n%QK_K == 0);
- const int nb = n/QK_K;
-
- Q8<nrc_y> q8(info);
-
- __m256i all_scales[2];
- __m256i scales[4];
- __m256 accd[nrc_y];
-
- Dequantizer deq(vx, bx);
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- deq.new_row(ix);
-
- for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
-
- for (int i = 0; i < nb; ++i) {
-
- deq.new_block(i, q8, accd, all_scales);
-
- __m256i sumi[nrc_y];
-
- for (int j = 0; j < QK_K/128; ++j) {
- deq.prepare(i, j);
- set_scales_16(all_scales[j], scales);
- if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4K> ||
- std::is_same_v<Dequantizer, DequantizerIQ5K> ||
- std::is_same_v<Dequantizer, DequantizerIQ6K>) {
- multiply_add_avx2(deq.bits, scales, j, i, q8, sumi);
- } else {
- multiply_add(deq.bits, scales, j, i, q8, sumi);
- }
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
- }
-
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, hsum_float_8(accd[iy]));
- }
-
- }
-
-}
-
-template <typename Dequantizer, int nrc_y>
-static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n % QK_K == 0);
- const int nb = n / QK_K;
-
- Q8<nrc_y> q8(info);
-
- Dequantizer deq(vx, bx);
-
- __m256 accd[nrc_y];
- __m256i scales[4];
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
-
- deq.new_row(ix);
-
- for (int i = 0; i < nb; ++i) {
-
- auto all_scales = deq.new_block(i, q8, accd);
-
- __m256i sumi[nrc_y];
-
- for (int j = 0; j < QK_K/128; ++j) {
-
- deq.prepare(i, j);
-
- set_scales_8(all_scales, j, scales);
-
- if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS>) {
- multiply_add_avx2(deq.bits, scales, j, i, q8, sumi);
- } else {
- multiply_add(deq.bits, scales, j, i, q8, sumi);
- }
-
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
- accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
- }
-
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, hsum_float_8(accd[iy]));
- }
-
- }
-}
-
-#endif // Zen4 or vanilla AVX2
-
-template <int nrc_y>
-static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- if (nrc_x%4) {
- printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
- GGML_ABORT("fatal error");
- }
- Q8_16<nrc_y> q8(info);
- auto m3 = _mm256_set1_epi8(0x3);
- auto m1 = _mm256_set1_epi16(1);
- int nb = n / QK_IQ1BN;
- __m256i qx[4];
- if constexpr (nrc_y > 4) {
- __m256i acc[nrc_y] = {};
- __m128 sum4[nrc_y];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const float * dptr = (const float *)((const char *)vx + ix*bx);
- auto dl = _mm_loadu_ps(dptr);
- const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
- for (int ib = 0; ib < nb; ++ib) {
- auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
- qx[0] = _mm256_and_si256(bits, m3);
- qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
- qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
- qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = q8.load_quants(iy, 2*ib+0);
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
- _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
- _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
- acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dy = q8.scale(iy);
- auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
- auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
- s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4);
- sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4);
- acc[iy] = _mm256_setzero_si256();
- }
- for (int ib = 0; ib < nb; ++ib) {
- auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
- qx[0] = _mm256_and_si256(bits, m3);
- qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
- qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
- qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = q8.load_quants(iy, 2*ib+1);
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
- _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
- _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
- acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dy = q8.scale(iy);
- auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
- auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]);
- s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4);
- info.store(ix, iy, s4);
- acc[iy] = _mm256_setzero_si256();
- }
- }
- } else {
- __m256i acc[2*nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const float * dptr = (const float *)((const char *)vx + ix*bx);
- auto dl = _mm_loadu_ps(dptr);
- const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
- for (int ib = 0; ib < nb; ++ib) {
- auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
- qx[0] = _mm256_and_si256(bits, m3);
- qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
- qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
- qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = q8.load_quants(iy, 2*ib+0);
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
- _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
- _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
- acc[2*iy+0] = _mm256_add_epi32(acc[2*iy+0], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
- }
- bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
- qx[0] = _mm256_and_si256(bits, m3);
- qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
- qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
- qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = q8.load_quants(iy, 2*ib+1);
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
- _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
- _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
- acc[2*iy+1] = _mm256_add_epi32(acc[2*iy+1], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dy = q8.scale(iy);
- auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]);
- auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]);
- auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
- sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
- sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
- sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
- sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
- info.store(ix, iy, sum4);
- acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256();
- }
- }
- }
-}
-
-#ifdef HAVE_FANCY_SIMD
-template <int nrc_y>
-static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- if (nrc_x%4) {
- printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
- GGML_ABORT("fatal error");
- }
- if constexpr (nrc_y == 1) {
- mul_mat_iq2_bn_r4_q8_k16_avx2<1>(n, vx, bx, info, nrc_x);
- } else {
- Q8_16<nrc_y> q8(info);
- auto m3 = _mm512_set1_epi8(0x3);
- int nb = n / QK_IQ1BN;
- __m512i acc[2*nrc_y] = {};
- __m512i qx[8];
- for (int ix = 0; ix < nrc_x/8; ++ix) {
- const float * dptr1 = (const float *)((const char *)vx + (8*ix+0)*bx);
- const float * dptr2 = (const float *)((const char *)vx + (8*ix+4)*bx);
- auto dl = _mm_loadu_ps(dptr1);
- auto dh = _mm_loadu_ps(dptr2);
- const uint8_t * iq2l = (const uint8_t *)(dptr1 + 4);
- const uint8_t * iq2h = (const uint8_t *)(dptr2 + 4);
- for (int ib = 0; ib < nb; ++ib) {
- auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
- auto bits_h = _mm512_loadu_si512((const __m512i *)iq2h + ib);
- qx[0] = _mm512_and_si512(bits_l, m3);
- qx[1] = _mm512_and_si512(bits_h, m3);
- qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
- qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 2), m3);
- qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
- qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 4), m3);
- qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
- qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 6), m3);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = q8.load_quants64(iy, ib);
- auto sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00));
- acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], sy);
- acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], sy);
- sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55));
- acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], sy);
- acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], sy);
- sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa));
- acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[4], sy);
- acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[5], sy);
- sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff));
- acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[6], sy);
- acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[7], sy);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dy = q8.scale(iy);
- __m128 sum4;
- for (int k = 0; k < 2; ++k) {
- const auto& dx = k == 0 ? dl : dh;
- auto sumf = _mm512_cvtepi32_ps(acc[2*iy+k]);
- sum4 = _mm_mul_ps (_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x00)));
- sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
- sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
- sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
- sum4 = _mm_fmadd_ps(dx, _mm_set1_ps(-q8.sum_row(iy)), sum4);
- info.store(8*ix + 4*k, iy, sum4);
- }
- acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
- }
- }
- if (int ix = 8*(nrc_x/8); ix < nrc_x) {
- const float * dptr = (const float *)((const char *)vx + ix*bx);
- auto dl = _mm_loadu_ps(dptr);
- const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
- for (int ib = 0; ib < nb; ++ib) {
- auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
- qx[0] = _mm512_and_si512(bits_l, m3);
- qx[1] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
- qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
- qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = q8.load_quants64(iy, ib);
- acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dy = q8.scale(iy);
- auto sumf = _mm512_cvtepi32_ps(acc[iy]);
- auto sum4 = _mm_mul_ps(_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
- sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
- sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
- sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
- sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
- info.store(ix, iy, sum4);
- }
- }
- }
-}
-#else
-template <int nrc_y>
-static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- if (nrc_x%4) {
- printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
- GGML_ABORT("fatal error");
- }
- mul_mat_iq2_bn_r4_q8_k16_avx2<nrc_y>(n, vx, bx, info, nrc_x);
-}
-#endif
-
-#ifdef HAVE_FANCY_SIMD
-template <int nrc_y>
-static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
- Q8<nrc_y, block_q8_2_x4> q8(info);
- auto m4 = _mm512_set1_epi8(0xf);
- auto values = load_iq4nl_values_512();
- int nb = n / QK4_NL;
- __m512 acc[2*nrc_y] = {};
- __m512i qx[4];
- float d8[8*nrc_y];
- auto prepare = [&qx, &m4, &values] (const block_iq4_nl_r4& iq4l, const block_iq4_nl_r4& iq4h) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l.d));
- auto scales1 = _mm256_set_m128(scales128, scales128);
- scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h.d));
- auto scales2 = _mm256_set_m128(scales128, scales128);
- auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
- auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+0)),
- _mm256_loadu_si256((const __m256i *)iq4h.qs+0), 1);
- auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+1)),
- _mm256_loadu_si256((const __m256i *)iq4h.qs+1), 1);
- qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
- qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
- qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
- qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
- return scales;
- };
- auto dot = [&qx] (__m256i y8) {
- auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- return sumi;
- };
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_iq4_nl_r4 * iq4l = (const block_iq4_nl_r4 *)((const char *)vx + (ix+0)*bx);
- const block_iq4_nl_r4 * iq4h = (const block_iq4_nl_r4 *)((const char *)vx + (ix+4)*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16);
- _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux));
- }
- for (int k = 0; k < 4; ++k) {
- auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
- auto dy = _mm512_set1_ps(d8[8*iy+k]);
- acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
- acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
- }
- }
- }
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto scales = prepare(iq4l[ib], iq4h[ib]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qy = (const block_q8_1 *)q8.y[iy];
- auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
- ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s;
- auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
- acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
- acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-64.f), acc[2*iy+1], acc[2*iy+0]);
- acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
- auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
- auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
- info.store(ix+0, iy, sum1);
- info.store(ix+4, iy, sum2);
- }
- }
-}
-#else
-template <int nrc_y>
-static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_2_x4> q8(info);
- auto m4 = _mm256_set1_epi8(0xf);
- auto m1 = _mm256_set1_epi16(1);
- auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
- auto values = MM256_SET_M128I(values128, values128);
- int nb = n / QK4_NL;
- __m256 acc[nrc_y] = {};
- __m256i qs[4];
- float d8[4*nrc_y];
- auto prepare = [&qs, &values, &m4] (const block_iq4_nl_r4& iq4) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4.d));
- auto scales = _mm256_set_m128(scales128, scales128);
- auto bits1 = _mm256_loadu_si256((const __m256i *)iq4.qs+0);
- auto bits2 = _mm256_loadu_si256((const __m256i *)iq4.qs+1);
- qs[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
- qs[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
- qs[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
- qs[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
- return scales;
- };
- auto dot = [&qs, &m1] (__m256i y) {
- auto u1 = _mm256_sign_epi8(qs[0], qs[0]);
- auto u2 = _mm256_sign_epi8(qs[1], qs[1]);
- auto sumi1 = _mm256_add_epi32(
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qs[0]))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qs[1]))));
- u1 = _mm256_sign_epi8(qs[2], qs[2]);
- u2 = _mm256_sign_epi8(qs[3], qs[3]);
- auto sumi2 = _mm256_add_epi32(
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qs[2]))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qs[3]))));
- return _mm256_add_epi32(sumi1, sumi2);
- };
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto aux = _mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)), 16);
- _mm_storeu_ps(d8+4*iy, _mm_castsi128_ps(aux));
- }
- for (int k = 0; k < 4; ++k) {
- auto scales = prepare(iq4[4*ib4+k]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
- }
- }
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto scales = prepare(iq4[ib]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qy = (const block_q8_1 *)q8.y[iy];
- auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
- ggml_bf16_t d{qy[ib].d};
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- info.store(ix, iy, sum);
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-#endif
-
-inline void prepare_q4_0_quants_avx2(const uint8_t * qs, __m256i * v, const __m256i& m4) {
- auto bits1 = _mm256_loadu_si256((const __m256i *)qs+0);
- auto bits2 = _mm256_loadu_si256((const __m256i *)qs+1);
- auto bits3 = _mm256_loadu_si256((const __m256i *)qs+2);
- auto bits4 = _mm256_loadu_si256((const __m256i *)qs+3);
- v[0] = _mm256_and_si256(bits1, m4);
- v[1] = _mm256_and_si256(bits2, m4);
- v[2] = _mm256_and_si256(bits3, m4);
- v[3] = _mm256_and_si256(bits4, m4);
- v[4] = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4);
- v[5] = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4);
- v[6] = _mm256_and_si256(_mm256_srli_epi16(bits3, 4), m4);
- v[7] = _mm256_and_si256(_mm256_srli_epi16(bits4, 4), m4);
-}
-
-inline __m256i accum_q4_0_quants(const __m256i * v, const int8_t * qs) {
- auto y4l = _mm_loadu_si128((const __m128i*)qs+0);
- auto y4h = _mm_loadu_si128((const __m128i*)qs+1);
- auto yl = MM256_SET_M128I(y4l, y4l);
- auto yh = MM256_SET_M128I(y4h, y4h);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, v[0], _mm256_shuffle_epi32(yl, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, v[1], _mm256_shuffle_epi32(yl, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, v[2], _mm256_shuffle_epi32(yl, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, v[3], _mm256_shuffle_epi32(yl, 0xff));
- sumi = _mm256_dpbusd_epi32(sumi, v[4], _mm256_shuffle_epi32(yh, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, v[5], _mm256_shuffle_epi32(yh, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, v[6], _mm256_shuffle_epi32(yh, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, v[7], _mm256_shuffle_epi32(yh, 0xff));
-#else
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(v[0], _mm256_shuffle_epi32(yl, 0x00)),
- _mm256_maddubs_epi16(v[1], _mm256_shuffle_epi32(yl, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(v[2], _mm256_shuffle_epi32(yl, 0xaa)),
- _mm256_maddubs_epi16(v[3], _mm256_shuffle_epi32(yl, 0xff)));
- auto sumi3 = _mm256_add_epi16(_mm256_maddubs_epi16(v[4], _mm256_shuffle_epi32(yh, 0x00)),
- _mm256_maddubs_epi16(v[5], _mm256_shuffle_epi32(yh, 0x55)));
- auto sumi4 = _mm256_add_epi16(_mm256_maddubs_epi16(v[6], _mm256_shuffle_epi32(yh, 0xaa)),
- _mm256_maddubs_epi16(v[7], _mm256_shuffle_epi32(yh, 0xff)));
- auto sumi = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(_mm256_add_epi16(sumi1, sumi2), _mm256_add_epi16(sumi3, sumi4)));
-#endif
- return sumi;
-}
-
-template <int nrc_y>
-static void mul_mat_q4_0_r8_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
- Q8<nrc_y, block_q8_1_x4> q8(info);
- auto m4 = _mm256_set1_epi8(0xf);
- int nb = n / QK4_NL;
- __m256i v[8];
- GGML_ASSERT(nb%4 == 0);
- if constexpr (nrc_y == 1) {
- union { __m256 vec; float val[8]; } helper;
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_iq4_nl_r8 * iq4 = (const block_iq4_nl_r8 *)((const char *)vx + ix*bx);
- auto acc1 = _mm256_setzero_ps();
- auto acc2 = _mm256_setzero_ps();
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- helper.vec = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)), 16));
- for (int k = 0; k < 4; ++k) {
- auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d));
- prepare_q4_0_quants_avx2(iq4[4*ib4+k].qs, v, m4);
- auto sumi = accum_q4_0_quants(v, q8.y[0][ib4].qs+32*k);
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(helper.val[k]));
- acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1);
- acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(helper.val[k+4]), acc2);
- }
- }
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto qy = (const block_q8_1 *)q8.y[0];
- auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d));
- prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4);
- auto sumi = accum_q4_0_quants(v, qy[ib].qs);
- ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
- acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1);
- acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc2);
- }
- acc1 = _mm256_fmadd_ps(acc2, _mm256_set1_ps(-8.f), acc1);
- info.store(ix, 0, acc1);
- }
- }
- else {
- __m256 acc[nrc_y] = {};
- float d8[8*nrc_y];
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_iq4_nl_r8 * iq4 = (const block_iq4_nl_r8 *)((const char *)vx + ix*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- {
- __m256 d4[4];
- for (int k = 0; k < 4; ++k) {
- d4[k] = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d));
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16));
- _mm256_storeu_ps(d8 + 8*iy, scales);
- auto m4 = _mm256_extractf128_ps(scales, 1);
- auto m8 = _mm256_set_m128(m4, m4);
- auto sumf = _mm256_mul_ps(d4[0], _mm256_shuffle_ps(m8, m8, 0x00));
- sumf = _mm256_fmadd_ps(d4[1], _mm256_shuffle_ps(m8, m8, 0x55), sumf);
- sumf = _mm256_fmadd_ps(d4[2], _mm256_shuffle_ps(m8, m8, 0xaa), sumf);
- sumf = _mm256_fmadd_ps(d4[3], _mm256_shuffle_ps(m8, m8, 0xff), sumf);
- acc[iy] = _mm256_fmadd_ps(sumf, _mm256_set1_ps(-8.f), acc[iy]);
- }
- }
- for (int k = 0; k < 4; ++k) {
- auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d));
- prepare_q4_0_quants_avx2(iq4[4*ib4+k].qs, v, m4);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = accum_q4_0_quants(v, q8.y[iy][ib4].qs+32*k);
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k]));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
- }
- }
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d));
- auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f));
- prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qy = (const block_q8_1 *)q8.y[iy];
- auto sumi = accum_q4_0_quants(v, qy[ib].qs);
- ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[iy]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = _mm256_setzero_ps();
- }
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K128> q8(info);
- int nb = n / 32;
- GGML_ASSERT(nb%4 == 0);
- __m256i qx[4];
- __m256 acc[nrc_y] = {};
- auto m1 = _mm256_set1_epi16(1);
- auto ms = _mm_set1_epi16(-32768);
- float d8[4*nrc_y];
- union { __m256i vec; uint16_t val[16]; } helper;
- struct aux_iq1_s_r4 {
- uint8_t qs[16];
- uint64_t qh;
- };
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
- auto d1 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr));
- auto x = (const aux_iq1_s_r4 *)(dptr + 4);
- for (int ib = 0; ib < nb/4; ++ib) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bsums = _mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib].bsums));
- _mm_storeu_ps(d8 + 4*iy, _mm_mul_ps(_mm_set1_ps(q8.y[iy][ib].d), _mm_cvtepi32_ps(bsums)));
- }
- for (int k = 0; k < 4; ++k) {
- auto idxh = _mm256_set1_epi64x(x[4*ib+k].qh);
- auto sas = _mm256_castsi256_si128(idxh);
- auto scales4 = _mm_and_si128(_mm_srli_epi16(sas, 12), _mm_set1_epi16(7));
- scales4 = _mm_or_si128(_mm_slli_epi16(scales4, 1), _mm_set1_epi16(1));
- auto signs = _mm_or_si128(_mm_cmpeq_epi16(_mm_and_si128(sas, ms), ms), _mm256_castsi256_si128(m1));
- signs = _mm_add_epi16(_mm_set1_epi16(-8), signs);
- signs = _mm_mullo_epi16(signs, scales4);
- auto delta4 = _mm_mul_ps(_mm_set1_ps(0.0625f), _mm_cvtepi32_ps(_mm_cvtepi16_epi32(signs)));
- auto delta = _mm256_set_m128(delta4, delta4);
- scales4 = _mm_unpacklo_epi16(scales4, scales4); // 0,0, 1,1, 2,2, 3,3
- auto scales = MM256_SET_M128I(scales4, scales4);
- auto idxl = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)x[4*ib+k].qs));
- idxh = _mm256_sllv_epi64(idxh, _mm256_set_epi64x(0, 2, 5, 8));
- idxh = _mm256_srlv_epi64(idxh, _mm256_set_epi64x(1, 0, 0, 0));
- helper.vec = _mm256_or_si256(idxl, _mm256_and_si256(_mm256_set1_epi16(0x0700), idxh));
- qx[0] = _mm256_set_epi64x(iq1s_grid_us[helper.val[ 9]], iq1s_grid_us[helper.val[ 8]],
- iq1s_grid_us[helper.val[ 1]], iq1s_grid_us[helper.val[ 0]]);
- qx[1] = _mm256_set_epi64x(iq1s_grid_us[helper.val[13]], iq1s_grid_us[helper.val[12]],
- iq1s_grid_us[helper.val[ 5]], iq1s_grid_us[helper.val[ 4]]);
- qx[2] = _mm256_set_epi64x(iq1s_grid_us[helper.val[11]], iq1s_grid_us[helper.val[10]],
- iq1s_grid_us[helper.val[ 3]], iq1s_grid_us[helper.val[ 2]]);
- qx[3] = _mm256_set_epi64x(iq1s_grid_us[helper.val[15]], iq1s_grid_us[helper.val[14]],
- iq1s_grid_us[helper.val[ 7]], iq1s_grid_us[helper.val[ 6]]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ib].qs + k);
-#ifdef HAVE_FANCY_SIMD
- // 0,0, 1,1, 0,0, 1,1 as int32_t
- auto sumi1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(),
- qx[0], _mm256_shuffle_epi32(y, 0x44)), qx[1], _mm256_shuffle_epi32(y, 0xee));
- // 2,2, 3,3, 2,2, 3,3 as int32_t
- auto sumi2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(),
- qx[2], _mm256_shuffle_epi32(y, 0x44)), qx[3], _mm256_shuffle_epi32(y, 0xee));
- auto sumi = _mm256_packs_epi32(sumi1, sumi2);
-#else
- // 4 x row 0, 4 x row 1, 4 x row 0, 4 x row 1
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x44)),
- _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0xee)));
- // 4 x row 2, 4 x row 3, 4 x row 2, 4 x row 3
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0x44)),
- _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xee)));
- // 0,0, 1,1, 0,0, 1,1 as int32_t
- sumi1 = _mm256_madd_epi16(m1, sumi1);
- // 2,2, 3,3, 2,2, 3,3 as int32_t
- sumi2 = _mm256_madd_epi16(m1, sumi2);
- // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 as int16_t
- auto sumi = _mm256_packs_epi32(sumi1, sumi2);
-#endif
- sumi = _mm256_madd_epi16(scales, sumi);
- acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.y[iy][ib].d), _mm256_cvtepi32_ps(sumi), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[4*iy+k]), delta, acc[iy]);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumf = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- info.store(ix, iy, _mm_mul_ps(d1, sumf));
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-
-// sum[ qy_i * ls_k * (qx_i - 1+/-delta_k)]
-// = sum[qy_i * qx_i * ls_k] - 1/8*sum[qy_i * ls_k * (8+/-o_k)]
-// = 1/8 * ( sum[qy_i * qx_i * 8*ls+k] - sum[qy_i * ls_k * (8+/-o_k)] )
-
-template <int nrc_y>
-static void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(n%QK_K == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- __m256i qx[8];
- __m256i scales[4];
- __m256 acc[nrc_y] = {};
- auto delta_mask = _mm_set1_epi16(-32768); // to avoid stupid overflow warnings when using 0x8000
- __m256i shuffle0 = _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100);
- for (int ix = 0; ix < nrc_x; ++ix) {
- auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);
- for (int ibl = 0; ibl < n/QK_K; ++ibl) {
- float d = GGML_FP16_TO_FP32(iq1s[ibl].d);
- auto qhb = _mm_loadu_si128((const __m128i *)iq1s[ibl].qh);
- auto scales128 = _mm_and_si128(_mm_srli_epi16(qhb, 12), _mm_set1_epi16(7));
- scales128 = _mm_add_epi16(_mm_slli_epi16(scales128, 1), _mm_set1_epi16(1));
-#ifdef HAVE_FANCY_SIMD
- auto mask = _mm_cmpeq_epi16_mask(_mm_and_si128(qhb, delta_mask), delta_mask);
- auto deltas128 = _mm_mask_blend_epi16(mask, _mm_set1_epi16(-7), _mm_set1_epi16(-9));
-#else
- auto mask = _mm_cmpeq_epi16(_mm_and_si128(qhb, delta_mask), delta_mask);
- auto deltas128 = _mm_or_si128(_mm_and_si128(mask, _mm_set1_epi16(-9)), _mm_andnot_si128(mask, _mm_set1_epi16(-7)));
-#endif
- deltas128 = _mm_mullo_epi16(scales128, deltas128);
- scales128 = _mm_slli_epi16(scales128, 3);
- auto deltas_l = _mm_unpacklo_epi16(deltas128, deltas128);
- auto deltas_h = _mm_unpackhi_epi16(deltas128, deltas128);
- auto deltas = MM256_SET_M128I(deltas_h, deltas_l); // blocks 0,0, 1,1, 2,2, ..., 7,7
- auto all_scales = MM256_SET_M128I(scales128, scales128);
- auto shuffle = shuffle0;
- for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
- scales[ib64] = _mm256_shuffle_epi8(all_scales, shuffle);
- shuffle = _mm256_add_epi8(shuffle, _mm256_set1_epi8(4));
- }
- const uint8_t * qs = iq1s[ibl].qs;
- const uint16_t * qh = iq1s[ibl].qh;
- for (int ib = 0; ib < QK_K/32; ib += 2) {
- qx[ib+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid_us[qs[2] | ((qh[ib+0] << 2) & 0x700)],
- iq1s_grid_us[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid_us[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
- qx[ib+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid_us[qs[6] | ((qh[ib+1] << 2) & 0x700)],
- iq1s_grid_us[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid_us[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
- qs += 8;
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bsums = q8.load_bsums(iy, ibl);
- auto sumi = _mm256_setzero_si256();
- for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
- auto qy1 = q8.load_quants(iy, ibl, 2*ib64+0);
- auto qy2 = q8.load_quants(iy, ibl, 2*ib64+1);
-#ifdef HAVE_FANCY_SIMD
- auto dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+0], qy1);
- auto dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+1], qy2);
- sumi = _mm256_dpwssd_epi32(sumi, scales[ib64], _mm256_packs_epi32(dot1, dot2));
-#else
- auto dot1 = _mm256_maddubs_epi16(qx[2*ib64+0], qy1);
- auto dot2 = _mm256_maddubs_epi16(qx[2*ib64+1], qy2);
- auto dot = _mm256_add_epi16(_mm256_unpacklo_epi64(dot1, dot2), _mm256_unpackhi_epi64(dot1, dot2));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(scales[ib64], dot));
-#endif
- }
-#ifdef HAVE_FANCY_SIMD
- sumi = _mm256_dpwssd_epi32(sumi, bsums, deltas);
-#else
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(bsums, deltas));
-#endif
- acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, 0.125f*hsum_float_8(acc[iy]));
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K128> q8(info);
- int nb = n / 32;
- GGML_ASSERT(nb%4 == 0);
- auto shuffle0 = _mm256_set_epi64x(0x0909090909090909, 0x0808080808080808, 0x0101010101010101, 0x0000000000000000);
- auto step = _mm256_set1_epi8(2);
-#ifndef HAVE_FANCY_SIMD
- auto m1 = _mm256_set1_epi16(1);
-#endif
- __m256i qx[4];
- __m256 acc[nrc_y] = {};
- __m256i isum[nrc_y] = {};
- auto ms = _mm_set1_epi8(0x08);
- union { __m256i vec; uint16_t val[16]; } helper;
- for (int ix= 0; ix < nrc_x; ix += 4) {
- auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
- auto d1 = _mm_mul_ps(_mm_set1_ps(0.125f), _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr)));
- auto x = (const block_iq1_m_r4 *)(dptr + 4);
- for (int ib = 0; ib < nb/4; ++ib) {
- for (int k = 0; k < 4; ++k) {
- auto qh = (const uint32_t *)x[4*ib+k].qh;
- auto idxh = _mm_set_epi32(qh[1] >> 4, qh[1], qh[0] >> 4, qh[0]);
- auto scales4 = _mm_set1_epi32(((const uint32_t *)x[4*ib+k].scales)[0]);
- scales4 = _mm_and_si128(_mm_srlv_epi32(scales4, _mm_set_epi32(4, 0, 4, 0)), _mm_set1_epi8(0xf));
- scales4 = _mm_cvtepu8_epi16(scales4);
- auto scales = MM256_SET_M128I(_mm_unpackhi_epi16(scales4, scales4), _mm_unpacklo_epi16(scales4, scales4));
-
- auto signs128 = _mm_or_si128(_mm_cmpeq_epi8(_mm_and_si128(idxh, ms), ms), _mm_set1_epi8(1));
- signs128 = _mm_add_epi8(_mm_set1_epi8(-8), signs128);
- auto signs = MM256_SET_M128I(signs128, signs128);
- auto idxl = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)x[4*ib+k].qs));
- idxh = _mm_and_si128(idxh, _mm_set1_epi8(0x07));
- helper.vec = _mm256_or_si256(idxl, _mm256_slli_epi16(_mm256_cvtepu8_epi16(idxh), 8));
- qx[0] = _mm256_set_epi64x(iq1s_grid_us[helper.val[ 9]], iq1s_grid_us[helper.val[ 8]],
- iq1s_grid_us[helper.val[ 1]], iq1s_grid_us[helper.val[ 0]]);
- qx[1] = _mm256_set_epi64x(iq1s_grid_us[helper.val[13]], iq1s_grid_us[helper.val[12]],
- iq1s_grid_us[helper.val[ 5]], iq1s_grid_us[helper.val[ 4]]);
- qx[2] = _mm256_set_epi64x(iq1s_grid_us[helper.val[11]], iq1s_grid_us[helper.val[10]],
- iq1s_grid_us[helper.val[ 3]], iq1s_grid_us[helper.val[ 2]]);
- qx[3] = _mm256_set_epi64x(iq1s_grid_us[helper.val[15]], iq1s_grid_us[helper.val[14]],
- iq1s_grid_us[helper.val[ 7]], iq1s_grid_us[helper.val[ 6]]);
- qx[0] = _mm256_add_epi8(_mm256_slli_epi16(qx[0], 3), _mm256_shuffle_epi8(signs, shuffle0));
- auto shuffle = _mm256_add_epi8(shuffle0, step);
- qx[2] = _mm256_add_epi8(_mm256_slli_epi16(qx[2], 3), _mm256_shuffle_epi8(signs, shuffle));
- shuffle = _mm256_add_epi8(shuffle, step);
- qx[1] = _mm256_add_epi8(_mm256_slli_epi16(qx[1], 3), _mm256_shuffle_epi8(signs, shuffle));
- shuffle = _mm256_add_epi8(shuffle, step);
- qx[3] = _mm256_add_epi8(_mm256_slli_epi16(qx[3], 3), _mm256_shuffle_epi8(signs, shuffle));
- auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
- auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
- auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
- auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ib].qs + k);
- auto y1 = _mm256_shuffle_epi32(y, 0x44);
- auto y2 = _mm256_shuffle_epi32(y, 0xee);
-#ifdef HAVE_FANCY_SIMD
- // 0,0, 1,1, 0,0, 1,1 as int32_t
- auto sumi1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(),
- s0, _mm256_sign_epi8(y1, qx[0])), s1, _mm256_sign_epi8(y2, qx[1]));
- // 2,2, 3,3, 2,2, 3,3 as int32_t
- auto sumi2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(),
- s2, _mm256_sign_epi8(y1, qx[2])), s3, _mm256_sign_epi8(y2, qx[3]));
- auto sumi = _mm256_packs_epi32(sumi1, sumi2);
-#else
- // 4 x row 0, 4 x row 1, 4 x row 0, 4 x row 1
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(s0, _mm256_sign_epi8(y1, qx[0])),
- _mm256_maddubs_epi16(s1, _mm256_sign_epi8(y2, qx[1])));
- // 4 x row 2, 4 x row 3, 4 x row 2, 4 x row 3
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(s2, _mm256_sign_epi8(y1, qx[2])),
- _mm256_maddubs_epi16(s3, _mm256_sign_epi8(y2, qx[3])));
- // 0,0, 1,1, 0,0, 1,1 as int32_t
- sumi1 = _mm256_madd_epi16(m1, sumi1);
- // 2,2, 3,3, 2,2, 3,3 as int32_t
- sumi2 = _mm256_madd_epi16(m1, sumi2);
- // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 as int16_t
- auto sumi = _mm256_packs_epi32(sumi1, sumi2);
-#endif
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales, sumi));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.y[iy][ib].d), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm256_setzero_si256();
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumf = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- info.store(ix, iy, _mm_mul_ps(d1, sumf));
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-
-#ifdef HAVE_FANCY_SIMD
-template <int nrc_y>
-static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- if constexpr (nrc_y == 1) {
- mul_mat_q4_0_r8_q8_2_avx2<1>(n, vx, bx, info, nrc_x);
- return;
- }
- GGML_ASSERT(nrc_x%16 == 0);
- Q8<nrc_y, block_q8_1_x4> q8(info);
- auto m4 = _mm512_set1_epi8(0xf);
- int nb = n / QK4_NL;
- __m512 acc[2*nrc_y] = {};
- __m512i qx[8];
- auto prepare = [&qx, &m4] (const block_iq4_nl_r8& iq4l, const block_iq4_nl_r8& iq4h) {
- auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l.d));
- auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h.d));
- auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
- for (int j = 0; j < 4; ++j) {
- auto bits = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+j)),
- _mm256_loadu_si256((const __m256i *)iq4h.qs+j), 1);
- qx[j+0] = _mm512_and_si512(bits, m4);
- qx[j+4] = _mm512_and_si512(_mm512_srli_epi16(bits, 4), m4);
- }
- return scales;
- };
- auto dot = [&qx] (const int8_t * qy) {
- auto y4l = _mm_loadu_si128((const __m128i*)qy+0);
- auto y4h = _mm_loadu_si128((const __m128i*)qy+1);
- auto y8l = MM256_SET_M128I(y4l, y4l);
- auto y8h = MM256_SET_M128I(y4h, y4h);
- auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
- auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
- return sumi;
- };
- float d8[8*nrc_y];
- for (int ix = 0; ix < nrc_x; ix += 16) {
- const block_iq4_nl_r8 * iq4l = (const block_iq4_nl_r8 *)((const char *)vx + (ix+0)*bx);
- const block_iq4_nl_r8 * iq4h = (const block_iq4_nl_r8 *)((const char *)vx + (ix+8)*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16);
- _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux));
- }
- for (int k = 0; k < 4; ++k) {
- auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = dot(q8.y[iy][ib4].qs+32*k);
- auto dy = _mm512_set1_ps(d8[8*iy+k]);
- acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
- acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
- }
- }
- }
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto scales = prepare(iq4l[ib], iq4h[ib]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qy = (const block_q8_1 *)q8.y[iy];
- auto sumi = dot(qy[ib].qs);
- ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
- auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
- acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
- acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]);
- acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
- info.store(ix, iy, sum);
- }
- }
-}
-#else
-template <int nrc_y>
-static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- mul_mat_q4_0_r8_q8_2_avx2<nrc_y>(n, vx, bx, info, nrc_x);
-}
-#endif
-
-template <int nrc_y>
-static void mul_mat_q5_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_2_x4> q8(info);
- auto m4 = _mm256_set1_epi8(0xf);
- auto m5 = _mm256_set1_epi8(0x10);
-#ifndef HAVE_FANCY_SIMD
- auto m1 = _mm256_set1_epi16(1);
-#endif
- auto mscale = _mm256_set_m128(_mm_set1_ps(-8.f), _mm_set1_ps(1.f));
- int nb = n / QK5_0;
- __m256 acc[nrc_y] = {};
- __m256i qx[4];
- float d8[8*nrc_y];
- auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5.d));
- auto scales = _mm256_set_m128(scales128, scales128);
- auto bits1 = _mm256_loadu_si256((const __m256i *)iq5.qs+0);
- auto bits2 = _mm256_loadu_si256((const __m256i *)iq5.qs+1);
- auto hbits = _mm_loadu_si128((const __m128i *)iq5.qh);
- auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 1), hbits);
- qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 4), m5));
- qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 2), m5));
- qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hb, m5));
- qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hb, 2), m5));;
- return scales;
- };
-#ifdef HAVE_FANCY_SIMD
- auto dot = [&qx] (__m256i y) {
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- return sumi;
- };
-#else
- auto dot = [&qx, &m1] (__m256i y) {
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
- _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
- _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
- auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
- return sumi;
- };
-#endif
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16));
- _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(mscale, scales));
- }
- for (int k = 0; k < 4; ++k) {
- auto scales = prepare(iq5[4*ib4+k]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k]));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]);
- }
- }
- }
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto scales = prepare(iq5[ib]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qy = (const block_q8_1 *)q8.y[iy];
- auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
- ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-8.f*GGML_BF16_TO_FP32(s)), acc[iy]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- info.store(ix, iy, sum);
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-
-#ifdef HAVE_FANCY_SIMD
-template <int nrc_y>
-static void mul_mat_q5_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- if constexpr (nrc_y == 1) {
- mul_mat_q5_0_r4_q8_2_avx2<1>(n, vx, bx, info, nrc_x);
- } else {
- GGML_ASSERT(nrc_x%8 == 0);
- Q8<nrc_y, block_q8_2_x4> q8(info);
- auto m4 = _mm512_set1_epi8(0xf);
- auto m5 = _mm512_set1_epi8(0x10);
- int nb = n / QK5_0;
- __m512 acc[2*nrc_y] = {};
- __m512i qx[4];
- float d8[8*nrc_y];
- auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5l, const block_q5_0_r4& iq5h) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5l.d));
- auto scales1 = _mm256_set_m128(scales128, scales128);
- scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5h.d));
- auto scales2 = _mm256_set_m128(scales128, scales128);
- auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
- auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+0)),
- _mm256_loadu_si256((const __m256i *)iq5h.qs+0), 1);
- auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+1)),
- _mm256_loadu_si256((const __m256i *)iq5h.qs+1), 1);
- auto hbits1 = _mm_loadu_si128((const __m128i *)iq5l.qh);
- auto hbits2 = _mm_loadu_si128((const __m128i *)iq5h.qh);
- auto hb1 = MM256_SET_M128I(_mm_srli_epi16(hbits1, 1), hbits1);
- auto hb2 = MM256_SET_M128I(_mm_srli_epi16(hbits2, 1), hbits2);
- auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hb1), hb2, 1);
- qx[0] = _mm512_or_si512(_mm512_and_si512(bits1, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 4), m5));
- qx[1] = _mm512_or_si512(_mm512_and_si512(bits2, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5));
- qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4), _mm512_and_si512(hb, m5));
- qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4), _mm512_and_si512(_mm512_srli_epi16(hb, 2), m5));
- return scales;
- };
- auto dot = [&qx] (__m256i y8) {
- auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- return sumi;
- };
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_q5_0_r4 * iq5l = (const block_q5_0_r4 *)((const char *)vx + (ix+0)*bx);
- const block_q5_0_r4 * iq5h = (const block_q5_0_r4 *)((const char *)vx + (ix+4)*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)));
- }
- for (int k = 0; k < 4; ++k) {
- auto scales = prepare(iq5l[4*ib4+k], iq5h[4*ib4+k]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
- auto dy = _mm512_set1_ps(d8[8*iy+k]);
- acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
- acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
- }
- }
- }
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto scales = prepare(iq5l[ib], iq5h[ib]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qy = (const block_q8_1 *)q8.y[iy];
- auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
- ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
- auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
- acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
- acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]);
- acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
- auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
- auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
- info.store(ix+0, iy, sum1);
- info.store(ix+4, iy, sum2);
- }
- }
- }
-}
-#else
-template <int nrc_y>
-static void mul_mat_q5_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- mul_mat_q5_0_r4_q8_2_avx2<nrc_y>(n, vx, bx, info, nrc_x);
-}
-#endif
-
-template <int nrc_y>
-static void mul_mat_q6_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_2_x4> q8(info);
- auto m4 = _mm256_set1_epi8(0xf);
- auto m6 = _mm256_set1_epi8(0x30);
- auto mscale = _mm256_set_m128(_mm_set1_ps(-16.f), _mm_set1_ps(1.f));
-#ifndef HAVE_FANCY_SIMD
- auto m1 = _mm256_set1_epi16(1);
-#endif
- int nb = n / QK6_0;
- __m256 acc[nrc_y] = {};
- float d8[8*nrc_y];
- __m256i qx[4];
- auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6.d));
- auto scales = _mm256_set_m128(scales128, scales128);
- auto bits1 = _mm256_loadu_si256((const __m256i *)iq6.qs+0);
- auto bits2 = _mm256_loadu_si256((const __m256i *)iq6.qs+1);
- auto hbits = _mm256_loadu_si256((const __m256i *)iq6.qh);
- qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), m6));
- qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), m6));
- qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hbits, m6));
- qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m6));
- return scales;
- };
-#ifdef HAVE_FANCY_SIMD
- auto dot = [&qx] (__m256i y) {
- auto sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- return sumi;
- };
-#else
- auto dot = [&qx, &m1] (__m256i y) {
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
- _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
- _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
- auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
- return sumi;
- };
-#endif
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16));
- _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(scales, mscale));
- }
- for (int k = 0; k < 4; ++k) {
- auto scales = prepare(iq6[4*ib4+k]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k]));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]);
- }
- }
- }
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto scales = prepare(iq6[ib]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qy = (const block_q8_1 *)q8.y[iy];
- auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
- ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-16.f*GGML_BF16_TO_FP32(s)), acc[iy]);
- }
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- info.store(ix, iy, sum);
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-
-#ifdef HAVE_FANCY_SIMD
-template <int nrc_y>
-static void mul_mat_q6_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- if constexpr (nrc_y == 1) {
- mul_mat_q6_0_r4_q8_2_avx2<1>(n, vx, bx, info, nrc_x);
- } else {
- GGML_ASSERT(nrc_x%8 == 0);
- Q8<nrc_y, block_q8_2_x4> q8(info);
- auto m4 = _mm512_set1_epi8(0xf);
- auto m6 = _mm512_set1_epi8(0x30);
- int nb = n / QK6_0;
- __m512 acc[2*nrc_y] = {};
- __m512i qx[4];
- float d8[8*nrc_y];
- auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6l, const block_q6_0_r4& iq6h) {
- auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6l.d));
- auto scales1 = _mm256_set_m128(scales128, scales128);
- scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6h.d));
- auto scales2 = _mm256_set_m128(scales128, scales128);
- auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
- auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+0)),
- _mm256_loadu_si256((const __m256i *)iq6h.qs+0), 1);
- auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+1)),
- _mm256_loadu_si256((const __m256i *)iq6h.qs+1), 1);
- auto hbits1 = _mm256_loadu_si256((const __m256i *)iq6l.qh);
- auto hbits2 = _mm256_loadu_si256((const __m256i *)iq6h.qh);
- auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hbits1), hbits2, 1);
- qx[0] = _mm512_and_si512(bits1, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 4), m6);
- qx[1] = _mm512_and_si512(bits2, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m6);;
- qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(hb, m6);
- qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4) | _mm512_and_si512(_mm512_srli_epi16(hb, 2), m6);
- return scales;
- };
- auto dot = [&qx] (__m256i y8) {
- auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- return sumi;
- };
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_q6_0_r4 * iq6l = (const block_q6_0_r4 *)((const char *)vx + (ix+0)*bx);
- const block_q6_0_r4 * iq6h = (const block_q6_0_r4 *)((const char *)vx + (ix+4)*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16));
- _mm256_storeu_ps(d8 + 8*iy, scales);
- }
- for (int k = 0; k < 4; ++k) {
- auto scales = prepare(iq6l[4*ib4+k], iq6h[4*ib4+k]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k));
- auto dy = _mm512_set1_ps(d8[8*iy+k]);
- acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
- acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
- }
- }
- }
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto scales = prepare(iq6l[ib], iq6h[ib]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qy = (const block_q8_1 *)q8.y[iy];
- auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
- ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
- auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
- acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
- acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-16.f), acc[2*iy+1], acc[2*iy+0]);
- acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
- auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
- auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
- info.store(ix+0, iy, sum1);
- info.store(ix+4, iy, sum2);
- }
- }
- }
-}
-#else
-template <int nrc_y>
-static void mul_mat_q6_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- mul_mat_q6_0_r4_q8_2_avx2<nrc_y>(n, vx, bx, info, nrc_x);
-}
-#endif
-
-#ifdef HAVE_FANCY_SIMD
-inline __m512i qx_r8_q8_dot_product(const __m512i * qx, const int8_t * y) {
- auto y4l = _mm_loadu_si128((const __m128i*)y+0);
- auto y4h = _mm_loadu_si128((const __m128i*)y+1);
- auto y8l = MM256_SET_M128I(y4l, y4l);
- auto y8h = MM256_SET_M128I(y4h, y4h);
- auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
- auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
- return sumi;
-}
-inline __m256i qx_r8_q8_dot_product(const __m256i * qx, const int8_t * y) {
- auto y4l = _mm_loadu_si128((const __m128i*)y+0);
- auto y4h = _mm_loadu_si128((const __m128i*)y+1);
- auto yl = MM256_SET_M128I(y4l, y4l);
- auto yh = MM256_SET_M128I(y4h, y4h);
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff));
- sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff));
- return sumi;
-}
-inline __m256i q8_0_r8_dot_product(const uint8_t * x, const int8_t * y, __m256i * qx) {
- for (int i = 0; i < 8; ++i) {
- qx[i] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)x+i), _mm256_set1_epi8(127));
- }
- return qx_r8_q8_dot_product(qx, y);
-}
-template <int nrc_y>
-static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%16 == 0);
- Q8<nrc_y, block_q8_2_x4> q8(info);
- int nb = n / QK8_0;
- if constexpr (nrc_y == 1) {
- __m256 acc[2] = {};
- __m256i qx[8];
- float d8[8];
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)), 16);
- _mm256_storeu_ps(d8, _mm256_castsi256_ps(aux));
- for (int k = 0; k < 4; ++k) {
- auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d));
- auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[4*ib4+k].qs, q8.y[0][ib4].qs+32*k, qx);
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[k]));
- acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]);
- acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[k+4]), acc[1]);
- }
- }
- if (4*(nb/4) < nb) {
- auto qy = (const block_q8_1 *)q8.y[0];
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
- auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx);
- ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s;
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
- acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]);
- acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[1]);
- }
- }
- info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0]));
- acc[0] = acc[1] = _mm256_setzero_ps();
- }
- } else {
- __m512 acc[2*nrc_y] = {};
- __m512i qx[8];
- float d8[8*nrc_y];
- for (int ix = 0; ix < nrc_x; ix += 16) {
- const block_q8_0_r8 * q8l = (const block_q8_0_r8 *)((const char *)vx + (ix+0)*bx);
- const block_q8_0_r8 * q8h = (const block_q8_0_r8 *)((const char *)vx + (ix+8)*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16);
- _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux));
- }
- for (int k = 0; k < 4; ++k) {
- auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[4*ib4+k].d));
- auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[4*ib4+k].d));
- auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
- for (int j = 0; j < 8; ++j) {
- qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+j)),
- _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+j), 1);
- qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127));
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = qx_r8_q8_dot_product(qx, q8.y[iy][ib4].qs+32*k);
- auto dy = _mm512_set1_ps(d8[8*iy+k]);
- acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
- acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
- }
- }
- }
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[ib].d));
- auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[ib].d));
- auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
- for (int j = 0; j < 8; ++j) {
- qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[ib].qs+j)),
- _mm256_loadu_si256((const __m256i *)q8h[ib].qs+j), 1);
- qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127));
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qy = (const block_q8_1 *)q8.y[iy];
- auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs);
- ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s;
- auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
- acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
- acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-127.f), acc[2*iy+1], acc[2*iy+0]);
- info.store(ix, iy, sum512);
- acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
- }
- }
- }
-}
-#else
-template <int nrc_y>
-static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
- Q8<nrc_y, block_q8_2_x4> q8(info);
- auto m1 = _mm256_set1_epi16(1);
- int nb = n / QK8_0;
- __m256 acc[nrc_y] = {};
- float d8[4*nrc_y];
- __m256i qx[4], sx[4];
- auto dot = [&qx, &sx, &m1] (const int8_t * qy) {
- auto y128 = _mm_loadu_si128((const __m128i*)qy);
- auto y = MM256_SET_M128I(y128, y128);
- auto sumi1 = _mm256_add_epi32(
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])))
- );
- auto sumi2 = _mm256_add_epi32(
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]))),
- _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])))
- );
- return _mm256_add_epi32(sumi1, sumi2);
- };
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto scales = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)), 16));
- _mm_storeu_ps(d8 + 4*iy, scales);
- }
- for (int k = 0; k < 4; ++k) {
- auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d));
- for (int j = 0; j < 4; ++j) {
- qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+j);
- sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = dot(q8.y[iy][ib4].qs+32*k);
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
- for (int j = 0; j < 4; ++j) {
- qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4+j);
- sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = dot(q8.y[iy][ib4].qs+32*k+16);
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
- }
- }
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
- for (int j = 0; j < 4; ++j) {
- qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+j);
- sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qy = (const block_q8_2 *)q8.y[iy];
- auto sumi = dot(qy[ib].qs);
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(ggml_bf16_t{qy[ib].d})));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
- for (int j = 0; j < 4; ++j) {
- qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+4+j);
- sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qy = (const block_q8_2 *)q8.y[iy];
- auto sumi = dot(qy[ib].qs+16);
- auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(ggml_bf16_t{qy[ib].d})));
- acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-#endif
-
-template <int nrc_y>
-static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = _mm256_set1_epi8(0xf);
- auto m30 = _mm256_set1_epi8(0x30);
- auto m32 = _mm256_set1_epi8(32);
-#ifndef HAVE_FANCY_SIMD
- auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
- auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
- auto values = MM256_SET_M128I(values128, values128);
-#else
- auto values = load_iq4nl_values_256();
-#endif
- int nbl = n / QK_K;
- using helper_t = union { __m256i vec[2]; uint64_t val[8]; };
- helper_t h;
- __m256 acc[nrc_y] = {};
- __m256i qx[4];
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_iq4_xs_r8 * iq4 = (const block_iq4_xs_r8 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
- auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
- auto sl1 = _mm256_and_si256(slbits, m4);
- auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4);
- auto shbits = _mm_loadu_si128((const __m128i*)iq4[ibl].scales_h);
- auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
- h.vec[0] = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(_mm256_slli_epi16(sh, 4), m30)), m32);
- h.vec[1] = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(sh, m30)), m32);
- __m256i isum[nrc_y] = {};
- for (int ib = 0; ib < QK_K/32; ++ib) {
-#ifdef HAVE_FANCY_SIMD
- auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi64x(h.val[ib]));
- auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
- auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-128.f));
- for (int iy = 0; iy < nrc_y; ++iy) {
- float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
- acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
- }
-#else
- auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(h.val[ib])), s_shuffle);
-#endif
- auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+0);
- auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+1);
- qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1));
- qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4)));
- qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2));
- qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4)));
-#ifndef HAVE_FANCY_SIMD
- auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
- auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
- auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
- auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
-#endif
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+0);
- auto y = MM256_SET_M128I(y128, y128);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
-#else
- auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
- auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
- auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
- auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
- auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)),
- _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
- isum[iy] = _mm256_add_epi32(isum[iy], sumi);
-#endif
- }
- bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+2);
- bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+3);
- qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1));
- qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4)));
- qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2));
- qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4)));
-#ifndef HAVE_FANCY_SIMD
- s1 = _mm256_sign_epi8(qx[0], qx[0]);
- s2 = _mm256_sign_epi8(qx[1], qx[1]);
- s3 = _mm256_sign_epi8(qx[2], qx[2]);
- s4 = _mm256_sign_epi8(qx[3], qx[3]);
-#endif
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+1);
- auto y = MM256_SET_M128I(y128, y128);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
-#else
- auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
- auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
- auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
- auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
- auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)),
- _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
- isum[iy] = _mm256_add_epi32(isum[iy], sumi);
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-
-#ifdef HAVE_FANCY_SIMD
-template <int nrc_y>
-static void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- mul_mat_iq4_xs_r8_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
- return;
- if constexpr (nrc_y == 1){
- mul_mat_iq4_xs_r8_q8_k_avx2<1>(n, vx, bx, info, nrc_x);
- } else {
- GGML_ASSERT(nrc_x%8 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = _mm512_set1_epi8(0xf);
- auto values = load_iq4nl_values_512();
- int nbl = n / QK_K;
- using helper_t = union { __m512i vec; uint32_t val[16]; };
- helper_t h;
- __m512 acc[nrc_y] = {};
- __m512i isum[nrc_y] = {};
- __m512i qx[4];
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_iq4_xs_r8 * iq4l = (const block_iq4_xs_r8 *)((const char *)vx + (ix+0)*bx);
- const block_iq4_xs_r8 * iq4h = (const block_iq4_xs_r8 *)((const char *)vx + (ix+4)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[ibl].d));
- auto dh = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[ibl].d));
- auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1);
- auto d4x64 = _mm512_mul_ps(d4, _mm512_set1_ps(-64.f));
- auto slbits_l = _mm_loadu_si128((const __m128i *)iq4l[ibl].scales_l);
- auto shbits_l = _mm_loadu_si128((const __m128i *)iq4h[ibl].scales_l);
- auto sl_l = MM256_SET_M128I(_mm_srli_epi16(slbits_l, 4), slbits_l);
- auto sh_l = MM256_SET_M128I(_mm_srli_epi16(shbits_l, 4), shbits_l);
- auto slb = _mm512_and_si512(_mm512_inserti32x8(_mm512_castsi256_si512(sl_l), sh_l, 1), m4);
- auto aux64 = (const uint64_t *)iq4l[ibl].scales_h;
- auto slbits_h = _mm_set_epi64x(aux64[0] >> 2, aux64[0]);
- aux64 = (const uint64_t *)iq4h[ibl].scales_h;
- auto shbits_h = _mm_set_epi64x(aux64[0] >> 2, aux64[0]);
- auto sl_h = MM256_SET_M128I(slbits_h, _mm_slli_epi16(slbits_h, 4));
- auto sh_h = MM256_SET_M128I(shbits_h, _mm_slli_epi16(shbits_h, 4));
- auto shb = _mm512_and_si512(_mm512_inserti32x8(_mm512_castsi256_si512(sl_h), sh_h, 1), _mm512_set1_epi8(0x30));
- h.vec = _mm512_sub_epi8(_mm512_or_si512(slb, shb), _mm512_set1_epi8(32));
- for (int ib = 0; ib < QK_K/32; ++ib) {
- auto iscales = _mm512_cvtepi8_epi32(_mm_blend_epi32(_mm_set1_epi32(h.val[ib+0]), _mm_set1_epi32(h.val[ib+8]), 0x0c));
- auto scales = _mm512_cvtepi32_ps(iscales);
- auto scales_m = _mm512_mul_ps(scales, d4x64);
- auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+0)),
- _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+0), 1);
- auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+1)),
- _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+1), 1);
- qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
- qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
- qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
- qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
- auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
- auto sumi = _mm512_setzero_si512();
- sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi));
- float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
- acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm512_setzero_si512();
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1));
- auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3));
- info.store(ix+0, iy, sum1);
- info.store(ix+4, iy, sum2);
- acc[iy] = _mm512_setzero_ps();
- }
- }
- }
-}
-#else
-template <int nrc_y>
-static void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- mul_mat_iq4_xs_r8_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
-}
-#endif
-
-template <int nrc_y>
-static void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = _mm256_set1_epi8(0xf);
-#ifndef HAVE_FANCY_SIMD
- auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
- auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
- auto values = MM256_SET_M128I(values128, values128);
-#else
- auto values = load_iq4nl_values_256();
-#endif
- int nbl = n / QK_K;
- using helper_t = union { __m256i vec; uint32_t val[8]; };
-#ifndef HAVE_FANCY_SIMD
- helper_t h, h_shift;
-#else
- using helper512_t = union { __m512i vec; uint64_t val[8]; };
- helper_t h;
- helper512_t h_shift;
-#endif
- __m256 acc[nrc_y] = {};
- __m256i isum[nrc_y] = {};
- __m256i qx[4];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto dptr = (const float *)((const char *)vx + (ix+0)*bx);
- const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4);
- auto d4 = _mm_loadu_ps(dptr);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto scales = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales);
- h.vec = _mm256_sub_epi8(_mm256_and_si256(scales, _mm256_set1_epi8(-2)), _mm256_set1_epi8(127));
-#ifndef HAVE_FANCY_SIMD
- h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 2);
- {
- __m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[0])))),
- _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[0])))));
- __m256 v2 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[1])))),
- _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[1])))));
- __m256 v3 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[2])))),
- _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[2])))));
- __m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[3])))),
- _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[3])))));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto m8 = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
- acc[iy] = _mm256_fmadd_ps(v1, _mm256_shuffle_ps(m8, m8, 0x00), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(v2, _mm256_shuffle_ps(m8, m8, 0x55), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(v3, _mm256_shuffle_ps(m8, m8, 0xaa), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(v4, _mm256_shuffle_ps(m8, m8, 0xff), acc[iy]);
- }
- }
-#else
- auto shift = _mm256_add_epi8(_mm256_set1_epi8(-64), _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1));
- h_shift.vec = _mm512_mullo_epi16(_mm512_cvtepi8_epi16(shift), _mm512_cvtepi8_epi16(h.vec));
-#endif
- for (int ib = 0; ib < QK_K/32; ++ib) {
-#ifdef HAVE_FANCY_SIMD
- auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib]));
- auto ishifts = _mm256_cvtepi16_epi32(_mm_set1_epi64x(h_shift.val[ib]));
- auto scales_m = _mm256_cvtepi32_ps(ishifts);
- for (int iy = 0; iy < nrc_y; ++iy) {
- float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
- acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
- }
-#endif
- auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
- auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
- qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
- qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
- qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
- qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
-#ifndef HAVE_FANCY_SIMD
- auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle);
- auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
- auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
- auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
- auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
-#endif
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
-#else
- auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
- auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
- auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
- auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, ibl)), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm256_setzero_si256();
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- acc[iy] = _mm256_setzero_ps();
- info.store(ix+0, iy, _mm_mul_ps(d4, sum));
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq2_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- int nbl = n / QK_K;
-#ifndef HAVE_FANCY_SIMD
- auto smask = _mm256_set1_epi64x(0x8040201008040201);
- auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
- auto m4 = _mm256_set1_epi8(4);
- auto m1 = _mm256_set1_epi16(1);
-#endif
- __m256 acc[nrc_y] = {};
- __m256i isum[nrc_y] = {};
- __m256i qx[4];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto iq2 = (const block_iq2_xxs_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
- auto d4 = _mm256_set_m128(dl, dl);
- auto qs = iq2[ibl].qs;
- for (int ib = 0; ib < QK_K/32; ++ib) {
- qx[0] = _mm256_set_epi64x(iq2xxs_grid[qs[ 3]], iq2xxs_grid[qs[ 2]], iq2xxs_grid[qs[ 1]], iq2xxs_grid[qs[ 0]]);
- qx[1] = _mm256_set_epi64x(iq2xxs_grid[qs[ 7]], iq2xxs_grid[qs[ 6]], iq2xxs_grid[qs[ 5]], iq2xxs_grid[qs[ 4]]);
- qx[2] = _mm256_set_epi64x(iq2xxs_grid[qs[11]], iq2xxs_grid[qs[10]], iq2xxs_grid[qs[ 9]], iq2xxs_grid[qs[ 8]]);
- qx[3] = _mm256_set_epi64x(iq2xxs_grid[qs[15]], iq2xxs_grid[qs[14]], iq2xxs_grid[qs[13]], iq2xxs_grid[qs[12]]);
- qs += 16;
- auto sas = _mm_loadu_si128((const __m128i *)iq2[ibl].sas + ib);
- auto scales = _mm_and_si128(sas, _mm_set1_epi8(1));
-#ifdef HAVE_FANCY_SIMD
- scales = _mm_dpbusd_epi32(_mm_set1_epi32(1), scales, _mm_set1_epi32(0x10080402));
-#else
- scales = _mm_maddubs_epi16(scales, _mm_set1_epi32(0x10080402));
- scales = _mm_add_epi32(_mm_madd_epi16(_mm_set1_epi16(1), scales), _mm_set1_epi32(1));
-#endif
- auto scales32 = MM256_SET_M128I(scales, scales);
- auto signs128 = _mm_and_si128(sas, _mm_set1_epi8(-2)); // 0xfe = -2 as signed. Needed to shutup compiler warning.
- signs128 = _mm_xor_si128(signs128, _mm_srli_epi16(signs128, 1));
-#ifdef HAVE_FANCY_SIMD
- auto mask = (const __mmask32 *)&signs128;
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
- auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y));
- auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_mask_sub_epi8(y, mask[1], _mm256_setzero_si256(), y));
- auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_mask_sub_epi8(y, mask[2], _mm256_setzero_si256(), y));
- auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_mask_sub_epi8(y, mask[3], _mm256_setzero_si256(), y));
- auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
- auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
- auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi));
- }
-#else
- auto signs = MM256_SET_M128I(signs128, signs128);
- auto shuffle = sign_shuffle;
- auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
- auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1)));
- auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2)));
- auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3)));
- auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4)));
- auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
- auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
- auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi));
- }
-#endif
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm256_setzero_si256();
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum));
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- int nbl = n / QK_K;
-#ifndef HAVE_FANCY_SIMD
- auto smask = _mm256_set1_epi64x(0x8040201008040201);
- auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
- auto m4 = _mm256_set1_epi8(4);
-#endif
- __m256 acc[nrc_y] = {};
-#ifdef HAVE_FANCY_SIMD
- __m256i shuffles[2] = {
- _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100),
- _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
- };
- __m256i isum[2*nrc_y] = {};
-#else
- __m256i shuffles[4] = {
- MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)),
- MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)),
- MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)),
- MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)),
- };
- __m256i isum[nrc_y == 1 ? 4 : nrc_y] = {};
-#endif
- auto s_shuffle = _mm_set_epi64x(0x0f0d0b0907050301, 0x0e0c0a0806040200);
- __m256i qx[4];
- union { __m256i vec; uint16_t val[16]; } helper;
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto iq2 = (const block_iq2_xs_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
- auto d4 = _mm256_set_m128(dl, dl);
- auto s32 = (const uint32_t *)iq2[ibl].scales;
- for (int ib = 0; ib < QK_K/32; ++ib) {
- auto val = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs + ib);
- helper.vec = _mm256_and_si256(val, _mm256_set1_epi16(511));
- qx[0] = _mm256_set_epi64x(iq2xs_grid[helper.val[ 3]], iq2xs_grid[helper.val[ 2]], iq2xs_grid[helper.val[ 1]], iq2xs_grid[helper.val[ 0]]);
- qx[1] = _mm256_set_epi64x(iq2xs_grid[helper.val[ 7]], iq2xs_grid[helper.val[ 6]], iq2xs_grid[helper.val[ 5]], iq2xs_grid[helper.val[ 4]]);
- qx[2] = _mm256_set_epi64x(iq2xs_grid[helper.val[11]], iq2xs_grid[helper.val[10]], iq2xs_grid[helper.val[ 9]], iq2xs_grid[helper.val[ 8]]);
- qx[3] = _mm256_set_epi64x(iq2xs_grid[helper.val[15]], iq2xs_grid[helper.val[14]], iq2xs_grid[helper.val[13]], iq2xs_grid[helper.val[12]]);
- auto signs16 = _mm256_srli_epi16(val, 9);
- signs16 = _mm256_xor_si256(signs16, _mm256_slli_epi16(signs16, 1));
- auto signs128 = _mm_or_si128(_mm256_castsi256_si128(signs16), _mm_slli_epi16(_mm256_extracti128_si256(signs16, 1), 8));
- signs128 = _mm_shuffle_epi8(signs128, s_shuffle);
- auto scales = _mm_set1_epi32(s32[ib]);
- scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf));
- scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1));
- auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7
-#ifdef HAVE_FANCY_SIMD
- __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) };
- auto mask = (const __mmask32 *)&signs128;
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
- auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y)); // blocks: 0,0,0,0, 1,1,1,1, row 0
- auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_mask_sub_epi8(y, mask[1], _mm256_setzero_si256(), y)); // blocks: 2,2,2,2, 3,3,3,3, row 1
- auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_mask_sub_epi8(y, mask[2], _mm256_setzero_si256(), y)); // blocks: 4,4,4,4, 5,5,5,5, row 2
- auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_mask_sub_epi8(y, mask[3], _mm256_setzero_si256(), y)); // blocks: 6,6,6,6, 7,7,7,7, row 3
- auto s12 = _mm256_packs_epi32(sumi1, sumi2); // 0,0,0,0, 2,2,2,2, 1,1,1,1, 3,3,3,3
- auto s34 = _mm256_packs_epi32(sumi3, sumi4); // 4,4,4,4, 6,6,6,6, 5,5,5,5, 7,7,7,7
- isum[2*iy+0] = _mm256_add_epi32(isum[2*iy+0], _mm256_madd_epi16(scs[0], s12));
- isum[2*iy+1] = _mm256_add_epi32(isum[2*iy+1], _mm256_madd_epi16(scs[1], s34));
- }
-#else
- auto signs = MM256_SET_M128I(signs128, signs128);
- auto shuffle = sign_shuffle;
- auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- __m256i scs[4] = {
- _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]),
- _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]),
- };
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
- if constexpr (nrc_y == 1) {
- isum[0] = _mm256_add_epi32(isum[0], _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))));
- isum[1] = _mm256_add_epi32(isum[1], _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))));
- isum[2] = _mm256_add_epi32(isum[2], _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))));
- isum[3] = _mm256_add_epi32(isum[3], _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))));
- } else {
- auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))); // blocks 4x0, 4x1, row 0
- auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))); // blocks 4x2, 4x3, row 1
- auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))); // blocks 4x4, 4x5, row 2
- auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))); // blocks 4x6, 4x7, row 3
- auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
- auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
- auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
- isum[iy] = _mm256_add_epi32(isum[iy], sumi);
- }
- }
-#endif
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]);
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256();
-#else
- if constexpr (nrc_y == 1) {
- auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[0], isum[1]), _mm256_unpackhi_epi32(isum[0], isum[1]));
- auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[2], isum[3]), _mm256_unpackhi_epi32(isum[2], isum[3]));
- auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34));
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- isum[0] = isum[1] = isum[2] = isum[3] = _mm256_setzero_si256();
- } else {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm256_setzero_si256();
- }
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum));
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-
-static void mul_mat_iq2_xs_r4_q8_k_16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- constexpr int nrc_y = 16;
- Q8<nrc_y, block_q8_K> q8(info);
- int nbl = n / QK_K;
-#ifndef HAVE_FANCY_SIMD
- auto smask = _mm256_set1_epi64x(0x8040201008040201);
- auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
- auto m4 = _mm256_set1_epi8(4);
-#endif
- __m256 acc[nrc_y] = {};
-#ifdef HAVE_FANCY_SIMD
- __m256i shuffles[2] = {
- _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100),
- _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
- };
- __m256i isum[2*nrc_y] = {};
-#else
- __m256i shuffles[4] = {
- MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)),
- MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)),
- MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)),
- MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)),
- };
- __m256i isum[nrc_y == 1 ? 4 : nrc_y] = {};
-#endif
- auto s_shuffle = _mm_set_epi64x(0x0f0d0b0907050301, 0x0e0c0a0806040200);
- __m256i qx[4];
- union { __m256i vec; uint16_t val[16]; } helper;
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto iq2 = (const block_iq2_xs_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
- auto d4 = _mm256_set_m128(dl, dl);
- auto s32 = (const uint32_t *)iq2[ibl].scales;
- {
- auto scale_bits = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales);
- auto scales1 = _mm256_and_si256(scale_bits, _mm256_set1_epi8(0xf));
- auto scales2 = _mm256_and_si256(_mm256_srli_epi16(scale_bits, 4), _mm256_set1_epi8(0xf));
- scales1 = _mm256_or_si256(_mm256_slli_epi16(scales1, 1), _mm256_set1_epi8(1));
- scales2 = _mm256_or_si256(_mm256_slli_epi16(scales2, 1), _mm256_set1_epi8(1));
- auto s1_8 = _mm256_unpacklo_epi8(scales1, scales2); // blocks 0...15, 32...47 (0...3, 8...11 from each row)
- auto s2_8 = _mm256_unpackhi_epi8(scales1, scales2); // blocks 16..31, 48...63 (4...7, 12..15 from each row)
- auto s1_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s1_8)); // 0...15 (0...3 from each row)
- auto s2_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s1_8, 1)); // 32...47 (8..11 from each row)
- auto s3_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s2_8)); // 16...31 (4...7 from each row)
- auto s4_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s2_8, 1)); // 48...63 (12.15 from each row)
- auto t1 = MM256_SET_M128I(_mm256_castsi256_si128(s2_16), _mm256_castsi256_si128(s1_16)); // 0,1 and 8,9 from each row
- auto t2 = MM256_SET_M128I(_mm256_extracti128_si256(s2_16, 1), _mm256_extracti128_si256(s1_16, 1)); // 2,3 and 10,11 from each row
- auto t3 = MM256_SET_M128I(_mm256_castsi256_si128(s4_16), _mm256_castsi256_si128(s3_16)); // 4,5 and 12,13 from each row
- auto t4 = MM256_SET_M128I(_mm256_extracti128_si256(s4_16, 1), _mm256_extracti128_si256(s3_16, 1)); // 6,7 and 14,15 from each row
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bsums = q8.load_bsums(iy, ibl);
- auto sumi = _mm256_setzero_si256();
-#ifdef HAVE_FANCY_SIMD
- sumi = _mm256_dpwssd_epi32(sumi, t1, _mm256_shuffle_epi32(bsums, 0x00));
- sumi = _mm256_dpwssd_epi32(sumi, t2, _mm256_shuffle_epi32(bsums, 0x55));
- sumi = _mm256_dpwssd_epi32(sumi, t3, _mm256_shuffle_epi32(bsums, 0xaa));
- sumi = _mm256_dpwssd_epi32(sumi, t4, _mm256_shuffle_epi32(bsums, 0xff));
-#else
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t1, _mm256_shuffle_epi32(bsums, 0x00)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t2, _mm256_shuffle_epi32(bsums, 0x55)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t3, _mm256_shuffle_epi32(bsums, 0xaa)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t4, _mm256_shuffle_epi32(bsums, 0xff)));
-#endif
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(-64.f*q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
- }
- for (int ib = 0; ib < QK_K/32; ++ib) {
- auto val = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs + ib);
- helper.vec = _mm256_and_si256(val, _mm256_set1_epi16(511));
- qx[0] = _mm256_set_epi64x(iq2xs_grid[helper.val[ 3]], iq2xs_grid[helper.val[ 2]], iq2xs_grid[helper.val[ 1]], iq2xs_grid[helper.val[ 0]]);
- qx[1] = _mm256_set_epi64x(iq2xs_grid[helper.val[ 7]], iq2xs_grid[helper.val[ 6]], iq2xs_grid[helper.val[ 5]], iq2xs_grid[helper.val[ 4]]);
- qx[2] = _mm256_set_epi64x(iq2xs_grid[helper.val[11]], iq2xs_grid[helper.val[10]], iq2xs_grid[helper.val[ 9]], iq2xs_grid[helper.val[ 8]]);
- qx[3] = _mm256_set_epi64x(iq2xs_grid[helper.val[15]], iq2xs_grid[helper.val[14]], iq2xs_grid[helper.val[13]], iq2xs_grid[helper.val[12]]);
- auto signs16 = _mm256_srli_epi16(val, 9);
- signs16 = _mm256_xor_si256(signs16, _mm256_slli_epi16(signs16, 1));
- auto signs128 = _mm_or_si128(_mm256_castsi256_si128(signs16), _mm_slli_epi16(_mm256_extracti128_si256(signs16, 1), 8));
- signs128 = _mm_shuffle_epi8(signs128, s_shuffle);
- auto scales = _mm_set1_epi32(s32[ib]);
- scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf));
- scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1));
- auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7
-#ifdef HAVE_FANCY_SIMD
- __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) };
- auto mask = (const __mmask32 *)&signs128;
- qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[0], mask[0], _mm256_setzero_si256(), qx[0]));
- qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[1], mask[1], _mm256_setzero_si256(), qx[1]));
- qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[2], mask[2], _mm256_setzero_si256(), qx[2]));
- qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[3], mask[3], _mm256_setzero_si256(), qx[3]));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
- auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], y); // blocks: 0,0,0,0, 1,1,1,1, row 0
- auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], y); // blocks: 2,2,2,2, 3,3,3,3, row 1
- auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], y); // blocks: 4,4,4,4, 5,5,5,5, row 2
- auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], y); // blocks: 6,6,6,6, 7,7,7,7, row 3
- auto s12 = _mm256_packs_epi32(sumi1, sumi2); // 0,0,0,0, 2,2,2,2, 1,1,1,1, 3,3,3,3
- auto s34 = _mm256_packs_epi32(sumi3, sumi4); // 4,4,4,4, 6,6,6,6, 5,5,5,5, 7,7,7,7
- isum[2*iy+0] = _mm256_add_epi32(isum[2*iy+0], _mm256_madd_epi16(scs[0], s12));
- isum[2*iy+1] = _mm256_add_epi32(isum[2*iy+1], _mm256_madd_epi16(scs[1], s34));
- }
-#else
- auto signs = MM256_SET_M128I(signs128, signs128);
- auto shuffle = sign_shuffle;
- auto s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[0], s));
- s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[1], s));
- s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[2], s));
- s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[3], s));
- __m256i scs[4] = {
- _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]),
- _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]),
- };
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
- auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], y)); // blocks 4x0, 4x1, row 0
- auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], y)); // blocks 4x2, 4x3, row 1
- auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], y)); // blocks 4x4, 4x5, row 2
- auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], y)); // blocks 4x6, 4x7, row 3
- auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
- auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
- auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
- isum[iy] = _mm256_add_epi32(isum[iy], sumi);
- }
-#endif
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]);
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256();
-#else
- if constexpr (nrc_y == 1) {
- auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[0], isum[1]), _mm256_unpackhi_epi32(isum[0], isum[1]));
- auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[2], isum[3]), _mm256_unpackhi_epi32(isum[2], isum[3]));
- auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34));
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- isum[0] = isum[1] = isum[2] = isum[3] = _mm256_setzero_si256();
- } else {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm256_setzero_si256();
- }
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum));
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- int nbl = n / QK_K;
-#ifndef HAVE_FANCY_SIMD
- auto smask = _mm256_set1_epi64x(0x8040201008040201);
- auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
- auto m4 = _mm256_set1_epi8(4);
-#endif
- __m256 acc[nrc_y] = {};
-#ifdef HAVE_FANCY_SIMD
- __m256i shuffles[2] = {
- _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100),
- _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
- };
- __m256i isum[2*nrc_y] = {};
-#else
- __m256i shuffles[4] = {
- MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)),
- MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)),
- MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)),
- MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)),
- };
- __m256i isum[nrc_y == 1 ? 4 : nrc_y] = {};
-#endif
- __m256i qx[4];
- auto grid = iq2s_grid;
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
- auto d4 = _mm256_set_m128(dl, dl);
- auto s32 = (const uint32_t *)iq2[ibl].scales;
- auto ql = iq2[ibl].qs;
- auto qh = iq2[ibl].qh;
- for (int ib = 0; ib < QK_K/32; ++ib) {
- qx[0] = _mm256_set_epi64x(grid[ql[ 3] | ((qh[0] << 2) & 0x300)], grid[ql[ 2] | ((qh[0] << 4) & 0x300)], grid[ql[ 1] | ((qh[0] << 6) & 0x300)], grid[ql[ 0] | ((qh[0] << 8) & 0x300)]);
- qx[1] = _mm256_set_epi64x(grid[ql[ 7] | ((qh[1] << 2) & 0x300)], grid[ql[ 6] | ((qh[1] << 4) & 0x300)], grid[ql[ 5] | ((qh[1] << 6) & 0x300)], grid[ql[ 4] | ((qh[1] << 8) & 0x300)]);
- qx[2] = _mm256_set_epi64x(grid[ql[11] | ((qh[2] << 2) & 0x300)], grid[ql[10] | ((qh[2] << 4) & 0x300)], grid[ql[ 9] | ((qh[2] << 6) & 0x300)], grid[ql[ 8] | ((qh[2] << 8) & 0x300)]);
- qx[3] = _mm256_set_epi64x(grid[ql[15] | ((qh[3] << 2) & 0x300)], grid[ql[14] | ((qh[3] << 4) & 0x300)], grid[ql[13] | ((qh[3] << 6) & 0x300)], grid[ql[12] | ((qh[3] << 8) & 0x300)]);
- ql += 16; qh += 4;
- auto signs128 = _mm_loadu_si128((const __m128i*)iq2[ibl].signs + ib);
- auto scales = _mm_set1_epi32(s32[ib]);
- scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf));
- scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1));
- auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7
-#ifdef HAVE_FANCY_SIMD
- __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) };
- auto mask = (const __mmask32 *)&signs128;
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
- auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y)); // blocks: 0,0,0,0, 1,1,1,1, row 0
- auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_mask_sub_epi8(y, mask[1], _mm256_setzero_si256(), y)); // blocks: 2,2,2,2, 3,3,3,3, row 1
- auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_mask_sub_epi8(y, mask[2], _mm256_setzero_si256(), y)); // blocks: 4,4,4,4, 5,5,5,5, row 2
- auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_mask_sub_epi8(y, mask[3], _mm256_setzero_si256(), y)); // blocks: 6,6,6,6, 7,7,7,7, row 3
- auto s12 = _mm256_packs_epi32(sumi1, sumi2); // 0,0,0,0, 2,2,2,2, 1,1,1,1, 3,3,3,3
- auto s34 = _mm256_packs_epi32(sumi3, sumi4); // 4,4,4,4, 6,6,6,6, 5,5,5,5, 7,7,7,7
- isum[2*iy+0] = _mm256_add_epi32(isum[2*iy+0], _mm256_madd_epi16(scs[0], s12));
- isum[2*iy+1] = _mm256_add_epi32(isum[2*iy+1], _mm256_madd_epi16(scs[1], s34));
- }
-#else
- auto signs = MM256_SET_M128I(signs128, signs128);
- auto shuffle = sign_shuffle;
- auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- __m256i scs[4] = {
- _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]),
- _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]),
- };
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
- if constexpr (nrc_y == 1) {
- isum[0] = _mm256_add_epi32(isum[0], _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))));
- isum[1] = _mm256_add_epi32(isum[1], _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))));
- isum[2] = _mm256_add_epi32(isum[2], _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))));
- isum[3] = _mm256_add_epi32(isum[3], _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))));
- } else {
- auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))); // blocks 4x0, 4x1, row 0
- auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))); // blocks 4x2, 4x3, row 1
- auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))); // blocks 4x4, 4x5, row 2
- auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))); // blocks 4x6, 4x7, row 3
- auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
- auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
- auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
- isum[iy] = _mm256_add_epi32(isum[iy], sumi);
- }
- }
-#endif
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]);
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256();
-#else
- if constexpr (nrc_y == 1) {
- auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[0], isum[1]), _mm256_unpackhi_epi32(isum[0], isum[1]));
- auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[2], isum[3]), _mm256_unpackhi_epi32(isum[2], isum[3]));
- auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34));
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- isum[0] = isum[1] = isum[2] = isum[3] = _mm256_setzero_si256();
- } else {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm256_setzero_si256();
- }
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum));
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-
-static void mul_mat_iq2_s_r4_q8_k_16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- constexpr int nrc_y = 16;
- Q8<nrc_y, block_q8_K> q8(info);
- int nbl = n / QK_K;
-#ifndef HAVE_FANCY_SIMD
- auto smask = _mm256_set1_epi64x(0x8040201008040201);
- auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
- auto m4 = _mm256_set1_epi8(4);
-#endif
- __m256 acc[nrc_y] = {};
-#ifdef HAVE_FANCY_SIMD
- __m256i shuffles[2] = {
- _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100),
- _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
- };
- __m256i isum[2*nrc_y] = {};
-#else
- __m256i shuffles[4] = {
- MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)),
- MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)),
- MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)),
- MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)),
- };
- __m256i isum[nrc_y == 1 ? 4 : nrc_y] = {};
-#endif
- __m256i qx[4];
- auto grid = iq2s_grid;
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
- auto d4 = _mm256_set_m128(dl, dl);
- auto s32 = (const uint32_t *)iq2[ibl].scales;
- auto ql = iq2[ibl].qs;
- auto qh = iq2[ibl].qh;
- {
- auto scale_bits = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales);
- auto scales1 = _mm256_and_si256(scale_bits, _mm256_set1_epi8(0xf));
- auto scales2 = _mm256_and_si256(_mm256_srli_epi16(scale_bits, 4), _mm256_set1_epi8(0xf));
- scales1 = _mm256_or_si256(_mm256_slli_epi16(scales1, 1), _mm256_set1_epi8(1));
- scales2 = _mm256_or_si256(_mm256_slli_epi16(scales2, 1), _mm256_set1_epi8(1));
- auto s1_8 = _mm256_unpacklo_epi8(scales1, scales2); // blocks 0...15, 32...47 (0...3, 8...11 from each row)
- auto s2_8 = _mm256_unpackhi_epi8(scales1, scales2); // blocks 16..31, 48...63 (4...7, 12..15 from each row)
- auto s1_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s1_8)); // 0...15 (0...3 from each row)
- auto s2_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s1_8, 1)); // 32...47 (8..11 from each row)
- auto s3_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s2_8)); // 16...31 (4...7 from each row)
- auto s4_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s2_8, 1)); // 48...63 (12.15 from each row)
- auto t1 = MM256_SET_M128I(_mm256_castsi256_si128(s2_16), _mm256_castsi256_si128(s1_16)); // 0,1 and 8,9 from each row
- auto t2 = MM256_SET_M128I(_mm256_extracti128_si256(s2_16, 1), _mm256_extracti128_si256(s1_16, 1)); // 2,3 and 10,11 from each row
- auto t3 = MM256_SET_M128I(_mm256_castsi256_si128(s4_16), _mm256_castsi256_si128(s3_16)); // 4,5 and 12,13 from each row
- auto t4 = MM256_SET_M128I(_mm256_extracti128_si256(s4_16, 1), _mm256_extracti128_si256(s3_16, 1)); // 6,7 and 14,15 from each row
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bsums = q8.load_bsums(iy, ibl);
- auto sumi = _mm256_setzero_si256();
-#ifdef HAVE_FANCY_SIMD
- sumi = _mm256_dpwssd_epi32(sumi, t1, _mm256_shuffle_epi32(bsums, 0x00));
- sumi = _mm256_dpwssd_epi32(sumi, t2, _mm256_shuffle_epi32(bsums, 0x55));
- sumi = _mm256_dpwssd_epi32(sumi, t3, _mm256_shuffle_epi32(bsums, 0xaa));
- sumi = _mm256_dpwssd_epi32(sumi, t4, _mm256_shuffle_epi32(bsums, 0xff));
-#else
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t1, _mm256_shuffle_epi32(bsums, 0x00)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t2, _mm256_shuffle_epi32(bsums, 0x55)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t3, _mm256_shuffle_epi32(bsums, 0xaa)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t4, _mm256_shuffle_epi32(bsums, 0xff)));
-#endif
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(-64.f*q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
- }
- for (int ib = 0; ib < QK_K/32; ++ib) {
- qx[0] = _mm256_set_epi64x(grid[ql[ 3] | ((qh[0] << 2) & 0x300)], grid[ql[ 2] | ((qh[0] << 4) & 0x300)], grid[ql[ 1] | ((qh[0] << 6) & 0x300)], grid[ql[ 0] | ((qh[0] << 8) & 0x300)]);
- qx[1] = _mm256_set_epi64x(grid[ql[ 7] | ((qh[1] << 2) & 0x300)], grid[ql[ 6] | ((qh[1] << 4) & 0x300)], grid[ql[ 5] | ((qh[1] << 6) & 0x300)], grid[ql[ 4] | ((qh[1] << 8) & 0x300)]);
- qx[2] = _mm256_set_epi64x(grid[ql[11] | ((qh[2] << 2) & 0x300)], grid[ql[10] | ((qh[2] << 4) & 0x300)], grid[ql[ 9] | ((qh[2] << 6) & 0x300)], grid[ql[ 8] | ((qh[2] << 8) & 0x300)]);
- qx[3] = _mm256_set_epi64x(grid[ql[15] | ((qh[3] << 2) & 0x300)], grid[ql[14] | ((qh[3] << 4) & 0x300)], grid[ql[13] | ((qh[3] << 6) & 0x300)], grid[ql[12] | ((qh[3] << 8) & 0x300)]);
- ql += 16; qh += 4;
- auto signs128 = _mm_loadu_si128((const __m128i*)iq2[ibl].signs + ib);
- auto scales = _mm_set1_epi32(s32[ib]);
- scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf));
- scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1));
- auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7
-#ifdef HAVE_FANCY_SIMD
- __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) };
- auto mask = (const __mmask32 *)&signs128;
- qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[0], mask[0], _mm256_setzero_si256(), qx[0]));
- qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[1], mask[1], _mm256_setzero_si256(), qx[1]));
- qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[2], mask[2], _mm256_setzero_si256(), qx[2]));
- qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[3], mask[3], _mm256_setzero_si256(), qx[3]));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
- auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], y); // blocks: 0,0,0,0, 1,1,1,1, row 0
- auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], y); // blocks: 2,2,2,2, 3,3,3,3, row 1
- auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], y); // blocks: 4,4,4,4, 5,5,5,5, row 2
- auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], y); // blocks: 6,6,6,6, 7,7,7,7, row 3
- auto s12 = _mm256_packs_epi32(sumi1, sumi2); // 0,0,0,0, 2,2,2,2, 1,1,1,1, 3,3,3,3
- auto s34 = _mm256_packs_epi32(sumi3, sumi4); // 4,4,4,4, 6,6,6,6, 5,5,5,5, 7,7,7,7
- isum[2*iy+0] = _mm256_add_epi32(isum[2*iy+0], _mm256_madd_epi16(scs[0], s12));
- isum[2*iy+1] = _mm256_add_epi32(isum[2*iy+1], _mm256_madd_epi16(scs[1], s34));
- }
-#else
- auto signs = MM256_SET_M128I(signs128, signs128);
- auto shuffle = sign_shuffle;
- auto s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[0], s));
- s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[1], s));
- s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[2], s));
- s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[3], s));
- __m256i scs[4] = {
- _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]),
- _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]),
- };
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
- auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], y)); // blocks 4x0, 4x1, row 0
- auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], y)); // blocks 4x2, 4x3, row 1
- auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], y)); // blocks 4x4, 4x5, row 2
- auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], y)); // blocks 4x6, 4x7, row 3
- auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
- auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
- auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
- isum[iy] = _mm256_add_epi32(isum[iy], sumi);
- }
-#endif
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]);
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256();
-#else
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm256_setzero_si256();
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum));
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- int nbl = n / QK_K;
-#ifndef HAVE_FANCY_SIMD
- auto smask = _mm256_set1_epi64x(0x8040201008040201);
- auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
- auto m4 = _mm256_set1_epi8(4);
- auto m1 = _mm256_set1_epi16(1);
-#endif
- __m256 acc[nrc_y] = {};
- __m256i isum[nrc_y] = {};
- __m256i qx[4];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto iq3 = (const block_iq3_xxs_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_mul_ps(_mm_set1_ps(0.25f), _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d))); // TODO: absorb the 0.25 factor into d when quantizing/repacking
- auto d4 = _mm256_set_m128(dl, dl);
- for (int ib = 0; ib < QK_K/32; ++ib) {
- qx[0] = _mm256_set_epi32(iq3xxs_grid[iq3[ibl].qs[32*ib+ 7]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 6]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 5]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 4]],
- iq3xxs_grid[iq3[ibl].qs[32*ib+ 3]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 2]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 1]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 0]]);
- qx[1] = _mm256_set_epi32(iq3xxs_grid[iq3[ibl].qs[32*ib+15]], iq3xxs_grid[iq3[ibl].qs[32*ib+14]], iq3xxs_grid[iq3[ibl].qs[32*ib+13]], iq3xxs_grid[iq3[ibl].qs[32*ib+12]],
- iq3xxs_grid[iq3[ibl].qs[32*ib+11]], iq3xxs_grid[iq3[ibl].qs[32*ib+10]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 9]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 8]]);
- qx[2] = _mm256_set_epi32(iq3xxs_grid[iq3[ibl].qs[32*ib+23]], iq3xxs_grid[iq3[ibl].qs[32*ib+22]], iq3xxs_grid[iq3[ibl].qs[32*ib+21]], iq3xxs_grid[iq3[ibl].qs[32*ib+20]],
- iq3xxs_grid[iq3[ibl].qs[32*ib+19]], iq3xxs_grid[iq3[ibl].qs[32*ib+18]], iq3xxs_grid[iq3[ibl].qs[32*ib+17]], iq3xxs_grid[iq3[ibl].qs[32*ib+16]]);
- qx[3] = _mm256_set_epi32(iq3xxs_grid[iq3[ibl].qs[32*ib+31]], iq3xxs_grid[iq3[ibl].qs[32*ib+30]], iq3xxs_grid[iq3[ibl].qs[32*ib+29]], iq3xxs_grid[iq3[ibl].qs[32*ib+28]],
- iq3xxs_grid[iq3[ibl].qs[32*ib+27]], iq3xxs_grid[iq3[ibl].qs[32*ib+26]], iq3xxs_grid[iq3[ibl].qs[32*ib+25]], iq3xxs_grid[iq3[ibl].qs[32*ib+24]]);
- auto sas = _mm_loadu_si128((const __m128i *)iq3[ibl].sas + ib);
- auto scales = _mm_and_si128(sas, _mm_set1_epi8(1));
-#ifdef HAVE_FANCY_SIMD
- scales = _mm_dpbusd_epi32(_mm_set1_epi32(1), scales, _mm_set1_epi32(0x10080402));
-#else
- scales = _mm_maddubs_epi16(scales, _mm_set1_epi32(0x10080402));
- scales = _mm_add_epi32(_mm_madd_epi16(_mm_set1_epi16(1), scales), _mm_set1_epi32(1));
- //auto t1 = _mm_or_si128(_mm_and_si128(scales, _mm_set1_epi32(0x00000001)), _mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x00000100)), 7));
- //auto t2 = _mm_or_si128(_mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x00010000)), 14), _mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x01000000)), 21));
- //scales = _mm_or_si128(_mm_slli_epi32(_mm_or_si128(t1, t2), 1), _mm_set1_epi32(1));
-#endif
- auto scales32 = MM256_SET_M128I(scales, scales);
- auto signs128 = _mm_and_si128(sas, _mm_set1_epi8(-2)); // 0xfe = -2 as signed. Needed to shutup compiler warning.
- signs128 = _mm_xor_si128(signs128, _mm_srli_epi16(signs128, 1));
-#ifdef HAVE_FANCY_SIMD
- auto mask = (const __mmask32 *)&signs128;
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
- auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y));
- auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_mask_sub_epi8(y, mask[1], _mm256_setzero_si256(), y));
- auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_mask_sub_epi8(y, mask[2], _mm256_setzero_si256(), y));
- auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_mask_sub_epi8(y, mask[3], _mm256_setzero_si256(), y));
- auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
- auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
- auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi));
- }
-#else
- auto signs = MM256_SET_M128I(signs128, signs128);
- auto shuffle = sign_shuffle;
- auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- shuffle = _mm256_add_epi8(shuffle, m4);
- auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
- auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1)));
- auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2)));
- auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3)));
- auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4)));
- auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
- auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
- auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi));
- }
-#endif
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm256_setzero_si256();
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- info.store(ix, iy, sum);
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-
-#ifdef HAVE_FANCY_SIMD
-// Strangely enough, the following implementation makes PP ~6% slower and TG ~6% faster
-// compared to the vanilla AVX2 version below.
-struct IndexHelperIQ3S {
- union index_t {
- __m256i vec;
- uint16_t val[16];
- };
- inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const {
- auto idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs));
- const __mmask16 * m16 = (const __mmask16 *)qh;
- index_t idx;
- idx.vec = _mm256_mask_add_epi16(idx_l, m16[0], idx_l, offset);
- values[0] = _mm256_set_epi32(iq3s_grid[idx.val[ 7]], iq3s_grid[idx.val[ 6]], iq3s_grid[idx.val[ 5]], iq3s_grid[idx.val[ 4]],
- iq3s_grid[idx.val[ 3]], iq3s_grid[idx.val[ 2]], iq3s_grid[idx.val[ 1]], iq3s_grid[idx.val[ 0]]);
- values[1] = _mm256_set_epi32(iq3s_grid[idx.val[15]], iq3s_grid[idx.val[14]], iq3s_grid[idx.val[13]], iq3s_grid[idx.val[12]],
- iq3s_grid[idx.val[11]], iq3s_grid[idx.val[10]], iq3s_grid[idx.val[ 9]], iq3s_grid[idx.val[ 8]]);
- }
- const __m256i offset = _mm256_set1_epi16(256);
-};
-#else
-struct IndexHelperIQ3S {
- union index_t {
- __m256i vec;
- uint32_t val[8];
- };
- inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const {
- index_t idx;
- auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs));
- auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask);
- idx.vec = _mm256_or_si256(idx_h, idx_l);
- values[0] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]],
- iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]);
- idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs+8)));
- idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask);
- idx.vec = _mm256_or_si256(idx_h, idx_l);
- values[1] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]],
- iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]);
- }
- const __m256i idx_mask = _mm256_set1_epi32(256);
- const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
-};
-#endif
-
-template <int nrc_y>
-static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- int nbl = n / QK_K;
- auto smask = _mm256_set1_epi8(1);
- union { __m256i vec; uint32_t val[8]; } helper;
- union { __m128i vec; uint16_t val[8]; } hidx;
- __m256 acc[nrc_y] = {};
- __m256i isum[nrc_y] = {};
- __m256i qx[4];
-#ifdef HAVE_FANCY_SIMD
- __mmask32 mask[4];
-#endif
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d));
- auto d4 = _mm256_set_m128(dl, dl);
- auto qs = iq3[ibl].qs;
- auto qh = iq3[ibl].qh;
- auto scale_bits = _mm_loadu_si128((const __m128i *)iq3[ibl].scales);
- auto scales8 = MM256_SET_M128I(_mm_srli_epi16(scale_bits, 4), scale_bits);
- helper.vec = _mm256_or_si256(_mm256_slli_epi16(_mm256_and_si256(scales8, _mm256_set1_epi8(0xf)), 1), _mm256_set1_epi8(1));
- for (int ib = 0; ib < QK_K/32; ++ib) {
- auto qh32 = (const uint32_t *)qh;
- auto idx_h = _mm_sllv_epi64(_mm_cvtepu8_epi16(_mm_set1_epi32(qh32[0])), _mm_set_epi64x(4, 8));
- for (int i = 0; i < 4; ++i) {
- auto idx_l = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)(qs + 8*i)));
- hidx.vec = _mm_or_si128(idx_l, _mm_and_si128(idx_h, _mm_set1_epi16(0x100))); idx_h = _mm_srli_epi16(idx_h, 1);
- qx[i] = _mm256_set_epi32(iq3s_grid[hidx.val[7]], iq3s_grid[hidx.val[6]], iq3s_grid[hidx.val[5]], iq3s_grid[hidx.val[4]],
- iq3s_grid[hidx.val[3]], iq3s_grid[hidx.val[2]], iq3s_grid[hidx.val[1]], iq3s_grid[hidx.val[0]]);
- }
- qs += 32; qh += 4;
- auto signs128 = _mm_loadu_si128((const __m128i*)iq3[ibl].signs + ib);
- auto signs = MM256_SET_M128I(_mm_srli_epi16(signs128, 4), signs128);
-#ifdef HAVE_FANCY_SIMD
- auto scales = _mm256_cvtepi8_epi32(_mm_set1_epi32(helper.val[ib]));
- mask[0] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1);
- mask[1] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1);
- mask[2] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1);
- mask[3] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
- auto sumi = _mm256_setzero_si256();
- auto ys = _mm256_shuffle_epi32(y, 0x00);
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_mask_sub_epi8(ys, mask[0], _mm256_setzero_si256(), ys));
- ys = _mm256_shuffle_epi32(y, 0x55);
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_mask_sub_epi8(ys, mask[1], _mm256_setzero_si256(), ys));
- ys = _mm256_shuffle_epi32(y, 0xaa);
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_mask_sub_epi8(ys, mask[2], _mm256_setzero_si256(), ys));
- ys = _mm256_shuffle_epi32(y, 0xff);
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_mask_sub_epi8(ys, mask[3], _mm256_setzero_si256(), ys));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(sumi, scales));
- }
-#else
- auto scales16 = _mm256_cvtepi8_epi16(_mm_set1_epi32(helper.val[ib]));
- auto scales = _mm256_unpacklo_epi16(scales16, scales16);
- auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1);
- auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1);
- auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1);
- auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), s1)));
- sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), s2)));
- sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), s3)));
- sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), s4)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales, sumi));
- }
-#endif
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm256_setzero_si256();
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- info.store(ix, iy, sum);
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-
-template <int nrc_y>
-inline void process_min_r4_b32(int ibl, __m256 m4, __m256i mins, const Q8<nrc_y, block_q8_K>& q8, __m256 * acc) {
- auto mins_l = _mm256_castsi256_si128(mins);
- auto mins_h = _mm256_extracti128_si256(mins, 1);
- auto aux1 = _mm_unpacklo_epi32(mins_l, mins_h);
- auto aux2 = _mm_unpackhi_epi32(mins_l, mins_h);
- auto ic1 = _mm256_cvtepi8_epi32(aux1);
- auto ic2 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux1, 0xee));
- auto ic3 = _mm256_cvtepi8_epi32(aux2);
- auto ic4 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux2, 0xee));
- if constexpr (nrc_y == 1) {
- auto bs = _mm256_loadu_ps((const float *)q8.y[0][ibl].bsums);
- auto sumf = _mm256_mul_ps(_mm256_cvtepi32_ps(ic1), _mm256_shuffle_ps(bs, bs, 0x00));
- sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic2), _mm256_shuffle_ps(bs, bs, 0x55), sumf);
- sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic3), _mm256_shuffle_ps(bs, bs, 0xaa), sumf);
- sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic4), _mm256_shuffle_ps(bs, bs, 0xff), sumf);
- acc[0] = _mm256_fmadd_ps(m4, sumf, acc[0]);
- } else {
- auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic1));
- auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic2));
- auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic3));
- auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic4));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
- acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto mf = _mm256_set1_epi8(0xf);
- auto m3 = _mm256_set1_epi8(0x30);
- int nbl = n / QK_K;
- union { __m256i vec; uint32_t val[8]; } hd;
- __m256 acc[nrc_y] = {};
- __m256i isum[nrc_y] = {};
- __m256i qx[4];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
- auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
- auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
- auto lbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
- auto hbits128 = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_h);
- auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
- hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m3));
- auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3));
- process_min_r4_b32(ibl, m4, mins, q8, acc);
- for (int ib = 0; ib < QK_K/32; ++ib) {
-#ifdef HAVE_FANCY_SIMD
- auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
-#else
- auto aux = _mm_set1_epi32(hd.val[ib]);
- aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
- auto scales_d = MM256_SET_M128I(aux, aux);
-#endif
- auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
- auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
- qx[0] = _mm256_and_si256(bits1, mf);
- qx[1] = _mm256_and_si256(bits2, mf);
- qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), mf);
- qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), mf);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
-#else
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
- _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
- _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales_d, _mm256_add_epi16(sumi1, sumi2)));
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm256_setzero_si256();
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- acc[iy] = _mm256_setzero_ps();
- info.store(ix+0, iy, sum);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto mf = _mm256_set1_epi8(0xf);
- auto m10 = _mm256_set1_epi8(0x10);
- auto m30 = _mm256_set1_epi8(0x30);
- int nbl = n / QK_K;
- union { __m256i vec; uint32_t val[8]; } hd;
- __m256 acc[nrc_y] = {};
- __m256i isum[nrc_y] = {};
- __m256i qx[4];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5[ibl].d));
- auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
- auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
- auto lbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l);
- auto hbits128 = _mm_loadu_si128((const __m128i *)iq5[ibl].scales_h);
- auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
- hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m30));
- auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m30));
- process_min_r4_b32(ibl, m4, mins, q8, acc);
- for (int ib = 0; ib < QK_K/32; ++ib) {
-#ifdef HAVE_FANCY_SIMD
- auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
-#else
- auto aux = _mm_set1_epi32(hd.val[ib]);
- aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
- auto scales_d = MM256_SET_M128I(aux, aux);
-#endif
- auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
- auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
- auto hbits128 = _mm_loadu_si128((const __m128i*)iq5[ibl].qh + ib);
- auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
- qx[0] = _mm256_or_si256(_mm256_and_si256(lbits1, mf), _mm256_and_si256(m10, hbits));
- qx[1] = _mm256_or_si256(_mm256_and_si256(lbits2, mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 2)));
- qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 1)));
- qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 3)));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
-#else
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
- _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
- _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
- // To avoid overflow, we can only add up to 4 q5 x q8 products.
- auto sumi = _mm256_add_epi32(_mm256_madd_epi16(scales_d, sumi1), _mm256_madd_epi16(scales_d, sumi2));
- isum[iy] = _mm256_add_epi32(isum[iy], sumi);
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm256_setzero_si256();
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- acc[iy] = _mm256_setzero_ps();
- info.store(ix+0, iy, sum);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto mxf = _mm256_set1_epi8(0xf);
- auto m03 = _mm256_set1_epi8(0x03);
- static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
- auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
-#ifdef HAVE_FANCY_SIMD
- __m256i isum[nrc_y] = {};
-#else
- auto m1 = _mm256_set1_epi16(1);
-#endif
- int nbl = n / QK_K;
- __m256 acc[nrc_y] = {};
- __m256i qx[4];
- int8_t scales[64];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q2_k_r4 * iq2 = (const block_q2_k_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq2[ibl].d));
- auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dm), _mm256_castps256_ps128(dm));
- auto m4 = _mm256_set_m128(_mm256_extractf128_ps(dm, 1), _mm256_extractf128_ps(dm, 1));
- m4 = _mm256_mul_ps(m4, _mm256_set1_ps(-1.f));
- auto all_scales1 = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales+0);
- auto all_scales2 = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales+1);
- auto scales1 = _mm256_and_si256(_mm256_srli_epi16(all_scales1, 4), mxf);
- auto scales2 = _mm256_and_si256(_mm256_srli_epi16(all_scales2, 4), mxf);
- {
- auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
- auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
- auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
- auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
- auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
- auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
- auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
- auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bsums = q8.load_bsums(iy, ibl);
- auto sumi = _mm256_setzero_si256();
-#ifdef HAVE_FANCY_SIMD
- sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
- sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
- sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
- sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
- auto d8 = _mm256_set1_ps(q8.scale(iy, ibl));
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]);
-#else
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
- auto d8 = _mm256_set1_ps(q8.scale(iy, ibl));
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]);
- if constexpr (nrc_y == 1) {
- d4 = _mm256_mul_ps(d4, d8);
- }
-#endif
- }
- }
- all_scales1 = _mm256_and_si256(all_scales1, mxf);
- all_scales2 = _mm256_and_si256(all_scales2, mxf);
- _mm256_storeu_si256((__m256i *)scales+0, all_scales1);
- _mm256_storeu_si256((__m256i *)scales+1, all_scales2);
- for (int ib = 0; ib < QK_K/32; ++ib) {
- auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib)));
-#ifndef HAVE_FANCY_SIMD
- auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
-#endif
- auto lb = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs+ib);
- qx[0] = _mm256_and_si256(lb, m03);
- qx[1] = _mm256_and_si256(_mm256_srli_epi16(lb, 2), m03);
- qx[2] = _mm256_and_si256(_mm256_srli_epi16(lb, 4), m03);
- qx[3] = _mm256_and_si256(_mm256_srli_epi16(lb, 6), m03);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
-#else
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
- _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
- _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
- // Quants are in 0...3, so we can add add up all of them as int16_t without overflowing
- auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
- if constexpr (nrc_y == 1) {
- acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
- } else {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
-#endif
- }
- }
-#ifdef HAVE_FANCY_SIMD
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
- acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm256_setzero_si256();
- }
-#endif
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- acc[iy] = _mm256_setzero_ps();
- info.store(ix+0, iy, sum);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = _mm256_set1_epi8(0xf);
- auto m30 = _mm256_set1_epi8(0x30);
- auto m32 = _mm256_set1_epi8(32);
- auto m03 = _mm256_set1_epi8(0x03);
- auto m04 = _mm256_set1_epi8(0x04);
- static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
- auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
-#ifdef HAVE_FANCY_SIMD
- __m256i isum[nrc_y];
-#else
- auto m1 = _mm256_set1_epi16(1);
-#endif
- int nbl = n / QK_K;
- __m256 acc[nrc_y] = {};
- __m256i qx[4];
- int8_t scales[64];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q3_k_r4 * iq3 = (const block_q3_k_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d));
- auto d4 = _mm256_set_m128(dl, dl);
-#ifndef HAVE_FANCY_SIMD
- if constexpr (nrc_y == 1) {
- d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
- }
-#endif
- auto slb = _mm256_loadu_si256((const __m256i *)iq3[ibl].scales_l);
- auto shbits = _mm_loadu_si128((const __m128i *)iq3[ibl].scales_h);
- auto shb = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
- auto scales1 = _mm256_sub_epi8(_mm256_or_si256(_mm256_and_si256(slb, m4), _mm256_and_si256(_mm256_slli_epi16(shb, 4), m30)), m32);
- auto scales2 = _mm256_sub_epi8(_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(slb, 4), m4), _mm256_and_si256(shb, m30)), m32);
- _mm256_storeu_si256((__m256i *)scales+0, scales1);
- _mm256_storeu_si256((__m256i *)scales+1, scales2);
- {
-#ifndef HAVE_FANCY_SIMD
- auto min = _mm256_mul_ps(d4, _mm256_set1_ps(-4.f));
-#endif
- auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
- auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
- auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
- auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
- auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
- auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
- auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
- auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
-#ifdef HAVE_FANCY_SIMD
- s1 = _mm256_mullo_epi16(s1, _mm256_set1_epi16(-4));
- s2 = _mm256_mullo_epi16(s2, _mm256_set1_epi16(-4));
- s3 = _mm256_mullo_epi16(s3, _mm256_set1_epi16(-4));
- s4 = _mm256_mullo_epi16(s4, _mm256_set1_epi16(-4));
-#endif
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bsums = q8.load_bsums(iy, ibl);
- auto sumi = _mm256_setzero_si256();
-#ifdef HAVE_FANCY_SIMD
- sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
- sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
- sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
- sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
- isum[iy] = sumi;
-#else
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
- if constexpr (nrc_y == 1) {
- acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]);
- } else {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(min, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
-#endif
- }
- }
- for (int ib = 0; ib < QK_K/32; ++ib) {
- auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib)));
-#ifndef HAVE_FANCY_SIMD
- auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
-#endif
- auto lb = _mm256_loadu_si256((const __m256i *)iq3[ibl].qs+ib);
- auto hbits = _mm_loadu_si128((const __m128i *)iq3[ibl].qh+ib);
- auto hb = MM256_SET_M128I(hbits, _mm_slli_epi16(hbits, 4));
- qx[0] = _mm256_or_si256(_mm256_and_si256(lb, m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 2)));
- qx[1] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 2), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 3)));
- qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 4), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 4)));
- qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 6), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 5)));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
-#else
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
- _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
- _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
- // Quants are in 0...8, so we can add add up all of them as int16_t without overflowing
- auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
- if constexpr (nrc_y == 1) {
- acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
- } else {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
-#endif
-
- }
- }
-#ifdef HAVE_FANCY_SIMD
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
- acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- }
-#endif
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- acc[iy] = _mm256_setzero_ps();
- info.store(ix+0, iy, sum);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = _mm256_set1_epi8(0xf);
- auto m3 = _mm256_set1_epi8(0x30);
- static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
- auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
-#ifdef HAVE_FANCY_SIMD
- __m256i isum[nrc_y];
-#else
- auto m1 = _mm256_set1_epi16(1);
-#endif
- int nbl = n / QK_K;
- __m256 acc[nrc_y] = {};
- __m256i qx[4];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q6_k_r4 * iq6 = (const block_q6_k_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6[ibl].d));
- auto d4 = _mm256_set_m128(dl, dl);
-#ifndef HAVE_FANCY_SIMD
- if constexpr (nrc_y == 1) {
- d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
- }
-#endif
- {
-#ifndef HAVE_FANCY_SIMD
- auto min = _mm256_mul_ps(d4, _mm256_set1_ps(-32.f));
-#endif
- auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+0)), shuff); // blocks 0, 1, 2, 3 for each row
- auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+1)), shuff); // blocks 4, 5, 6, 7 for each row
- auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+2)), shuff); // blocks 8, 9, 10, 11 for each row
- auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+3)), shuff); // blocks 12, 13, 14, 15 for each row
- auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
- auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
- auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
- auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
-#ifdef HAVE_FANCY_SIMD
- s1 = _mm256_mullo_epi16(s1, _mm256_set1_epi16(-32));
- s2 = _mm256_mullo_epi16(s2, _mm256_set1_epi16(-32));
- s3 = _mm256_mullo_epi16(s3, _mm256_set1_epi16(-32));
- s4 = _mm256_mullo_epi16(s4, _mm256_set1_epi16(-32));
-#endif
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bsums = q8.load_bsums(iy, ibl);
- auto sumi = _mm256_setzero_si256();
-#ifdef HAVE_FANCY_SIMD
- sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
- sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
- sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
- sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
- isum[iy] = sumi;
-#else
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
- if constexpr (nrc_y == 1) {
- acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]);
- } else {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(min, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
-#endif
- }
- }
- const uint32_t * scales = (const uint32_t *)iq6[ibl].scales;
- for (int ib = 0; ib < QK_K/32; ++ib) {
- auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 2*ib)));
-#ifndef HAVE_FANCY_SIMD
- auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
-#endif
- auto lbits1 = _mm256_loadu_si256((const __m256i *)iq6[ibl].ql+2*ib+0);
- auto lbits2 = _mm256_loadu_si256((const __m256i *)iq6[ibl].ql+2*ib+1);
- auto hbits = _mm256_loadu_si256((const __m256i *)iq6[ibl].qh+ib);
- qx[0] = _mm256_or_si256(_mm256_and_si256(lbits1, m4), _mm256_and_si256(m3, _mm256_slli_epi16(hbits, 4)));
- qx[1] = _mm256_or_si256(_mm256_and_si256(lbits2, m4), _mm256_and_si256(m3, hbits));
- qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4), _mm256_and_si256(m3, _mm256_slli_epi16(hbits, 2)));
- qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4), _mm256_and_si256(m3, _mm256_srli_epi16(hbits, 2)));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
-#else
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
- _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
- _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
- // Quants are in 0...63, so we can add at most 4 as int16_t to be sure of no int16_t overflow
- auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
- if constexpr (nrc_y == 1) {
- acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
- } else {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- }
-#endif
- }
- }
-#ifdef HAVE_FANCY_SIMD
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
- acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- }
-#endif
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- acc[iy] = _mm256_setzero_ps();
- info.store(ix+0, iy, sum);
- }
- }
-}
-
-// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
-template <int nrc_y>
-static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
-#ifndef HAVE_FANCY_SIMD
- auto m1 = _mm256_set1_epi16(1);
-#endif
- int nbl = n / QK_K;
- __m256 acc[nrc_y] = {};
- __m256i isum[nrc_y] = {};
- __m256i qx[4];
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ibl].d));
- for (int ib = 0; ib < QK_K/16; ++ib) {
- qx[0] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+0);
- qx[1] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1);
- qx[2] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2);
- qx[3] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3);
-#ifndef HAVE_FANCY_SIMD
- auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
- auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
- auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
- auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
-#else
- qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127));
- qx[1] = _mm256_add_epi8(qx[1], _mm256_set1_epi8(127));
- qx[2] = _mm256_add_epi8(qx[2], _mm256_set1_epi8(127));
- qx[3] = _mm256_add_epi8(qx[3], _mm256_set1_epi8(127));
-#endif
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+ib);
- auto y = MM256_SET_M128I(y128, y128);
-#ifdef HAVE_FANCY_SIMD
- isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
- isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
- isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
- isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
-#else
- auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])));
- auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])));
- auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])));
- auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi1, sumi2));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi3, sumi4));
-#endif
- }
- }
-#ifdef HAVE_FANCY_SIMD
- auto m4 = _mm256_mul_ps(d4, _mm256_set1_ps(-128.f));
-#endif
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
- acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
-#ifdef HAVE_FANCY_SIMD
- auto bsums = (const float *)q8.y[iy][ibl].bsums;
- acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(bsums[0]), acc[iy]);
-#endif
- isum[iy] = _mm256_setzero_si256();
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = _mm256_setzero_ps();
- }
- }
-}
-
-// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
-template <int nrc_y>
-static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(n%32 == 0);
- GGML_ASSERT(nrc_x%8 == 0);
-#ifndef HAVE_FANCY_SIMD
- auto m1 = _mm256_set1_epi16(1);
-#endif
- int nb = n / 16;
- __m256i acc[nrc_y] = {};
- __m256i qx[4];
- float dy[nrc_y];
-#ifdef HAVE_FANCY_SIMD
- float sy[nrc_y];
-#endif
- const int8_t * q8y[nrc_y];
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dptr = (const float *)info.src1_row(iy);
- dy[iy] = dptr[0];
-#ifdef HAVE_FANCY_SIMD
- auto iptr = (const int32_t *)(dptr + 1);
- sy[iy] = -127*iptr[0];
-#endif
- q8y[iy] = (const int8_t *)(dptr + 2);
- }
- for (int ix = 0; ix < nrc_x; ix += 8) {
- auto dptr = (const float *)((const char *)vx + ix*bx);
- auto dx = _mm256_loadu_ps(dptr);
- auto q8x = (const int8_t *)(dptr + 8);
- for (int ib = 0; ib < nb; ++ib) { // Blocks of 16 for 8 interleaved rows
- qx[0] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+0);
- qx[1] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+1);
- qx[2] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+2);
- qx[3] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+3);
-#ifndef HAVE_FANCY_SIMD
- auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
- auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
- auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
- auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
-#else
- qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127));
- qx[1] = _mm256_add_epi8(qx[1], _mm256_set1_epi8(127));
- qx[2] = _mm256_add_epi8(qx[2], _mm256_set1_epi8(127));
- qx[3] = _mm256_add_epi8(qx[3], _mm256_set1_epi8(127));
-#endif
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y128 = _mm_loadu_si128((const __m128i*)q8y[iy]+ib);
- auto y = MM256_SET_M128I(y128, y128);
-#ifdef HAVE_FANCY_SIMD
- acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
- acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
- acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
- acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
-#else
- auto sumi1 = _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
- auto sumi2 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
- auto sumi3 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
- auto sumi4 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
- auto sumi12 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
- auto sumi34 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi3), _mm256_madd_epi16(m1, sumi4));
- acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(sumi12, sumi34));
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy]));
-#ifdef HAVE_FANCY_SIMD
- acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy]));
-#endif
- info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy])));
- acc[iy] = _mm256_setzero_si256();
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(n%32 == 0);
- if (nrc_y == 1 && nrc_x == 1) {
- auto dx = (const float *)vx;
- auto dy = (const float *)info.src1_row(0);
-#ifdef HAVE_FANCY_SIMD
- auto sy = (const int32_t *)(dy + 1);
- auto x = (const int8_t *)(dx + 2);
- auto y = (const int8_t *)(dy + 2);
- auto isum = _mm512_setzero_si512();
- for (int i = 0; i < n/64; ++i) {
- auto qx = _mm512_loadu_si512((const __m512i *)x + i);
- auto qy = _mm512_loadu_si512((const __m512i *)y + i);
- isum = _mm512_dpbusd_epi32(isum, _mm512_add_epi8(qx, _mm512_set1_epi8(127)), qy);
- }
- auto isum256 = _mm256_add_epi32(_mm512_castsi512_si256(isum), _mm512_extracti32x8_epi32(isum, 1));
- for (int i = 2*(n/64); i < n/32; ++i) {
- auto qx = _mm256_loadu_si256((const __m256i *)x + i);
- auto qy = _mm256_loadu_si256((const __m256i *)y + i);
- isum256 = _mm256_dpbusd_epi32(isum256, _mm256_add_epi8(qx, _mm256_set1_epi8(127)), qy);
- }
- info.store(0, 0, dx[0]*dy[0]*(hsum_i32_8(isum256) - 127*sy[0]));
-#else
- auto x = (const int8_t *)(dx + 2);
- auto y = (const int8_t *)(dy + 2);
- auto isum = _mm256_setzero_si256();
- for (int i = 0; i < n/32; ++i) {
- auto qx = _mm256_loadu_si256((const __m256i *)x + i);
- auto qy = _mm256_loadu_si256((const __m256i *)y + i);
- auto dot = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(qy, qx));
- isum = _mm256_add_epi32(isum, _mm256_madd_epi16(_mm256_set1_epi16(1), dot));
- }
- info.store(0, 0, dx[0]*dy[0]*hsum_i32_8(isum));
-#endif
- return;
- }
- __m256i qx[2];
- __m256i acc[2*nrc_y] = {};
- float dy[nrc_y];
-#ifdef HAVE_FANCY_SIMD
- int32_t sy[nrc_y];
-#else
- __m256i sx[2];
- auto m1 = _mm256_set1_epi16(1);
-#endif
- const int8_t * q8y[nrc_y];
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dptr = (const float *)info.src1_row(iy);
- dy[iy] = dptr[0];
-#ifdef HAVE_FANCY_SIMD
- auto iptr = (const int32_t *)(dptr+1);
- sy[iy] = -127*iptr[0];
-#endif
- q8y[iy] = (const int8_t *)(dptr + 2);
- }
- for (int ix = 0; ix < nrc_x; ++ix) {
- auto dx = (const float *)((const char *)vx + ix*bx);
- auto q8x = (const int8_t *)(dx + 2);
- for (int i = 0; i < n/64; ++i) {
- for (int j = 0; j < 2; ++j) {
-#ifdef HAVE_FANCY_SIMD
- qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127));
-#else
- qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
- sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
-#endif
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- for (int j = 0; j < 2; ++j) {
-#ifdef HAVE_FANCY_SIMD
- acc[2*iy+j] = _mm256_dpbusd_epi32(acc[2*iy+j], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j));
-#else
- auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
- acc[2*iy+j] = _mm256_add_epi32(acc[2*iy+j], _mm256_madd_epi16(m1, dot));
-#endif
- }
- }
- }
- if (int i = 2*(n/64); i < n/32) {
-#ifdef HAVE_FANCY_SIMD
- qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127));
-#else
- qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
- sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
-#endif
- for (int iy = 0; iy < nrc_y; ++iy) {
-#ifdef HAVE_FANCY_SIMD
- acc[2*iy] = _mm256_dpbusd_epi32(acc[2*iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i));
-#else
- auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
- acc[2*iy] = _mm256_add_epi32(acc[2*iy], _mm256_madd_epi16(m1, dot));
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = hsum_i32_8(_mm256_add_epi32(acc[2*iy], acc[2*iy+1]));
-#ifdef HAVE_FANCY_SIMD
- info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy]));
-#else
- info.store(ix, iy, dx[0]*dy[iy]*sumi);
-#endif
- acc[2*iy] = acc[2*iy+1] = _mm256_setzero_si256();
- }
- }
-}
-
-#ifdef HAVE_FANCY_SIMD
-template <int nrc_y>
-static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
- GGML_ASSERT(n%32 == 0);
- __m512i qx[4];
- __m512i acc[nrc_y <= 4 ? 2*nrc_y : nrc_y] = {};
- float dy[nrc_y];
- int32_t sy[nrc_y];
- const int8_t * q8y[nrc_y];
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dptr = (const float *)info.src1_row(iy);
- dy[iy] = dptr[0];
- auto iptr = (const int32_t *)(dptr + 1);
- sy[iy] = -64*iptr[0];
- q8y[iy] = (const int8_t *)(dptr + 2);
- }
- const int8_t * q8x[8];
- float dx[8];
- for (int ix = 0; ix < nrc_x; ix += 8) {
- for (int kx = 0; kx < 8; ++kx) {
- auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
- dx[kx] = dptr[0];
- q8x[kx] = (const int8_t *)(dptr + 2);
- }
- for (int i = 0; i < n/32; ++i) {
- for (int kx = 0; kx < 4; ++kx) {
- qx[kx] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8x[kx+0] + i)),
- _mm256_loadu_si256((const __m256i *)q8x[kx+4] + i), 1);
- }
- auto t0 = _mm512_unpacklo_epi32(qx[0], qx[1]);
- auto t1 = _mm512_unpacklo_epi32(qx[2], qx[3]);
- auto t2 = _mm512_unpackhi_epi32(qx[0], qx[1]);
- auto t3 = _mm512_unpackhi_epi32(qx[2], qx[3]);
- qx[0] = _mm512_xor_si512(_mm512_unpacklo_epi64(t0, t1), _mm512_set1_epi8(-128));
- qx[1] = _mm512_xor_si512(_mm512_unpackhi_epi64(t0, t1), _mm512_set1_epi8(-128));
- qx[2] = _mm512_xor_si512(_mm512_unpacklo_epi64(t2, t3), _mm512_set1_epi8(-128));
- qx[3] = _mm512_xor_si512(_mm512_unpackhi_epi64(t2, t3), _mm512_set1_epi8(-128));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y256 = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
- auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
- if constexpr (nrc_y <= 4) {
- acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- } else {
- acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- }
- }
- }
- auto scales_x = _mm256_loadu_ps(dx);
- for (int iy = 0; iy < nrc_y; ++iy) {
- if constexpr (nrc_y <= 4) {
- auto ss = _mm512_add_epi32(_mm512_add_epi32(acc[2*iy+0], acc[2*iy+1]), _mm512_set1_epi32(sy[iy]));
- auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 0), _mm512_extracti32x4_epi32(ss, 1));
- auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 2), _mm512_extracti32x4_epi32(ss, 3));
- auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy]));
- info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1)));
- info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2)));
- acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
- } else {
- acc[iy] = _mm512_add_epi32(acc[iy], _mm512_set1_epi32(sy[iy]));
- auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 0), _mm512_extracti32x4_epi32(acc[iy], 1));
- auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 2), _mm512_extracti32x4_epi32(acc[iy], 3));
- auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy]));
- info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1)));
- info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2)));
- acc[iy] = _mm512_setzero_si512();
- }
- }
- }
-}
-#endif
-
-template <int nrc_y>
-static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- GGML_ASSERT(n%32 == 0);
- __m256i qx[4];
-#ifndef HAVE_FANCY_SIMD
- __m256i sx[4];
- auto m1 = _mm256_set1_epi16(1);
-#endif
- __m256i acc[nrc_y] = {};
- float dy[nrc_y];
-#ifdef HAVE_FANCY_SIMD
- int32_t sy[nrc_y];
-#endif
- const int8_t * q8y[nrc_y];
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dptr = (const float *)info.src1_row(iy);
- dy[iy] = dptr[0];
-#ifdef HAVE_FANCY_SIMD
- auto iptr = (const int32_t *)(dptr + 1);
- sy[iy] = -127*iptr[0];
-#endif
- q8y[iy] = (const int8_t *)(dptr + 2);
- }
- const int8_t * q8x[4];
- float dx[4];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- for (int kx = 0; kx < 4; ++kx) {
- auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
- dx[kx] = dptr[0];
- q8x[kx] = (const int8_t *)(dptr + 2);
- }
- for (int i = 0; i < n/32; ++i) {
- for (int kx = 0; kx < 4; ++kx) qx[kx] = _mm256_loadu_si256((const __m256i *)q8x[kx] + i);
- auto t0 = _mm256_unpacklo_epi32(qx[0], qx[1]);
- auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
- auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
- auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
-#ifdef HAVE_FANCY_SIMD
- qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127));
- qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127));
- qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127));
- qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127));
-#else
- qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
- qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
- qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
- qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
-#endif
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
-#ifdef HAVE_FANCY_SIMD
- acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
- acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
- acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
- acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
-#else
- auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
- auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
- auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
- auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
- auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2));
- auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4));
- acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34));
-#endif
- }
- }
- auto scales_x = _mm_loadu_ps(dx);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
-#ifdef HAVE_FANCY_SIMD
- sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy]));
-#endif
- auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
- info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
- acc[iy] = _mm256_setzero_si256();
- }
- }
-}
-
-#ifdef __AVX512BF16__
-template <int nrc_y>
-static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%16 == 0);
- const ggml_bf16_t * y[nrc_y];
- for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
- for (int ix = 0; ix < nrc_x/32; ++ix) {
- __m512 acc[2*nrc_y] = {};
- __m512bh qx[8];
- const ggml_bf16_t * b8_1 = (const ggml_bf16_t *)((const char *)vx + (32*ix+ 0)*bx);
- const ggml_bf16_t * b8_2 = (const ggml_bf16_t *)((const char *)vx + (32*ix+16)*bx);
- for (int ib = 0; ib < n/8; ++ib) {
- qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+0);
- qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+1);
- qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+2);
- qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+3);
- qx[4] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+0);
- qx[5] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+1);
- qx[6] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+2);
- qx[7] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+3);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
- //auto y = _mm512_broadcast_i32x4(y128);
- auto y256 = MM256_SET_M128I(y128, y128);
- auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
- acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[4], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[5], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[6], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[7], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(32*ix+ 0, iy, acc[2*iy+0]);
- info.store(32*ix+16, iy, acc[2*iy+1]);
- }
- }
- for (int ix = 32*(nrc_x/32); ix < nrc_x; ix += 16) {
- __m512 acc[nrc_y] = {};
- __m512bh qx[4];
- const ggml_bf16_t * b8 = (const ggml_bf16_t *)((const char *)vx + (ix+0)*bx);
- for (int ib = 0; ib < n/8; ++ib) {
- qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+0);
- qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1);
- qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2);
- qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
- auto y256 = MM256_SET_M128I(y128, y128);
- auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
- acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
- acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
- acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
- acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- }
- }
-}
-#endif
-
-template <int nrc_y>
-//IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
-inline void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
- __m256i * isum, int16_t min) {
- auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
- auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
- auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
- auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
- if constexpr (nrc_y == 1) {
- auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
- auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
- auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
- auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
- auto sumi = _mm256_setzero_si256();
- auto bsums = q8.load_bsums(0, ibl);
-#ifdef HAVE_FANCY_SIMD
- sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
- sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
- sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
- sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
-#else
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
- sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
-#endif
- isum[0] = _mm256_mullo_epi32(sumi, _mm256_set1_epi32(min));
-
- } else {
- auto s1 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9
- auto s2 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11
- auto s3 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13
- auto s4 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bsums = q8.load_bsums(iy, ibl);
-#ifdef HAVE_FANCY_SIMD
- isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00));
- isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55));
- isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa));
- isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff));
-#else
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
-#endif
- }
- }
-}
-
-template <int nrc_y>
-inline void iq2345_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
- __m256i extra, __m256i * isum, int8_t min, int8_t delta) {
- auto mask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101);
- auto vdelta = _mm256_set1_epi8(delta);
- auto vmin = _mm256_set1_epi8(min);
- auto min1 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(extra, mask), mask)));
- auto min2 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(extra, 4), mask), mask)));
- auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
- auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
- auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
- auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
- auto m1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
- auto m2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
- auto m3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
- auto m4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
- auto s1 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 0), _mm256_extracti128_si256(m1, 0)),
- MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9
- auto s2 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 1), _mm256_extracti128_si256(m1, 1)),
- MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11
- auto s3 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 0), _mm256_extracti128_si256(m2, 0)),
- MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13
- auto s4 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 1), _mm256_extracti128_si256(m2, 1)),
- MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bsums = q8.load_bsums(iy, ibl);
-#ifdef HAVE_FANCY_SIMD
- isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00));
- isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55));
- isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa));
- isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff));
-#else
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
-#endif
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = _mm256_set1_epi8(0xf);
- auto ms = _mm256_set1_epi8(4);
- auto m03 = _mm256_set1_epi8(0x03);
- auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
- static const uint8_t kvalues_iq2nl[32] = {1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54};
- auto values = _mm256_loadu_si256((const __m256i*)kvalues_iq2nl);
- static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
- auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
-#ifndef HAVE_FANCY_SIMD
- auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
-#endif
- int nbl = n / QK_K;
- __m256 acc[nrc_y] = {};
- __m256i qx[4];
- uint64_t stored_scales[8];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_iq2_k_r4 * iq2 = (const block_iq2_k_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d));
- auto d4 = _mm256_set_m128(dl, dl);
- auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq2[ibl].extra);
- auto slbits = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales);
- auto i8scales1 = _mm256_add_epi8(_mm256_and_si256(slbits, m4), _mm256_set1_epi8(-8));
- auto i8scales2 = _mm256_add_epi8(_mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4), _mm256_set1_epi8(-8));
- _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
- _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
- __m256i isum[nrc_y] = {};
- iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -32);
- for (int ib = 0; ib < QK_K/32; ++ib) {
-#ifdef HAVE_FANCY_SIMD
- auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
-#else
- auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
-#endif
- auto lb = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs+ib);
- auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 2)); extra = _mm256_srli_epi16(extra, 1);
- shift = _mm256_shuffle_epi8(shift, shift_shuffle);
- qx[0] = _mm256_and_si256(lb, m03);
- qx[1] = _mm256_and_si256(_mm256_srli_epi16(lb, 2), m03);
- qx[2] = _mm256_and_si256(_mm256_srli_epi16(lb, 4), m03);
- qx[3] = _mm256_and_si256(_mm256_srli_epi16(lb, 6), m03);
- qx[0] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[0], shift));
- qx[1] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[1], shift));
- qx[2] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[2], shift));
- qx[3] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[3], shift));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
-#else
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
- _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
- _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- acc[iy] = _mm256_setzero_ps();
- info.store(ix+0, iy, sum);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = _mm256_set1_epi8(0xf);
- auto ms = _mm256_set1_epi8(8);
- auto m03 = _mm256_set1_epi8(0x03);
- auto m04 = _mm256_set1_epi8(0x04);
- auto smask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101);
- auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
- auto values128 = _mm_loadu_si128((const __m128i *)iq3nl_values);
- auto values = MM256_SET_M128I(values128, values128);
- values = _mm256_add_epi8(values, _mm256_set1_epi8(64));
- static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
- auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
-#ifndef HAVE_FANCY_SIMD
- auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
-#endif
- int nbl = n / QK_K;
- __m256 acc[nrc_y] = {};
- __m256i qx[4];
- uint64_t stored_scales[8];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d));
- auto d4 = _mm256_set_m128(dl, dl);
- auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq3[ibl].extra);
- auto slbits = _mm256_loadu_si256((const __m256i *)iq3[ibl].scales_l);
- auto sl1 = _mm256_add_epi8(_mm256_slli_epi16(_mm256_and_si256(slbits, m4), 1), _mm256_set1_epi8(1));
- auto sl2 = _mm256_add_epi8(_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4), 1), _mm256_set1_epi8(1));
- auto sh = _mm256_set1_epi64x(((const uint64_t *)iq3[ibl].scales_h)[0]);
- auto sh1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sh, smask), smask), _mm256_set1_epi8(1));
- auto sh2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(sh, 4), smask), smask), _mm256_set1_epi8(1));
- auto i8scales1 = _mm256_sign_epi8(sl1, sh1);
- auto i8scales2 = _mm256_sign_epi8(sl2, sh2);
- _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
- _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
- __m256i isum[nrc_y] = {};
- iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -64);
- for (int ib = 0; ib < QK_K/32; ++ib) {
-#ifdef HAVE_FANCY_SIMD
- auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
-#else
- auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
-#endif
- auto lb = _mm256_loadu_si256((const __m256i *)iq3[ibl].qs+ib);
- auto hbits = _mm_loadu_si128((const __m128i *)iq3[ibl].qh+ib);
- auto hb = MM256_SET_M128I(hbits, _mm_slli_epi16(hbits, 4));
- auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 3)); extra = _mm256_srli_epi16(extra, 1);
- shift = _mm256_shuffle_epi8(shift, shift_shuffle);
- qx[0] = _mm256_or_si256(_mm256_and_si256(lb, m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 2)));
- qx[1] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 2), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 3)));
- qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 4), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 4)));
- qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 6), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 5)));
- qx[0] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[0], shift));
- qx[1] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[1], shift));
- qx[2] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[2], shift));
- qx[3] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[3], shift));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
-#else
- auto sumi1 = _mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00));
- auto sumi2 = _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55));
- auto sumi3 = _mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa));
- auto sumi4 = _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4)));
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- acc[iy] = _mm256_setzero_ps();
- info.store(ix+0, iy, sum);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = _mm256_set1_epi8(0xf);
- auto m30 = _mm256_set1_epi8(0x30);
- auto m32 = _mm256_set1_epi8(32);
- auto ms = _mm256_set1_epi8(4);
- auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
-#ifdef HAVE_FANCY_SIMD
- auto values = load_iq4nl_values_256();
- static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
- auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
-#else
- auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
- auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
- auto values = MM256_SET_M128I(values128, values128);
-#endif
- int nbl = n / QK_K;
- __m256 acc[nrc_y] = {};
- __m256i qx[4];
- uint64_t stored_scales[8];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_iq4_k_r4 * iq4 = (const block_iq4_k_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[ibl].d));
- auto d4 = _mm256_set_m128(dl, dl);
- auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq4[ibl].extra);
- auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
- auto sl1 = _mm256_and_si256(slbits, m4);
- auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4);
- auto shbits = _mm_loadu_si128((const __m128i*)iq4[ibl].scales_h);
- auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
- auto i8scales1 = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(m30, _mm256_slli_epi16(sh, 4))), m32);
- auto i8scales2 = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(m30, sh)), m32);
- _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
- _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
- __m256i isum[nrc_y] = {};
-#ifdef HAVE_FANCY_SIMD
- iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128);
-#endif
- for (int ib = 0; ib < QK_K/32; ++ib) {
-#ifdef HAVE_FANCY_SIMD
- auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
-#else
- auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
-#endif
- auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
- auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
- auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 2)); extra = _mm256_srli_epi16(extra, 1);
- shift = _mm256_shuffle_epi8(shift, shift_shuffle);
- qx[0] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)));
- qx[1] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)));
- qx[2] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)));
- qx[3] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)));
-#ifndef HAVE_FANCY_SIMD
- auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
- auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
- auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
- auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
-#endif
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
-#else
- auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
- auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
- auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
- auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4)));
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- acc[iy] = _mm256_setzero_ps();
- info.store(ix+0, iy, sum);
- }
- }
-}
-
-static inline __m256i prepare_5bit_quants(const __m256i * values, __m256i ql, __m256i qh, __m256i mask) {
- auto q5vl = _mm256_shuffle_epi8(values[0], ql);
- auto q5vh = _mm256_shuffle_epi8(values[1], ql);
-#ifdef HAVE_FANCY_SIMD
- return _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(qh, mask), mask), q5vl, q5vh);
-#else
- return _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(qh, mask), mask));
-#endif
-}
-
-template <int nrc_y>
-static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = _mm256_set1_epi8(0xf);
- auto m30 = _mm256_set1_epi8(0x30);
- auto m32 = _mm256_set1_epi8(32);
- auto ms = _mm256_set1_epi8(2);
- auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
- __m256i values[2];
- {
- auto val1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0);
- auto val2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1);
- values[0] = MM256_SET_M128I(val1, val1);
- values[1] = MM256_SET_M128I(val2, val2);
-#ifdef HAVE_FANCY_SIMD
- values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128));
- values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128));
-#endif
- }
-#ifdef HAVE_FANCY_SIMD
- static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
- auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
-#else
- auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
-#endif
- int nbl = n / QK_K;
- __m256 acc[nrc_y] = {};
- __m256i qx[4];
- uint64_t stored_scales[8];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5[ibl].d));
- auto d4 = _mm256_set_m128(dl, dl);
- auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq5[ibl].extra);
- auto slbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l);
- auto sl1 = _mm256_and_si256(slbits, m4);
- auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4);
- auto shbits = _mm_loadu_si128((const __m128i*)iq5[ibl].scales_h);
- auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
- auto i8scales1 = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(m30, _mm256_slli_epi16(sh, 4))), m32);
- auto i8scales2 = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(m30, sh)), m32);
- _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
- _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
- __m256i isum[nrc_y] = {};
-#ifdef HAVE_FANCY_SIMD
- if constexpr (nrc_y == 1) {
- iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128);
- } else {
- iq2345_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, extra, isum, -128, 2);
- }
-#endif
- for (int ib = 0; ib < QK_K/32; ++ib) {
-#ifdef HAVE_FANCY_SIMD
- auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
-#else
- auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
-#endif
- auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
- auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
- auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib);
- auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits);
- qx[0] = _mm256_and_si256(lbits1, m4);
- qx[1] = _mm256_and_si256(lbits2, m4);
- qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4);
- qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4);
-
- qx[0] = prepare_5bit_quants(values, qx[0], hb, _mm256_set1_epi8(0x01));
- qx[1] = prepare_5bit_quants(values, qx[1], hb, _mm256_set1_epi8(0x10));
- qx[2] = prepare_5bit_quants(values, qx[2], hb, _mm256_set1_epi8(0x02));
- qx[3] = prepare_5bit_quants(values, qx[3], hb, _mm256_set1_epi8(0x20));
-#ifdef HAVE_FANCY_SIMD
- if constexpr (nrc_y == 1) {
- auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
- shift = _mm256_shuffle_epi8(shift, shift_shuffle);
- qx[0] = _mm256_add_epi8(qx[0], shift);
- qx[1] = _mm256_add_epi8(qx[1], shift);
- qx[2] = _mm256_add_epi8(qx[2], shift);
- qx[3] = _mm256_add_epi8(qx[3], shift);
- }
-#else
- auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
- shift = _mm256_shuffle_epi8(shift, shift_shuffle);
- qx[0] = _mm256_add_epi8(qx[0], shift);
- qx[1] = _mm256_add_epi8(qx[1], shift);
- qx[2] = _mm256_add_epi8(qx[2], shift);
- qx[3] = _mm256_add_epi8(qx[3], shift);
- auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
- auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
- auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
- auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
-#endif
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
-#else
- auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
- auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
- auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
- auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4)));
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- acc[iy] = _mm256_setzero_ps();
- info.store(ix+0, iy, sum);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = _mm256_set1_epi8(0xf);
- __m256i values[2];
- {
- auto val1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0);
- auto val2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1);
- values[0] = MM256_SET_M128I(val1, val1);
- values[1] = MM256_SET_M128I(val2, val2);
-#ifdef HAVE_FANCY_SIMD
- values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128));
- values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128));
-#endif
- }
- int nbl = n / QK_K;
- using helper_t = union { __m256i vec; uint32_t val[8]; };
-#ifndef HAVE_FANCY_SIMD
- helper_t h, h_shift;
- auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
-#else
- using helper512_t = union { __m512i vec; uint64_t val[8]; };
- helper_t h;
- helper512_t h_shift;
-#endif
- __m256 acc[nrc_y] = {};
- __m256i isum[nrc_y] = {};
- __m256i qx[4];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto dptr = (const float *)((const char *)vx + (ix+0)*bx);
- const block_iq5_ks_r4 * iq5 = (const block_iq5_ks_r4 *)(dptr + 4);
- auto d4 = _mm_loadu_ps(dptr);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto scales = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales);
- h.vec = _mm256_sub_epi8(_mm256_and_si256(scales, _mm256_set1_epi8(-2)), _mm256_set1_epi8(127));
-#ifndef HAVE_FANCY_SIMD
- h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1);
- {
- __m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[0])))),
- _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[0])))));
- __m256 v2 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[1])))),
- _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[1])))));
- __m256 v3 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[2])))),
- _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[2])))));
- __m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[3])))),
- _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[3])))));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto m8 = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
- acc[iy] = _mm256_fmadd_ps(v1, _mm256_shuffle_ps(m8, m8, 0x00), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(v2, _mm256_shuffle_ps(m8, m8, 0x55), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(v3, _mm256_shuffle_ps(m8, m8, 0xaa), acc[iy]);
- acc[iy] = _mm256_fmadd_ps(v4, _mm256_shuffle_ps(m8, m8, 0xff), acc[iy]);
- }
- }
-#else
- auto shift = _mm256_add_epi8(_mm256_set1_epi8(-64), _mm256_and_si256(scales, _mm256_set1_epi8(1)));
- h_shift.vec = _mm512_mullo_epi16(_mm512_cvtepi8_epi16(shift), _mm512_cvtepi8_epi16(h.vec));
-#endif
- for (int ib = 0; ib < QK_K/32; ++ib) {
-#ifdef HAVE_FANCY_SIMD
- auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib]));
- auto ishifts = _mm256_cvtepi16_epi32(_mm_set1_epi64x(h_shift.val[ib]));
- auto scales_m = _mm256_cvtepi32_ps(ishifts);
- for (int iy = 0; iy < nrc_y; ++iy) {
- float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
- acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
- }
-#endif
- auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
- auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
- auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib);
- auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits);
- qx[0] = _mm256_and_si256(lbits1, m4);
- qx[1] = _mm256_and_si256(lbits2, m4);
- qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4);
- qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4);
-
- qx[0] = prepare_5bit_quants(values, qx[0], hb, _mm256_set1_epi8(0x01));
- qx[1] = prepare_5bit_quants(values, qx[1], hb, _mm256_set1_epi8(0x10));
- qx[2] = prepare_5bit_quants(values, qx[2], hb, _mm256_set1_epi8(0x02));
- qx[3] = prepare_5bit_quants(values, qx[3], hb, _mm256_set1_epi8(0x20));
-
-#ifndef HAVE_FANCY_SIMD
- auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle);
- auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
- auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
- auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
- auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
-#endif
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
-#ifdef HAVE_FANCY_SIMD
- auto sumi = _mm256_setzero_si256();
- sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
- sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
- sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
- sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
-#else
- auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
- auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
- auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
- auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)));
- isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
-#endif
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, ibl)), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
- isum[iy] = _mm256_setzero_si256();
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
- acc[iy] = _mm256_setzero_ps();
- info.store(ix+0, iy, _mm_mul_ps(d4, sum));
- }
- }
-}
-
-template <typename Bits>
-inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) {
- if (j == 0) {
-#ifdef HAVE_FANCY_SIMD
- auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);
- auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);
- auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);
- auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);
- sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_packs_epi32(p1, p2));
- sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_packs_epi32(p3, p4));
-#else
- const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));
- const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));
- const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));
- const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));
- sumi[0] = _mm256_add_epi32(p1, p3);
- sumi[1] = _mm256_add_epi32(p2, p4);
-#endif
- } else {
-#ifdef HAVE_FANCY_SIMD
- auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);
- auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);
- auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);
- auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);
- sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_packs_epi32(p1, p2));
- sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_packs_epi32(p3, p4));
-#else
- const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));
- const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));
- const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));
- const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));
- sumi[0] = _mm256_add_epi32(sumi[0], _mm256_add_epi32(p1, p3));
- sumi[1] = _mm256_add_epi32(sumi[1], _mm256_add_epi32(p2, p4));
-#endif
- }
-}
-
-// TODO: find the bug that causes this to be called without HAVE_FANCY_SIMD, which triggers
-// writing 4 vvalues into scales, which is of size 2.
-inline void set_scales_8_iq(int j, const __m256i& all_scales, __m256i * scales) {
-//#ifdef HAVE_FANCY_SIMD
- auto shuffle = j == 0 ? _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100)
- : _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908);
- scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);
- scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(4)));
-//#else
-// set_scales_8(all_scales, j, scales);
-//#endif
-}
-
-inline void set_scales_16_iq(const __m256i& all_scales, __m256i * scales) {
-#ifdef HAVE_FANCY_SIMD
- auto shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100);
- scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);
- scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(8)));
-#else
- set_scales_16(all_scales, scales);
-#endif
-}
-
-template <typename Dequantizer>
-static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- const int nb = n / QK_K;
- Q8<1> q8(info);
- Dequantizer deq(vx, bx);
- __m256i scales[2];
- __m256i q8_quants[4];
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- __m256 accd = _mm256_setzero_ps();
- deq.new_row(ix);
-
- for (int i = 0; i < nb; ++i) {
-
- __m256i sumi[2], all_scales[Dequantizer::num_blocks/8];
- deq.new_block(i, all_scales);
-
- for (int j = 0; j < QK_K/128; ++j) {
- deq.prepare(i, j, q8, q8_quants);
- if constexpr (Dequantizer::num_blocks == 8) {
- set_scales_8_iq(j, all_scales[0], scales);
- } else {
- set_scales_16_iq(all_scales[j], scales);
- }
- multiply_add_1(j, deq.bits, scales, q8_quants, sumi);
- }
- accd = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(0, i)), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi[0], sumi[1])), accd);
- }
-
- info.store(ix, 0, hsum_float_8(accd));
- }
-}
-
-// So, if I uncomment this function and the call to it in mul_mat_qX_K_q8_K_IQ_N() below,
-// PP performance improves by ~2-3% (when we have __AVX512VNNI__ and __AVX512VL__).
-// But TG performance for iq3_xs drops by 35%. Seriously? I mean, c'mon,
-// what does the compilation of mul_mat_qX_K_q8_K_IQ_1 (which gets invoked during TG)
-// have to do with the compilation of mul_mat_qX_K_q8_K_IQ_N (invoked during PP)?
-//template <typename Q8, typename Bits>
-//inline void multiply_add_iq(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
-//#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
-// for (int iy = 0; iy < Q8::nrc_y; ++iy) {
-// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4*j+0)));
-// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 4*j+1)));
-// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 4*j+2)));
-// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 4*j+3)));
-// }
-//#else
-// for (int iy = 0; iy < Q8::nrc_y; ++iy) {
-// const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4*j+0)));
-// const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 4*j+1)));
-// const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 4*j+2)));
-// const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 4*j+3)));
-// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));
-// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));
-// }
-//#endif
-//}
-
-template <typename Dequantizer, int nrc_y>
-static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- const int nb = n / QK_K;
- Q8<nrc_y> q8(info);
- Dequantizer deq(vx, bx);
- __m256i scales[4];
- __m256 accd[nrc_y];
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
-
- deq.new_row(ix);
-
- for (int i = 0; i < nb; ++i) {
-
- __m256i sumi[nrc_y], all_scales[Dequantizer::num_blocks/8];
- //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();
- __m256i mins;
- float dmin = deq.new_block(i, all_scales, mins);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bsums = q8.load_bsums(iy, i);
- auto prod = _mm256_madd_epi16(mins, bsums);
- accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
- }
-
- for (int j = 0; j < QK_K/128; ++j) {
- deq.prepare(i, j);
- if constexpr (Dequantizer::num_blocks == 8) {
- set_scales_8(all_scales[0], j, scales);
- } else {
- set_scales_16(all_scales[j], scales);
- }
- //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);
- multiply_add(deq.bits, scales, j, i, q8, sumi);
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
- accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
- }
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, hsum_float_8(accd[iy]));
- }
- }
-}
-
-template <int nrc> struct Q8_K64 {
-
- constexpr static int nrc_y = nrc;
-
- Q8_K64(const DataInfo& info) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- const float * dptr = (const float *)info.src1_row(iy);
- std::memcpy(d + 8*iy, dptr, 8*sizeof(float));
- y[iy] = (const int8_t *)(dptr + 8);
- }
- }
-
- inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); }
- inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 8*iy); }
- inline __m128 minus(int iy) const { return _mm_loadu_ps(d + 8*iy + 4); }
-
- float d[8*nrc_y];
- const int8_t * y[nrc_y];
-};
-
-struct DequantizerIQ1BN {
- const __m256i m1_8 = _mm256_set1_epi8(1);
- static __m256i load_shuffle(int i) {
- static const uint8_t data[128] = {
- 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 1, 255, 1, 255, 1, 255, 1, 255, 1, 255, 2, 255, 2, 255, 2, 255, 2, 255, 2, 255, 12, 255,
- 3, 255, 3, 255, 3, 255, 3, 255, 3, 255, 4, 255, 4, 255, 4, 255, 4, 255, 4, 255, 5, 255, 5, 255, 5, 255, 5, 255, 5, 255, 12, 255,
- 6, 255, 6, 255, 6, 255, 6, 255, 6, 255, 7, 255, 7, 255, 7, 255, 7, 255, 7, 255, 8, 255, 8, 255, 8, 255, 8, 255, 8, 255, 12, 255,
- 9, 255, 9, 255, 9, 255, 9, 255, 9, 255, 10, 255, 10, 255, 10, 255, 10, 255, 10, 255, 11, 255, 11, 255, 11, 255, 11, 255, 11, 255, 12, 255,
- };
- return _mm256_loadu_si256((const __m256i*)data + i);
- }
- const __m256i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) };
- const __m256i mult[4] = {
- _mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
- _mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
- _mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
- _mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
- };
- const __m256i m3 = _mm256_set1_epi16(3);
-#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__
- const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
-#endif
-
- IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const {
- auto data128 = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes!
- auto data = MM256_SET_M128I(data128, data128);
- auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[0]), mult[0]), m3);
- auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[1]), mult[1]), m3);
- auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3);
- auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3);
-#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__
- v1 = _mm256_permutex2var_epi8(val1, bmask, val2);
- v2 = _mm256_permutex2var_epi8(val3, bmask, val4);
-#else
- v1 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216);
- v2 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216);
-#endif
- }
-
-};
-
-template <int nrc_y>
-IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- const int nb = n / QK_IQ1BN;
- Q8_K64<nrc_y> q8(info);
- DequantizerIQ1BN deq;
- __m256i accd[nrc_y];
- __m256i val[4];
-
-#ifndef HAVE_FANCY_SIMD
- const auto m1_16 = _mm256_set1_epi16(1);
-#endif
-
- const block_iq1_bn * x;
- const char * cx0 = (const char *)vx;
- float scale;
- ggml_half d16;
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- const char * cx = cx0 + ix*bx;
- std::memcpy(&d16, cx, sizeof(d16));
- scale = GGML_FP16_TO_FP32(d16);
- cx += sizeof(d16);
- x = (const block_iq1_bn *)cx;
-
- if constexpr (nrc_y == 1) {
- __m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256();
- for (int i = 0; i < nb/2; ++i) {
- deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
- deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
-#ifdef HAVE_FANCY_SIMD
- acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1));
- acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3));
-#else
- auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
- _mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
- auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
- _mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
- acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1));
- acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2));
-#endif
- }
- accd[0] = _mm256_add_epi32(acc1, acc2);
- }
- else {
-
- for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
-
- for (int i = 0; i < nb/2; ++i) {
-
- deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
- deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
-
- for (int iy = 0; iy < nrc_y; ++iy) {
-#ifdef HAVE_FANCY_SIMD
- accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
- val[0], q8.load_quants(iy, i, 0)),
- val[1], q8.load_quants(iy, i, 1)),
- val[2], q8.load_quants(iy, i, 2)),
- val[3], q8.load_quants(iy, i, 3));
-#else
- auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
- _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1)));
- auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
- _mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)));
- dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
- accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
-#endif
- }
- }
- }
- int i = 2*(nb/2);
- if (i < nb) {
- deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
- for (int iy = 0; iy < nrc_y; ++iy) {
-#ifdef HAVE_FANCY_SIMD
- accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
- val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1));
-#else
- auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
- _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 1))));
- accd[iy] = _mm256_add_epi32(dot, accd[iy]);
-#endif
- }
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto vd = q8.scale(iy);
- auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
- auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
- info.store(ix, iy, scale*hsum_float_4(sumf));
- }
-
- }
-}
-
-struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn, true> {
- DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
-
- IQK_ALWAYS_INLINE void prepare4(int i, __m256i * val) const {
- auto q2bits_1 = _mm256_loadu_si256((const __m256i *)x[2*i].qs);
- auto q2bits_2 = _mm256_srli_epi16(q2bits_1, 2);
- make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x20), val+0);
- make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2);
- }
- IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const {
- val[0] = _mm256_and_si256(q2_1, mask2);
- val[1] = _mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2);
- }
- IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const {
- auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
- make2(MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1), val);
- }
- const __m256i m1_8 = _mm256_set1_epi8(1);
- const __m256i mf_8 = _mm256_set1_epi8(16);
- const __m256i mask2 = _mm256_set1_epi8(0x03);
- const __m256i mask3 = _mm256_set1_epi8(0x30);
-};
-
-template <int nrc_y>
-IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- const int nb = n / QK_IQ1BN;
- Q8_K64<nrc_y> q8(info);
- DequantizeIQ2BN deq(vx, bx);
- __m256i accd[nrc_y];
- __m256i val[4];
-
-#ifndef HAVE_FANCY_SIMD
- const auto m1_16 = _mm256_set1_epi16(1);
-#endif
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- deq.new_row(ix);
-
- if constexpr (nrc_y == 1) {
- __m256i acc[2] = {};
- for (int i = 0; i < nb/2; ++i) {
- deq.prepare4(i, val);
-#ifdef HAVE_FANCY_SIMD
- acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)),
- val[1], q8.load_quants(0, i, 1));
- acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)),
- val[3], q8.load_quants(0, i, 3));
-#else
- auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
- _mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
- auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
- _mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
- acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1));
- acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2));
-#endif
- }
- accd[0] = _mm256_add_epi32(acc[0], acc[1]);
- }
- else {
-
- for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
-
- for (int i = 0; i < nb/2; ++i) {
- deq.prepare4(i, val);
- for (int iy = 0; iy < nrc_y; ++iy) {
-#ifdef HAVE_FANCY_SIMD
- accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
- val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)),
- val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3));
-#else
- auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
- _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
- _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))),
- _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
- _mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)))));
- accd[iy] = _mm256_add_epi32(dot, accd[iy]);
-#endif
- }
- }
- }
- int i = 2*(nb/2);
- if (i < nb) {
- deq.prepare2(i, val);
- for (int iy = 0; iy < nrc_y; ++iy) {
-#ifdef HAVE_FANCY_SIMD
- accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)),
- val[1], q8.load_quants(iy, i/2, 1));
-#else
- auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
- _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 0))));
- accd[iy] = _mm256_add_epi32(dot, accd[iy]);
-#endif
- }
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto vd = q8.scale(iy);
- auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
- auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
- info.store(ix, iy, deq.d*hsum_float_4(sumf));
- }
- }
-}
-
-template <typename Dequantizer, int nrc_y>
-static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n % QK_K == 0);
-#ifdef HAVE_FANCY_SIMD
- if constexpr (nrc_y == 1) {
- mul_mat_qX_K_q8_K_IQ_1<Dequantizer>(n, vx, bx, info, nrc_x);
- } else {
- mul_mat_qX_K_q8_K_IQ_N<Dequantizer, nrc_y>(n, vx, bx, info, nrc_x);
- }
-#else
- mul_mat_qX_K_q8_K_IQ_N<Dequantizer, nrc_y>(n, vx, bx, info, nrc_x);
-#endif
-}
-
-struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
- DequantizerIQ3S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
-
- constexpr static int num_blocks = 8;
-
- inline __m128i make_scales(int i, float& dd) const {
- dd = GGML_FP16_TO_FP32(x[i].d);
- uint32_t aux32[2];
- std::memcpy(aux32, x[i].scales, 4);
- aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
- aux32[0] &= 0x0f0f0f0f;
- auto scales8 = _mm_shuffle_epi8(_mm_loadl_epi64((const __m128i *)aux32), _mm_set1_epi64x(0x0703060205010400));
- auto scales16 = _mm256_castsi256_si128(_mm256_cvtepi8_epi16(scales8));
- return _mm_or_si128(_mm_slli_epi16(scales16, 1), _mm_set1_epi16(1));
- }
- inline void new_block(int i, __m256i * scales) {
- auto scales16 = make_scales(i, d);
- scales[0] = MM256_SET_M128I(scales16, scales16);
- }
- inline float new_block(int i, __m256i * scales, __m256i& mins) {
- auto scales16 = make_scales(i, d);
- mins = scb.shuffle(scales16);
- scales[0] = MM256_SET_M128I(scales16, scales16);
- return -minv*d;
- }
-
- inline void prepare(int i, int j) {
- prepare_unsigned(i, j);
- sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, bits.values);
- for (int k = 0; k < 4; ++k) bits.values[k] = _mm256_add_epi8(bits.values[k], min_value);
- }
- inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
- prepare_unsigned(i, j);
- for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
- sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants);
- }
-
- inline void prepare_unsigned(int i, int j) {
- auto qs = x[i].qs + 32*j;
- auto qh = x[i].qh + 4*j;
- helper.make2(qs+ 0, qh+0, bits.values+0);
- helper.make2(qs+16, qh+2, bits.values+2);
- }
-
- constexpr static int minv = 16;
-
- SimpleBits bits;
- SignHelper sh;
- Scales8KBase scb;
- IndexHelperIQ3S helper;
- const __m256i min_value = _mm256_set1_epi8(minv);
-
-};
-
-struct EvenSignHelper {
-#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
- union sbits_t {
- __m128i vec;
- __mmask32 mask[4];
- };
- IQK_ALWAYS_INLINE void sign_2_values(__m256i aux, __m256i * values) const {
- aux = _mm256_and_si256(_mm256_srlv_epi32(aux, shifts), mask);
- auto pcnt = _mm256_popcnt_epi32(aux);
- sbits_t sbits;
- sbits.vec = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7)));
- values[0] = _mm256_mask_sub_epi8(values[0], sbits.mask[0], _mm256_setzero_si256(), values[0]);
- values[1] = _mm256_mask_sub_epi8(values[1], sbits.mask[1], _mm256_setzero_si256(), values[1]);
- //auto sign_bits = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7)));
- //const __mmask32 * m32 = (const __mmask32 *)&sign_bits;
- //values[0] = _mm256_mask_sub_epi8(values[0], m32[0], _mm256_setzero_si256(), values[0]);
- //values[1] = _mm256_mask_sub_epi8(values[1], m32[1], _mm256_setzero_si256(), values[1]);
- }
- const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0);
- const __m256i mask = _mm256_set1_epi32(127);
- const __m256i mone = _mm256_set1_epi32(1);
-#else
- inline void sign_value(uint32_t aux32, __m256i& value) const {
- auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127],
- keven_signs[(aux32 >> 7) & 127], keven_signs[(aux32 >> 0) & 127]);
- value = _mm256_sign_epi8(value, signs);
- }
-#endif
-};
-
-struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
- DequantizerIQ3XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
-
- constexpr static int num_blocks = 8;
-
- inline __m128i prepare_scales(int i) {
- d = 0.25f * GGML_FP16_TO_FP32(x[i].d);
- auto tmp = _mm256_loadu_si256((const __m256i *)(x[i].qs + QK_K/4));
- auto scales32 = _mm256_srli_epi32(tmp, 28);
- scales32 = _mm256_or_si256(_mm256_slli_epi32(scales32, 1), _mm256_set1_epi32(1));
- return _mm_packs_epi32(_mm256_castsi256_si128(scales32), _mm256_extractf128_si256(scales32, 1));
- }
-
- inline void new_block(int i, __m256i * scales) {
- auto scales16 = prepare_scales(i);
- scales[0] = MM256_SET_M128I(scales16, scales16);
- }
- inline float new_block(int i, __m256i * scales, __m256i& mins) {
- auto scales16 = prepare_scales(i);
- mins = scb.shuffle(scales16);
- scales[0] = MM256_SET_M128I(scales16, scales16);
- return -d*minv;
- }
-
- inline static __m256i make_quants(const uint8_t * qs) {
- return _mm256_set_epi32(iq3xxs_grid[qs[7]], iq3xxs_grid[qs[6]], iq3xxs_grid[qs[5]], iq3xxs_grid[qs[4]],
- iq3xxs_grid[qs[3]], iq3xxs_grid[qs[2]], iq3xxs_grid[qs[1]], iq3xxs_grid[qs[0]]);
- }
- inline static void make4_unsigned(const uint8_t * qs, __m256i * values) {
- values[0] = make_quants(qs+ 0);
- values[1] = make_quants(qs+ 8);
- values[2] = make_quants(qs+16);
- values[3] = make_quants(qs+24);
- }
-
- IQK_ALWAYS_INLINE void sign_2_values(const uint16_t * signs, __m256i * values) const {
-#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
- esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(signs[2] | (signs[3] << 16)), _mm_set1_epi32(signs[0] | (signs[1] << 16))), values);
-#else
- esh.sign_value(signs[0] | (signs[1] << 16), values[0]);
- esh.sign_value(signs[2] | (signs[3] << 16), values[1]);
-#endif
- }
-
- inline void prepare(int i, int j) {
- auto qs = x[i].qs + 32*j;
- const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j;
- make4_unsigned(qs, bits.values);
- sign_2_values(signs+0, bits.values+0);
- sign_2_values(signs+4, bits.values+2);
- for (int k = 0; k < 4; ++k) bits.values[k] = _mm256_add_epi32(bits.values[k], min_value);
- }
- inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
- for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
- auto qs = x[i].qs + 32*j;
- const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j;
- make4_unsigned(qs, bits.values);
- sign_2_values(signs+0, q8_quants+0);
- sign_2_values(signs+4, q8_quants+2);
- }
-
- constexpr static int minv = 64;
-
- SimpleBits bits;
- Scales8KBase scb;
- EvenSignHelper esh;
- const __m256i min_value = _mm256_set1_epi8(minv);
-
-};
-
-struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
- DequantizerIQ2S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
-
- constexpr static int num_blocks = 16;
-
- inline __m256i load_scales(int i) {
- d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
- auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
- auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf));
- auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
- return _mm256_cvtepi8_epi16(scales8);
- }
- inline static void prepare_scales(const __m256i& all, __m256i * scales) {
- auto scales_l = _mm256_castsi256_si128(all);
- auto scales_h = _mm256_extractf128_si256(all, 1);
- scales[0] = MM256_SET_M128I(scales_l, scales_l);
- scales[1] = MM256_SET_M128I(scales_h, scales_h);
- }
-
- inline void new_block(int i, __m256i * scales) {
- prepare_scales(load_scales(i), scales);
- }
- inline float new_block(int i, __m256i * scales, __m256i& mins) {
- mins = load_scales(i);
- prepare_scales(mins, scales);
- return -d*minv;
- }
-
- union index_t {
- __m256i vec;
- uint32_t val[8];
- };
-
- inline static void make2(const uint8_t * qs, const uint8_t * qh, const __m256i& idx_shift, const __m256i& idx_mask, __m256i * values) {
- auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs));
- auto idx_h = MM256_SET_M128I(_mm_set1_epi32(qh[1]), _mm_set1_epi32(qh[0]));
- index_t idx;
- idx.vec = _mm256_or_si256(idx_l, _mm256_and_si256(_mm256_sllv_epi32(idx_h, idx_shift), idx_mask));
- values[0] = _mm256_set_epi64x(iq2s_grid[idx.val[3]], iq2s_grid[idx.val[2]], iq2s_grid[idx.val[1]], iq2s_grid[idx.val[0]]);
- values[1] = _mm256_set_epi64x(iq2s_grid[idx.val[7]], iq2s_grid[idx.val[6]], iq2s_grid[idx.val[5]], iq2s_grid[idx.val[4]]);
- }
- inline static void make2_signed(const SignHelper& sh, const uint8_t * qs, const uint8_t * qh, const uint16_t * sidx,
- const __m256i& idx_shift, const __m256i& idx_mask, const __m256i& min_value, __m256i * values) {
- make2(qs, qh, idx_shift, idx_mask, values);
- values[0] = _mm256_add_epi8(sh.sign_value(sidx+0, values[0]), min_value);
- values[1] = _mm256_add_epi8(sh.sign_value(sidx+2, values[1]), min_value);
- }
-
- inline void prepare(int i, int j) {
- auto qs = x[i].qs + 16*j;
- auto qh = x[i].qh + 4*j;
- const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/8) + 8*j;
- make2_signed(sh, qs+0, qh+0, signs+0, idx_shift, idx_mask, min_value, bits.values+0);
- make2_signed(sh, qs+8, qh+2, signs+4, idx_shift, idx_mask, min_value, bits.values+2);
- }
- inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
- auto qs = x[i].qs + 16*j;
- auto qh = x[i].qh + 4*j;
- const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/8) + 8*j;
- make2(qs+0, qh+0, idx_shift, idx_mask, bits.values+0);
- make2(qs+8, qh+2, idx_shift, idx_mask, bits.values+2);
- q8_quants[0] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+0), sh.make_signs(signs[0] | (signs[1] << 16)));
- q8_quants[1] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+1), sh.make_signs(signs[2] | (signs[3] << 16)));
- q8_quants[2] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+2), sh.make_signs(signs[4] | (signs[5] << 16)));
- q8_quants[3] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+3), sh.make_signs(signs[6] | (signs[7] << 16)));
- }
-
- constexpr static int minv = 43;
-
- SimpleBits bits;
- SignHelper sh;
- const __m256i idx_shift = _mm256_set_epi32(2, 4, 6, 8, 2, 4, 6, 8);
- const __m256i idx_mask = _mm256_set1_epi32(0x300);
- const __m256i min_value = _mm256_set1_epi8(minv);
-
-};
-
-struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
- DequantizerIQ2XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
-
- constexpr static int num_blocks = 16;
-
- inline __m256i load_scales(int i) {
- d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
- auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
- auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf));
- auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
- return _mm256_cvtepi8_epi16(scales8);
- }
- inline static void prepare_scales(const __m256i& all, __m256i * scales) {
- auto scales_l = _mm256_castsi256_si128(all);
- auto scales_h = _mm256_extractf128_si256(all, 1);
- scales[0] = MM256_SET_M128I(scales_l, scales_l);
- scales[1] = MM256_SET_M128I(scales_h, scales_h);
- }
-
- inline void new_block(int i, __m256i * scales) {
- prepare_scales(load_scales(i), scales);
- }
- inline float new_block(int i, __m256i * scales, __m256i& mins) {
- mins = load_scales(i);
- prepare_scales(mins, scales);
- return -d*minv;
- }
-
- struct Helper {
- const __m256i mone = _mm256_set1_epi8(1);
- const __m256i mask = _mm256_set1_epi64x(0x8040201008040201);
- //const __m256i bhelper = _mm256_set_epi64x(0x8000008000808000, 0x0080800080000080, 0x8000008000808000, 0x0080800080000080);
- const __m256i bhelper = load_bhelper();
- const __m256i shuff1 = _mm256_set_epi64x(0x0606060606060606, 0x0404040404040404, 0x0202020202020202, 0x0000000000000000);
- const __m256i shuff2 = _mm256_set_epi64x(0x0e0e0e0e0e0e0e0e, 0x0c0c0c0c0c0c0c0c, 0x0a0a0a0a0a0a0a0a, 0x0808080808080808);
- static __m256i load_bhelper() {
- static const uint8_t k_bit_helper[32] = {
- 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
- 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
- };
- return _mm256_loadu_si256((const __m256i*)k_bit_helper);
- }
- };
-
- union index_t {
- __m256i vec;
- uint16_t val[8];
- };
-
- inline static void make4(const __m256i& data, const __m256i& mask, __m256i * values) {
- index_t idx;
- idx.vec = _mm256_and_si256(data, mask);
- values[0] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 3]], iq2xs_grid[idx.val[ 2]], iq2xs_grid[idx.val[ 1]], iq2xs_grid[idx.val[ 0]]);
- values[1] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 7]], iq2xs_grid[idx.val[ 6]], iq2xs_grid[idx.val[ 5]], iq2xs_grid[idx.val[ 4]]);
- values[2] = _mm256_set_epi64x(iq2xs_grid[idx.val[11]], iq2xs_grid[idx.val[10]], iq2xs_grid[idx.val[ 9]], iq2xs_grid[idx.val[ 8]]);
- values[3] = _mm256_set_epi64x(iq2xs_grid[idx.val[15]], iq2xs_grid[idx.val[14]], iq2xs_grid[idx.val[13]], iq2xs_grid[idx.val[12]]);
- }
- inline static void sign_value(const __m256i& sign_bits, const __m256i& shuffle, const __m256i& mask,
- const __m256i& mone, __m256i& value) {
- auto signs = _mm256_shuffle_epi8(sign_bits, shuffle);
- signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, mask), mask);
- value = _mm256_sign_epi8(value, _mm256_or_si256(signs, mone));
- }
- inline void sign_values(const __m256i& data, __m256i * values) const {
-#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
- auto partial_bits = _mm256_cvtepi16_epi8(_mm256_srli_epi16(data, 9));
- auto pcnt = _mm_popcnt_epi8(partial_bits);
- auto full_bits = _mm_or_si128(partial_bits, _mm_slli_epi16(_mm_and_si128(pcnt, _mm_set1_epi8(1)), 7));
- const __mmask32 * m32 = (const __mmask32 *)&full_bits;
- auto zero = _mm256_setzero_si256();
- values[0] = _mm256_mask_sub_epi8(values[0], m32[0], zero, values[0]);
- values[1] = _mm256_mask_sub_epi8(values[1], m32[1], zero, values[1]);
- values[2] = _mm256_mask_sub_epi8(values[2], m32[2], zero, values[2]);
- values[3] = _mm256_mask_sub_epi8(values[3], m32[3], zero, values[3]);
-#else
- auto psb1 = _mm256_srli_epi16(data, 9);
- auto psb2 = _mm256_srli_epi16(data, 13);
- auto psbc = _mm256_xor_si256(psb1, psb2);
- auto oddb = _mm256_shuffle_epi8(helper.bhelper, psbc);
- auto full = _mm256_or_si256(psb1, oddb);
- auto full_l = _mm256_castsi256_si128(full);
- auto full_h = _mm256_extractf128_si256(full, 1);
- auto full_1 = MM256_SET_M128I(full_l, full_l);
- auto full_2 = MM256_SET_M128I(full_h, full_h);
- sign_value(full_1, helper.shuff1, helper.mask, helper.mone, values[0]);
- sign_value(full_1, helper.shuff2, helper.mask, helper.mone, values[1]);
- sign_value(full_2, helper.shuff1, helper.mask, helper.mone, values[2]);
- sign_value(full_2, helper.shuff2, helper.mask, helper.mone, values[3]);
-#endif
- }
- inline void make4_signed(const uint16_t * qs, const __m256i& m511,
- const __m256i& min_value, __m256i * values) const {
- auto q2 = _mm256_loadu_si256((const __m256i *)qs);
- make4(q2, m511, values);
- sign_values(q2, values);
- for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);
- }
- inline void make4(const uint16_t * qs, const __m256i& m511, __m256i * values, __m256i * q8) const {
- auto q2 = _mm256_loadu_si256((const __m256i *)qs);
- make4(q2, m511, values);
- sign_values(q2, q8);
- }
-
- inline void prepare(int i, int j) {
- make4_signed(x[i].qs + 16*j, idx_mask, min_value, bits.values);
- }
- inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
- for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
- make4(x[i].qs + 16*j, idx_mask, bits.values, q8_quants);
- }
-
- constexpr static int minv = 43;
-
- SimpleBits bits;
-#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__)
- Helper helper;
-#endif
- const __m256i idx_mask = _mm256_set1_epi16(511);
- const __m256i min_value = _mm256_set1_epi8(minv);
-
-};
-
-struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
- DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
-
- constexpr static int num_blocks = 8;
-
- union Data {
- __m256i vec;
- uint32_t val[8];
- };
-
- inline __m128i load_scales(int i) {
- d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
- const uint16_t * a16 = (const uint16_t *)x[i].qs;
- auto scales = _mm_srli_epi16(_mm_set_epi16(a16[31], a16[27], a16[23], a16[19], a16[15], a16[11], a16[7], a16[3]), 12);
- return _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi16(1));
- }
-
- inline void new_block(int i, __m256i * scales) {
- auto sc16 = load_scales(i);
- scales[0] = MM256_SET_M128I(sc16, sc16);
- }
- inline float new_block(int i, __m256i * scales, __m256i& mins) {
- auto sc16 = load_scales(i);
- mins = scb.shuffle(sc16);
- scales[0] = MM256_SET_M128I(sc16, sc16);
- return -d*minv;
- }
-
- inline static void make4(const uint32_t * aux32, __m256i * values) {
- const uint8_t * aux8 = (const uint8_t *)aux32;
- values[0] = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[ 1]], iq2xxs_grid[aux8[ 0]]);
- values[1] = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[ 9]], iq2xxs_grid[aux8[ 8]]);
- values[2] = _mm256_set_epi64x(iq2xxs_grid[aux8[19]], iq2xxs_grid[aux8[18]], iq2xxs_grid[aux8[17]], iq2xxs_grid[aux8[16]]);
- values[3] = _mm256_set_epi64x(iq2xxs_grid[aux8[27]], iq2xxs_grid[aux8[26]], iq2xxs_grid[aux8[25]], iq2xxs_grid[aux8[24]]);
- }
-
- IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const {
-#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
- esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0);
- esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2);
-#else
- esh.sign_value(aux32[1], values[0]);
- esh.sign_value(aux32[3], values[1]);
- esh.sign_value(aux32[5], values[2]);
- esh.sign_value(aux32[7], values[3]);
-#endif
- }
- inline void make4_signed(const uint32_t * aux32, const __m256i& min_value, __m256i * values) const {
- make4(aux32, values);
- sign_values(aux32, values);
- for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);
- }
- inline void make4(const uint32_t * aux32, __m256i * values, __m256i * q8) const {
- make4(aux32, values);
- sign_values(aux32, q8);
- }
- inline void prepare(int i, int j) {
- Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
- make4_signed(data.val, min_value, bits.values);
- }
- inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
- for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
- Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
- make4(data.val, bits.values, q8_quants);
- }
-
- constexpr static int minv = 43;
- SimpleBits bits;
- Scales8KBase scb;
- EvenSignHelper esh;
- const __m256i min_value = _mm256_set1_epi8(minv);
- const __m256i shuffle = _mm256_set_epi32(7, 5, 3, 1, 7, 5, 3, 1);
-};
-
-//
-// ============================== Legacy quants
-//
-
-struct DotHelper {
- const __m256i m1 = _mm256_set1_epi16(1);
-#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
- inline __m256i dot(__m256i x, __m256i y) const {
- return _mm256_dpbusd_epi32(_mm256_setzero_si256(), x, y);
- }
-#else
- inline __m256i dot(__m256i x, __m256i y) const {
- return _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x, y));
- }
-#endif
-};
-
-struct SignedDot {
- DotHelper helper;
- inline __m256i compute(__m256i x, __m256i y) const {
- return helper.dot(_mm256_sign_epi8(x, x), _mm256_sign_epi8(y, x));
- }
-};
-struct UnsignedDot {
- DotHelper helper;
- inline __m256i compute(__m256i x, __m256i y) const {
- return helper.dot(x, y);
- }
-};
-
-template <typename Q8, typename Q8x4, typename Dot, bool can_pack = true> struct Sum4 {
- Dot dot;
- inline __m256i compute(const __m256i * qx, const Q8 * y) const {
- const Q8x4 * y4 = (const Q8x4 *)y;
- const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 8x block 0
- const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 8x block 1
- const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 8x block 2
- const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 8x block 3
- if constexpr (can_pack) {
- const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1
- const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3
- return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3
- } else {
- // Note to myself: this is much faster than using _mm256_hadd_epi32()
- auto p01 = _mm256_add_epi32(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,1, 0,1, 0,1, 0,1
- auto p23 = _mm256_add_epi32(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,3, 2,3, 2,3, 2,3
- return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3
- }
- }
- inline __m256i compute(__m256i x, __m256i y) const { return dot.compute(x, y); }
-};
-
-template <typename Q8, typename Q8x4> struct Sum4q4 {
- inline __m256i compute(const __m256i * qx, const Q8 * y) const {
- const Q8x4 * y4 = (const Q8x4 *)y;
- auto p0 = _mm256_maddubs_epi16(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 16x block 0
- auto p1 = _mm256_maddubs_epi16(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 16x block 1
- auto p2 = _mm256_maddubs_epi16(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 16x block 2
- auto p3 = _mm256_maddubs_epi16(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 16x block 3
- auto p01 = _mm256_add_epi16(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1, 0,0, 1,1, 0,0, 1,1
- auto p23 = _mm256_add_epi16(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3, 2,2, 3,3, 2,2, 3,3
- auto p0123 = _mm256_add_epi16(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3
- return _mm256_madd_epi16(_mm256_set1_epi16(1), p0123);
- }
- inline __m256i compute(__m256i x, __m256i y) const { return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(x, y)); }
-};
-
-struct ScaleHelperQ8_0 {
- inline __m128 prepare4(const block_q8_0 * y) {
- const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y;
- return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4->d));
- }
- inline __m128 prepare4(__m128 other_scales, const block_q8_0 * y) {
- return _mm_mul_ps(other_scales, prepare4(y));
- }
- template <typename Q> inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); }
- template <typename Q> inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }
-};
-
-struct ScaleHelperQ_0 {
- ggml_half scales8[4];
- template <typename Q>
- inline __m128 prepare4(const Q * y) {
- for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;
- return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8));
- }
- template <typename Q>
- inline __m128 prepare4(__m128 other_scales, const Q * y) {
- return _mm_mul_ps(other_scales, prepare4<Q>(y));
- }
- template <typename Q> inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); }
- template <typename Q> inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }
-};
-
-template <int min_value>
-struct ScaleHelperQ_0_1 {
- ggml_half scales8[4];
- template <typename Q>
- inline __m256 prepare4(const Q * y) {
- for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;
- auto s4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8));
- return _mm256_set_m128(_mm_mul_ps(s4, min), s4);
- }
- template <typename Q>
- inline __m256 prepare4(__m256 other_scales, const Q * y) {
- return _mm_mul256_ps(other_scales, prepare4<Q>(y));
- }
- template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
- float d = GGML_FP16_TO_FP32(y->d);
- return std::make_pair(d, -d*float(min_value));
- }
- std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {
- return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
- }
- const __m128 min = _mm_set1_ps(float(-min_value));
-};
-
-//template <int min_value>
-//struct ScaleHelperQ_0_2 {
-// ggml_bf16_t scales8[4];
-// template <typename Q>
-// inline __m256 prepare4(const Q * y) {
-// for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;
-// auto s4 = _mm_castsi128_ps(_mm_slli_epi16(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)scales8)), 16));
-// return _mm256_set_m128(_mm_mul_ps(s4, min), s4);
-// }
-// template <typename Q>
-// inline __m256 prepare4(__m256 other_scales, const Q * y) {
-// return _mm_mul256_ps(other_scales, prepare4<Q>(y));
-// }
-// template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
-// float d = GGML_BF16_TO_FP32(y->d);
-// return std::make_pair(d, -d*float(min_value));
-// }
-// std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {
-// return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
-// }
-// const __m128 min = _mm_set1_ps(float(-min_value));
-//};
-
-struct ScaleHelperQ8_1 {
- template <typename Q>
- inline __m256 prepare4(const Q * y) {
- const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y;
- return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)y4->d));
- }
- template <typename Q>
- inline __m256 prepare4(__m256 other_scales, const Q * y) {
- return _mm256_mul_ps(other_scales, prepare4<Q>(y));
- }
- template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
- return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m));
- }
- template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const {
- return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m));
- }
- std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {
- return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
- }
-};
-
-struct ScaleHelperQ8_2 {
- template <typename Q>
- inline __m256 prepare4(const Q * y) {
- const block_q8_2_x4 * y4 = (const block_q8_2_x4 *)y;
- auto aux = _mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)y4->d));
- return _mm256_castsi256_ps(_mm256_slli_epi32(aux, 16));
- }
- template <typename Q>
- inline __m256 prepare4(__m256 other_scales, const Q * y) {
- return _mm256_mul_ps(other_scales, prepare4<Q>(y));
- }
- template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
- return std::make_pair(GGML_BF16_TO_FP32(y->d), GGML_BF16_TO_FP32(y->m));
- }
- template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const {
- ggml_bf16_t d, s; d.bits = y->d; s.bits = y->s;
- return std::make_pair(dm.first*GGML_BF16_TO_FP32(d), dm.second*GGML_BF16_TO_FP32(s));
- }
- std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_2 * y) const {
- ggml_bf16_t d, s; d.bits = y->d; s.bits = y->s;
- return std::make_pair(dm.first*GGML_BF16_TO_FP32(d), dm.second*GGML_BF16_TO_FP32(s));
- }
-};
-
-struct ScaleHelperQ_1 {
- uint32_t scales8[4];
- const __m128i shuffle = _mm_set_epi16(0x0f0e, 0x0b0a, 0x0706, 0x0302, 0x0d0c, 0x0908, 0x0504, 0x0100);
-
- template <typename Q>
- inline __m256 prepare4(const Q * y) {
- for (int j = 0; j < 4; ++j) {
- // it is slightly faster to directly dereference (const uint32 *)&y[j].d, but some compilers
- // complain that this breaks strict-aliasing rules.
- memcpy(scales8 + j, &y[j].d, sizeof(uint32_t));
- }
- return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *)scales8), shuffle));
- }
-
- template <typename Q>
- inline __m256 prepare4(__m256 other_scales, const Q * y) {
- return _mm256_mul_ps(other_scales, prepare4<Q>(y));
- }
-
- template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
- return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m));
- }
- template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const {
- return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m));
- }
- std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {
- return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
- }
-};
-
-struct MinusType0 {
- inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); }
- inline float compute(float d, int) const { return d; }
- inline float result(__m256 acc, int) const { return hsum_float_8(acc); }
- inline __m256 vresult(__m256 acc, int) const { return acc; }
-};
-
-template <int nrc_y> struct MinusType1 {
- __m128 accm[nrc_y];
- MinusType1() { for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm_setzero_ps(); }
- inline __m256 compute(__m256 dm, int iy) {
- const __m128 d = _mm256_castps256_ps128(dm);
- const __m128 m = _mm256_extractf128_ps(dm, 1);
- accm[iy] = _mm_add_ps(accm[iy], m);
- return _mm256_set_m128(d, d);
- }
- inline float compute(const std::pair<float, float>& dm, int iy) {
- accm[iy] = _mm_add_ps(accm[iy], _mm_set1_ps(dm.second*0.25f));
- return dm.first;
- }
- inline float result(__m256 acc, int iy) const {
- const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
- return hsum_float_4(_mm_add_ps(sum, accm[iy]));
- }
- inline __m256 vresult(__m256 acc, int iy) const {
- return _mm256_add_ps(acc, _mm256_insertf128_ps(_mm256_setzero_ps(), accm[iy], 0));
- }
-};
-
-template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
- __m256 acc[nrc_y];
- Minus accm;
- AccumT() { for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps(); }
- template <typename Unpacker, typename Scales, typename Sum, typename Q8>
- inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, const DataInfo& info, int ix) {
- auto qx = unp.quants();
- __m256 dall[nrc_y];
- for (int i = 0; i < nb/4; ++i) {
- auto other_scales = unp.set_block_4(i);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto s12 = scales.prepare4(other_scales, y[iy] + 4*i);
- dall[iy] = accm.compute(s12, iy);
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto pall = sum.compute(qx, y[iy] + 4*i);
- acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]);
- }
- }
- if (!is_multiple_of_4) {
- for (int i = 4*(nb/4); i < nb; ++i) {
- auto other_scales = unp.set_block(i);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto s12 = scales.prepare1(other_scales, y[iy] + i);
- auto d = accm.compute(s12, iy);
- const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));
- acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, accm.result(acc[iy], iy));
- }
- }
- template <typename Unpacker, typename Scales, typename Sum, typename Q8>
- inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, __m256 * result) {
- auto qx = unp.quants();
- __m256 dall[nrc_y];
- for (int i = 0; i < nb/4; ++i) {
- auto other_scales = unp.set_block_4(i);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto s12 = scales.prepare4(other_scales, y[iy] + 4*i);
- dall[iy] = accm.compute(s12, iy);
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto pall = sum.compute(qx, y[iy] + 4*i);
- acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]);
- }
- }
- if (!is_multiple_of_4) {
- for (int i = 4*(nb/4); i < nb; ++i) {
- auto other_scales = unp.set_block(i);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto s12 = scales.prepare1(other_scales, y[iy] + i);
- auto d = accm.compute(s12, iy);
- const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));
- acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- result[iy] = accm.vresult(acc[iy], iy);
- }
- }
-};
-
-template <int nrc_y, bool is_multiple_of_4>
-using AccumType0 = AccumT<MinusType0, nrc_y, is_multiple_of_4>;
-
-template <int nrc_y, bool is_multiple_of_4>
-using AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>;
-
-using Sum4TypeQ80 = Sum4<block_q8_0, block_q8_0_x4, SignedDot, false>;
-using Sum4TypeQ82 = Sum4<block_q8_2, block_q8_2_x4, UnsignedDot, false>;
-
-template <typename Unpacker, typename AccumType, typename Scales, typename Q8, int nrc_y>
-void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) {
- Unpacker unp(vx, bx);
- typename Unpacker::Sum4T sum4;
- Scales scales;
- for (int ix = 0; ix < nrc_x; ++ix) {
- unp.set_row(ix);
- AccumType accum;
- accum.compute(nb, unp, scales, sum4, y, info, ix);
- }
-}
-
-template <typename Unpacker, typename AccumType, typename Scales, typename Q8, int nrc_y>
-void mul_mat_qX_q8_Helper_x2(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) {
- GGML_ASSERT(nrc_x%2 == 0);
- Unpacker unp(vx, bx);
- typename Unpacker::Sum4T sum4;
- Scales scales;
- for (int ix = 0; ix < nrc_x; ix += 2) {
- unp.set_row(ix);
- AccumType accum;
- accum.compute(nb, unp, scales, sum4, y, info, ix);
- }
-}
-
-template <typename Unpacker, int nrc_y>
-void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n%Unpacker::block_size() == 0);
- Q8<nrc_y, block_q8_0> q8(info);
- int nb = n/Unpacker::block_size();
- if (nb%4 == 0) {
- mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, true>, ScaleHelperQ8_0, block_q8_0, nrc_y>(
- nb, vx, bx, info, q8.y, nrc_x
- );
- } else {
- mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, false>, ScaleHelperQ8_0, block_q8_0, nrc_y>(
- nb, vx, bx, info, q8.y, nrc_x
- );
- }
-}
-
-inline __m256 hsum_float_8x8(__m256 * accm) {
- for (int i = 0; i < 4; ++i) {
- accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31));
- //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
- // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
- }
- for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
- return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
-}
-
-template <typename Unpacker, int nrc_y, int nrc_x>
-void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) {
- static_assert(8%nrc_y == 0);
- Q8<nrc_y, block_q8_0> q8(info);
- int nb = n/Unpacker::block_size();
- Unpacker unp(vx, bx);
- typename Unpacker::Sum4T sum4;
- ScaleHelperQ8_0 scales;
- __m256 result[8];
- auto store = [&info, &result] (int ix0) {
- if constexpr (nrc_y == 1) {
- info.store(ix0, 0, hsum_float_8x8(result));
- }
- else if constexpr (nrc_y == 2) {
- auto value = hsum_float_8x8(result);
- auto value1 = _mm256_extractf128_ps(value, 1);
- info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88));
- info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd));
- }
- else {
- float val[8];
- _mm256_storeu_ps(val, hsum_float_8x8(result));
- for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
- }
- };
- if (nb%4 == 0) {
- for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
- for (int ix = 0; ix < 8/nrc_y; ++ix) {
- unp.set_row(ix0 + ix);
- AccumType0<nrc_y, true> accum;
- accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
- }
- store(ix0);
- }
- } else {
- for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
- for (int ix = 0; ix < 8/nrc_y; ++ix) {
- unp.set_row(ix0 + ix);
- AccumType0<nrc_y, false> accum;
- accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
- }
- store(ix0);
- }
- }
-}
-
-
-template <typename Unpacker, int nrc_y>
-void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n%Unpacker::block_size() == 0);
- Q8<nrc_y, block_q8_1> q8(info);
- int nb = n/Unpacker::block_size();
- if (nb%4 == 0) {
- mul_mat_qX_q8_Helper<Unpacker, AccumType1<nrc_y, true>, ScaleHelperQ8_1, block_q8_1, nrc_y>(
- nb, vx, bx, info, q8.y, nrc_x
- );
- } else {
- mul_mat_qX_q8_Helper<Unpacker, AccumType1<nrc_y, false>, ScaleHelperQ8_1, block_q8_1, nrc_y>(
- nb, vx, bx, info, q8.y, nrc_x
- );
- }
-}
-
-template <typename Unpacker, int nrc_y>
-void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n%Unpacker::block_size() == 0);
- Q8<nrc_y, block_q8_2> q8(info);
- int nb = n/Unpacker::block_size();
- if (nb%4 == 0) {
- mul_mat_qX_q8_Helper<Unpacker, AccumType1<nrc_y, true>, ScaleHelperQ8_2, block_q8_2, nrc_y>(
- nb, vx, bx, info, q8.y, nrc_x
- );
- } else {
- mul_mat_qX_q8_Helper<Unpacker, AccumType1<nrc_y, false>, ScaleHelperQ8_2, block_q8_2, nrc_y>(
- nb, vx, bx, info, q8.y, nrc_x
- );
- }
-}
-
-template <typename Unpacker, int nrc_y, int nrc_x>
-void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) {
- static_assert(8%nrc_y == 0);
- Q8<nrc_y, block_q8_2> q8(info);
- int nb = n/Unpacker::block_size();
- Unpacker unp(vx, bx);
- typename Unpacker::Sum4T sum4;
- ScaleHelperQ8_2 scales;
- __m256 result[8];
- auto store = [&info, &result] (int ix0) {
- if constexpr (nrc_y == 1) {
- info.store(ix0, 0, hsum_float_8x8(result));
- }
- else if constexpr (nrc_y == 2) {
- auto value = hsum_float_8x8(result);
- auto value1 = _mm256_extractf128_ps(value, 1);
- info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88));
- info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd));
- }
- else {
- float val[8];
- _mm256_storeu_ps(val, hsum_float_8x8(result));
- for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
- }
- };
- if (nb%4 == 0) {
- for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
- for (int ix = 0; ix < 8/nrc_y; ++ix) {
- unp.set_row(ix0 + ix);
- AccumType1<nrc_y, true> accum;
- accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
- }
- store(ix0);
- }
- } else {
- for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
- for (int ix = 0; ix < 8/nrc_y; ++ix) {
- unp.set_row(ix0 + ix);
- AccumType1<nrc_y, false> accum;
- accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
- }
- store(ix0);
- }
- }
-}
-
-struct Dequantizer4bit {
- const __m256i m4 = _mm256_set1_epi8(0xf);
- inline __m256i dequant(const uint8_t * qs) const {
- const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);
- return _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), m4);
- }
-};
-
-struct Q8_0_Dequantizer {
- inline __m256i dequant(const block_q8_0 * x) const {
- return _mm256_loadu_si256((const __m256i *)x->qs);
- }
-};
-
-struct Q8_0_1_Dequantizer {
- inline __m256i dequant(const block_q8_0 * x) const {
- return _mm256_add_epi8(_mm256_set1_epi8(127), _mm256_loadu_si256((const __m256i *)x->qs));
- }
-};
-
-struct Q4_0_Dequantizer {
- Dequantizer4bit b4;
- const __m256i m8 = _mm256_set1_epi8(-8);
- inline __m256i dequant(const block_q4_0 * x) const {
- return _mm256_add_epi8(b4.dequant(x->qs), m8);
- }
-};
-
-struct Q4_0_1_Dequantizer {
- Dequantizer4bit b4;
- inline __m256i dequant(const block_q4_0 * x) const {
- return b4.dequant(x->qs);
- }
-};
-
-struct IQ4_NL_Dequantizer {
- Dequantizer4bit b4;
-#ifdef HAVE_FANCY_SIMD
- const __m256i values = load_iq4nl_values_256();
-#else
- const __m256i values = load_iq4k_values_256();
-#endif
- inline __m256i dequant(const block_iq4_nl * x) const {
- return _mm256_shuffle_epi8(values, b4.dequant(x->qs));
- }
-};
-
-struct Q4_1_Dequantizer {
- Dequantizer4bit b4;
- inline __m256i dequant(const block_q4_1 * x) const {
- return b4.dequant(x->qs);
- }
-};
-
-struct HBitDequantizer {
- const __m256i shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
- const __m256i mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
- const __m256i minus1 = _mm256_set1_epi64x(-1);
- inline __m256i to_bytes(const uint8_t * bits) const {
- // Note: Data in all ggml quants is at least 2-byte aligned.
- // => we can cast to uint16_t and use or on two consecutive entries
- // which is faster than memcpy
- const uint16_t * aux16 = (const uint16_t *)bits;
- const uint32_t aux32 = aux16[0] | (aux16[1] << 16);
- //uint32_t aux32; memcpy(&aux32, bits, sizeof(uint32_t));
- __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(aux32), shuffle);
- bytes = _mm256_or_si256(bytes, mask);
- return _mm256_cmpeq_epi8(bytes, minus1);
- }
-};
-
-struct Q5_0_Dequantizer {
- Dequantizer4bit b4;
- HBitDequantizer hbit;
- const __m256i mh = _mm256_set1_epi8((char)0xF0);
- inline __m256i dequant(const block_q5_0 * x) const {
- const __m256i vqh = _mm256_andnot_si256(hbit.to_bytes(x->qh), mh);
- return _mm256_or_si256(b4.dequant(x->qs), vqh);
- }
-};
-
-template <typename Q5>
-struct Q5_1_Dequantizer {
- Dequantizer4bit b4;
- HBitDequantizer hbit;
- const __m256i mh = _mm256_set1_epi8(0x10);
- inline __m256i dequant(const Q5 * x) const {
- const __m256i vqh = _mm256_and_si256(hbit.to_bytes(x->qh), mh);
- return _mm256_or_si256(b4.dequant(x->qs), vqh);
- }
-};
-struct Q6_0_1_Dequantizer {
- Dequantizer4bit b4;
- const __m256i mh = _mm256_set1_epi8(0x30);
- const __m256i shift1 = _mm256_set_epi64x(0, 2, 0, 4);
- const __m256i shift2 = _mm256_set_epi64x(2, 0, 0, 0);
- inline __m256i dequant(const block_q6_0 * x) const {
- uint64_t aux64; std::memcpy(&aux64, x->qh, 8);
- auto h256 = _mm256_sllv_epi64(_mm256_set1_epi64x(aux64), shift1);
- return _mm256_or_si256(b4.dequant(x->qs), _mm256_and_si256(_mm256_srlv_epi64(h256, shift2), mh));
- }
-};
-
-template <typename Q, typename Scales, typename Dequantizer>
-struct Q_Unpacker {
- Q_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const Q*)cx_0), bx(bx) {}
-
- const char * cx_0;
- const Q * x;
- size_t bx;
-
- Scales scales;
- Dequantizer deq;
-
- __m256i qx[4];
-
- inline const __m256i* quants() const { return qx; }
-
- inline void set_row(int ix) { x = (const Q*)(cx_0 + ix*bx); }
-
- inline auto set_block_4(int i) {
- for (int j = 0; j < 4; ++j) {
- qx[j] = deq.dequant(x + 4*i + j);
- }
- return scales.prepare4(x + 4*i);
- }
- inline auto set_block(int i) {
- qx[0] = deq.dequant(x + i);
- return scales.prepare1(x + i);
- }
-};
-
-struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {
- Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
- using Sum4T = Sum4TypeQ80;
- inline static int block_size() { return QK8_0; }
-};
-struct Q8_0_1_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0_1<127>, Q8_0_1_Dequantizer> {
- Q8_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
- using Sum4T = Sum4TypeQ82;
- inline static int block_size() { return QK8_0; }
-};
-struct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_Dequantizer> {
- Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
- using Sum4T = Sum4TypeQ80;
- inline static int block_size() { return QK4_0; }
-};
-struct Q4_0_1_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0_1<8>, Q4_0_1_Dequantizer> {
- Q4_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
- //using Sum4T = Sum4TypeQ82;
- using Sum4T = Sum4q4<block_q8_2, block_q8_2_x4>;
- inline static int block_size() { return QK4_0; }
-};
-#ifdef HAVE_FANCY_SIMD
-struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0_1<128>, IQ4_NL_Dequantizer> {
- IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
- using Sum4T = Sum4TypeQ82;
- inline static int block_size() { return QK4_NL; }
-};
-#else
-struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0, IQ4_NL_Dequantizer> {
- IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
- using Sum4T = Sum4TypeQ80;
- inline static int block_size() { return QK4_NL; }
-};
-#endif
-struct Q5_0_Unpacker final : public Q_Unpacker<block_q5_0, ScaleHelperQ_0, Q5_0_Dequantizer> {
- Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
- using Sum4T = Sum4TypeQ80;
- inline static int block_size() { return QK5_0; }
-};
-struct Q5_0_1_Unpacker final : public Q_Unpacker<block_q5_0, ScaleHelperQ_0_1<16>, Q5_1_Dequantizer<block_q5_0>> {
- Q5_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
- using Sum4T = Sum4TypeQ82;
- inline static int block_size() { return QK5_0; }
-};
-struct Q4_1_Unpacker final : public Q_Unpacker<block_q4_1, ScaleHelperQ_1, Q4_1_Dequantizer> {
- Q4_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
- using Sum4T = Sum4TypeQ82;
- inline static int block_size() { return QK4_1; }
-};
-struct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_Dequantizer<block_q5_1>> {
- Q5_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
- using Sum4T = Sum4TypeQ82;
- inline static int block_size() { return QK5_1; }
-};
-struct Q6_0_1_Unpacker final : public Q_Unpacker<block_q6_0, ScaleHelperQ_0_1<32>, Q6_0_1_Dequantizer> {
- Q6_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
- using Sum4T = Sum4TypeQ82;
- inline static int block_size() { return QK6_0; }
-};
-
-// float matrices - we handle f16, bf16 (if native bf16 support is available) and f32, but only to f32 result
-
-struct QFBase {
-#ifdef __AVX512F__
- constexpr static int k_step = 16;
- using Data = __m512;
- using Acc = __m512;
- static inline Data load(const ggml_half * x) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x)); }
- static inline Data load(const float * x) { return _mm512_loadu_ps(x); }
- static inline Data load(const ggml_bf16_t * x) {
- return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)x)), 16));
- }
- static inline Acc acc(Acc prev, const Data& y, const Data& x) {
- return _mm512_fmadd_ps(y, x, prev);
- }
- static inline Acc acc_first(const Data& y, const Data& x) {
- return _mm512_mul_ps(y, x);
- }
- static inline Acc add(Acc x, Acc y) { return _mm512_add_ps(x, y); }
- static inline float hsum(Acc acc) {
- return _mm512_reduce_add_ps(acc);
- }
- template <typename Float>
- static inline Data load4Floats(const Float * x) {
- return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0);
- }
- static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
- acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc);
- acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);
- acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);
- acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);
- return acc;
- }
- static inline Acc acc_r4_first(const Data * xv, const Data& yv) {
- auto acc = _mm512_mul_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00));
- acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);
- acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);
- acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);
- return acc;
- }
- static inline __m128 hsum_r4(Acc acc) {
- auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 0), _mm512_extractf32x4_ps(acc, 1));
- auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 2), _mm512_extractf32x4_ps(acc, 3));
- return _mm_add_ps(sum1, sum2);
- }
-#else
- constexpr static int k_step = 8;
- using Data = __m256;
- using Acc = __m256;
- static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); }
- static inline Data load(const float * x) { return _mm256_loadu_ps(x); }
- static inline Data load(const ggml_bf16_t * x) {
- return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)x)), 16));
- }
- static inline Acc acc(Acc prev, const Data& y, const Data& x) {
- return _mm256_fmadd_ps(y, x, prev);
- }
- static inline Acc add(Acc x, Acc y) { return _mm256_add_ps(x, y); }
- static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
- acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc);
- acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
- acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);
- acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);
- return acc;
- }
- static inline Acc acc_r4_first(const Data * xv, const Data& yv) {
- auto acc = _mm256_mul_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00));
- acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
- acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);
- acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);
- return acc;
- }
- static inline Acc acc_first(const Data& y, const Data& x) {
- return _mm256_mul_ps(y, x);
- }
- static inline float hsum(Acc acc) {
- return hsum_float_8(acc);
- }
- static inline __m128 hsum_r4(Acc acc) {
- return _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
- }
- template <typename Float>
- static inline Data load4Floats(const Float * x) {
- return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0);
- }
-#endif
- static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); }
- static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); }
- static inline __m128 load128(const ggml_bf16_t * x) {
- return _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*)x)), 16));
- }
-};
-template <typename Float, int nrc_in> struct QFT final : public QFBase {
- constexpr static int nrc = nrc_in;
- QFT(const DataInfo& info) {
- for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)info.src1_row(iy);
- }
- QFT(const char * cx, size_t bx) {
- for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)(cx + iy*bx);
- }
- IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
- IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); }
- IQK_ALWAYS_INLINE void load_r4(int ix, int i, Data * xv) const {
- xv[0] = load1(ix+0, i);
- xv[1] = load1(ix+1, i);
- xv[2] = load1(ix+2, i);
- xv[3] = load1(ix+3, i);
-#ifdef __AVX512F__
- auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]);
- auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]);
- auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]);
- auto t3 = _mm512_unpackhi_ps(xv[2], xv[3]);
- xv[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));
- xv[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));
- xv[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));
- xv[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));
-#else
- auto t0 = _mm256_unpacklo_ps(xv[0], xv[1]);
- auto t1 = _mm256_unpacklo_ps(xv[2], xv[3]);
- auto t2 = _mm256_unpackhi_ps(xv[0], xv[1]);
- auto t3 = _mm256_unpackhi_ps(xv[2], xv[3]);
- xv[0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
- xv[1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
- xv[2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
- xv[3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
-#endif
- }
- const Float * y[nrc];
-};
-
-// TBD if we want this
-//template <typename Qy, typename Qx>
-//IQK_NOINLINE void mul_mat_Qx_Qy_Mx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
-// static_assert(Qy::nrc == 1);
-// int nb = n/QFBase::k_step;
-// int nb4 = n/4;
-// Qy y(info);
-// Qx x(cx + ix0*bx, bx);
-// QFBase::Data xv[2*Qx::nrc];
-// QFBase::Acc acc[2*Qx::nrc];
-// auto yv1 = y.load1(0, 0);
-// auto yv2 = y.load1(0, 1);
-// for (int ix = 0; ix < Qx::nrc; ++ix) {
-// xv[2*ix+0] = x.load1(ix, 0);
-// xv[2*ix+1] = x.load1(ix, 1);
-// acc[2*ix+0] = QFBase::acc_first(yv1, xv[2*ix+0]);
-// acc[2*ix+1] = QFBase::acc_first(yv2, xv[2*ix+1]);
-// }
-// for (int i = 1; i < nb/2; ++i) {
-// yv1 = y.load1(0, 2*i+0);
-// yv2 = y.load1(0, 2*i+1);
-// for (int ix = 0; ix < Qx::nrc; ++ix) {
-// xv[2*ix+0] = x.load1(ix, 2*i+0);
-// xv[2*ix+1] = x.load1(ix, 2*i+1);
-// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[2*ix+0]);
-// acc[2*ix+1] = QFBase::acc(acc[2*ix+1], yv2, xv[2*ix+1]);
-// }
-// }
-// for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
-// yv1 = y.load_tail(0, i);
-// for (int ix = 0; ix < Qx::nrc; ++ix) {
-// xv[ix] = x.load_tail(ix, i);
-// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[ix]);
-// }
-// }
-// for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, 0, QFBase::hsum(QFBase::add(acc[2*ix+0], acc[2*ix+1])));
-//}
-
-template <typename Qy, typename Qx>
-IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
- int nb = n/QFBase::k_step;
- int nb4 = n/4;
- Qy y(info);
- Qx x(cx + ix0*bx, bx);
- QFBase::Data xv[Qx::nrc];
- QFBase::Acc acc[Qx::nrc*Qy::nrc];
- auto yv = y.load1(0, 0);
- for (int ix = 0; ix < Qx::nrc; ++ix) {
- xv[ix] = x.load1(ix, 0);
- acc[ix] = QFBase::acc_first(yv, xv[ix]);
- }
- for (int iy = 1; iy < Qy::nrc; ++iy) {
- yv = y.load1(iy, 0);
- for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]);
- }
- for (int i = 1; i < nb; ++i) {
- yv = y.load1(0, i);
- for (int ix = 0; ix < Qx::nrc; ++ix) {
- xv[ix] = x.load1(ix, i);
- acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
- }
- for (int iy = 1; iy < Qy::nrc; ++iy) {
- yv = y.load1(iy, i);
- for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
- }
- }
- for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
- yv = y.load_tail(0, i);
- for (int ix = 0; ix < Qx::nrc; ++ix) {
- xv[ix] = x.load_tail(ix, i);
- acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
- }
- for (int iy = 1; iy < Qy::nrc; ++iy) {
- yv = y.load_tail(iy, i);
- for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
- }
- }
- for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
-}
-
-template <typename Qy, typename Qx>
-inline void mul_mat_Qx_Qy_MxN_fa(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
- int nb = n/QFBase::k_step;
- Qy y(info);
- Qx x(cx + ix0*bx, bx);
- QFBase::Data xv[Qx::nrc];
- QFBase::Acc acc[Qx::nrc*Qy::nrc];
- auto yv = y.load1(0, 0);
- for (int ix = 0; ix < Qx::nrc; ++ix) {
- xv[ix] = x.load1(ix, 0);
- acc[ix] = QFBase::acc_first(yv, xv[ix]);
- }
- for (int iy = 1; iy < Qy::nrc; ++iy) {
- yv = y.load1(iy, 0);
- for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]);
- }
- for (int i = 1; i < nb; ++i) {
- yv = y.load1(0, i);
- for (int ix = 0; ix < Qx::nrc; ++ix) {
- xv[ix] = x.load1(ix, i);
- acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
- }
- for (int iy = 1; iy < Qy::nrc; ++iy) {
- yv = y.load1(iy, i);
- for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
- }
- }
- for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
-}
-
-template <typename Qy, typename Qx>
-inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, const DataInfo& info) {
- static_assert(Qx::nrc%4 == 0);
- int nb = D/QFBase::k_step;
- Qy y(info);
- Qx x(cx + ix0*bx, bx);
- QFBase::Data xv[Qx::nrc];
- QFBase::Acc acc[Qx::nrc*Qy::nrc/4] = {};
- for (int i = 0; i < nb; ++i) {
- for (int ix = 0; ix < Qx::nrc/4; ++ix) x.load_r4(4*ix, i, xv + 4*ix);
- for (int iy = 0; iy < Qy::nrc; ++iy) {
- auto yv = y.load1(iy, i);
- for (int ix = 0; ix < Qx::nrc/4; ++ix) acc[ix*Qy::nrc + iy] = QFBase::acc_r4(acc[ix*Qy::nrc + iy], xv + 4*ix, yv);
- }
- }
- for (int iy = 0; iy < Qy::nrc; ++iy) {
- for (int ix = 0; ix < Qx::nrc/4; ++ix) info.store(ix0+4*ix, iy, QFBase::hsum_r4(acc[ix*Qy::nrc + iy]));
- }
-}
-
-// This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done
-// in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in
-// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.
-template <int nrc_y, typename FloatX, typename FloatY>
-void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- const char * cx = (const char *)vx;
- // TBD if we want this
- //if constexpr (nrc_y == 1) {
- // constexpr int k_nx = 2;
- // for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
- // mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
- // }
- // if (int lastx = k_nx*(nrc_x/k_nx); lastx < nrc_x) {
- // int nx = nrc_x - lastx;
- // switch (nx) {
- // case 1: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info); break;
- // case 2: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, lastx, info); break;
- // case 3: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, lastx, info); break;
- // }
- // //mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info);
- // }
- // return;
- //}
-#ifdef __AVX512F__
- constexpr int k_nx = 5;
-#else
- constexpr int k_nx = nrc_y == 1 ? 4 : 2;
-#endif
- for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
- mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
- }
- int last_x = k_nx*(nrc_x/k_nx);
- if (last_x == nrc_x) return;
- int nx = nrc_x - last_x;
-#ifdef __AVX512F__
- switch (nx) {
- case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
- case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
- case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
- case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;
- }
-#else
- if constexpr (nrc_y == 1) {
- switch (nx) {
- case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
- case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
- case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
- }
- } else {
- switch (nx) {
- case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
- }
- }
-#endif
-}
-
-#ifdef __AVX512BF16__
-struct QFBaseBF16 {
- constexpr static int k_step = 32;
- using Data = __m512bh;
- using Acc = __m512;
- static inline Data load(const ggml_bf16_t * x) { return __m512bh(_mm512_loadu_si512((const __m512i *)x)); }
- //static inline Acc acc(Acc prev, const Data& y, const Data& x) {
- static inline Acc acc(Acc prev, Data y, Data x) {
- return _mm512_dpbf16_ps(prev, y, x);
- }
- static inline Acc acc_first(const Data& y, const Data& x) {
- return _mm512_dpbf16_ps(_mm512_setzero_ps(), y, x);
- }
- static inline float hsum(Acc acc) {
- return _mm512_reduce_add_ps(acc);
- }
-};
-template <int nrc_in> struct QFTBF16 final : public QFBaseBF16 {
- constexpr static int nrc = nrc_in;
- QFTBF16(const DataInfo& info) {
- for (int iy = 0; iy < nrc; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
- }
- QFTBF16(const char * cx, size_t bx) {
- for (int iy = 0; iy < nrc; ++iy) y[iy] = (const ggml_bf16_t *)(cx + iy*bx);
- }
- IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
- const ggml_bf16_t * y[nrc];
-};
-
-template <int nrc_y, int nrc_x>
-IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
- int nb = n/QFBaseBF16::k_step;
- QFTBF16<nrc_y> y(info);
- QFTBF16<nrc_x> x(cx + ix0*bx, bx);
- QFBaseBF16::Data xv[nrc_x];
- QFBaseBF16::Acc acc[nrc_x*nrc_y];
- auto yv = y.load1(0, 0);
- for (int ix = 0; ix < nrc_x; ++ix) {
- xv[ix] = x.load1(ix, 0);
- acc[ix] = QFBaseBF16::acc_first(yv, xv[ix]);
- }
- for (int iy = 1; iy < nrc_y; ++iy) {
- yv = y.load1(iy, 0);
- for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QFBaseBF16::acc_first(yv, xv[ix]);
- }
- for (int i = 1; i < nb; ++i) {
- yv = y.load1(0, i);
- for (int ix = 0; ix < nrc_x; ++ix) {
- xv[ix] = x.load1(ix, i);
- acc[ix] = QFBaseBF16::acc(acc[ix], yv, xv[ix]);
- }
- for (int iy = 1; iy < nrc_y; ++iy) {
- yv = y.load1(iy, i);
- for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QFBaseBF16::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QFBaseBF16::hsum(acc[nrc_x*iy+ix]));
-}
-
-template <int nrc_y>
-void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- constexpr int k_nx = nrc_y <= 2 ? 8 : 5;
- const char * cx = (const char *)vx;
- for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
- mul_mat_Qx_Qy_MxN<nrc_y, k_nx>(n, cx, bx, ix*k_nx, info);
- }
- int last_x = k_nx*(nrc_x/k_nx);
- if (last_x == nrc_x) return;
- int nx = nrc_x - last_x;
- if constexpr (nrc_y <= 2) {
- if (nx >= 4) {
- mul_mat_Qx_Qy_MxN<nrc_y, 4>(n, cx, bx, last_x, info);
- last_x += 4;
- if (last_x == nrc_x) return;
- nx = nrc_x - last_x;
- }
- }
- switch (nx) {
- case 1: mul_mat_Qx_Qy_MxN<nrc_y, 1>(n, cx, bx, last_x, info); break;
- case 2: mul_mat_Qx_Qy_MxN<nrc_y, 2>(n, cx, bx, last_x, info); break;
- case 3: mul_mat_Qx_Qy_MxN<nrc_y, 3>(n, cx, bx, last_x, info); break;
- case 4: mul_mat_Qx_Qy_MxN<nrc_y, 4>(n, cx, bx, last_x, info); break;
- }
-}
-#endif
-
-//
-// Tiled Q8_0 x Q8_0 implementation. Not used as the templated legacy quant implementation
-// above is faster. Left behind so we remember we tried.
-//
-template <int nrc> struct Q80 {
- constexpr static int nrc_y = nrc;
- Q80(const DataInfo& info) {
- for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy);
- }
- IQK_ALWAYS_INLINE __m256i load1(int iy, int i) const { return _mm256_loadu_si256((const __m256i *)y[iy][i].qs); }
- IQK_ALWAYS_INLINE float scale(int iy, int i) const { return GGML_FP16_TO_FP32(y[iy][i].d); }
-
- const block_q8_0 * y[nrc_y];
-};
-inline __m256i mul_q80(__m256i x, __m256i y) {
- auto ux = _mm256_sign_epi8(x, x);
-#ifdef HAVE_FANCY_SIMD
- return _mm256_dpbusd_epi32(_mm256_setzero_si256(), ux, _mm256_sign_epi8(y, x));
-#else
- return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(ux, _mm256_sign_epi8(y, x)));
-#endif
-}
-template <int nrc_y>
-void mul_mat_q80_q80_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n%QK8_0 == 0);
- constexpr int k_nx = 4;
- int nb = n/QK8_0;
- Q80<nrc_y> q8(info);
- const block_q8_0 * x[k_nx];
- float ds[k_nx];
- __m256 acc[k_nx*nrc_y];
- __m256i xv[k_nx];
- for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
- int ix0 = k_nx*ix;
- for (int kx = 0; kx < k_nx; ++kx) {
- x[kx] = (const block_q8_0 *)((const char *)vx + (ix0 + kx)*bx);
- ds[kx] = GGML_FP16_TO_FP32(x[kx][0].d);
- xv[kx] = _mm256_loadu_si256((const __m256i *)x[kx][0].qs);
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto yv = q8.load1(iy, 0);
- float d = q8.scale(iy, 0);
- for (int kx = 0; kx < k_nx; ++kx) {
- auto dot = mul_q80(yv, xv[kx]);
- acc[k_nx*iy + kx] = _mm256_mul_ps(_mm256_set1_ps(ds[kx]*d), _mm256_cvtepi32_ps(dot));
- }
- }
- for (int i = 1; i < nb; ++i) {
- for (int kx = 0; kx < k_nx; ++kx) {
- ds[kx] = GGML_FP16_TO_FP32(x[kx][i].d);
- xv[kx] = _mm256_loadu_si256((const __m256i *)x[kx][i].qs);
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto yv = q8.load1(iy, i);
- float d = q8.scale(iy, i);
- for (int kx = 0; kx < k_nx; ++kx) {
- auto dot = mul_q80(yv, xv[kx]);
- acc[k_nx*iy + kx] = _mm256_fmadd_ps(_mm256_set1_ps(ds[kx]*d), _mm256_cvtepi32_ps(dot), acc[k_nx*iy + kx]);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- for (int kx = 0; kx < k_nx; ++kx) info.store(ix0+kx, iy, hsum_float_8(acc[k_nx*iy+kx]));
- }
- }
- int last_x = k_nx*(nrc_x/k_nx);
- if (last_x == nrc_x) return;
- // TODO: handle remaining rows
-}
-
-template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
- if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||
- std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
- m.funcs[0] = mul_mat_qX_0_q8_0_T<Dequantizer, 1>;
- m.funcs[1] = mul_mat_qX_0_q8_0_T<Dequantizer, 2>;
- m.funcs[2] = mul_mat_qX_0_q8_0_T<Dequantizer, 3>;
- m.funcs[3] = mul_mat_qX_0_q8_0_T<Dequantizer, 4>;
- m.funcs[4] = mul_mat_qX_0_q8_0_T<Dequantizer, 5>;
- m.funcs[5] = mul_mat_qX_0_q8_0_T<Dequantizer, 6>;
- m.funcs[6] = mul_mat_qX_0_q8_0_T<Dequantizer, 7>;
- m.funcs[7] = mul_mat_qX_0_q8_0_T<Dequantizer, 8>;
- }
- else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker>) {
- m.funcs[0] = mul_mat_qX_1_q8_2_T<Dequantizer, 1>;
- m.funcs[1] = mul_mat_qX_1_q8_2_T<Dequantizer, 2>;
- m.funcs[2] = mul_mat_qX_1_q8_2_T<Dequantizer, 3>;
- m.funcs[3] = mul_mat_qX_1_q8_2_T<Dequantizer, 4>;
- m.funcs[4] = mul_mat_qX_1_q8_2_T<Dequantizer, 5>;
- m.funcs[5] = mul_mat_qX_1_q8_2_T<Dequantizer, 6>;
- m.funcs[6] = mul_mat_qX_1_q8_2_T<Dequantizer, 7>;
- m.funcs[7] = mul_mat_qX_1_q8_2_T<Dequantizer, 8>;
- }
- else if constexpr (std::is_same_v<Dequantizer, IQ4_NL_Unpacker>) {
-#ifdef HAVE_FANCY_SIMD
- m.funcs[0] = mul_mat_qX_1_q8_2_T<Dequantizer, 1>;
- m.funcs[1] = mul_mat_qX_1_q8_2_T<Dequantizer, 2>;
- m.funcs[2] = mul_mat_qX_1_q8_2_T<Dequantizer, 3>;
- m.funcs[3] = mul_mat_qX_1_q8_2_T<Dequantizer, 4>;
- m.funcs[4] = mul_mat_qX_1_q8_2_T<Dequantizer, 5>;
- m.funcs[5] = mul_mat_qX_1_q8_2_T<Dequantizer, 6>;
- m.funcs[6] = mul_mat_qX_1_q8_2_T<Dequantizer, 7>;
- m.funcs[7] = mul_mat_qX_1_q8_2_T<Dequantizer, 8>;
-#else
- m.funcs[0] = mul_mat_qX_0_q8_0_T<Dequantizer, 1>;
- m.funcs[1] = mul_mat_qX_0_q8_0_T<Dequantizer, 2>;
- m.funcs[2] = mul_mat_qX_0_q8_0_T<Dequantizer, 3>;
- m.funcs[3] = mul_mat_qX_0_q8_0_T<Dequantizer, 4>;
- m.funcs[4] = mul_mat_qX_0_q8_0_T<Dequantizer, 5>;
- m.funcs[5] = mul_mat_qX_0_q8_0_T<Dequantizer, 6>;
- m.funcs[6] = mul_mat_qX_0_q8_0_T<Dequantizer, 7>;
- m.funcs[7] = mul_mat_qX_0_q8_0_T<Dequantizer, 8>;
-#endif
- }
- else if constexpr (std::is_same_v<Dequantizer, Q8_0_1_Unpacker> || std::is_same_v<Dequantizer, Q4_0_1_Unpacker> ||
- std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, Q6_0_1_Unpacker>) {
- m.funcs[0] = mul_mat_qX_1_q8_2_T<Dequantizer, 1>;
- m.funcs[1] = mul_mat_qX_1_q8_2_T<Dequantizer, 2>;
- m.funcs[2] = mul_mat_qX_1_q8_2_T<Dequantizer, 3>;
- m.funcs[3] = mul_mat_qX_1_q8_2_T<Dequantizer, 4>;
- m.funcs[4] = mul_mat_qX_1_q8_2_T<Dequantizer, 5>;
- m.funcs[5] = mul_mat_qX_1_q8_2_T<Dequantizer, 6>;
- m.funcs[6] = mul_mat_qX_1_q8_2_T<Dequantizer, 7>;
- m.funcs[7] = mul_mat_qX_1_q8_2_T<Dequantizer, 8>;
- }
- else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S> || std::is_same_v<Dequantizer, DequantizerIQ3XXS> ||
- std::is_same_v<Dequantizer, DequantizerIQ2S> || std::is_same_v<Dequantizer, DequantizerIQ2XS> ||
- std::is_same_v<Dequantizer, DequantizerIQ2XXS>) {
- m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
- m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
- m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;
- m.funcs[3] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 4>;
- m.funcs[4] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 5>;
- m.funcs[5] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 6>;
- m.funcs[6] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 7>;
- m.funcs[7] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 8>;
- }
- else {
-#ifdef HAVE_FANCY_SIMD
- if constexpr (std::is_same_v<Dequantizer, DequantizerIQ6K> ||
- std::is_same_v<Dequantizer, DequantizerIQ5K> ||
- std::is_same_v<Dequantizer, DequantizerIQ4K> ||
- std::is_same_v<Dequantizer, DequantizerIQ3K> ||
- std::is_same_v<Dequantizer, DequantizerIQ4XS>||
- //std::is_same_v<Dequantizer, DequantizerIQ4KS>||
- //std::is_same_v<Dequantizer, DequantizerIQ5KS>||
- std::is_same_v<Dequantizer, DequantizerIQ4KSS>) {
- m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
- m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;
- m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 3>;
- m.funcs[3] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 4>;
- m.funcs[4] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 5>;
- m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 6>;
- m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 7>;
- m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 8>;
- } else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2KS> ||
- std::is_same_v<Dequantizer, DequantizerIQ4KS> ||
- std::is_same_v<Dequantizer, DequantizerIQ5KS>) {
- m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 1>;
- m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 2>;
- m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 3>;
- m.funcs[3] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 4>;
- m.funcs[4] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 5>;
- m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 6>;
- m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 7>;
- m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 8>;
- } else {
- m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
- m.funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;
- m.funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>;
- m.funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>;
- m.funcs[4] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 5>;
- m.funcs[5] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 6>;
- m.funcs[6] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 7>;
- m.funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>;
- }
-#else
- if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> ||
- std::is_same_v<Dequantizer, DequantizerQ3K> ||
- std::is_same_v<Dequantizer, DequantizerQ6K> ||
- std::is_same_v<Dequantizer, DequantizerIQ2K>||
- std::is_same_v<Dequantizer, DequantizerIQ3K>||
- std::is_same_v<Dequantizer, DequantizerIQ4K>||
- std::is_same_v<Dequantizer, DequantizerIQ5K>||
- std::is_same_v<Dequantizer, DequantizerIQ6K>) {
- m.funcs[0] = mul_mat_qY_K_q8_K_T<Dequantizer, 1>;
- m.funcs[1] = mul_mat_qY_K_q8_K_T<Dequantizer, 2>;
- m.funcs[2] = mul_mat_qY_K_q8_K_T<Dequantizer, 3>;
- m.funcs[3] = mul_mat_qY_K_q8_K_T<Dequantizer, 4>;
- m.funcs[4] = mul_mat_qY_K_q8_K_T<Dequantizer, 5>;
- m.funcs[5] = mul_mat_qY_K_q8_K_T<Dequantizer, 6>;
- m.funcs[6] = mul_mat_qY_K_q8_K_T<Dequantizer, 7>;
- m.funcs[7] = mul_mat_qY_K_q8_K_T<Dequantizer, 8>;
- } else {
- m.funcs[0] = mul_mat_qX_K_q8_K_T<Dequantizer, 1>;
- m.funcs[1] = mul_mat_qX_K_q8_K_T<Dequantizer, 2>;
- m.funcs[2] = mul_mat_qX_K_q8_K_T<Dequantizer, 3>;
- m.funcs[3] = mul_mat_qX_K_q8_K_T<Dequantizer, 4>;
- m.funcs[4] = mul_mat_qX_K_q8_K_T<Dequantizer, 5>;
- m.funcs[5] = mul_mat_qX_K_q8_K_T<Dequantizer, 6>;
- m.funcs[6] = mul_mat_qX_K_q8_K_T<Dequantizer, 7>;
- m.funcs[7] = mul_mat_qX_K_q8_K_T<Dequantizer, 8>;
- }
-#endif
- }
-}
-
-template <typename FloatX, typename FloatY>
-void set_mul_mat_f(MulMat& mm) {
- for (auto& f : mm.funcs) f = nullptr;
- mm.funcs[0] = mul_mat_fX_fY_T<1, FloatX, FloatY>;
- mm.funcs[1] = mul_mat_fX_fY_T<2, FloatX, FloatY>;
- mm.funcs[2] = mul_mat_fX_fY_T<3, FloatX, FloatY>;
- mm.funcs[3] = mul_mat_fX_fY_T<4, FloatX, FloatY>;
- mm.funcs[4] = mul_mat_fX_fY_T<5, FloatX, FloatY>;
-#ifndef __AVX512F__
- mm.funcs[5] = mul_mat_fX_fY_T<6, FloatX, FloatY>;
-#endif
-}
-
-#ifdef __AVX512BF16__
-void set_mul_mat_bf16(MulMat& mm) {
- for (auto& f : mm.funcs) f = nullptr;
- mm.funcs[0] = mul_mat_fX_fY_T<1>;
- mm.funcs[1] = mul_mat_fX_fY_T<2>;
- mm.funcs[2] = mul_mat_fX_fY_T<3>;
- mm.funcs[3] = mul_mat_fX_fY_T<4>;
- mm.funcs[4] = mul_mat_fX_fY_T<5>;
-}
-void set_mul_mat_bf16_r16(MulMat& mm) {
- for (auto& f : mm.funcs) f = nullptr;
- mm.funcs[0] = mul_mat_bf16_r16_bf16<1>;
- mm.funcs[1] = mul_mat_bf16_r16_bf16<2>;
- mm.funcs[2] = mul_mat_bf16_r16_bf16<3>;
- mm.funcs[3] = mul_mat_bf16_r16_bf16<4>;
- mm.funcs[4] = mul_mat_bf16_r16_bf16<5>;
- mm.funcs[5] = mul_mat_bf16_r16_bf16<6>;
- mm.funcs[6] = mul_mat_bf16_r16_bf16<7>;
- mm.funcs[7] = mul_mat_bf16_r16_bf16<8>;
-}
-#endif
-
bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
(void)Ny;
- if (typeA == GGML_TYPE_BF16) {
- if (ne00 % 32) return false;
- switch (typeB) {
-#ifdef __AVX512BF16__
- case GGML_TYPE_BF16: set_mul_mat_bf16(mm); break;
-#else
- case GGML_TYPE_BF16: set_mul_mat_f<ggml_bf16_t, ggml_bf16_t>(mm); break;
- case GGML_TYPE_F32: set_mul_mat_f<ggml_bf16_t, float>(mm); break;
-#endif
- default: return false;
- }
- return true;
- }
-
- if (typeA == GGML_TYPE_BF16_R16) {
- if (ne00 % 16) return false;
- switch (typeB) {
-#ifdef __AVX512BF16__
- case GGML_TYPE_BF16: set_mul_mat_bf16_r16(mm); break;
-#endif
- default: return false;
- }
- return true;
- }
-
- if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) {
- if (ne00 % 4) return false;
- }
- if (typeA == GGML_TYPE_F16) {
- switch (typeB) {
- case GGML_TYPE_F16: set_mul_mat_f<ggml_half, ggml_half>(mm); break;
- case GGML_TYPE_F32: set_mul_mat_f<ggml_half, float>(mm); break;
- default: return false;
- }
- return true;
- }
- if (typeA == GGML_TYPE_F32) {
- switch (typeB) {
- case GGML_TYPE_F16: set_mul_mat_f<float, ggml_half>(mm); break;
- case GGML_TYPE_F32: set_mul_mat_f<float, float>(mm); break;
- default: return false;
- }
- return true;
- }
-
- auto expected_typeB = GGML_TYPE_Q8_K;
-
switch (typeA) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ case GGML_TYPE_BF16:
+ case GGML_TYPE_BF16_R16:
+ return iqk_set_kernels_float(ne00, typeA, typeB, mm.funcs);
case GGML_TYPE_Q2_K:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerQ2K>(mm);
- break;
case GGML_TYPE_Q3_K:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerQ3K>(mm);
- break;
case GGML_TYPE_Q4_K:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerQ4K>(mm);
- break;
case GGML_TYPE_Q5_K:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerQ5K>(mm);
- break;
case GGML_TYPE_Q6_K:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerQ6K>(mm);
- break;
case GGML_TYPE_IQ4_XS:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ4XS>(mm);
- break;
+ case GGML_TYPE_Q2_K_R4:
+ case GGML_TYPE_Q3_K_R4:
+ case GGML_TYPE_Q4_K_R4:
+ case GGML_TYPE_Q5_K_R4:
+ case GGML_TYPE_Q6_K_R4:
+ case GGML_TYPE_IQ4_XS_R8:
+ case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV:
+ case GGML_TYPE_Q8_KV_R8:
+ return iqk_set_kernels_kquants(ne00, typeA, typeB, mm.funcs, mm.func16);
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ2_XXS_R4:
+ case GGML_TYPE_IQ2_XS_R4:
+ case GGML_TYPE_IQ2_S_R4:
+ case GGML_TYPE_IQ3_XXS_R4:
+ case GGML_TYPE_IQ3_S_R4:
+ return ggml_type(typeB) == GGML_TYPE_Q8_K ? iqk_set_kernels_iquants(ne00, typeA, typeB, mm.funcs, mm.func16) : false;
case GGML_TYPE_IQ4_KS:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ4KS>(mm);
- break;
case GGML_TYPE_IQ5_KS:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ5KS>(mm);
- break;
case GGML_TYPE_IQ4_KSS:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ4KSS>(mm);
- break;
case GGML_TYPE_IQ2_K:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ2K>(mm);
- break;
case GGML_TYPE_IQ2_KS:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ2KS>(mm);
- break;
case GGML_TYPE_IQ3_K:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ3K>(mm);
- break;
case GGML_TYPE_IQ4_K:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ4K>(mm);
- break;
case GGML_TYPE_IQ5_K:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ5K>(mm);
- break;
case GGML_TYPE_IQ6_K:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ6K>(mm);
- break;
- case GGML_TYPE_IQ3_S:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ3S>(mm);
- break;
- case GGML_TYPE_IQ3_XXS:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ3XXS>(mm);
- break;
- case GGML_TYPE_IQ2_S:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ2S>(mm);
- break;
- case GGML_TYPE_IQ2_XS:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ2XS>(mm);
- break;
- case GGML_TYPE_IQ2_XXS:
- assert (ne00 % QK_K == 0);
- MulMat::set_functions<DequantizerIQ2XXS>(mm);
- break;
- case GGML_TYPE_IQ1_BN:
- assert (ne00 % QK_IQ1BN == 0);
- mm.funcs[0] = mul_mat_iq1bn_q8_K64<1>;
- mm.funcs[1] = mul_mat_iq1bn_q8_K64<2>;
- mm.funcs[2] = mul_mat_iq1bn_q8_K64<3>;
- mm.funcs[3] = mul_mat_iq1bn_q8_K64<4>;
- mm.funcs[4] = mul_mat_iq1bn_q8_K64<5>;
- mm.funcs[5] = mul_mat_iq1bn_q8_K64<6>;
- mm.funcs[6] = mul_mat_iq1bn_q8_K64<7>;
- mm.funcs[7] = mul_mat_iq1bn_q8_K64<8>;
- expected_typeB = GGML_TYPE_Q8_K64;
- break;
- case GGML_TYPE_IQ2_BN:
- assert (ne00 % QK_IQ1BN == 0);
- mm.funcs[0] = mul_mat_iq2bn_q8_K64<1>;
- mm.funcs[1] = mul_mat_iq2bn_q8_K64<2>;
- mm.funcs[2] = mul_mat_iq2bn_q8_K64<3>;
- mm.funcs[3] = mul_mat_iq2bn_q8_K64<4>;
- mm.funcs[4] = mul_mat_iq2bn_q8_K64<5>;
- mm.funcs[5] = mul_mat_iq2bn_q8_K64<6>;
- mm.funcs[6] = mul_mat_iq2bn_q8_K64<7>;
- mm.funcs[7] = mul_mat_iq2bn_q8_K64<8>;
- expected_typeB = GGML_TYPE_Q8_K64;
- break;
- case GGML_TYPE_IQ2_BN_R4:
- assert (ne00 % QK_IQ1BN == 0);
- mm.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
- mm.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
- mm.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
- mm.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
- mm.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
- mm.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
-//#ifdef HAVE_FANCY_SIMD
- mm.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
- mm.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
-//#endif
- expected_typeB = GGML_TYPE_Q8_K16;
- break;
+ case GGML_TYPE_IQ2_K_R4:
+ case GGML_TYPE_IQ3_K_R4:
+ case GGML_TYPE_IQ4_K_R4:
+ case GGML_TYPE_IQ5_K_R4:
+ case GGML_TYPE_IQ4_KS_R4:
+ case GGML_TYPE_IQ5_KS_R4:
+ return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
case GGML_TYPE_Q4_0:
- assert (ne00 % QK4_0 == 0);
- MulMat::set_functions<Q4_0_1_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_2_X4;
- break;
case GGML_TYPE_Q4_1:
- assert (ne00 % QK4_1 == 0);
- MulMat::set_functions<Q4_1_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_2_X4;
- break;
case GGML_TYPE_Q5_0:
- assert (ne00 % QK5_0 == 0);
- MulMat::set_functions<Q5_0_1_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_2_X4;
- break;
case GGML_TYPE_Q5_1:
- assert (ne00 % QK5_1 == 0);
- MulMat::set_functions<Q5_1_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_2_X4;
- break;
case GGML_TYPE_Q6_0:
- assert (ne00 % QK6_0 == 0);
- MulMat::set_functions<Q6_0_1_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_2_X4;
- break;
case GGML_TYPE_Q8_0:
- assert (ne00 % QK8_0 == 0);
-#ifdef HAVE_FANCY_SIMD
- MulMat::set_functions<Q8_0_1_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_2_X4;
-#else
- MulMat::set_functions<Q8_0_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_0_X4;
-#endif
- break;
case GGML_TYPE_IQ4_NL:
- assert (ne00 % QK4_NL == 0);
- MulMat::set_functions<IQ4_NL_Unpacker>(mm);
-#ifdef HAVE_FANCY_SIMD
- expected_typeB = GGML_TYPE_Q8_2_X4;
-#else
- expected_typeB = GGML_TYPE_Q8_0_X4;
-#endif
- break;
- case GGML_TYPE_IQ4_NL_R4:
- assert (ne00 % QK4_NL == 0);
- mm.funcs[0] = mul_mat_iq4_nl_r4_q8_2<1>;
- mm.funcs[1] = mul_mat_iq4_nl_r4_q8_2<2>;
- mm.funcs[2] = mul_mat_iq4_nl_r4_q8_2<3>;
- mm.funcs[3] = mul_mat_iq4_nl_r4_q8_2<4>;
- mm.funcs[4] = mul_mat_iq4_nl_r4_q8_2<5>;
- mm.funcs[5] = mul_mat_iq4_nl_r4_q8_2<6>;
- mm.funcs[6] = mul_mat_iq4_nl_r4_q8_2<7>;
- mm.funcs[7] = mul_mat_iq4_nl_r4_q8_2<8>;
- expected_typeB = GGML_TYPE_Q8_2_X4;
- break;
- case GGML_TYPE_IQ4_XS_R8:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_iq4_xs_r8_q8_k<1>;
- mm.funcs[1] = mul_mat_iq4_xs_r8_q8_k<2>;
- mm.funcs[2] = mul_mat_iq4_xs_r8_q8_k<3>;
- mm.funcs[3] = mul_mat_iq4_xs_r8_q8_k<4>;
- mm.funcs[4] = mul_mat_iq4_xs_r8_q8_k<5>;
- mm.funcs[5] = mul_mat_iq4_xs_r8_q8_k<6>;
- mm.funcs[6] = mul_mat_iq4_xs_r8_q8_k<7>;
- mm.funcs[7] = mul_mat_iq4_xs_r8_q8_k<8>;
- expected_typeB = GGML_TYPE_Q8_K32;
- break;
- case GGML_TYPE_IQ4_KS_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_iq4_ks_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_iq4_ks_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_iq4_ks_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_iq4_ks_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_iq4_ks_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_iq4_ks_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_iq4_ks_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_iq4_ks_r4_q8_k<8>;
-#ifndef HAVE_FANCY_SIMD
- // For some reason Zen4 does not like this particular function
- mm.func16 = mul_mat_iq4_ks_r4_q8_k<16>;
-#endif
- expected_typeB = GGML_TYPE_Q8_K32;
- break;
- case GGML_TYPE_IQ5_KS_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_iq5_ks_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_iq5_ks_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_iq5_ks_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_iq5_ks_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_iq5_ks_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_iq5_ks_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_iq5_ks_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_iq5_ks_r4_q8_k<8>;
-#ifndef HAVE_FANCY_SIMD
- // For some reason Zen4 does not like this particular function
- mm.func16 = mul_mat_iq5_ks_r4_q8_k<16>;
-#endif
- expected_typeB = GGML_TYPE_Q8_K32;
- break;
- case GGML_TYPE_IQ2_XXS_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_iq2_xxs_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_iq2_xxs_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_iq2_xxs_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_iq2_xxs_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_iq2_xxs_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_iq2_xxs_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_iq2_xxs_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_iq2_xxs_r4_q8_k<8>;
- mm.func16 = mul_mat_iq2_xxs_r4_q8_k<16>;
- expected_typeB = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ2_XS_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_iq2_xs_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_iq2_xs_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_iq2_xs_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_iq2_xs_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_iq2_xs_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_iq2_xs_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_iq2_xs_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_iq2_xs_r4_q8_k<8>;
-#ifndef HAVE_FANCY_SIMD
- // For some reason Zen4 does not like this particular function
- mm.func16 = mul_mat_iq2_xs_r4_q8_k_16;
-#endif
- expected_typeB = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ2_S_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_iq2_s_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_iq2_s_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_iq2_s_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_iq2_s_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_iq2_s_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_iq2_s_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_iq2_s_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_iq2_s_r4_q8_k<8>;
- mm.func16 = mul_mat_iq2_s_r4_q8_k_16;
- expected_typeB = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ3_XXS_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_iq3_xxs_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_iq3_xxs_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_iq3_xxs_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_iq3_xxs_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_iq3_xxs_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_iq3_xxs_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_iq3_xxs_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_iq3_xxs_r4_q8_k<8>;
- mm.func16 = mul_mat_iq3_xxs_r4_q8_k<16>;
- expected_typeB = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ3_S_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_iq3_s_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_iq3_s_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_iq3_s_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_iq3_s_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_iq3_s_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_iq3_s_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_iq3_s_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_iq3_s_r4_q8_k<8>;
- mm.func16 = mul_mat_iq3_s_r4_q8_k<16>;
- expected_typeB = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_Q2_K_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_q2_k_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_q2_k_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_q2_k_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_q2_k_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_q2_k_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_q2_k_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_q2_k_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_q2_k_r4_q8_k<8>;
- expected_typeB = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_Q3_K_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_q3_k_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_q3_k_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_q3_k_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_q3_k_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_q3_k_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_q3_k_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_q3_k_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_q3_k_r4_q8_k<8>;
- expected_typeB = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_Q4_K_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_q4_k_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_q4_k_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_q4_k_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_q4_k_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_q4_k_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_q4_k_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_q4_k_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_q4_k_r4_q8_k<8>;
- expected_typeB = GGML_TYPE_Q8_K32;
- break;
- case GGML_TYPE_Q5_K_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_q5_k_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_q5_k_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_q5_k_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_q5_k_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_q5_k_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_q5_k_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_q5_k_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_q5_k_r4_q8_k<8>;
- expected_typeB = GGML_TYPE_Q8_K32;
- break;
- case GGML_TYPE_Q6_K_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_q6_k_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_q6_k_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_q6_k_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_q6_k_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_q6_k_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_q6_k_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_q6_k_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_q6_k_r4_q8_k<8>;
- expected_typeB = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_Q8_K_R8:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_q8_k_r8_q8_k<1>;
- mm.funcs[1] = mul_mat_q8_k_r8_q8_k<2>;
- mm.funcs[2] = mul_mat_q8_k_r8_q8_k<3>;
- mm.funcs[3] = mul_mat_q8_k_r8_q8_k<4>;
- mm.funcs[4] = mul_mat_q8_k_r8_q8_k<5>;
- mm.funcs[5] = mul_mat_q8_k_r8_q8_k<6>;
- mm.funcs[6] = mul_mat_q8_k_r8_q8_k<7>;
- mm.funcs[7] = mul_mat_q8_k_r8_q8_k<8>;
-#ifdef HAVE_FANCY_SIMD
- mm.func16 = mul_mat_q8_k_r8_q8_k<16>;
-#endif
- expected_typeB = GGML_TYPE_Q8_KR8;
- break;
- case GGML_TYPE_Q8_KV:
- assert (ne00 % 32 == 0);
- mm.funcs[0] = mul_mat_q8_KV_q8_KV_1<1>;
- mm.funcs[1] = mul_mat_q8_KV_q8_KV<2>;
- mm.funcs[2] = mul_mat_q8_KV_q8_KV<3>;
- mm.funcs[3] = mul_mat_q8_KV_q8_KV<4>;
- mm.funcs[4] = mul_mat_q8_KV_q8_KV<5>;
- mm.funcs[5] = mul_mat_q8_KV_q8_KV<6>;
- mm.funcs[6] = mul_mat_q8_KV_q8_KV<7>;
- mm.funcs[7] = mul_mat_q8_KV_q8_KV<8>;
-#ifdef HAVE_FANCY_SIMD
- mm.func16 = mul_mat_q8_KV_q8_KV<16>;
-#endif
- expected_typeB = GGML_TYPE_Q8_KV;
- break;
- case GGML_TYPE_Q8_KV_R8:
- assert (ne00 % 32 == 0);
- mm.funcs[0] = mul_mat_q8_KV_r8_q8_KV<1>;
- mm.funcs[1] = mul_mat_q8_KV_r8_q8_KV<2>;
- mm.funcs[2] = mul_mat_q8_KV_r8_q8_KV<3>;
- mm.funcs[3] = mul_mat_q8_KV_r8_q8_KV<4>;
- mm.funcs[4] = mul_mat_q8_KV_r8_q8_KV<5>;
- mm.funcs[5] = mul_mat_q8_KV_r8_q8_KV<6>;
- mm.funcs[6] = mul_mat_q8_KV_r8_q8_KV<7>;
- mm.funcs[7] = mul_mat_q8_KV_r8_q8_KV<8>;
- expected_typeB = GGML_TYPE_Q8_KV;
- break;
- case GGML_TYPE_IQ4_K_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_iq4_k_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_iq4_k_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_iq4_k_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_iq4_k_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_iq4_k_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_iq4_k_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_iq4_k_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_iq4_k_r4_q8_k<8>;
- mm.func16 = mul_mat_iq4_k_r4_q8_k<16>;
- expected_typeB = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ5_K_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_iq5_k_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_iq5_k_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_iq5_k_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_iq5_k_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_iq5_k_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_iq5_k_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_iq5_k_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_iq5_k_r4_q8_k<8>;
- mm.func16 = mul_mat_iq5_k_r4_q8_k<16>;
- expected_typeB = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ2_K_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_iq2_k_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_iq2_k_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_iq2_k_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_iq2_k_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_iq2_k_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_iq2_k_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_iq2_k_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_iq2_k_r4_q8_k<8>;
- expected_typeB = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ3_K_R4:
- assert (ne00 % QK_K == 0);
- mm.funcs[0] = mul_mat_iq3_k_r4_q8_k<1>;
- mm.funcs[1] = mul_mat_iq3_k_r4_q8_k<2>;
- mm.funcs[2] = mul_mat_iq3_k_r4_q8_k<3>;
- mm.funcs[3] = mul_mat_iq3_k_r4_q8_k<4>;
- mm.funcs[4] = mul_mat_iq3_k_r4_q8_k<5>;
- mm.funcs[5] = mul_mat_iq3_k_r4_q8_k<6>;
- mm.funcs[6] = mul_mat_iq3_k_r4_q8_k<7>;
- mm.funcs[7] = mul_mat_iq3_k_r4_q8_k<8>;
-#ifdef HAVE_FANCY_SIMD
- mm.func16 = mul_mat_iq3_k_r4_q8_k<16>;
-#endif
- expected_typeB = GGML_TYPE_Q8_K;
- break;
case GGML_TYPE_Q4_0_R8:
- assert (ne00 % QK4_NL == 0);
- mm.funcs[0] = mul_mat_q4_0_r8_q8_2<1>;
- mm.funcs[1] = mul_mat_q4_0_r8_q8_2<2>;
- mm.funcs[2] = mul_mat_q4_0_r8_q8_2<3>;
- mm.funcs[3] = mul_mat_q4_0_r8_q8_2<4>;
- mm.funcs[4] = mul_mat_q4_0_r8_q8_2<5>;
- mm.funcs[5] = mul_mat_q4_0_r8_q8_2<6>;
- mm.funcs[6] = mul_mat_q4_0_r8_q8_2<7>;
- mm.funcs[7] = mul_mat_q4_0_r8_q8_2<8>;
-#ifdef HAVE_FANCY_SIMD
- mm.func16 = mul_mat_q4_0_r8_q8_2<16>;
-#endif
- expected_typeB = GGML_TYPE_Q8_2_X4;
- break;
case GGML_TYPE_Q5_0_R4:
- assert (ne00 % QK4_NL == 0);
- mm.funcs[0] = mul_mat_q5_0_r4_q8_2<1>;
- mm.funcs[1] = mul_mat_q5_0_r4_q8_2<2>;
- mm.funcs[2] = mul_mat_q5_0_r4_q8_2<3>;
- mm.funcs[3] = mul_mat_q5_0_r4_q8_2<4>;
- mm.funcs[4] = mul_mat_q5_0_r4_q8_2<5>;
- mm.funcs[5] = mul_mat_q5_0_r4_q8_2<6>;
- mm.funcs[6] = mul_mat_q5_0_r4_q8_2<7>;
- mm.funcs[7] = mul_mat_q5_0_r4_q8_2<8>;
- expected_typeB = GGML_TYPE_Q8_2_X4;
- break;
case GGML_TYPE_Q6_0_R4:
- assert (ne00 % QK4_NL == 0);
- mm.funcs[0] = mul_mat_q6_0_r4_q8_2<1>;
- mm.funcs[1] = mul_mat_q6_0_r4_q8_2<2>;
- mm.funcs[2] = mul_mat_q6_0_r4_q8_2<3>;
- mm.funcs[3] = mul_mat_q6_0_r4_q8_2<4>;
- mm.funcs[4] = mul_mat_q6_0_r4_q8_2<5>;
- mm.funcs[5] = mul_mat_q6_0_r4_q8_2<6>;
- mm.funcs[6] = mul_mat_q6_0_r4_q8_2<7>;
- mm.funcs[7] = mul_mat_q6_0_r4_q8_2<8>;
- expected_typeB = GGML_TYPE_Q8_2_X4;
- break;
case GGML_TYPE_Q8_0_R8:
- assert (ne00 % QK4_NL == 0);
- mm.funcs[0] = mul_mat_q8_0_r8_q8_2<1>;
- mm.funcs[1] = mul_mat_q8_0_r8_q8_2<2>;
- mm.funcs[2] = mul_mat_q8_0_r8_q8_2<3>;
- mm.funcs[3] = mul_mat_q8_0_r8_q8_2<4>;
- mm.funcs[4] = mul_mat_q8_0_r8_q8_2<5>;
- mm.funcs[5] = mul_mat_q8_0_r8_q8_2<6>;
- mm.funcs[6] = mul_mat_q8_0_r8_q8_2<7>;
- mm.funcs[7] = mul_mat_q8_0_r8_q8_2<8>;
- expected_typeB = GGML_TYPE_Q8_2_X4;
- break;
+ case GGML_TYPE_IQ4_NL_R4:
+ return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
case GGML_TYPE_IQ1_S:
- mm.funcs[0] = mul_mat_iq1_s_q8_K<1>;
- mm.funcs[1] = mul_mat_iq1_s_q8_K<2>;
- mm.funcs[2] = mul_mat_iq1_s_q8_K<3>;
- mm.funcs[3] = mul_mat_iq1_s_q8_K<4>;
- mm.funcs[4] = mul_mat_iq1_s_q8_K<5>;
- mm.funcs[5] = mul_mat_iq1_s_q8_K<6>;
- mm.funcs[6] = mul_mat_iq1_s_q8_K<7>;
- mm.funcs[7] = mul_mat_iq1_s_q8_K<8>;
-#ifdef HAVE_FANCY_SIMD
- mm.func16 = mul_mat_iq1_s_q8_K<16>;
-#endif
- expected_typeB = GGML_TYPE_Q8_K;
- break;
case GGML_TYPE_IQ1_S_R4:
- assert (ne00 % QK4_NL == 0);
- mm.funcs[0] = mul_mat_iq1_s_r4_q8_1<1>;
- mm.funcs[1] = mul_mat_iq1_s_r4_q8_1<2>;
- mm.funcs[2] = mul_mat_iq1_s_r4_q8_1<3>;
- mm.funcs[3] = mul_mat_iq1_s_r4_q8_1<4>;
- mm.funcs[4] = mul_mat_iq1_s_r4_q8_1<5>;
- mm.funcs[5] = mul_mat_iq1_s_r4_q8_1<6>;
- mm.funcs[6] = mul_mat_iq1_s_r4_q8_1<7>;
- mm.funcs[7] = mul_mat_iq1_s_r4_q8_1<8>;
-#ifdef HAVE_FANCY_SIMD
- mm.func16 = mul_mat_iq1_s_r4_q8_1<16>;
-#endif
- expected_typeB = GGML_TYPE_Q8_K128;
- break;
case GGML_TYPE_IQ1_M_R4:
- assert (ne00 % QK4_NL == 0);
- mm.funcs[0] = mul_mat_iq1_m_r4_q8_0<1>;
- mm.funcs[1] = mul_mat_iq1_m_r4_q8_0<2>;
- mm.funcs[2] = mul_mat_iq1_m_r4_q8_0<3>;
- mm.funcs[3] = mul_mat_iq1_m_r4_q8_0<4>;
- mm.funcs[4] = mul_mat_iq1_m_r4_q8_0<5>;
- mm.funcs[5] = mul_mat_iq1_m_r4_q8_0<6>;
- mm.funcs[6] = mul_mat_iq1_m_r4_q8_0<7>;
- mm.funcs[7] = mul_mat_iq1_m_r4_q8_0<8>;
-#ifdef HAVE_FANCY_SIMD
- mm.func16 = mul_mat_iq1_m_r4_q8_0<16>;
-#endif
- expected_typeB = GGML_TYPE_Q8_K128;
- break;
+ case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_BN_R4:
+ return iqk_set_kernels_1bit(ne00, typeA, typeB, mm.funcs, mm.func16);
default:
return false;
}
- return ggml_type(typeB) == expected_typeB;
+ return false;
}
} // namespace
@@ -10500,5115 +576,80 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
namespace {
-template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
-
- constexpr static int nrc_y = nrc;
-
- Q8(const DataInfo& info) {
- for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);
- }
-
- inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }
- inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); }
- inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); }
- inline int16x8_t load_bsums8(int iy, int i) const {
- auto q8s = vld1q_s16_x2(y[iy][i].bsums);
- return vpaddq_s16(q8s.val[0], q8s.val[1]);
- }
- inline float scale(int iy, int i) const { return y[iy][i].d; }
-
- const block_q8 * y[nrc_y];
-};
-
-template <typename Q8>
-inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
- const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) {
- auto mzero = vdupq_n_s32(0);
- auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
- auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
- vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1]); // block 1
- auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
- auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
- vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1]); // block 2
- auto p12 = vpaddq_s32(p1, p2);
-
- auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
- auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
- vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1]); // block 1
- auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
- auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
- vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1]); // block 2
- auto p34 = vpaddq_s32(p3, p4);
-
- auto pall = vpaddq_s32(p12, p34);
- sumi = vmlaq_s32(sumi, scales.val[j], pall);
-}
-
-template <typename Q8>
-inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
- const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {
-
- auto mzero = vdupq_n_s32(0);
- auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
- auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
- ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1,
- auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
- auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
- ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4,
- auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3
- sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12);
-
- auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
- auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
- ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5,
- auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
- auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
- ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7,
- auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7
- sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34);
-}
-
-template <typename Q8>
-inline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- auto q8s = q8.load_bsums8(iy, i);
- int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s));
- int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s));
- float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2));
- acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
- }
-}
-template <typename Q8>
-inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- auto q8s = q8.load_bsums(iy, i);
- int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0]));
- int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0]));
- int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1]));
- int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1]));
- float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4)));
- acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
- }
-}
-
-struct Scales8 {
- uint32_t utmp[4];
- const uint8_t * sc8 = (const uint8_t *)utmp;
- template <typename Q8, typename Qx>
- inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) {
- make_q4_scales(x.scales, utmp);
- int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8));
- accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin));
-
- uint8x8_t scales8 = vld1_u8(sc8);
- uint16x8_t scales16 = vmovl_u8(scales8);
- int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))),
- vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))};
- return scales;
- }
-};
-
-struct Q4bits {
- const uint8x16_t m4b = vdupq_n_u8(0xf);
- uint8x16x4_t b1, b2;
- inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const {
- b.val[0] = vandq_u8(val[0], m4b);
- b.val[2] = vshrq_n_u8(val[0], 4);
- b.val[1] = vandq_u8(val[1], m4b);
- b.val[3] = vshrq_n_u8(val[1], 4);
- }
- inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const {
- b.val[0] = vandq_u8(val[0], m4b);
- b.val[1] = vshrq_n_u8(val[0], 4);
- b.val[2] = vandq_u8(val[1], m4b);
- b.val[3] = vshrq_n_u8(val[1], 4);
- }
- inline void prepare(const uint8_t * qs) {
- auto q4bits = vld1q_u8_x2(qs);
- prepare4(b1, q4bits.val);
- q4bits = vld1q_u8_x2(qs+32);
- prepare4(b2, q4bits.val);
- }
- inline void prepare_v2(const uint8_t * qs) {
- auto q4bits = vld1q_u8_x4(qs);
- prepare4(b1, q4bits.val+0);
- prepare4(b2, q4bits.val+2);
- }
- inline void prepare64(const uint8_t * qs) {
- auto q4bits = vld1q_u8_x4(qs);
- b1.val[0] = vandq_u8(q4bits.val[0], m4b);
- b1.val[1] = vandq_u8(q4bits.val[1], m4b);
- b1.val[2] = vandq_u8(q4bits.val[2], m4b);
- b1.val[3] = vandq_u8(q4bits.val[3], m4b);
- b2.val[0] = vshrq_n_u8(q4bits.val[0], 4);
- b2.val[1] = vshrq_n_u8(q4bits.val[1], 4);
- b2.val[2] = vshrq_n_u8(q4bits.val[2], 4);
- b2.val[3] = vshrq_n_u8(q4bits.val[3], 4);
- }
- inline void prepare16(const uint8_t * qs) {
- auto q4bits = vld1q_u8_x2(qs);
- prepare4_16(b1, q4bits.val);
- q4bits = vld1q_u8_x2(qs+32);
- prepare4_16(b2, q4bits.val);
- }
- inline void prepare16_v2(const uint8_t * qs) {
- auto q4bits = vld1q_u8_x4(qs);
- prepare4_16(b1, q4bits.val+0);
- prepare4_16(b2, q4bits.val+2);
- }
-};
-
-struct Q2bits {
- const uint8x16_t m4b = vdupq_n_u8(0x03);
- uint8x16x4_t b1, b2;
- inline void prepare(const uint8_t * qs) {
- auto q2bits = vld1q_u8_x2(qs);
- b1.val[0] = vandq_u8(q2bits.val[0], m4b);
- b1.val[1] = vandq_u8(q2bits.val[1], m4b);
-
- q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
- q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
- b1.val[2] = vandq_u8(q2bits.val[0], m4b);
- b1.val[3] = vandq_u8(q2bits.val[1], m4b);
-
- q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
- q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
- b2.val[0] = vandq_u8(q2bits.val[0], m4b);
- b2.val[1] = vandq_u8(q2bits.val[1], m4b);
-
- q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
- q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
- b2.val[2] = vandq_u8(q2bits.val[0], m4b);
- b2.val[3] = vandq_u8(q2bits.val[1], m4b);
- }
-};
-
-template <typename block_q, bool has_row_scale = false, bool scale_is_f16 = false>
-struct BaseDequantizer {
- BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}
- inline void new_row(int ix) {
- if constexpr (has_row_scale) {
- if constexpr (scale_is_f16) {
- const ggml_half * dptr = (const ggml_half *)((const char *)vx + ix*bx);
- d = GGML_FP16_TO_FP32(*dptr);
- x = (const block_q *)(dptr + 1);
- } else {
- const float * dptr = (const float *)((const char *)vx + ix*bx);
- d = *dptr;
- x = (const block_q *)(dptr + 1);
- }
- } else {
- x = (const block_q *)((const char *)vx + ix*bx);
- }
- }
- const void * vx;
- const block_q * x;
- const size_t bx;
- const int nrc;
- float d;
-};
-
-struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
- DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
-
- constexpr static int num_blocks() { return 8; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
- d = GGML_FP16_TO_FP32(x[i].d);
- return s8.process_scales_mins(x[i], q8, i, acc);
- }
- inline void prepare(int i, int j) {
- if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
- else bits.prepare(x[i].qs+64*j);
- }
-
- Q4bits bits;
- Scales8 s8;
-
-};
-
-struct HighBit5 {
- const uint8x16_t mhb = vdupq_n_u8(0x10);
- uint8x16x2_t bits;
- inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
- b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb));
- b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb));
- b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb));
- b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb));
-
- b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
- b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
- b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
- b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
-
- if (do_shift) {
- bits.val[0] = vshrq_n_u8(bits.val[0], 4);
- bits.val[1] = vshrq_n_u8(bits.val[1], 4);
- }
- }
-};
-
-struct HighBit3 {
- const uint8x16_t mhb = vdupq_n_u8(0x04);
- uint8x16x2_t bits;
- inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
- b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
- b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
- b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
- b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
-
- b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb));
- b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb));
- b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb));
- b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb));
-
- if (do_shift) {
- bits.val[0] = vshrq_n_u8(bits.val[0], 4);
- bits.val[1] = vshrq_n_u8(bits.val[1], 4);
- }
- }
-};
-
-struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
- DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
-
- constexpr static int num_blocks() { return 8; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
- d = GGML_FP16_TO_FP32(x[i].d);
- h.bits = vld1q_u8_x2(x[i].qh);
- return s8.process_scales_mins(x[i], q8, i, acc);
- }
- inline void prepare(int i, int j) {
- if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
- else bits.prepare(x[i].qs+64*j);
- h.apply(bits.b1, bits.b2, j == 0);
- }
-
- Q4bits bits;
- HighBit5 h;
- Scales8 s8;
-
- uint8x16x2_t hbits;
-
-};
-
-inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
- int32x4x4_t scales = {
- vmovl_s16(vget_low_s16 (scales16.val[0])),
- vmovl_s16(vget_high_s16(scales16.val[0])),
- vmovl_s16(vget_low_s16 (scales16.val[1])),
- vmovl_s16(vget_high_s16(scales16.val[1])),
- };
- return scales;
-}
-
-template <typename Q8>
-inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) {
- int16x8x2_t scales16;
- scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
- scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
- accum_mins_16(scales16, q8, acc, i, c);
- return make_wider(scales16);
-}
-
-struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
- DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
-
- constexpr static int num_blocks() { return 16; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
- d = GGML_FP16_TO_FP32(x[i].d);
- return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d);
- }
- inline void prepare(int i, int j) {
-
- auto hbits = vld1q_u8_x2(x[i].qh + 32*j);
-
- bits.prepare64(x[i].ql+64*j);
- bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb));
- bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb));
- bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb));
- bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb));
-
- bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb));
- bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb));
- bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb));
- bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb));
-
- }
-
- Q4bits bits;
-
- const uint8x16_t mhb = vdupq_n_u8(0x30);
-
-};
-
-struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
- DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
-
- constexpr static int num_blocks() { return 16; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
- d = GGML_FP16_TO_FP32(x[i].d);
- h.bits = vld1q_u8_x2(x[i].hmask);
- mask = vdupq_n_u8(0x01);
- const uint16_t * sc16 = (const uint16_t *)x[i].scales;
- uint32_t aux0 = sc16[0] | (sc16[1] << 16);
- uint32_t aux1 = sc16[2] | (sc16[3] << 16);
- uint32_t aux2 = sc16[4] | (sc16[5] << 16);
- aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030);
- aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);
- aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);
- aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);
- auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32));
- if (nrc > 1) {
- return process_scales_mins_16(scales8, q8, acc, i, -4.f*d);
- }
- int16x8x2_t scales16;
- scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
- scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
- return make_wider(scales16);
- }
-
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs+32*j);
- if (nrc > 1) {
- h.apply(bits.b1, bits.b2, j == 0);
- } else {
- auto minus4 = vdupq_n_u8(0xfc);
- auto zero = vdupq_n_u8(0);
- bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
- bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
- mask = vshlq_n_u8(mask, 1);
- bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
- bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
- mask = vshlq_n_u8(mask, 1);
- bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
- bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
- mask = vshlq_n_u8(mask, 1);
- bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
- bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
- mask = vshlq_n_u8(mask, 1);
- }
- }
-
- uint32_t aux32[4];
-
- Q2bits bits;
-
- uint8x16_t mask;
- HighBit3 h;
-
-};
-
-struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
- DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
-
- constexpr static int num_blocks() { return 16; }
- constexpr static bool should_scale_quants() { return true; }
-
- template <typename Q8>
- inline void process_scales(int i, const Q8& q8, float32x4_t * acc) {
- d = GGML_FP16_TO_FP32(x[i].d);
- auto scales_and_mins = vld1q_u8(x[i].scales);
- auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4));
- int16x8x2_t scales16;
- scales16.val[0] = vmovl_s8(vget_low_s8(mins8));
- scales16.val[1] = vmovl_s8(vget_high_s8(mins8));
- accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin));
-
- scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf));
- }
-
- template <typename Q8>
- inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
- process_scales(i, q8, acc);
- int16x8x2_t scales16;
- scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8)));
- scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8)));
- return make_wider(scales16);
- }
-
- template <typename Q8>
- inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {
- auto m1 = vdupq_n_u8(1);
- auto shuffle = vdupq_n_u8(8*j);
- bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
- bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
- bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
- bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
- bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
- bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
- bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
- bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
- sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
- vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
-
- auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
- sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
- vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
-
- auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
- sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),
- vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);
-
- auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
- sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),
- vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);
- }
- }
-
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs+32*j);
- }
-
- uint32_t aux32[4];
-
- uint8x16_t scales8;
-
- Q2bits bits;
-
-};
-
-// ============================= i-quants
-
-inline int32x4x4_t make_wider_8(const int8x16_t& scales8) {
- int16x8x2_t scales16{vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8))};
- return make_wider(scales16);
-}
-
-struct Scale16Extra {
- template <typename Q8>
- static inline int32x4x4_t new_block(int i, float d, uint16_t extra, uint8_t val,
- const int8x16_t& scales8, const Q8& q8, float32x4_t * acc) {
- uint8x16_t e8 = vreinterpretq_u8_u16(vdupq_n_u16(extra));
- e8 = vceqq_u8(vandq_u8(e8, emask), emask);
- e8 = vqtbl1q_u8(vandq_u8(e8, vdupq_n_u8(val)), eshuff);
- int16x8x2_t extra16 = {vmull_s8(vget_low_s8 (e8), vget_low_s8 (scales8)),
- vmull_s8(vget_high_s8(e8), vget_high_s8(scales8))};
- accum_mins_16(extra16, q8, acc, i, d);
- return make_wider_8(scales8);
- }
-
- constexpr static uint32x4_t emask = {0x02020101, 0x08080404, 0x20201010, 0x80804040};
- constexpr static uint32x4_t eshuff = {0x06040200, 0x0e0c0a08, 0x07050301, 0x0f0d0b09};
-};
-
-// Note: on ARM_NEON we cannot use the values shifted into the uint8_t range because
-// the ARM_NEON only has vdotq_s32 or vdotq_u32, where both operands need to
-// be signed or unsigned. As the Q8_K quants are signed, we need to have the
-// iq4_s quants also signed. We can only use unsigned values in k-quants
-// because they are all within the valid int8_t range.
-struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
- DequantizerIQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8(iq4k_values)) {}
-
- constexpr static int num_blocks() { return 16; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
- d = GGML_FP16_TO_FP32(x[i].d);
- return Scale16Extra::new_block(i, d, x[i].extra, 4, make_scales(x[i].scales_l, x[i].scales_h), q8, acc);
- }
- inline void prepare(int i, int j) {
- bits.prepare16(x[i].qs+64*j);
- for (int k = 0; k < 4; ++k) {
- bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]);
- bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]);
- }
- }
- inline int8x16_t make_scales(const uint8_t * scales_l, const uint8_t * scales_h) const {
- uint8x8_t aux = vld1_u8(scales_l);
- uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
- const uint32_t * aux32 = (const uint32_t *)scales_h;
- uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2};
- uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30));
- int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, hshuff));
- return vaddq_s8(vqtbl1q_s8(scales8, hshuff), vdupq_n_s8(-32));
- }
-
- Q4bits bits;
- const int8x16_t values;
- const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
-
-};
-
-struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
- DequantizerIQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq5nl_values)) {}
-
- constexpr static int num_blocks() { return 16; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
- d = GGML_FP16_TO_FP32(x[i].d);
- hbits = vld1q_u8_x2(x[i].qh); // hbits.val[0] holds 0....15, 32...47, 64...79, 96...111, 128...143, 160...175, 192...207, 224...239
- // hbits.val[1] holds 16...31, 48...63, 80...95, 112..127, 144...159, 176...191, 208...223, 240...255
- return Scale16Extra::new_block(i, d, x[i].extra, 2, make_scales(x[i].scales_l, x[i].scales_h), q8, acc);
- }
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs+64*j);
- if (j == 1) {
- for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4);
- }
- bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm));
- bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm));
- bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm));
- bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm));
- bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm));
- bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm));
- bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm));
- bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm));
- for (int k = 0; k < 4; ++k) {
- bits.b1.val[k] = vqtbl2q_s8(values, bits.b1.val[k]);
- bits.b2.val[k] = vqtbl2q_s8(values, bits.b2.val[k]);
- }
- }
- inline int8x16_t make_scales(const uint8_t * scales_l, const uint8_t * scales_h) const {
- uint8x8_t aux = vld1_u8(scales_l);
- uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
- const uint32_t * aux32 = (const uint32_t *)scales_h;
- uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2};
- uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30));
- int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, hshuff));
- return vaddq_s8(vqtbl1q_s8(scales8, hshuff), vdupq_n_s8(-32));
- }
-
- Q4bits bits;
- const int8x16x2_t values;
- const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
- const uint8x16_t hm = vdupq_n_u8(0x10);
- uint8x16x2_t hbits;
-
-};
-
-struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
- DequantizerIQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x4(iq6nl_values)) {}
-
- constexpr static int num_blocks() { return 16; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
- d = GGML_FP16_TO_FP32(x[i].d);
- return Scale16Extra::new_block(i, d, x[i].extra, 1, vld1q_s8(x[i].scales), q8, acc);
- }
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs+64*j);
- auto hbits = vld1q_u8_x2(x[i].qh + 32*j);
- bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm));
- bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm));
- bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm));
- bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm));
- bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hm));
- bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hm));
- bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), hm));
- bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), hm));
- for (int k = 0; k < 4; ++k) {
- bits.b1.val[k] = vqtbl4q_s8(values, bits.b1.val[k]);
- bits.b2.val[k] = vqtbl4q_s8(values, bits.b2.val[k]);
- }
- }
-
- Q4bits bits;
- const int8x16x4_t values;
- const uint8x16_t hm = vdupq_n_u8(0x30);
-
-};
-
-struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
- DequantizerIQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
-
- constexpr static int num_blocks() { return 16; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
- d = GGML_FP16_TO_FP32(x[i].d);
- return Scale16Extra::new_block(i, d, x[i].extra, 5, make_scales(x[i].scales), q8, acc);
- }
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs+32*j);
- for (int k = 0; k < 4; ++k) {
- bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]);
- bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]);
- }
- }
- inline int8x16_t make_scales(const uint8_t * scales_l) const {
- uint8x8_t aux = vld1_u8(scales_l);
- uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
- int8x16_t scales = vaddq_s8(vreinterpretq_s8_u8(scl8), vdupq_n_s8(-8));
- return vqtbl1q_s8(scales, hshuff);
- }
-
- Q2bits bits;
- const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x000000001101f3e1));
- const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
-
-};
-
-struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
- DequantizerIQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
-
- constexpr static int num_blocks() { return 16; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
- d = GGML_FP16_TO_FP32(x[i].d);
- return Scale16Extra::new_block(i, d, x[i].extra, 4, make_scales(x[i].scales_h, x[i].scales_l), q8, acc);
- }
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs+32*j);
- if (j == 0) {
- hbits = vld1q_u8_x2(x[i].qh);
- }
- else {
- hbits.val[0] = vshrq_n_u8(hbits.val[0], 4);
- hbits.val[1] = vshrq_n_u8(hbits.val[1], 4);
- }
- bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hmask));
- bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hmask));
- bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hmask));
- bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hmask));
- bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hmask));
- bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hmask));
- bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 1), hmask));
- bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 1), hmask));
- for (int k = 0; k < 4; ++k) {
- bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]);
- bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]);
- }
- }
- inline int8x16_t make_scales(uint16_t sign_bits, const uint8_t * scales_l) const {
- uint8x8_t aux = vld1_u8(scales_l);
- uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
- int8x16_t scales = vaddq_s8(vreinterpretq_s8_u8(vshlq_n_u8(scl8, 1)), vdupq_n_s8(1));
- uint8x16_t signs = vceqq_u8(vandq_u8(vreinterpretq_u8_u16(vdupq_n_u16(sign_bits)), sign_mask), sign_mask);
- signs = vorrq_u8(signs, vdupq_n_u8(1));
- // scales are 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15
- // signs are 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15
- scales = vmulq_s8(scales, vreinterpretq_s8_u8(vqtbl1q_u8(signs, sign_shuffle)));
- return vqtbl1q_s8(scales, hshuff);
- }
- inline static uint8x16_t load_sign_shuffle() {
- static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
- return vld1q_u8(k_shuff);
- }
-
- Q2bits bits;
- uint8x16x2_t hbits;
- const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x2f1c0d01f6e9d8c1));
- const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
- const uint8x16_t hmask = vdupq_n_u8(4);
- const uint8x16_t sign_mask = vreinterpretq_u8_u64(uint64x2_t{0x0808040402020101, 0x8080404020201010});
- const uint8x16_t sign_shuffle = load_sign_shuffle();
-
-};
-
-struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
-
- static int8x16_t load_values() {
- static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
- return vld1q_s8(iq4nl_values);
- }
-
- DequantizerIQ4XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {}
-
- constexpr static int num_blocks() { return 8; }
- constexpr static bool should_scale_quants() { return false; }
-
- inline void new_row(int ix) { x = (const block_iq4_xs *)((const char *)vx + bx*ix); }
-
- template <typename Q8>
- inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
- (void)q8;
- (void)acc;
- d = GGML_FP16_TO_FP32(x[i].d);
- const uint16_t scales_h = x[i].scales_h;
- const uint16_t * scales_l = (const uint16_t *)x[i].scales_l;
- aux32[0] = scales_l[0] | (scales_l[1] << 16);
- aux32[1] = aux32[0] >> 4;
- // scl is ordered as 0, 2, 4, 6, 1, 3, 5, 7
- uint8x8_t scl8 = vand_u8(vld1_u8((const uint8_t *)aux32), vdup_n_u8(0xf));
- uint16_t * aux16 = (uint16_t *)aux32;
- aux16[0] = scales_h << 4; aux16[1] = scales_h << 2; aux16[2] = scales_h; aux16[3] = scales_h >> 2;
- // sch is ordered as 0, 4, 1, 5, 2, 6, 3, 7
- uint8x8_t sch8 = vand_u8(vld1_u8((const uint8_t *)aux16), vdup_n_u8(0x30));
- int8x8_t scales8 = vadd_s8(vreinterpret_s8_u8(vorr_u8(scl8, vtbl1_u8(sch8, vreinterpret_u8_u32(hshuff)))), vdup_n_s8(-32));
- // shuffle 0, 2, 4, 6, 1, 3, 5, 7 -> 0, 1, 2, 3, 4, 5, 6, 7
- scales8 = vtbl1_s8(scales8, vreinterpret_s8_u32(hshuff));
- int16x8_t scales16 = vmovl_s8(scales8);
- int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
- return scales;
- }
- inline void prepare(int i, int j) {
- bits.prepare16(x[i].qs+64*j);
- //if (nrc == 1) {
- // bits.prepare16_v2(x[i].qs+64*j);
- //} else {
- // bits.prepare16(x[i].qs+64*j);
- //}
- for (int k = 0; k < 4; ++k) {
- bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k]));
- bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k]));
- }
- }
-
- Q4bits bits;
- const int8x16_t values;
- uint32_t aux32[2];
-
- constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602};
-
-};
-
-struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
-
- DequantizerIQ4KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {}
-
- constexpr static int num_blocks() { return 8; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
- (void)q8;
- (void)acc;
- auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(vld1_u8(x[i].scales)), mask)), m127);
- int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
- return scales;
- }
- inline void prepare(int i, int j) {
- bits.prepare16(x[i].qs+64*j);
- for (int k = 0; k < 4; ++k) {
- bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values.val[x[i].scales[4*j+k] & 1], bits.b1.val[k]));
- bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values.val[x[i].scales[4*j+k] & 1], bits.b2.val[k]));
- }
- }
-
- Q4bits bits;
- const int8x16x2_t values;
- const uint16x8_t mask = vdupq_n_u16(254);
- const int16x8_t m127 = vdupq_n_s16(-127);
-};
-
-struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
- DequantizerIQ5KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc),
- values(vld1q_s8_x4(iq5nl_values)) {}
-
- constexpr static int num_blocks() { return 8; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
- (void)q8;
- (void)acc;
- auto sas8 = vld1_u8(x[i].scales);
- auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(sas8), mask)), m127);
- hbits = vld1q_u8_x2(x[i].qh);
- sas = vcombine_u8(sas8, sas8);
- sas = vshlq_n_u8(vandq_u8(sas, vdupq_n_u8(1)), 5);
- int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
- return scales;
- }
-
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs+64*j);
- if (j == 1) {
- for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4);
- }
- auto shift = vdupq_n_u8((x[i].scales[4*j+0] & 1) << 5);
- bits.b1.val[0] = vaddq_u8(shift, vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm)));
- bits.b1.val[1] = vaddq_u8(shift, vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm)));
- shift = vdupq_n_u8((x[i].scales[4*j+1] & 1) << 5);
- bits.b1.val[2] = vaddq_u8(shift, vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm)));
- bits.b1.val[3] = vaddq_u8(shift, vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm)));
- for (int k = 0; k < 4; ++k) bits.b1.val[k] = vqtbl4q_s8(values, bits.b1.val[k]);
- shift = vdupq_n_u8((x[i].scales[4*j+2] & 1) << 5);
- bits.b2.val[0] = vaddq_u8(shift, vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm)));
- bits.b2.val[1] = vaddq_u8(shift, vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm)));
- shift = vdupq_n_u8((x[i].scales[4*j+3] & 1) << 5);
- bits.b2.val[2] = vaddq_u8(shift, vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm)));
- bits.b2.val[3] = vaddq_u8(shift, vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm)));
- for (int k = 0; k < 4; ++k) bits.b2.val[k] = vqtbl4q_s8(values, bits.b2.val[k]);
- }
-
- Q4bits bits;
- const int8x16x4_t values;
- const uint8x16_t hm = vdupq_n_u8(0x10);
- const uint16x8_t mask = vdupq_n_u16(254);
- const int16x8_t m127 = vdupq_n_s16(-127);
- uint8x16x2_t hbits;
- uint8x16_t sas;
-
-};
-
-struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
-
- DequantizerIQ4KSS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {}
-
- constexpr static int num_blocks() { return 8; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
- (void)q8;
- (void)acc;
- auto q4bits_1 = vld1q_u16_x4((const uint16_t *)x[i].qs);
- q4bits_2 = vld1q_u16_x4((const uint16_t *)x[i].qs + 32);
- for (int k = 0; k < 4; ++k) {
- aux[k+0] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_1.val[k], m1), shift));
- aux[k+4] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_2.val[k], m1), shift));
- q4bits_1.val[k] = vandq_u16(q4bits_1.val[k], bmask);
- q4bits_1.val[k] = veorq_u16(q4bits_1.val[k], vshrq_n_u16(q4bits_1.val[k], 1));
- q4bits_2.val[k] = vandq_u16(q4bits_2.val[k], bmask);
- q4bits_2.val[k] = veorq_u16(q4bits_2.val[k], vshrq_n_u16(q4bits_2.val[k], 1));
- }
- make_quants(q4bits_1, bits, aux);
- auto scales16 = vld1q_s16(aux);
- scales16 = vaddq_s16(vandq_s16(scales16, mask), m127);
- int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
- return scales;
- }
- inline void make_quants(uint16x8x4_t& q4bits, Q4bits& bits, const int16_t * aux) const {
- bits.b1.val[0] = vqtbl1q_s8(values.val[aux[0] & 1], vandq_u8(q4bits.val[0], bits.m4b));
- bits.b1.val[1] = vqtbl1q_s8(values.val[aux[0] & 1], vshrq_n_u8(q4bits.val[0], 4));
- bits.b1.val[2] = vqtbl1q_s8(values.val[aux[1] & 1], vandq_u8(q4bits.val[1], bits.m4b));
- bits.b1.val[3] = vqtbl1q_s8(values.val[aux[1] & 1], vshrq_n_u8(q4bits.val[1], 4));
- bits.b2.val[0] = vqtbl1q_s8(values.val[aux[2] & 1], vandq_u8(q4bits.val[2], bits.m4b));
- bits.b2.val[1] = vqtbl1q_s8(values.val[aux[2] & 1], vshrq_n_u8(q4bits.val[2], 4));
- bits.b2.val[2] = vqtbl1q_s8(values.val[aux[3] & 1], vandq_u8(q4bits.val[3], bits.m4b));
- bits.b2.val[3] = vqtbl1q_s8(values.val[aux[3] & 1], vshrq_n_u8(q4bits.val[3], 4));
- }
- inline void prepare([[maybe_unused]] int i, int j) {
- if (j == 0) return;
- make_quants(q4bits_2, bits, aux+4);
- }
- static int16x8_t load_shift() {
- static const int16_t k_shift[8] = {0, 1, 2, 3, 4, 5, 6, 7};
- return vld1q_s16(k_shift);
- }
-
- Q4bits bits;
- const int8x16x2_t values;
- const uint16x8_t mask = vdupq_n_s16(254);
- const uint16x8_t bmask = vdupq_n_u16(0xfffe);
- const uint16x8_t m1 = vdupq_n_u16(1);
- const int16x8_t shift = load_shift();
- const int16x8_t m127 = vdupq_n_s16(-127);
- uint16x8x4_t q4bits_2;
- int16_t aux[8];
-};
-
-struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
- DequantizerIQ2KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
-
- constexpr static int num_blocks() { return 8; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x2_t new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) {
- const uint16_t * sc16 = (const uint16_t *)x[i].scales;
- uint32_t aux32 = sc16[0] | (sc16[1] << 16);
- uint8x8_t scales8 = vreinterpret_u8_u32(vdup_n_u32(aux32));
- scales8 = vand_u8(vzip1_u8(scales8, vshr_n_u8(scales8, 4)), vdup_n_u8(0xf));
- uint8x8_t sh = vand_u8(vceq_u8(vand_u8(vdup_n_u8(x[i].extra >> 8), hmask), vdup_n_u8(0)), vdup_n_u8(16));
- int16x8_t scales16 = vmovl_s8(vsub_s8(vreinterpret_s8_u8(scales8), vreinterpret_s8_u8(sh)));
- int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
- return scales;
- }
- inline void prepare(int i, int j) {
- uint8_t extra = x[i].extra >> 4*j;
- bits.prepare(x[i].qs+32*j);
- bits.b1.val[0] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[0]);
- bits.b1.val[1] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[1]); extra >>= 1;
- bits.b1.val[2] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[2]);
- bits.b1.val[3] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[3]); extra >>= 1;
- bits.b2.val[0] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[0]);
- bits.b2.val[1] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[1]); extra >>= 1;
- bits.b2.val[2] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[2]);
- bits.b2.val[3] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[3]);
- }
-
- Q2bits bits;
- const uint8x8_t hmask = vreinterpret_u8_u64(vdup_n_u64(0x8040201008040201));
- const int8x16x2_t values = { vreinterpretq_s8_u64(vdupq_n_u64(0x1101f3e1)), vreinterpretq_s8_u64(vdupq_n_u64(0x1606f8e6)) };
-
-};
-
-struct SimpleBits {
- uint8x16x4_t b1;
- uint8x16x4_t b2;
-};
-
-inline int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) {
- int32x4x2_t scales;
- scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v1, 28), 1), vdupq_n_u32(1)));
- scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v2, 28), 1), vdupq_n_u32(1)));
- return scales;
-}
-
-inline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) {
- auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127))));
- auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127))));
- b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1));
- b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2));
-}
-
-struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
- DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
-
- constexpr static int num_blocks() { return 8; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
- d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
-
- auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs);
- data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3
- data.val[1] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3
- data.val[2] = vuzp1q_u32(tmp.val[2], tmp.val[3]); // codebook indices for blocks 4...7
- data.val[3] = vuzp2q_u32(tmp.val[2], tmp.val[3]); // scales and signs for blocks 4...7
-
- return prepare_scales_8(data.val[1], data.val[3]);
- }
-
- static inline void prepare2(uint8x16_t * b, const uint8_t * idx, const uint64_t * signs, uint32_t sidx) {
- b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]});
- b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]});
- apply_signs_2(b, signs, sidx);
- }
-
- inline void prepare(int /*i*/, int j) {
- const uint8_t * idx = (const uint8_t *)(data.val + 2*j);
- const uint32_t * sidx = (const uint32_t *)(data.val + 2*j+1);
- prepare2(bits.b1.val + 0, idx, keven_signs, sidx[0]); idx += 4;
- prepare2(bits.b1.val + 2, idx, keven_signs, sidx[1]); idx += 4;
- prepare2(bits.b2.val + 0, idx, keven_signs, sidx[2]); idx += 4;
- prepare2(bits.b2.val + 2, idx, keven_signs, sidx[3]);
- }
-
- uint32x4x4_t data;
- SimpleBits bits;
-
-};
-
-inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) {
- auto aux = vld1_u8(sc);
- auto scales_l = vand_u8(aux, vdup_n_u8(0xf));
- auto scales_h = vshr_n_u8(aux, 4);
- auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
-
- auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1)));
- int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) };
- return make_wider(scales16);
-}
-
-struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
- DequantizerIQ2XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
-
- constexpr static int num_blocks() { return 16; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
- d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
- return prepare_4bit_scales16(x[i].scales);
- }
-
- inline static uint8x16_t make1(const uint16_t * qs) {
- auto b = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_u8((const uint8_t *)(iq2xs_grid + (qs[1] & 511))));
- auto s = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9))));
- return vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b), s));
- }
-
- inline static void make4(const uint16_t * qs, uint8x16_t * b) {
- b[0] = make1(qs + 0);
- b[1] = make1(qs + 2);
- b[2] = make1(qs + 4);
- b[3] = make1(qs + 6);
- }
-
- inline void prepare(int i, int j) {
- make4(x[i].qs + 16*j + 0, bits.b1.val);
- make4(x[i].qs + 16*j + 8, bits.b2.val);
- }
-
- SimpleBits bits;
-
-
-};
-
-struct SignHelper {
-
- inline void init() { shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); }
-
- inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16) {
- auto aux = vqtbl1q_u8(signs16, shuffle);
- auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1));
- b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s));
- shuffle = vaddq_u8(shuffle, step);
- }
-
- const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
- const uint8x16_t m1 = vdupq_n_u8(1);
- const uint8x16_t step = vdupq_n_u8(2);
- uint8x16_t shuffle;
-};
-
-struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
- DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
-
- constexpr static int num_blocks() { return 16; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
- d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
- return prepare_4bit_scales16(x[i].scales);
- }
-
- static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) {
- uint32_t aux32[2];
- const uint16_t * aux16 = (const uint16_t *)aux32;
- for (int k = 0; k < 2; ++k) {
- aux32[1] = (qh[k] << 4) | (qh[k] << 18);
- aux32[0] = (aux32[1] << 4) & 0x03000300;
- aux32[1] &= 0x03000300;
- b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))),
- vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1]))));
- sh.apply_signs_1(b+2*k+0, signs16);
-
- b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))),
- vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3]))));
- sh.apply_signs_1(b+2*k+1, signs16);
- }
- }
-
- inline void prepare(int i, int j) {
-
- const auto * qs = x[i].qs + 16*j;
- const auto * qh = x[i].qh + 4*j;
- const auto signs16 = vld1q_u8(qs + QK_K/8);
-
- sh.init();
- make4(sh, signs16, qs+0, qh+0, bits.b1.val);
- make4(sh, signs16, qs+8, qh+2, bits.b2.val);
- }
-
- SimpleBits bits;
- SignHelper sh;
-
-
-};
-
-struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
- DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
-
- constexpr static int num_blocks() { return 8; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
- d = 0.25f * GGML_FP16_TO_FP32(x[i].d);
- gas = vld1q_u32_x2((const uint32_t *)(x[i].qs + QK_K/4));
- return prepare_scales_8(gas.val[0], gas.val[1]);
- }
-
- inline static void make2(const uint8_t * q3, uint32_t sidx, uint8x16_t * b) {
- b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]});
- b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]});
- apply_signs_2(b, keven_signs, sidx);
- }
- inline void prepare(int i, int j) {
- const auto * q3 = x[i].qs + 32*j;
- const auto * signs = (const uint32_t *)(gas.val + j);
- make2(q3, signs[0], bits.b1.val + 0); q3 += 8;
- make2(q3, signs[1], bits.b1.val + 2); q3 += 8;
- make2(q3, signs[2], bits.b2.val + 0); q3 += 8;
- make2(q3, signs[3], bits.b2.val + 2);
- }
-
- SimpleBits bits;
- uint32x4x2_t gas;
-
-};
-
-struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
- DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
-
- constexpr static int num_blocks() { return 8; }
- constexpr static bool should_scale_quants() { return false; }
-
- template <typename Q8>
- inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
- d = GGML_FP16_TO_FP32(x[i].d);
- uint32_t scales32[2];
- std::memcpy(scales32, x[i].scales, 4);
- scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
- scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
- auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7
- scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400)));
- auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8));
- int32x4x2_t scales;
- scales.val[0] = vmovl_s16(vget_low_s16(scales16));
- scales.val[1] = vmovl_s16(vget_high_s16(scales16));
- return scales;
- }
-
- static inline void make2(SignHelper& sh, const uint8x16_t& signs16, const uint16x8_t& idx_l, uint8_t qh,
- const int8x16_t& hshift, uint8x16_t * b) {
- auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256)));
- const uint16_t * idx = (const uint16_t *)&vindex;
- b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});
- b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});
- sh.apply_signs_1(b+0, signs16);
- sh.apply_signs_1(b+1, signs16);
- }
- static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh,
- const int8x16_t& hshift, uint8x16_t * b) {
- auto idx_l = vld1q_u8(qs);
- make2(sh, signs16, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0);
- make2(sh, signs16, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2);
- }
-
- inline void prepare(int i, int j) {
-
- static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
- const auto hshift = vld1q_s16(k_shift);
-
- const auto * qs = x[i].qs + 32*j;
- const auto * qh = x[i].qh + 4*j;
- const auto signs16 = vld1q_u8(x[i].signs + 16*j);
-
- sh.init();
- make4(sh, signs16, qs+ 0, qh+0, hshift, bits.b1.val);
- make4(sh, signs16, qs+16, qh+2, hshift, bits.b2.val);
- }
-
- SimpleBits bits;
- SignHelper sh;
- uint32x4x2_t gas;
-
-};
-
-template <typename Dequantizer, int nrc_y>
-void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n % QK_K == 0);
- const int nb = n / QK_K;
-
- Q8<nrc_y, block_q8_K> q8(info);
-
- Dequantizer deq(vx, bx, nrc_y);
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- deq.new_row(ix);
-
- float32x4_t acc[nrc_y];
- for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
-
- for (int i = 0; i < nb; ++i) {
-
- int32x4_t sumi[nrc_y];
- for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);
-
- if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) {
- deq.process_scales(i, q8, acc);
- deq.prepare(i, 0);
- deq.compute(q8, i, 0, sumi);
- deq.prepare(i, 1);
- deq.compute(q8, i, 1, sumi);
- } else {
- if constexpr (Dequantizer::num_blocks() == 8) {
- auto scales = deq.new_block(i, q8, acc);
- deq.prepare(i, 0);
- for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
- deq.prepare(i, 1);
- for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
- }
- else if constexpr (Dequantizer::num_blocks() == 16) {
- auto scales = deq.new_block(i, q8, acc);
- deq.prepare(i, 0);
- for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
- deq.prepare(i, 1);
- for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
- }
- else {
- GGML_ASSERT(false);
- }
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
- }
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, vaddvq_f32(acc[iy]));
- }
- }
-}
-
-// =========================================== Legacy quants
-
-template <typename Block>
-inline float16x4_t load_scales_q0(const Block * x, ggml_half * aux) {
- for (int k = 0; k < 4; ++k) aux[k] = x[k].d;
- return vld1_f16((const float16_t *)aux);
-}
-
-template <typename Block>
-inline float16x8_t load_scales_q1(const Block * x, ggml_half * aux) {
- if constexpr (std::is_same_v<Block, block_q8_1>) {
- for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].s; }
- } else {
- for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].m; }
- }
- return vld1q_f16((const float16_t *)aux);
-}
-
-struct Q4LegacyBits {
- template <typename Block>
- inline void prepare(const Block * x) {
- for (int i = 0; i < 4; ++i) {
- auto q4bits = vld1q_u8(x[i].qs);
- b[2*i+0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b));
- b[2*i+1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4));
- }
- }
- inline void prepare1(const uint8_t * qs, int8x16_t * q) const {
- auto q4bits = vld1q_u8(qs);
- q[0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b));
- q[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4));
- }
- inline void prepare1(const uint8_t * qs) {
- prepare1(qs, b);
- }
- const uint8x16_t m4b = vdupq_n_u8(0xf);
- int8x16_t b[8];
-};
-
-// One would think this commented out version would do better than the one below
-// because it offers more opportunities to execute instructions in parallel.
-// Instead, it runs significantly slower. Why? If the compiler is running out of vector registers
-// cannot it just do the sequential version below on its own?
-//inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {
-// const auto q8b_1 = vld1q_s8_x2(qs + 0);
-// auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b_1.val[0]), b[1], q8b_1.val[1]);
-// const auto q8b_2 = vld1q_s8_x2(qs + 32);
-// auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b_2.val[0]), b[3], q8b_2.val[1]);
-// auto p1234 = vpaddq_s32(p12, p34);
-// const auto q8b_3 = vld1q_s8_x2(qs + 64);
-// auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b_3.val[0]), b[5], q8b_3.val[1]);
-// const auto q8b_4 = vld1q_s8_x2(qs + 96);
-// auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b_4.val[0]), b[7], q8b_4.val[1]);
-// return vpaddq_s32(p1234, vpaddq_s32(p56, p78));
-//}
-
-inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {
- auto q8b = vld1q_s8_x2(qs + 0);
- auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b.val[0]), b[1], q8b.val[1]);
- q8b = vld1q_s8_x2(qs + 32);
- auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b.val[0]), b[3], q8b.val[1]);
- auto p1234 = vpaddq_s32(p12, p34);
- q8b = vld1q_s8_x2(qs + 64);
- auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b.val[0]), b[5], q8b.val[1]);
- q8b = vld1q_s8_x2(qs + 96);
- auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b.val[0]), b[7], q8b.val[1]);
- return vpaddq_s32(p1234, vpaddq_s32(p56, p78));
-}
-
-inline int32x4x2_t sum_4_blocks(const int8x16_t * b1, const int8x16_t * b2, const int8_t * qs) {
- auto q8b = vld1q_s8_x2(qs + 0);
- auto p12_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q8b.val[0]), b1[1], q8b.val[1]);
- auto p12_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q8b.val[0]), b2[1], q8b.val[1]);
- q8b = vld1q_s8_x2(qs + 32);
- auto p34_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q8b.val[0]), b1[3], q8b.val[1]);
- auto p34_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q8b.val[0]), b2[3], q8b.val[1]);
- auto p1234_1 = vpaddq_s32(p12_1, p34_1);
- auto p1234_2 = vpaddq_s32(p12_2, p34_2);
- q8b = vld1q_s8_x2(qs + 64);
- auto p56_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[4], q8b.val[0]), b1[5], q8b.val[1]);
- auto p56_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[4], q8b.val[0]), b2[5], q8b.val[1]);
- q8b = vld1q_s8_x2(qs + 96);
- auto p78_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[6], q8b.val[0]), b1[7], q8b.val[1]);
- auto p78_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[6], q8b.val[0]), b2[7], q8b.val[1]);
- auto p5678_1 = vpaddq_s32(p56_1, p78_1);
- auto p5678_2 = vpaddq_s32(p56_2, p78_2);
- return { vpaddq_s32(p1234_1, p5678_1), vpaddq_s32(p1234_2, p5678_2)};
-}
-
-template <int nrc> struct Q80 {
-
- constexpr static int nrc_y = nrc;
-
- Q80(const DataInfo& info) {
- for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy);
- }
-
- inline const int8_t * quant_data(int iy, int i) const {
- const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i;
- return y4->qs;
- }
-
- inline float16x4_t load_scales(int iy, int i) const {
- const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i;
- return vld1_f16((const float16_t *)y4->d);
- }
-
- template <typename Dequantizer>
- inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * /*acc*/) const {
- auto qx_scales = deq.new_block(i);
- for (int iy = 0; iy < nrc; ++iy) {
- auto q8_scales = load_scales(iy, i);
- sc16[iy] = vmul_f16(qx_scales, q8_scales);
- }
- }
-
- template <typename Dequantizer>
- inline void process_scales(int i, Dequantizer& deq1, Dequantizer& deq2, float16x4_t * sc16, float32x4_t * /*acc*/) const {
- auto qx_scales_1 = deq1.new_block(i);
- auto qx_scales_2 = deq2.new_block(i);
- for (int iy = 0; iy < nrc; ++iy) {
- auto q8_scales = load_scales(iy, i);
- sc16[iy ] = vmul_f16(qx_scales_1, q8_scales);
- sc16[iy+nrc_y] = vmul_f16(qx_scales_2, q8_scales);
- }
- }
-
- template <typename Dequantizer>
- inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {
- deq.prepare1(i);
- float d = GGML_FP16_TO_FP32(deq.x[i].d);
- for (int iy = 0; iy < nrc; ++iy) {
- auto q8b = vld1q_s8_x2(y[iy][i].qs);
- auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);
- acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p));
- }
- }
-
- const block_q8_0 * y[nrc_y];
-};
-
-template <int nrc> struct Q81 {
-
- constexpr static int nrc_y = nrc;
-
- Q81(const DataInfo& info) {
- for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_1 *)info.src1_row(iy);
- }
-
- inline const int8_t * quant_data(int iy, int i) const {
- const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i;
- return y4->qs;
- }
-
- inline float16x8_t load_scales(int iy, int i) const {
- const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i;
- return vld1q_f16((const float16_t *)y4->d);
- }
-
- template <typename Dequantizer>
- inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * acc) const {
- auto qx_scales = deq.new_block(i);
- for (int iy = 0; iy < nrc; ++iy) {
- auto q8_scales = load_scales(iy, i);
- auto m = vmul_f16(vget_high_f16(qx_scales), vget_high_f16(q8_scales));
- acc[iy] = vaddq_f32(acc[iy], vcvt_f32_f16(m));
- sc16[iy] = vmul_f16(vget_low_f16(qx_scales), vget_low_f16(q8_scales));
- }
- }
-
- template <typename Dequantizer>
- inline void process_scales(int i, Dequantizer& deq1, Dequantizer& deq2, float16x4_t * sc16, float32x4_t * acc) const {
- auto qx_scales_1 = deq1.new_block(i);
- auto qx_scales_2 = deq2.new_block(i);
- for (int iy = 0; iy < nrc; ++iy) {
- auto q8_scales = load_scales(iy, i);
- auto q8_scales_l = vget_low_f16(q8_scales);
- auto q8_scales_h = vget_high_f16(q8_scales);
- auto m1 = vmul_f16(vget_high_f16(qx_scales_1), q8_scales_h);
- auto m2 = vmul_f16(vget_high_f16(qx_scales_2), q8_scales_h);
- acc[iy ] = vaddq_f32(acc[iy ], vcvt_f32_f16(m1));
- acc[iy+nrc_y ] = vaddq_f32(acc[iy+nrc_y], vcvt_f32_f16(m2));
- sc16[iy ] = vmul_f16(vget_low_f16(qx_scales_1), q8_scales_l);
- sc16[iy+nrc_y] = vmul_f16(vget_low_f16(qx_scales_2), q8_scales_l);
- }
- }
-
- template <typename Dequantizer>
- inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {
- deq.prepare1(i);
- float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m);
- for (int iy = 0; iy < nrc; ++iy) {
- auto q8b = vld1q_s8_x2(y[iy][i].qs);
- auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);
- acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p));
- acc[iy] = vaddq_f32(acc[iy], vdupq_n_f32(m*GGML_FP16_TO_FP32(y[iy][i].s)));
- }
- }
-
- const block_q8_1 * y[nrc_y];
-};
-
-template <typename block_q>
-struct BaseLegacyDequantizer {
-
- BaseLegacyDequantizer(const void * vx, size_t bx) : vx(vx), x(nullptr), bx(bx) {}
-
- inline void new_row(int ix) { x = (const block_q *)((const char *)vx + bx*ix); }
-
- Q4LegacyBits bits;
-
- const void * vx;
- const block_q * x;
- size_t bx;
-};
-
-struct DequantizerQ40 final : public BaseLegacyDequantizer<block_q4_0> {
-
- DequantizerQ40(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
-
- inline void prepare1(int i, int8x16_t * q) const {
- bits.prepare1(x[i].qs, q);
- q[0] = vaddq_s8(q[0], m8);
- q[1] = vaddq_s8(q[1], m8);
- }
- inline void prepare1(int i) {
- prepare1(i, bits.b);
- }
-
- inline float16x4_t new_block(int i) {
- ggml_half aux[4];
- for (int k = 0; k < 4; ++k) {
- aux[k] = x[4*i+k].d;
- prepare1(4*i+k, bits.b + 2*k);
- }
- return vld1_f16((const float16_t *)aux);
- }
-
- const int8x16_t m8 = vdupq_n_s8(-8);
- //ggml_half aux[4];
-};
-
-struct DequantizerQ60 final : public BaseLegacyDequantizer<block_q6_0> {
-
- DequantizerQ60(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
-
- inline void prepare1(int i, int8x16_t * q) const {
- bits.prepare1(x[i].qs, q);
- auto qh8 = vld1_u8(x[i].qh);
- auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8);
- q[0] = vaddq_s8(vorrq_u8(q[0], vandq_u8(qh, hmask)), m32);
- q[1] = vaddq_s8(vorrq_u8(q[1], vandq_u8(vshrq_n_u8(qh, 2), hmask)), m32);
- }
- inline void prepare1(int i) {
- prepare1(i, bits.b);
- }
-
- inline float16x4_t new_block(int i) {
- ggml_half aux[4];
- for (int k = 0; k < 4; ++k) {
- aux[k] = x[4*i+k].d;
- prepare1(4*i+k, bits.b + 2*k);
- }
- return vld1_f16((const float16_t *)aux);
- }
-
- const int8x16_t m32 = vdupq_n_s8(-32);
- const uint8x16_t hmask = vdupq_n_u8(0x30);
-};
-
-struct DequantizerIQ4NL final : public BaseLegacyDequantizer<block_iq4_nl> {
-
- DequantizerIQ4NL(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
-
- inline void prepare1(int i, int8x16_t * q) const {
- bits.prepare1(x[i].qs, q);
- q[0] = vqtbl1q_s8(values, q[0]);
- q[1] = vqtbl1q_s8(values, q[1]);
- }
- inline void prepare1(int i) {
- prepare1(i, bits.b);
- }
-
- inline float16x4_t new_block(int i) {
- ggml_half aux[4];
- for (int k = 0; k < 4; ++k) {
- aux[k] = x[4*i+k].d;
- prepare1(4*i+k, bits.b + 2*k);
- }
- return vld1_f16((const float16_t *)aux);
- }
- static int8x16_t load_values() {
- static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
- return vld1q_s8(iq4nl_values);
- }
-
- const int8x16_t values = load_values();
-};
-
-struct DequantizerQ41 : public BaseLegacyDequantizer<block_q4_1> {
-
- DequantizerQ41(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
-
- inline void prepare1(int i) {
- bits.prepare1(x[i].qs);
- }
-
- inline float16x8_t new_block(int i) {
- uint32_t aux32[4];
- const uint32_t * s32 = (const uint32_t *)&x[4*i].d;
- for (int k = 0; k < 4; ++k) {
- aux32[k] = *s32; s32 += sizeof(block_q4_1)/4;
- bits.prepare1(x[4*i+k].qs, bits.b + 2*k);
- }
- return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle)));
- }
- // Leaving this commented out attempt to be reminded that I already tried this.
- // It has basically the same performance as the version above.
- //inline float16x8_t new_block(int i) {
- // uint32x4_t scales = {};
- // const block_q4_1 * xi = x + 4*i;
- // const uint32_t * s32 = (const uint32_t *)&xi->d;
- // scales = vsetq_lane_u32(*s32, scales, 0); s32 += sizeof(block_q4_1)/4;
- // bits.prepare1(xi[0].qs, bits.b + 0);
- // scales = vsetq_lane_u32(*s32, scales, 1); s32 += sizeof(block_q4_1)/4;
- // bits.prepare1(xi[1].qs, bits.b + 2);
- // scales = vsetq_lane_u32(*s32, scales, 2); s32 += sizeof(block_q4_1)/4;
- // bits.prepare1(xi[2].qs, bits.b + 4);
- // scales = vsetq_lane_u32(*s32, scales, 3);
- // bits.prepare1(xi[3].qs, bits.b + 6);
- // return vreinterpretq_f16_u8(vqtbl1q_u8(vreinterpretq_u8_u32(scales), vreinterpretq_u8_u64(shuffle)));
- //}
-
- const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302};
-};
-
-struct HighBit5Legacy {
- inline uint8x16_t to_bytes(const uint8_t * qh) const {
- uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle);
- return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vreinterpretq_u8_u64(mask));
- }
- inline uint8x16_t to_negated_bytes(const uint8_t * qh) const {
- uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle);
- return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vdupq_n_u8(0));
- }
- const uint64x2_t mask = vdupq_n_u64(0x8040201008040201);
- const uint8x16_t shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1));
-};
-
-struct DequantizerQ50 final : public BaseLegacyDequantizer<block_q5_0> {
-
- DequantizerQ50(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
-
- inline void prepare1(int i, int8x16_t * q) const {
- bits.prepare1(x[i].qs, q);
- auto qh = x[i].qh;
- q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0))));
- q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2))));
- }
- inline void prepare1(int i) {
- prepare1(i, bits.b);
- }
-
- inline float16x4_t new_block(int i) {
- ggml_half aux[4];
- for (int k = 0; k < 4; ++k) {
- aux[k] = x[4*i+k].d;
- prepare1(4*i+k, bits.b + 2*k);
- }
- return vld1_f16((const float16_t *)aux);
- }
-
- HighBit5Legacy hbits;
-
- const uint8x16_t mh = vdupq_n_u8(0xf0);
-
-};
-
-struct DequantizerQ80 final : public BaseLegacyDequantizer<block_q8_0> {
-
- DequantizerQ80(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
-
- inline void prepare1(int i) {
- bits.b[0] = vld1q_s8(x[i].qs);
- bits.b[1] = vld1q_s8(x[i].qs+16);
- }
-
- inline float16x4_t new_block(int i) {
- ggml_half aux[4];
- for (int k = 0; k < 4; ++k) {
- aux[k] = x[4*i+k].d;
- bits.b[2*k+0] = vld1q_s8(x[4*i+k].qs);
- bits.b[2*k+1] = vld1q_s8(x[4*i+k].qs+16);
- }
- return vld1_f16((const float16_t *)aux);
- }
-
-};
-
-// TODO: handle case where row size is not a multiple of 128
-struct DequantizerQ80_x4 final : public BaseLegacyDequantizer<block_q8_0_x4> {
-
- DequantizerQ80_x4(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
-
- inline void prepare1(int i) {
- bits.b[0] = vld1q_s8(x[i].qs);
- bits.b[1] = vld1q_s8(x[i].qs+16);
- }
-
- inline float16x4_t new_block(int i) {
- auto scale = vld1_f16((const float16_t *)x[i].d);
- for (int k = 0; k < 4; ++k) {
- bits.b[2*k+0] = vld1q_s8(x[i].qs+32*k);
- bits.b[2*k+1] = vld1q_s8(x[i].qs+32*k+16);
- }
- return scale;
- }
-
-};
-
-struct DequantizerQ51 final : public BaseLegacyDequantizer<block_q5_1> {
-
- DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
-
- inline void prepare1(int i, int8x16_t * q) const {
- bits.prepare1(x[i].qs, q);
- auto qh = x[i].qh;
- q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_bytes(qh+0))));
- q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_bytes(qh+2))));
- }
- inline void prepare1(int i) {
- bits.prepare1(x[i].qs, bits.b);
- }
-
- inline float16x8_t new_block(int i) {
- uint32_t aux32[4];
- const uint32_t * s32 = (const uint32_t *)&x[4*i].d;
- for (int k = 0; k < 4; ++k) {
- aux32[k] = *s32; s32 += sizeof(block_q5_1)/4;
- prepare1(4*i+k, bits.b + 2*k);
- }
- return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle)));
- }
-
- HighBit5Legacy hbits;
-
- const uint8x16_t mh = vdupq_n_u8(0x10);
- const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302};
-
-};
-
-template <typename Dequantizer, typename Q8>
-inline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) {
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- auto pall = sum_4_blocks(deq.bits.b, q8.quant_data(iy, i));
- auto scale = vcvt_f32_f16(sc16[iy]);
- acc[iy] = vmlaq_f32(acc[iy], scale, vcvtq_f32_s32(pall));
- }
-}
-
-template <typename Dequantizer, typename Q8>
-inline void sum_4(int i, Dequantizer& deq1, Dequantizer& deq2, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) {
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- auto pall = sum_4_blocks(deq1.bits.b, deq2.bits.b, q8.quant_data(iy, i));
- auto scale1 = vcvt_f32_f16(sc16[iy]);
- auto scale2 = vcvt_f32_f16(sc16[iy+Q8::nrc_y]);
- acc[iy] = vmlaq_f32(acc[iy], scale1, vcvtq_f32_s32(pall.val[0]));
- acc[iy+Q8::nrc_y] = vmlaq_f32(acc[iy+Q8::nrc_y], scale2, vcvtq_f32_s32(pall.val[1]));
- }
-}
-
-template <typename Dequantizer, typename Q8>
-inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) {
- const int nb = n / QK4_1;
-
- float16x4_t sc16[Q8::nrc_y];
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- deq.new_row(ix);
-
- float32x4_t acc[Q8::nrc_y];
- for (int iy = 0; iy < Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
-
- for (int i = 0; i < nb/4; ++i) {
- q8.process_scales(i, deq, sc16, acc);
- sum_4(i, deq, q8, sc16, acc);
- }
- for (int i = 4*(nb/4); i < nb; ++i) {
- q8.process_1_block(i, deq, acc);
- }
-
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- info.store(ix, iy, vaddvq_f32(acc[iy]));
- }
- }
-}
-
-template <typename Dequantizer, typename Q8>
-inline void mul_mat_qX_Y_q8_Y_IK(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) {
- const int nb = n / QK4_1;
-
- float16x4_t sc16[2*Q8::nrc_y];
- float32x4_t acc[2*Q8::nrc_y];
-
- for (int ix = 0; ix < nrc_x; ix += 2) {
-
- deq1.new_row(ix+0);
- deq2.new_row(ix+1);
-
- for (int iy = 0; iy < 2*Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
-
- for (int i = 0; i < nb/4; ++i) {
- q8.process_scales(i, deq1, deq2, sc16, acc);
- sum_4(i, deq1, deq2, q8, sc16, acc);
- }
- //for (int i = 4*(nb/4); i < nb; ++i) {
- // q8.process_1_block(i, deq, acc);
- //}
-
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- info.store(ix+0, iy, vaddvq_f32(acc[iy]));
- info.store(ix+1, iy, vaddvq_f32(acc[iy+Q8::nrc_y]));
- }
- }
-}
-
-template <typename Dequantizer, typename Q8>
-inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) {
- const int nb = n / QK4_1;
-
- float16x4_t sc16[2];
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- deq1.new_row(ix);
- deq2.new_row(ix);
-
- float32x4_t acc[2] = { vdupq_n_f32(0.f), vdupq_n_f32(0.f) };
-
- for (int i = 0; i < nb/8; ++i) {
- q8.process_scales(2*i+0, deq1, sc16+0, acc+0);
- q8.process_scales(2*i+1, deq2, sc16+1, acc+1);
- sum_4(2*i+0, deq1, q8, sc16+0, acc+0);
- sum_4(2*i+1, deq2, q8, sc16+1, acc+1);
- }
- for (int i = 2*(nb/8); i < nb/4; ++i) {
- q8.process_scales(i, deq1, sc16, acc);
- sum_4(i, deq1, q8, sc16, acc);
- }
- //for (int i = 4*(nb/4); i < nb; ++i) {
- // q8.process_1_block(i, deq1, acc);
- //}
-
- info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1])));
- }
-}
-
-template <typename Dequantizer, int nrc_y>
-static void mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- Q81<nrc_y> q8(info);
- if constexpr (nrc_y == 1) {
- Dequantizer deq1(vx, bx), deq2(vx, bx);
- mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);
- } else {
- if (nrc_x%2 == 0 && n%128 == 0) {
- Dequantizer deq1(vx, bx), deq2(vx, bx);
- mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x);
- } else {
- Dequantizer deq(vx, bx);
- mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);
- }
- }
-}
-
-template <typename Dequantizer, int nrc_y>
-static void mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- Q80<nrc_y> q8(info);
- if constexpr (nrc_y == 1) {
- Dequantizer deq1(vx, bx), deq2(vx, bx);
- mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);
- } else {
- if (nrc_x%2 == 0 && n%128 == 0) {
- Dequantizer deq1(vx, bx), deq2(vx, bx);
- mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x);
- } else {
- Dequantizer deq(vx, bx);
- mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);
- }
- }
-}
-
-template <typename Dequantizer>
-static void mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- Dequantizer deq1(vx, bx), deq2(vx, bx);
- Q81<1> q8(info);
- mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);
-}
-
-template <typename Dequantizer>
-static void mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- Dequantizer deq1(vx, bx), deq2(vx, bx);
- Q80<1> q8(info);
- mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x);
-}
-
-struct QF16Base {
- constexpr static int k_step = 8;
- using Data = float16x8_t;
- using Acc = float16x8_t;
- static inline Data load(const __fp16 * x) { return vld1q_f16(x); }
- static inline Data load4(const __fp16 * x) { return vcombine_f16(vld1_f16(x), vdup_n_f16(0)); }
- static inline Acc acc(Acc prev, const Data& y, const Data& x) {
- return vfmaq_f16(prev, y, x);
- }
- static inline Acc acc_first(const Data& y, const Data& x) {
- return vmulq_f16(y, x);
- }
- //constexpr static int k_step = 16;
- //using Data = float16x8x2_t;
- //static inline Data load(const __fp16 * x) { return vld1q_f16_x2(x); }
- //static inline Acc acc(Acc prev, const Data& y, const Data& x) {
- // return vfmaq_f16(vfmaq_f16(prev, y.val[0], x.val[0]), y.val[1], x.val[1]);
- //}
- //static inline Acc acc_first(const Data& y, const Data& x) {
- // return vfmaq_f16(vmulq_f16(y.val[0], x.val[0]), y.val[1], x.val[1]);
- //}
- static inline float hsum(Acc acc) {
- float32x4_t sum = vcvt_f32_f16(vadd_f16(vget_low_f16(acc), vget_high_f16(acc)));
- return vaddvq_f32(sum);
- }
-};
-template <int nrc> struct QF16 final : public QF16Base {
- using Base = QF16Base;
- constexpr static int nrc_y = nrc;
- QF16(const DataInfo& info) {
- for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)info.src1_row(iy);
- }
- QF16(const char * cx, size_t bx) {
- for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)(cx + iy*bx);
- }
- IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
- IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4(y[iy] + 4*i); }
- IQK_ALWAYS_INLINE float16x8x4_t loadx(int iy, int i) const { return vld1q_f16_x4(y[iy] + 4*k_step*i); }
- const __fp16 * y[nrc_y];
-};
-
-struct QBF16Base {
- constexpr static int k_step = 4;
- using Data = float32x4_t;
- using Acc = float32x4_t;
- static inline Data load(const uint16_t * x) { return vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vld1_u16(x)), 16)); }
- static inline Data load4(const uint16_t * x) { return load(x); }
- static inline Acc acc(Acc prev, const Data& y, const Data& x) {
- return vfmaq_f32(prev, y, x);
- }
- static inline Acc acc_first(const Data& y, const Data& x) {
- return vmulq_f32(y, x);
- }
- static inline float hsum(Acc acc) { return vaddvq_f32(acc); }
-};
-template <int nrc> struct QBF16 final : public QBF16Base {
- using Base = QBF16Base;
- constexpr static int nrc_y = nrc;
- QBF16(const DataInfo& info) {
- for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const uint16_t *)info.src1_row(iy);
- }
- QBF16(const char * cx, size_t bx) {
- for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const uint16_t *)(cx + iy*bx);
- }
- IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
- IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load(y[iy] + 4*i); }
- const uint16_t * y[nrc_y];
-};
-
-struct QF32Base {
- constexpr static int k_step = 4;
- using Data = float32x4_t;
- using Acc = float32x4_t;
- static inline Data load(const float * x) { return vld1q_f32(x); }
- static inline Data load4(const float * x) { return load(x); }
- static inline Acc acc(Acc prev, const Data& y, const Data& x) { return vfmaq_f32(prev, y, x); }
- static inline Acc acc_first(const Data& y, const Data& x) { return vmulq_f32(y, x); }
- static inline float hsum(Acc acc) { return vaddvq_f32(acc); }
-};
-template <int nrc> struct QF32 final : public QF32Base {
- using Base = QF32Base;
- constexpr static int nrc_y = nrc;
- QF32(const DataInfo& info) {
- for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
- }
- QF32(const char * cx, size_t bx) {
- for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)(cx + iy*bx);
- }
- IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
- IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load(y[iy] + 4*i); }
- const float * y[nrc_y];
-};
-
-template <typename Qy, typename Qx, bool is_multiple_of_k_step>
-IQK_NOINLINE void mul_mat_Qx_Qy_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
- GGML_ASSERT(Qx::Base::k_step == Qy::Base::k_step);
- int nb = n/Qx::Base::k_step;
- Qy y(info);
- Qx x(cx + ix0*bx, bx);
- typename Qx::Base::Data xv[Qx::nrc_y];
- typename Qx::Base::Acc acc[Qx::nrc_y*Qy::nrc_y];
- auto yv = y.load1(0, 0);
- for (int ix = 0; ix < Qx::nrc_y; ++ix) {
- xv[ix] = x.load1(ix, 0);
- acc[ix] = Qx::Base::acc_first(yv, xv[ix]);
- }
- for (int iy = 1; iy < Qy::nrc_y; ++iy) {
- yv = y.load1(iy, 0);
- for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc_first(yv, xv[ix]);
- }
- for (int i = 1; i < nb; ++i) {
- yv = y.load1(0, i);
- for (int ix = 0; ix < Qx::nrc_y; ++ix) {
- xv[ix] = x.load1(ix, i);
- acc[ix] = Qx::Base::acc(acc[ix], yv, xv[ix]);
- }
- for (int iy = 1; iy < Qy::nrc_y; ++iy) {
- yv = y.load1(iy, i);
- for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc(acc[Qx::nrc_y*iy + ix], yv, xv[ix]);
- }
- }
- if constexpr (Qx::Base::k_step > 4 && !is_multiple_of_k_step) {
- int nb4 = n/4;
- for (int i = (Qx::Base::k_step/4)*nb; i < nb4; ++i) {
- yv = y.load_tail(0, i);
- for (int ix = 0; ix < Qx::nrc_y; ++ix) {
- xv[ix] = x.load_tail(ix, i);
- acc[ix] = Qx::Base::acc(acc[ix], yv, xv[ix]);
- }
- for (int iy = 1; iy < Qy::nrc_y; ++iy) {
- yv = y.load_tail(iy, i);
- for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc(acc[Qx::nrc_y*iy + ix], yv, xv[ix]);
- }
- }
- }
- for (int iy = 0; iy < Qy::nrc_y; ++iy) for (int ix = 0; ix < Qx::nrc_y; ++ix) info.store(ix0+ix, iy, Qx::Base::hsum(acc[Qx::nrc_y*iy+ix]));
-}
-
-template <int nrc_y, int nrc_x, bool is_multiple_of_k_step>
-IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
- assert(n%QF16Base::k_step == 0);
- int nb = n/QF16Base::k_step;
- QF16<nrc_y> y(info);
- QF16<nrc_x> x(cx + ix0*bx, bx);
- QF16Base::Data xv[nrc_x];
- QF16Base::Acc acc[nrc_x*nrc_y];
- auto yv = y.load1(0, 0);
- for (int ix = 0; ix < nrc_x; ++ix) {
- xv[ix] = x.load1(ix, 0);
- acc[ix] = QF16Base::acc_first(yv, xv[ix]);
- }
- for (int iy = 1; iy < nrc_y; ++iy) {
- yv = y.load1(iy, 0);
- for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc_first(yv, xv[ix]);
- }
- for (int i = 1; i < nb; ++i) {
- yv = y.load1(0, i);
- for (int ix = 0; ix < nrc_x; ++ix) {
- xv[ix] = x.load1(ix, i);
- acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]);
- }
- for (int iy = 1; iy < nrc_y; ++iy) {
- yv = y.load1(iy, i);
- for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
- }
- }
- if constexpr (!is_multiple_of_k_step) {
- int nb4 = n/4;
- for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) {
- yv = y.load_tail(0, i);
- for (int ix = 0; ix < nrc_x; ++ix) {
- xv[ix] = x.load_tail(ix, i);
- acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]);
- }
- for (int iy = 1; iy < nrc_y; ++iy) {
- yv = y.load_tail(iy, i);
- for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QF16Base::hsum(acc[nrc_x*iy+ix]));
-}
-
-template <typename Qy, template<int> typename Qx>
-void mul_mat_Qx_Qy_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(n%4 == 0);
- constexpr int k_nx = 5;
- const char * cx = (const char *)vx;
- if (n%Qx<k_nx>::Base::k_step == 0) {
- for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
- mul_mat_Qx_Qy_NxN<Qy, Qx<k_nx>, true>(n, cx, bx, ix*k_nx, info);
- }
- int last_x = k_nx*(nrc_x/k_nx);
- if (last_x == nrc_x) return;
- int nx = nrc_x - last_x;
- switch (nx) {
- case 1: mul_mat_Qx_Qy_NxN<Qy, Qx<1>, true>(n, cx, bx, last_x, info); break;
- case 2: mul_mat_Qx_Qy_NxN<Qy, Qx<2>, true>(n, cx, bx, last_x, info); break;
- case 3: mul_mat_Qx_Qy_NxN<Qy, Qx<3>, true>(n, cx, bx, last_x, info); break;
- case 4: mul_mat_Qx_Qy_NxN<Qy, Qx<4>, true>(n, cx, bx, last_x, info); break;
- }
- } else {
- for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
- mul_mat_Qx_Qy_NxN<Qy, Qx<k_nx>, false>(n, cx, bx, ix*k_nx, info);
- }
- int last_x = k_nx*(nrc_x/k_nx);
- if (last_x == nrc_x) return;
- int nx = nrc_x - last_x;
- switch (nx) {
- case 1: mul_mat_Qx_Qy_NxN<Qy, Qx<1>, false>(n, cx, bx, last_x, info); break;
- case 2: mul_mat_Qx_Qy_NxN<Qy, Qx<2>, false>(n, cx, bx, last_x, info); break;
- case 3: mul_mat_Qx_Qy_NxN<Qy, Qx<3>, false>(n, cx, bx, last_x, info); break;
- case 4: mul_mat_Qx_Qy_NxN<Qy, Qx<4>, false>(n, cx, bx, last_x, info); break;
- }
- }
-}
-
-template <int nrc_y>
-void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(n%4 == 0);
- constexpr int k_nx = 5;
- const char * cx = (const char *)vx;
- if (n%QF16Base::k_step == 0) {
- for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
- mul_mat_f16_f16_NxN<nrc_y, k_nx, true>(n, cx, bx, ix*k_nx, info);
- }
- int last_x = k_nx*(nrc_x/k_nx);
- if (last_x == nrc_x) return;
- int nx = nrc_x - last_x;
- switch (nx) {
- case 1: mul_mat_f16_f16_NxN<nrc_y, 1, true>(n, cx, bx, last_x, info); break;
- case 2: mul_mat_f16_f16_NxN<nrc_y, 2, true>(n, cx, bx, last_x, info); break;
- case 3: mul_mat_f16_f16_NxN<nrc_y, 3, true>(n, cx, bx, last_x, info); break;
- case 4: mul_mat_f16_f16_NxN<nrc_y, 4, true>(n, cx, bx, last_x, info); break;
- }
- } else {
- for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
- mul_mat_f16_f16_NxN<nrc_y, k_nx, false>(n, cx, bx, ix*k_nx, info);
- }
- int last_x = k_nx*(nrc_x/k_nx);
- if (last_x == nrc_x) return;
- int nx = nrc_x - last_x;
- switch (nx) {
- case 1: mul_mat_f16_f16_NxN<nrc_y, 1, false>(n, cx, bx, last_x, info); break;
- case 2: mul_mat_f16_f16_NxN<nrc_y, 2, false>(n, cx, bx, last_x, info); break;
- case 3: mul_mat_f16_f16_NxN<nrc_y, 3, false>(n, cx, bx, last_x, info); break;
- case 4: mul_mat_f16_f16_NxN<nrc_y, 4, false>(n, cx, bx, last_x, info); break;
- }
- }
-}
-
-template <int nrc_x, bool is_multiple_of_k_step>
-IQK_NOINLINE void mul_mat_f16_f16_Nx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
- assert(n%QF16Base::k_step == 0);
- int nb = n/QF16Base::k_step;
- QF16<1> y(info);
- QF16<nrc_x> x(cx + ix0*bx, bx);
- QF16Base::Acc acc[4*nrc_x];
- auto yv = y.loadx(0, 0);
- for (int ix = 0; ix < nrc_x; ++ix) {
- for (int k = 0; k < 4; ++k) {
- auto xv = x.load1(ix, k);
- acc[4*ix+k] = QF16Base::acc_first(yv.val[k], xv);
- }
- }
- for (int i = 1; i < nb/4; ++i) {
- yv = y.loadx(0, i);
- for (int ix = 0; ix < nrc_x; ++ix) {
- for (int k = 0; k < 4; ++k) {
- auto xv = x.load1(ix, 4*i+k);
- acc[4*ix+k] = QF16Base::acc(acc[4*ix+k], yv.val[k], xv);
- }
- }
- }
- for (int i = 4*(nb/4); i < nb; ++i) {
- auto yv1 = y.load1(0, i);
- for (int ix = 0; ix < nrc_x; ++ix) {
- auto xv1 = x.load1(ix, i);
- acc[4*ix] = QF16Base::acc(acc[4*ix], yv1, xv1);
- }
- }
- if constexpr (!is_multiple_of_k_step) {
- int nb4 = n/4;
- for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) {
- auto yv1 = y.load_tail(0, i);
- for (int ix = 0; ix < nrc_x; ++ix) {
- auto xv1 = x.load_tail(ix, i);
- acc[4*ix] = QF16Base::acc(acc[4*ix], yv1, xv1);
- }
- }
- }
- for (int ix = 0; ix < nrc_x; ++ix) {
- auto v1 = vaddq_f16(acc[4*ix+0], acc[4*ix+1]);
- auto v2 = vaddq_f16(acc[4*ix+2], acc[4*ix+3]);
- info.store(ix0+ix, 0, QF16Base::hsum(vaddq_f16(v1, v2)));
- }
-}
-
-// At least on my M2-Max the version below, which does the multiplication row-by-row, is faster.
-// But let's keep this version commented out for now.
-//void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
-// GGML_ASSERT(n%4 == 0);
-// constexpr int k_nx = 2;
-// const char * cx = (const char *)vx;
-// if (n%QF16Base::k_step == 0) {
-// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
-// mul_mat_f16_f16_Nx1<k_nx, true>(n, cx, bx, ix*k_nx, info);
-// }
-// int last_x = k_nx*(nrc_x/k_nx);
-// if (last_x == nrc_x) return;
-// int nx = nrc_x - last_x;
-// switch (nx) {
-// case 1: mul_mat_f16_f16_Nx1<1, true>(n, cx, bx, last_x, info); break;
-// //case 2: mul_mat_f16_f16_Nx1<2, true>(n, cx, bx, last_x, info); break;
-// //case 3: mul_mat_f16_f16_Nx1<3, true>(n, cx, bx, last_x, info); break;
-// }
-// } else {
-// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
-// mul_mat_f16_f16_Nx1<k_nx, false>(n, cx, bx, ix*k_nx, info);
-// }
-// int last_x = k_nx*(nrc_x/k_nx);
-// if (last_x == nrc_x) return;
-// int nx = nrc_x - last_x;
-// switch (nx) {
-// case 1: mul_mat_f16_f16_Nx1<1, false>(n, cx, bx, last_x, info); break;
-// //case 2: mul_mat_f16_f16_Nx1<2, false>(n, cx, bx, last_x, info); break;
-// //case 3: mul_mat_f16_f16_Nx1<3, false>(n, cx, bx, last_x, info); break;
-// }
-// }
-//}
-
-void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(n%4 == 0);
- const char * cx = (const char *)vx;
- if (n%QF16Base::k_step == 0) {
- for (int ix = 0; ix < nrc_x; ++ix) {
- mul_mat_f16_f16_Nx1<1, true>(n, cx, bx, ix, info);
- }
- } else {
- for (int ix = 0; ix < nrc_x; ++ix) {
- mul_mat_f16_f16_Nx1<1, false>(n, cx, bx, ix, info);
- }
- }
-}
-
-template <int nrc> struct Q8_K64 {
-
- constexpr static int nrc_y = nrc;
-
- Q8_K64(const DataInfo& info) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dptr = (const float *)info.src1_row(iy);
- std::memcpy(d + 8*iy, dptr, 8*sizeof(float));
- y[iy] = (const int8_t *)(dptr + 8);
- }
- }
-
- inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy] + 128*i + 64*j); }
- inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy] + 128*i + 32*j); }
- inline float32x4_t scale(int iy) const { return vld1q_f32(d + 8*iy); }
- inline float32x4_t minus(int iy) const { return vld1q_f32(d + 8*iy + 4); }
-
- float d[8*nrc_y];
- const int8_t * y[nrc_y];
-};
-
-struct DequantizerIQ1BN {
- const uint8x16_t m1 = vdupq_n_u8(1);
-
- static inline uint8x16x4_t load_shuffles() {
- static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12,
- 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12,
- 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12,
- 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12};
- return vld1q_u8_x4(data);
- }
- static inline uint8x16x4_t load_mult() {
- static const uint8_t data[64] = {81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81,
- 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 27,
- 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 9,
- 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 3};
- return vld1q_u8_x4(data);
- }
- const uint8x16x4_t shuff = load_shuffles();
- const uint8x16x4_t mult = load_mult();
-
- IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
- auto data = vld1q_u8((const uint8_t *)x);
- for (int k = 0; k < 4; ++k) {
- auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
- val = vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6);
- v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1);
- }
- }
-
- IQK_ALWAYS_INLINE void prepare_iq1bn_quants_nosub(const block_iq1_bn * x, int8x16x4_t& v) const {
- auto data = vld1q_u8((const uint8_t *)x);
- for (int k = 0; k < 4; ++k) {
- auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
- v.val[k] = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6));
- }
- }
-};
-
-template <int nrc_y>
-static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- const int nb = n / QK_IQ1BN;
-
- Q8_K64<nrc_y> q8(info);
- DequantizerIQ1BN deq;
-
- int32x4_t accd[nrc_y];
- int8x16x4_t v1, v2;
-
- float scale;
- ggml_half d16;
- char * c16 = (char *)&d16;
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- const char * cx = ((const char *)vx + ix*bx);
- c16[0] = cx[0]; c16[1] = cx[1];
- //std::memcpy(&d16, cx, sizeof(d16));
- cx += sizeof(d16);
- scale = GGML_FP16_TO_FP32(d16);
-
- const block_iq1_bn * x = (const block_iq1_bn *)cx;
-
- if constexpr (nrc_y == 1) {
- int32x4_t acc[4] = {};
- for (int i = 0; i < nb/2; ++i) {
- deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1);
- auto q = q8.load_quants64(0, i, 0);
- for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
- deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2);
- q = q8.load_quants64(0, i, 1);
- for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v2.val[j]);
- }
- accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3]));
- }
- else {
-
- for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0);
-
- for (int i = 0; i < nb/2; ++i) {
-
- deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1);
- deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2);
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto q = q8.load_quants(iy, i, 0);
- accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
- q = q8.load_quants(iy, i, 1);
- accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
- q = q8.load_quants(iy, i, 2);
- accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]);
- q = q8.load_quants(iy, i, 3);
- accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]);
- }
- }
- }
- int i = 2*(nb/2);
- if (i < nb) {
- deq.prepare_iq1bn_quants_nosub(x+i, v1);
- if constexpr (nrc_y == 1) {
- auto q = q8.load_quants(0, i/2, 0);
- for (int j = 0; j < 4; ++j) {
- accd[0] = ggml_vdotq_s32(accd[0], q.val[j], v1.val[j]);
- }
- } else {
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto q = q8.load_quants(iy, i/2, 0);
- accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
- q = q8.load_quants(iy, i/2, 1);
- accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
- }
- }
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, -scale * vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
- }
-
- }
-}
-
-template <int nrc> struct Q8_16 {
-
- constexpr static int nrc_y = nrc;
-
- Q8_16(const DataInfo& info) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto ptr = (const float *)info.src1_row(iy);
- std::memcpy(d + 5*iy, ptr, 5*sizeof(float));
- y[iy] = (const int8_t *)(ptr + 5);
- }
- }
-
- inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy] + 64*i); }
- inline int8x16x2_t load_quants_32(int iy, int i) const { return vld1q_s8_x2(y[iy] + 32*i); }
- inline float scale(int iy, int k) const { return d[5*iy+k]; }
- inline float sum_row(int iy) const { return d[5*iy + 4]; }
- inline float32x4_t scale(int iy) const { return vld1q_f32(d + 5*iy); }
-
- float d[5*nrc_y];
- const int8_t * y[nrc_y];
-};
-
-template <int nrc_y>
-static IQK_NOINLINE void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- if (nrc_x%4) {
- printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
- GGML_ABORT("fatal error");
- }
- Q8_16<nrc_y> q8(info);
- auto m3 = vdupq_n_u8(0x3);
- int nb = n / QK_IQ1BN;
- if constexpr (nrc_y == 1) {
- auto mc = vdupq_n_u8(0xc);
- int32x4_t acc[8];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- for (int k = 0; k < 8; ++k) acc[k] = vdupq_n_s32(0);
- const float * dptr = (const float *)((const char *)vx + ix*bx);
- auto dl = vld1q_f32(dptr);
- const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
- for (int ib = 0; ib < nb; ++ib) {
- auto y = q8.load_quants(0, ib);
- for (int j = 0; j < 4; ++j) {
- auto bits1 = vld1q_u8(iq2 + 64*ib + 16*j);
- auto bits2 = vshrq_n_u8(bits1, 4);
- acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits1, m3), y.val[j], 0);
- acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits1, mc), y.val[j], 1);
- acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits2, m3), y.val[j], 2);
- acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits2, mc), y.val[j], 3);
- }
- }
- auto dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 0)));
- auto sumf1 = vmulq_f32( vcvtq_f32_s32(acc[0]), dy);
- auto sumf2 = vmulq_f32( vcvtq_f32_s32(acc[1]), dy);
- dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 1)));
- sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[2]), dy);
- sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[3]), dy);
- dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 2)));
- sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[4]), dy);
- sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[5]), dy);
- dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 3)));
- sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[6]), dy);
- sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[7]), dy);
- auto sumf = vfmaq_f32(sumf1, vdupq_n_f32(0.25f), sumf2);
- sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(0)));
- info.store(ix, 0, sumf);
- }
- } else {
- int32x4_t acc[4*nrc_y] = {};
- uint8x16_t qx[8];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const float * dptr = (const float *)((const char *)vx + ix*bx);
- auto dl = vld1q_f32(dptr);
- const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
- for (int ib = 0; ib < nb; ++ib) {
- auto bits = vld1q_u8_x2(iq2 + 64*ib);
- qx[0] = vandq_u8(bits.val[0], m3);
- qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
- qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
- qx[3] = vshrq_n_u8(bits.val[0], 6);
- qx[4] = vandq_u8(bits.val[1], m3);
- qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
- qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
- qx[7] = vshrq_n_u8(bits.val[1], 6);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = q8.load_quants_32(iy, 2*ib+0);
- acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[0], y.val[0], 0);
- acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[1], y.val[0], 1);
- acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[2], y.val[0], 2);
- acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[3], y.val[0], 3);
- acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[4], y.val[1], 0);
- acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[5], y.val[1], 1);
- acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[6], y.val[1], 2);
- acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[7], y.val[1], 3);
- }
- bits = vld1q_u8_x2(iq2 + 64*ib + 32);
- qx[0] = vandq_u8(bits.val[0], m3);
- qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
- qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
- qx[3] = vshrq_n_u8(bits.val[0], 6);
- qx[4] = vandq_u8(bits.val[1], m3);
- qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
- qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
- qx[7] = vshrq_n_u8(bits.val[1], 6);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = q8.load_quants_32(iy, 2*ib+1);
- acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[0], y.val[0], 0);
- acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[1], y.val[0], 1);
- acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[2], y.val[0], 2);
- acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[3], y.val[0], 3);
- acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[4], y.val[1], 0);
- acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[5], y.val[1], 1);
- acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[6], y.val[1], 2);
- acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[7], y.val[1], 3);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dy = q8.scale(iy);
- float32x4_t sumf = vmulq_f32(vcvtq_f32_s32(acc[4*iy+0]), vmulq_laneq_f32(dl, dy, 0));
- sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+1]), vmulq_laneq_f32(dl, dy, 1));
- sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+2]), vmulq_laneq_f32(dl, dy, 2));
- sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+3]), vmulq_laneq_f32(dl, dy, 3));
- sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(iy)));
- info.store(ix, iy, sumf);
- acc[4*iy+0] = acc[4*iy+1] = acc[4*iy+2] = acc[4*iy+3] = vdupq_n_s32(0);
- }
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- const int nb = n / QK_IQ1BN;
-
- Q8_K64<nrc_y> q8(info);
-
- int32x4_t accd[nrc_y];
-
- const auto mask2 = vdupq_n_s8(3);
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- const float * dptr = (const float *)((const char *)vx + ix*bx);
- const float d = *dptr;
- const block_iq2_bn * x = (const block_iq2_bn *)(dptr + 1);
-
- if constexpr (nrc_y == 1) {
- int8x16x4_t v1;
- int32x4_t acc[4] = {};
- for (int i = 0; i < nb/2; ++i) {
- for (int j = 0; j < 2; ++j) {
- auto q = q8.load_quants64(0, i, j);
- auto q2bits = vld1q_u8(x[2*i+j].qs);
- v1.val[0] = vandq_s8(q2bits, mask2);
- v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
- v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
- v1.val[3] = vshrq_n_u8(q2bits, 6);
- acc[0] = ggml_vdotq_s32(acc[0], q.val[0], v1.val[0]);
- acc[1] = ggml_vdotq_s32(acc[1], q.val[1], v1.val[1]);
- acc[2] = ggml_vdotq_s32(acc[2], q.val[2], v1.val[2]);
- acc[3] = ggml_vdotq_s32(acc[3], q.val[3], v1.val[3]);
- }
- }
- accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3]));
- } else {
- int8x16x4_t v1, v2;
- for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0);
- for (int i = 0; i < nb/2; ++i) {
- auto q2bits = vld1q_u8(x[2*i+0].qs);
- v1.val[0] = vandq_s8(q2bits, mask2);
- v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
- v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
- v1.val[3] = vshrq_n_u8(q2bits, 6);
- q2bits = vld1q_u8(x[2*i+1].qs);
- v2.val[0] = vandq_s8(q2bits, mask2);
- v2.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
- v2.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
- v2.val[3] = vshrq_n_u8(q2bits, 6);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto q = q8.load_quants(iy, i, 0);
- accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
- q = q8.load_quants(iy, i, 1);
- accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
- q = q8.load_quants(iy, i, 2);
- accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]);
- q = q8.load_quants(iy, i, 3);
- accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]);
- }
- }
- }
- int i = 2*(nb/2);
- if (i < nb) {
- auto q2bits = vld1q_u8(x[i].qs);
- int8x16x4_t v1;
- v1.val[0] = vandq_s8(q2bits, mask2);
- v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
- v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
- v1.val[3] = vshrq_n_u8(q2bits, 6);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto q = q8.load_quants(iy, i/2, 0);
- accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
- q = q8.load_quants(iy, i/2, 1);
- accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
- }
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, -d*vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
- }
- }
-}
-
-IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) {
- auto sumi = vdupq_n_s32(0);
- sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
- sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
- sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
- sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
- sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
- sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
- sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
- sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
- return sumi;
-}
-
-IQK_ALWAYS_INLINE int32x4x2_t interleaved_dotq_b16(const int8x16_t * qx, const int8x16x2_t& y) {
- int32x4x2_t sumi = { vdupq_n_s32(0), vdupq_n_s32(0) };
- sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[0], y.val[0], 0);
- sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[1], y.val[1], 0);
- sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[2], y.val[0], 1);
- sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[3], y.val[1], 1);
- sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[4], y.val[0], 2);
- sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[5], y.val[1], 2);
- sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[6], y.val[0], 3);
- sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[7], y.val[1], 3);
- return sumi;
-}
-
-IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16_t& y) {
- auto sumi = vdupq_n_s32(0);
- sumi = vdotq_laneq_s32(sumi, qx[0], y, 0);
- sumi = vdotq_laneq_s32(sumi, qx[1], y, 1);
- sumi = vdotq_laneq_s32(sumi, qx[2], y, 2);
- sumi = vdotq_laneq_s32(sumi, qx[3], y, 3);
- return sumi;
-}
-
-IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) {
- qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
- qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
- qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
- qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
- qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
- qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
- qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
- qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
-}
-
-IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x2_t& bits, int8x16_t * qx) {
- qx[0] = vqtbl1q_s8(values, vandq_u8( bits.val[0], m4));
- qx[1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4));
- qx[2] = vqtbl1q_s8(values, vandq_u8( bits.val[1], m4));
- qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4));
-}
-
-template <int nrc_y>
-void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = vdupq_n_u8(0xf);
- auto m3 = vdupq_n_u8(0x30);
- auto m32 = vdupq_n_s8(-32);
- auto values = vld1q_s8(iq4k_values);
- int nbl = n / QK_K;
- int8x16_t qx[8];
- int8x16x4_t iscales;
- int32x4x2_t scales;
- float32x4_t acc[2*nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_iq4_xs_r8 * iq4 = (const block_iq4_xs_r8 *)((const char *)vx + ix*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) {
- auto d4_f16 = vld1q_f16((const float16_t *)iq4[ibl].d);
- auto d4l = vcvt_f32_f16(vget_low_f16 (d4_f16));
- auto d4h = vcvt_f32_f16(vget_high_f16(d4_f16));
- auto sl = vld1q_u8_x2(iq4[ibl].scales_l);
- auto sh = vld1q_u8(iq4[ibl].scales_h);
- iscales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
- iscales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
- iscales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
- iscales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
- int32x4_t isum[nrc_y] = {};
- for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
- auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[ib64]));
- auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[ib64]));
- scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1));
- scales.val[1] = vmovl_s16(vget_low_s16(iscales16_2));
- for (int l = 0; l < 2; ++l) {
- uint8x16x2_t bits;
- bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l);
- bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 32);
- prepare_iq4_nl_quants_r8(values, m4, bits, qx+0);
- bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 64);
- bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 96);
- prepare_iq4_nl_quants_r8(values, m4, bits, qx+4);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+64*ib64+32*l);
- auto sumi = vdupq_n_s32(0);
- sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
- sumi = vdotq_laneq_s32(sumi, qx[1], y.val[0], 1);
- sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 2);
- sumi = vdotq_laneq_s32(sumi, qx[3], y.val[0], 3);
- sumi = vdotq_laneq_s32(sumi, qx[4], y.val[1], 0);
- sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 1);
- sumi = vdotq_laneq_s32(sumi, qx[6], y.val[1], 2);
- sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
- isum[iy] = vmlaq_s32(isum[iy], sumi, scales.val[l]);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
- acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[iy]));
- isum[iy] = vdupq_n_s32(0);
- }
- for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
- auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[ib64]));
- auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[ib64]));
- scales.val[0] = vmovl_s16(vget_high_s16(iscales16_1));
- scales.val[1] = vmovl_s16(vget_high_s16(iscales16_2));
- for (int l = 0; l < 2; ++l) {
- uint8x16x2_t bits;
- bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 16);
- bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 48);
- prepare_iq4_nl_quants_r8(values, m4, bits, qx+0);
- bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 80);
- bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l +112);
- prepare_iq4_nl_quants_r8(values, m4, bits, qx+4);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+64*ib64+32*l);
- auto sumi = vdupq_n_s32(0);
- sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
- sumi = vdotq_laneq_s32(sumi, qx[1], y.val[0], 1);
- sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 2);
- sumi = vdotq_laneq_s32(sumi, qx[3], y.val[0], 3);
- sumi = vdotq_laneq_s32(sumi, qx[4], y.val[1], 0);
- sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 1);
- sumi = vdotq_laneq_s32(sumi, qx[6], y.val[1], 2);
- sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
- isum[iy] = vmlaq_s32(isum[iy], sumi, scales.val[l]);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
- acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[iy]));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix+0, iy, acc[2*iy+0]);
- info.store(ix+4, iy, acc[2*iy+1]);
- acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = vdupq_n_u8(0xf);
- auto values = vld1q_s8(iq4k_values);
- int nbl = n / QK_K;
- int8x16_t qx[8];
- int16x8x4_t iscales;
- int32x4x4_t scales;
- float32x4_t acc[nrc_y] = {};
- int32x4_t isum[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto dptr = (const float *)((const char *)vx + ix*bx);
- auto d4 = vld1q_f32(dptr);
- const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4);
- for (int ibl = 0; ibl < nbl; ++ibl) {
- auto sas = vld1q_u8_x2(iq4[ibl].scales);
- auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254));
- iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
- iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
- scale = vandq_u8(sas.val[1], vdupq_n_u8(254));
- iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
- iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
- // Adding the block shifts costs us ~9% in performance drop.
- // Is there a better way?
- sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 2);
- sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 2);
- {
- auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0])));
- auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0])));
- auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1])));
- auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1])));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums);
- auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]);
- auto b8 = vget_low_s16(bs);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
- b8 = vget_high_s16(bs);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
- }
- }
- for (int is = 0; is < 2; ++is) {
- scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0]));
- scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0]));
- scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1]));
- scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1]));
- for (int ib = 0; ib < 4; ++ib) {
- auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
- prepare_iq4_nl_quants(values, m4, bits, qx);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy]));
- isum[iy] = vdupq_n_s32(0);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, vmulq_f32(d4, acc[iy]));
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-void mul_mat_iq5_ks_r4_q8_k_neon(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = vdupq_n_u8(0xf);
- auto m10 = vdupq_n_u8(0x10);
- auto values = vld1q_s8_x2(iq5nl_values);
- int nbl = n / QK_K;
- int8x16_t qx[8];
- int16x8x4_t iscales;
- int32x4x4_t scales;
- float32x4_t acc[nrc_y] = {};
- int32x4_t isum[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto dptr = (const float *)((const char *)vx + ix*bx);
- auto d4 = vld1q_f32(dptr);
- const block_iq5_ks_r4 * iq5 = (const block_iq5_ks_r4 *)(dptr + 4);
- for (int ibl = 0; ibl < nbl; ++ibl) {
- auto sas = vld1q_u8_x2(iq5[ibl].scales);
- auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254));
- iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
- iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
- scale = vandq_u8(sas.val[1], vdupq_n_u8(254));
- iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
- iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
- // Adding the block shifts costs us ~9% in performance drop.
- // Is there a better way?
- sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 1);
- sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 1);
- {
- auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0])));
- auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0])));
- auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1])));
- auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1])));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums);
- auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]);
- auto b8 = vget_low_s16(bs);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
- b8 = vget_high_s16(bs);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
- }
- }
- for (int is = 0; is < 2; ++is) {
- scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0]));
- scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0]));
- scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1]));
- scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1]));
- for (int ib = 0; ib < 4; ++ib) {
- auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
- auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
- qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4)));
- qx[1] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2)));
- qx[2] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits));
- qx[3] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2)));
- qx[4] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3)));
- qx[5] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1)));
- qx[6] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1)));
- qx[7] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3)));
- for (int l = 0; l < 8; ++l) qx[l] = vqtbl2q_s8(values, qx[l]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy]));
- isum[iy] = vdupq_n_s32(0);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, vmulq_f32(d4, acc[iy]));
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq2_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- int nbl = n / QK_K;
- float32x4_t acc[nrc_y] = {};
- int32x4_t isum[nrc_y] = {};
- int8x16_t qx[8];
- SignHelper sh;
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto iq2 = (const block_iq2_xxs_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
- auto qs = iq2[ibl].qs;
- for (int ib = 0; ib < QK_K/32; ++ib) {
- auto sas = vld1q_u8(iq2[ibl].sas + 16*ib);
- auto scale_bits = vandq_u8(sas, vdupq_n_u8(1));
- auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402)));
- auto signs128 = vandq_u8(sas, vdupq_n_u8(254));
- signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1));
- sh.init();
- for (int i = 0; i < 8; ++i) {
- qx[i] = vreinterpretq_s8_u64(uint64x2_t{iq2xxs_grid[qs[2*i+0]], iq2xxs_grid[qs[2*i+1]]});
- sh.apply_signs_1((uint8x16_t *)qx+i, signs128);
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib);
- auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]);
- auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]);
- auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]);
- auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]);
- auto sumi12 = vpaddq_s32(sumi1, sumi2);
- auto sumi34 = vpaddq_s32(sumi3, sumi4);
- auto sumi = vpaddq_s32(sumi12, sumi34);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- qs += 16;
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
- isum[iy] = vdupq_n_s32(0);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy]));
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- int nbl = n / QK_K;
- static const uint8_t k_shuff[16] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31};
- auto shuff = vld1q_u8(k_shuff);
- float32x4_t acc[nrc_y] = {};
- int32x4_t isum[2*nrc_y] = {};
- int8x16_t qx[8];
- uint16x8x4_t scales16;
- SignHelper sh;
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto iq2 = (const block_iq2_xs_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
- auto qs = iq2[ibl].qs;
- for (int is = 0; is < 2; ++is) {
- auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is);
- auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf));
- auto scales2 = vshrq_n_u8(scale_bits, 4);
- scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1));
- scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1));
- auto s1 = vzip1q_u8(scales1, scales2);
- auto s2 = vzip2q_u8(scales1, scales2);
- scales16.val[0] = vmovl_u8(vget_low_u8 (s1));
- scales16.val[1] = vmovl_u8(vget_high_u8(s1));
- scales16.val[2] = vmovl_u8(vget_low_u8 (s2));
- scales16.val[3] = vmovl_u8(vget_high_u8(s2));
- for (int ib = 0; ib < QK_K/64; ++ib) {
- auto v = vld1q_u8_x2((const uint8_t *)qs);
- auto signs128 = vandq_u8(vqtbl2q_u8(v, shuff), vdupq_n_u8(254));
- signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1));
- sh.init();
- for (int i = 0; i < 8; ++i) {
- qx[i] = vreinterpretq_s8_u64(uint64x2_t{iq2xs_grid[qs[2*i+0] & 511], iq2xs_grid[qs[2*i+1] & 511]});
- sh.apply_signs_1((uint8x16_t *)qx+i, signs128);
- }
- auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib]));
- auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib]));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib);
- auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1]));
- auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1]));
- auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1]));
- auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1]));
- auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1
- auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3
- isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12);
- isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34);
- }
- qs += 16;
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]);
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi));
- isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy]));
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<1, block_q8_K128> q8(info);
- int nb = n / 32;
- GGML_ASSERT(nb%4 == 0);
- int8x16_t qx[8];
- float32x4_t acc[2] = {};
- int32x4_t isum[8];
- auto ms = vdup_n_u16(0x8000);
- for (int ix= 0; ix < nrc_x; ix += 4) {
- auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
- auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr));
- auto x = (const block_iq1_s_r4 *)(dptr + 4);
- for (int ib = 0; ib < nb/4; ++ib) {
- auto scale_yd = vdupq_n_f32(q8.y[0][ib].d);
- auto scale_ym = vmulq_f32(scale_yd, vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[0][ib].bsums))));
- for (int k = 0; k < 4; ++k) {
- auto sas = vld1_u16(x[4*ib+k].qh);
- auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7));
- scales4 = vorr_u16(vshl_n_u16(scales4, 1), vdup_n_u16(1));
- auto signs = vreinterpret_s16_u16(vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1)));
- isum[k+4] = vmull_s16(signs, scales4);
- qx[0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
- iq1s_grid[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)]});
- qx[1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)],
- iq1s_grid[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)]});
- qx[2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
- iq1s_grid[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)]});
- qx[3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)],
- iq1s_grid[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)]});
- qx[4] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
- iq1s_grid[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)]});
- qx[5] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)],
- iq1s_grid[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)]});
- qx[6] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)],
- iq1s_grid[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]});
- qx[7] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)],
- iq1s_grid[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]});
- auto scales = vmovl_u16(scales4);
- auto y = vld1q_s8_x2(q8.y[0][ib].qs + 32*k);
- auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]);
- auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]);
- auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]);
- auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]);
- sumi1 = vpaddq_s32(sumi1, sumi2);
- sumi3 = vpaddq_s32(sumi3, sumi4);
- isum[k] = vmulq_s32(scales, vpaddq_s32(sumi1, sumi3));
- }
- acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[0]), scale_yd, 0);
- acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[1]), scale_yd, 1);
- acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[2]), scale_yd, 2);
- acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[3]), scale_yd, 3);
- acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[4]), scale_ym, 0);
- acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[5]), scale_ym, 1);
- acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[6]), scale_ym, 2);
- acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[7]), scale_ym, 3);
- }
- info.store(ix, 0, vmulq_f32(d1, vfmaq_f32(acc[0], acc[1], vdupq_n_f32(IQ1S_DELTA))));
- acc[0] = acc[1] = vdupq_n_f32(0.f);
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K128> q8(info);
- int nb = n / 32;
- GGML_ASSERT(nb%4 == 0);
- uint8x16_t qx[8];
- float32x4_t acc[nrc_y] = {};
- auto ms = vdup_n_u16(0x8000);
- auto mask = vdupq_n_s8(0x03);
- float d8[4*nrc_y];
- for (int ix= 0; ix < nrc_x; ix += 4) {
- auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
- auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr));
- auto x = (const block_iq1_s_r4 *)(dptr + 4);
- for (int ib = 0; ib < nb/4; ++ib) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto scales = vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[iy][ib].bsums)));
- vst1q_f32(d8+4*iy, vmulq_f32(vdupq_n_f32(q8.y[iy][ib].d), scales));
- }
- for (int k = 0; k < 4; ++k) {
- auto sas = vld1_u16(x[4*ib+k].qh);
- auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7));
- scales4 = vorr_u16(vshl_n_u16(scales4, 1), vdup_n_u16(1));
- auto signs = vreinterpret_s16_u16(vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1)));
- signs = vadd_s16(vdup_n_s16(-8), signs);
- auto delta4 = vmulq_f32(vdupq_n_f32(0.125f), vcvtq_f32_s32(vmull_s16(signs, scales4)));
- qx[0] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
- qx[2] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]});
- qx[4] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]});
- qx[6] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]});
- qx[1] = vandq_u8(vshrq_n_u8(qx[0], 4), mask); qx[0] = vandq_u8(qx[0], mask);
- qx[3] = vandq_u8(vshrq_n_u8(qx[2], 4), mask); qx[2] = vandq_u8(qx[2], mask);
- qx[5] = vandq_u8(vshrq_n_u8(qx[4], 4), mask); qx[4] = vandq_u8(qx[4], mask);
- qx[7] = vandq_u8(vshrq_n_u8(qx[6], 4), mask); qx[6] = vandq_u8(qx[6], mask);
- auto scales = vmovl_u16(scales4);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ib].qs + 32*k);
- auto sumi = vdupq_n_s32(0);
- sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[0]), y.val[0], 0);
- sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[1]), y.val[0], 1);
- sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[2]), y.val[0], 2);
- sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[3]), y.val[0], 3);
- sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[4]), y.val[1], 0);
- sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[5]), y.val[1], 1);
- sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[6]), y.val[1], 2);
- sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[7]), y.val[1], 3);
- sumi = vmulq_s32(scales, sumi);
- acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(sumi));
- acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4*iy+k]), delta4);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, vmulq_f32(d1, acc[iy]));
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(n%QK_K == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- int8x16_t qx[16];
- int32x4_t scales[2];
- int16x4_t deltas[2];
- float32x4_t acc[nrc_y] = {};
- auto delta_mask = vdupq_n_u16(0x8000);
- for (int ix = 0; ix < nrc_x; ++ix) {
- auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);
- for (int ibl = 0; ibl < n/QK_K; ++ibl) {
- float d = GGML_FP16_TO_FP32(iq1s[ibl].d);
- auto qhb = vld1q_u16(iq1s[ibl].qh);
- auto scales128 = vandq_u16(vshrq_n_u16(qhb, 12), vdupq_n_u16(7));
- scales128 = vaddq_u16(vshlq_n_u16(scales128, 1), vdupq_n_u16(1));
- auto mask = vceqq_u16(vandq_u16(qhb, delta_mask), delta_mask);
- // Note: we explicitely assume IQ1S_DELTA = 0.125
- auto deltas128 = vsubq_s16(vbicq_s16(scales128, mask), vandq_s16(scales128, mask));
- //auto deltas128 = vorrq_s16(vandq_s16(vdupq_n_s16(-1), mask), vbicq_s16(vdupq_n_s16(1), mask));
- //deltas128 = vmulq_s16(scales128, deltas128);
- scales128 = vshlq_n_u16(scales128, 3);
- auto qs = iq1s[ibl].qs;
- auto qh = iq1s[ibl].qh;
- for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
- qx[4*ib64+0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[0] | ((qh[2*ib64+0] << 8) & 0x700)], iq1s_grid[qs[1] | ((qh[2*ib64+0] << 5) & 0x700)]});
- qx[4*ib64+1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[2] | ((qh[2*ib64+0] << 2) & 0x700)], iq1s_grid[qs[3] | ((qh[2*ib64+0] >> 1) & 0x700)]});
- qx[4*ib64+2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[4] | ((qh[2*ib64+1] << 8) & 0x700)], iq1s_grid[qs[5] | ((qh[2*ib64+1] << 5) & 0x700)]});
- qx[4*ib64+3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[6] | ((qh[2*ib64+1] << 2) & 0x700)], iq1s_grid[qs[7] | ((qh[2*ib64+1] >> 1) & 0x700)]});
- qs += 8;
- }
- scales[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales128)));
- scales[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales128)));
- deltas[0] = vget_low_s16 (deltas128);
- deltas[1] = vget_high_s16(deltas128);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto bsums = q8.load_bsums8(iy, ibl);
- auto sumi = vdupq_n_s32(0);
- sumi = vmlal_s16(sumi, deltas[0], vget_low_s16 (bsums));
- sumi = vmlal_s16(sumi, deltas[1], vget_high_s16(bsums));
- for (int k = 0; k < QK_K/128; ++k) {
- auto qy = q8.load_quants_64(iy, ibl, 2*k+0);
- auto dot1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+0], qy.val[0]), qx[8*k+1], qy.val[1]);
- auto dot2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+2], qy.val[2]), qx[8*k+3], qy.val[3]);
- auto dot12 = vpaddq_s32(dot1, dot2);
- qy = q8.load_quants_64(iy, ibl, 2*k+1);
- auto dot3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+4], qy.val[0]), qx[8*k+5], qy.val[1]);
- auto dot4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+6], qy.val[2]), qx[8*k+7], qy.val[3]);
- auto dot34 = vpaddq_s32(dot3, dot4);
- auto dot = vpaddq_s32(dot12, dot34);
- sumi = vmlaq_s32(sumi, dot, scales[k]);
- }
- acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, ibl)), vcvtq_f32_s32(sumi));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, 0.125f*vaddvq_f32(acc[iy]));
- acc[iy] = vdupq_n_f32(0);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K128> q8(info);
- int nb = n / 32;
- GGML_ASSERT(nb%4 == 0);
- int8x16_t qx[8];
- float32x4_t acc[nrc_y] = {};
- int32x4_t isum[nrc_y] = {};
- auto shuffle0 = uint32x4_t{0x00000000, 0x01010101, 0x02020202, 0x03030303};
- auto step = vdupq_n_u8(4);
- auto ms = vdupq_n_u8(0x08);
- auto mask = vdupq_n_s8(0x18);
- for (int ix= 0; ix < nrc_x; ix += 4) {
- auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
- auto d1 = vmulq_f32(vdupq_n_f32(0.125f), vcvt_f32_f16(vld1_f16((const float16_t *)dptr)));
- auto x = (const block_iq1_m_r4 *)(dptr + 4);
- for (int ib = 0; ib < nb/4; ++ib) {
- for (int k = 0; k < 4; ++k) {
- auto scales4 = vdup_n_u32(((const uint32_t *)x[4*ib+k].scales)[0]);
- scales4 = vand_u8(vshl_u32(scales4, int32x2_t{0, -4}), vdup_n_u8(0xf));
- auto scales16 = vmovl_u8(scales4);
- auto scales1 = vmovl_u16(vget_low_u16(scales16));
- auto scales2 = vmovl_u16(vget_high_u16(scales16));
- auto qh = (const uint32_t *)x[4*ib+k].qh;
- auto idxh = uint32x4_t{qh[0], qh[0] >> 4, qh[1], qh[1] >> 4};
- auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(idxh, ms), ms), vdupq_n_u8(1)));
- signs = vaddq_s8(signs, vdupq_n_s8(-8));
- qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
- qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]});
- qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]});
- qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]});
- auto shuffle = shuffle0;
- for (int j = 0; j < 4; ++j) {
- auto s = vqtbl1q_s8(signs, shuffle);
- qx[2*j+1] = vaddq_s8(s, vandq_s8(vshrq_n_s8(qx[2*j+0], 1), mask));
- qx[2*j+0] = vaddq_s8(s, vandq_s8(vshlq_n_s8(qx[2*j+0], 3), mask));
- shuffle = vaddq_u8(shuffle, step);
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ib].qs + 32*k);
- auto sumi1 = vdupq_n_s32(0);
- auto sumi2 = vdupq_n_s32(0);
- sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[0]), y.val[0], 0);
- sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[1]), y.val[0], 1);
- sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[2]), y.val[0], 2);
- sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[3]), y.val[0], 3);
- sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[4]), y.val[1], 0);
- sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[5]), y.val[1], 1);
- sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[6]), y.val[1], 2);
- sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[7]), y.val[1], 3);
- isum[iy] = vmlaq_s32(vmlaq_s32(isum[iy], sumi1, scales1), sumi2, scales2);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(isum[iy]));
- isum[iy] = vdupq_n_s32(0);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, vmulq_f32(d1, acc[iy]));
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- int nbl = n / QK_K;
- float32x4_t acc[nrc_y] = {};
- int32x4_t isum[2*nrc_y] = {};
- int8x16_t qx[8];
- uint16x8x4_t scales16;
- SignHelper sh;
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
- auto qs = iq2[ibl].qs;
- auto qh = iq2[ibl].qh;
- for (int is = 0; is < 2; ++is) {
- auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is);
- auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf));
- auto scales2 = vshrq_n_u8(scale_bits, 4);
- scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1));
- scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1));
- auto s1 = vzip1q_u8(scales1, scales2);
- auto s2 = vzip2q_u8(scales1, scales2);
- scales16.val[0] = vmovl_u8(vget_low_u8 (s1));
- scales16.val[1] = vmovl_u8(vget_high_u8(s1));
- scales16.val[2] = vmovl_u8(vget_low_u8 (s2));
- scales16.val[3] = vmovl_u8(vget_high_u8(s2));
- for (int ib = 0; ib < QK_K/64; ++ib) {
- auto signs128 = vld1q_u8(iq2[ibl].signs + 64*is + 16*ib);
- sh.init();
- for (int i = 0; i < 4; ++i) {
- qx[2*i+0] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+0] | ((qh[i] << 8) & 0x300)], iq2s_grid[qs[4*i+1] | ((qh[i] << 6) & 0x300)]});
- sh.apply_signs_1((uint8x16_t *)qx+2*i+0, signs128);
- qx[2*i+1] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+2] | ((qh[i] << 4) & 0x300)], iq2s_grid[qs[4*i+3] | ((qh[i] << 2) & 0x300)]});
- sh.apply_signs_1((uint8x16_t *)qx+2*i+1, signs128);
- }
- qs += 16; qh += 4;
- auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib]));
- auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib]));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib);
- auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1]));
- auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1]));
- auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1]));
- auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1]));
- auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1
- auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3
- isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12);
- isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]);
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi));
- isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy]));
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- int nbl = n / QK_K;
- float32x4_t acc[nrc_y] = {};
- int32x4_t isum[nrc_y] = {};
- int8x16_t qx[8];
- SignHelper sh;
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto iq3 = (const block_iq3_xxs_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto d4 = vmulq_f32(vdupq_n_f32(0.25f), vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d)));
- auto qs = iq3[ibl].qs;
- for (int ib = 0; ib < QK_K/32; ++ib) {
- auto sas = vld1q_u8(iq3[ibl].sas + 16*ib);
- auto scale_bits = vandq_u8(sas, vdupq_n_u8(1));
- auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402)));
- auto signs128 = vandq_u8(sas, vdupq_n_u8(254));
- signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1));
- sh.init();
- for (int i = 0; i < 8; ++i) {
- qx[i] = vreinterpretq_s8_u32(uint32x4_t{iq3xxs_grid[qs[4*i+0]], iq3xxs_grid[qs[4*i+1]], iq3xxs_grid[qs[4*i+2]], iq3xxs_grid[qs[4*i+3]]});
- sh.apply_signs_1((uint8x16_t *)qx+i, signs128);
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib);
- auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]);
- auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]);
- auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]);
- auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]);
- auto sumi12 = vpaddq_s32(sumi1, sumi2);
- auto sumi34 = vpaddq_s32(sumi3, sumi4);
- auto sumi = vpaddq_s32(sumi12, sumi34);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- qs += 32;
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
- isum[iy] = vdupq_n_s32(0);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- int nbl = n / QK_K;
- float32x4_t acc[nrc_y] = {};
- int32x4_t isum[nrc_y] = {};
- int8x16_t qx[8];
- auto m1 = vdupq_n_u8(1);
- auto shuff = vreinterpretq_u8_u32(uint32x4_t{0xffffff00, 0xffffff01, 0xffffff02, 0xffffff03});
- uint32_t stored_scales[8];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
- auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
- auto qs = iq3[ibl].qs;
- auto qh = iq3[ibl].qh;
- auto scale_bits = vld1q_u8(iq3[ibl].scales);
- uint8x16x2_t scales8 = { vandq_u8(scale_bits, vdupq_n_u8(0xf)), vshrq_n_u8(scale_bits, 4) };
- scales8.val[0] = vorrq_u8(vshlq_n_u8(scales8.val[0], 1), m1);
- scales8.val[1] = vorrq_u8(vshlq_n_u8(scales8.val[1], 1), m1);
- vst1q_u8_x2((uint8_t *)stored_scales, scales8);
- for (int ib = 0; ib < QK_K/32; ++ib) {
- auto signs128 = vld1q_u8(iq3[ibl].signs+16*ib);
- if constexpr (nrc_y == 1) {
- auto qh32 = (const uint32_t *)qh;
- auto idx_h = vreinterpretq_u16_u64(vshlq_u64(vreinterpretq_u64_u16(vmovl_u8(vreinterpret_u8_u32(vdup_n_u32(qh32[0])))), int64x2_t{8, 4}));
- union { uint16x8_t vec; uint16_t val[8]; } hidx;
- for (int i = 0; i < 4; ++i) {
- auto idx_l = vmovl_u8(vld1_u8(qs));
- hidx.vec = vorrq_u16(idx_l, vandq_u16(idx_h, vdupq_n_u16(0x100))); idx_h = vshrq_n_u16(idx_h, 1);
- qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[0]], iq3s_grid[hidx.val[1]], iq3s_grid[hidx.val[2]], iq3s_grid[hidx.val[3]]});
- auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1));
- qx[2*i+0] = vmulq_s8(qx[2*i+0], signs);
- qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[4]], iq3s_grid[hidx.val[5]], iq3s_grid[hidx.val[6]], iq3s_grid[hidx.val[7]]});
- signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1));
- qx[2*i+1] = vmulq_s8(qx[2*i+1], signs);
- signs128 = vshrq_n_u8(signs128, 1);
- qs += 8;
- }
- } else {
- for (int i = 0; i < 4; ++i) {
- qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[0] | ((qh[0] << (8-i)) & 0x100)], iq3s_grid[qs[1] | ((qh[1] << (8-i)) & 0x100)],
- iq3s_grid[qs[2] | ((qh[2] << (8-i)) & 0x100)], iq3s_grid[qs[3] | ((qh[3] << (8-i)) & 0x100)]});
- auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1));
- qx[2*i+0] = vmulq_s8(qx[2*i+0], signs);
-
- qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[4] | ((qh[0] << (4-i)) & 0x100)], iq3s_grid[qs[5] | ((qh[1] << (4-i)) & 0x100)],
- iq3s_grid[qs[6] | ((qh[2] << (4-i)) & 0x100)], iq3s_grid[qs[7] | ((qh[3] << (4-i)) & 0x100)]});
- signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1));
- qx[2*i+1] = vmulq_s8(qx[2*i+1], signs);
-
- qs += 8;
- signs128 = vshrq_n_u8(signs128, 1);
- }
- }
- auto scales = vreinterpretq_s32_u8(vqtbl1q_u8(vreinterpretq_u8_u32(vdupq_n_u32(stored_scales[ib])), shuff));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- qh += 4;
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
- isum[iy] = vdupq_n_s32(0);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y, int k_shift>
-inline void iq3_4_add_shift(int ibl, const Q8<nrc_y, block_q8_K>& q8, const int8x16x4_t& i8scales, uint8x16_t extra,
- int32x4_t * isum) {
- auto ms = vdupq_n_s8(k_shift);
- int8x16_t s8_1, s8_2;
- if constexpr (k_shift == 5) {
- auto m1 = vdupq_n_u8(1);
- s8_1 = vmulq_s8(i8scales.val[0], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
- s8_2 = vmulq_s8(i8scales.val[1], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
- } else {
- if constexpr (k_shift == 4) {
- s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 2)));
- s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, extra));
- } else {
- s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 1)));
- s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, vshrq_n_u8(extra, 1)));
- }
- }
- auto s16_1 = vmovl_s8(vget_low_s8 (s8_1));
- auto s16_2 = vmovl_s8(vget_high_s8(s8_1));
- auto s16_3 = vmovl_s8(vget_low_s8 (s8_2));
- auto s16_4 = vmovl_s8(vget_high_s8(s8_2));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto b8 = vld1_s16(q8.y[iy][ibl].bsums);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
- b8 = vld1_s16(q8.y[iy][ibl].bsums+4);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
- }
- if constexpr (k_shift == 5) {
- auto m1 = vdupq_n_u8(1);
- s8_1 = vmulq_s8(i8scales.val[2], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
- s8_2 = vmulq_s8(i8scales.val[3], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
- } else {
- if constexpr (k_shift == 4) {
- s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 2)));
- s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 4)));
- } else {
- s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 3)));
- s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 5)));
- }
- }
- s16_1 = vmovl_s8(vget_low_s8 (s8_1));
- s16_2 = vmovl_s8(vget_high_s8(s8_1));
- s16_3 = vmovl_s8(vget_low_s8 (s8_2));
- s16_4 = vmovl_s8(vget_high_s8(s8_2));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto b8 = vld1_s16(q8.y[iy][ibl].bsums+8);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
- b8 = vld1_s16(q8.y[iy][ibl].bsums+12);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
- }
-}
-
-template <int nrc_y>
-void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = vdupq_n_u8(0xf);
- auto m03 = vdupq_n_u8(0x03);
- auto ms = vdupq_n_u8(4);
- uint8x16x2_t shift_shuffle = {
- vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
- vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
- };
- auto values8 = vld1_s8(iq2nl_values);
- auto values = vcombine_s8(values8, values8);
- int nbl = n / QK_K;
- int8x16_t qx[4];
- int8x16x4_t i8scales;
- int16x8x4_t i16scales;
- float32x4_t acc[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_iq2_k_r4 * iq2 = (const block_iq2_k_r4 *)((const char *)vx + ix*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) {
- auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
- auto extra8 = vld1_u8(iq2[ibl].extra);
- uint8x16_t extra;
- if constexpr (nrc_y == 1) {
- extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
- } else {
- extra = vcombine_u8(extra8, extra8);
- }
- auto sl = vld1q_u8_x2(iq2[ibl].scales);
- i8scales.val[0] = vaddq_s8(vandq_u8(sl.val[0], m4), vdupq_n_s8(-8));
- i8scales.val[1] = vaddq_s8(vandq_u8(sl.val[1], m4), vdupq_n_s8(-8));
- i8scales.val[2] = vaddq_s8(vshrq_n_u8(sl.val[0], 4), vdupq_n_s8(-8));
- i8scales.val[3] = vaddq_s8(vshrq_n_u8(sl.val[1], 4), vdupq_n_s8(-8));
- int32x4_t isum[nrc_y] = {};
- if constexpr (nrc_y == 1) {
- iq3_4_add_shift<nrc_y, 5>(ibl, q8, i8scales, extra, isum);
- }
- for (int is = 0; is < 2; ++is) {
- i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
- i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
- i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
- i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
- for (int ib = 0; ib < 4; ++ib) {
- auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
- auto bits = vld1q_u8_x2(iq2[ibl].qs + 128*is + 32*ib);
- qx[0] = vandq_u8( bits.val[0], m03);
- qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m03);
- qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m03);
- qx[3] = vandq_u8(vshrq_n_u8(bits.val[0], 6), m03);
- uint8x16_t shifts;
- if constexpr (nrc_y == 1) {
- qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
- qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
- qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
- qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
- } else {
- shifts = vandq_u8(ms, vshlq_n_u8(extra, 2));
- auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
- extra = vshrq_n_u8(extra, 1);
- qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
- qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
- qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
- qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- qx[0] = vandq_u8( bits.val[1], m03);
- qx[1] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m03);
- qx[2] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m03);
- qx[3] = vandq_u8(vshrq_n_u8(bits.val[1], 6), m03);
- if constexpr (nrc_y == 1) {
- qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
- qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
- qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
- qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
- } else {
- auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
- qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
- qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
- qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
- qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
- }
- scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = vdupq_n_u8(0xf);
- auto ms = nrc_y == 1 ? vdupq_n_u8(4) : vdupq_n_u8(8);
- auto m03 = vdupq_n_u8(0x03);
- auto m04 = vdupq_n_u8(0x04);
- uint8x16x2_t shift_shuffle = {
- vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
- vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
- };
- uint8x16x2_t smask = { vcombine_u8(vdup_n_u8(1), vdup_n_u8(2)), vcombine_u8(vdup_n_u8(4), vdup_n_u8(8)) };
- auto values = vld1q_s8(iq3nl_values);
- int nbl = n / QK_K;
- int8x16_t qx[4];
- int8x16x4_t i8scales;
- int16x8x4_t i16scales;
- float32x4_t acc[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + ix*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) {
- auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
- auto extra8 = vld1_u8(iq3[ibl].extra);
- uint8x16_t extra;
- if constexpr (nrc_y == 1) {
- extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
- } else {
- extra = vcombine_u8(extra8, extra8);
- }
- auto sl = vld1q_u8_x2(iq3[ibl].scales_l);
- auto sh8 = vld1_u8(iq3[ibl].scales_h);
- auto sh = vcombine_u8(sh8, sh8);
- i8scales.val[0] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[0], m4), 1), vdupq_n_s8(1));
- i8scales.val[1] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[1], m4), 1), vdupq_n_s8(1));
- i8scales.val[2] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[0], 4), 1), vdupq_n_s8(1));
- i8scales.val[3] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[1], 4), 1), vdupq_n_s8(1));
- i8scales.val[0] = vmulq_s8(i8scales.val[0], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1)));
- i8scales.val[1] = vmulq_s8(i8scales.val[1], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1)));
- sh = vshrq_n_u8(sh, 4);
- i8scales.val[2] = vmulq_s8(i8scales.val[2], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1)));
- i8scales.val[3] = vmulq_s8(i8scales.val[3], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1)));
- int32x4_t isum[nrc_y] = {};
- if constexpr (nrc_y == 1) {
- iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum);
- }
- for (int is = 0; is < 2; ++is) {
- i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
- i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
- i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
- i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
- for (int ib = 0; ib < 4; ++ib) {
- auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
- auto lbits = vld1q_u8_x2(iq3[ibl].qs + 128*is + 32*ib);
- auto hbits = vld1q_u8(iq3[ibl].qh + 64*is + 16*ib);
- qx[0] = vorrq_u8(vandq_u8( lbits.val[0], m03), vandq_u8(m04, vshlq_n_u8(hbits, 2)));
- qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 2), m03), vandq_u8(m04, vshlq_n_u8(hbits, 1)));
- qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 4), m03), vandq_u8(m04, hbits));
- qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 1)));
- uint8x16_t shifts;
- if constexpr (nrc_y == 1) {
- qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
- qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
- qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
- qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
- } else {
- shifts = vandq_u8(ms, vshlq_n_u8(extra, 3));
- auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
- extra = vshrq_n_u8(extra, 1);
- qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
- qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
- qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
- qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- qx[0] = vorrq_u8(vandq_u8( lbits.val[1], m03), vandq_u8(m04, vshrq_n_u8(hbits, 2)));
- qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 2), m03), vandq_u8(m04, vshrq_n_u8(hbits, 3)));
- qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 4), m03), vandq_u8(m04, vshrq_n_u8(hbits, 4)));
- qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 5)));
- if constexpr (nrc_y == 1) {
- qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
- qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
- qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
- qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
- } else {
- auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
- qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
- qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
- qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
- qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
- }
- scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = vdupq_n_u8(0xf);
- auto m3 = vdupq_n_u8(0x30);
- auto ms = vdupq_n_u8(4);
- auto m32 = vdupq_n_s8(-32);
- uint8x16x2_t shift_shuffle = {
- vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
- vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
- };
- auto values = vld1q_s8(iq4k_values);
- int nbl = n / QK_K;
- int8x16_t qx[4];
- int8x16x4_t i8scales;
- int16x8x4_t i16scales;
- float32x4_t acc[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_iq4_k_r4 * iq4 = (const block_iq4_k_r4 *)((const char *)vx + ix*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) {
- auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d));
- auto extra8 = vld1_u8(iq4[ibl].extra);
- uint8x16_t extra;
- if constexpr (nrc_y == 1) {
- extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
- } else {
- extra = vcombine_u8(extra8, extra8);
- }
- auto sl = vld1q_u8_x2(iq4[ibl].scales_l);
- auto sh = vld1q_u8(iq4[ibl].scales_h);
- i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
- i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
- i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
- i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
- int32x4_t isum[nrc_y] = {};
- if constexpr (nrc_y == 1) {
- iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum);
- }
- for (int is = 0; is < 2; ++is) {
- i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
- i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
- i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
- i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
- for (int ib = 0; ib < 4; ++ib) {
- auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
- uint8x16_t shifts;
- if constexpr (nrc_y == 1) {
- qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
- qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
- qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
- qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
- } else {
- shifts = vandq_u8(ms, vshlq_n_u8(extra, 2));
- auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
- extra = vshrq_n_u8(extra, 1);
- qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[0], m4))); // 0...3 from the 4 rows
- qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[2], m4))); // 4...7
- qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4))); // 8..11
- qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4))); // 12..15
- }
- auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- if constexpr (nrc_y == 1) {
- qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
- qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
- qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
- qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
- } else {
- auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
- qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[1], m4))); // 16..19
- qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[3], m4))); // 20..23
- qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4))); // 24..27
- qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4))); // 28..31
- }
- scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto m4 = vdupq_n_u8(0xf);
- auto m3 = vdupq_n_u8(0x30);
- auto ms = vdupq_n_u8(2);
- auto m32 = vdupq_n_s8(-32);
- auto m10 = vdupq_n_u8(0x10);
- uint8x16x2_t shift_shuffle = {
- vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
- vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
- };
- auto values = vld1q_s8_x2(iq5nl_values);
- int nbl = n / QK_K;
- int8x16_t qx[4];
- int8x16x4_t i8scales;
- int16x8x4_t i16scales;
- float32x4_t acc[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + ix*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) {
- auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d));
- auto extra8 = vld1_u8(iq5[ibl].extra);
- uint8x16_t extra;
- if constexpr (nrc_y == 1) {
- extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
- } else {
- extra = vcombine_u8(extra8, extra8);
- }
- auto sl = vld1q_u8_x2(iq5[ibl].scales_l);
- auto sh = vld1q_u8(iq5[ibl].scales_h);
- i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
- i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
- i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
- i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
- int32x4_t isum[nrc_y] = {};
- if constexpr (nrc_y == 1) {
- iq3_4_add_shift<nrc_y, 2>(ibl, q8, i8scales, extra, isum);
- }
- for (int is = 0; is < 2; ++is) {
- i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
- i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
- i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
- i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
- for (int ib = 0; ib < 4; ++ib) {
- auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
- auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
- qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4))); // aligns with 1st half of qx[0] in AVX2
- qx[1] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits)); // aligns with 1st half of qx[1] in AVX2
- qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3))); // aligns with 1st half of qx[2] in AVX2
- qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1))); // aligns with 1st half of qx[3] in AVX2
- uint8x16_t shifts;
- if constexpr (nrc_y == 1) {
- qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows
- qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7
- qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11
- qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15
- } else {
- shifts = vandq_u8(ms, vshlq_n_u8(extra, 1));
- auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
- extra = vshrq_n_u8(extra, 1);
- qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows
- qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7
- qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11
- qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15
- }
- auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- qx[0] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2))); // aligns with 2nd half of qx[0] in AVX2
- qx[1] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2))); // aligns with 2nd half of qx[1] in AVX2
- qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1))); // aligns with 2nd half of qx[2] in AVX2
- qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3))); // aligns with 2nd half of qx[3] in AVX2
- if constexpr (nrc_y == 1) {
- qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows
- qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7
- qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11
- qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15
- } else {
- auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
- qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows
- qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7
- qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11
- qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15
- }
- scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-IQK_ALWAYS_INLINE void prepare_q4_k_quants(const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) {
- qx[0] = vandq_u8(bits.val[0], m4); // 0...3 from the 4 rows
- qx[1] = vandq_u8(bits.val[1], m4); // 16..19
- qx[2] = vandq_u8(bits.val[2], m4); // 4...7
- qx[3] = vandq_u8(bits.val[3], m4); // 20..23
- qx[4] = vshrq_n_u8(bits.val[0], 4); // 8..11
- qx[5] = vshrq_n_u8(bits.val[1], 4); // 24..27
- qx[6] = vshrq_n_u8(bits.val[2], 4); // 12..15
- qx[7] = vshrq_n_u8(bits.val[3], 4); // 28..31
-}
-
-template <int nrc_y>
-void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto mf = vdupq_n_u8(0x0f);
- auto m03 = vdupq_n_u8(0x03);
- int nbl = n / QK_K;
- int8x16_t qx[4];
- float32x4_t acc[nrc_y] = {};
- int16x8x4_t i16scales;
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q2_k_r4 * iq2 = (const block_q2_k_r4 *)((const char *)vx + ix*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) {
- int32x4_t isum[nrc_y] = {};
- auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
- auto m4 = vmulq_f32(vdupq_n_f32(-1.f), vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d+4)));
- for (int is = 0; is < 2; ++is) {
- auto sl = vld1q_u8_x2(iq2[ibl].scales + 32*is);
- auto m = vshrq_n_u8(sl.val[0], 4);
- i16scales.val[0] = vmovl_u8(vget_low_u8 (m));
- i16scales.val[1] = vmovl_u8(vget_high_u8(m));
- m = vshrq_n_u8(sl.val[1], 4);
- i16scales.val[2] = vmovl_u8(vget_low_u8 (m));
- i16scales.val[3] = vmovl_u8(vget_high_u8(m));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi = vdupq_n_s32(0);
- auto bsums = vld1q_s16(q8.y[iy][ibl].bsums + 8*is);
- auto b8 = vget_low_s16(bsums);
- //auto bsums = q8.load_bsums(iy, ibl);
- //auto b8 = vget_low_s16(bsums.val[0]);
- sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[0]), b8, 0);
- sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[0]), b8, 1);
- sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[1]), b8, 2);
- sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[1]), b8, 3);
- b8 = vget_high_s16(bsums);
- sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[2]), b8, 0);
- sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[2]), b8, 1);
- sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[3]), b8, 2);
- sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[3]), b8, 3);
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(m4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi));
- }
- m = vandq_u8(sl.val[0], mf);
- i16scales.val[0] = vmovl_u8(vget_low_u8 (m));
- i16scales.val[1] = vmovl_u8(vget_high_u8(m));
- m = vandq_u8(sl.val[1], mf);
- i16scales.val[2] = vmovl_u8(vget_low_u8 (m));
- i16scales.val[3] = vmovl_u8(vget_high_u8(m));
- for (int ib = 0; ib < 4; ++ib) {
- auto bits = vld1q_u8_x2(iq2[ibl].qs + 128*is + 32*ib);
- auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
- qx[0] = vreinterpretq_s8_u8(vandq_u8( bits.val[0], m03));
- qx[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 2), m03));
- qx[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 4), m03));
- qx[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 6), m03));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
- qx[0] = vreinterpretq_s8_u8(vandq_u8( bits.val[1], m03));
- qx[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 2), m03));
- qx[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 4), m03));
- qx[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 6), m03));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto mf = vdupq_n_u8(0x0f);
- auto m30 = vdupq_n_u8(0x30);
- auto m32 = vdupq_n_s8(-32);
- auto m03 = vdupq_n_u8(0x03);
- auto m04 = vdupq_n_u8(0x04);
- int nbl = n / QK_K;
- int8x16_t qx[4];
- float32x4_t acc[nrc_y] = {};
- int8x16x4_t i8scales;
- int16x8x4_t i16scales;
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q3_k_r4 * iq3 = (const block_q3_k_r4 *)((const char *)vx + ix*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) {
- int32x4_t isum[nrc_y] = {};
- auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
- auto sl = vld1q_u8_x2(iq3[ibl].scales_l);
- auto sh = vld1q_u8(iq3[ibl].scales_h);
- i8scales.val[0] = vaddq_s8(m32, vorrq_u8(vandq_u8(sl.val[0], mf), vandq_u8(vshlq_n_u8(sh, 4), m30)));
- i8scales.val[1] = vaddq_s8(m32, vorrq_u8(vandq_u8(sl.val[1], mf), vandq_u8(vshlq_n_u8(sh, 2), m30)));
- i8scales.val[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m30)));
- i8scales.val[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m30)));
- for (int is = 0; is < 2; ++is) {
- i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
- i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
- i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
- i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
- for (int ib = 0; ib < 4; ++ib) {
- auto lbits = vld1q_u8_x2(iq3[ibl].qs + 128*is + 32*ib);
- auto hbits = vld1q_u8(iq3[ibl].qh + 64*is + 16*ib);
- hbits = veorq_u8(hbits, vdupq_n_u8(0xff));
- auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
- qx[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8( lbits.val[0], m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshlq_n_u8(hbits, 2))));
- qx[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 2), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshlq_n_u8(hbits, 1))));
- qx[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 4), m03)), vreinterpretq_s8_u8(vandq_u8(m04, hbits)));
- qx[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 6), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 1))));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
- qx[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8( lbits.val[1], m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 2))));
- qx[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 2), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 3))));
- qx[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 4), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 4))));
- qx[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 6), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 5))));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto mf = vdupq_n_u8(0xf);
- auto m3 = vdupq_n_u8(0x30);
- int nbl = n / QK_K;
- int8x16_t qx[8];
- int8x16x2_t iscales;
- int32x4x4_t scales;
- float32x4_t acc[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + ix*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) {
- auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d));
- auto m4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d+4));
- m4 = vmulq_f32(m4, vdupq_n_f32(-1.f));
- auto sl = vld1q_u8_x2(iq4[ibl].scales_l);
- auto sh = vld1q_u8(iq4[ibl].scales_h);
- iscales.val[0] = vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(vshlq_n_u8(sh, 2), m3));
- iscales.val[1] = vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3));
- for (int is = 0; is < 2; ++is) {
- auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
- auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
- float32x4x4_t fscales;
- fscales.val[0] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_1))));
- fscales.val[1] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_1))));
- fscales.val[2] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_2))));
- fscales.val[3] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_2))));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto m8 = vld1q_f32((const float *)q8.y[iy][ibl].bsums + 4*is);
- acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[0], m8, 0);
- acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[1], m8, 1);
- acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[2], m8, 2);
- acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[3], m8, 3);
- }
- }
- iscales.val[0] = vorrq_u8(vandq_u8(sl.val[0], mf), vandq_u8(vshlq_n_u8(sh, 4), m3));
- iscales.val[1] = vorrq_u8(vandq_u8(sl.val[1], mf), vandq_u8(sh, m3));
- int32x4_t isum[nrc_y] = {};
- for (int is = 0; is < 2; ++is) {
- auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
- auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
- scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1));
- scales.val[1] = vmovl_s16(vget_high_s16(iscales16_1));
- scales.val[2] = vmovl_s16(vget_low_s16(iscales16_2));
- scales.val[3] = vmovl_s16(vget_high_s16(iscales16_2));
- for (int ib = 0; ib < 4; ++ib) {
- auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
- prepare_q4_k_quants(mf, bits, qx);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto mf = vdupq_n_u8(0xf);
- auto m30 = vdupq_n_u8(0x30);
- auto m10 = vdupq_n_u8(0x10);
- int nbl = n / QK_K;
- int8x16_t qx[8];
- int8x16x2_t iscales;
- int32x4x4_t scales;
- float32x4_t acc[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + ix*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) {
- auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d));
- auto m4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d+4));
- m4 = vmulq_f32(m4, vdupq_n_f32(-1.f));
- auto sl = vld1q_u8_x2(iq5[ibl].scales_l);
- auto sh = vld1q_u8(iq5[ibl].scales_h);
- iscales.val[0] = vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(vshlq_n_u8(sh, 2), m30));
- iscales.val[1] = vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m30));
- for (int is = 0; is < 2; ++is) {
- auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
- auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
- float32x4x4_t fscales;
- fscales.val[0] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_1))));
- fscales.val[1] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_1))));
- fscales.val[2] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_2))));
- fscales.val[3] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_2))));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto m8 = vld1q_f32((const float *)q8.y[iy][ibl].bsums + 4*is);
- acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[0], m8, 0);
- acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[1], m8, 1);
- acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[2], m8, 2);
- acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[3], m8, 3);
- }
- }
- iscales.val[0] = vorrq_u8(vandq_u8(sl.val[0], mf), vandq_u8(vshlq_n_u8(sh, 4), m30));
- iscales.val[1] = vorrq_u8(vandq_u8(sl.val[1], mf), vandq_u8(sh, m30));
- int32x4_t isum[nrc_y] = {};
- for (int is = 0; is < 2; ++is) {
- auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
- auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
- scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1));
- scales.val[1] = vmovl_s16(vget_high_s16(iscales16_1));
- scales.val[2] = vmovl_s16(vget_low_s16(iscales16_2));
- scales.val[3] = vmovl_s16(vget_high_s16(iscales16_2));
- for (int ib = 0; ib < 4; ++ib) {
- auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
- auto hbits2 = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
- auto hbits1 = vshlq_n_u8(hbits2, 4);
- prepare_q4_k_quants(mf, lbits, qx);
- qx[0] = vorrq_u8(qx[0], vandq_u8(m10, hbits1));
- qx[1] = vorrq_u8(qx[1], vandq_u8(m10, hbits2));
- qx[2] = vorrq_u8(qx[2], vandq_u8(m10, vshrq_n_u8(hbits1, 2)));
- qx[3] = vorrq_u8(qx[3], vandq_u8(m10, vshrq_n_u8(hbits2, 2)));
- qx[4] = vorrq_u8(qx[4], vandq_u8(m10, vshrq_n_u8(hbits1, 1)));
- qx[5] = vorrq_u8(qx[5], vandq_u8(m10, vshrq_n_u8(hbits2, 1)));
- qx[6] = vorrq_u8(qx[6], vandq_u8(m10, vshrq_n_u8(hbits1, 3)));
- qx[7] = vorrq_u8(qx[7], vandq_u8(m10, vshrq_n_u8(hbits2, 3)));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- auto mf = vdupq_n_u8(0x0f);
- auto m3 = vdupq_n_u8(0x30);
- auto m32 = vdupq_n_s8(-32);
- int nbl = n / QK_K;
- int8x16_t qx[4];
- float32x4_t acc[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q6_k_r4 * iq6 = (const block_q6_k_r4 *)((const char *)vx + ix*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) {
- auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ibl].d));
- int32x4_t isum[nrc_y] = {};
- for (int is = 0; is < 2; ++is) {
- for (int ib = 0; ib < 4; ++ib) {
- auto lbits = vld1q_u8_x4(iq6[ibl].ql + 256*is + 64*ib);
- auto hbits = vld1q_u8(iq6[ibl].qh + 128*is + 32*ib);
- auto iscales = vmovl_s8(vld1_s8(iq6[ibl].scales + 32*is + 8*ib));
- auto scales = vmovl_s16(vget_low_s16(iscales));
- qx[0] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[0], mf), vandq_u8(m3, vshlq_n_u8(hbits, 4))));
- qx[1] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[2], mf), vandq_u8(m3, hbits)));
- qx[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m3, vshlq_n_u8(hbits, 2))));
- qx[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m3, vshrq_n_u8(hbits, 2))));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- scales = vmovl_s16(vget_high_s16(iscales));
- hbits = vld1q_u8(iq6[ibl].qh + 128*is + 32*ib + 16);
- qx[0] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[1], mf), vandq_u8(m3, vshlq_n_u8(hbits, 4))));
- qx[1] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[3], mf), vandq_u8(m3, hbits)));
- qx[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m3, vshlq_n_u8(hbits, 2))));
- qx[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m3, vshrq_n_u8(hbits, 2))));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
- auto sumi = interleaved_dotq(qx, y);
- isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
-void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
- Q8<nrc_y, block_q8_K> q8(info);
- int nbl = n / QK_K;
- float32x4_t acc[2*nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + ix*bx);
- for (int ibl = 0; ibl < nbl; ++ibl) {
- auto d4l = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+0));
- auto d4h = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+4));
- int32x4_t isum[2*nrc_y] = {};
- for (int ib = 0; ib < QK_K/16; ++ib) {
- auto q1 = vld1q_s8_x4(iq8[ibl].qs + 128*ib + 0);
- auto q2 = vld1q_s8_x4(iq8[ibl].qs + 128*ib + 64);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8.y[iy][ibl].qs+16*ib);
- isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[0], y, 0);
- isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[1], y, 0);
- isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[2], y, 1);
- isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[3], y, 1);
- isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[0], y, 2);
- isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[1], y, 2);
- isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[2], y, 3);
- isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[3], y, 3);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
- const float * bsum = (const float *)q8.y[iy][ibl].bsums;
- auto m8 = vdupq_n_f32(-128.f*bsum[0]);
- acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[2*iy+0]));
- acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[2*iy+1]));
- acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], d4l, m8);
- acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], d4l, m8);
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix+0, iy, acc[2*iy+0]);
- info.store(ix+4, iy, acc[2*iy+1]);
- acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
- }
- }
-}
-
-static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(n%32 == 0);
- int32x4_t acc[4] = {};
- auto dptr = (const float *)info.src1_row(0);
- const float dy = dptr[0];
- auto q8y = (const int8_t *)(dptr + 2);
- for (int ix = 0; ix < nrc_x; ++ix) {
- auto dx = (const float *)((const char *)vx + ix*bx);
- auto q8x = (const int8_t *)(dx + 2);
- for (int i = 0; i < n/64; ++i) {
- auto qx = vld1q_s8_x4(q8x + 64*i);
- for (int j = 0; j < 4; ++j) {
- acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 64*i + 16*j));
- }
- }
- if (int i = 2*(n/64); i < n/32) {
- auto qx = vld1q_s8_x2(q8x + 32*i);
- for (int j = 0; j < 2; ++j) {
- acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 32*i + 16*j));
- }
- }
- acc[0] = vaddq_s32(acc[0], acc[1]);
- acc[2] = vaddq_s32(acc[2], acc[3]);
- acc[0] = vaddq_s32(acc[0], acc[2]);
- info.store(ix, 0, dx[0]*dy*vaddvq_s32(acc[0]));
- acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0);
- }
-}
-
-template <int nrc_y>
-static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- GGML_ASSERT(n%16 == 0);
- int8x16_t qx[4];
- int32x4_t acc[nrc_y] = {};
- float dy[nrc_y];
- const int8_t * q8y[nrc_y];
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dptr = (const float *)info.src1_row(iy);
- dy[iy] = dptr[0];
- q8y[iy] = (const int8_t *)(dptr + 2);
- }
- const int8_t * q8x[4];
- float dx[4];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- for (int kx = 0; kx < 4; ++kx) {
- auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
- dx[kx] = dptr[0];
- q8x[kx] = (const int8_t *)(dptr + 2);
- }
- for (int i = 0; i < n/16; ++i) {
- for (int kx = 0; kx < 4; ++kx) qx[kx] = vld1q_s8(q8x[kx] + 16*i);
- auto row01 = vtrnq_s32(qx[0], qx[1]);
- auto row23 = vtrnq_s32(qx[2], qx[3]);
- qx[0] = vtrn1q_s64(row01.val[0], row23.val[0]);
- qx[1] = vtrn1q_s64(row01.val[1], row23.val[1]);
- qx[2] = vtrn2q_s64(row01.val[0], row23.val[0]);
- qx[3] = vtrn2q_s64(row01.val[1], row23.val[1]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8y[iy] + 16*i);
- acc[iy] = vdotq_laneq_s32(acc[iy], qx[0], y, 0);
- acc[iy] = vdotq_laneq_s32(acc[iy], qx[1], y, 1);
- acc[iy] = vdotq_laneq_s32(acc[iy], qx[2], y, 2);
- acc[iy] = vdotq_laneq_s32(acc[iy], qx[3], y, 3);
- }
- }
- auto scales_x = vld1q_f32(dx);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto scale = vmulq_f32(scales_x, vdupq_n_f32(dy[iy]));
- info.store(ix, iy, vmulq_f32(scale, vcvtq_f32_s32(acc[iy])));
- acc[iy] = vdupq_n_s32(0);
- }
- }
-}
-
-template <int nrc_y>
-void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
- int32x4_t acc[2*nrc_y] = {};
- float dy[nrc_y];
- const int8_t * q8y[nrc_y];
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dptr = (const float *)info.src1_row(iy);
- dy[iy] = dptr[0];
- q8y[iy] = (const int8_t *)(dptr + 2);
- }
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const float * dptr = (const float *)((const char *)vx + ix*bx);
- auto q8x = (const int8_t *)(dptr + 8);
- for (int ib = 0; ib < n/16; ++ib) {
- auto q1 = vld1q_s8_x4(q8x + 128*ib + 0);
- auto q2 = vld1q_s8_x4(q8x + 128*ib + 64);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8(q8y[iy]+16*ib);
- acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[0], y, 0);
- acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[1], y, 0);
- acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[2], y, 1);
- acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[3], y, 1);
- acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[0], y, 2);
- acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[1], y, 2);
- acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[2], y, 3);
- acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[3], y, 3);
- }
- }
- auto scale1_x = vld1q_f32(dptr+0);
- auto scale2_x = vld1q_f32(dptr+4);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto scale_y = vdupq_n_f32(dy[iy]);
- auto scale1 = vmulq_f32(scale1_x, scale_y);
- auto scale2 = vmulq_f32(scale2_x, scale_y);
- info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0])));
- info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1])));
- acc[2*iy+0] = acc[2*iy+1] = vdupq_n_s32(0.f);
- }
- }
-}
-
-void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<1, block_q8_0_x4> q8(info);
- auto m4 = vdupq_n_u8(0xf);
- auto values = vld1q_s8(iq4k_values);
- int nb = n / QK4_NL;
- GGML_ASSERT(nb%4 == 0);
- int8x16_t qx[8];
- for (int ix = 0; ix < nrc_x; ix += 4) {
- auto acc = vdupq_n_f32(0.f);
- const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- auto y1 = vld1q_s8_x4(q8.y[0][ib4].qs);
- auto y2 = vld1q_s8_x4(q8.y[0][ib4].qs+64);
- for (int k = 0; k < 4; ++k) {
- auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
- auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[0][ib4].d[k])));
- auto sumi = vdupq_n_s32(0);
- const auto yval = k < 2 ? y1.val + 2*k : y2.val + 2*(k-2);
- auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
- qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
- qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
- sumi = vdotq_laneq_s32(sumi, qx[0], yval[0], 0);
- sumi = vdotq_laneq_s32(sumi, qx[1], yval[1], 0);
- qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
- qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
- sumi = vdotq_laneq_s32(sumi, qx[2], yval[0], 1);
- sumi = vdotq_laneq_s32(sumi, qx[3], yval[1], 1);
- qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
- qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
- sumi = vdotq_laneq_s32(sumi, qx[4], yval[0], 2);
- sumi = vdotq_laneq_s32(sumi, qx[5], yval[1], 2);
- qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
- qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
- sumi = vdotq_laneq_s32(sumi, qx[6], yval[0], 3);
- sumi = vdotq_laneq_s32(sumi, qx[7], yval[1], 3);
- acc = vfmaq_f32(acc, d4d8, vcvtq_f32_s32(sumi));
- }
- }
- info.store(ix, 0, acc);
- }
-}
-
-template <typename Dequantizer, int nrc_y>
-void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_0_x4> q8(info);
- Dequantizer deq(vx, bx);
- int nb = n / QK4_NL;
- int8x16_t qx[8];
- float d8[4*nrc_y];
- float32x4_t acc[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- deq.new_row(ix);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
- }
- for (int k = 0; k < 4; ++k) {
- auto scales = deq.prepare(4*ib4+k, qx);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
- auto sumi = interleaved_dotq(qx, y);
- auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k]));
- acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
- }
- }
- }
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto scales = deq.prepare(ib, qx);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qy = (const block_q8_0 *)q8.y[iy];
- auto y = vld1q_s8_x2(qy[ib].qs);
- auto sumi = interleaved_dotq(qx, y);
- auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)));
- acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, deq.result(acc[iy]));
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <typename Dequantizer, int nrc_y>
-void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
- Q8<nrc_y, block_q8_0_x4> q8(info);
- Dequantizer deq(vx, bx);
- int nb = n / QK4_NL;
- int8x16_t qx[16];
- float d8[4*nrc_y];
- float32x4_t acc[2*nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 8) {
- deq.new_row(ix);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
- }
- for (int k = 0; k < 4; ++k) {
- auto scales = deq.prepare(ib4, k, qx);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
- auto sumi1 = interleaved_dotq(qx+0, y);
- auto sumi2 = interleaved_dotq(qx+8, y);
- auto dy = vdupq_n_f32(d8[4*iy+k]);
- acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1));
- acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2));
- }
- }
- }
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto scales = deq.prepare(ib, 0, qx);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qy = (const block_q8_0 *)q8.y[iy];
- auto y = vld1q_s8_x2(qy[ib].qs);
- auto sumi1 = interleaved_dotq(qx+0, y);
- auto sumi2 = interleaved_dotq(qx+8, y);
- auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
- acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1));
- acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix+0, iy, deq.result(acc[2*iy+0]));
- info.store(ix+4, iy, deq.result(acc[2*iy+1]));
- acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f);
- }
- }
-}
-
-struct IQ4_NL_R4_Dequantizer {
- IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {}
- inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); }
- inline float32x4_t prepare(int ib, int8x16_t * qx) const {
- auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ib].d));
- auto bits = vld1q_u8_x4(iq4[ib].qs);
- prepare_iq4_nl_quants(values, m4, bits, qx);
- return scales;
- }
- inline float32x4_t result(float32x4_t acc) const {
- return acc;
- }
-
- const char * cx;
- const size_t bx;
- const block_iq4_nl_r4 * iq4;
- const uint8x16_t m4 = vdupq_n_u8(0x0f);
- const int8x16_t values;
-};
-
-struct Q4_0_R4_Dequantizer {
- Q4_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
- inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); }
- inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
- auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
- auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
- for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]);
- qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows
- qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19
- qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7
- qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23
- qx[4] = vandq_u8(bits.val[0], m4); // 8..11
- qx[5] = vandq_u8(bits.val[1], m4); // 24..27
- qx[6] = vandq_u8(bits.val[2], m4); // 12..15
- qx[7] = vandq_u8(bits.val[3], m4); // 28..31
- return scales;
- }
- inline float32x4_t result(float32x4_t acc) const {
- return vmulq_f32(norm, acc);
- }
-
- const char * cx;
- const size_t bx;
- const block_iq4_nl_r4 * iq4;
- const uint8x16_t m4 = vdupq_n_u8(0xf0);
- const uint8x16_t m88 = vdupq_n_u8(0x88);
- const float32x4_t norm = vdupq_n_f32(1.f/16);
-};
-
-struct Q4_0_R8_Dequantizer {
- Q4_0_R8_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
- inline void new_row(int ix) { iq4 = (const block_iq4_nl_r8 *)(cx + ix*bx); }
- inline float32x4x2_t prepare(int ib4, int k, int8x16_t * qx) const {
- auto scales16 = vld1q_f16((const float16_t *)iq4[4*ib4+k].d);
- float32x4x2_t scales = { vcvt_f32_f16(vget_low_f16(scales16)), vcvt_f32_f16(vget_high_f16(scales16)) };
- for (int j = 0; j < 4; ++j) {
- auto bits = vld1q_u8_x2(iq4[4*ib4+k].qs + 32*j);
- bits.val[0] = veorq_u8(m88, bits.val[0]);
- bits.val[1] = veorq_u8(m88, bits.val[1]);
- qx[2*j+0] = vshlq_n_u8(bits.val[0], 4);
- qx[2*j+1] = vandq_u8(bits.val[0], m4);
- qx[2*j+8] = vshlq_n_u8(bits.val[1], 4);
- qx[2*j+9] = vandq_u8(bits.val[1], m4);
- }
- return scales;
- }
- inline float32x4_t result(float32x4_t acc) const {
- return vmulq_f32(norm, acc);
- }
-
- const char * cx;
- const size_t bx;
- const block_iq4_nl_r8 * iq4;
- const uint8x16_t m4 = vdupq_n_u8(0xf0);
- const uint8x16_t m88 = vdupq_n_u8(0x88);
- const float32x4_t norm = vdupq_n_f32(1.f/16);
-};
-
-struct Q5_0_R4_Dequantizer {
- Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
- inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); }
- inline float32x4_t prepare(int ib, int8x16_t * qx) const {
- auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ib].d));
- auto lbits = vld1q_u8_x4(iq5[ib].qs);
- auto hbits = vld1q_u8(iq5[ib].qh);
- qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3
- qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19
- qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7
- qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23
- qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11
- qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27
- qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15
- qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31
- return scales;
- }
- inline float32x4_t result(float32x4_t acc) const {
- return acc;
- }
-
- const char * cx;
- const size_t bx;
- const block_q5_0_r4 * iq5;
- const uint8x16_t m4 = vdupq_n_u8(0x0f);
- const uint8x16_t m5 = vdupq_n_u8(0x10);
- const int8x16_t m16 = vdupq_n_s8(-16);
-};
-
-struct Q6_0_R4_Dequantizer {
- Q6_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
- inline void new_row(int ix) { iq6 = (const block_q6_0_r4 *)(cx + ix*bx); }
- inline float32x4_t prepare(int ib, int8x16_t * qx) const {
- auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ib].d));
- auto lbits = vld1q_u8_x4(iq6[ib].qs);
- auto hbits = vld1q_u8_x2(iq6[ib].qh);
- qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3
- qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19
- qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7
- qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 2), m6), m32); // 20..23
- qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits.val[0], m6), m32); // 8..11
- qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(hbits.val[1], m6), m32); // 24..27
- qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits.val[0], 2), m6), m32); // 12..15
- qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits.val[1], 2), m6), m32); // 28..31
- return scales;
- }
- inline float32x4_t result(float32x4_t acc) const {
- return acc;
- }
-
- const char * cx;
- const size_t bx;
- const block_q6_0_r4 * iq6;
- const uint8x16_t m4 = vdupq_n_u8(0x0f);
- const uint8x16_t m6 = vdupq_n_u8(0x30);
- const int8x16_t m32 = vdupq_n_s8(-32);
-};
-
-inline void qx_0_q8_0_dot(const int8x16_t * qx, const int8_t * qy, int32x4_t& sumi1, int32x4_t& sumi2) {
- auto y = vld1q_s8_x2(qy);
- sumi1 = sumi2 = vdupq_n_s32(0);
- sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
- sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
- sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
- sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
- sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
- sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
- sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
- sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
- sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
- sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
- sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
- sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
- sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
- sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
- sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
- sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
-}
-
-template <int nrc_y>
-void mul_mat_q8_0_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
- Q8<nrc_y, block_q8_0_x4> q8(info);
- int nb = n / QK8_0;
- float32x4_t acc[2*nrc_y] = {};
- int8x16_t qx[16];
- float d8[4*nrc_y];
- for (int ix = 0; ix < nrc_x; ix += 8) {
- const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
- }
- for (int k = 0; k < 4; ++k) {
- auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d);
- auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
- auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
- for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j);
- int32x4_t sumi1, sumi2;
- for (int iy = 0; iy < nrc_y; ++iy) {
- qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2);
- auto dy = vdupq_n_f32(d8[4*iy+k]);
- acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
- acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
- }
- }
- }
- for (int ib = 4*(nb/4); ib < nb; ++ib) {
- auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d);
- auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
- auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
- for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j);
- int32x4_t sumi1, sumi2;
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto qy = (const block_q8_0 *)q8.y[iy];
- qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2);
- auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
- acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
- acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix+0, iy, acc[2*iy+0]);
- info.store(ix+4, iy, acc[2*iy+1]);
- acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f);
- }
- }
-}
-
-#define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \
- m.funcs[0] = func<Dequantizer, 1>;\
- m.funcs[1] = func<Dequantizer, 2>;\
- m.funcs[2] = func<Dequantizer, 3>;\
- m.funcs[3] = func<Dequantizer, 4>;\
- m.funcs[4] = func<Dequantizer, 5>;\
- m.funcs[5] = func<Dequantizer, 6>;\
- m.funcs[6] = func<Dequantizer, 7>;\
- m.funcs[7] = func<Dequantizer, 8>;\
-
-#define SET_MUL_MAT_FUNCTIONS(m, func) \
- m.funcs[0] = func<1>;\
- m.funcs[1] = func<2>;\
- m.funcs[2] = func<3>;\
- m.funcs[3] = func<4>;\
- m.funcs[4] = func<5>;\
- m.funcs[5] = func<6>;\
- m.funcs[6] = func<7>;\
- m.funcs[7] = func<8>;\
-
-template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
- if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> ||
- std::is_same_v<Dequantizer, DequantizerQ80> || std::is_same_v<Dequantizer, DequantizerIQ4NL> ||
- std::is_same_v<Dequantizer, DequantizerQ60>) {
- SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_0_q8_0, Dequantizer);
- }
- else if constexpr (std::is_same_v<Dequantizer, DequantizerQ41> || std::is_same_v<Dequantizer, DequantizerQ51>) {
- SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_1_q8_1, Dequantizer);
- }
- else {
- SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_K_q8_K_T, Dequantizer);
- }
-}
-
bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
- if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F16) {
- if (ne00%4) return false;
- for (auto& f : m.funcs) f = nullptr;
- m.funcs[0] = mul_mat_f16_f16_1;
- m.funcs[1] = mul_mat_f16_f16_T<2>;
- m.funcs[2] = mul_mat_f16_f16_T<3>;
- m.funcs[3] = mul_mat_f16_f16_T<4>;
- m.funcs[4] = mul_mat_f16_f16_T<5>;
- return true;
- }
-
- if (typeA == GGML_TYPE_BF16 && typeB == GGML_TYPE_F32) {
- if (ne00%4) return false;
- for (auto& f : m.funcs) f = nullptr;
- m.funcs[0] = mul_mat_Qx_Qy_T<QF32<1>, QBF16>;
- m.funcs[1] = mul_mat_Qx_Qy_T<QF32<2>, QBF16>;
- m.funcs[2] = mul_mat_Qx_Qy_T<QF32<3>, QBF16>;
- m.funcs[3] = mul_mat_Qx_Qy_T<QF32<4>, QBF16>;
- m.funcs[4] = mul_mat_Qx_Qy_T<QF32<5>, QBF16>;
- return true;
- }
-
- auto expected_Btype = GGML_TYPE_Q8_K;
-
switch (typeA) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_BF16:
+ case GGML_TYPE_F32:
+ return iqk_set_kernels_float(ne00, typeA, typeB, m.funcs);
case GGML_TYPE_Q2_K:
- MulMat::set_functions<DequantizerQ2K>(m);
- break;
case GGML_TYPE_Q3_K:
- MulMat::set_functions<DequantizerQ3K>(m);
- break;
case GGML_TYPE_Q4_K:
- MulMat::set_functions<DequantizerQ4K>(m);
- break;
case GGML_TYPE_Q5_K:
- MulMat::set_functions<DequantizerQ5K>(m);
- break;
case GGML_TYPE_Q6_K:
- MulMat::set_functions<DequantizerQ6K>(m);
- break;
case GGML_TYPE_IQ4_XS:
- MulMat::set_functions<DequantizerIQ4XS>(m);
- break;
- case GGML_TYPE_IQ4_KS:
- MulMat::set_functions<DequantizerIQ4KS>(m);
- break;
- case GGML_TYPE_IQ4_KSS:
- MulMat::set_functions<DequantizerIQ4KSS>(m);
- break;
+ case GGML_TYPE_Q2_K_R4:
+ case GGML_TYPE_Q3_K_R4:
+ case GGML_TYPE_Q4_K_R4:
+ case GGML_TYPE_Q5_K_R4:
+ case GGML_TYPE_Q6_K_R4:
+ case GGML_TYPE_IQ4_XS_R8:
+ case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV:
+ case GGML_TYPE_Q8_KV_R8:
+ return iqk_set_kernels_kquants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ2_KS:
- MulMat::set_functions<DequantizerIQ2KS>(m);
- break;
+ case GGML_TYPE_IQ2_K:
+ case GGML_TYPE_IQ3_K:
+ case GGML_TYPE_IQ4_KSS:
+ case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_K:
- MulMat::set_functions<DequantizerIQ4K>(m);
- break;
- case GGML_TYPE_IQ5_K:
- MulMat::set_functions<DequantizerIQ5K>(m);
- break;
case GGML_TYPE_IQ5_KS:
- MulMat::set_functions<DequantizerIQ5KS>(m);
- break;
+ case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
- MulMat::set_functions<DequantizerIQ6K>(m);
- break;
- case GGML_TYPE_IQ2_K:
- MulMat::set_functions<DequantizerIQ2K>(m);
- break;
- case GGML_TYPE_IQ3_K:
- MulMat::set_functions<DequantizerIQ3K>(m);
- break;
+ case GGML_TYPE_IQ2_K_R4:
+ case GGML_TYPE_IQ3_K_R4:
+ case GGML_TYPE_IQ4_K_R4:
+ case GGML_TYPE_IQ5_K_R4:
+ case GGML_TYPE_IQ4_KS_R4:
+ case GGML_TYPE_IQ5_KS_R4:
+ return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ2_XXS:
- MulMat::set_functions<DequantizerIQ2XXS>(m);
- break;
case GGML_TYPE_IQ2_XS:
- MulMat::set_functions<DequantizerIQ2XS>(m);
- break;
case GGML_TYPE_IQ2_S:
- MulMat::set_functions<DequantizerIQ2S>(m);
- break;
case GGML_TYPE_IQ3_XXS:
- MulMat::set_functions<DequantizerIQ3XXS>(m);
- break;
case GGML_TYPE_IQ3_S:
- MulMat::set_functions<DequantizerIQ3S>(m);
- break;
- case GGML_TYPE_IQ1_BN:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1bn_q8_K64);
- expected_Btype = GGML_TYPE_Q8_K64;
- break;
- case GGML_TYPE_IQ2_BN:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2bn_q8_K64);
- expected_Btype = GGML_TYPE_Q8_K64;
- break;
- case GGML_TYPE_IQ2_BN_R4:
- m.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
- m.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
- m.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
- m.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
- m.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
- //m.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
- //m.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
- //m.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
- expected_Btype = GGML_TYPE_Q8_K16;
- break;
+ case GGML_TYPE_IQ2_XXS_R4:
+ case GGML_TYPE_IQ2_XS_R4:
+ case GGML_TYPE_IQ2_S_R4:
+ case GGML_TYPE_IQ3_XXS_R4:
+ case GGML_TYPE_IQ3_S_R4:
+ return iqk_set_kernels_iquants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_Q4_0:
- MulMat::set_functions<DequantizerQ40>(m);
- expected_Btype = GGML_TYPE_Q8_0_X4;
- break;
case GGML_TYPE_Q4_1:
- MulMat::set_functions<DequantizerQ41>(m);
- expected_Btype = GGML_TYPE_Q8_1_X4;
- break;
case GGML_TYPE_Q5_0:
- MulMat::set_functions<DequantizerQ50>(m);
- expected_Btype = GGML_TYPE_Q8_0_X4;
- break;
case GGML_TYPE_Q5_1:
- MulMat::set_functions<DequantizerQ51>(m);
- expected_Btype = GGML_TYPE_Q8_1_X4;
- break;
case GGML_TYPE_Q6_0:
- MulMat::set_functions<DequantizerQ60>(m);
- expected_Btype = GGML_TYPE_Q8_0_X4;
- break;
case GGML_TYPE_Q8_0:
- MulMat::set_functions<DequantizerQ80>(m);
- expected_Btype = GGML_TYPE_Q8_0_X4;
- break;
case GGML_TYPE_IQ4_NL:
- MulMat::set_functions<DequantizerIQ4NL>(m);
- expected_Btype = GGML_TYPE_Q8_0_X4;
- break;
- case GGML_TYPE_IQ4_NL_R4:
- SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer);
- expected_Btype = GGML_TYPE_Q8_0_X4;
- break;
- case GGML_TYPE_IQ4_XS_R8:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r8_q8_k);
- expected_Btype = GGML_TYPE_Q8_K32;
- break;
- case GGML_TYPE_IQ4_KS_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_ks_r4_q8_k);
- expected_Btype = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ2_XXS_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_xxs_r4_q8_k);
- m.func16 = mul_mat_iq2_xxs_r4_q8_k<16>;
- expected_Btype = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ2_XS_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_xs_r4_q8_k);
- m.func16 = mul_mat_iq2_xs_r4_q8_k<16>;
- expected_Btype = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ2_S_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_s_r4_q8_k);
- m.func16 = mul_mat_iq2_s_r4_q8_k<16>;
- expected_Btype = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ1_S:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_q8_K);
- m.func16 = mul_mat_iq1_s_q8_K<16>;
- expected_Btype = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ1_S_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_r4_q8_1);
- m.funcs[0] = mul_mat_iq1_s_r4_q8_1_1;
- m.func16 = mul_mat_iq1_s_r4_q8_1<16>;
- expected_Btype = GGML_TYPE_Q8_K128;
- break;
- case GGML_TYPE_IQ1_M_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_m_r4_q8_0);
- m.func16 = mul_mat_iq1_m_r4_q8_0<16>;
- expected_Btype = GGML_TYPE_Q8_K128;
- break;
- case GGML_TYPE_IQ3_XXS_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k);
- m.func16 = mul_mat_iq3_xxs_r4_q8_k<16>;
- expected_Btype = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ3_S_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_s_r4_q8_k);
- m.func16 = mul_mat_iq3_s_r4_q8_k<16>;
- expected_Btype = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_Q2_K_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_q2_k_r4_q8_k);
- expected_Btype = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_Q3_K_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_q3_k_r4_q8_k);
- expected_Btype = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_Q4_K_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_q4_k_r4_q8_k);
- expected_Btype = GGML_TYPE_Q8_K32;
- break;
- case GGML_TYPE_Q5_K_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_q5_k_r4_q8_k);
- expected_Btype = GGML_TYPE_Q8_K32;
- break;
- case GGML_TYPE_Q6_K_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_q6_k_r4_q8_k);
- expected_Btype = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_Q8_K_R8:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k);
- expected_Btype = GGML_TYPE_Q8_KR8;
- break;
- case GGML_TYPE_Q8_KV:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_q8_KV);
- m.funcs[0] = mul_mat_q8_KV_q8_KV_1;
- m.func16 = mul_mat_q8_KV_q8_KV<16>;
- expected_Btype = GGML_TYPE_Q8_KV;
- break;
- case GGML_TYPE_Q8_KV_R8:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_r8_q8_KV);
- expected_Btype = GGML_TYPE_Q8_KV;
- break;
- case GGML_TYPE_IQ2_K_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_k_r4_q8_k);
- expected_Btype = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ3_K_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_k_r4_q8_k);
- expected_Btype = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ4_K_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_k_r4_q8_k);
- expected_Btype = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ5_K_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq5_k_r4_q8_k);
- expected_Btype = GGML_TYPE_Q8_K;
- break;
- case GGML_TYPE_IQ5_KS_R4:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq5_ks_r4_q8_k_neon);
- expected_Btype = GGML_TYPE_Q8_K;
- break;
case GGML_TYPE_Q4_0_R8:
- SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r8_q8_0, Q4_0_R8_Dequantizer);
- expected_Btype = GGML_TYPE_Q8_0_X4;
- break;
case GGML_TYPE_Q5_0_R4:
- SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q5_0_R4_Dequantizer);
- expected_Btype = GGML_TYPE_Q8_0_X4;
- break;
case GGML_TYPE_Q6_0_R4:
- SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer);
- expected_Btype = GGML_TYPE_Q8_0_X4;
- break;
case GGML_TYPE_Q8_0_R8:
- SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_0_r8_q8_0);
- expected_Btype = GGML_TYPE_Q8_0_X4;
- break;
+ case GGML_TYPE_IQ4_NL_R4:
+ return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16);
+ case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_BN_R4:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_S_R4:
+ case GGML_TYPE_IQ1_M_R4:
+ return iqk_set_kernels_1bit(ne00, typeA, typeB, m.funcs, m.func16);
default:
return false;
}
- return typeB == expected_Btype;
}
}
@@ -15618,70 +659,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
namespace {
#if defined(__ARM_NEON) && defined(__aarch64__)
-// copy-pasted from Justine Tunney's contribution to llama.cpp
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline float32x4_t v_expf(float32x4_t x) {
- const float32x4_t r = vdupq_n_f32(0x1.8p23f);
- const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
- const float32x4_t n = vsubq_f32(z, r);
- const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
- vdupq_n_f32(0x1.7f7d1cp-20f));
- const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
- const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
- const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
- const float32x4_t u = vmulq_f32(b, b);
- const float32x4_t j = vfmaq_f32(
- vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
- vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
- vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
- if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
- return vfmaq_f32(k, j, k);
- const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
- const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
- const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
- return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
- vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
-}
-inline float16x8_t v_expf(float16x8_t x) {
- auto val1 = v_expf(vcvt_f32_f16(vget_low_f16(x)));
- auto val2 = v_expf(vcvt_f32_f16(vget_high_f16(x)));
- return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
-}
-inline float32x4_t v_tanh(float32x4_t x) {
- const float32x4_t one = vdupq_n_f32(1.0f);
- const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
- const float32x4_t exp_two_x = v_expf(two_x);
- const uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
- const float32x4_t res = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
- return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
- //return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
-}
-//inline float32x4_t v_tanh(float16x8_t x) {
-// auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x)));
-// auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
-// return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
-//}
-inline float32x4_t v_silu(float32x4_t x) {
- const float32x4_t one = vdupq_n_f32(1.0f);
- const float32x4_t zero = vdupq_n_f32(0.0f);
- const float32x4_t neg_x = vsubq_f32(zero, x);
- const float32x4_t exp_neg_x = v_expf(neg_x);
- const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
- return vdivq_f32(x, one_plus_exp_neg_x);
-}
-inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) {
- const float32x4_t one = vdupq_n_f32(1.0f);
- float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x));
- arg = vmulq_f32(arg, vmulq_f32(x, c2));
- float32x4_t exp_arg = v_expf(arg);
- float32x4_t gelu = vmulq_f32(x, vdivq_f32(exp_arg, vaddq_f32(exp_arg, one)));
- uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
- return vbslq_f32(mask, x, gelu);
-}
-
void MulMat::gelu(int n, const float * x, float * y) {
constexpr float GELU_COEF_A = 0.044715f;
constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -15693,147 +670,18 @@ void MulMat::gelu(int n, const float * x, float * y) {
}
for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i])));
}
-
void MulMat::silu(int n, const float * x, float * y) {
int i = 0;
for (; i + 3 < n; i += 4) vst1q_f32(y + i, v_silu(vld1q_f32(x + i)));
for (; i < n; ++i) y[i] = x[i]/(1.0f + expf(-x[i]));
}
-
void MulMat::relu(int n, const float * x, float * y) {
for (int j = 0; j < n; ++j) y[j] = x[j] > 0 ? x[j] : 0;
}
#endif
-#if defined(__AVX512F__) && defined(__AVX512DQ__)
-
-// copy-pasted from Justine Tunney's contribution to llama.cpp
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline __m512 v_expf(__m512 x) {
- const __m512 r = _mm512_set1_ps(0x1.8p23f);
- const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
- const __m512 n = _mm512_sub_ps(z, r);
- const __m512 b =
- _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
- _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
- const __mmask16 d =
- _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
- const __m512 u = _mm512_mul_ps(b, b);
- const __m512 j = _mm512_fmadd_ps(
- _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
- _mm512_set1_ps(0x1.573e2ep-5f)),
- u,
- _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
- _mm512_set1_ps(0x1.fffdb6p-2f))),
- u,
- _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
- const __m512 res = _mm512_scalef_ps(j, n);
- if (_mm512_kortestz(d, d))
- return res;
- const __m512 zero = _mm512_setzero_ps();
- const __m512 alt = _mm512_mask_blend_ps(
- _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
- return _mm512_mask_blend_ps(d, res, alt);
-}
-inline __m512 v_tanh(__m512 x) {
- const __m512 one = _mm512_set1_ps(1.0f);
- const __m512 exp_two_x = v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f)));
- const __mmask16 mask = _mm512_cmp_ps_mask(x, _mm512_set1_ps(10.f), _CMP_GT_OQ);
- const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
- return _mm512_mask_blend_ps(mask, res, one);
-}
-inline __m512 v_gelu(__m512 x, __m512 c1, __m512 c2) {
- const __m512 one = _mm512_set1_ps(1.0f);
- __m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one);
- //__m512 arg = _mm512_add_ps(one, _mm512_mul_ps(_mm512_mul_ps(x, x), c1));
- arg = _mm512_mul_ps(arg, _mm512_mul_ps(c2, x));
- const __mmask16 mask = _mm512_cmp_ps_mask(arg, _mm512_set1_ps(30.f), _CMP_GT_OQ);
- const __m512 exp_arg = v_expf(arg);
- const __m512 ratio = _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one));
- return _mm512_mul_ps(x, _mm512_mask_blend_ps(mask, ratio, one));
-}
-inline static __m512 v_silu(__m512 x) {
- const __m512 one = _mm512_set1_ps(1);
- const __m512 zero = _mm512_setzero_ps();
- const __m512 neg_x = _mm512_sub_ps(zero, x);
- const __m512 exp_neg_x = v_expf(neg_x);
- const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
- return _mm512_div_ps(x, one_plus_exp_neg_x);
-}
-#endif
-
#if defined(__AVX2__) && defined(__FMA__)
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline __m256 v_expf(__m256 x) {
- const __m256 r = _mm256_set1_ps(0x1.8p23f);
- const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
- const __m256 n = _mm256_sub_ps(z, r);
- const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
- _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
- const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
- const __m256 k = _mm256_castsi256_ps(
- _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
- const __m256i c = _mm256_castps_si256(
- _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
- _mm256_set1_ps(126), _CMP_GT_OQ));
- const __m256 u = _mm256_mul_ps(b, b);
- const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
- _mm256_set1_ps(0x1.573e2ep-5f)), u,
- _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
- _mm256_set1_ps(0x1.fffdb6p-2f))),
- u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
- if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
- return _mm256_fmadd_ps(j, k, k);
- const __m256i g = _mm256_and_si256(
- _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
- _mm256_set1_epi32(0x82000000u));
- const __m256 s1 =
- _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
- const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
- const __m256i d = _mm256_castps_si256(
- _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
- _mm256_set1_ps(192), _CMP_GT_OQ));
- return _mm256_or_ps(
- _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
- _mm256_andnot_ps(
- _mm256_castsi256_ps(d),
- _mm256_or_ps(
- _mm256_and_ps(_mm256_castsi256_ps(c),
- _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
- _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
-}
-inline __m256 v_tanh(__m256 x) {
- const __m256 one = _mm256_set1_ps(1.0f);
- const __m256 exp_two_x = v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f)));
- const __m256 res = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
- const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
- return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res));
-}
-inline static __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) {
- const __m256 one = _mm256_set1_ps(1.0f);
- const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
- __m256 arg = _mm256_add_ps(one, _mm256_mul_ps(_mm256_mul_ps(x, x), c1));
- arg = _mm256_mul_ps(arg, _mm256_mul_ps(x, c2));
- __m256 exp_arg = v_expf(arg);
- __m256 gelu = _mm256_mul_ps(x, _mm256_div_ps(exp_arg, _mm256_add_ps(exp_arg, one)));
- return _mm256_or_ps(_mm256_and_ps(mask, x), _mm256_andnot_ps(mask, gelu));
-}
-inline static __m256 v_silu(__m256 x) {
- const __m256 one = _mm256_set1_ps(1);
- const __m256 zero = _mm256_setzero_ps();
- const __m256 neg_x = _mm256_sub_ps(zero, x);
- const __m256 exp_neg_x = v_expf(neg_x);
- const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
- return _mm256_div_ps(x, one_plus_exp_neg_x);
-}
-
void MulMat::gelu(int n, const float * x, float * y) {
constexpr float GELU_COEF_A = 0.044715f;
constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -15876,281 +724,6 @@ void MulMat::relu(int n, const float * x, float * y) {
} // namespace
#ifdef GGML_IQK_FLASH_ATTENTION
-namespace {
-
-template <int k_step>
-struct BaseHelper {
- BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {}
-
- //inline void set_block(int k1) { block = data + k1*k_step*stride; }
- inline void reset_block() { block = data; }
- inline void next_block() { block += k_step*stride; }
- inline const char * lblock(int l1) const { return block + l1*stride; }
-
- const char * data;
- const char * block;
- int stride;
-
-};
-
-struct F16 {
-#ifdef __AVX512F__
- using Data = __m512;
- constexpr static int block_size = 16;
- constexpr static int num_registers = 32;
- constexpr static int q_step = 8;
- static inline Data zero() { return _mm512_setzero_ps(); }
- static inline Data load(const char * ptr, int i) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)ptr + i)); }
- static inline Data set1(float val) { return _mm512_set1_ps(val); }
- static inline Data mul(Data v1, Data v2) { return _mm512_mul_ps(v1, v2); }
- static inline Data sub(Data v1, Data v2) { return _mm512_sub_ps(v1, v2); }
- static inline Data load(const float * ptr) { return _mm512_loadu_ps(ptr); }
- static inline void store(float * ptr, Data data) { _mm512_storeu_ps(ptr, data); }
- static inline Data fmadd(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, v2, prev); }
- static inline float reduce_max(Data data) { return _mm512_reduce_max_ps(data); }
- static inline float reduce_add(Data data) { return _mm512_reduce_add_ps(data); }
- static inline Data max(Data v1, Data v2) { return _mm512_max_ps(v1, v2); }
- static inline Data add(Data v1, Data v2) { return _mm512_add_ps(v1, v2); }
- static inline Data set4(const float * ptr) {
- auto v128 = _mm_loadu_ps(ptr);
- auto v256 = _mm256_set_m128(v128, v128);
- return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1);
- }
- static inline void set4(const float * ptr, Data * vs) {
- auto v = set4(ptr);
- vs[0] = _mm512_shuffle_ps(v, v, 0x00);
- vs[1] = _mm512_shuffle_ps(v, v, 0x55);
- vs[2] = _mm512_shuffle_ps(v, v, 0xaa);
- vs[3] = _mm512_shuffle_ps(v, v, 0xff);
- }
- static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); }
- static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); }
- static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); }
- static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xff), prev); }
-#elif defined __AVX2__
- using Data = __m256;
- constexpr static int block_size = 8;
- constexpr static int num_registers = 16;
- constexpr static int q_step = 8;
- static inline Data zero() { return _mm256_setzero_ps(); }
- static inline Data load(const char * ptr, int i) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)ptr + i)); }
- static inline Data set1(float val) { return _mm256_set1_ps(val); }
- static inline Data mul(Data v1, Data v2) { return _mm256_mul_ps(v1, v2); }
- static inline Data load(const float * ptr) { return _mm256_loadu_ps(ptr); }
- static inline Data sub(Data v1, Data v2) { return _mm256_sub_ps(v1, v2); }
- static inline void store(float * ptr, Data data) { _mm256_storeu_ps(ptr, data); }
- static inline Data fmadd(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, v2, prev); }
- static inline float reduce_max(Data data) { return hmax_float_8(data); }
- static inline float reduce_add(Data data) { return hsum_float_8(data); }
- static inline Data max(Data v1, Data v2) { return _mm256_max_ps(v1, v2); }
- static inline Data add(Data v1, Data v2) { return _mm256_add_ps(v1, v2); }
- static inline Data set4(const float * ptr) {
- auto v128 = _mm_loadu_ps(ptr);
- return _mm256_set_m128(v128, v128);
- }
- static inline void set4(const float * ptr, Data * vs) {
- auto v = set4(ptr);
- vs[0] = _mm256_shuffle_ps(v, v, 0x00);
- vs[1] = _mm256_shuffle_ps(v, v, 0x55);
- vs[2] = _mm256_shuffle_ps(v, v, 0xaa);
- vs[3] = _mm256_shuffle_ps(v, v, 0xff);
- }
- static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); }
- static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); }
- static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); }
- static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xff), prev); }
-#else
- using Data = float16x8_t;
- constexpr static int block_size = 8;
- //constexpr static int num_registers = 32;
- //constexpr static int q_step = 8;
- static inline Data zero() { return vdupq_n_f16(0); }
- static inline Data load(const char * ptr, int i) { return vld1q_f16((const float16_t *)ptr + block_size*i); }
- static inline Data load(const float16_t * ptr, int i) { return vld1q_f16(ptr + block_size*i); }
- static inline Data load(const float16_t * ptr) { return vld1q_f16(ptr); }
- static inline Data load(const float * ptr) {
- auto val1 = vld1q_f32(ptr);
- auto val2 = vld1q_f32(ptr+4);
- return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
- }
- static inline Data set1(float val) { return vdupq_n_f16(val); }
- static inline Data mul(Data v1, Data v2) { return vmulq_f16(v1, v2); }
- static inline Data sub(Data v1, Data v2) { return vsubq_f16(v1, v2); }
- static inline void store(float * ptr, Data data) {
- vst1q_f32(ptr+0, vcvt_f32_f16(vget_low_f16(data)));
- vst1q_f32(ptr+4, vcvt_f32_f16(vget_high_f16(data)));
- }
- static inline void store(float16_t * ptr, Data data) { vst1q_f16(ptr, data); }
- static inline void store(float * ptr, float32x4_t data) { vst1q_f32(ptr, data); }
- static inline Data fmadd(Data prev, Data v1, Data v2) { return vfmaq_f16(prev, v1, v2); }
- static inline float reduce_max(Data data) { return vmaxvq_f16(data); }
- static inline float reduce_add(Data data) {
- auto sum = vadd_f16(vget_low_f16(data), vget_high_f16(data));
- return vaddvq_f32(vcvt_f32_f16(sum));
- }
- static inline Data max(Data v1, Data v2) { return vmaxq_f16(v1, v2); }
- static inline Data add(Data v1, Data v2) { return vaddq_f16(v1, v2); }
- static inline float16x4_t set4(const float * ptr) {
- auto val32 = vld1q_f32(ptr);
- return vcvt_f16_f32(val32);
- }
- static inline Data fmadd_lane0(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 0); }
- static inline Data fmadd_lane1(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 1); }
- static inline Data fmadd_lane2(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 2); }
- static inline Data fmadd_lane3(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 3); }
-#endif
- template <int k_step> static inline float reduce_max(const Data * data) {
- return reduce_T<k_step, &F16::max, &F16::reduce_max>(data);
- }
- template <int k_step> static inline float reduce_add(const Data * data) {
- return reduce_T<k_step, &F16::add, &F16::reduce_add>(data);
- }
- template <int k_step, Data (*Op_combine)(Data, Data), float (*Op)(Data)>
- static float reduce_T(const Data * data) {
- float result;
- if constexpr (k_step/block_size == 1) {
- result = Op(data[0]);
- }
- else if constexpr (k_step/block_size == 2) {
- result = Op(Op_combine(data[0], data[1]));
- }
- else {
- auto vmax = Op_combine(data[0], data[1]);
- for (int l = 2; l < k_step/block_size; ++l) vmax = Op_combine(vmax, data[l]);
- result = Op(vmax);
- }
- return result;
- }
-};
-
-template <int D, int step>
-struct HelperF16 final : public BaseHelper<step> {
- using Base = BaseHelper<step>;
- HelperF16(const char * data, int stride) : Base(data, stride) {}
-
- inline void load(int l1, F16::Data * vk) const {
- auto dr = Base::lblock(l1);
- for (int i = 0; i < D/F16::block_size; ++i) vk[i] = F16::load(dr, i);
- }
-
- inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
- //auto dr = (const ggml_half *)Base::lblock(l1);
- auto dr = Base::lblock(l1);
- v1 = F16::load(dr, i + 0);
- v2 = F16::load(dr, i + 1);
- }
-
- inline void load_2(int l1, F16::Data* vk) const {
- load(l1+0, vk+0);
- load(l1+1, vk+D/16);
- }
-};
-
-template <int D> struct block_q8_KV {
- float d;
- int s;
- int8_t qs[D];
-};
-
-template <int D, int step>
-struct HelperQ8KV final : public BaseHelper<step> {
- using Base = BaseHelper<step>;
- using block_q8 = block_q8_KV<D>;
- constexpr static int block_size_q = D;
- HelperQ8KV(const char * data, int stride) : Base(data, stride) {}
-
- // Needed for v * softmax(k * q)
- inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
- auto q8 = (const block_q8_KV<D> *)Base::lblock(l1);
-#ifdef __aarch64__
- auto vd = F16::set1(q8->d);
- auto qs = vld1_s8_x2(q8->qs + 8*i);
- v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
- v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
-#else
- auto vd = F16::set1(q8->d);
-#ifdef __AVX512F__
- v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+0))));
- v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+1))));
-#else
- v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+0)))));
- v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+8)))));
-#endif
-#endif
- }
-};
-
-template <int D, int step>
-struct HelperQ80 final : public BaseHelper<step> {
- using Base = BaseHelper<step>;
-#ifdef HAVE_FANCY_SIMD
- using block_q8 = block_q8_2;
- constexpr static int block_size_q = QK8_2;
-#else
- using block_q8 = block_q8_0;
- constexpr static int block_size_q = QK8_0;
-#endif
- HelperQ80(const char * data, int stride) : Base(data, stride) {}
-
- // Needed for v * softmax(k * q)
- inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
- int j = F16::block_size*i;
- auto dl = (const block_q8_0 *)Base::lblock(l1) + j/QK8_0;
-#ifdef __aarch64__
- auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
- int ii = j%QK8_0;
- auto qs = vld1_s8_x2(dl->qs + ii);
- v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
- v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
-#else
- auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
-#ifdef __AVX512F__
- v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+0))));
- v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+1))));
-#else
- int ii = j%QK8_0;
- v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii+0)))));
- v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii+8)))));
-#endif
-#endif
- }
-
- static inline void convert(int nq, int stride_q, const float * q, block_q8_0 * y) {
- //GGML_ASSERT(nq <= step); Why did I have this assert?
- for (int i = 0; i < nq; ++i) {
- quantize_row_q8_0_x4(q, y, D);
- q += stride_q;
- y += D/QK8_0;
- }
- }
-
- static inline void convert(int nq, int stride_q, const float * q, block_q8_1 * y) {
- //GGML_ASSERT(nq <= step); Why did I have this assert?
- for (int i = 0; i < nq; ++i) {
- quantize_row_q8_1_x4(q, y, D);
- q += stride_q;
- y += D/QK8_1;
- }
- }
-
- static inline void convert(int nq, int stride_q, const float * q, block_q8_2 * y) {
- //GGML_ASSERT(nq <= step); Why did I have this assert?
- for (int i = 0; i < nq; ++i) {
- quantize_row_q8_2_x4(q, y, D);
- q += stride_q;
- y += D/QK8_2;
- }
- }
-
- static inline void convert(int nq, int stride_q, const float * q, block_q8_KV<D> * y) {
- for (int i = 0; i < nq; ++i) {
- quantize_row_q8_KV(q, y, D);
- q += stride_q;
- ++y;
- }
- }
-};
-}
void * iqk_repack_k(int int_type_k, int nek0, int nek1, int nek2, int nek3, long nbk1, long nbk2, long nbk3,
const void * data, void * work, int ith, int nth, int& repacked_type, uint64_t& row_size) {
@@ -16262,2261 +835,8 @@ void * iqk_repack_k(int int_type_k, int nek0, int nek1, int nek2, int nek3, long
return result;
}
-namespace {
-template <int D, int step>
-struct HelperQ80R8 : public BaseHelper<step> {
- using Base = BaseHelper<step>;
-#ifdef __AVX2__
- constexpr static int block_size_q = QK8_2;
- using block_q8 = block_q8_2;
-#else
- constexpr static int block_size_q = QK8_0;
- using block_q8 = block_q8_0;
-#endif
- HelperQ80R8(const char * data, int stride) : Base(data, stride) {}
- HelperQ80R8(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) {
- r4 = repack(nk, q8);
- Base::data = (const char *)r4.data();
- Base::stride = (D/QK8_0)*sizeof(block_q8_0);
- }
-
- static void repack(int nk, const char * q8_data, int q8_stride, block_q8_0_r8 * y) {
- constexpr int nblock = D/QK8_0;
- const block_q8_0 * x8[8];
-#ifdef __ARM_NEON
- int8x16x2_t m0, m1, m2, m3;
-#endif
- for (int row = 0; row < nk; row += 8) {
- for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8_data + (row + k)*q8_stride);
- for (int ib = 0; ib < nblock; ++ib) {
- for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d;
-#ifdef __AVX2__
- auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs), _mm_loadu_si128((const __m128i *)x8[0][ib].qs));
- auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs), _mm_loadu_si128((const __m128i *)x8[1][ib].qs));
- auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs), _mm_loadu_si128((const __m128i *)x8[2][ib].qs));
- auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs), _mm_loadu_si128((const __m128i *)x8[3][ib].qs));
- auto t0 = _mm256_unpacklo_epi32(m0, m1);
- auto t1 = _mm256_unpacklo_epi32(m2, m3);
- auto t2 = _mm256_unpackhi_epi32(m0, m1);
- auto t3 = _mm256_unpackhi_epi32(m2, m3);
- m0 = _mm256_unpacklo_epi64(t0, t1);
- m1 = _mm256_unpackhi_epi64(t0, t1);
- m2 = _mm256_unpacklo_epi64(t2, t3);
- m3 = _mm256_unpackhi_epi64(t2, t3);
-//#ifdef HAVE_FANCY_SIMD
-// m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
-// m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
-// m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
-// m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
-//#endif
- _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0);
- _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1);
- _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2);
- _mm256_storeu_si256((__m256i *)y[ib].qs + 3, m3);
- m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[0][ib].qs+1));
- m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[1][ib].qs+1));
- m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[2][ib].qs+1));
- m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[3][ib].qs+1));
- t0 = _mm256_unpacklo_epi32(m0, m1);
- t1 = _mm256_unpacklo_epi32(m2, m3);
- t2 = _mm256_unpackhi_epi32(m0, m1);
- t3 = _mm256_unpackhi_epi32(m2, m3);
- m0 = _mm256_unpacklo_epi64(t0, t1);
- m1 = _mm256_unpackhi_epi64(t0, t1);
- m2 = _mm256_unpacklo_epi64(t2, t3);
- m3 = _mm256_unpackhi_epi64(t2, t3);
-//#ifdef HAVE_FANCY_SIMD
-// m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
-// m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
-// m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
-// m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
-//#endif
- _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0);
- _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1);
- _mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2);
- _mm256_storeu_si256((__m256i *)y[ib].qs + 7, m3);
-#elif defined __ARM_NEON
- for (int l = 0; l < 2; ++l) {
- m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l);
- m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l);
- m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l);
- m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l);
- auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
- auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
- m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
- m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
- m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
- m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
- row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
- row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
- m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
- m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
- m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
- m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
- vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0);
- vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1);
- vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2);
- vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3);
- }
-#else
- for (int l = 0; l < 4; ++l) {
- for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
- y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
- y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
- }
- }
-#endif
- }
- y += nblock;
- }
- }
-
- static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) {
- static_assert(D%QK8_0 == 0);
- GGML_ASSERT(nk%8 == 0);
- constexpr int nblock = D/QK8_0;
- std::vector<block_q8_0_r8> result(nblock * nk/8);
- auto y = result.data();
- repack(nk, q8.data, q8.stride, y);
- return result;
- }
-
- std::vector<block_q8_0_r8> r4;
-};
-
-// TODO: unite this with the above
-template <int D, int step>
-struct HelperQ8KVR8 : public BaseHelper<step> {
- using Base = BaseHelper<step>;
- constexpr static int block_size_q = D;
- using block_q8 = block_q8_KV<D>;
-
- struct block_q8_KV_r8 {
- float d[8];
- int8_t qs[8*D];
- };
-
- HelperQ8KVR8(int nk, const HelperQ8KV<D, step>& q8) : Base(q8.data, q8.stride) {
- r4 = repack(nk, q8);
- Base::data = (const char *)r4.data();
- Base::stride = sizeof(block_q8_KV_r8)/8;
- }
-
- static std::vector<block_q8_KV_r8> repack(int nk, const HelperQ8KV<D, step>& q8) {
- static_assert(D%32 == 0);
- GGML_ASSERT(nk%8 == 0);
- std::vector<block_q8_KV_r8> result(nk/8);
- auto y = result.data();
-#ifdef __ARM_NEON
- int8x16x2_t m0, m1, m2, m3;
-#endif
- const int8_t * x8[8];
- for (int ix = 0; ix < nk/8; ++ix) {
- for (int k = 0; k < 8; ++k) {
- auto dptr = (const float *)(q8.data + (8*ix + k)*q8.stride);
- y[ix].d[k] = dptr[0];
- x8[k] = (const int8_t *)(dptr + 2);
- }
- for (int ib = 0; ib < D/16; ++ib) {
-#ifdef __AVX2__
- auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib));
- auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib));
- auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib));
- auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib));
- auto t0 = _mm256_unpacklo_epi32(m0, m1);
- auto t1 = _mm256_unpacklo_epi32(m2, m3);
- auto t2 = _mm256_unpackhi_epi32(m0, m1);
- auto t3 = _mm256_unpackhi_epi32(m2, m3);
- m0 = _mm256_unpacklo_epi64(t0, t1);
- m1 = _mm256_unpackhi_epi64(t0, t1);
- m2 = _mm256_unpacklo_epi64(t2, t3);
- m3 = _mm256_unpackhi_epi64(t2, t3);
-//#ifdef HAVE_FANCY_SIMD
-// m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
-// m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
-// m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
-// m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
-//#endif
- _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+0, m0);
- _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+1, m1);
- _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+2, m2);
- _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+3, m3);
-#elif defined __ARM_NEON
- // TODO
- m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib);
- m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib);
- m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib);
- m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib);
- auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
- auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
- m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
- m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
- m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
- m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
- row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
- row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
- m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
- m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
- m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
- m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
- vst1q_s8_x2(y[ix].qs + 0 + 128*ib, m0);
- vst1q_s8_x2(y[ix].qs + 32 + 128*ib, m1);
- vst1q_s8_x2(y[ix].qs + 64 + 128*ib, m2);
- vst1q_s8_x2(y[ix].qs + 96 + 128*ib, m3);
-#else
- // TODO
- for (int l = 0; l < 4; ++l) {
- for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
- y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
- y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
- }
- }
-#endif
- }
- }
- return result;
- }
-
- std::vector<block_q8_KV_r8> r4;
-};
-
-template <int D, int step>
-struct HelperQ40 final : public BaseHelper<step> {
- using Base = BaseHelper<step>;
-#if defined __AVX2__
- using block_q8 = block_q8_2;
- constexpr static int block_size_q = QK8_2;
-#else
- using block_q8 = block_q8_0;
- constexpr static int block_size_q = QK8_0;
-#endif
- HelperQ40(const char * data, int stride) : Base(data, stride) {}
-
- // Needed for v * softmax(k * q)
- inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
- int j = F16::block_size*i;
- auto dl = (const block_q4_0 *)Base::lblock(l1) + j/QK4_0;
-#ifdef __aarch64__
- auto vd = F16::set1(*(const float16_t *)&dl->d);
- auto q = vld1q_u8(dl->qs);
- q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
- q = vaddq_s8(q, m8);
- v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q))));
- v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q))));
-#else
- auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
- auto q = _mm_loadu_si128((const __m128i *)dl->qs);
-#ifdef __AVX512F__
- auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
- auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
- v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
- v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
-#else
- if (j%QK4_0) q = _mm_srli_epi16(q, 4);
- auto q16 = _mm256_cvtepi8_epi16(_mm_add_epi8(_mm_and_si128(q, mask), m8));
- v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))));
- v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))));
-#endif
-#endif
- }
-
-#ifdef __AVX2__
- const __m128i mask = _mm_set1_epi8(0xf);
- const __m128i m8 = _mm_set1_epi8(-8);
-#else
- const uint8x16_t mask = vdupq_n_u8(0xf);
- const int8x16_t m8 = vdupq_n_s8(-8);
-#endif
-};
-
-template <int D, int step>
-struct HelperQ41 final : public BaseHelper<step> {
- using Base = BaseHelper<step>;
- using block_q8 = block_q8_2;
- constexpr static int block_size_q = QK8_2;
- HelperQ41(const char * data, int stride) : Base(data, stride) {}
-
- // Needed for v * softmax(k * q)
- inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
- int j = F16::block_size*i;
- auto dl = (const block_q4_1 *)Base::lblock(l1) + j/QK4_1;
-#ifdef __aarch64__
- auto vd = F16::set1(*(const float16_t *)&dl->d);
- auto vm = F16::set1(*(const float16_t *)&dl->m);
- auto q = vld1q_u8(dl->qs);
- q = (j%QK4_1) ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
- v1 = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_low_u8(q))));
- v2 = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_high_u8(q))));
-#else
- auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
- auto vm = F16::set1(GGML_FP16_TO_FP32(dl->m));
- auto q = _mm_loadu_si128((const __m128i *)dl->qs);
-#ifdef __AVX512F__
- auto ql = _mm_and_si128(q, mask);
- auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask);
- v1 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm);
- v2 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)), vm);
-#else
- if (j%QK4_1) q = _mm_srli_epi16(q, 4);
- auto q16 = _mm256_cvtepi8_epi16(_mm_and_si128(q, mask));
- v1 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))), vm);
- v2 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))), vm);
-#endif
-#endif
- }
-
-#ifdef __aarch64__
- const uint8x16_t mask = vdupq_n_u8(0xf);
-#else
- const __m128i mask = _mm_set1_epi8(0xf);
-#endif
-};
-
-template <int D, int step>
-struct HelperIQ4nl final : public BaseHelper<step> {
- using Base = BaseHelper<step>;
-#ifdef __aarch64__
- using block_q8 = block_q8_0;
- HelperIQ4nl(const char * data, int stride) : Base(data, stride), values(vld1q_s8(iq4k_values)) {}
- constexpr static int block_size_q = QK8_0;
-#else
- HelperIQ4nl(const char * data, int stride) : Base(data, stride) {}
-#ifdef HAVE_FANCY_SIMD
- using block_q8 = block_q8_2;
- constexpr static int block_size_q = QK8_2;
-#else
- using block_q8 = block_q8_0;
- constexpr static int block_size_q = QK8_0;
-#endif
-#endif
-
- // Needed for v * softmax(k * q)
- inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
- int j = F16::block_size*i;
- auto dl = (const block_iq4_nl *)Base::lblock(l1) + j/QK4_0;
-#ifdef __aarch64__
- auto vd = F16::set1(*(const float16_t *)&dl->d);
- auto q = vld1q_u8(dl->qs);
- q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
- q = vqtbl1q_s8(values, q);
- v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q))));
- v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q))));
-#else
- auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
- auto q = _mm_loadu_si128((const __m128i *)dl->qs);
-#ifdef __AVX512F__
- auto ql = _mm_shuffle_epi8(values, _mm_and_si128(q, mask));
- auto qh = _mm_shuffle_epi8(values, _mm_and_si128(_mm_srli_epi16(q, 4), mask));
- v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
- v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
-#else
- if (j%QK4_0) q = _mm_srli_epi16(q, 4);
- auto q16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(values, _mm_and_si128(q, mask)));
- v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))));
- v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))));
-#endif
-#endif
- }
-
-#ifdef __aarch64__
- const uint8x16_t mask = vdupq_n_u8(0xf);
- const int8x16_t values;
-#else
- const __m128i mask = _mm_set1_epi8(0xf);
- const __m128i values = _mm_loadu_si128((const __m128i *)iq4k_values);
-#endif
-};
-
-template <int D, int step>
-struct HelperQ60 final : public BaseHelper<step> {
-#ifdef __aarch64__
- using block_q8 = block_q8_0;
- constexpr static int block_size_q = QK8_0;
-#else
- using block_q8 = block_q8_2;
- constexpr static int block_size_q = QK8_2;
-#endif
- using Base = BaseHelper<step>;
- HelperQ60(const char * data, int stride) : Base(data, stride) {}
-
- // Needed for v * softmax(k * q)
- inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
- int j = F16::block_size*i;
- auto dl = (const block_q6_0 *)Base::lblock(l1) + j/QK6_0;
-#ifdef __aarch64__
- // TODO
- const float16_t * d16 = (const float16_t *)&dl->d;
- auto vd = F16::set1(d16[0]);
- //auto vd = F16::set1(*(const float16_t *)&dl->d);
- auto qh8 = vld1_u8(dl->qh);
- auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8);
- auto qs = vld1q_u8(dl->qs);
- qs = j%QK4_0 ? vshrq_n_u8(qs, 4) : vandq_u8(qs, mask_l);
- qs = vorrq_u8(qs, vandq_u8(mask_h, j%QK4_0 ? vshrq_n_u8(qh, 2) : qh));
- qs = vaddq_s8(qs, m32);
- v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(qs))));
- v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(qs))));
-#else
- auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
- auto bl = _mm_loadu_si128((const __m128i *)dl->qs);
- uint64_t aux64; std::memcpy(&aux64, dl->qh, 8);
- auto bh = _mm_set_epi64x(aux64, aux64 << 4);
-#ifdef __AVX512F__
- auto ql = _mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32);
- auto qh = _mm_add_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(bl, 4), mask_l), _mm_and_si128(_mm_srli_epi16(bh, 2), mask_h)), m32);
- v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
- v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
-#else
- if (j%QK4_0) {
- bl = _mm_srli_epi16(bl, 4);
- bh = _mm_srli_epi16(bh, 2);
- }
- auto q16 = _mm256_cvtepi8_epi16(_mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32));
- v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))));
- v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))));
-#endif
-#endif
- }
-
-#ifdef __AVX2__
- const __m128i mask_l = _mm_set1_epi8(0x0f);
- const __m128i mask_h = _mm_set1_epi8(0x30);
- const __m128i m32 = _mm_set1_epi8(-32);
-#else
- const uint8x16_t mask_l = vdupq_n_u8(0x0f);
- const uint8x16_t mask_h = vdupq_n_u8(0x30);
- const int8x16_t m32 = vdupq_n_s8(-32);
-#endif
-};
-
-template <int q_step, int k_step>
-struct FlashMS {
-// Something goes wrong when storing and manipulating K*Q as fp16.
-// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
-// As I wasn't able to find where we lose precision, let's comment this out
-// for now and do the K*Q part in fp32.
-//#ifdef __aarch64__
-// using cache_t = float16_t;
-//#else
-// using cache_t = float;
-//#endif
- using cache_t = float;
-
- FlashMS(float scale, float softcap) : vscale(F16::set1(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
-
- inline void init_qstep() {
- for (int j = 0; j < q_step; ++j) {
- S[j] = 0; M[j] = -INFINITY;
- }
- }
-
- inline void update_M(int j, float smax) {
- if (smax == -INFINITY) {
- std::memset(cache + k_step*j, 0, k_step*sizeof(float));
- need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
- return;
- }
- need_scaling[j] = 0;
- if (smax > M[j]) {
- if (M[j] > -INFINITY) {
- float m = expf(M[j] - smax);
- vms[j] = m;
- need_scaling[j] = 1;
- S[j] *= m;
- } else {
- need_scaling[j] = 2;
- S[j] = 0;
- }
- M[j] = smax;
- }
- }
-
-#ifdef __aarch64__
- inline void update_S(int j, float32x4_t * vk) {
- auto vm = vdupq_n_f32(M[j]);
- auto vsum = vdupq_n_f32(0);
- for (int l = 0; l < k_step/4; ++l) {
- vk[l] = v_expf(vsubq_f32(vk[l], vm));
- vsum = vaddq_f32(vsum, vk[l]);
- F16::store(cache + k_step*j + 4*l, vk[l]);
- }
- S[j] += vaddvq_f32(vsum);
- }
-#else
- inline void update_S(int j, F16::Data * vk) {
- auto vm = F16::set1(M[j]);
- for (int l = 0; l < k_step/F16::block_size; ++l) {
- vk[l] = v_expf(F16::sub(vk[l], vm));
- F16::store(cache + k_step*j + F16::block_size*l, vk[l]);
- }
- S[j] += F16::reduce_add<k_step>(vk);
- }
-#endif
-
-#ifdef __aarch64__
- inline float load_and_scale(int j, float32x4_t * vk) {
- float32x4_t vmax = vdupq_n_f32(-INFINITY);
- // Something goes wrong when storing and manipulating K*Q as fp16.
- // It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
- // As I wasn't able to find where we lose precision, let's comment this out
- // for now and do the K*Q part in fp32.
- //if (softcap <= 0.0f) {
- // for (int l = 0; l < k_step/F16::block_size; ++l) {
- // auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
- // vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val));
- // vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val));
- // vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1]));
- // }
- //} else {
- // auto v_softcap = vdupq_n_f32(softcap);
- // for (int l = 0; l < k_step/F16::block_size; ++l) {
- // auto val = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
- // vk[2*l+0] = vcvt_f32_f16(vget_low_f16(val));
- // vk[2*l+1] = vcvt_f32_f16(vget_high_f16(val));
- // vk[2*l+0] = vmulq_f32(v_softcap, v_tanh(vk[2*l+0]));
- // vk[2*l+1] = vmulq_f32(v_softcap, v_tanh(vk[2*l+1]));
- // vmax = vmaxq_f32(vmax, vmaxq_f32(vk[2*l+0], vk[2*l+1]));
- // }
- //}
- auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale));
- if (softcap <= 0.0f) {
- for (int l = 0; l < k_step/4; ++l) {
- vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l));
- vmax = vmaxq_f32(vmax, vk[l]);
- }
- } else {
- auto v_softcap = vdupq_n_f32(softcap);
- for (int l = 0; l < k_step/4; ++l) {
- vk[l] = vmulq_f32(vscale32, vld1q_f32(cache + k_step*j + 4*l));
- vk[l] = vmulq_f32(v_softcap, v_tanh(vk[l]));
- vmax = vmaxq_f32(vmax, vk[l]);
- }
- }
- return vmaxvq_f32(vmax);
- }
- inline float load_apply_mask_and_scale(int j, float32x4_t * vk, const char * mask) {
- auto vzero = vdupq_n_f16(0);
- auto vinf = vdupq_n_f32(-INFINITY);
- for (int l = 0; l < k_step/8; ++l) {
- auto vm = vceqq_f16(vzero, vld1q_f16((const float16_t *)mask + 8*l));
- auto vm1 = vzip1q_u16(vm, vm);
- auto vm2 = vzip2q_u16(vm, vm);
- auto kq = vld1q_f32_x2(cache + k_step*j + 8*l);
- vk[2*l+0] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[0]), vm1),
- vbicq_u32(vreinterpretq_u32_f32(vinf), vm1)));
- vk[2*l+1] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[1]), vm2),
- vbicq_u32(vreinterpretq_u32_f32(vinf), vm2)));
- }
- float32x4_t vmax = vdupq_n_f32(-INFINITY);
- auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale));
- if (softcap <= 0.0f) {
- for (int l = 0; l < k_step/4; ++l) {
- vk[l] = vmulq_f32(vscale32, vk[l]);
- vmax = vmaxq_f32(vmax, vk[l]);
- }
- } else {
- auto v_softcap = vdupq_n_f32(softcap);
- for (int l = 0; l < k_step/4; ++l) {
- vk[l] = vmulq_f32(vscale32, vk[l]);
- vk[l] = vmulq_f32(v_softcap, v_tanh(vk[l]));
- vmax = vmaxq_f32(vmax, vk[l]);
- }
- }
- return vmaxvq_f32(vmax);
- }
-#else
- inline float load_and_scale(int j, F16::Data * vk) {
- if (softcap <= 0.0f) {
- for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
- } else {
- auto v_softcap = F16::set1(softcap);
- for (int l = 0; l < k_step/F16::block_size; ++l) {
- auto val = F16::load(cache + k_step*j + F16::block_size*l);
- vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, val)));
- }
- }
- return F16::reduce_max<k_step>(vk);
- }
- static inline __m256 apply_mask(int l, const char * mask, __m256 val, [[maybe_unused]] __m256 vinf) {
- return _mm256_add_ps(val, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)mask+l)));
- //auto m128 = _mm_loadu_si128((const __m128i *)mask+l);
- //m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128());
- //auto m256 = _mm256_cvtepi16_epi32(m128);
- //auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16)));
- //return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf));
- }
-#ifdef __AVX512F__
- static inline __m512 apply_mask(int l, const char * mask, __m512 val, __m512 vinf) {
- auto m256 = _mm256_loadu_si256((const __m256i *)mask+l);
- m256 = _mm256_cmpeq_epi16(m256, _mm256_setzero_si256());
- auto m512 = _mm512_cvtepi16_epi32(m256);
- auto mf = _mm512_castsi512_ps(_mm512_or_si512(m512, _mm512_slli_epi32(m512, 16)));
- return _mm512_or_ps(_mm512_and_ps(mf, val), _mm512_andnot_ps(mf, vinf));
- }
-#endif
- inline float load_apply_mask_and_scale(int j, F16::Data * vk, const char * mask) {
-#ifdef HAVE_FANCY_SIMD
- auto vzero = _mm256_set1_epi16(0);
- auto vinf = _mm512_set1_ps(-INFINITY);
- if (softcap <= 0) {
- for (int l = 0; l < k_step/F16::block_size; ++l) {
- auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero);
- vk[l] = _mm512_mask_mul_ps(vinf, m16, vscale, F16::load(cache + k_step*j + F16::block_size*l));
- }
- } else {
- auto v_softcap = F16::set1(softcap);
- for (int l = 0; l < k_step/F16::block_size; ++l) {
- auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero);
- vk[l] = _mm512_mask_mul_ps(vinf, m16, v_softcap, v_tanh(F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l))));
- }
- }
-#else
- auto vinf = F16::set1(-INFINITY);
- for (int l = 0; l < k_step/F16::block_size; ++l) {
- vk[l] = apply_mask(l, mask, F16::load(cache + k_step*j + F16::block_size*l), vinf);
- }
- if (softcap <= 0) {
- for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, vk[l]);
- } else {
- auto v_softcap = F16::set1(softcap);
- for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, vk[l])));
- }
-#endif
- return F16::reduce_max<k_step>(vk);
- }
-#endif
-
-#ifdef __aarch64__
- inline void update_M_S(int j, float32x4_t * vk) {
- float smax = load_and_scale(j, vk);
- update_M(j, smax);
- if (M[j] > -INFINITY) update_S(j, vk);
- }
- inline void update_M_S(int j, float32x4_t * vk, const char * mask) {
- float smax = load_apply_mask_and_scale(j, vk, mask);
- update_M(j, smax);
- if (M[j] > -INFINITY) update_S(j, vk);
- }
-#else
- inline void update_M_S(int j, F16::Data * vk) {
- float smax = load_and_scale(j, vk);
- update_M(j, smax);
- if (M[j] > -INFINITY) update_S(j, vk);
- }
- inline void update_M_S(int j, F16::Data * vk, const char * mask) {
- float smax = load_apply_mask_and_scale(j, vk, mask);
- update_M(j, smax);
- if (M[j] > -INFINITY) update_S(j, vk);
- }
-#endif
-
- cache_t cache[q_step*k_step];
- float S[q_step], M[q_step];
- int need_scaling[q_step];
- float vms[q_step];
- const F16::Data vscale;
- const float softcap;
- const ggml_half h_inf;
-
-};
-
-template <int D, int q_step, int k_step>
-struct FlashQKV {
-
-#ifdef __aarch64__
- using qkv_cache_t = float16_t;
-#else
- using qkv_cache_t = float;
-#endif
-
- template <typename VHelper>
- inline void accumulate_qkv_1(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
- F16::Data vq[D/F16::block_size];
- if (fms.need_scaling[0] == 2) {
- for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::zero();
- } else {
- for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::load(qkv_cache + F16::block_size*i);
- if (fms.need_scaling[0] == 1) {
- auto vms = F16::set1(fms.vms[0]);
- for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::mul(vms, vq[i]);
- }
- }
- F16::Data v0, v1;
- for (int l = 0; l < k_step; l += 4) {
- auto vs0 = F16::set1(fms.cache[l + 0]);
- auto vs1 = F16::set1(fms.cache[l + 1]);
- auto vs2 = F16::set1(fms.cache[l + 2]);
- auto vs3 = F16::set1(fms.cache[l + 3]);
- for (int i = 0; i < D/F16::block_size; i += 2) {
- vh.load(l+0, i, v0, v1);
- vq[i+0] = F16::fmadd(vq[i+0], v0, vs0);
- vq[i+1] = F16::fmadd(vq[i+1], v1, vs0);
- vh.load(l+1, i, v0, v1);
- vq[i+0] = F16::fmadd(vq[i+0], v0, vs1);
- vq[i+1] = F16::fmadd(vq[i+1], v1, vs1);
- vh.load(l+2, i, v0, v1);
- vq[i+0] = F16::fmadd(vq[i+0], v0, vs2);
- vq[i+1] = F16::fmadd(vq[i+1], v1, vs2);
- vh.load(l+3, i, v0, v1);
- vq[i+0] = F16::fmadd(vq[i+0], v0, vs3);
- vq[i+1] = F16::fmadd(vq[i+1], v1, vs3);
- }
- }
- for (int i = 0; i < D/F16::block_size; ++i) F16::store(qkv_cache + F16::block_size*i, vq[i]);
- }
-
- // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
- // Hence, for now, we will not handle head sizes of 80 and 112
- template <typename VHelper>
- inline void accumulate_qkv(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
- if constexpr (q_step == 1) {
- accumulate_qkv_1(vh, fms);
- return;
- }
- for (int j = 0; j < q_step; ++j) {
- auto R = qkv_cache + D*j;
- if (fms.need_scaling[j] == 2) {
- std::memset(R, 0, D*sizeof(qkv_cache_t));
- }
- else if (fms.need_scaling[j] == 1) {
- auto vms = F16::set1(fms.vms[j]);
- for (int i = 0; i < D/F16::block_size; ++i) {
- F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i)));
- }
- }
- }
-#ifdef __AVX512F__
- if constexpr ((D/F16::block_size)%4 == 0) {
- F16::Data v[16];
- F16::Data vs[4];
- for (int i = 0; i < D/F16::block_size; i += 4) {
- for (int l = 0; l < k_step; l += 4) {
- for (int k = 0; k < 4; ++k) {
- vh.load(l+k, i+0, v[4*k+0], v[4*k+1]);
- vh.load(l+k, i+2, v[4*k+2], v[4*k+3]);
- }
- for (int j = 0; j < q_step; ++j) {
- auto R = qkv_cache + D*j;
- auto s1 = F16::load(R + F16::block_size*(i+0));
- auto s2 = F16::load(R + F16::block_size*(i+1));
- auto s3 = F16::load(R + F16::block_size*(i+2));
- auto s4 = F16::load(R + F16::block_size*(i+3));
- F16::set4(fms.cache + k_step*j + l, vs);
- for (int k = 0; k < 4; ++k) {
- s1 = F16::fmadd(s1, v[4*k+0], vs[k]);
- s2 = F16::fmadd(s2, v[4*k+1], vs[k]);
- s3 = F16::fmadd(s3, v[4*k+2], vs[k]);
- s4 = F16::fmadd(s4, v[4*k+3], vs[k]);
- }
- F16::store(R + F16::block_size*(i+0), s1);
- F16::store(R + F16::block_size*(i+1), s2);
- F16::store(R + F16::block_size*(i+2), s3);
- F16::store(R + F16::block_size*(i+3), s4);
- }
- }
- }
- return;
- }
-#endif
- F16::Data v[8];
-#ifdef __AVX2__
- F16::Data vs[4];
-#endif
- for (int i = 0; i < D/F16::block_size; i += 2) {
- for (int l = 0; l < k_step; l += 4) {
- vh.load(l+0, i, v[0], v[4]);
- vh.load(l+1, i, v[1], v[5]);
- vh.load(l+2, i, v[2], v[6]);
- vh.load(l+3, i, v[3], v[7]);
- for (int j = 0; j < q_step; ++j) {
- auto R = qkv_cache + D*j;
- auto s1 = F16::load(R + F16::block_size*(i+0));
- auto s2 = F16::load(R + F16::block_size*(i+1));
-#ifdef __AVX2__
- F16::set4(fms.cache + k_step*j + l, vs);
- for (int k = 0; k < 4; ++k) {
- s1 = F16::fmadd(s1, v[k+0], vs[k]);
- s2 = F16::fmadd(s2, v[k+4], vs[k]);
- }
-#else
- auto vs = F16::set4(fms.cache + k_step*j + l);
- s1 = F16::fmadd_lane0(s1, v[0], vs);
- s2 = F16::fmadd_lane0(s2, v[4], vs);
- s1 = F16::fmadd_lane1(s1, v[1], vs);
- s2 = F16::fmadd_lane1(s2, v[5], vs);
- s1 = F16::fmadd_lane2(s1, v[2], vs);
- s2 = F16::fmadd_lane2(s2, v[6], vs);
- s1 = F16::fmadd_lane3(s1, v[3], vs);
- s2 = F16::fmadd_lane3(s2, v[7], vs);
-#endif
- F16::store(R + F16::block_size*(i+0), s1);
- F16::store(R + F16::block_size*(i+1), s2);
- }
- }
- }
- }
-
- template <typename VHelper>
- inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
- if (nq1 == 1) {
- accumulate_qkv_1(vh, fms);
- return;
- }
- F16::Data v[8];
- for (int j = 0; j < nq1; ++j) {
- auto R = qkv_cache + D*j;
- if (fms.need_scaling[j] == 2) {
- std::memset(R, 0, D*sizeof(qkv_cache_t));
- }
- else if (fms.need_scaling[j] == 1) {
- auto vms = F16::set1(fms.vms[j]);
- for (int i = 0; i < D/F16::block_size; ++i) {
- F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i)));
- }
- }
- }
- for (int i = 0; i < D/F16::block_size; i += 2) {
- for (int l = 0; l < k_step; l += 4) {
- vh.load(l+0, i, v[0], v[4]);
- vh.load(l+1, i, v[1], v[5]);
- vh.load(l+2, i, v[2], v[6]);
- vh.load(l+3, i, v[3], v[7]);
- for (int j = 0; j < nq1; ++j) {
- auto R = qkv_cache + D*j;
- auto s1 = F16::load(R + F16::block_size*(i+0));
- auto s2 = F16::load(R + F16::block_size*(i+1));
- auto vs = F16::set4(fms.cache + k_step*j + l);
- s1 = F16::fmadd_lane0(s1, v[0], vs);
- s2 = F16::fmadd_lane0(s2, v[4], vs);
- s1 = F16::fmadd_lane1(s1, v[1], vs);
- s2 = F16::fmadd_lane1(s2, v[5], vs);
- s1 = F16::fmadd_lane2(s1, v[2], vs);
- s2 = F16::fmadd_lane2(s2, v[6], vs);
- s1 = F16::fmadd_lane3(s1, v[3], vs);
- s2 = F16::fmadd_lane3(s2, v[7], vs);
- F16::store(R + F16::block_size*(i+0), s1);
- F16::store(R + F16::block_size*(i+1), s2);
- }
- }
- }
- }
-
- inline void normalize_and_store_1row(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
- GGML_ASSERT(fms.S[j] > 0);
- auto norm = F16::set1(1/fms.S[j]);
- //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
- for (int i = 0; i < D/F16::block_size; ++i) {
- auto r = F16::load(R + F16::block_size*i);
- F16::store(qkv + F16::block_size*i, F16::mul(norm, r));
- }
- }
-
- inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const {
- if (M && S) {
- std::memcpy(M, fms.M, nq1*sizeof(float));
- std::memcpy(S, fms.S, nq1*sizeof(float));
- auto R = qkv_cache;
- for (int j = 0; j < nq1; ++j) {
-#ifdef __aarch64__
- for (int i = 0; i < D/F16::block_size; ++i) {
- F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i));
- }
-#else
- std::memcpy(qkv, R, D*sizeof(float));
-#endif
- qkv += stride_qkv;
- R += D;
- }
- } else {
- auto R = qkv_cache;
- for (int j = 0; j < nq1; ++j) {
- normalize_and_store_1row(fms, j, R, qkv);
- qkv += stride_qkv;
- R += D;
- }
- }
- }
-
- inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int stride_qkv, float * qkv, float * M, float * S) const {
- if (M && S) {
- std::memcpy(M, fms.M, q_step*sizeof(float));
- std::memcpy(S, fms.S, q_step*sizeof(float));
- auto R = qkv_cache;
- for (int j = 0; j < q_step; ++j) {
-#ifdef __aarch64__
- for (int i = 0; i < D/F16::block_size; ++i) {
- F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i));
- }
-#else
- std::memcpy(qkv, R, D*sizeof(float));
-#endif
- qkv += stride_qkv;
- R += D;
- }
- } else {
- auto R = qkv_cache;
- for (int j = 0; j < q_step; ++j) {
- normalize_and_store_1row(fms, j, R, qkv);
- qkv += stride_qkv;
- R += D;
- }
- }
- }
-
- // qkv_cache_t qkv_cache[D*q_step];
- // The initializer is not actually required. But the compiler cannot figure out that when qkv_cache is
- // first used for q_step rows, fms.need_scaling[j] is always 2, which zeroes the content of qkv_cache.
- // As a result, we get an infinite stream of warnings about uninitialized variable use (one for each
- // combination of D, q_step, k_step), which is extremely annoying. Hence, I succumb to the trend of
- // constantly being saved by others (the compiler in this case), and add this 100% unnecessary initialization.
- qkv_cache_t qkv_cache[D*q_step]; // = {};
- //qkv_cache_t * qkv_cache;
-};
-
-template <int D, int q_step, int k_step>
-struct FlashQKfp32 {
- static_assert(D%F16::block_size == 0 && D <= 576);
- static_assert(k_step%F16::block_size == 0);
- static_assert(q_step <= 4 || q_step%4 == 0);
-
-#ifdef __AVX2__
- template <typename KHelper, typename q_float>
- static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
- FlashMS<q_step, k_step>& fms) {
-#ifdef HAVE_FANCY_SIMD
- constexpr int nrc_q = 8;
- constexpr int nrc_k = 8;
-#else
- // somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4
- constexpr int nrc_q = 4;
- constexpr int nrc_k = 8;
-#endif
- constexpr int qrem = q_step - nrc_q*(q_step/nrc_q);
- constexpr int krem = k_step - nrc_k*(k_step/nrc_k);
- static_assert(krem == 0);
- DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
- for (int iq = 0; iq < q_step/nrc_q; ++iq) {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, nrc_q>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- info.cur_y += nrc_q;
- }
- if constexpr (qrem > 0) {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, qrem>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- }
- F16::Data vk[k_step/F16::block_size];
- for (int j = 0; j < q_step; ++j) {
- fms.update_M_S(j, vk, mask + stride_m*j);
- }
- }
-#else
- template <typename KHelper, typename q_float>
- static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
- FlashMS<q_step, k_step>& fms) {
- constexpr int nrc_q = 4;
- constexpr int nrc_k = 6;
- constexpr int qrem = q_step - nrc_q*(q_step/nrc_q);
- constexpr int krem = k_step - nrc_k*(k_step/nrc_k);
- DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
- for (int iq = 0; iq < q_step/nrc_q; ++iq) {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_f16_f16_NxN<nrc_q, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- if constexpr (krem > 0) {
- mul_mat_f16_f16_NxN<nrc_q, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
- }
- info.cur_y += nrc_q;
- }
- if constexpr (qrem > 0) {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_f16_f16_NxN<qrem, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- if constexpr (krem > 0) {
- mul_mat_f16_f16_NxN<qrem, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
- }
- }
- float32x4_t vk[k_step/4];
- for (int j = 0; j < q_step; ++j) {
- fms.update_M_S(j, vk, mask + stride_m*j);
- }
- }
-#endif
-
-#ifdef __AVX2__
- template <typename KHelper, typename q_float>
- static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
- FlashMS<q_step, k_step>& fms) {
-#ifdef HAVE_FANCY_SIMD
- constexpr int nrc_q = 8;
- constexpr int nrc_k = 8;
-#else
- // somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4
- constexpr int nrc_q = 4;
- constexpr int nrc_k = 8;
-#endif
- static_assert(k_step%nrc_k == 0);
- int qrem = nq - nrc_q*(nq/nrc_q);
- DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
- for (int iq = 0; iq < nq/nrc_q; ++iq) {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, nrc_q>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- info.cur_y += nrc_q;
- }
- if (qrem > 0) {
- switch (qrem) {
- case 1: {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 1>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- } break;
- case 2: {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 2>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- } break;
- case 3: {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 3>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- } break;
-#ifdef HAVE_FANCY_SIMD
- case 4: {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 4>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- } break;
- case 5: {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 5>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- } break;
- case 6: {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 6>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- } break;
- case 7: {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 7>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- } break;
-#endif
- }
- }
- F16::Data vk[k_step/F16::block_size];
- for (int j = 0; j < nq; ++j) {
- fms.update_M_S(j, vk, mask + stride_m*j);
- }
- }
-#else
- template <typename KHelper, typename q_float>
- static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
- FlashMS<q_step, k_step>& fms) {
- constexpr int nrc_q = 4;
- constexpr int nrc_k = 6;
- constexpr int krem = k_step - nrc_k*(k_step/nrc_k);
- const int qrem = q_step - nrc_q*(q_step/nrc_q);
- DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
- for (int iq = 0; iq < nq/nrc_q; ++iq) {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_f16_f16_NxN<nrc_q, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- if constexpr (krem > 0) {
- mul_mat_f16_f16_NxN<nrc_q, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
- }
- info.cur_y += nrc_q;
- }
- switch (qrem) {
- case 0: break;
- case 1: {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_f16_f16_NxN<1, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- if constexpr (krem > 0) {
- mul_mat_f16_f16_NxN<1, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
- }
- } break;
- case 2: {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_f16_f16_NxN<2, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- if constexpr (krem > 0) {
- mul_mat_f16_f16_NxN<2, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
- }
- } break;
- case 3: {
- for (int ik = 0; ik < k_step/nrc_k; ++ik) {
- mul_mat_f16_f16_NxN<3, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
- }
- if constexpr (krem > 0) {
- mul_mat_f16_f16_NxN<3, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
- }
- } break;
- }
- float32x4_t vk[k_step/4];
- for (int j = 0; j < q_step; ++j) {
- fms.update_M_S(j, vk, mask + stride_m*j);
- }
- }
-#endif
-
-#ifdef __aarch64__
- static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) {
- for (int i = 0; i < nq; ++i) {
- for (int j = 0; j < D; j += 8) {
- auto val1_f32 = vld1q_f32(q + j + 0);
- auto val2_f32 = vld1q_f32(q + j + 4);
- auto val_f16 = vcombine_f16(vcvt_f16_f32(val1_f32), vcvt_f16_f32(val2_f32));
- vst1q_f16(q_f16 + j, val_f16);
- }
- q += stride_q;
- q_f16 += D;
- }
- }
-#endif
-
- template <typename KHelper>
- static inline std::pair<mul_mat_t, int> mul_mat_kernel(int nq) {
- constexpr int kMaxQ = 8;
-#define MAKE_FUNCS(mul_mat, n) \
- if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ>, kMaxQ);\
- else {\
- switch (n) {\
- case 1: return std::make_pair(mul_mat, 1>, 1);\
- case 2: return std::make_pair(mul_mat, 2>, 2);\
- case 3: return std::make_pair(mul_mat, 3>, 3);\
- case 4: return std::make_pair(mul_mat, 4>, 4);\
- case 5: return std::make_pair(mul_mat, 5>, 5);\
- case 6: return std::make_pair(mul_mat, 6>, 6);\
- case 7: return std::make_pair(mul_mat, 7>, 7);\
- }\
- }
-#define MAKE_FUNCS_ONLY_NRC(mul_mat, n) \
- if (n >= kMaxQ) return std::make_pair(mul_mat<kMaxQ>, kMaxQ);\
- else {\
- switch (n) {\
- case 1: return std::make_pair(mul_mat<1>, 1);\
- case 2: return std::make_pair(mul_mat<2>, 2);\
- case 3: return std::make_pair(mul_mat<3>, 3);\
- case 4: return std::make_pair(mul_mat<4>, 4);\
- case 5: return std::make_pair(mul_mat<5>, 5);\
- case 6: return std::make_pair(mul_mat<6>, 6);\
- case 7: return std::make_pair(mul_mat<7>, 7);\
- }\
- }
- if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
-#ifdef __aarch64__
- MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
-#else
-#ifdef HAVE_FANCY_SIMD
- if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 1, k_step>, 1);
- if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 2, k_step>, 2);
- if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 4, k_step>, 4);
- MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q8_0_1_Unpacker, nq);
-#else
- if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 1, k_step>, 1);
- if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 2, k_step>, 2);
- if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 4, k_step>, 4);
- MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
-#endif
-#endif
- }
- else if constexpr (std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
-#ifdef __aarch64__
- if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
- if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1);
- MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
-#else
- if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1);
-#ifdef HAVE_FANCY_SIMD
- if constexpr (D%32 == 0 && k_step%8 == 0) {
- if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV_8<16>, 16);
- MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV_8, nq);
- } else {
- if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
- }
-#endif
- MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
-#endif
- }
- else if constexpr (std::is_same_v<KHelper, HelperQ80R8<D, k_step>>) {
-#ifdef __aarch64__
- MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq);
-#else
- MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_2, nq);
-#endif
- }
- else if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>>) {
- MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq);
- }
- else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
-#ifdef __aarch64__
- MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
-#else
- if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 1, k_step>, 1);
- if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 2, k_step>, 2);
- if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 4, k_step>, 4);
- MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q6_0_1_Unpacker, nq);
-#endif
- }
- else if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
-#ifdef __aarch64__
- MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
-#else
- if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 1, k_step>, 1);
- if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 2, k_step>, 2);
- if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 4, k_step>, 4);
- MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_0_1_Unpacker, nq);
-#endif
- }
-#if GGML_IQK_FA_ALL_QUANTS
- else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
-#ifdef __aarch64__
- MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq);
-#else
- MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_1_Unpacker, nq);
-#endif
- }
- else if constexpr (std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
-#ifdef __aarch64__
- MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerIQ4NL, nq);
-#else
-#ifdef HAVE_FANCY_SIMD
- MAKE_FUNCS(mul_mat_qX_1_q8_2_T<IQ4_NL_Unpacker, nq);
-#else
- MAKE_FUNCS(mul_mat_qX_0_q8_0_T<IQ4_NL_Unpacker, nq);
-#endif
-#endif
- }
-#endif
- else {
- GGML_ASSERT(false);
- }
- return std::make_pair<mul_mat_t, int>(nullptr, 0);
- }
-
- template <typename KHelper, typename block_q8>
- static inline void mul_mask_kq(const KHelper& kh, int stride_m,
- const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
- constexpr int kMaxQ = 8;
- static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
- auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(q_step);
- DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
- for (int iq = 0; iq < q_step/nrc_q; ++iq) {
- mul_mat(D, kh.block, kh.stride, info, k_step);
- info.cur_y += nrc_q;
- }
-#ifdef __aarch64__
- float32x4_t vk[k_step/4];
- for (int j = 0; j < q_step; ++j) {
- fms.update_M_S(j, vk, mask + stride_m*j);
- }
-#else
- F16::Data vk[k_step/F16::block_size];
- for (int j = 0; j < q_step; ++j) {
- fms.update_M_S(j, vk, mask + stride_m*j);
- }
-#endif
- }
-
- template <typename KHelper, typename block_q8>
- static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m,
- const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
- auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(nq);
- DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
- for (int iq = 0; iq < nq/nrc_q; ++iq) {
- mul_mat(D, kh.block, kh.stride, info, k_step);
- info.cur_y += nrc_q;
- }
- int iq = nrc_q*(nq/nrc_q);
- if (iq < nq) {
- auto [mul_mat1, nrc_q1] = mul_mat_kernel<KHelper>(nq - iq);
- GGML_ASSERT(nrc_q1 == nq - iq);
- mul_mat1(D, kh.block, kh.stride, info, k_step);
- }
-#ifdef __aarch64__
- float32x4_t vk[k_step/4];
- for (int j = 0; j < nq; ++j) {
- fms.update_M_S(j, vk, mask + stride_m*j);
- }
-#else
- F16::Data vk[k_step/F16::block_size];
- for (int j = 0; j < nq; ++j) {
- fms.update_M_S(j, vk, mask + stride_m*j);
- }
-#endif
- }
-};
-
-template <int Dk, int Dv, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
-void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
- FlashMS<q_step, k_step>& fms,
- FlashQKV<Dv, q_step, k_step>& fqkv,
- const float * q, const char * mask, float * qkv,
- float * M, float * S) {
-#ifdef __aarch64__
- float16_t q_f16[Dk*q_step];
-#endif
-
- for (int i1 = 0; i1 < nq1/q_step; ++i1) {
- fms.init_qstep();
- kh.reset_block();
- vh.reset_block();
-#ifdef __aarch64__
- KQHelper::convert(q_step, stride_q, q, q_f16);
-#endif
- auto mr = mask;
- for (int k1 = 0; k1 < nk1/k_step; ++k1) {
-#ifdef __aarch64__
- KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms);
-#else
- KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms);
-#endif
- fqkv.accumulate_qkv(vh, fms);
- kh.next_block();
- vh.next_block();
- mr += k_step*sizeof(ggml_half);
- }
- fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
-
- q += q_step*stride_q;
- mask += q_step*stride_m;
- qkv += q_step*stride_qkv;
- if (M && S) { M += q_step; S += q_step; }
- }
- int n_left = nq1 - q_step*(nq1/q_step);
- if (n_left > 0) {
- fms.init_qstep();
- kh.reset_block();
- vh.reset_block();
-#ifdef __aarch64__
- KQHelper::convert(n_left, stride_q, q, q_f16);
-#endif
- auto mr = mask;
- for (int k1 = 0; k1 < nk1/k_step; ++k1) {
-#ifdef __aarch64__
- KQHelper::multiply_mask_kq(n_left, kh, Dk, stride_m, q_f16, mr, fms);
-#else
- KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms);
-#endif
- fqkv.accumulate_qkv(n_left, vh, fms);
- kh.next_block();
- vh.next_block();
- mr += k_step*sizeof(ggml_half);
- }
- fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
- }
-}
-
-template <int Dk, int Dv, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
-void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
- FlashMS<q_step, k_step>& fms,
- FlashQKV<Dv, q_step, k_step>& fqkv,
- const float * q, const char * mask, float * qkv,
- float * M, float * S, char * qptr) {
- auto q8 = (typename KHelper::block_q8 *)qptr;
- if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
- if (nq1 == q_step) {
- fms.init_qstep();
- kh.reset_block();
- vh.reset_block();
- block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8];
- HelperQ80R8<Dk, k_step> khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0));
- auto q8r = (typename HelperQ80R8<Dk, k_step>::block_q8 *)qptr;
- HelperQ80<Dk, QK8_0>::convert(q_step, stride_q, q, q8r);
- auto mr = mask;
- for (int k1 = 0; k1 < nk1/k_step; ++k1) {
- HelperQ80R8<Dk, k_step>::repack(k_step, kh.block, kh.stride, q8r8);
- KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
- fqkv.accumulate_qkv(vh, fms);
- kh.next_block();
- vh.next_block();
- mr += k_step*sizeof(ggml_half);
- }
- fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
- return;
- }
- }
-#if FA_TIMING
- Perf perf(false);
-#endif
- for (int i1 = 0; i1 < nq1/q_step; ++i1) {
-#if FA_TIMING
- auto t1 = Perf::cur_time();
-#endif
- fms.init_qstep();
- kh.reset_block();
- vh.reset_block();
- HelperQ80<Dk, QK8_0>::convert(q_step, stride_q, q, q8);
-#if FA_TIMING
- perf.accum_nolock(0, t1);
-#endif
- auto mr = mask;
- for (int k1 = 0; k1 < nk1/k_step; ++k1) {
-#if FA_TIMING
- t1 = Perf::cur_time();
- KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
- perf.accum_nolock(1, t1);
- t1 = Perf::cur_time();
- fqkv.accumulate_qkv(vh, fms);
- perf.accum_nolock(2, t1);
-#else
- KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
- fqkv.accumulate_qkv(vh, fms);
-#endif
- kh.next_block();
- vh.next_block();
- mr += k_step*sizeof(ggml_half);
- }
-#if FA_TIMING
- t1 = Perf::cur_time();
- fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
- perf.accum_nolock(3, t1);
-#else
- fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
-#endif
-
- q += q_step*stride_q;
- mask += q_step*stride_m;
- qkv += q_step*stride_qkv;
- if (M && S) { M += q_step; S += q_step; }
- }
- int n_left = nq1 - q_step*(nq1/q_step);
- if (n_left > 0) {
- fms.init_qstep();
- kh.reset_block();
- vh.reset_block();
- HelperQ80<Dk, QK8_0>::convert(n_left, stride_q, q, q8);
- auto mr = mask;
- for (int k1 = 0; k1 < nk1/k_step; ++k1) {
- KQHelper::mul_mask_kq(n_left, kh, stride_m, q8, mr, fms);
- fqkv.accumulate_qkv(n_left, vh, fms);
- kh.next_block();
- vh.next_block();
- mr += k_step*sizeof(ggml_half);
- }
- fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
- }
-#if FA_TIMING
- Perf::instance().add(perf);
-#endif
-}
-
-char * get_q_storage(size_t size) {
- thread_local std::vector<char> q_storage;
- if (q_storage.size() < size) q_storage.resize(size);
- return q_storage.data();
-}
-
-// Some of the methods in FlashAttn have two identical implementations that only differ by
-// one version using a loop over the template parameter q_step, while the other using a loop
-// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot,
-// but performance drops signficantly if I remove the version with fixed q_step iterations.
-// We only instantiate FlashAttn with q_step = 1 and q_step = 4 or 8 (depending on head size D),
-// so when we have to process Nq rows, we process q_step*(Nq/q_step) using fixed q_step loops,
-// and use the variable nq version (with lower performance) only for the remaining i1...q_step-1
-// rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to
-// process template parameter of such functions, but this would result in the compiler generating
-// q_step-1 versions of these functions for us, which I though was too much with q_step = 8.
-template <int Dk, int Dv, int q_step, int k_step>
-struct FlashAttn {
- static_assert(Dk%F16::block_size == 0 && Dk <= 576);
- static_assert(Dv%F16::block_size == 0 && Dv <= 512);
- static_assert(k_step%F16::block_size == 0);
- static_assert(q_step <= 4 || q_step%4 == 0);
-
- FlashAttn(float scale, float softcap) : fms(scale, softcap) {}
-
- template <typename KHelper, typename VHelper>
- void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
- const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) {
- if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> ||
- std::is_same_v<KHelper, HelperQ41<Dk, k_step>> ||
- std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> ||
- std::is_same_v<KHelper, HelperQ60<Dk, k_step>> ||
- std::is_same_v<KHelper, HelperQ80R8<Dk, k_step>> ||
- std::is_same_v<KHelper, HelperQ80<Dk, k_step>> ||
- std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> ||
- std::is_same_v<KHelper, HelperQ8KVR8<Dk, k_step>>) {
- constexpr size_t kMaxOnStackSize = 576;
- //auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8);
- auto q_size = q_step*(Dk/QK8_2*sizeof(block_q8_2));
- q_size = GGML_PAD(q_size, 64);
- if (q_size > kMaxOnStackSize) {
- auto qptr = get_q_storage(q_size);
- if (false && nq1 >= 8) {
- if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
-#if FA_TIMING
- auto t1 = Perf::cur_time();
- HelperQ80R8<Dk, k_step> khr4(nk1, kh);
- Perf::instance().accum(4, t1);
-#else
- HelperQ80R8<Dk, k_step> khr4(nk1, kh);
-#endif
- compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
- return;
-
- }
- if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
-#if FA_TIMING
- auto t1 = Perf::cur_time();
- HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
- Perf::instance().accum(4, t1);
-#else
- HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
-#endif
- compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
- return;
- }
- }
- compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
-
- }
- else {
- typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
- compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, (char *)q8);
- }
- }
- else {
- compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
- }
- }
-
- FlashMS<q_step, k_step> fms;
- FlashQKV<Dv, q_step, k_step> fqkv;
-
-};
-
-#ifdef __AVX512BF16__
-
-template <int D, int step>
-struct HelperBF16 final : public BaseHelper<step> {
- using Base = BaseHelper<step>;
- HelperBF16(const char * data, int stride) : Base(data, stride) {}
- inline void load(int l1, __m512bh * vk) const {
- auto dr = Base::lblock(l1);
- for (int i = 0; i < D/32; ++i) vk[i] = __m512bh(_mm512_loadu_si512((const __m512i*)dr + i));
- }
-
- inline void load(int l1, int i, __m512& v1, __m512& v2) const {
- auto dr = Base::lblock(l1);
- v1 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)dr + i + 0)), 16));
- v2 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)dr + i + 1)), 16));
- }
-
- inline void load_2(int l1, __m512bh * vk) const {
- load(l1+0, vk+0);
- load(l1+1, vk+D/32);
- }
-
- inline void load_4(int l1, __m512bh * vk) const {
- load(l1+0, vk+0);
- load(l1+1, vk+1*D/32);
- load(l1+2, vk+2*D/32);
- load(l1+3, vk+3*D/32);
- }
-
- inline void load_8(int l1, __m512bh * vk) const {
- for (int k = 0; k < 8; ++k) load(l1 + k, vk + k*D/32);
- }
-};
-
-template <int D, int q_step, int k_step>
-struct FlashQKbf16 {
- //static_assert(D%32 == 0 && D <= 256);
- static_assert(D%32 == 0 && D <= 576);
- static_assert(k_step%32 == 0);
- static_assert(q_step <= 4 || q_step%4 == 0);
-
- static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
- __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
- // q index is q_step*i1 + m1
- // k index is k_step*k1 + l1
- const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
- fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY;
- if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) {
- return;
- }
- auto qr = q + m1*stride_q;
- for (int i = 0; i < D/32; ++i) {
- auto val1 = _mm512_loadu_ps(qr + 32*i);
- auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
- qv[i] = _mm512_cvtne2ps_pbh(val2, val1);
- }
- if (mp[l1+0] != fms.h_inf) {
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]);
- fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
- }
- if (mp[l1+1] != fms.h_inf) {
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]);
- fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
- }
- }
-
- static inline void mult_mask_kq_one(int l1, int m1, int stride_m, const ggml_bf16_t * q, const char * mask,
- __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
- // q index is q_step*i1 + m1
- // k index is k_step*k1 + l1
- const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
- fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY;
- if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) {
- return;
- }
- auto qr = q + m1*D;
- for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i*)qr + i));
- if (mp[l1+0] != fms.h_inf) {
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]);
- fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
- }
- if (mp[l1+1] != fms.h_inf) {
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]);
- fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
- }
- }
-
- static inline void mult_mask_kq_4(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
- __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
- // q index is q_step*i1 + m1
- // k index is k_step*k1 + l1
- const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
- fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] =
- fms.cache[k_step*m1 + l1 + 2] = fms.cache[k_step*m1 + l1 + 3] = -INFINITY;
- if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf && mp[l1+2] == fms.h_inf && mp[l1+3] == fms.h_inf) {
- return;
- }
- auto qr = q + m1*stride_q;
- for (int i = 0; i < D/32; ++i) {
- auto val1 = _mm512_loadu_ps(qr + 32*i);
- auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
- qv[i] = _mm512_cvtne2ps_pbh(val2, val1);
- }
- for (int k = 0; k < 4; ++k) {
- if (mp[l1+k] == fms.h_inf) continue;
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
- fms.cache[k_step*m1 + l1 + k] = _mm512_reduce_add_ps(vsum);
- }
- }
-
- static inline void mult_mask_kq_4(int l1, int m1, int stride_m, const ggml_bf16_t * q, const char * mask,
- __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
- // q index is q_step*i1 + m1
- // k index is k_step*k1 + l1
- const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
- fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] =
- fms.cache[k_step*m1 + l1 + 2] = fms.cache[k_step*m1 + l1 + 3] = -INFINITY;
- if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf && mp[l1+2] == fms.h_inf && mp[l1+3] == fms.h_inf) {
- return;
- }
- auto qr = q + m1*D;
- for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i));
- for (int k = 0; k < 4; ++k) {
- if (mp[l1+k] == fms.h_inf) continue;
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
- fms.cache[k_step*m1 + l1 + k] = _mm512_reduce_add_ps(vsum);
- }
- }
-
- static inline __m128 hsum_float_4x4(__m128 * a) {
- for (int i = 0; i < 2; ++i) a[i] = _mm_add_ps(_mm_unpacklo_ps(a[i], a[i+2]), _mm_unpackhi_ps(a[i], a[i+2]));
- return _mm_add_ps(_mm_unpacklo_ps(a[0], a[1]), _mm_unpackhi_ps(a[0], a[1]));
- }
-
- template <typename KHelper>
- static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q,
- const char * mask, FlashMS<q_step, k_step>& fms) {
- {
- __m512bh qv[D/32];
- if constexpr (D <= 128) {
- __m512bh vkh[D/8];
- for (int l1 = 0; l1 < k_step; l1 += 4) {
- kh.load_4(l1, vkh);
- for (int j = 0; j < q_step; ++j) {
- mult_mask_kq_4(l1, j, stride_q, stride_m, q, mask, qv, vkh, fms);
- }
- }
- } else {
- __m512bh vkh[D/16];
- for (int l1 = 0; l1 < k_step; l1 += 2) {
- kh.load_2(l1, vkh);
- for (int j = 0; j < q_step; ++j) {
- mult_mask_kq_one(l1, j, stride_q, stride_m, q, mask, qv, vkh, fms);
- }
- }
- }
- }
- __m512 vk[k_step/16];
- for (int j = 0; j < q_step; ++j) {
- fms.update_M_S(j, vk);
- }
- }
-
- static inline void mult_mask_kq_4(int l1, int m1, const ggml_bf16_t * q,
- __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
- auto qr = q + m1*D;
- for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i));
- __m128 sum[4];
- for (int k = 0; k < 4; ++k) {
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
- auto aux = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1));
- sum[k] = _mm_add_ps(_mm256_castps256_ps128(aux), _mm256_extractf128_ps(aux, 1));
- }
- //auto sum4 = _mm_mask_blend_ps(m8, hsum_float_4x4(sum), _mm_set1_ps(-INFINITY));
- //_mm_storeu_ps(fms.cache + k_step*m1 + l1, sum4);
- _mm_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_4x4(sum));
- }
-
- static IQK_ALWAYS_INLINE __m256 hsum_float_8x8(__m256 * accm) {
- for (int i = 0; i < 4; ++i) {
- accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31));
- //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
- // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
- }
- for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
- return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
- }
-
- static inline void mult_mask_kq_8(int l1, int m1, const ggml_bf16_t * q,
- __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
- auto qr = q + m1*D;
- for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i));
- __m256 sum[8];
- for (int k = 0; k < 8; ++k) {
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
- sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1));
- }
- _mm256_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_8x8(sum));
- }
-
- static inline void mult_mask_kq_one(int l1, int m1, const ggml_bf16_t * q,
- __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
- auto qr = q + m1*D;
- for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i*)qr + i));
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]);
- fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
- vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]);
- fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
- }
-
-#if FA_TIMING
- template <typename KHelper>
- static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
- const char * mask, FlashMS<q_step, k_step>& fms, Perf& perf) {
- auto t1 = Perf::cur_time();
-#else
- template <typename KHelper>
- static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
- const char * mask, FlashMS<q_step, k_step>& fms) {
-#endif
- if constexpr (q_step == 1) {
- __m512bh vq[D/32];
- __m512bh vk[D/32];
- __m256 sum[8];
- for (int i = 0; i < D/32; ++i) vq[i] = __m512bh(_mm512_loadu_si512((const __m512i *)q + i));
- for (int l = 0; l < k_step; l += 8) {
- for (int k = 0; k < 8; ++k) {
- kh.load(l+k, vk);
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vk[i], vq[i]);
- sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1));
- }
- _mm256_storeu_ps(fms.cache + l, hsum_float_8x8(sum));
- }
- }
- else {
- __m512bh qv[D/32];
- if constexpr (D <= 128) {
- __m512bh vkh[D/4];
- for (int l1 = 0; l1 < k_step; l1 += 8) {
- kh.load_8(l1, vkh);
- for (int j = 0; j < q_step; ++j) mult_mask_kq_8(l1, j, q, qv, vkh, fms);
- }
- } else {
- __m512bh vkh[D/16];
- for (int l1 = 0; l1 < k_step; l1 += 2) {
- kh.load_2(l1, vkh);
- for (int j = 0; j < q_step; ++j) mult_mask_kq_one(l1, j, q, qv, vkh, fms);
- }
- }
- }
-#if FA_TIMING
- perf.accum_nolock(1, t1);
- t1 = Perf::cur_time();
-#endif
- F16::Data vk[k_step/16];
- for (int j = 0; j < q_step; ++j) {
- fms.update_M_S(j, vk, mask + stride_m*j);
- }
-#if FA_TIMING
- perf.accum_nolock(2, t1);
-#endif
- }
-
- template <typename KHelper>
- static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_m, const ggml_bf16_t * q,
- const char * mask, FlashMS<q_step, k_step>& fms) {
- {
- __m512bh qv[D/32];
- if constexpr (D <= 128) {
- __m512bh vkh[D/8];
- for (int l1 = 0; l1 < k_step; l1 += 4) {
- kh.load_4(l1, vkh);
- for (int j = 0; j < nq; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms);
- }
- } else {
- __m512bh vkh[D/16];
- for (int l1 = 0; l1 < k_step; l1 += 2) {
- kh.load_2(l1, vkh);
- for (int j = 0; j < nq; ++j) mult_mask_kq_one(l1, j, q, qv, vkh, fms);
- }
- }
- }
- F16::Data vk[k_step/16];
- for (int j = 0; j < nq; ++j) {
- fms.update_M_S(j, vk, mask + stride_m*j);
- }
- }
-
- template <typename KHelper>
- static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q,
- const char * mask, FlashMS<q_step, k_step>& fms) {
- {
- __m512bh qv[D/32];
- __m512bh vkh[D/16];
- for (int l1 = 0; l1 < k_step; l1 += 2) {
- kh.load_2(l1, vkh);
- for (int m1 = 0; m1 < nq; ++m1) {
- mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vkh, fms);
- }
- }
- }
- __m512 vk[k_step/16];
- for (int j = 0; j < nq; ++j) {
- fms.update_M_S(j, vk);
- }
- }
-
- static inline void convert(int stride_q, const float * q, ggml_bf16_t * bf16) {
- auto qr = q;
- for (int j = 0; j < q_step; ++j) {
- for (int i = 0; i < D/32; ++i) {
- auto val1 = _mm512_loadu_ps(qr + 32*i);
- auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
- _mm512_storeu_si512((__m512i *)bf16 + i, (__m512i)_mm512_cvtne2ps_pbh(val2, val1));
- }
- qr += stride_q;
- bf16 += D;
- }
- }
-
- static inline void convert(int nq, int stride_q, const float * q, ggml_bf16_t * bf16) {
- auto qr = q;
- for (int j = 0; j < nq; ++j) {
- for (int i = 0; i < D/32; ++i) {
- auto val1 = _mm512_loadu_ps(qr + 32*i);
- auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
- _mm512_storeu_si512((__m512i *)bf16 + i, (__m512i)_mm512_cvtne2ps_pbh(val2, val1));
- }
- qr += stride_q;
- bf16 += D;
- }
- }
-};
-
-template <int Dk, int Dv, int q_step, int k_step>
-struct FlashAttnBF16 {
- //static_assert(Dk%32 == 0 && Dk <= 256);
- //static_assert(Dv%32 == 0 && Dv <= 256);
- static_assert(Dk%32 == 0 && Dk <= 576);
- static_assert(Dv%32 == 0 && Dv <= 512);
- static_assert(k_step%32 == 0);
- static_assert(q_step <= 4 || q_step%4 == 0);
-
- FlashAttnBF16(float scale, float softcap) : fms(scale, softcap) {}
-
- template <typename KHelper, typename VHelper>
- void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
- const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) {
- ggml_bf16_t q_bf16[q_step*Dk];
-#if FA_TIMING
- Perf perf(false);
-#endif
- for (int i1 = 0; i1 < nq1/q_step; ++i1) {
-#if FA_TIMING
- auto t1 = Perf::cur_time();
-#endif
- fms.init_qstep();
- kh.reset_block();
- vh.reset_block();
- FlashQKbf16<Dk, q_step, k_step>::convert(stride_q, q, q_bf16);
-#if FA_TIMING
- perf.accum_nolock(0, t1);
-#endif
- auto mr = mask;
- for (int k1 = 0; k1 < nk1/k_step; ++k1) {
-#if FA_TIMING
- //t1 = Perf::cur_time();
- FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);
- //perf.accum_nolock(1, t1);
- t1 = Perf::cur_time();
- fqkv.accumulate_qkv(vh, fms);
- perf.accum_nolock(3, t1);
-#else
- FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms);
- fqkv.accumulate_qkv(vh, fms);
-#endif
- kh.next_block();
- vh.next_block();
- mr += k_step*sizeof(ggml_half);
- }
-#if FA_TIMING
- t1 = Perf::cur_time();
-#endif
- fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
-#if FA_TIMING
- perf.accum_nolock(4, t1);
-#endif
-
- q += q_step*stride_q;
- mask += q_step*stride_m;
- qkv += q_step*stride_qkv;
- }
- int n_left = nq1 - q_step*(nq1/q_step);
- if (n_left > 0) {
- fms.init_qstep();
- kh.reset_block();
- vh.reset_block();
- FlashQKbf16<Dk, q_step, k_step>::convert(n_left, stride_q, q, q_bf16);
- auto mr = mask;
- for (int k1 = 0; k1 < nk1/k_step; ++k1) {
- FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms);
- fqkv.accumulate_qkv(n_left, vh, fms);
- kh.next_block();
- vh.next_block();
- mr += k_step*sizeof(ggml_half);
- }
- fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
- }
-#if FA_TIMING
- Perf::instance().add(perf);
-#endif
- }
-
- FlashMS<q_step, k_step> fms;
- FlashQKV<Dv, q_step, k_step> fqkv;
-};
-#endif
-
-template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
-inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
- const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
-
- auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
- nq1 -= n;
- if (nq1 == 0) return true;
- q += n*stride_q;
- mask += n*stride_m;
- qkv += n*stride_qkv;
- if (M && S) { M += n; S += n; }
- return false;
- };
- if (nk1 >= 512) {
- if (nq1 >= 128) {
- int n_step = nq1/128;
- FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
- fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- if (update(128*n_step)) return;
- }
- if (nq1 >= 64) {
- int n_step = nq1/64;
- FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
- fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- if (update(64*n_step)) return;
- }
- if (nq1 >= 32) {
- int n_step = nq1/32;
- FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
- fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- if (update(32*n_step)) return;
- }
- if (nq1 >= 16) {
- int n_step = nq1/16;
- FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
- fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- if (update(16*n_step)) return;
- }
- }
- if (nq1 >= 8) {
- int n_step = nq1/8;
- FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
- fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- if (update(8*n_step)) return;
- }
- else if (nq1 >= 4) {
- int n_step = nq1/4;
- FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap);
- fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- if (update(4*n_step)) return;
- }
- else if (nq1 >= 2) {
- int n_step = nq1/2;
- FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap);
- fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- if (update(2*n_step)) return;
- }
- FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
-}
-
-#ifdef __AVX512BF16__
-template <int Dk, int Dv, int k_step>
-inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
- const float * q, const char * k, const char * v, const char * mask,
- float scale, float softcap, float * qkv, float * M, float * S) {
- HelperBF16<Dk, k_step> kh(k, stride_k);
- HelperBF16<Dv, k_step> vh(v, stride_v);
- if (nk1 >= 4096) {
- if (nq1 >= 64) {
- FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- return;
- }
- else if (nq1 >= 16) {
- FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- return;
- }
- }
- if (nq1 >= 8) {
- FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- } else {
- FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- }
-}
-#endif
-
-template <int Dk, int Dv, int k_step, typename KHelper>
-inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
- int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv,
- const float * q, const char * v, const char * mask,
- float scale, float softcap, float * qkv, float * M, float * S) {
-
- switch (type_v) {
- case GGML_TYPE_F16: {
- HelperF16<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
- } break;
-#ifdef __AVX512BF16__
- case GGML_TYPE_BF16: {
- HelperBF16<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
- } break;
-#endif
- case GGML_TYPE_Q8_0: {
- HelperQ80<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
- } break;
- case GGML_TYPE_Q8_KV: {
- HelperQ8KV<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
- } break;
- case GGML_TYPE_Q6_0: {
- HelperQ60<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
- } break;
- case GGML_TYPE_Q4_0: {
- HelperQ40<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
- } break;
-#if GGML_IQK_FA_ALL_QUANTS
- case GGML_TYPE_Q4_1: {
- HelperQ41<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
- } break;
- case GGML_TYPE_IQ4_NL: {
- HelperIQ4nl<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
- } break;
-#endif
- default: break;
- }
-}
-
-template <int Dk, int Dv, int k_step>
-inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
- int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
- const float * q, const char * k, const char * v, const char * mask,
- float scale, float softcap, float * qkv, float * M, float * S) {
-
- switch (type_k) {
- case GGML_TYPE_F16: {
- HelperF16<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
- } break;
- case GGML_TYPE_Q8_0: {
- HelperQ80<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
- } break;
- case GGML_TYPE_Q8_0_R8: {
- HelperQ80R8<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
- } break;
- case GGML_TYPE_Q8_KV: {
- HelperQ8KV<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
- } break;
- case GGML_TYPE_Q6_0: {
- HelperQ60<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
- } break;
- case GGML_TYPE_Q4_0: {
- HelperQ40<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
- } break;
-#if GGML_IQK_FA_ALL_QUANTS
- case GGML_TYPE_Q4_1: {
- HelperQ41<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
- } break;
- case GGML_TYPE_IQ4_NL: {
- HelperIQ4nl<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
- } break;
-#endif
- default: break;
- }
-
-}
-
-inline bool flash_attn_is_supported(ggml_type type) {
-#ifdef __AVX512BF16__
- if (type == GGML_TYPE_BF16) return true;
-#endif
-#if GGML_IQK_FA_ALL_QUANTS
- if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
- type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL || type == GGML_TYPE_Q8_0_R8) return true;
-#else
- if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV || type == GGML_TYPE_Q8_0_R8
- || type == GGML_TYPE_Q4_0) return true;
-#endif
- return false;
-}
-
-template <int step_k, typename KHelper, typename VHelper>
-inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
- int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
- const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
- auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
- nq1 -= n;
- if (nq1 == 0) return true;
- q += n*stride_q;
- mask += n*stride_m;
- qkv += n*stride_qkv;
- if (M && S) { M += n; S += n; }
- return false;
- };
- if (nq1 >= 16) {
- int n_step = nq1/16;
- FlashAttn<576, 512, 16, step_k> fa(scale, softcap);
- fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
- if (update(16*n_step)) return;
- }
- if (nq1 >= 8) {
- int n_step = nq1/8;
- FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
- fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
- if (update(8*n_step)) return;
- }
- if (nq1 >= 4) {
- int n_step = nq1/4;
- FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
- fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
- if (update(4*n_step)) return;
- }
- if (nq1 >= 2) {
- int n_step = nq1/2;
- FlashAttn<576, 512, 2, step_k> fa(scale, softcap);
- fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
- if (update(2*n_step)) return;
- }
- FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
-}
-
-template <int step_k>
-inline bool iqk_deepseek_helper(ggml_type type_k,
- int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
- const float * q, const char * k, const char * v, const char * mask,
- float scale, float softcap, float * qkv, float * M, float * S) {
- if (type_k == GGML_TYPE_Q8_0) {
- HelperQ80<576, step_k> kh((const char *)k, stride_k);
- HelperQ80<512, step_k> vh((const char *)v, stride_v);
- iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
- return true;
- }
- if (type_k == GGML_TYPE_Q8_0_R8) {
- HelperQ80R8<576, step_k> kh((const char *)k, stride_k);
- HelperQ80<512, step_k> vh((const char *)v, stride_v);
- iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
- return true;
- }
- if (type_k == GGML_TYPE_Q6_0) {
- HelperQ60<576, step_k> kh((const char *)k, stride_k);
- HelperQ60<512, step_k> vh((const char *)v, stride_v);
- iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
- return true;
- }
- if (type_k == GGML_TYPE_Q8_KV) {
- HelperQ8KV<576, step_k> kh((const char *)k, stride_k);
- HelperQ8KV<512, step_k> vh((const char *)v, stride_v);
- iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
- return true;
- }
- if (type_k == GGML_TYPE_F16) {
- HelperF16<576, step_k> kh((const char *)k, stride_k);
- HelperF16<512, step_k> vh((const char *)v, stride_v);
- iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
- return true;
- }
-#ifdef __AVX512BF16__
- if (type_k == GGML_TYPE_BF16) {
- HelperBF16<576, step_k> kh((const char *)k, stride_k);
- HelperBF16<512, step_k> vh((const char *)v, stride_v);
- if (nq1 % 8 == 0) {
- FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- } else {
- FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- }
- return true;
- }
-#endif
- return false;
-}
-
-}
-
#include "iqk_flash_impl.h"
+#include "fa/iqk_fa_templates.h"
bool iqk_flash_attn_impl(int int_type_k, // type of k
int int_type_v, // type of v
@@ -18540,129 +860,37 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
- auto type_k = ggml_type(int_type_k);
- auto type_v = ggml_type(int_type_v);
-
if (Dk == 576 && Dv == 512) {
- GGML_ASSERT(type_k == type_v || (type_k == GGML_TYPE_Q8_0_R8 && type_v == GGML_TYPE_Q8_0));
- stride_q /= sizeof(float); // q stride as float
- return iqk_deepseek_helper<32>(type_k, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
- q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S);
+ return iqk_fa_576_512(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, k, v, mask, scale, softcap, qkv, M, S);
}
- if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;
- if (Dk != Dv && Dk != 192 && Dv != 128) return false;
- if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false;
- if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dk != 256) return false;
-
- auto ck = (const char *)k;
- auto cv = (const char *)v;
- auto cm = (const char *)mask;
-
- stride_q /= sizeof(float); // q stride as float
-
-#ifdef __AVX512BF16__
- if (type_k == GGML_TYPE_BF16) {
- if (nk1%64 == 0) {
- if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
- switch (Dk) {
- case 64:
- iqk_flash_helper_T< 64, 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 96:
- iqk_flash_helper_T< 96, 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 128:
- iqk_flash_helper_T<128, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 192:
- iqk_flash_helper_T<192, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 256:
- iqk_flash_helper_T<256, 256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- default:
- return false;
- }
- return true;
- }
- if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
- switch (Dk) {
- case 64:
- iqk_flash_helper_T< 64, 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 96:
- iqk_flash_helper_T< 96, 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 128:
- iqk_flash_helper_T<128, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 192:
- iqk_flash_helper_T<192, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 256:
- iqk_flash_helper_T<256, 256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- default:
- return false;
- }
+ if (Dk == 192 && Dv == 128) {
+ return iqk_fa_192_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, k, v, mask, scale, softcap, qkv, M, S);
+ }
- return true;
+ if (Dk == 256 && Dv == 256) {
+ return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, k, v, mask, scale, softcap, qkv, M, S);
}
-#endif
- if (nk1%128 == 0) {
- switch (Dk) {
- case 64:
- iqk_flash_helper_T< 64, 64, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 96:
- iqk_flash_helper_T< 96, 96, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 128:
- iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 192:
- iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 256:
- iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- default:
- return false;
- }
- return true;
+ if (Dk == 128 && Dv == 128) {
+ return iqk_fa_128_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, k, v, mask, scale, softcap, qkv, M, S);
}
- if (nk1%64 == 0) {
- switch (Dk) {
- case 64:
- iqk_flash_helper_T< 64, 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- // Disable until we fix accumulate_qkv for odd D/16
- //case 80:
- // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
- case 96:
- iqk_flash_helper_T< 96, 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- // Disable until we fix accumulate_qkv for odd D/16
- //case 112:
- // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
- case 128:
- iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 192:
- iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 256:
- iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- default:
- return false;
- }
- return true;
+
+ if (Dk == 96 && Dv == 96) {
+ return iqk_fa_96_96(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, k, v, mask, scale, softcap, qkv, M, S);
}
- switch (Dk) {
- case 64:
- iqk_flash_helper_T< 64, 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- // Disable until we fix accumulate_qkv for odd D/16
- //case 80:
- // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
- case 96:
- iqk_flash_helper_T< 96, 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- // Disable until we fix accumulate_qkv for odd D/16
- //case 112:
- // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
- case 128:
- iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 192:
- iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- case 256:
- iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
- default:
- return false;
+
+ if (Dk == 64 && Dv == 64) {
+ return iqk_fa_64_64(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, k, v, mask, scale, softcap, qkv, M, S);
}
- return true;
+ return false;
}
#endif
diff --git a/ggml/src/iqk/iqk_utils.h b/ggml/src/iqk/iqk_utils.h
new file mode 100644
index 00000000..194bf9b8
--- /dev/null
+++ b/ggml/src/iqk/iqk_utils.h
@@ -0,0 +1,207 @@
+#pragma once
+
+#include "iqk_config.h"
+
+#if defined IQK_IMPLEMENT
+
+#include "ggml-impl.h"
+
+#if defined(__ARM_NEON) && defined(__aarch64__)
+// copy-pasted from Justine Tunney's contribution to llama.cpp
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+static inline float32x4_t v_expf(float32x4_t x) {
+ const float32x4_t r = vdupq_n_f32(0x1.8p23f);
+ const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
+ const float32x4_t n = vsubq_f32(z, r);
+ const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
+ vdupq_n_f32(0x1.7f7d1cp-20f));
+ const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
+ const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
+ const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
+ const float32x4_t u = vmulq_f32(b, b);
+ const float32x4_t j = vfmaq_f32(
+ vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
+ vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
+ vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
+ if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
+ return vfmaq_f32(k, j, k);
+ const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
+ const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
+ const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
+ return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
+ vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
+}
+static inline float16x8_t v_expf(float16x8_t x) {
+ auto val1 = v_expf(vcvt_f32_f16(vget_low_f16(x)));
+ auto val2 = v_expf(vcvt_f32_f16(vget_high_f16(x)));
+ return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
+}
+static inline float32x4_t v_tanh(float32x4_t x) {
+ const float32x4_t one = vdupq_n_f32(1.0f);
+ const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
+ const float32x4_t exp_two_x = v_expf(two_x);
+ const uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
+ const float32x4_t res = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
+ return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
+ //return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
+}
+//inline float32x4_t v_tanh(float16x8_t x) {
+// auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x)));
+// auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
+// return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
+//}
+static inline float32x4_t v_silu(float32x4_t x) {
+ const float32x4_t one = vdupq_n_f32(1.0f);
+ const float32x4_t zero = vdupq_n_f32(0.0f);
+ const float32x4_t neg_x = vsubq_f32(zero, x);
+ const float32x4_t exp_neg_x = v_expf(neg_x);
+ const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
+ return vdivq_f32(x, one_plus_exp_neg_x);
+}
+static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) {
+ const float32x4_t one = vdupq_n_f32(1.0f);
+ float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x));
+ arg = vmulq_f32(arg, vmulq_f32(x, c2));
+ float32x4_t exp_arg = v_expf(arg);
+ float32x4_t gelu = vmulq_f32(x, vdivq_f32(exp_arg, vaddq_f32(exp_arg, one)));
+ uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
+ return vbslq_f32(mask, x, gelu);
+}
+
+#endif // __ARN_NEON
+
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+
+// copy-pasted from Justine Tunney's contribution to llama.cpp
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+static inline __m512 v_expf(__m512 x) {
+ const __m512 r = _mm512_set1_ps(0x1.8p23f);
+ const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
+ const __m512 n = _mm512_sub_ps(z, r);
+ const __m512 b =
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
+ const __mmask16 d =
+ _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
+ const __m512 u = _mm512_mul_ps(b, b);
+ const __m512 j = _mm512_fmadd_ps(
+ _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
+ _mm512_set1_ps(0x1.573e2ep-5f)),
+ u,
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
+ _mm512_set1_ps(0x1.fffdb6p-2f))),
+ u,
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
+ const __m512 res = _mm512_scalef_ps(j, n);
+ if (_mm512_kortestz(d, d))
+ return res;
+ const __m512 zero = _mm512_setzero_ps();
+ const __m512 alt = _mm512_mask_blend_ps(
+ _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
+ return _mm512_mask_blend_ps(d, res, alt);
+}
+static inline __m512 v_tanh(__m512 x) {
+ const __m512 one = _mm512_set1_ps(1.0f);
+ const __m512 exp_two_x = v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f)));
+ const __mmask16 mask = _mm512_cmp_ps_mask(x, _mm512_set1_ps(10.f), _CMP_GT_OQ);
+ const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
+ return _mm512_mask_blend_ps(mask, res, one);
+}
+static inline __m512 v_gelu(__m512 x, __m512 c1, __m512 c2) {
+ const __m512 one = _mm512_set1_ps(1.0f);
+ __m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one);
+ //__m512 arg = _mm512_add_ps(one, _mm512_mul_ps(_mm512_mul_ps(x, x), c1));
+ arg = _mm512_mul_ps(arg, _mm512_mul_ps(c2, x));
+ const __mmask16 mask = _mm512_cmp_ps_mask(arg, _mm512_set1_ps(30.f), _CMP_GT_OQ);
+ const __m512 exp_arg = v_expf(arg);
+ const __m512 ratio = _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one));
+ return _mm512_mul_ps(x, _mm512_mask_blend_ps(mask, ratio, one));
+}
+static inline __m512 v_silu(__m512 x) {
+ const __m512 one = _mm512_set1_ps(1);
+ const __m512 zero = _mm512_setzero_ps();
+ const __m512 neg_x = _mm512_sub_ps(zero, x);
+ const __m512 exp_neg_x = v_expf(neg_x);
+ const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
+ return _mm512_div_ps(x, one_plus_exp_neg_x);
+}
+#endif // __AVX512__
+
+#if defined(__AVX2__) && defined(__FMA__)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+static inline __m256 v_expf(__m256 x) {
+ const __m256 r = _mm256_set1_ps(0x1.8p23f);
+ const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
+ const __m256 n = _mm256_sub_ps(z, r);
+ const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
+ _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
+ const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
+ const __m256 k = _mm256_castsi256_ps(
+ _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
+ const __m256i c = _mm256_castps_si256(
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
+ _mm256_set1_ps(126), _CMP_GT_OQ));
+ const __m256 u = _mm256_mul_ps(b, b);
+ const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
+ _mm256_set1_ps(0x1.573e2ep-5f)), u,
+ _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
+ _mm256_set1_ps(0x1.fffdb6p-2f))),
+ u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
+ if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
+ return _mm256_fmadd_ps(j, k, k);
+ const __m256i g = _mm256_and_si256(
+ _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
+ _mm256_set1_epi32(0x82000000u));
+ const __m256 s1 =
+ _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
+ const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
+ const __m256i d = _mm256_castps_si256(
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
+ _mm256_set1_ps(192), _CMP_GT_OQ));
+ return _mm256_or_ps(
+ _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
+ _mm256_andnot_ps(
+ _mm256_castsi256_ps(d),
+ _mm256_or_ps(
+ _mm256_and_ps(_mm256_castsi256_ps(c),
+ _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
+ _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
+}
+static inline __m256 v_tanh(__m256 x) {
+ const __m256 one = _mm256_set1_ps(1.0f);
+ const __m256 exp_two_x = v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f)));
+ const __m256 res = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
+ const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
+ return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res));
+}
+static inline __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) {
+ const __m256 one = _mm256_set1_ps(1.0f);
+ const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
+ __m256 arg = _mm256_add_ps(one, _mm256_mul_ps(_mm256_mul_ps(x, x), c1));
+ arg = _mm256_mul_ps(arg, _mm256_mul_ps(x, c2));
+ __m256 exp_arg = v_expf(arg);
+ __m256 gelu = _mm256_mul_ps(x, _mm256_div_ps(exp_arg, _mm256_add_ps(exp_arg, one)));
+ return _mm256_or_ps(_mm256_and_ps(mask, x), _mm256_andnot_ps(mask, gelu));
+}
+static inline __m256 v_silu(__m256 x) {
+ const __m256 one = _mm256_set1_ps(1);
+ const __m256 zero = _mm256_setzero_ps();
+ const __m256 neg_x = _mm256_sub_ps(zero, x);
+ const __m256 exp_neg_x = v_expf(neg_x);
+ const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
+ return _mm256_div_ps(x, one_plus_exp_neg_x);
+}
+
+#endif // __AVX2__
+
+#endif // IQK_IMPLEMENT