summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--iqk_mul_mat.cpp113
1 files changed, 112 insertions, 1 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index aa364900..02878372 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -22,6 +22,9 @@
#include "ggml-quants.h"
#include "sgemm.h"
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
// clang-format off
// This matrix - vector and matrix - matrix multiplication implementation
@@ -1944,8 +1947,113 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
float d;
};
+static const int8_t keven_signs_q2xs[1024] = {
+ 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
+ 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
+ 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
+ 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
+ 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
+ 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
+ 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
+ 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
+ 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
+ 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
+ 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
+ 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
+ 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
+ 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
+ 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
+ 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
+ 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
+ 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
+ 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
+ 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
+ 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
+ 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
+ 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
+ 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
+ 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
+ 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
+ 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
+ 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
+ 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
+ 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
+ 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
+ 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
+};
+
+struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
+ DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs);
+ data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3
+ data.val[1] = vuzp1q_u32(tmp.val[2], tmp.val[3]); // codebook indices for blocks 4...7
+ data.val[2] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3
+ data.val[3] = vuzp2q_u32(tmp.val[2], tmp.val[3]); // scales and signs for blocks 4...7
+ int32x4x2_t scales;
+ scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(data.val[2], 28), 1), vdupq_n_u32(1)));
+ scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(data.val[3], 28), 1), vdupq_n_u32(1)));
+ return scales;
+ }
+
+ inline void prepare(int /*i*/, int j) {
+ const uint8_t * idx = (const uint8_t *)(data.val + j);
+ bits.b1.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 0])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 1])));
+ bits.b1.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 2])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 3])));
+ bits.b1.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 4])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 5])));
+ bits.b1.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 6])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 7])));
+ bits.b2.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 8])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 9])));
+ bits.b2.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[10])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[11])));
+ bits.b2.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[12])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[13])));
+ bits.b2.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[14])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[15])));
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+ const uint32_t * sidx = (const uint32_t *)(data.val + 2 + j);
+ bits.b1.val[0] = vmulq_s8(bits.b1.val[0], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[0] >> 0) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[0] >> 7) & 127)))));
+ bits.b1.val[1] = vmulq_s8(bits.b1.val[1], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[0] >> 14) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[0] >> 21) & 127)))));
+ bits.b1.val[2] = vmulq_s8(bits.b1.val[2], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[1] >> 0) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[1] >> 7) & 127)))));
+ bits.b1.val[3] = vmulq_s8(bits.b1.val[3], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[1] >> 14) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[1] >> 21) & 127)))));
+ bits.b2.val[0] = vmulq_s8(bits.b2.val[0], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[2] >> 0) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[2] >> 7) & 127)))));
+ bits.b2.val[1] = vmulq_s8(bits.b2.val[1], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[2] >> 14) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[2] >> 21) & 127)))));
+ bits.b2.val[2] = vmulq_s8(bits.b2.val[2], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[3] >> 0) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[3] >> 7) & 127)))));
+ bits.b2.val[3] = vmulq_s8(bits.b2.val[3], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[3] >> 14) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[3] >> 21) & 127)))));
+ //auto mask = vdupq_n_u32(127);
+ //uint32x4_t sindex;
+ //sindex = vandq_u32(data.val[2+j], mask);
+ //mask = vshlq_n_u32(mask, 7);
+ //sindex = vorrq_u32(sindex, vshlq_n_u32(vandq_u32(data.val[2+j], mask), 1));
+ //mask = vshlq_n_u32(mask, 7);
+ //sindex = vorrq_u32(sindex, vshlq_n_u32(vandq_u32(data.val[2+j], mask), 2));
+ //mask = vshlq_n_u32(mask, 7);
+ //sindex = vorrq_u32(sindex, vshlq_n_u32(vandq_u32(data.val[2+j], mask), 3));
+ //const uint8_t * sidx = (const uint8_t *)&sindex;
+ //bits.b1.val[0] = vmulq_s8(bits.b1.val[0], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 0])), vld1_s8((const int8_t *)(signs64 + sidx[ 1]))));
+ //bits.b1.val[1] = vmulq_s8(bits.b1.val[1], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 2])), vld1_s8((const int8_t *)(signs64 + sidx[ 3]))));
+ //bits.b1.val[2] = vmulq_s8(bits.b1.val[2], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 4])), vld1_s8((const int8_t *)(signs64 + sidx[ 5]))));
+ //bits.b1.val[3] = vmulq_s8(bits.b1.val[3], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 6])), vld1_s8((const int8_t *)(signs64 + sidx[ 7]))));
+ //bits.b2.val[0] = vmulq_s8(bits.b2.val[0], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 8])), vld1_s8((const int8_t *)(signs64 + sidx[ 9]))));
+ //bits.b2.val[1] = vmulq_s8(bits.b2.val[1], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[10])), vld1_s8((const int8_t *)(signs64 + sidx[11]))));
+ //bits.b2.val[2] = vmulq_s8(bits.b2.val[2], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[12])), vld1_s8((const int8_t *)(signs64 + sidx[13]))));
+ //bits.b2.val[3] = vmulq_s8(bits.b2.val[3], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[14])), vld1_s8((const int8_t *)(signs64 + sidx[15]))));
+ }
+
+ uint32x4x4_t data;
+ struct Bits {
+ uint8x16x4_t b1;
+ uint8x16x4_t b2;
+ };
+ Bits bits;
+
+ float d;
+};
+
template <int nrc_y, typename Dequantizer>
-static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
@@ -2466,6 +2574,9 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /
case GGML_TYPE_IQ4_XS:
MulMat::set_functions<DequantizerIQ4XS>(m);
break;
+ case GGML_TYPE_IQ2_XXS:
+ MulMat::set_functions<DequantizerIQ2XXS>(m);
+ break;
case GGML_TYPE_Q4_0:
MulMat::set_functions<DequantizerQ40>(m);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);