summaryrefslogtreecommitdiff
path: root/ggml-cuda
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-25 18:19:11 +0300
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-25 18:19:11 +0300
commit753dbaeeb0be5fb3d0d4337d7854dcf4f3a30fe1 (patch)
treeafedc73d7d8b8032f5c2057aec8bdff95e6601df /ggml-cuda
parent8b436a84c53de4c5a8eaf9be72cdd82324da2eeb (diff)
bitnet: remove iq1_bn lookup table storing +/- signs
The AVX2 implementation was the only one left using it, so I decided to see if we can get a performant implementation using the 0,1,2 lookup table. Turns out we can, and it is even slightly faster than the sign based table. We now get PP-512 = 275 t/s and TG-128 = 57.7 t/s with 16 threads on the Ryzen-7950X. With only one lookup table left for iq1_bn, I renamed it to iq1bn_grid_u16.
Diffstat (limited to 'ggml-cuda')
-rw-r--r--ggml-cuda/convert.cu2
-rw-r--r--ggml-cuda/vecdotq.cuh6
2 files changed, 4 insertions, 4 deletions
diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu
index 888c8452..0e1cde9b 100644
--- a/ggml-cuda/convert.cu
+++ b/ggml-cuda/convert.cu
@@ -433,7 +433,7 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst
if (i >= nb64) return;
ib = ib%(QK_IQ1BN/32);
uint16_t idx = x[i].ql[4*ib + il] | ((x[i].qh[2*ib + il/2] << (8 - 4*(il%2))) & 0x0f00);
- uint16_t val = x[i].extra & (1 << (4*ib + il)) ? 0xaaaa - iq1bn_grid_zzz[idx] : iq1bn_grid_zzz[idx];
+ uint16_t val = x[i].extra & (1 << (4*ib + il)) ? 0xaaaa - iq1bn_grid_u16[idx] : iq1bn_grid_u16[idx];
uint32_t aux32[2];
const int8_t * aux8 = (const int8_t *)aux32;
aux32[0] = val | (val << 14);
diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh
index bce2c154..1e2b4b7a 100644
--- a/ggml-cuda/vecdotq.cuh
+++ b/ggml-cuda/vecdotq.cuh
@@ -1086,8 +1086,8 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
for (int l = 0; l < 2; ++l) {
uint16_t idx1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*iqs + l] << 8) & 0x0f00);
uint16_t idx2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*iqs + l] << 4) & 0x0f00);
- uint16_t val1 = extra & 1 ? 0xaaaa - iq1bn_grid_zzz[idx1] : iq1bn_grid_zzz[idx1];
- uint16_t val2 = extra & 2 ? 0xaaaa - iq1bn_grid_zzz[idx2] : iq1bn_grid_zzz[idx2];
+ uint16_t val1 = extra & 1 ? 0xaaaa - iq1bn_grid_u16[idx1] : iq1bn_grid_u16[idx1];
+ uint16_t val2 = extra & 2 ? 0xaaaa - iq1bn_grid_u16[idx2] : iq1bn_grid_u16[idx2];
val32 = val1 | (val1 << 14);
v1 = __vsub4(val32 & 0x03030303, 0x01010101);
v2 = __vsub4((val32 >> 4) & 0x03030303, 0x01010101);
@@ -1104,7 +1104,7 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
const int8_t * q8 = bq8_1[iqs].qs;
for (int l = 0; l < 4; ++l) {
uint16_t idx = bq1->ql[4*iqs + l] | ((bq1->qh[2*iqs + l/2] << (8 - 4*(l%2))) & 0x0f00);
- uint16_t val = extra & 1 ? 0xaaaa - iq1bn_grid_zzz[idx] : iq1bn_grid_zzz[idx];
+ uint16_t val = extra & 1 ? 0xaaaa - iq1bn_grid_u16[idx] : iq1bn_grid_u16[idx];
aux32[0] = val | (val << 14);
aux32[1] = (aux32[0] >> 4) & 0x03030303;
aux32[0] &= 0x03030303;