summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-07-30 12:33:48 +0300
committerKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-08-01 09:38:06 +0200
commitab4f9e1fdb7441f8250364b248b5e709ec66771f (patch)
tree003a54ba333fb25ad5cf764610c72f1b6c0a8946
parent69842c6ad805c7de8f0416e52a1f12d3357023d9 (diff)
iq2_k: CUDA dot product finally works
Performance is pathetic: 140 t/s for LLaMA-3.1-8B vs 172 t/s for iq2_xs.
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu79
1 files changed, 55 insertions, 24 deletions
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index 7d54fdd2..b3603697 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -251,30 +251,61 @@ __device__ __forceinline__ float vec_dot_iq5_k_q8_1(
// TODO
__device__ __forceinline__ float vec_dot_iq2_k_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- return 0;
-//
-// const block_iq2_k * bq4 = (const block_iq2_k *) vbq + kbx;
-// const uint8_t * all_values = (const uint8_t *)iq4k_values;
-//
-// // iqs is 0...28
-// const int ib32 = iqs/4;
-// // Why iqs/4 ?
-// const int32_t * q8 = (const int *)bq8_1[ib32].qs;
-// const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32;
-// const uint16_t extra = bq4->extra >> 2*ib32;
-// int v1, v2;
-// int sumi1 = 0, sumi2 = 0;
-// for (int j = 0; j < 4; ++j) {
-// const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16);
-// get_int_from_table_16_shift(aux32, extra, all_values, v1, v2);
-// sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1);
-// sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2);
-// }
-// const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds);
-// const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2);
-// const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32;
-// const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32;
-// return d * (sumi1 * ls1 + sumi2 * ls2);
+
+ // iqs is 0, 4, 8, 12, 16, 20, 24, 28
+ // we have 16 packed quants (when cast to int)
+
+ int i4 = iqs/4; // 0...7. We will process q8 blocks 4*(i4/4), 4*(i4/4)+1, 4*(i4/4)+2, 4*(i4/4)+3
+ const int32_t * q8_1 = (const int *)bq8_1[4*(i4/4)+0].qs + 2*(i4%4);
+ const int32_t * q8_2 = (const int *)bq8_1[4*(i4/4)+1].qs + 2*(i4%4);
+ const int32_t * q8_3 = (const int *)bq8_1[4*(i4/4)+2].qs + 2*(i4%4);
+ const int32_t * q8_4 = (const int *)bq8_1[4*(i4/4)+3].qs + 2*(i4%4);
+
+ const block_iq2_k * bq2 = (const block_iq2_k *) vbq + kbx;
+ const uint32_t * q2 = (const uint32_t *)bq2->qs + 8*(i4/4) + 2*(i4%4);
+ const uint16_t extra = bq2->extra >> (8*(i4/4) + (i4%4)/2);
+
+ const uint8_t * all_values = (const uint8_t *)iq2nl_values;
+ const uint8_t * values;
+
+ uint32_t val1 = q2[0], val2 = q2[1];
+
+ uint32_t aux32[2];
+ const uint8_t * a8 = (const uint8_t *)&aux32;
+ int v1, v2, ls;
+
+ // Block of 16: (32*(4*(i4/4)+k)+8*(i4%4))/16 = 8*(i4/4) + 2*k + (i4%4)/2
+ // -> scales_l[4*(i4/4) + k] >> 4*(((i4%4)/2)%2)
+
+ ls = (bq2->scales[4*(i4/4) + 0] >> 4*(((i4%4)/2)%2)) & 0xf;
+ aux32[0] = ((val1 >> 0) & 0x03030303); aux32[1] = ((val2 >> 0) & 0x03030303); values = all_values + ((extra & 0x01) << 2);
+ v1 = int_from_table(a8 + 0, values);
+ v2 = int_from_table(a8 + 4, values);
+ int sumi1 = ggml_cuda_dp4a(v2, q8_1[1], ggml_cuda_dp4a(v1, q8_1[0], 0)) * (2*ls - 15);
+
+ ls = (bq2->scales[4*(i4/4) + 1] >> 4*(((i4%4)/2)%2)) & 0xf;
+ aux32[0] = ((val1 >> 2) & 0x03030303); aux32[1] = ((val2 >> 2) & 0x03030303); values = all_values + ((extra & 0x04) << 0);
+ v1 = int_from_table(a8 + 0, values);
+ v2 = int_from_table(a8 + 4, values);
+ int sumi2 = ggml_cuda_dp4a(v2, q8_2[1], ggml_cuda_dp4a(v1, q8_2[0], 0)) * (2*ls - 15);
+
+ ls = (bq2->scales[4*(i4/4) + 2] >> 4*(((i4%4)/2)%2)) & 0xf;
+ aux32[0] = ((val1 >> 4) & 0x03030303); aux32[1] = ((val2 >> 4) & 0x03030303); values = all_values + ((extra & 0x10) >> 2);
+ v1 = int_from_table(a8 + 0, values);
+ v2 = int_from_table(a8 + 4, values);
+ int sumi3 = ggml_cuda_dp4a(v2, q8_3[1], ggml_cuda_dp4a(v1, q8_3[0], 0)) * (2*ls - 15);
+
+ ls = (bq2->scales[4*(i4/4) + 3] >> 4*(((i4%4)/2)%2)) & 0xf;
+ aux32[0] = ((val1 >> 6) & 0x03030303); aux32[1] = ((val2 >> 6) & 0x03030303); values = all_values + ((extra & 0x40) >> 4);
+ v1 = int_from_table(a8 + 0, values);
+ v2 = int_from_table(a8 + 4, values);
+ int sumi4 = ggml_cuda_dp4a(v2, q8_4[1], ggml_cuda_dp4a(v1, q8_4[0], 0)) * (2*ls - 15);
+
+ return __half2float(bq2->d) * (__low2float(bq8_1[4*(i4/4)+0].ds) * sumi1
+ + __low2float(bq8_1[4*(i4/4)+1].ds) * sumi2
+ + __low2float(bq8_1[4*(i4/4)+2].ds) * sumi3
+ + __low2float(bq8_1[4*(i4/4)+3].ds) * sumi4);
+
}
}