summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-07-17 08:54:11 +0300
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-07-17 08:54:11 +0300
commit873a790ee22538d1d9d7205db7210c70955ab1e1 (patch)
tree426f0ce20cf45c325be28ae8bceb51d42c072452
parent52a25e307c3af8686436d977c60e9975b0900e2b (diff)
iq1bn(no lookup): better version
We have 4 groups of 16 in a block of 64 quants. For each group of 16 we have 3 groups of 5, each using 8 bits. The remaining 16'th quants of the 4 groups of 16 are encoded with 8 bits using the same encoding as the groups of 5. The only kernel where we have complications is the CUDA dequantize kernel (because we are dequantizing 8 quants there, and we have different encoding for the 1st and 2nd group of 8 in a group of 16). Ths achieves better performance on all tested platforms than any previous 1.625 bpw attempt. We have: | model | size | params | backend | threads | test | t/s | | ---------------- | ---------: | ---------: | ---------- | ------: | ------------: | ---------------: | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | CUDA | 8 | pp512 | 9613.02 ± 24.54 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | CUDA | 8 | tg128 | 229.85 ± 0.33 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 16 | pp512 | 322.59 ± 1.00 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 16 | tg128 | 59.79 ± 0.03 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 8 | tg128 | 57.62 ± 0.21 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 4 | tg128 | 33.66 ± 0.29 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 2 | tg128 | 18.30 ± 0.01 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | Metal | 8 | pp512 | 698.13 ± 0.21 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | Metal | 8 | tg128 | 68.88 ± 0.24 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 8 | pp512 | 196.80 ± 0.50 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 8 | tg128 | 51.58 ± 0.41 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 4 | tg128 | 30.80 ± 0.03 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 2 | tg128 | 16.89 ± 0.01 | It is still slower than 2 bpw Bitnet, but the difference now is not as dramatic.
-rw-r--r--ggml-common.h5
-rw-r--r--ggml-cuda/convert.cu26
-rw-r--r--ggml-cuda/vecdotq.cuh76
-rw-r--r--ggml-metal.metal117
-rw-r--r--iqk-quantize.cpp120
-rw-r--r--iqk_mul_mat.cpp150
6 files changed, 242 insertions, 252 deletions
diff --git a/ggml-common.h b/ggml-common.h
index bf95da2a..f515e95c 100644
--- a/ggml-common.h
+++ b/ggml-common.h
@@ -380,11 +380,10 @@ static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m bl
//
#define QK_IQ1BN 64
typedef struct {
+ uint8_t ql[12];
uint8_t extra;
- uint8_t ql[QK_IQ1BN/8];
- uint8_t qh[QK_IQ1BN/16];
} block_iq1_bn;
-static_assert(sizeof(block_iq1_bn) == sizeof(uint8_t) + QK_IQ1BN/8 + QK_IQ1BN/16, "wrong iq1_bn block size/padding");
+static_assert(sizeof(block_iq1_bn) == 13, "wrong iq1_bn block size/padding");
//
// Bitnet - implemented as 2.25 bpw
//
diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu
index a0a826e9..4a67c498 100644
--- a/ggml-cuda/convert.cu
+++ b/ggml-cuda/convert.cu
@@ -425,7 +425,10 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst
const int64_t ii = blockIdx.x;
const block_iq1_bn * x = (const block_iq1_bn *) vx;
- static const uint16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
+ static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+
+//#define COMPUTE_VS(v) 3*v >> 8
+#define COMPUTE_VS(v) (v + (v >> 1)) >> 7
const int tid = threadIdx.x;
const int il = tid/4; // 0...7
@@ -433,11 +436,24 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst
dst_t * y = yy + ii*QK_K + 64*ib + 8*il;
int64_t i = QK_K/QK_IQ1BN * ii + ib;
if (i >= nb64) return;
- uint16_t val = x[i].ql[il] | ((x[i].qh[il%4] << (8 - 4*(il/4))) & 0x0f00) | ((x[i].extra << (12 - il)) & 4096);
- for (int j = 0; j < 8; ++j) {
- uint16_t v = (val*k_mult[j] & 0x1fff)*3 >> 13;
- y[j] = v - 1;
+ const int i16 = il/2;
+ uint8_t q = x[i].ql[3*i16+2*(il%2)];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[2*(il%2)+j] = vs - 1;
}
+ q = x[i].ql[3*i16+1];
+ for (int j = 0; j < 2; ++j) {
+ uint8_t v = k_mult[3*(il%2)+j]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[5*(1-(il%2))+j] = vs-1;
+ }
+ uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[7] = vs - 1;
+
+#undef COMPUTE_VS
}
template<typename dst_t>
diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh
index d0d2e923..f0133e07 100644
--- a/ggml-cuda/vecdotq.cuh
+++ b/ggml-cuda/vecdotq.cuh
@@ -1078,67 +1078,47 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx;
- static const int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
+ static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
// iqs is 0 or 1
- uint8_t extra = bq1->extra >> 4*iqs;
int sumi = 0;
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const int * q8 = (const int *)bq8_1[iqs].qs;
- //int v[2];
- //int8_t * a = (int8_t *)v;
- //for (int l = 0; l < 2; ++l) {
- // int16_t val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
- // int16_t val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
- // for (int k = 0; k < 8; ++k) a[k] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1;
- // sumi = __dp4a(v[0], q8[4*l+0], __dp4a(v[1], q8[4*l+1], sumi));
- // for (int k = 0; k < 8; ++k) a[k] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1;
- // sumi = __dp4a(v[0], q8[4*l+2], __dp4a(v[1], q8[4*l+3], sumi));
- // extra >>= 2;
- //}
-
- int v[4];
- int8_t * a = (int8_t *)v;
+ int val[4];
for (int l = 0; l < 2; ++l) {
- int16_t val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
- int16_t val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
- for (int k = 0; k < 8; ++k) {
- a[k+0] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1;
- a[k+8] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1;
+ int8_t * a = (int8_t *)val;
+ const int i16 = 2*iqs + l;
+ for (int k = 0; k < 3; ++k) {
+ uint8_t q = bq1->ql[3*i16+k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ *a++ = vs-1;
+ }
}
- sumi = __dp4a(v[0], q8[4*l+0], __dp4a(v[1], q8[4*l+1], __dp4a(v[2], q8[4*l+2], __dp4a(v[3], q8[4*l+3], sumi))));
- extra >>= 2;
+ uint8_t v = k_mult[i16]*bq1->extra;
+ int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ *a++ = vs-1;
+ sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi))));
}
-
- //int v[8];
- //int8_t * a = (int8_t *)v;
- //int16_t val1 = bq1->ql[4*iqs + 0] | ((bq1->qh[0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
- //int16_t val2 = bq1->ql[4*iqs + 1] | ((bq1->qh[1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
- //int16_t val3 = bq1->ql[4*iqs + 2] | ((bq1->qh[2] << (8-4*iqs)) & 0x0f00) | ((extra << 10) & 4096);
- //int16_t val4 = bq1->ql[4*iqs + 3] | ((bq1->qh[3] << (8-4*iqs)) & 0x0f00) | ((extra << 9) & 4096);
- //for (int k = 0; k < 8; ++k) {
- // a[k+ 0] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1;
- // a[k+ 8] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1;
- // a[k+16] = ((val3*k_mult[k] & 0x1fff)*3 >> 13) - 1;
- // a[k+24] = ((val4*k_mult[k] & 0x1fff)*3 >> 13) - 1;
- //}
- //sumi = __dp4a(v[0], q8[0], __dp4a(v[1], q8[1], __dp4a(v[2], q8[2], __dp4a(v[3], q8[3], sumi))));
- //sumi = __dp4a(v[4], q8[4], __dp4a(v[5], q8[5], __dp4a(v[6], q8[6], __dp4a(v[7], q8[7], sumi))));
#else
const int8_t * q8 = bq8_1[iqs].qs;
for (int l = 0; l < 2; ++l) {
- int val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
- int val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
- for (int k = 0; k < 4; ++k) {
- int v1 = (val1*k_mult[k+0] & 0x1fff)*3 >> 13;
- int v2 = (val1*k_mult[k+4] & 0x1fff)*3 >> 13;
- int v3 = (val2*k_mult[k+0] & 0x1fff)*3 >> 13;
- int v4 = (val2*k_mult[k+4] & 0x1fff)*3 >> 13;
- sumi += (v1 - 1)*q8[k+0] + (v2 - 1)*q8[k+4] + (v3 - 1)*q8[k+8] + (v4 - 1)*q8[k+12];
+ const int i16 = 2*iqs + l;
+ for (int k = 0; k < 3; ++k) {
+ uint8_t q = bq1->ql[3*i16+k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = (v + (v >> 1)) >> 7;
+ sumi += q8[j]*(vs - 1);
+ }
+ q8 += 5;
}
- q8 += 16;
- extra >>= 2;
+ uint8_t v = k_mult[i16]*bq1->extra;
+ int8_t vs = (v + (v >> 1)) >> 7;
+ sumi += q8[0]*(vs - 1);
+ q8++;
}
#endif
return __low2float(bq8_1[iqs].ds) * sumi;
diff --git a/ggml-metal.metal b/ggml-metal.metal
index cca4b0e9..2bf6894a 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -5054,6 +5054,49 @@ static inline float iq1bn_fp8_to_float(uint8_t fp8) {
return s.f;
}
+//static constant int8_t iq1bn_values[256*5] = {
+// -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 0, -1, -1, -1, 0, 0, -1, -1, -1, 1, 0,
+// -1, -1, -1, -1, 1, -1, -1, -1, 0, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, 0, -1, -1, 0, -1, 0, -1, -1, 1, -1, 0, -1,
+// -1, -1, 0, 0, -1, -1, 0, 0, 0, -1, -1, 1, 0, 0, -1, -1, -1, 1, 0, -1, -1, 0, 1, 0, -1, -1, 1, 1, 0, -1, -1, -1,
+// -1, 1, -1, -1, 0, 0, 0, 0, 0, 0, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, 0, 1, -1, -1, 0, 0, 1, -1, -1, 1, 0, 1,
+// -1, -1, -1, 1, 1, -1, -1, 0, 1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 0, -1, 0, -1, -1, 0, -1, 1, -1, -1, 0, -1,
+// -1, 0, -1, 0, -1, 0, 0, -1, 0, -1, 1, 0, -1, 0, -1, -1, 1, -1, 0, -1, 0, 1, -1, 0, -1, 1, 1, -1, 0, -1, -1, -1,
+// 0, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 0, -1, -1, 0, 0, 0, -1, 0, 0, 0, 0, -1, 1, 0, 0, 0,
+// -1, -1, 1, 0, 0, -1, 0, 1, 0, 0, -1, 1, 1, 0, 0, -1, -1, -1, 1, 0, -1, 0, -1, 1, 0, -1, 1, -1, 1, 0, -1, -1,
+// 0, 1, 0, -1, 0, 0, 1, 0, -1, 1, 0, 1, 0, -1, -1, 1, 1, 0, -1, 0, 1, 1, 0, -1, 1, 1, 1, 0, -1, -1, -1, -1,
+// 1, -1, 0, -1, -1, 1, -1, 1, -1, -1, 1, -1, 0, 0, 0, 0, 0, -1, 0, -1, 1, -1, 0, 0, -1, 1, -1, 1, 0, -1, 1, -1,
+// -1, 1, -1, 1, -1, 0, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 0, 1, -1, 0, -1, 0, 1, -1, 1, -1, 0, 1, -1, -1, 0,
+// 0, 1, -1, 0, 0, 0, 1, -1, 1, 0, 0, 1, -1, -1, 1, 0, 1, -1, 0, 1, 0, 1, -1, 1, 1, 0, 1, -1, -1, -1, 1, 1,
+// -1, 0, -1, 1, 1, -1, 1, -1, 1, 1, -1, 0, 0, 0, 0, 0, -1, 0, 1, 1, -1, 0, 0, 1, 1, -1, 1, 0, 1, 1, -1, -1,
+// 1, 1, 1, -1, 0, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, 0, 0, -1, -1, -1, 0, 1, -1, -1, -1, 0, -1, 0, -1,
+// -1, 0, 0, 0, -1, -1, 0, 1, 0, -1, -1, 0, -1, 1, -1, -1, 0, 0, 1, -1, -1, 0, 1, 1, -1, -1, 0, -1, -1, 0, -1, 0,
+// 0, -1, 0, -1, 0, 1, -1, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 1, 0, 0, -1, 0, -1, 1,
+// 0, -1, 0, 0, 1, 0, -1, 0, 1, 1, 0, -1, 0, -1, -1, 1, -1, 0, 0, -1, 1, -1, 0, 1, -1, 1, -1, 0, -1, 0, 1, -1,
+// 0, 0, 0, 1, -1, 0, 1, 0, 1, -1, 0, -1, 1, 1, -1, 0, 0, 1, 1, -1, 0, 1, 1, 1, -1, 0, -1, -1, -1, 0, 0, 0,
+// -1, -1, 0, 0, 1, -1, -1, 0, 0, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, -1, 1, -1,
+// 0, 0, 0, 1, -1, 0, 0, 1, 1, -1, 0, 0, -1, -1, 0, 0, 0, 0, -1, 0, 0, 0, 1, -1, 0, 0, 0, -1, 0, 0, 0, 0,
+// 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, -1, -1, 1, 0, 0, 0, -1,
+// 1, 0, 0, 1, -1, 1, 0, 0, -1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, -1, 1, 1, 0,
+// 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, -1, -1, -1, 1, 0, 0, -1, -1, 1, 0, 1, -1, -1, 1, 0, -1, 0, -1, 1, 0, 0,
+// 0, -1, 1, 0, 1, 0, -1, 1, 0, -1, 1, -1, 1, 0, 0, 1, -1, 1, 0, 1, 1, -1, 1, 0, -1, -1, 0, 1, 0, 0, -1, 0,
+// 1, 0, 1, -1, 0, 1, 0, -1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1, 1, 0, 1, 0,
+// 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, -1, -1, 1, 1, 0, 0, -1, 1, 1, 0, 1, -1, 1, 1, 0, -1, 0, 1, 1, 0, 0, 0,
+// 1, 1, 0, 1, 0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, -1, -1, -1, -1, 1, 0, -1, -1, -1,
+// 1, 1, -1, -1, -1, 1, -1, 0, -1, -1, 1, 0, 0, -1, -1, 1, 1, 0, -1, -1, 1, -1, 1, -1, -1, 1, 0, 0, 0, 0, 0, 0,
+// 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 0, -1, 1, 0, -1, 0, -1, 1, 1, -1, 0, -1, 1, -1, 0, 0, -1, 1, 0, 0, 0,
+// -1, 1, 1, 0, 0, -1, 1, -1, 1, 0, -1, 1, 0, 1, 0, -1, 1, 1, 1, 0, -1, 1, -1, -1, 1, -1, 1, 0, -1, 1, -1, 1,
+// 1, -1, 1, -1, 1, -1, 0, 1, -1, 1, 0, 0, 1, -1, 1, 1, 0, 1, -1, 1, -1, 1, 1, -1, 1, 0, 0, 0, 0, 0, 0, 1,
+// 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 0, 1, 0, -1, -1, 0, 1, 1, -1, -1, 0, 1, -1, 0, -1, 0, 1, 0, 0, -1, 0,
+// 1, 1, 0, -1, 0, 1, -1, 1, -1, 0, 1, 0, 1, -1, 0, 1, 1, 1, -1, 0, 1, -1, -1, 0, 0, 1, 0, -1, 0, 0, 1, 1,
+// -1, 0, 0, 1, -1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, -1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0,
+// 0, 0, 1, 1, 0, 0, 1, -1, -1, 1, 0, 1, 0, -1, 1, 0, 1, 1, -1, 1, 0, 1, -1, 0, 1, 0, 1, 0, 0, 1, 0, 1,
+// 1, 0, 1, 0, 1, -1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, -1, -1, -1, 1, 1, 0, -1, -1, 1, 1, 1, -1,
+// -1, 1, 1, -1, 0, -1, 1, 1, 0, 0, -1, 1, 1, 1, 0, -1, 1, 1, -1, 1, -1, 1, 1, 0, 1, -1, 1, 1, 1, 1, -1, 1,
+// 1, 0, 0, 0, 0, 0, -1, -1, 0, 1, 1, 0, -1, 0, 1, 1, 1, -1, 0, 1, 1, -1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1,
+// 0, 0, 1, 1, -1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, -1, -1, 1, 1, 1, 0, -1, 1, 1, 1, 1, -1, 1,
+// 1, 1, -1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, -1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+//};
+
void kernel_mul_mv_iq1_bn_f32_impl(
device const void * src0,
device const float * src1,
@@ -5087,53 +5130,62 @@ void kernel_mul_mv_iq1_bn_f32_impl(
device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
- float4 yl[2];
- float sumf[N_DST]={0.f}, all_sum;
+ float yl[16];
+ float sumf[N_DST]={0.f};
const int nb32 = nb * (QK_IQ1BN / 32);
- const int ix = tiisg/4;
- const int ir = tiisg%4;
+ const int ix = tiisg/2;
+ const int ir = tiisg%2;
- device const float4 * y4 = (device const float4 *)y + 8 * ix + 2 * ir;
+ device const float * y4 = (device const float *)y + 32 * ix + 16 * ir;
uint32_t aux32[2];
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
const float values[3] = {-1.f, 0.f, 1.f};
- constexpr int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
+ constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
- for (int ib32 = ix; ib32 < nb32; ib32 += 8) {
+ for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
- yl[0] = y4[0]; yl[1] = y4[1];
+ for (int j = 0; j < 16; ++j) yl[j] = y4[j];
const int ibl = ib32 / (QK_IQ1BN / 32);
const int ib = ib32 % (QK_IQ1BN / 32);
- const int il = 4*ib + ir;
+ const int i16 = 2*ib + ir;
device const block_iq1_bn * xr = x + ibl;
+ device const uint8_t * ql = xr->ql + 3*i16;
device const uint8_t * extra = (device const uint8_t *)&xr->extra;
- device const uint8_t * ql = xr->ql + il;
- device const uint8_t * qh = xr->qh + il%4;
for (int row = 0; row < N_DST; row++) {
- uint8_t h = extra[0] >> il;
+ float acc = 0;
+ int i = 0;
+ for (int k = 0; k < 3; ++k) {
+ //constant int8_t * vs = iq1bn_values + 5*ql[k];
+ //for (int j = 0; j < 5; ++j) acc += yl[5*k+j]*vs[j];
+ uint8_t q = ql[k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ v = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ acc += yl[i++] * values[v];
+ }
+ }
+ //constant int8_t * vs = iq1bn_values + 5*extra[0];
+ //acc += yl[15] * vs[i16];
+ uint8_t v = k_mult[i16]*extra[0];
+ v = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ acc += yl[15] * values[v];
- int16_t val = ql[0] | ((qh[0] << (8 - 4*(il/4))) & 0x0f00) | ((extra[0] << (12 - il)) & 4096);
- float4 acc4 = yl[0] * float4{values[(val*k_mult[0] & 0x1fff)*3 >> 13], values[(val*k_mult[1] & 0x1fff)*3 >> 13],
- values[(val*k_mult[2] & 0x1fff)*3 >> 13], values[(val*k_mult[3] & 0x1fff)*3 >> 13]}
- + yl[1] * float4{values[(val*k_mult[4] & 0x1fff)*3 >> 13], values[(val*k_mult[5] & 0x1fff)*3 >> 13],
- values[(val*k_mult[6] & 0x1fff)*3 >> 13], values[(val*k_mult[7] & 0x1fff)*3 >> 13]};
- sumf[row] += acc4[0] + acc4[1] + acc4[2] + acc4[3];
+ sumf[row] += acc;
extra += nb*sizeof(block_iq1_bn);
ql += nb*sizeof(block_iq1_bn);
- qh += nb*sizeof(block_iq1_bn);
}
- y4 += 32 * 2;
+ y4 += 32 * 16;
}
for (int row = 0; row < N_DST; row += 2) {
@@ -5990,18 +6042,23 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
template <typename type4x4>
void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
// il is in 0...3
- uint8_t gs = xb->extra >> 2*il;
-
- constexpr int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
-
- short il1 = 2*il+0, il2 = 2*il+1;
- int16_t v1 = xb->ql[il1] | ((xb->qh[il1%4] << (8 - 4*(il1/4))) & 0x0f00) | ((gs << 12) & 4096);
- int16_t v2 = xb->ql[il2] | ((xb->qh[il2%4] << (8 - 4*(il2/4))) & 0x0f00) | ((gs << 11) & 4096);
- for (int i = 0; i < 8; ++i) {
- reg[i/4+0][i%4] = ((v1*k_mult[i] & 0x1fff)*3 >> 13) - 1;
- reg[i/4+2][i%4] = ((v2*k_mult[i] & 0x1fff)*3 >> 13) - 1;
+ constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+
+ int i = 0;
+ for (int k = 0; k < 3; ++k) {
+ uint8_t q = xb->ql[3*il + k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = 3*v >> 8;
+ //int8_t vs = (v + (v >> 1)) >> 7;
+ reg[i/4][i%4] = vs - 1;
+ ++i;
+ }
}
+ uint8_t v = k_mult[il]*xb->extra;
+ int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ reg[3][3] = vs - 1;
}
template <typename type4x4>
diff --git a/iqk-quantize.cpp b/iqk-quantize.cpp
index 03d73ff0..f25f2031 100644
--- a/iqk-quantize.cpp
+++ b/iqk-quantize.cpp
@@ -118,7 +118,7 @@ uint16_t IQ1BNQuantizer::quantize_one_block_1bn(const IQ1BNData& iq1bn, const fl
void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix) {
- static const int k_nb[8] = {1, 3, 9, 27, 81, 243, 729, 2187};
+ static const int k_nb[6] = {1, 3, 9, 27, 81, 243};
(void)imatrix;
const int nblock = n_per_row/QK_IQ1BN;
@@ -126,21 +126,24 @@ void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, i
for (int ib = 0; ib < nblock; ++ib) {
std::memset(&y[ib], 0, sizeof(block_iq1_bn));
auto xb = src + ib*QK_IQ1BN;
- for (int i = 0; i < QK_IQ1BN/8; ++i) {
- int idx = 0;
- for (int j = 0; j < 8; ++j) {
- float v = xb[8*i + j];
- int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2;
- idx += k_nb[j]*q;
+ int v13 = 0;
+ for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) {
+ for (int k = 0; k < 3; ++k) {
+ int idx = 0;
+ for (int j = 0; j < 5; ++j) {
+ float v = xb[16*i16 + 5*k + j];
+ int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2;
+ idx += k_nb[j]*q;
+ }
+ idx = (256*idx + k_nb[5] - 1)/k_nb[5];
+ y[ib].ql[3*i16 + k] = idx;
}
- idx = (8192*idx + 6560)/6561;
- y[ib].ql[i] = idx & 255;
- y[ib].qh[i%4] |= ((idx >> 8) & 0xf) << 4*(i/4);
- y[ib].extra |= (idx >> 12) << i;
-
+ float v = xb[16*i16 + 15];
+ int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2;
+ v13 += k_nb[i16]*q;
}
+ y[ib].extra = (256*v13 + k_nb[5] - 1)/k_nb[5];
}
-
}
void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, int n_per_row, const float * imatrix) {
@@ -194,18 +197,23 @@ void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) {
assert(k%QK_IQ1BN == 0);
int nblock = k / QK_IQ1BN;
- static const int k_mult[8] = {17496, 5832, 1944, 648, 216, 72, 24, 8};
+ static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
for (int i = 0; i < nblock; ++i) {
uint8_t extra = x[i].extra;
- auto qh = x[i].qh;
auto ql = x[i].ql;
- for (int k = 0; k < QK_IQ1BN/8; ++k) {
- uint16_t idx = ql[k] | ((qh[k%4] << (8 - 4*(k/4))) & 0x0f00) | ((extra << (12 - k)) & 4096);
- for (int j = 0; j < 8; ++j) {
- int v = (idx*k_mult[j] & 0xffff)*3 >> 16;
- *y++ = v - 1;
+ for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) {
+ for (int k = 0; k < 3; ++k) {
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = ql[k]*k_mult[j];
+ int8_t vs = ((v + (v >> 1)) >> 7);
+ *y++ = vs - 1;
+ }
}
+ ql += 3;
+ uint8_t v = extra*k_mult[i16];
+ int8_t vs = ((v + (v >> 1)) >> 7);
+ *y++ = vs - 1;
}
}
}
@@ -260,42 +268,44 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
return;
}
- constexpr uint16_t k_magic = 0xaaaa;
-
- const block_iq1_bn * x = (const block_iq1_bn *)vx;
-
- const float * d8 = (const float *)vy;
- const int8_t * q8 = (const int8_t *)(d8 + 4);
- int nblock = n / QK_IQ1BN;
-
- int sumi[8] = {};
- uint32_t aux32[2];
- const int8_t * aux8 = (const int8_t *)aux32;
-
- for (int i = 0; i < nblock; ++i) {
- auto qh = x[i].qh;
- auto ql = x[i].ql;
- auto extra = x[i].extra;
- for (int j = 0; j < QK_IQ1BN/16; ++j) {
- uint16_t idx1 = ql[2*j+0] | ((qh[j] << 8) & 0x0f00);
- uint16_t idx2 = ql[2*j+1] | ((qh[j] << 4) & 0x0f00);
- uint16_t val1 = extra & 1 ? k_magic - iq1bn_grid_u16[idx1] : iq1bn_grid_u16[idx1];
- uint16_t val2 = extra & 2 ? k_magic - iq1bn_grid_u16[idx2] : iq1bn_grid_u16[idx2];
- extra >>= 2;
- aux32[0] = val1 | (val1 << 14);
- aux32[1] = (aux32[0] >> 4) & 0x03030303;
- aux32[0] &= 0x03030303;
- for (int k = 0; k < 8; ++k) sumi[k] += q8[k] * (aux8[k] - 1);
- q8 += 8;
- aux32[0] = val2 | (val2 << 14);
- aux32[1] = (aux32[0] >> 4) & 0x03030303;
- aux32[0] &= 0x03030303;
- for (int k = 0; k < 8; ++k) sumi[k] += q8[k] * (aux8[k] - 1);
- q8 += 8;
- }
- }
+ // TODO
+
+ //constexpr uint16_t k_magic = 0xaaaa;
+
+ //const block_iq1_bn * x = (const block_iq1_bn *)vx;
+
+ //const float * d8 = (const float *)vy;
+ //const int8_t * q8 = (const int8_t *)(d8 + 4);
+ //int nblock = n / QK_IQ1BN;
+
+ //int sumi[8] = {};
+ //uint32_t aux32[2];
+ //const int8_t * aux8 = (const int8_t *)aux32;
+
+ //for (int i = 0; i < nblock; ++i) {
+ // auto qh = x[i].qh;
+ // auto ql = x[i].ql;
+ // auto extra = x[i].extra;
+ // for (int j = 0; j < QK_IQ1BN/16; ++j) {
+ // uint16_t idx1 = ql[2*j+0] | ((qh[j] << 8) & 0x0f00);
+ // uint16_t idx2 = ql[2*j+1] | ((qh[j] << 4) & 0x0f00);
+ // uint16_t val1 = extra & 1 ? k_magic - iq1bn_grid_u16[idx1] : iq1bn_grid_u16[idx1];
+ // uint16_t val2 = extra & 2 ? k_magic - iq1bn_grid_u16[idx2] : iq1bn_grid_u16[idx2];
+ // extra >>= 2;
+ // aux32[0] = val1 | (val1 << 14);
+ // aux32[1] = (aux32[0] >> 4) & 0x03030303;
+ // aux32[0] &= 0x03030303;
+ // for (int k = 0; k < 8; ++k) sumi[k] += q8[k] * (aux8[k] - 1);
+ // q8 += 8;
+ // aux32[0] = val2 | (val2 << 14);
+ // aux32[1] = (aux32[0] >> 4) & 0x03030303;
+ // aux32[0] &= 0x03030303;
+ // for (int k = 0; k < 8; ++k) sumi[k] += q8[k] * (aux8[k] - 1);
+ // q8 += 8;
+ // }
+ //}
- *s = d8[0] * (sumi[0] + sumi[4]) + d8[1] * (sumi[1] + sumi[5]) + d8[2] * (sumi[2] + sumi[6]) + d8[3] * (sumi[3] + sumi[7]);
+ //*s = d8[0] * (sumi[0] + sumi[4]) + d8[1] * (sumi[1] + sumi[5]) + d8[2] * (sumi[2] + sumi[6]) + d8[3] * (sumi[3] + sumi[7]);
}
void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index d9aa074e..4d34f17b 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -1342,44 +1342,31 @@ template <int nrc> struct Q8_K64 {
struct DequantizerIQ1BN {
const __m256i m1_8 = _mm256_set1_epi8(1);
-#ifdef HAVE_FANCY_SIMD
- const __m128i shifthh = _mm_set_epi16(5, 6, 7, 8, 9, 10, 11, 12);
-#else
- const __m128i mulhh = _mm_set_epi16(32, 64, 128, 256, 512, 1024, 2048, 4096);
-#endif
- const __m128i maskhh = _mm_set1_epi16(4096);
- const __m256i shuffles[4] = {
- _mm256_set_epi64x(0x0302030203020302, 0x0302030203020302, 0x0100010001000100, 0x0100010001000100),
- _mm256_set_epi64x(0x0706070607060706, 0x0706070607060706, 0x0504050405040504, 0x0504050405040504),
- _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0908090809080908),
- _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0f0e0f0e0f0e0f0e, 0x0d0c0d0c0d0c0d0c, 0x0d0c0d0c0d0c0d0c),
+ static __m128i load_shuffle(int i) {
+ static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12,
+ 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12,
+ 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12,
+ 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12};
+ return _mm_loadu_si128((const __m128i*)data + i);
+ }
+ const __m128i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) };
+ const __m256i mult[4] = {
+ _mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
+ _mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
+ _mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
+ _mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
};
- const __m256i mult = _mm256_set_epi16(8, 24, 72, 216, 648, 1944, 5832, 17496, 8, 24, 72, 216, 648, 1944, 5832, 17496);
const __m256i m3 = _mm256_set1_epi16(3);
- const __m128i shuff_l = _mm_set_epi8(-128, 8, -128, 7, -128, 6, -128, 5, -128, 4, -128, 3, -128, 2, -128, 1);
- const __m128i shuff_h = _mm_set_epi8(12, -128, 11, -128, 10, -128, 9, -128, 12, -128, 11, -128, 10, -128, 9, -128);
- const __m128i shift_h = _mm_set_epi32(4, 4, 0, 0);
- const __m128i mask_h = _mm_set1_epi16(0x0f00);
- const __m128i shuff_hh = _mm_set_epi8(-128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0);
#ifdef HAVE_FANCY_SIMD
const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
#endif
- IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) {
+ IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const {
auto data = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes!
- auto aux1 = _mm_shuffle_epi8(data, shuff_l);
- auto aux2 = _mm_and_si128(_mm_srlv_epi32(_mm_shuffle_epi8(data, shuff_h), shift_h), mask_h);
-#ifdef HAVE_FANCY_SIMD
- auto aux3 = _mm_and_si128(_mm_sllv_epi16(_mm_shuffle_epi8(data, shuff_hh), shifthh), maskhh);
-#else
- auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_shuffle_epi8(data, shuff_hh), mulhh), maskhh);
-#endif
- auto all128 = _mm_or_si128(_mm_or_si128(aux1, aux2), aux3);
- auto all = MM256_SET_M128I(all128, all128);
- auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
- auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
- auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
- auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
+ auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[0])), mult[0]), m3);
+ auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[1])), mult[1]), m3);
+ auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[2])), mult[2]), m3);
+ auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[3])), mult[3]), m3);
#ifdef HAVE_FANCY_SIMD
v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8);
v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8);
@@ -1389,21 +1376,6 @@ struct DequantizerIQ1BN {
#endif
}
- //IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) {
-
- // auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql));
- // uint32_t aux32; std::memcpy(&aux32, qh, 4);
- // auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1));
- // auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh);
- // auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3));
- // auto all = MM256_SET_M128I(all128, all128);
- // auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
- // auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
- // auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
- // auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
- // v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8);
- // v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8);
- //}
};
template <int nrc_y>
@@ -1466,9 +1438,9 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
#else
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])),
- _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1])));
+ _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1])));
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])),
- _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3])));
+ _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3])));
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
#endif
@@ -4376,73 +4348,29 @@ static const uint64_t kall_signs[257] = {
struct DequantizerIQ1BN {
const uint8x16_t m1 = vdupq_n_u8(1);
- static inline uint8x16_t load_shuffle_l() {
- static const uint8_t data[16] = {1, 255, 2, 255, 3, 255, 4, 255, 5, 255, 6, 255, 7, 255, 8, 255};
- return vld1q_u8(data);
- }
- static inline uint8x16_t load_shuffle_h() {
- static const uint8_t data[16] = {9, 255, 10, 255, 11, 255, 12, 255, 9, 255, 10, 255, 11, 255, 12, 255};
- return vld1q_u8(data);
- }
- static inline uint8x16_t load_shuffle_hh() {
- static const uint8_t data[16] = {0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255};
- return vld1q_u8(data);
- }
- static inline int16x8_t load_shift_hh() {
- static const int16_t data[8] = {12, 11, 10, 9, 8, 7, 6, 5};
- return vld1q_s16(data);
- }
- static inline uint16x8_t load_mult() {
- //static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
- static const uint16_t data[8] = {2187*8, 729*8, 243*8, 81*8, 27*8, 9*8, 3*8, 1*8};
- return vld1q_u16(data);
- }
- //static inline uint8x16x4_t load_shuffles(uint16_t s0) {
- // uint8x16x4_t r;
- // auto step = vdupq_n_u8(4);
- // r.val[0] = vreinterpretq_u8_u16(vdupq_n_u16(s0));
- // r.val[1] = vaddq_u8(r.val[0], step);
- // r.val[2] = vaddq_u8(r.val[1], step);
- // r.val[3] = vaddq_u8(r.val[2], step);
- // return r;
- //}
-
- const uint8x16_t shuff_l = load_shuffle_l();
- const uint8x16_t shuff_h = load_shuffle_h();
- const int32x4_t shift_h = {8, 8, 4, 4};
- const uint16x8_t mask_h = vdupq_n_u16(0x0f00);
- const uint8x16_t shuff_hh = load_shuffle_hh();
- const uint16x8_t mask_hh = vdupq_n_u16(4096);
- const int16x8_t shift_hh = load_shift_hh();
- const uint16x8_t mult = load_mult();
- const uint8x16_t step = vdupq_n_u8(2);
- const uint8x16_t shuff0 = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
- //const uint8x16x4_t shuff1 = load_shuffles(0x0100);
- //const uint8x16x4_t shuff2 = load_shuffles(0x0302);
- //const uint16x8_t mask = vdupq_n_u16(0x1fff);
- //const uint16x8_t m3 = vdupq_n_u16(3);
+ static inline uint8x16x4_t load_shuffles() {
+ static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12,
+ 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12,
+ 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12,
+ 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12};
+ return vld1q_u8_x4(data);
+ }
+ static inline uint8x16x4_t load_mult() {
+ static const uint8_t data[64] = {81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81,
+ 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 27,
+ 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 9,
+ 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 3};
+ return vld1q_u8_x4(data);
+ }
+ const uint8x16x4_t shuff = load_shuffles();
+ const uint8x16x4_t mult = load_mult();
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
auto data = vld1q_u8((const uint8_t *)x);
- auto aux1 = vqtbl1q_u8(data, shuff_l);
- auto aux2 = vandq_u16(vshlq_u32(vqtbl1q_u8(data, shuff_h), shift_h), mask_h);
- auto aux3 = vandq_u16(vshlq_u16(vqtbl1q_u8(data, shuff_hh), shift_hh), mask_hh);
- auto all = vorrq_u16(vorrq_u16(aux1, aux2), aux3);
- auto shuffle = shuff0;
- //auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
- //auto step = vdupq_n_u8(2);
for (int k = 0; k < 4; ++k) {
- auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
- auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
- //auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff1.val[k]));
- //auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff2.val[k]));
- v1 = vmulq_u16(v1, mult);
- v2 = vmulq_u16(v2, mult);
- v1 = vshrq_n_u16(vhaddq_u16(v1, vshrq_n_u16(v1, 1)), 14);
- v2 = vshrq_n_u16(vhaddq_u16(v2, vshrq_n_u16(v2, 1)), 14);
- //v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13);
- //v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13);
- v.val[k] = vsubq_s8(vreinterpretq_s8_u8(vcombine_u8(vmovn_u16(v1), vmovn_u16(v2))), m1);
+ auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
+ val = vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6);
+ v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1);
}
}
};