summaryrefslogtreecommitdiff
path: root/ggml-cuda
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-07-17 08:54:11 +0300
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-07-17 08:54:11 +0300
commit873a790ee22538d1d9d7205db7210c70955ab1e1 (patch)
tree426f0ce20cf45c325be28ae8bceb51d42c072452 /ggml-cuda
parent52a25e307c3af8686436d977c60e9975b0900e2b (diff)
iq1bn(no lookup): better version
We have 4 groups of 16 in a block of 64 quants. For each group of 16 we have 3 groups of 5, each using 8 bits. The remaining 16'th quants of the 4 groups of 16 are encoded with 8 bits using the same encoding as the groups of 5. The only kernel where we have complications is the CUDA dequantize kernel (because we are dequantizing 8 quants there, and we have different encoding for the 1st and 2nd group of 8 in a group of 16). Ths achieves better performance on all tested platforms than any previous 1.625 bpw attempt. We have: | model | size | params | backend | threads | test | t/s | | ---------------- | ---------: | ---------: | ---------- | ------: | ------------: | ---------------: | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | CUDA | 8 | pp512 | 9613.02 ± 24.54 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | CUDA | 8 | tg128 | 229.85 ± 0.33 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 16 | pp512 | 322.59 ± 1.00 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 16 | tg128 | 59.79 ± 0.03 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 8 | tg128 | 57.62 ± 0.21 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 4 | tg128 | 33.66 ± 0.29 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 2 | tg128 | 18.30 ± 0.01 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | Metal | 8 | pp512 | 698.13 ± 0.21 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | Metal | 8 | tg128 | 68.88 ± 0.24 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 8 | pp512 | 196.80 ± 0.50 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 8 | tg128 | 51.58 ± 0.41 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 4 | tg128 | 30.80 ± 0.03 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 2 | tg128 | 16.89 ± 0.01 | It is still slower than 2 bpw Bitnet, but the difference now is not as dramatic.
Diffstat (limited to 'ggml-cuda')
-rw-r--r--ggml-cuda/convert.cu26
-rw-r--r--ggml-cuda/vecdotq.cuh76
2 files changed, 49 insertions, 53 deletions
diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu
index a0a826e9..4a67c498 100644
--- a/ggml-cuda/convert.cu
+++ b/ggml-cuda/convert.cu
@@ -425,7 +425,10 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst
const int64_t ii = blockIdx.x;
const block_iq1_bn * x = (const block_iq1_bn *) vx;
- static const uint16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
+ static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+
+//#define COMPUTE_VS(v) 3*v >> 8
+#define COMPUTE_VS(v) (v + (v >> 1)) >> 7
const int tid = threadIdx.x;
const int il = tid/4; // 0...7
@@ -433,11 +436,24 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst
dst_t * y = yy + ii*QK_K + 64*ib + 8*il;
int64_t i = QK_K/QK_IQ1BN * ii + ib;
if (i >= nb64) return;
- uint16_t val = x[i].ql[il] | ((x[i].qh[il%4] << (8 - 4*(il/4))) & 0x0f00) | ((x[i].extra << (12 - il)) & 4096);
- for (int j = 0; j < 8; ++j) {
- uint16_t v = (val*k_mult[j] & 0x1fff)*3 >> 13;
- y[j] = v - 1;
+ const int i16 = il/2;
+ uint8_t q = x[i].ql[3*i16+2*(il%2)];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[2*(il%2)+j] = vs - 1;
}
+ q = x[i].ql[3*i16+1];
+ for (int j = 0; j < 2; ++j) {
+ uint8_t v = k_mult[3*(il%2)+j]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[5*(1-(il%2))+j] = vs-1;
+ }
+ uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[7] = vs - 1;
+
+#undef COMPUTE_VS
}
template<typename dst_t>
diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh
index d0d2e923..f0133e07 100644
--- a/ggml-cuda/vecdotq.cuh
+++ b/ggml-cuda/vecdotq.cuh
@@ -1078,67 +1078,47 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx;
- static const int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
+ static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
// iqs is 0 or 1
- uint8_t extra = bq1->extra >> 4*iqs;
int sumi = 0;
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const int * q8 = (const int *)bq8_1[iqs].qs;
- //int v[2];
- //int8_t * a = (int8_t *)v;
- //for (int l = 0; l < 2; ++l) {
- // int16_t val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
- // int16_t val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
- // for (int k = 0; k < 8; ++k) a[k] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1;
- // sumi = __dp4a(v[0], q8[4*l+0], __dp4a(v[1], q8[4*l+1], sumi));
- // for (int k = 0; k < 8; ++k) a[k] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1;
- // sumi = __dp4a(v[0], q8[4*l+2], __dp4a(v[1], q8[4*l+3], sumi));
- // extra >>= 2;
- //}
-
- int v[4];
- int8_t * a = (int8_t *)v;
+ int val[4];
for (int l = 0; l < 2; ++l) {
- int16_t val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
- int16_t val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
- for (int k = 0; k < 8; ++k) {
- a[k+0] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1;
- a[k+8] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1;
+ int8_t * a = (int8_t *)val;
+ const int i16 = 2*iqs + l;
+ for (int k = 0; k < 3; ++k) {
+ uint8_t q = bq1->ql[3*i16+k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ *a++ = vs-1;
+ }
}
- sumi = __dp4a(v[0], q8[4*l+0], __dp4a(v[1], q8[4*l+1], __dp4a(v[2], q8[4*l+2], __dp4a(v[3], q8[4*l+3], sumi))));
- extra >>= 2;
+ uint8_t v = k_mult[i16]*bq1->extra;
+ int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ *a++ = vs-1;
+ sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi))));
}
-
- //int v[8];
- //int8_t * a = (int8_t *)v;
- //int16_t val1 = bq1->ql[4*iqs + 0] | ((bq1->qh[0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
- //int16_t val2 = bq1->ql[4*iqs + 1] | ((bq1->qh[1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
- //int16_t val3 = bq1->ql[4*iqs + 2] | ((bq1->qh[2] << (8-4*iqs)) & 0x0f00) | ((extra << 10) & 4096);
- //int16_t val4 = bq1->ql[4*iqs + 3] | ((bq1->qh[3] << (8-4*iqs)) & 0x0f00) | ((extra << 9) & 4096);
- //for (int k = 0; k < 8; ++k) {
- // a[k+ 0] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1;
- // a[k+ 8] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1;
- // a[k+16] = ((val3*k_mult[k] & 0x1fff)*3 >> 13) - 1;
- // a[k+24] = ((val4*k_mult[k] & 0x1fff)*3 >> 13) - 1;
- //}
- //sumi = __dp4a(v[0], q8[0], __dp4a(v[1], q8[1], __dp4a(v[2], q8[2], __dp4a(v[3], q8[3], sumi))));
- //sumi = __dp4a(v[4], q8[4], __dp4a(v[5], q8[5], __dp4a(v[6], q8[6], __dp4a(v[7], q8[7], sumi))));
#else
const int8_t * q8 = bq8_1[iqs].qs;
for (int l = 0; l < 2; ++l) {
- int val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
- int val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
- for (int k = 0; k < 4; ++k) {
- int v1 = (val1*k_mult[k+0] & 0x1fff)*3 >> 13;
- int v2 = (val1*k_mult[k+4] & 0x1fff)*3 >> 13;
- int v3 = (val2*k_mult[k+0] & 0x1fff)*3 >> 13;
- int v4 = (val2*k_mult[k+4] & 0x1fff)*3 >> 13;
- sumi += (v1 - 1)*q8[k+0] + (v2 - 1)*q8[k+4] + (v3 - 1)*q8[k+8] + (v4 - 1)*q8[k+12];
+ const int i16 = 2*iqs + l;
+ for (int k = 0; k < 3; ++k) {
+ uint8_t q = bq1->ql[3*i16+k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = (v + (v >> 1)) >> 7;
+ sumi += q8[j]*(vs - 1);
+ }
+ q8 += 5;
}
- q8 += 16;
- extra >>= 2;
+ uint8_t v = k_mult[i16]*bq1->extra;
+ int8_t vs = (v + (v >> 1)) >> 7;
+ sumi += q8[0]*(vs - 1);
+ q8++;
}
#endif
return __low2float(bq8_1[iqs].ds) * sumi;