summaryrefslogtreecommitdiff
path: root/ggml-cuda
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-25 11:32:48 +0300
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-25 11:32:48 +0300
commitaa14a06b44ff12be7e4461a6e169a657275a5b20 (patch)
treec0ab2e1cd51a778594f0dd226d3e54c102c81b39 /ggml-cuda
parentcc44d4a5c3368801f1de0d68096619a6746d47a4 (diff)
Bitnet: trying an alternative iq1_bn grid
Faster on CUDA. The scalar version is faster too. The issue with CUDA is that now I see wild performance fluctuations. Running llama-bench I can get 220 t/s for TG-128 one time, and 190 t/s another time, with uncertaintiers of 1-2 t/s. Same for PP, results are jumping back-and-fort between ~9500 t/s and ~8900 t/s. So, basically no reliable measurement at this point, but for sure faster than the previous version, which was at around 170-180 t/s.
Diffstat (limited to 'ggml-cuda')
-rw-r--r--ggml-cuda/convert.cu19
-rw-r--r--ggml-cuda/vecdotq.cuh38
2 files changed, 36 insertions, 21 deletions
diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu
index 2be03a3e..888c8452 100644
--- a/ggml-cuda/convert.cu
+++ b/ggml-cuda/convert.cu
@@ -432,13 +432,20 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst
int64_t i = QK_K/QK_IQ1BN * ii + ib/(QK_IQ1BN/32);
if (i >= nb64) return;
ib = ib%(QK_IQ1BN/32);
- const float dl = x[i].extra & (1 << (4*ib + il)) ? -1 : 1;
- const float ml = -dl;
uint16_t idx = x[i].ql[4*ib + il] | ((x[i].qh[2*ib + il/2] << (8 - 4*(il%2))) & 0x0f00);
- const uint16_t gp = iq1bn_grid_u16[idx];
- for (int j = 0; j < 8; ++j) {
- y[j] = dl * ((gp >> 2*j) & 3) + ml;
- }
+ uint16_t val = x[i].extra & (1 << (4*ib + il)) ? 0xaaaa - iq1bn_grid_zzz[idx] : iq1bn_grid_zzz[idx];
+ uint32_t aux32[2];
+ const int8_t * aux8 = (const int8_t *)aux32;
+ aux32[0] = val | (val << 14);
+//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+// aux32[1] = __vsub4((aux32[0] >> 4) & 0x03030303, 0x01010101);
+// aux32[0] = __vsub4(aux32[0] & 0x03030303, 0x01010101);
+// for (int j = 0; j < 8; ++j) y[j] = aux8[j];
+//#else
+ aux32[1] = (aux32[0] >> 4) & 0x03030303;
+ aux32[0] &= 0x03030303;
+ for (int j = 0; j < 8; ++j) y[j] = aux8[j] - 1;
+//#endif
}
template<typename dst_t>
diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh
index 6b831cf6..bce2c154 100644
--- a/ggml-cuda/vecdotq.cuh
+++ b/ggml-cuda/vecdotq.cuh
@@ -1082,27 +1082,35 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
int sumi = 0;
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const int * q8 = (const int *)bq8_1[iqs].qs;
- for (int l = 0; l < 4; ++l) {
- uint16_t val = iq1bn_grid_xxx[bq1->ql[4*iqs + l] | ((bq1->qh[2*iqs + l/2] << (8 - 4*(l%2))) & 0x0f00)];
- uint8_t vp = val & 0xff, vm = val >> 8;
- int32_t vp1 = __vcmpeq4(((vp & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
- int32_t vp2 = __vcmpeq4(((vp >> 4) * 0x01010101) & 0x08040201, 0x08040201);
- int32_t vm1 = __vcmpeq4(((vm & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
- int32_t vm2 = __vcmpeq4(((vm >> 4) * 0x01010101) & 0x08040201, 0x08040201);
- int32_t pm = __dp4a(q8[2*l+0], vm1, __dp4a(q8[2*l+1], vm2, 0));
- int32_t pp = __dp4a(q8[2*l+0], vp1, __dp4a(q8[2*l+1], vp2, 0));
- sumi += extra & (1 << l) ? pp - pm : pm - pp;
+ int val32, v1, v2;
+ for (int l = 0; l < 2; ++l) {
+ uint16_t idx1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*iqs + l] << 8) & 0x0f00);
+ uint16_t idx2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*iqs + l] << 4) & 0x0f00);
+ uint16_t val1 = extra & 1 ? 0xaaaa - iq1bn_grid_zzz[idx1] : iq1bn_grid_zzz[idx1];
+ uint16_t val2 = extra & 2 ? 0xaaaa - iq1bn_grid_zzz[idx2] : iq1bn_grid_zzz[idx2];
+ val32 = val1 | (val1 << 14);
+ v1 = __vsub4(val32 & 0x03030303, 0x01010101);
+ v2 = __vsub4((val32 >> 4) & 0x03030303, 0x01010101);
+ sumi = __dp4a(v1, q8[4*l+0], __dp4a(v2, q8[4*l+1], sumi));
+ val32 = val2 | (val2 << 14);
+ v1 = __vsub4(val32 & 0x03030303, 0x01010101);
+ v2 = __vsub4((val32 >> 4) & 0x03030303, 0x01010101);
+ sumi = __dp4a(v1, q8[4*l+2], __dp4a(v2, q8[4*l+3], sumi));
+ extra >>= 2;
}
#else
+ uint32_t aux32[2];
+ const int8_t * aux8 = (const int8_t *)aux32;
const int8_t * q8 = bq8_1[iqs].qs;
for (int l = 0; l < 4; ++l) {
- uint16_t val = iq1bn_grid_u16[bq1->ql[4*iqs + l] | ((bq1->qh[2*iqs + l/2] << (8 - 4*(l%2))) & 0x0f00)];
- int s1 = 0, s2 = 0;
+ uint16_t idx = bq1->ql[4*iqs + l] | ((bq1->qh[2*iqs + l/2] << (8 - 4*(l%2))) & 0x0f00);
+ uint16_t val = extra & 1 ? 0xaaaa - iq1bn_grid_zzz[idx] : iq1bn_grid_zzz[idx];
+ aux32[0] = val | (val << 14);
+ aux32[1] = (aux32[0] >> 4) & 0x03030303;
+ aux32[0] &= 0x03030303;
for (int j = 0; j < 8; ++j) {
- s1 += q8[j] * ((val >> 2*j) & 3);
- s2 += q8[j];
+ sumi += q8[j] * (aux8[j] - 1);
}
- sumi += extra & (1 << l) ? s2 - s1 : s1 - s2;
q8 += 8;
}
#endif