summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-07-10 09:27:28 +0200
committerGitHub <noreply@github.com>2025-07-10 09:27:28 +0200
commit283753cabcabd30eb2cfb93739d9c1679200bf1f (patch)
tree86de891461ece7de11c98f7fb3eb494203b36cbb
parent5446ccc8ac87037484ba63f91941de35e0bd58ca (diff)
CUDA: Faster prompt processing for several quantization types (#595)
* cuda: slightly faster MMQ for iq3_k, iq3_k_r4 * cuda: slightly faster MMQ for iq4_k, iq4_k_r4 * cuda: slightly faster MMQ for iq4_ks_r4 * cuda: slightly faster MMQ for iq4_ks * cuda: slightly faster MMQ for iq4_xs --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-cuda/iqk_cuda_common.h54
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu15
-rw-r--r--ggml/src/ggml-cuda/mmq.cuh128
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_r4.cu26
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_k_r4.cu23
5 files changed, 139 insertions, 107 deletions
diff --git a/ggml/src/ggml-cuda/iqk_cuda_common.h b/ggml/src/ggml-cuda/iqk_cuda_common.h
index 95d9b40e..fbe655c4 100644
--- a/ggml/src/ggml-cuda/iqk_cuda_common.h
+++ b/ggml/src/ggml-cuda/iqk_cuda_common.h
@@ -73,3 +73,57 @@ __device__ __forceinline__ int int_from_table_4(const uint32_t idx, const int *
return values[ggml_cuda_dp4a(idx, 0x40100401, 0)];
}
+static const __device__ uint16_t iq3k_table[128] = {
+ 0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f,
+ 0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f,
+ 0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f,
+ 0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f,
+ 0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33,
+ 0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33,
+ 0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133,
+ 0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333,
+};
+
+__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) {
+ return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16);
+}
+
+static const __device__ uint16_t iq4k_table[512] = {
+ 0x8181, 0x8198, 0x81ad, 0x81bf, 0x81cf, 0x81dd, 0x81ea, 0x81f6, 0x8101, 0x810d, 0x8119, 0x8126, 0x8135, 0x8145, 0x8159, 0x8171,
+ 0x9881, 0x9898, 0x98ad, 0x98bf, 0x98cf, 0x98dd, 0x98ea, 0x98f6, 0x9801, 0x980d, 0x9819, 0x9826, 0x9835, 0x9845, 0x9859, 0x9871,
+ 0xad81, 0xad98, 0xadad, 0xadbf, 0xadcf, 0xaddd, 0xadea, 0xadf6, 0xad01, 0xad0d, 0xad19, 0xad26, 0xad35, 0xad45, 0xad59, 0xad71,
+ 0xbf81, 0xbf98, 0xbfad, 0xbfbf, 0xbfcf, 0xbfdd, 0xbfea, 0xbff6, 0xbf01, 0xbf0d, 0xbf19, 0xbf26, 0xbf35, 0xbf45, 0xbf59, 0xbf71,
+ 0xcf81, 0xcf98, 0xcfad, 0xcfbf, 0xcfcf, 0xcfdd, 0xcfea, 0xcff6, 0xcf01, 0xcf0d, 0xcf19, 0xcf26, 0xcf35, 0xcf45, 0xcf59, 0xcf71,
+ 0xdd81, 0xdd98, 0xddad, 0xddbf, 0xddcf, 0xdddd, 0xddea, 0xddf6, 0xdd01, 0xdd0d, 0xdd19, 0xdd26, 0xdd35, 0xdd45, 0xdd59, 0xdd71,
+ 0xea81, 0xea98, 0xeaad, 0xeabf, 0xeacf, 0xeadd, 0xeaea, 0xeaf6, 0xea01, 0xea0d, 0xea19, 0xea26, 0xea35, 0xea45, 0xea59, 0xea71,
+ 0xf681, 0xf698, 0xf6ad, 0xf6bf, 0xf6cf, 0xf6dd, 0xf6ea, 0xf6f6, 0xf601, 0xf60d, 0xf619, 0xf626, 0xf635, 0xf645, 0xf659, 0xf671,
+ 0x0181, 0x0198, 0x01ad, 0x01bf, 0x01cf, 0x01dd, 0x01ea, 0x01f6, 0x0101, 0x010d, 0x0119, 0x0126, 0x0135, 0x0145, 0x0159, 0x0171,
+ 0x0d81, 0x0d98, 0x0dad, 0x0dbf, 0x0dcf, 0x0ddd, 0x0dea, 0x0df6, 0x0d01, 0x0d0d, 0x0d19, 0x0d26, 0x0d35, 0x0d45, 0x0d59, 0x0d71,
+ 0x1981, 0x1998, 0x19ad, 0x19bf, 0x19cf, 0x19dd, 0x19ea, 0x19f6, 0x1901, 0x190d, 0x1919, 0x1926, 0x1935, 0x1945, 0x1959, 0x1971,
+ 0x2681, 0x2698, 0x26ad, 0x26bf, 0x26cf, 0x26dd, 0x26ea, 0x26f6, 0x2601, 0x260d, 0x2619, 0x2626, 0x2635, 0x2645, 0x2659, 0x2671,
+ 0x3581, 0x3598, 0x35ad, 0x35bf, 0x35cf, 0x35dd, 0x35ea, 0x35f6, 0x3501, 0x350d, 0x3519, 0x3526, 0x3535, 0x3545, 0x3559, 0x3571,
+ 0x4581, 0x4598, 0x45ad, 0x45bf, 0x45cf, 0x45dd, 0x45ea, 0x45f6, 0x4501, 0x450d, 0x4519, 0x4526, 0x4535, 0x4545, 0x4559, 0x4571,
+ 0x5981, 0x5998, 0x59ad, 0x59bf, 0x59cf, 0x59dd, 0x59ea, 0x59f6, 0x5901, 0x590d, 0x5919, 0x5926, 0x5935, 0x5945, 0x5959, 0x5971,
+ 0x7181, 0x7198, 0x71ad, 0x71bf, 0x71cf, 0x71dd, 0x71ea, 0x71f6, 0x7101, 0x710d, 0x7119, 0x7126, 0x7135, 0x7145, 0x7159, 0x7171,
+ 0x8585, 0x859c, 0x85b1, 0x85c3, 0x85d3, 0x85e1, 0x85ee, 0x85fa, 0x8505, 0x8511, 0x851d, 0x852a, 0x8539, 0x8549, 0x855d, 0x8575,
+ 0x9c85, 0x9c9c, 0x9cb1, 0x9cc3, 0x9cd3, 0x9ce1, 0x9cee, 0x9cfa, 0x9c05, 0x9c11, 0x9c1d, 0x9c2a, 0x9c39, 0x9c49, 0x9c5d, 0x9c75,
+ 0xb185, 0xb19c, 0xb1b1, 0xb1c3, 0xb1d3, 0xb1e1, 0xb1ee, 0xb1fa, 0xb105, 0xb111, 0xb11d, 0xb12a, 0xb139, 0xb149, 0xb15d, 0xb175,
+ 0xc385, 0xc39c, 0xc3b1, 0xc3c3, 0xc3d3, 0xc3e1, 0xc3ee, 0xc3fa, 0xc305, 0xc311, 0xc31d, 0xc32a, 0xc339, 0xc349, 0xc35d, 0xc375,
+ 0xd385, 0xd39c, 0xd3b1, 0xd3c3, 0xd3d3, 0xd3e1, 0xd3ee, 0xd3fa, 0xd305, 0xd311, 0xd31d, 0xd32a, 0xd339, 0xd349, 0xd35d, 0xd375,
+ 0xe185, 0xe19c, 0xe1b1, 0xe1c3, 0xe1d3, 0xe1e1, 0xe1ee, 0xe1fa, 0xe105, 0xe111, 0xe11d, 0xe12a, 0xe139, 0xe149, 0xe15d, 0xe175,
+ 0xee85, 0xee9c, 0xeeb1, 0xeec3, 0xeed3, 0xeee1, 0xeeee, 0xeefa, 0xee05, 0xee11, 0xee1d, 0xee2a, 0xee39, 0xee49, 0xee5d, 0xee75,
+ 0xfa85, 0xfa9c, 0xfab1, 0xfac3, 0xfad3, 0xfae1, 0xfaee, 0xfafa, 0xfa05, 0xfa11, 0xfa1d, 0xfa2a, 0xfa39, 0xfa49, 0xfa5d, 0xfa75,
+ 0x0585, 0x059c, 0x05b1, 0x05c3, 0x05d3, 0x05e1, 0x05ee, 0x05fa, 0x0505, 0x0511, 0x051d, 0x052a, 0x0539, 0x0549, 0x055d, 0x0575,
+ 0x1185, 0x119c, 0x11b1, 0x11c3, 0x11d3, 0x11e1, 0x11ee, 0x11fa, 0x1105, 0x1111, 0x111d, 0x112a, 0x1139, 0x1149, 0x115d, 0x1175,
+ 0x1d85, 0x1d9c, 0x1db1, 0x1dc3, 0x1dd3, 0x1de1, 0x1dee, 0x1dfa, 0x1d05, 0x1d11, 0x1d1d, 0x1d2a, 0x1d39, 0x1d49, 0x1d5d, 0x1d75,
+ 0x2a85, 0x2a9c, 0x2ab1, 0x2ac3, 0x2ad3, 0x2ae1, 0x2aee, 0x2afa, 0x2a05, 0x2a11, 0x2a1d, 0x2a2a, 0x2a39, 0x2a49, 0x2a5d, 0x2a75,
+ 0x3985, 0x399c, 0x39b1, 0x39c3, 0x39d3, 0x39e1, 0x39ee, 0x39fa, 0x3905, 0x3911, 0x391d, 0x392a, 0x3939, 0x3949, 0x395d, 0x3975,
+ 0x4985, 0x499c, 0x49b1, 0x49c3, 0x49d3, 0x49e1, 0x49ee, 0x49fa, 0x4905, 0x4911, 0x491d, 0x492a, 0x4939, 0x4949, 0x495d, 0x4975,
+ 0x5d85, 0x5d9c, 0x5db1, 0x5dc3, 0x5dd3, 0x5de1, 0x5dee, 0x5dfa, 0x5d05, 0x5d11, 0x5d1d, 0x5d2a, 0x5d39, 0x5d49, 0x5d5d, 0x5d75,
+ 0x7585, 0x759c, 0x75b1, 0x75c3, 0x75d3, 0x75e1, 0x75ee, 0x75fa, 0x7505, 0x7511, 0x751d, 0x752a, 0x7539, 0x7549, 0x755d, 0x7575,
+};
+
+__device__ __forceinline__ int int_from_table_x(const uint8_t * a8, const uint16_t * values) {
+ return values[a8[0] | (a8[1] << 4)] | (values[a8[2] | (a8[3] << 4)] << 16);
+}
+
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index 54d03f78..d897063f 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -950,21 +950,6 @@ __device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1(
#define VDR_IQ3_K_Q8_1_MMVQ 4
#define VDR_IQ3_K_Q8_1_MMQ 4
-static const __device__ uint16_t iq3k_table[128] = {
- 0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f,
- 0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f,
- 0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f,
- 0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f,
- 0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33,
- 0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33,
- 0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133,
- 0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333,
-};
-
-__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) {
- return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16);
-}
-
__device__ __forceinline__ void vec_dot_iq3_k_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iiqs, float * result) {
const block_iq3_k * bq3 = (const block_iq3_k *) vbq + kbx;
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 8a87e3e8..ee34452a 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -2436,6 +2436,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kbx = 0; // threadIdx.x / QI4_XS
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
+ uint32_t aux32[2];
+ auto a8 = (const uint8_t *)aux32;
+
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;
@@ -2446,15 +2449,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0 + kbx;
- const int aux_q4 = get_int_b4(bxi->qs, kqsx);
- const int2 v = get_int_from_table_16(aux_q4);
+ const int q4 = get_int_b4(bxi->qs, kqsx);
+ aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
+ aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = int_from_table_x(a8+0, iq4k_table);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = int_from_table_x(a8+4, iq4k_table);
#else
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = int_from_table_x(a8+0, iq4k_table);
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = int_from_table_x(a8+4, iq4k_table);
#endif // INT8_MMA_AVAILABLE
}
@@ -2623,8 +2627,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
constexpr int qstep = 8;
const int kqsx = threadIdx.x % qstep;
- auto values = iq3nl_values;
-
uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;
#pragma unroll
@@ -2646,57 +2648,43 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
for (int l = 0; l < qstep/4; ++l) {
const int ql = get_int_b2(bxi->qs, kqsx + qstep*l);
- aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404) | (((extra << 3) & 8) * 0x01010101);
- aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404) | (((extra << 1) & 8) * 0x01010101);
- aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404) | (((extra >> 1) & 8) * 0x01010101);
- aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404) | (((extra >> 3) & 8) * 0x01010101);
+ aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404);
+ aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404);
+ aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404);
+ aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404);
+
+ const int val0 = int_from_table_2(aux8+ 0, iq3k_table + ((extra << 6) & 0x40));
+ const int val1 = int_from_table_2(aux8+ 4, iq3k_table + ((extra << 4) & 0x40));
+ const int val2 = int_from_table_2(aux8+ 8, iq3k_table + ((extra << 2) & 0x40));
+ const int val3 = int_from_table_2(aux8+12, iq3k_table + ((extra << 0) & 0x40));
+
extra >>= 8;
qh >>= 4;
- const char4 val0 = make_char4(values[aux8[ 0]], values[aux8[ 1]], values[aux8[ 2]], values[aux8[ 3]]);
- const char4 val1 = make_char4(values[aux8[ 4]], values[aux8[ 5]], values[aux8[ 6]], values[aux8[ 7]]);
- const char4 val2 = make_char4(values[aux8[ 8]], values[aux8[ 9]], values[aux8[10]], values[aux8[11]]);
- const char4 val3 = make_char4(values[aux8[12]], values[aux8[13]], values[aux8[14]], values[aux8[15]]);
-
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = *(const int *)&val0;
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = *(const int *)&val1;
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = *(const int *)&val2;
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = *(const int *)&val3;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = val0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = val1;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = val2;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = val3;
#else
- x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = *(const int *)&val0;
- x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = *(const int *)&val1;
- x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = *(const int *)&val2;
- x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = *(const int *)&val3;
+ x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = val0;
+ x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = val1;
+ x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = val2;
+ x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = val3;
#endif // INT8_MMA_AVAILABLE
}
uint16_t sh = bxi->scales_h >> 2*kqsx;
#ifdef INT8_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ((2*((bxi->scales_l[kqsx] >> 0) & 0xf) + 1) * (sh & 1 ? -1 : 1));
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ((2*((bxi->scales_l[kqsx] >> 4) & 0xf) + 1) * (sh & 2 ? -1 : 1));
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ((2*(bxi->scales_l[kqsx] & 0xf) + 1) * (sh & 1 ? -1 : 1));
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ((2*(bxi->scales_l[kqsx] >> 4) + 1) * (sh & 2 ? -1 : 1));
#else
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ((2*((bxi->scales_l[kqsx] >> 0) & 0xf) + 1) * (sh & 1 ? -1 : 1));
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ((2*((bxi->scales_l[kqsx] >> 4) & 0xf) + 1) * (sh & 2 ? -1 : 1));
+ x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ((2*(bxi->scales_l[kqsx] & 0xf) + 1) * (sh & 1 ? -1 : 1));
+ x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ((2*(bxi->scales_l[kqsx] >> 4) + 1) * (sh & 2 ? -1 : 1));
#endif // INT8_MMA_AVAILABLE
}
}
-static const __device__ uint16_t iq3k_table[128] = {
- 0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f,
- 0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f,
- 0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f,
- 0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f,
- 0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33,
- 0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33,
- 0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133,
- 0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333,
-};
-
-__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) {
- return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16);
-}
-
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_ks(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
@@ -2781,6 +2769,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kqsx = threadIdx.x / 4;
+ uint32_t aux32[2];
+ auto a8 = (const uint8_t *)aux32;
+
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
@@ -2792,18 +2783,20 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const float * dptr = (const float *)(x + i*stride);
const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
const int ls = (bxi->scales[kqsx] & 254) - 127;
- auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4);
+
+ auto values = iq4k_table + ((bxi->scales[kqsx] & 1) << 8);
#pragma unroll
for (int j = 0; j < 4; ++j) {
- const int aux_q4 = get_int_b4(bxi->qs, 4*kqsx+j);
- const int2 v = get_int_from_table_16(aux_q4, values);
+ const int q4 = get_int_b4(bxi->qs, 4*kqsx+j);
+ aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
+ aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x;
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
#else
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x;
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
#endif // INT8_MMA_AVAILABLE
}
#ifdef INT8_MMA_AVAILABLE
@@ -2830,6 +2823,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kqsx = threadIdx.x/4;
+ uint32_t aux32[2];
+ const uint8_t * a8 = (const uint8_t *)aux32;
+
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
@@ -2844,18 +2840,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_iq4_ks_r4 * bxi = (const block_iq4_ks_r4 *)(dptr + 4) + kbx0;
const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127;
- auto values = iq4k_values + ((bxi->scales[4*kqsx+ir] & 1) << 4);
+ auto values = iq4k_table + ((bxi->scales[4*kqsx+ir] & 1) << 8);
#pragma unroll
for (int j = 0; j < 4; ++j) {
const int q4 = get_int_b4(bxi->qs, 16*kqsx+4*j+ir);
- const int2 v = get_int_from_table_16(q4, values);
+ aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
+ aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
const int k0 = 8*kqsx + 4*(j%2) + j/2;
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = v.y;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = int_from_table_x(a8+0, values);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = int_from_table_x(a8+4, values);
#else
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = v.y;
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = int_from_table_x(a8+0, values);
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = int_from_table_x(a8+4, values);
#endif // INT8_MMA_AVAILABLE
}
#ifdef INT8_MMA_AVAILABLE
@@ -3177,25 +3174,26 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_iq4_k * bxi = (const block_iq4_k *)(x + i*stride) + kbx0;
const uint16_t extra = bxi->extra >> 2*kqsx;
- auto values_l = iq4k_values + ((extra & 1) << 4);
- auto values_h = iq4k_values + ((extra & 2) << 3);
+ auto values_l = iq4k_table + ((extra & 1) << 8);
+ auto values_h = iq4k_table + ((extra & 2) << 7);
#pragma unroll
for (int l = 0; l < qstep/2; ++l) {
const int q4 = get_int_b4(bxi->qs, (qstep/2)*kqsx + l);
+
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
- const char4 val0 = make_char4(values_l[aux8[0]], values_l[aux8[1]], values_l[aux8[2]], values_l[aux8[3]]);
- const char4 val1 = make_char4(values_h[aux8[4]], values_h[aux8[5]], values_h[aux8[6]], values_h[aux8[7]]);
+ int val0 = int_from_table_x(aux8+0, values_l);
+ int val1 = int_from_table_x(aux8+4, values_h);
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + l + 0] = *(const int *)&val0;
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + l + 4] = *(const int *)&val1;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + l + 0] = val0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + l + 4] = val1;
#else
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + l + 0] = *(const int *)&val0;
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + l + 4] = *(const int *)&val1;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + l + 0] = val0;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + l + 4] = val1;
#endif // INT8_MMA_AVAILABLE
}
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_r4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_r4.cu
index 7b6096a3..a588969f 100644
--- a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_r4.cu
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_r4.cu
@@ -37,7 +37,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
#pragma unroll
for (int l = 0; l < 2; ++l) {
- auto values_l = iq3nl_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 3);
+ auto values_l = iq3k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 6);
const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404);
@@ -45,21 +45,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404);
aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404);
- const char4 val0 = make_char4(values_l[aux8[ 0]], values_l[aux8[ 1]], values_l[aux8[ 2]], values_l[aux8[ 3]]);
- const char4 val1 = make_char4(values_l[aux8[ 4]], values_l[aux8[ 5]], values_l[aux8[ 6]], values_l[aux8[ 7]]);
- const char4 val2 = make_char4(values_l[aux8[ 8]], values_l[aux8[ 9]], values_l[aux8[10]], values_l[aux8[11]]);
- const char4 val3 = make_char4(values_l[aux8[12]], values_l[aux8[13]], values_l[aux8[14]], values_l[aux8[15]]);
+ int val0 = int_from_table_2(aux8+ 0, values_l);
+ int val1 = int_from_table_2(aux8+ 4, values_l);
+ int val2 = int_from_table_2(aux8+ 8, values_l);
+ int val3 = int_from_table_2(aux8+12, values_l);
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = *(const int *)&val0;
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = *(const int *)&val1;
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = *(const int *)&val2;
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = *(const int *)&val3;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = val0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = val1;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = val2;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = val3;
#else
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = *(const int *)&val0;
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = *(const int *)&val1;
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = *(const int *)&val2;
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = *(const int *)&val3;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = val0;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = val1;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = val2;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = val3;
#endif // INT8_MMA_AVAILABLE
qh >>= 4;
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_k_r4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_k_r4.cu
index 2fc314e3..a9dd85d4 100644
--- a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_k_r4.cu
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_k_r4.cu
@@ -35,7 +35,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
#pragma unroll
for (int l = 0; l < 2; ++l) {
- auto values_l = iq4k_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 4);
+ auto values_l = iq4k_table + ((bxi->extra[ir+4*l] << (8 - kqsx)) & 0x100);
const int ql1 = get_int_b4(bxi->qs, 16*kqsx + ir + 4*l + 0);
const int ql2 = get_int_b4(bxi->qs, 16*kqsx + ir + 4*l + 8);
@@ -44,21 +44,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
aux32[2] = (ql2 >> 0) & 0x0f0f0f0f;
aux32[3] = (ql2 >> 4) & 0x0f0f0f0f;
- const char4 val0 = make_char4(values_l[aux8[ 0]], values_l[aux8[ 1]], values_l[aux8[ 2]], values_l[aux8[ 3]]);
- const char4 val1 = make_char4(values_l[aux8[ 4]], values_l[aux8[ 5]], values_l[aux8[ 6]], values_l[aux8[ 7]]);
- const char4 val2 = make_char4(values_l[aux8[ 8]], values_l[aux8[ 9]], values_l[aux8[10]], values_l[aux8[11]]);
- const char4 val3 = make_char4(values_l[aux8[12]], values_l[aux8[13]], values_l[aux8[14]], values_l[aux8[15]]);
-
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = *(const int *)&val0;
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = *(const int *)&val1;
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = *(const int *)&val2;
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = *(const int *)&val3;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = int_from_table_x(aux8+ 0, values_l);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = int_from_table_x(aux8+ 4, values_l);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = int_from_table_x(aux8+ 8, values_l);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = int_from_table_x(aux8+12, values_l);
#else
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = *(const int *)&val0;
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = *(const int *)&val1;
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = *(const int *)&val2;
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = *(const int *)&val3;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = int_from_table_x(aux8+ 0, values_l);
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = int_from_table_x(aux8+ 4, values_l);
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = int_from_table_x(aux8+ 8, values_l);
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = int_from_table_x(aux8+12, values_l);
#endif // INT8_MMA_AVAILABLE
}