diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-07-30 17:18:31 +0300 |
---|---|---|
committer | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-08-01 09:38:06 +0200 |
commit | fd1ae85a329e8148d1de20dc6ef5302110d53b73 (patch) | |
tree | 8bc375e60041fde1e5d96954a170de86ebfaea8d | |
parent | 0d19d19af88a508ee8987abe5fc4f8fcaaa1dc2d (diff) |
iq3_k: faster CUDA dot product
138 t/s for LLaMA-3.1-8B, which is almost on par with iq3_s.
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 46 |
1 files changed, 28 insertions, 18 deletions
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index e6223f65..071d55ca 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -381,7 +381,21 @@ __device__ __forceinline__ float vec_dot_iq2_k_q8_1( #define VDR_IQ3_K_Q8_1_MMVQ 4 #define VDR_IQ3_K_Q8_1_MMQ 4 -// TODO +static const __device__ uint16_t iq3k_table[128] = { + 0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f, + 0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f, + 0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f, + 0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f, + 0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33, + 0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33, + 0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133, + 0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333, +}; + +__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) { + return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16); +} + __device__ __forceinline__ float vec_dot_iq3_k_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iiqs) { const block_iq3_k * bq3 = (const block_iq3_k *) vbq + kbx; @@ -402,14 +416,10 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1( const uint16_t sh = bq3->scales_h >> (8*ib128 + il8/2); const uint8_t extra = bq3->extra >> (8*ib128 + il8/2); - - uint32_t indx[4]; - indx[0] = ((extra << 3) & 8) * 0x01010101; - indx[1] = ((extra << 2) & 8) * 0x01010101; - indx[2] = ((extra << 1) & 8) * 0x01010101; - indx[3] = ((extra << 0) & 8) * 0x01010101; - - const uint8_t * values = (const uint8_t *)iq3nl_values; + const uint16_t * values1 = iq3k_table + ((extra << 6) & 0x40); + const uint16_t * values2 = iq3k_table + ((extra << 5) & 0x40); + const uint16_t * values3 = iq3k_table + ((extra << 4) & 0x40); + const uint16_t * values4 = iq3k_table + ((extra << 3) & 0x40); const int * q8; int sumi[4] = {0, 0, 0, 0}; @@ -420,30 +430,30 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1( uint32_t vh = ((qh[2*i+0] | (qh[2*i+1] << 16)) << hshift) >> 2; q8 = (const int *)bq8_1[4*ib128+0].qs + 2*il8; - aux32 = (vl & 0x03030303) | (vh & 0x04040404) | indx[0]; - v1 = values[aux8[0]] | (values[aux8[1]] << 8); v2 = values[aux8[2]] | values[aux8[3]] << 8; v = v1 | (v2 << 16); + aux32 = (vl & 0x03030303) | (vh & 0x04040404); + v = int_from_table_2(aux8, values1); sumi[0] = ggml_cuda_dp4a(v, q8[i], sumi[0]); vl >>= 2; vh >>= 1; q8 += sizeof(block_q8_1)/4; - aux32 = (vl & 0x03030303) | (vh & 0x04040404) | indx[1]; - v1 = values[aux8[0]] | (values[aux8[1]] << 8); v2 = values[aux8[2]] | values[aux8[3]] << 8; v = v1 | (v2 << 16); + aux32 = (vl & 0x03030303) | (vh & 0x04040404); + v = int_from_table_2(aux8, values2); sumi[1] = ggml_cuda_dp4a(v, q8[i], sumi[1]); vl >>= 2; vh >>= 1; q8 += sizeof(block_q8_1)/4; - aux32 = (vl & 0x03030303) | (vh & 0x04040404) | indx[2]; - v1 = values[aux8[0]] | (values[aux8[1]] << 8); v2 = values[aux8[2]] | values[aux8[3]] << 8; v = v1 | (v2 << 16); + aux32 = (vl & 0x03030303) | (vh & 0x04040404); + v = int_from_table_2(aux8, values3); sumi[2] = ggml_cuda_dp4a(v, q8[i], sumi[2]); vl >>= 2; vh >>= 1; q8 += sizeof(block_q8_1)/4; - aux32 = (vl & 0x03030303) | (vh & 0x04040404) | indx[3]; - v1 = values[aux8[0]] | (values[aux8[1]] << 8); v2 = values[aux8[2]] | values[aux8[3]] << 8; v = v1 | (v2 << 16); + aux32 = (vl & 0x03030303) | (vh & 0x04040404); + v = int_from_table_2(aux8, values4); sumi[3] = ggml_cuda_dp4a(v, q8[i], sumi[3]); } - const float d = (float)bq3->d; + const float d = __half2float(bq3->d); const uint16_t * sl16 = (const uint16_t *)bq3->scales_l + 2*ib128; aux32 = ((((sl16[0] | (sl16[1] << 16)) >> shift) & 0x0f0f0f0f) << 1) | 0x01010101; return d * (__low2float(bq8_1[4*ib128+0].ds) * aux8[0] * (sh & 0x01 ? -1 : 1) * sumi[0] + |