diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-07-17 08:54:11 +0300 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-07-17 08:54:11 +0300 |
commit | 873a790ee22538d1d9d7205db7210c70955ab1e1 (patch) | |
tree | 426f0ce20cf45c325be28ae8bceb51d42c072452 | |
parent | 52a25e307c3af8686436d977c60e9975b0900e2b (diff) |
iq1bn(no lookup): better version
We have 4 groups of 16 in a block of 64 quants.
For each group of 16 we have 3 groups of 5, each using 8 bits.
The remaining 16'th quants of the 4 groups of 16 are encoded
with 8 bits using the same encoding as the groups of 5.
The only kernel where we have complications is the CUDA dequantize
kernel (because we are dequantizing 8 quants there, and we have
different encoding for the 1st and 2nd group of 8 in a group of 16).
Ths achieves better performance on all tested platforms than
any previous 1.625 bpw attempt. We have:
| model | size | params | backend | threads | test | t/s |
| ---------------- | ---------: | ---------: | ---------- | ------: | ------------: | ---------------: |
| 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | CUDA | 8 | pp512 | 9613.02 ± 24.54 |
| 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | CUDA | 8 | tg128 | 229.85 ± 0.33 |
| 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 16 | pp512 | 322.59 ± 1.00 |
| 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 16 | tg128 | 59.79 ± 0.03 |
| 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 8 | tg128 | 57.62 ± 0.21 |
| 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 4 | tg128 | 33.66 ± 0.29 |
| 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 2 | tg128 | 18.30 ± 0.01 |
| 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | Metal | 8 | pp512 | 698.13 ± 0.21 |
| 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | Metal | 8 | tg128 | 68.88 ± 0.24 |
| 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 8 | pp512 | 196.80 ± 0.50 |
| 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 8 | tg128 | 51.58 ± 0.41 |
| 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 4 | tg128 | 30.80 ± 0.03 |
| 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 2 | tg128 | 16.89 ± 0.01 |
It is still slower than 2 bpw Bitnet, but the difference now is not as
dramatic.
-rw-r--r-- | ggml-common.h | 5 | ||||
-rw-r--r-- | ggml-cuda/convert.cu | 26 | ||||
-rw-r--r-- | ggml-cuda/vecdotq.cuh | 76 | ||||
-rw-r--r-- | ggml-metal.metal | 117 | ||||
-rw-r--r-- | iqk-quantize.cpp | 120 | ||||
-rw-r--r-- | iqk_mul_mat.cpp | 150 |
6 files changed, 242 insertions, 252 deletions
diff --git a/ggml-common.h b/ggml-common.h index bf95da2a..f515e95c 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -380,11 +380,10 @@ static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m bl // #define QK_IQ1BN 64 typedef struct { + uint8_t ql[12]; uint8_t extra; - uint8_t ql[QK_IQ1BN/8]; - uint8_t qh[QK_IQ1BN/16]; } block_iq1_bn; -static_assert(sizeof(block_iq1_bn) == sizeof(uint8_t) + QK_IQ1BN/8 + QK_IQ1BN/16, "wrong iq1_bn block size/padding"); +static_assert(sizeof(block_iq1_bn) == 13, "wrong iq1_bn block size/padding"); // // Bitnet - implemented as 2.25 bpw // diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index a0a826e9..4a67c498 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -425,7 +425,10 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst const int64_t ii = blockIdx.x; const block_iq1_bn * x = (const block_iq1_bn *) vx; - static const uint16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1}; + static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; + +//#define COMPUTE_VS(v) 3*v >> 8 +#define COMPUTE_VS(v) (v + (v >> 1)) >> 7 const int tid = threadIdx.x; const int il = tid/4; // 0...7 @@ -433,11 +436,24 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst dst_t * y = yy + ii*QK_K + 64*ib + 8*il; int64_t i = QK_K/QK_IQ1BN * ii + ib; if (i >= nb64) return; - uint16_t val = x[i].ql[il] | ((x[i].qh[il%4] << (8 - 4*(il/4))) & 0x0f00) | ((x[i].extra << (12 - il)) & 4096); - for (int j = 0; j < 8; ++j) { - uint16_t v = (val*k_mult[j] & 0x1fff)*3 >> 13; - y[j] = v - 1; + const int i16 = il/2; + uint8_t q = x[i].ql[3*i16+2*(il%2)]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + int8_t vs = COMPUTE_VS(v); + y[2*(il%2)+j] = vs - 1; } + q = x[i].ql[3*i16+1]; + for (int j = 0; j < 2; ++j) { + uint8_t v = k_mult[3*(il%2)+j]*q; + int8_t vs = COMPUTE_VS(v); + y[5*(1-(il%2))+j] = vs-1; + } + uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q; + int8_t vs = COMPUTE_VS(v); + y[7] = vs - 1; + +#undef COMPUTE_VS } template<typename dst_t> diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh index d0d2e923..f0133e07 100644 --- a/ggml-cuda/vecdotq.cuh +++ b/ggml-cuda/vecdotq.cuh @@ -1078,67 +1078,47 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx; - static const int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1}; + static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; // iqs is 0 or 1 - uint8_t extra = bq1->extra >> 4*iqs; int sumi = 0; #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const int * q8 = (const int *)bq8_1[iqs].qs; - //int v[2]; - //int8_t * a = (int8_t *)v; - //for (int l = 0; l < 2; ++l) { - // int16_t val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096); - // int16_t val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096); - // for (int k = 0; k < 8; ++k) a[k] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1; - // sumi = __dp4a(v[0], q8[4*l+0], __dp4a(v[1], q8[4*l+1], sumi)); - // for (int k = 0; k < 8; ++k) a[k] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1; - // sumi = __dp4a(v[0], q8[4*l+2], __dp4a(v[1], q8[4*l+3], sumi)); - // extra >>= 2; - //} - - int v[4]; - int8_t * a = (int8_t *)v; + int val[4]; for (int l = 0; l < 2; ++l) { - int16_t val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096); - int16_t val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096); - for (int k = 0; k < 8; ++k) { - a[k+0] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1; - a[k+8] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1; + int8_t * a = (int8_t *)val; + const int i16 = 2*iqs + l; + for (int k = 0; k < 3; ++k) { + uint8_t q = bq1->ql[3*i16+k]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7; + *a++ = vs-1; + } } - sumi = __dp4a(v[0], q8[4*l+0], __dp4a(v[1], q8[4*l+1], __dp4a(v[2], q8[4*l+2], __dp4a(v[3], q8[4*l+3], sumi)))); - extra >>= 2; + uint8_t v = k_mult[i16]*bq1->extra; + int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7; + *a++ = vs-1; + sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi)))); } - - //int v[8]; - //int8_t * a = (int8_t *)v; - //int16_t val1 = bq1->ql[4*iqs + 0] | ((bq1->qh[0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096); - //int16_t val2 = bq1->ql[4*iqs + 1] | ((bq1->qh[1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096); - //int16_t val3 = bq1->ql[4*iqs + 2] | ((bq1->qh[2] << (8-4*iqs)) & 0x0f00) | ((extra << 10) & 4096); - //int16_t val4 = bq1->ql[4*iqs + 3] | ((bq1->qh[3] << (8-4*iqs)) & 0x0f00) | ((extra << 9) & 4096); - //for (int k = 0; k < 8; ++k) { - // a[k+ 0] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1; - // a[k+ 8] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1; - // a[k+16] = ((val3*k_mult[k] & 0x1fff)*3 >> 13) - 1; - // a[k+24] = ((val4*k_mult[k] & 0x1fff)*3 >> 13) - 1; - //} - //sumi = __dp4a(v[0], q8[0], __dp4a(v[1], q8[1], __dp4a(v[2], q8[2], __dp4a(v[3], q8[3], sumi)))); - //sumi = __dp4a(v[4], q8[4], __dp4a(v[5], q8[5], __dp4a(v[6], q8[6], __dp4a(v[7], q8[7], sumi)))); #else const int8_t * q8 = bq8_1[iqs].qs; for (int l = 0; l < 2; ++l) { - int val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096); - int val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096); - for (int k = 0; k < 4; ++k) { - int v1 = (val1*k_mult[k+0] & 0x1fff)*3 >> 13; - int v2 = (val1*k_mult[k+4] & 0x1fff)*3 >> 13; - int v3 = (val2*k_mult[k+0] & 0x1fff)*3 >> 13; - int v4 = (val2*k_mult[k+4] & 0x1fff)*3 >> 13; - sumi += (v1 - 1)*q8[k+0] + (v2 - 1)*q8[k+4] + (v3 - 1)*q8[k+8] + (v4 - 1)*q8[k+12]; + const int i16 = 2*iqs + l; + for (int k = 0; k < 3; ++k) { + uint8_t q = bq1->ql[3*i16+k]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + int8_t vs = (v + (v >> 1)) >> 7; + sumi += q8[j]*(vs - 1); + } + q8 += 5; } - q8 += 16; - extra >>= 2; + uint8_t v = k_mult[i16]*bq1->extra; + int8_t vs = (v + (v >> 1)) >> 7; + sumi += q8[0]*(vs - 1); + q8++; } #endif return __low2float(bq8_1[iqs].ds) * sumi; diff --git a/ggml-metal.metal b/ggml-metal.metal index cca4b0e9..2bf6894a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -5054,6 +5054,49 @@ static inline float iq1bn_fp8_to_float(uint8_t fp8) { return s.f; } +//static constant int8_t iq1bn_values[256*5] = { +// -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 0, -1, -1, -1, 0, 0, -1, -1, -1, 1, 0, +// -1, -1, -1, -1, 1, -1, -1, -1, 0, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, 0, -1, -1, 0, -1, 0, -1, -1, 1, -1, 0, -1, +// -1, -1, 0, 0, -1, -1, 0, 0, 0, -1, -1, 1, 0, 0, -1, -1, -1, 1, 0, -1, -1, 0, 1, 0, -1, -1, 1, 1, 0, -1, -1, -1, +// -1, 1, -1, -1, 0, 0, 0, 0, 0, 0, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, 0, 1, -1, -1, 0, 0, 1, -1, -1, 1, 0, 1, +// -1, -1, -1, 1, 1, -1, -1, 0, 1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 0, -1, 0, -1, -1, 0, -1, 1, -1, -1, 0, -1, +// -1, 0, -1, 0, -1, 0, 0, -1, 0, -1, 1, 0, -1, 0, -1, -1, 1, -1, 0, -1, 0, 1, -1, 0, -1, 1, 1, -1, 0, -1, -1, -1, +// 0, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 0, -1, -1, 0, 0, 0, -1, 0, 0, 0, 0, -1, 1, 0, 0, 0, +// -1, -1, 1, 0, 0, -1, 0, 1, 0, 0, -1, 1, 1, 0, 0, -1, -1, -1, 1, 0, -1, 0, -1, 1, 0, -1, 1, -1, 1, 0, -1, -1, +// 0, 1, 0, -1, 0, 0, 1, 0, -1, 1, 0, 1, 0, -1, -1, 1, 1, 0, -1, 0, 1, 1, 0, -1, 1, 1, 1, 0, -1, -1, -1, -1, +// 1, -1, 0, -1, -1, 1, -1, 1, -1, -1, 1, -1, 0, 0, 0, 0, 0, -1, 0, -1, 1, -1, 0, 0, -1, 1, -1, 1, 0, -1, 1, -1, +// -1, 1, -1, 1, -1, 0, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 0, 1, -1, 0, -1, 0, 1, -1, 1, -1, 0, 1, -1, -1, 0, +// 0, 1, -1, 0, 0, 0, 1, -1, 1, 0, 0, 1, -1, -1, 1, 0, 1, -1, 0, 1, 0, 1, -1, 1, 1, 0, 1, -1, -1, -1, 1, 1, +// -1, 0, -1, 1, 1, -1, 1, -1, 1, 1, -1, 0, 0, 0, 0, 0, -1, 0, 1, 1, -1, 0, 0, 1, 1, -1, 1, 0, 1, 1, -1, -1, +// 1, 1, 1, -1, 0, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, 0, 0, -1, -1, -1, 0, 1, -1, -1, -1, 0, -1, 0, -1, +// -1, 0, 0, 0, -1, -1, 0, 1, 0, -1, -1, 0, -1, 1, -1, -1, 0, 0, 1, -1, -1, 0, 1, 1, -1, -1, 0, -1, -1, 0, -1, 0, +// 0, -1, 0, -1, 0, 1, -1, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 1, 0, 0, -1, 0, -1, 1, +// 0, -1, 0, 0, 1, 0, -1, 0, 1, 1, 0, -1, 0, -1, -1, 1, -1, 0, 0, -1, 1, -1, 0, 1, -1, 1, -1, 0, -1, 0, 1, -1, +// 0, 0, 0, 1, -1, 0, 1, 0, 1, -1, 0, -1, 1, 1, -1, 0, 0, 1, 1, -1, 0, 1, 1, 1, -1, 0, -1, -1, -1, 0, 0, 0, +// -1, -1, 0, 0, 1, -1, -1, 0, 0, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, -1, 1, -1, +// 0, 0, 0, 1, -1, 0, 0, 1, 1, -1, 0, 0, -1, -1, 0, 0, 0, 0, -1, 0, 0, 0, 1, -1, 0, 0, 0, -1, 0, 0, 0, 0, +// 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, -1, -1, 1, 0, 0, 0, -1, +// 1, 0, 0, 1, -1, 1, 0, 0, -1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, -1, 1, 1, 0, +// 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, -1, -1, -1, 1, 0, 0, -1, -1, 1, 0, 1, -1, -1, 1, 0, -1, 0, -1, 1, 0, 0, +// 0, -1, 1, 0, 1, 0, -1, 1, 0, -1, 1, -1, 1, 0, 0, 1, -1, 1, 0, 1, 1, -1, 1, 0, -1, -1, 0, 1, 0, 0, -1, 0, +// 1, 0, 1, -1, 0, 1, 0, -1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1, 1, 0, 1, 0, +// 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, -1, -1, 1, 1, 0, 0, -1, 1, 1, 0, 1, -1, 1, 1, 0, -1, 0, 1, 1, 0, 0, 0, +// 1, 1, 0, 1, 0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, -1, -1, -1, -1, 1, 0, -1, -1, -1, +// 1, 1, -1, -1, -1, 1, -1, 0, -1, -1, 1, 0, 0, -1, -1, 1, 1, 0, -1, -1, 1, -1, 1, -1, -1, 1, 0, 0, 0, 0, 0, 0, +// 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 0, -1, 1, 0, -1, 0, -1, 1, 1, -1, 0, -1, 1, -1, 0, 0, -1, 1, 0, 0, 0, +// -1, 1, 1, 0, 0, -1, 1, -1, 1, 0, -1, 1, 0, 1, 0, -1, 1, 1, 1, 0, -1, 1, -1, -1, 1, -1, 1, 0, -1, 1, -1, 1, +// 1, -1, 1, -1, 1, -1, 0, 1, -1, 1, 0, 0, 1, -1, 1, 1, 0, 1, -1, 1, -1, 1, 1, -1, 1, 0, 0, 0, 0, 0, 0, 1, +// 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 0, 1, 0, -1, -1, 0, 1, 1, -1, -1, 0, 1, -1, 0, -1, 0, 1, 0, 0, -1, 0, +// 1, 1, 0, -1, 0, 1, -1, 1, -1, 0, 1, 0, 1, -1, 0, 1, 1, 1, -1, 0, 1, -1, -1, 0, 0, 1, 0, -1, 0, 0, 1, 1, +// -1, 0, 0, 1, -1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, -1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, +// 0, 0, 1, 1, 0, 0, 1, -1, -1, 1, 0, 1, 0, -1, 1, 0, 1, 1, -1, 1, 0, 1, -1, 0, 1, 0, 1, 0, 0, 1, 0, 1, +// 1, 0, 1, 0, 1, -1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, -1, -1, -1, 1, 1, 0, -1, -1, 1, 1, 1, -1, +// -1, 1, 1, -1, 0, -1, 1, 1, 0, 0, -1, 1, 1, 1, 0, -1, 1, 1, -1, 1, -1, 1, 1, 0, 1, -1, 1, 1, 1, 1, -1, 1, +// 1, 0, 0, 0, 0, 0, -1, -1, 0, 1, 1, 0, -1, 0, 1, 1, 1, -1, 0, 1, 1, -1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, +// 0, 0, 1, 1, -1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, -1, -1, 1, 1, 1, 0, -1, 1, 1, 1, 1, -1, 1, +// 1, 1, -1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, -1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, +//}; + void kernel_mul_mv_iq1_bn_f32_impl( device const void * src0, device const float * src1, @@ -5087,53 +5130,62 @@ void kernel_mul_mv_iq1_bn_f32_impl( device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float4 yl[2]; - float sumf[N_DST]={0.f}, all_sum; + float yl[16]; + float sumf[N_DST]={0.f}; const int nb32 = nb * (QK_IQ1BN / 32); - const int ix = tiisg/4; - const int ir = tiisg%4; + const int ix = tiisg/2; + const int ir = tiisg%2; - device const float4 * y4 = (device const float4 *)y + 8 * ix + 2 * ir; + device const float * y4 = (device const float *)y + 32 * ix + 16 * ir; uint32_t aux32[2]; thread const uint8_t * aux8 = (thread const uint8_t *)aux32; const float values[3] = {-1.f, 0.f, 1.f}; - constexpr int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1}; + constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1}; - for (int ib32 = ix; ib32 < nb32; ib32 += 8) { + for (int ib32 = ix; ib32 < nb32; ib32 += 16) { - yl[0] = y4[0]; yl[1] = y4[1]; + for (int j = 0; j < 16; ++j) yl[j] = y4[j]; const int ibl = ib32 / (QK_IQ1BN / 32); const int ib = ib32 % (QK_IQ1BN / 32); - const int il = 4*ib + ir; + const int i16 = 2*ib + ir; device const block_iq1_bn * xr = x + ibl; + device const uint8_t * ql = xr->ql + 3*i16; device const uint8_t * extra = (device const uint8_t *)&xr->extra; - device const uint8_t * ql = xr->ql + il; - device const uint8_t * qh = xr->qh + il%4; for (int row = 0; row < N_DST; row++) { - uint8_t h = extra[0] >> il; + float acc = 0; + int i = 0; + for (int k = 0; k < 3; ++k) { + //constant int8_t * vs = iq1bn_values + 5*ql[k]; + //for (int j = 0; j < 5; ++j) acc += yl[5*k+j]*vs[j]; + uint8_t q = ql[k]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + v = 3*v >> 8; //(v + (v >> 1)) >> 7; + acc += yl[i++] * values[v]; + } + } + //constant int8_t * vs = iq1bn_values + 5*extra[0]; + //acc += yl[15] * vs[i16]; + uint8_t v = k_mult[i16]*extra[0]; + v = 3*v >> 8; //(v + (v >> 1)) >> 7; + acc += yl[15] * values[v]; - int16_t val = ql[0] | ((qh[0] << (8 - 4*(il/4))) & 0x0f00) | ((extra[0] << (12 - il)) & 4096); - float4 acc4 = yl[0] * float4{values[(val*k_mult[0] & 0x1fff)*3 >> 13], values[(val*k_mult[1] & 0x1fff)*3 >> 13], - values[(val*k_mult[2] & 0x1fff)*3 >> 13], values[(val*k_mult[3] & 0x1fff)*3 >> 13]} - + yl[1] * float4{values[(val*k_mult[4] & 0x1fff)*3 >> 13], values[(val*k_mult[5] & 0x1fff)*3 >> 13], - values[(val*k_mult[6] & 0x1fff)*3 >> 13], values[(val*k_mult[7] & 0x1fff)*3 >> 13]}; - sumf[row] += acc4[0] + acc4[1] + acc4[2] + acc4[3]; + sumf[row] += acc; extra += nb*sizeof(block_iq1_bn); ql += nb*sizeof(block_iq1_bn); - qh += nb*sizeof(block_iq1_bn); } - y4 += 32 * 2; + y4 += 32 * 16; } for (int row = 0; row < N_DST; row += 2) { @@ -5990,18 +6042,23 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & template <typename type4x4> void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) { // il is in 0...3 - uint8_t gs = xb->extra >> 2*il; - - constexpr int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1}; - - short il1 = 2*il+0, il2 = 2*il+1; - int16_t v1 = xb->ql[il1] | ((xb->qh[il1%4] << (8 - 4*(il1/4))) & 0x0f00) | ((gs << 12) & 4096); - int16_t v2 = xb->ql[il2] | ((xb->qh[il2%4] << (8 - 4*(il2/4))) & 0x0f00) | ((gs << 11) & 4096); - for (int i = 0; i < 8; ++i) { - reg[i/4+0][i%4] = ((v1*k_mult[i] & 0x1fff)*3 >> 13) - 1; - reg[i/4+2][i%4] = ((v2*k_mult[i] & 0x1fff)*3 >> 13) - 1; + constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1}; + + int i = 0; + for (int k = 0; k < 3; ++k) { + uint8_t q = xb->ql[3*il + k]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + int8_t vs = 3*v >> 8; + //int8_t vs = (v + (v >> 1)) >> 7; + reg[i/4][i%4] = vs - 1; + ++i; + } } + uint8_t v = k_mult[il]*xb->extra; + int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7; + reg[3][3] = vs - 1; } template <typename type4x4> diff --git a/iqk-quantize.cpp b/iqk-quantize.cpp index 03d73ff0..f25f2031 100644 --- a/iqk-quantize.cpp +++ b/iqk-quantize.cpp @@ -118,7 +118,7 @@ uint16_t IQ1BNQuantizer::quantize_one_block_1bn(const IQ1BNData& iq1bn, const fl void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix) { - static const int k_nb[8] = {1, 3, 9, 27, 81, 243, 729, 2187}; + static const int k_nb[6] = {1, 3, 9, 27, 81, 243}; (void)imatrix; const int nblock = n_per_row/QK_IQ1BN; @@ -126,21 +126,24 @@ void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, i for (int ib = 0; ib < nblock; ++ib) { std::memset(&y[ib], 0, sizeof(block_iq1_bn)); auto xb = src + ib*QK_IQ1BN; - for (int i = 0; i < QK_IQ1BN/8; ++i) { - int idx = 0; - for (int j = 0; j < 8; ++j) { - float v = xb[8*i + j]; - int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2; - idx += k_nb[j]*q; + int v13 = 0; + for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) { + for (int k = 0; k < 3; ++k) { + int idx = 0; + for (int j = 0; j < 5; ++j) { + float v = xb[16*i16 + 5*k + j]; + int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2; + idx += k_nb[j]*q; + } + idx = (256*idx + k_nb[5] - 1)/k_nb[5]; + y[ib].ql[3*i16 + k] = idx; } - idx = (8192*idx + 6560)/6561; - y[ib].ql[i] = idx & 255; - y[ib].qh[i%4] |= ((idx >> 8) & 0xf) << 4*(i/4); - y[ib].extra |= (idx >> 12) << i; - + float v = xb[16*i16 + 15]; + int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2; + v13 += k_nb[i16]*q; } + y[ib].extra = (256*v13 + k_nb[5] - 1)/k_nb[5]; } - } void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, int n_per_row, const float * imatrix) { @@ -194,18 +197,23 @@ void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) { assert(k%QK_IQ1BN == 0); int nblock = k / QK_IQ1BN; - static const int k_mult[8] = {17496, 5832, 1944, 648, 216, 72, 24, 8}; + static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; for (int i = 0; i < nblock; ++i) { uint8_t extra = x[i].extra; - auto qh = x[i].qh; auto ql = x[i].ql; - for (int k = 0; k < QK_IQ1BN/8; ++k) { - uint16_t idx = ql[k] | ((qh[k%4] << (8 - 4*(k/4))) & 0x0f00) | ((extra << (12 - k)) & 4096); - for (int j = 0; j < 8; ++j) { - int v = (idx*k_mult[j] & 0xffff)*3 >> 16; - *y++ = v - 1; + for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) { + for (int k = 0; k < 3; ++k) { + for (int j = 0; j < 5; ++j) { + uint8_t v = ql[k]*k_mult[j]; + int8_t vs = ((v + (v >> 1)) >> 7); + *y++ = vs - 1; + } } + ql += 3; + uint8_t v = extra*k_mult[i16]; + int8_t vs = ((v + (v >> 1)) >> 7); + *y++ = vs - 1; } } } @@ -260,42 +268,44 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si return; } - constexpr uint16_t k_magic = 0xaaaa; - - const block_iq1_bn * x = (const block_iq1_bn *)vx; - - const float * d8 = (const float *)vy; - const int8_t * q8 = (const int8_t *)(d8 + 4); - int nblock = n / QK_IQ1BN; - - int sumi[8] = {}; - uint32_t aux32[2]; - const int8_t * aux8 = (const int8_t *)aux32; - - for (int i = 0; i < nblock; ++i) { - auto qh = x[i].qh; - auto ql = x[i].ql; - auto extra = x[i].extra; - for (int j = 0; j < QK_IQ1BN/16; ++j) { - uint16_t idx1 = ql[2*j+0] | ((qh[j] << 8) & 0x0f00); - uint16_t idx2 = ql[2*j+1] | ((qh[j] << 4) & 0x0f00); - uint16_t val1 = extra & 1 ? k_magic - iq1bn_grid_u16[idx1] : iq1bn_grid_u16[idx1]; - uint16_t val2 = extra & 2 ? k_magic - iq1bn_grid_u16[idx2] : iq1bn_grid_u16[idx2]; - extra >>= 2; - aux32[0] = val1 | (val1 << 14); - aux32[1] = (aux32[0] >> 4) & 0x03030303; - aux32[0] &= 0x03030303; - for (int k = 0; k < 8; ++k) sumi[k] += q8[k] * (aux8[k] - 1); - q8 += 8; - aux32[0] = val2 | (val2 << 14); - aux32[1] = (aux32[0] >> 4) & 0x03030303; - aux32[0] &= 0x03030303; - for (int k = 0; k < 8; ++k) sumi[k] += q8[k] * (aux8[k] - 1); - q8 += 8; - } - } + // TODO + + //constexpr uint16_t k_magic = 0xaaaa; + + //const block_iq1_bn * x = (const block_iq1_bn *)vx; + + //const float * d8 = (const float *)vy; + //const int8_t * q8 = (const int8_t *)(d8 + 4); + //int nblock = n / QK_IQ1BN; + + //int sumi[8] = {}; + //uint32_t aux32[2]; + //const int8_t * aux8 = (const int8_t *)aux32; + + //for (int i = 0; i < nblock; ++i) { + // auto qh = x[i].qh; + // auto ql = x[i].ql; + // auto extra = x[i].extra; + // for (int j = 0; j < QK_IQ1BN/16; ++j) { + // uint16_t idx1 = ql[2*j+0] | ((qh[j] << 8) & 0x0f00); + // uint16_t idx2 = ql[2*j+1] | ((qh[j] << 4) & 0x0f00); + // uint16_t val1 = extra & 1 ? k_magic - iq1bn_grid_u16[idx1] : iq1bn_grid_u16[idx1]; + // uint16_t val2 = extra & 2 ? k_magic - iq1bn_grid_u16[idx2] : iq1bn_grid_u16[idx2]; + // extra >>= 2; + // aux32[0] = val1 | (val1 << 14); + // aux32[1] = (aux32[0] >> 4) & 0x03030303; + // aux32[0] &= 0x03030303; + // for (int k = 0; k < 8; ++k) sumi[k] += q8[k] * (aux8[k] - 1); + // q8 += 8; + // aux32[0] = val2 | (val2 << 14); + // aux32[1] = (aux32[0] >> 4) & 0x03030303; + // aux32[0] &= 0x03030303; + // for (int k = 0; k < 8; ++k) sumi[k] += q8[k] * (aux8[k] - 1); + // q8 += 8; + // } + //} - *s = d8[0] * (sumi[0] + sumi[4]) + d8[1] * (sumi[1] + sumi[5]) + d8[2] * (sumi[2] + sumi[6]) + d8[3] * (sumi[3] + sumi[7]); + //*s = d8[0] * (sumi[0] + sumi[4]) + d8[1] * (sumi[1] + sumi[5]) + d8[2] * (sumi[2] + sumi[6]) + d8[3] * (sumi[3] + sumi[7]); } void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index d9aa074e..4d34f17b 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1342,44 +1342,31 @@ template <int nrc> struct Q8_K64 { struct DequantizerIQ1BN { const __m256i m1_8 = _mm256_set1_epi8(1); -#ifdef HAVE_FANCY_SIMD - const __m128i shifthh = _mm_set_epi16(5, 6, 7, 8, 9, 10, 11, 12); -#else - const __m128i mulhh = _mm_set_epi16(32, 64, 128, 256, 512, 1024, 2048, 4096); -#endif - const __m128i maskhh = _mm_set1_epi16(4096); - const __m256i shuffles[4] = { - _mm256_set_epi64x(0x0302030203020302, 0x0302030203020302, 0x0100010001000100, 0x0100010001000100), - _mm256_set_epi64x(0x0706070607060706, 0x0706070607060706, 0x0504050405040504, 0x0504050405040504), - _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0908090809080908), - _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0f0e0f0e0f0e0f0e, 0x0d0c0d0c0d0c0d0c, 0x0d0c0d0c0d0c0d0c), + static __m128i load_shuffle(int i) { + static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12, + 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12, + 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12, + 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12}; + return _mm_loadu_si128((const __m128i*)data + i); + } + const __m128i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) }; + const __m256i mult[4] = { + _mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), + _mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), + _mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), + _mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), }; - const __m256i mult = _mm256_set_epi16(8, 24, 72, 216, 648, 1944, 5832, 17496, 8, 24, 72, 216, 648, 1944, 5832, 17496); const __m256i m3 = _mm256_set1_epi16(3); - const __m128i shuff_l = _mm_set_epi8(-128, 8, -128, 7, -128, 6, -128, 5, -128, 4, -128, 3, -128, 2, -128, 1); - const __m128i shuff_h = _mm_set_epi8(12, -128, 11, -128, 10, -128, 9, -128, 12, -128, 11, -128, 10, -128, 9, -128); - const __m128i shift_h = _mm_set_epi32(4, 4, 0, 0); - const __m128i mask_h = _mm_set1_epi16(0x0f00); - const __m128i shuff_hh = _mm_set_epi8(-128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0); #ifdef HAVE_FANCY_SIMD const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); #endif - IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) { + IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const { auto data = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes! - auto aux1 = _mm_shuffle_epi8(data, shuff_l); - auto aux2 = _mm_and_si128(_mm_srlv_epi32(_mm_shuffle_epi8(data, shuff_h), shift_h), mask_h); -#ifdef HAVE_FANCY_SIMD - auto aux3 = _mm_and_si128(_mm_sllv_epi16(_mm_shuffle_epi8(data, shuff_hh), shifthh), maskhh); -#else - auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_shuffle_epi8(data, shuff_hh), mulhh), maskhh); -#endif - auto all128 = _mm_or_si128(_mm_or_si128(aux1, aux2), aux3); - auto all = MM256_SET_M128I(all128, all128); - auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3); - auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3); - auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3); - auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3); + auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[0])), mult[0]), m3); + auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[1])), mult[1]), m3); + auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[2])), mult[2]), m3); + auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[3])), mult[3]), m3); #ifdef HAVE_FANCY_SIMD v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8); v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8); @@ -1389,21 +1376,6 @@ struct DequantizerIQ1BN { #endif } - //IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) { - - // auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql)); - // uint32_t aux32; std::memcpy(&aux32, qh, 4); - // auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1)); - // auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh); - // auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3)); - // auto all = MM256_SET_M128I(all128, all128); - // auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3); - // auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3); - // auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3); - // auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3); - // v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8); - // v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8); - //} }; template <int nrc_y> @@ -1466,9 +1438,9 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4); #else auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]))); + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]))); auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]))); + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]))); dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2)); accd[iy] = _mm256_add_epi32(dot1, accd[iy]); #endif @@ -4376,73 +4348,29 @@ static const uint64_t kall_signs[257] = { struct DequantizerIQ1BN { const uint8x16_t m1 = vdupq_n_u8(1); - static inline uint8x16_t load_shuffle_l() { - static const uint8_t data[16] = {1, 255, 2, 255, 3, 255, 4, 255, 5, 255, 6, 255, 7, 255, 8, 255}; - return vld1q_u8(data); - } - static inline uint8x16_t load_shuffle_h() { - static const uint8_t data[16] = {9, 255, 10, 255, 11, 255, 12, 255, 9, 255, 10, 255, 11, 255, 12, 255}; - return vld1q_u8(data); - } - static inline uint8x16_t load_shuffle_hh() { - static const uint8_t data[16] = {0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255}; - return vld1q_u8(data); - } - static inline int16x8_t load_shift_hh() { - static const int16_t data[8] = {12, 11, 10, 9, 8, 7, 6, 5}; - return vld1q_s16(data); - } - static inline uint16x8_t load_mult() { - //static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1}; - static const uint16_t data[8] = {2187*8, 729*8, 243*8, 81*8, 27*8, 9*8, 3*8, 1*8}; - return vld1q_u16(data); - } - //static inline uint8x16x4_t load_shuffles(uint16_t s0) { - // uint8x16x4_t r; - // auto step = vdupq_n_u8(4); - // r.val[0] = vreinterpretq_u8_u16(vdupq_n_u16(s0)); - // r.val[1] = vaddq_u8(r.val[0], step); - // r.val[2] = vaddq_u8(r.val[1], step); - // r.val[3] = vaddq_u8(r.val[2], step); - // return r; - //} - - const uint8x16_t shuff_l = load_shuffle_l(); - const uint8x16_t shuff_h = load_shuffle_h(); - const int32x4_t shift_h = {8, 8, 4, 4}; - const uint16x8_t mask_h = vdupq_n_u16(0x0f00); - const uint8x16_t shuff_hh = load_shuffle_hh(); - const uint16x8_t mask_hh = vdupq_n_u16(4096); - const int16x8_t shift_hh = load_shift_hh(); - const uint16x8_t mult = load_mult(); - const uint8x16_t step = vdupq_n_u8(2); - const uint8x16_t shuff0 = vreinterpretq_u8_u16(vdupq_n_u16(0x0100)); - //const uint8x16x4_t shuff1 = load_shuffles(0x0100); - //const uint8x16x4_t shuff2 = load_shuffles(0x0302); - //const uint16x8_t mask = vdupq_n_u16(0x1fff); - //const uint16x8_t m3 = vdupq_n_u16(3); + static inline uint8x16x4_t load_shuffles() { + static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12, + 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12, + 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12, + 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12}; + return vld1q_u8_x4(data); + } + static inline uint8x16x4_t load_mult() { + static const uint8_t data[64] = {81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, + 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 27, + 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 9, + 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 3}; + return vld1q_u8_x4(data); + } + const uint8x16x4_t shuff = load_shuffles(); + const uint8x16x4_t mult = load_mult(); IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const { auto data = vld1q_u8((const uint8_t *)x); - auto aux1 = vqtbl1q_u8(data, shuff_l); - auto aux2 = vandq_u16(vshlq_u32(vqtbl1q_u8(data, shuff_h), shift_h), mask_h); - auto aux3 = vandq_u16(vshlq_u16(vqtbl1q_u8(data, shuff_hh), shift_hh), mask_hh); - auto all = vorrq_u16(vorrq_u16(aux1, aux2), aux3); - auto shuffle = shuff0; - //auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100)); - //auto step = vdupq_n_u8(2); for (int k = 0; k < 4; ++k) { - auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step); - auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step); - //auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff1.val[k])); - //auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff2.val[k])); - v1 = vmulq_u16(v1, mult); - v2 = vmulq_u16(v2, mult); - v1 = vshrq_n_u16(vhaddq_u16(v1, vshrq_n_u16(v1, 1)), 14); - v2 = vshrq_n_u16(vhaddq_u16(v2, vshrq_n_u16(v2, 1)), 14); - //v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13); - //v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13); - v.val[k] = vsubq_s8(vreinterpretq_s8_u8(vcombine_u8(vmovn_u16(v1), vmovn_u16(v2))), m1); + auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]); + val = vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6); + v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1); } } }; |