diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-07-08 19:44:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-07-08 19:44:48 +0200 |
commit | 97c34f4056067e167ed4508366f74b49e60202f7 (patch) | |
tree | 4ccdcc9ab35a3544c53aae6de72355d3c4603c08 | |
parent | 4c0b66026619cf51f45249181bf2cc1de8cd6884 (diff) |
Faster prompt processing for IQ2_KS, IQ2_K, IQ2_K_R4 (#593)
* cuda: faster MMQ for iq2_ks, iq2_k, iq2_k_r4
* Lookup is still beter for MMQ if we get 4 values at once
* Minor
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-cuda/iqk_cuda_common.h | 75 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 116 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cu | 4 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cuh | 67 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_r4.cu | 31 |
5 files changed, 133 insertions, 160 deletions
diff --git a/ggml/src/ggml-cuda/iqk_cuda_common.h b/ggml/src/ggml-cuda/iqk_cuda_common.h new file mode 100644 index 00000000..95d9b40e --- /dev/null +++ b/ggml/src/ggml-cuda/iqk_cuda_common.h @@ -0,0 +1,75 @@ +#pragma once + +#include "common.cuh" + +static const __device__ uint32_t iq2k_table[512] = { + 0xe1e1e1e1, 0xe1e1e1f3, 0xe1e1e101, 0xe1e1e111, 0xe1e1f3e1, 0xe1e1f3f3, 0xe1e1f301, 0xe1e1f311, + 0xe1e101e1, 0xe1e101f3, 0xe1e10101, 0xe1e10111, 0xe1e111e1, 0xe1e111f3, 0xe1e11101, 0xe1e11111, + 0xe1f3e1e1, 0xe1f3e1f3, 0xe1f3e101, 0xe1f3e111, 0xe1f3f3e1, 0xe1f3f3f3, 0xe1f3f301, 0xe1f3f311, + 0xe1f301e1, 0xe1f301f3, 0xe1f30101, 0xe1f30111, 0xe1f311e1, 0xe1f311f3, 0xe1f31101, 0xe1f31111, + 0xe101e1e1, 0xe101e1f3, 0xe101e101, 0xe101e111, 0xe101f3e1, 0xe101f3f3, 0xe101f301, 0xe101f311, + 0xe10101e1, 0xe10101f3, 0xe1010101, 0xe1010111, 0xe10111e1, 0xe10111f3, 0xe1011101, 0xe1011111, + 0xe111e1e1, 0xe111e1f3, 0xe111e101, 0xe111e111, 0xe111f3e1, 0xe111f3f3, 0xe111f301, 0xe111f311, + 0xe11101e1, 0xe11101f3, 0xe1110101, 0xe1110111, 0xe11111e1, 0xe11111f3, 0xe1111101, 0xe1111111, + 0xf3e1e1e1, 0xf3e1e1f3, 0xf3e1e101, 0xf3e1e111, 0xf3e1f3e1, 0xf3e1f3f3, 0xf3e1f301, 0xf3e1f311, + 0xf3e101e1, 0xf3e101f3, 0xf3e10101, 0xf3e10111, 0xf3e111e1, 0xf3e111f3, 0xf3e11101, 0xf3e11111, + 0xf3f3e1e1, 0xf3f3e1f3, 0xf3f3e101, 0xf3f3e111, 0xf3f3f3e1, 0xf3f3f3f3, 0xf3f3f301, 0xf3f3f311, + 0xf3f301e1, 0xf3f301f3, 0xf3f30101, 0xf3f30111, 0xf3f311e1, 0xf3f311f3, 0xf3f31101, 0xf3f31111, + 0xf301e1e1, 0xf301e1f3, 0xf301e101, 0xf301e111, 0xf301f3e1, 0xf301f3f3, 0xf301f301, 0xf301f311, + 0xf30101e1, 0xf30101f3, 0xf3010101, 0xf3010111, 0xf30111e1, 0xf30111f3, 0xf3011101, 0xf3011111, + 0xf311e1e1, 0xf311e1f3, 0xf311e101, 0xf311e111, 0xf311f3e1, 0xf311f3f3, 0xf311f301, 0xf311f311, + 0xf31101e1, 0xf31101f3, 0xf3110101, 0xf3110111, 0xf31111e1, 0xf31111f3, 0xf3111101, 0xf3111111, + 0x01e1e1e1, 0x01e1e1f3, 0x01e1e101, 0x01e1e111, 0x01e1f3e1, 0x01e1f3f3, 0x01e1f301, 0x01e1f311, + 0x01e101e1, 0x01e101f3, 0x01e10101, 0x01e10111, 0x01e111e1, 0x01e111f3, 0x01e11101, 0x01e11111, + 0x01f3e1e1, 0x01f3e1f3, 0x01f3e101, 0x01f3e111, 0x01f3f3e1, 0x01f3f3f3, 0x01f3f301, 0x01f3f311, + 0x01f301e1, 0x01f301f3, 0x01f30101, 0x01f30111, 0x01f311e1, 0x01f311f3, 0x01f31101, 0x01f31111, + 0x0101e1e1, 0x0101e1f3, 0x0101e101, 0x0101e111, 0x0101f3e1, 0x0101f3f3, 0x0101f301, 0x0101f311, + 0x010101e1, 0x010101f3, 0x01010101, 0x01010111, 0x010111e1, 0x010111f3, 0x01011101, 0x01011111, + 0x0111e1e1, 0x0111e1f3, 0x0111e101, 0x0111e111, 0x0111f3e1, 0x0111f3f3, 0x0111f301, 0x0111f311, + 0x011101e1, 0x011101f3, 0x01110101, 0x01110111, 0x011111e1, 0x011111f3, 0x01111101, 0x01111111, + 0x11e1e1e1, 0x11e1e1f3, 0x11e1e101, 0x11e1e111, 0x11e1f3e1, 0x11e1f3f3, 0x11e1f301, 0x11e1f311, + 0x11e101e1, 0x11e101f3, 0x11e10101, 0x11e10111, 0x11e111e1, 0x11e111f3, 0x11e11101, 0x11e11111, + 0x11f3e1e1, 0x11f3e1f3, 0x11f3e101, 0x11f3e111, 0x11f3f3e1, 0x11f3f3f3, 0x11f3f301, 0x11f3f311, + 0x11f301e1, 0x11f301f3, 0x11f30101, 0x11f30111, 0x11f311e1, 0x11f311f3, 0x11f31101, 0x11f31111, + 0x1101e1e1, 0x1101e1f3, 0x1101e101, 0x1101e111, 0x1101f3e1, 0x1101f3f3, 0x1101f301, 0x1101f311, + 0x110101e1, 0x110101f3, 0x11010101, 0x11010111, 0x110111e1, 0x110111f3, 0x11011101, 0x11011111, + 0x1111e1e1, 0x1111e1f3, 0x1111e101, 0x1111e111, 0x1111f3e1, 0x1111f3f3, 0x1111f301, 0x1111f311, + 0x111101e1, 0x111101f3, 0x11110101, 0x11110111, 0x111111e1, 0x111111f3, 0x11111101, 0x11111111, + 0xe6e6e6e6, 0xe6e6e6f8, 0xe6e6e606, 0xe6e6e616, 0xe6e6f8e6, 0xe6e6f8f8, 0xe6e6f806, 0xe6e6f816, + 0xe6e606e6, 0xe6e606f8, 0xe6e60606, 0xe6e60616, 0xe6e616e6, 0xe6e616f8, 0xe6e61606, 0xe6e61616, + 0xe6f8e6e6, 0xe6f8e6f8, 0xe6f8e606, 0xe6f8e616, 0xe6f8f8e6, 0xe6f8f8f8, 0xe6f8f806, 0xe6f8f816, + 0xe6f806e6, 0xe6f806f8, 0xe6f80606, 0xe6f80616, 0xe6f816e6, 0xe6f816f8, 0xe6f81606, 0xe6f81616, + 0xe606e6e6, 0xe606e6f8, 0xe606e606, 0xe606e616, 0xe606f8e6, 0xe606f8f8, 0xe606f806, 0xe606f816, + 0xe60606e6, 0xe60606f8, 0xe6060606, 0xe6060616, 0xe60616e6, 0xe60616f8, 0xe6061606, 0xe6061616, + 0xe616e6e6, 0xe616e6f8, 0xe616e606, 0xe616e616, 0xe616f8e6, 0xe616f8f8, 0xe616f806, 0xe616f816, + 0xe61606e6, 0xe61606f8, 0xe6160606, 0xe6160616, 0xe61616e6, 0xe61616f8, 0xe6161606, 0xe6161616, + 0xf8e6e6e6, 0xf8e6e6f8, 0xf8e6e606, 0xf8e6e616, 0xf8e6f8e6, 0xf8e6f8f8, 0xf8e6f806, 0xf8e6f816, + 0xf8e606e6, 0xf8e606f8, 0xf8e60606, 0xf8e60616, 0xf8e616e6, 0xf8e616f8, 0xf8e61606, 0xf8e61616, + 0xf8f8e6e6, 0xf8f8e6f8, 0xf8f8e606, 0xf8f8e616, 0xf8f8f8e6, 0xf8f8f8f8, 0xf8f8f806, 0xf8f8f816, + 0xf8f806e6, 0xf8f806f8, 0xf8f80606, 0xf8f80616, 0xf8f816e6, 0xf8f816f8, 0xf8f81606, 0xf8f81616, + 0xf806e6e6, 0xf806e6f8, 0xf806e606, 0xf806e616, 0xf806f8e6, 0xf806f8f8, 0xf806f806, 0xf806f816, + 0xf80606e6, 0xf80606f8, 0xf8060606, 0xf8060616, 0xf80616e6, 0xf80616f8, 0xf8061606, 0xf8061616, + 0xf816e6e6, 0xf816e6f8, 0xf816e606, 0xf816e616, 0xf816f8e6, 0xf816f8f8, 0xf816f806, 0xf816f816, + 0xf81606e6, 0xf81606f8, 0xf8160606, 0xf8160616, 0xf81616e6, 0xf81616f8, 0xf8161606, 0xf8161616, + 0x06e6e6e6, 0x06e6e6f8, 0x06e6e606, 0x06e6e616, 0x06e6f8e6, 0x06e6f8f8, 0x06e6f806, 0x06e6f816, + 0x06e606e6, 0x06e606f8, 0x06e60606, 0x06e60616, 0x06e616e6, 0x06e616f8, 0x06e61606, 0x06e61616, + 0x06f8e6e6, 0x06f8e6f8, 0x06f8e606, 0x06f8e616, 0x06f8f8e6, 0x06f8f8f8, 0x06f8f806, 0x06f8f816, + 0x06f806e6, 0x06f806f8, 0x06f80606, 0x06f80616, 0x06f816e6, 0x06f816f8, 0x06f81606, 0x06f81616, + 0x0606e6e6, 0x0606e6f8, 0x0606e606, 0x0606e616, 0x0606f8e6, 0x0606f8f8, 0x0606f806, 0x0606f816, + 0x060606e6, 0x060606f8, 0x06060606, 0x06060616, 0x060616e6, 0x060616f8, 0x06061606, 0x06061616, + 0x0616e6e6, 0x0616e6f8, 0x0616e606, 0x0616e616, 0x0616f8e6, 0x0616f8f8, 0x0616f806, 0x0616f816, + 0x061606e6, 0x061606f8, 0x06160606, 0x06160616, 0x061616e6, 0x061616f8, 0x06161606, 0x06161616, + 0x16e6e6e6, 0x16e6e6f8, 0x16e6e606, 0x16e6e616, 0x16e6f8e6, 0x16e6f8f8, 0x16e6f806, 0x16e6f816, + 0x16e606e6, 0x16e606f8, 0x16e60606, 0x16e60616, 0x16e616e6, 0x16e616f8, 0x16e61606, 0x16e61616, + 0x16f8e6e6, 0x16f8e6f8, 0x16f8e606, 0x16f8e616, 0x16f8f8e6, 0x16f8f8f8, 0x16f8f806, 0x16f8f816, + 0x16f806e6, 0x16f806f8, 0x16f80606, 0x16f80616, 0x16f816e6, 0x16f816f8, 0x16f81606, 0x16f81616, + 0x1606e6e6, 0x1606e6f8, 0x1606e606, 0x1606e616, 0x1606f8e6, 0x1606f8f8, 0x1606f806, 0x1606f816, + 0x160606e6, 0x160606f8, 0x16060606, 0x16060616, 0x160616e6, 0x160616f8, 0x16061606, 0x16061616, + 0x1616e6e6, 0x1616e6f8, 0x1616e606, 0x1616e616, 0x1616f8e6, 0x1616f8f8, 0x1616f806, 0x1616f816, + 0x161606e6, 0x161606f8, 0x16160606, 0x16160616, 0x161616e6, 0x161616f8, 0x16161606, 0x16161616, +}; + +__device__ __forceinline__ int int_from_table_4(const uint32_t idx, const int * values) { + return values[ggml_cuda_dp4a(idx, 0x40100401, 0)]; +} + diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 3b1f6acb..54d03f78 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -5,6 +5,7 @@ // #include "iqk_mmvq.cuh" +#include "iqk_cuda_common.h" typedef void (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float *); @@ -785,77 +786,6 @@ __device__ __forceinline__ void vec_dot_iq6_k_q8_1( *result += d6 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * bq6->scales[4*(i4/2)+(i4%2)] + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * bq6->scales[4*(i4/2)+(i4%2)+2]); } -static const __device__ uint32_t iq2k_table[512] = { - 0xe1e1e1e1, 0xe1e1e1f3, 0xe1e1e101, 0xe1e1e111, 0xe1e1f3e1, 0xe1e1f3f3, 0xe1e1f301, 0xe1e1f311, - 0xe1e101e1, 0xe1e101f3, 0xe1e10101, 0xe1e10111, 0xe1e111e1, 0xe1e111f3, 0xe1e11101, 0xe1e11111, - 0xe1f3e1e1, 0xe1f3e1f3, 0xe1f3e101, 0xe1f3e111, 0xe1f3f3e1, 0xe1f3f3f3, 0xe1f3f301, 0xe1f3f311, - 0xe1f301e1, 0xe1f301f3, 0xe1f30101, 0xe1f30111, 0xe1f311e1, 0xe1f311f3, 0xe1f31101, 0xe1f31111, - 0xe101e1e1, 0xe101e1f3, 0xe101e101, 0xe101e111, 0xe101f3e1, 0xe101f3f3, 0xe101f301, 0xe101f311, - 0xe10101e1, 0xe10101f3, 0xe1010101, 0xe1010111, 0xe10111e1, 0xe10111f3, 0xe1011101, 0xe1011111, - 0xe111e1e1, 0xe111e1f3, 0xe111e101, 0xe111e111, 0xe111f3e1, 0xe111f3f3, 0xe111f301, 0xe111f311, - 0xe11101e1, 0xe11101f3, 0xe1110101, 0xe1110111, 0xe11111e1, 0xe11111f3, 0xe1111101, 0xe1111111, - 0xf3e1e1e1, 0xf3e1e1f3, 0xf3e1e101, 0xf3e1e111, 0xf3e1f3e1, 0xf3e1f3f3, 0xf3e1f301, 0xf3e1f311, - 0xf3e101e1, 0xf3e101f3, 0xf3e10101, 0xf3e10111, 0xf3e111e1, 0xf3e111f3, 0xf3e11101, 0xf3e11111, - 0xf3f3e1e1, 0xf3f3e1f3, 0xf3f3e101, 0xf3f3e111, 0xf3f3f3e1, 0xf3f3f3f3, 0xf3f3f301, 0xf3f3f311, - 0xf3f301e1, 0xf3f301f3, 0xf3f30101, 0xf3f30111, 0xf3f311e1, 0xf3f311f3, 0xf3f31101, 0xf3f31111, - 0xf301e1e1, 0xf301e1f3, 0xf301e101, 0xf301e111, 0xf301f3e1, 0xf301f3f3, 0xf301f301, 0xf301f311, - 0xf30101e1, 0xf30101f3, 0xf3010101, 0xf3010111, 0xf30111e1, 0xf30111f3, 0xf3011101, 0xf3011111, - 0xf311e1e1, 0xf311e1f3, 0xf311e101, 0xf311e111, 0xf311f3e1, 0xf311f3f3, 0xf311f301, 0xf311f311, - 0xf31101e1, 0xf31101f3, 0xf3110101, 0xf3110111, 0xf31111e1, 0xf31111f3, 0xf3111101, 0xf3111111, - 0x01e1e1e1, 0x01e1e1f3, 0x01e1e101, 0x01e1e111, 0x01e1f3e1, 0x01e1f3f3, 0x01e1f301, 0x01e1f311, - 0x01e101e1, 0x01e101f3, 0x01e10101, 0x01e10111, 0x01e111e1, 0x01e111f3, 0x01e11101, 0x01e11111, - 0x01f3e1e1, 0x01f3e1f3, 0x01f3e101, 0x01f3e111, 0x01f3f3e1, 0x01f3f3f3, 0x01f3f301, 0x01f3f311, - 0x01f301e1, 0x01f301f3, 0x01f30101, 0x01f30111, 0x01f311e1, 0x01f311f3, 0x01f31101, 0x01f31111, - 0x0101e1e1, 0x0101e1f3, 0x0101e101, 0x0101e111, 0x0101f3e1, 0x0101f3f3, 0x0101f301, 0x0101f311, - 0x010101e1, 0x010101f3, 0x01010101, 0x01010111, 0x010111e1, 0x010111f3, 0x01011101, 0x01011111, - 0x0111e1e1, 0x0111e1f3, 0x0111e101, 0x0111e111, 0x0111f3e1, 0x0111f3f3, 0x0111f301, 0x0111f311, - 0x011101e1, 0x011101f3, 0x01110101, 0x01110111, 0x011111e1, 0x011111f3, 0x01111101, 0x01111111, - 0x11e1e1e1, 0x11e1e1f3, 0x11e1e101, 0x11e1e111, 0x11e1f3e1, 0x11e1f3f3, 0x11e1f301, 0x11e1f311, - 0x11e101e1, 0x11e101f3, 0x11e10101, 0x11e10111, 0x11e111e1, 0x11e111f3, 0x11e11101, 0x11e11111, - 0x11f3e1e1, 0x11f3e1f3, 0x11f3e101, 0x11f3e111, 0x11f3f3e1, 0x11f3f3f3, 0x11f3f301, 0x11f3f311, - 0x11f301e1, 0x11f301f3, 0x11f30101, 0x11f30111, 0x11f311e1, 0x11f311f3, 0x11f31101, 0x11f31111, - 0x1101e1e1, 0x1101e1f3, 0x1101e101, 0x1101e111, 0x1101f3e1, 0x1101f3f3, 0x1101f301, 0x1101f311, - 0x110101e1, 0x110101f3, 0x11010101, 0x11010111, 0x110111e1, 0x110111f3, 0x11011101, 0x11011111, - 0x1111e1e1, 0x1111e1f3, 0x1111e101, 0x1111e111, 0x1111f3e1, 0x1111f3f3, 0x1111f301, 0x1111f311, - 0x111101e1, 0x111101f3, 0x11110101, 0x11110111, 0x111111e1, 0x111111f3, 0x11111101, 0x11111111, - 0xe6e6e6e6, 0xe6e6e6f8, 0xe6e6e606, 0xe6e6e616, 0xe6e6f8e6, 0xe6e6f8f8, 0xe6e6f806, 0xe6e6f816, - 0xe6e606e6, 0xe6e606f8, 0xe6e60606, 0xe6e60616, 0xe6e616e6, 0xe6e616f8, 0xe6e61606, 0xe6e61616, - 0xe6f8e6e6, 0xe6f8e6f8, 0xe6f8e606, 0xe6f8e616, 0xe6f8f8e6, 0xe6f8f8f8, 0xe6f8f806, 0xe6f8f816, - 0xe6f806e6, 0xe6f806f8, 0xe6f80606, 0xe6f80616, 0xe6f816e6, 0xe6f816f8, 0xe6f81606, 0xe6f81616, - 0xe606e6e6, 0xe606e6f8, 0xe606e606, 0xe606e616, 0xe606f8e6, 0xe606f8f8, 0xe606f806, 0xe606f816, - 0xe60606e6, 0xe60606f8, 0xe6060606, 0xe6060616, 0xe60616e6, 0xe60616f8, 0xe6061606, 0xe6061616, - 0xe616e6e6, 0xe616e6f8, 0xe616e606, 0xe616e616, 0xe616f8e6, 0xe616f8f8, 0xe616f806, 0xe616f816, - 0xe61606e6, 0xe61606f8, 0xe6160606, 0xe6160616, 0xe61616e6, 0xe61616f8, 0xe6161606, 0xe6161616, - 0xf8e6e6e6, 0xf8e6e6f8, 0xf8e6e606, 0xf8e6e616, 0xf8e6f8e6, 0xf8e6f8f8, 0xf8e6f806, 0xf8e6f816, - 0xf8e606e6, 0xf8e606f8, 0xf8e60606, 0xf8e60616, 0xf8e616e6, 0xf8e616f8, 0xf8e61606, 0xf8e61616, - 0xf8f8e6e6, 0xf8f8e6f8, 0xf8f8e606, 0xf8f8e616, 0xf8f8f8e6, 0xf8f8f8f8, 0xf8f8f806, 0xf8f8f816, - 0xf8f806e6, 0xf8f806f8, 0xf8f80606, 0xf8f80616, 0xf8f816e6, 0xf8f816f8, 0xf8f81606, 0xf8f81616, - 0xf806e6e6, 0xf806e6f8, 0xf806e606, 0xf806e616, 0xf806f8e6, 0xf806f8f8, 0xf806f806, 0xf806f816, - 0xf80606e6, 0xf80606f8, 0xf8060606, 0xf8060616, 0xf80616e6, 0xf80616f8, 0xf8061606, 0xf8061616, - 0xf816e6e6, 0xf816e6f8, 0xf816e606, 0xf816e616, 0xf816f8e6, 0xf816f8f8, 0xf816f806, 0xf816f816, - 0xf81606e6, 0xf81606f8, 0xf8160606, 0xf8160616, 0xf81616e6, 0xf81616f8, 0xf8161606, 0xf8161616, - 0x06e6e6e6, 0x06e6e6f8, 0x06e6e606, 0x06e6e616, 0x06e6f8e6, 0x06e6f8f8, 0x06e6f806, 0x06e6f816, - 0x06e606e6, 0x06e606f8, 0x06e60606, 0x06e60616, 0x06e616e6, 0x06e616f8, 0x06e61606, 0x06e61616, - 0x06f8e6e6, 0x06f8e6f8, 0x06f8e606, 0x06f8e616, 0x06f8f8e6, 0x06f8f8f8, 0x06f8f806, 0x06f8f816, - 0x06f806e6, 0x06f806f8, 0x06f80606, 0x06f80616, 0x06f816e6, 0x06f816f8, 0x06f81606, 0x06f81616, - 0x0606e6e6, 0x0606e6f8, 0x0606e606, 0x0606e616, 0x0606f8e6, 0x0606f8f8, 0x0606f806, 0x0606f816, - 0x060606e6, 0x060606f8, 0x06060606, 0x06060616, 0x060616e6, 0x060616f8, 0x06061606, 0x06061616, - 0x0616e6e6, 0x0616e6f8, 0x0616e606, 0x0616e616, 0x0616f8e6, 0x0616f8f8, 0x0616f806, 0x0616f816, - 0x061606e6, 0x061606f8, 0x06160606, 0x06160616, 0x061616e6, 0x061616f8, 0x06161606, 0x06161616, - 0x16e6e6e6, 0x16e6e6f8, 0x16e6e606, 0x16e6e616, 0x16e6f8e6, 0x16e6f8f8, 0x16e6f806, 0x16e6f816, - 0x16e606e6, 0x16e606f8, 0x16e60606, 0x16e60616, 0x16e616e6, 0x16e616f8, 0x16e61606, 0x16e61616, - 0x16f8e6e6, 0x16f8e6f8, 0x16f8e606, 0x16f8e616, 0x16f8f8e6, 0x16f8f8f8, 0x16f8f806, 0x16f8f816, - 0x16f806e6, 0x16f806f8, 0x16f80606, 0x16f80616, 0x16f816e6, 0x16f816f8, 0x16f81606, 0x16f81616, - 0x1606e6e6, 0x1606e6f8, 0x1606e606, 0x1606e616, 0x1606f8e6, 0x1606f8f8, 0x1606f806, 0x1606f816, - 0x160606e6, 0x160606f8, 0x16060606, 0x16060616, 0x160616e6, 0x160616f8, 0x16061606, 0x16061616, - 0x1616e6e6, 0x1616e6f8, 0x1616e606, 0x1616e616, 0x1616f8e6, 0x1616f8f8, 0x1616f806, 0x1616f816, - 0x161606e6, 0x161606f8, 0x16160606, 0x16160616, 0x161616e6, 0x161616f8, 0x16161606, 0x16161616, -}; - -__device__ __forceinline__ int int_from_table_4(const uint8_t * a8, const int * values) { - return values[a8[0] | (a8[1] << 2) | (a8[2] << 4) | (a8[3] << 6)]; -} - #define VDR_IQ2_K_Q8_1_MMVQ 4 #define VDR_IQ2_K_Q8_1_MMQ 4 @@ -881,7 +811,6 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1( uint32_t val1 = q2[0], val2 = q2[1]; uint32_t aux32[2]; - const uint8_t * a8 = (const uint8_t *)&aux32; int v1, v2; // Block of 16: (32*(4*(i4/4)+k)+8*(i4%4))/16 = 8*(i4/4) + 2*k + (i4%4)/2 @@ -892,23 +821,23 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1( const int8_t * s8 = (const int8_t *)&s32; aux32[0] = ((val1 >> 0) & 0x03030303); aux32[1] = ((val2 >> 0) & 0x03030303); values = all_values + ((extra & 0x01) << 8); - v1 = int_from_table_4(a8 + 0, values); - v2 = int_from_table_4(a8 + 4, values); + v1 = int_from_table_4(aux32[0], values); + v2 = int_from_table_4(aux32[1], values); int sumi1 = ggml_cuda_dp4a(v2, q8_1[1], ggml_cuda_dp4a(v1, q8_1[0], 0)) * s8[0]; aux32[0] = ((val1 >> 2) & 0x03030303); aux32[1] = ((val2 >> 2) & 0x03030303); values = all_values + ((extra & 0x04) << 6); - v1 = int_from_table_4(a8 + 0, values); - v2 = int_from_table_4(a8 + 4, values); + v1 = int_from_table_4(aux32[0], values); + v2 = int_from_table_4(aux32[1], values); int sumi2 = ggml_cuda_dp4a(v2, q8_2[1], ggml_cuda_dp4a(v1, q8_2[0], 0)) * s8[1]; aux32[0] = ((val1 >> 4) & 0x03030303); aux32[1] = ((val2 >> 4) & 0x03030303); values = all_values + ((extra & 0x10) << 4); - v1 = int_from_table_4(a8 + 0, values); - v2 = int_from_table_4(a8 + 4, values); + v1 = int_from_table_4(aux32[0], values); + v2 = int_from_table_4(aux32[1], values); int sumi3 = ggml_cuda_dp4a(v2, q8_3[1], ggml_cuda_dp4a(v1, q8_3[0], 0)) * s8[2]; aux32[0] = ((val1 >> 6) & 0x03030303); aux32[1] = ((val2 >> 6) & 0x03030303); values = all_values + ((extra & 0x40) << 2); - v1 = int_from_table_4(a8 + 0, values); - v2 = int_from_table_4(a8 + 4, values); + v1 = int_from_table_4(aux32[0], values); + v2 = int_from_table_4(aux32[1], values); int sumi4 = ggml_cuda_dp4a(v2, q8_4[1], ggml_cuda_dp4a(v1, q8_4[0], 0)) * s8[3]; *result += __half2float(bq2->d) * (__low2float(bq8_1[4*(i4/4)+0].ds) * sumi1 @@ -941,7 +870,6 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1( uint32_t val1 = q2[0] | (q2[1] << 16), val2 = q2[2] | (q2[3] << 16); uint32_t aux32[2]; - const uint8_t * a8 = (const uint8_t *)&aux32; int v1, v2; int32_t scales32; @@ -954,23 +882,23 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1( s8[3] += ((extra >> 7) & 0x10); aux32[0] = ((val1 >> 0) & 0x03030303); aux32[1] = ((val2 >> 0) & 0x03030303); values = all_values + ((extra & 0x01) << 8); - v1 = int_from_table_4(a8 + 0, values); - v2 = int_from_table_4(a8 + 4, values); + v1 = int_from_table_4(aux32[0], values); + v2 = int_from_table_4(aux32[1], values); int sumi1 = ggml_cuda_dp4a(v2, q8_1[1], ggml_cuda_dp4a(v1, q8_1[0], 0)) * s8[0]; aux32[0] = ((val1 >> 2) & 0x03030303); aux32[1] = ((val2 >> 2) & 0x03030303); values = all_values + ((extra & 0x02) << 7); - v1 = int_from_table_4(a8 + 0, values); - v2 = int_from_table_4(a8 + 4, values); + v1 = int_from_table_4(aux32[0], values); + v2 = int_from_table_4(aux32[1], values); int sumi2 = ggml_cuda_dp4a(v2, q8_2[1], ggml_cuda_dp4a(v1, q8_2[0], 0)) * s8[2]; aux32[0] = ((val1 >> 4) & 0x03030303); aux32[1] = ((val2 >> 4) & 0x03030303); values = all_values + ((extra & 0x04) << 6); - v1 = int_from_table_4(a8 + 0, values); - v2 = int_from_table_4(a8 + 4, values); + v1 = int_from_table_4(aux32[0], values); + v2 = int_from_table_4(aux32[1], values); int sumi3 = ggml_cuda_dp4a(v2, q8_3[1], ggml_cuda_dp4a(v1, q8_3[0], 0)) * s8[1]; aux32[0] = ((val1 >> 6) & 0x03030303); aux32[1] = ((val2 >> 6) & 0x03030303); values = all_values + ((extra & 0x08) << 5); - v1 = int_from_table_4(a8 + 0, values); - v2 = int_from_table_4(a8 + 4, values); + v1 = int_from_table_4(aux32[0], values); + v2 = int_from_table_4(aux32[1], values); int sumi4 = ggml_cuda_dp4a(v2, q8_4[1], ggml_cuda_dp4a(v1, q8_4[0], 0)) * s8[3]; *result += scale * (__low2float(bq8_1[4*(i4/4)+0].ds) * sumi1 @@ -1000,20 +928,19 @@ __device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1( int2 val1; const int * q2 = (const int *)bq2->qs + 8*ib32 + 4*is; int aux32[2]; - const uint8_t * aux8 = (const uint8_t *)aux32; #pragma unroll for (int i = 0; i < 4; ++i) { auto values1 = all_values + (((bq2->extra[i+4*is] >> ib32) & 1) << 8); int sumi1 = 0; aux32[0] = ((q2[i] >> 0) & 0x03030303); aux32[1] = ((q2[i] >> 2) & 0x03030303); - val1.x = int_from_table_4(aux8+0, values1); - val1.y = int_from_table_4(aux8+4, values1); + val1.x = int_from_table_4(aux32[0], values1); + val1.y = int_from_table_4(aux32[1], values1); sumi1 = ggml_cuda_dp4a(val1.x, q8[0], ggml_cuda_dp4a(val1.y, q8[1], sumi1)); aux32[0] = ((q2[i] >> 4) & 0x03030303); aux32[1] = ((q2[i] >> 6) & 0x03030303); - val1.x = int_from_table_4(aux8+0, values1); - val1.y = int_from_table_4(aux8+4, values1); + val1.x = int_from_table_4(aux32[0], values1); + val1.y = int_from_table_4(aux32[1], values1); sumi1 = ggml_cuda_dp4a(val1.x, q8[2], ggml_cuda_dp4a(val1.y, q8[3], sumi1)); const float d = __half2float(bq2->d[i]) * d8; result[i] += d * sumi1 * s8[i]; @@ -1114,7 +1041,6 @@ __device__ __forceinline__ void vec_dot_iq3_ks_q8_1( const int ib128 = iqs/4; // 0 or 1. 0 works on quants 0...127, 1 on quants 128...255 // Each thread processes 8 quants in each of the 4 32-blocks const int il8 = iqs%4; // 0...3. 0 works on quants 0...7, 1 on quants 8...15, 2 on 16...23, 3 on 24...31 - const int shift = 4*(il8/2); const uint16_t * ql = (const uint16_t *)bq3->qs + 16*ib128 + 4*il8; const uint16_t * qh = (const uint16_t *)bq3->qh + 4*il8; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 9c206e50..231c4a41 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -174,11 +174,13 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { mmq_supported = ne11 < 1536; break; case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: + mmq_supported = ne11 < 2048; + break; case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ6_K: - case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K_R4: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index d6f4cf3a..8a87e3e8 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -10,6 +10,7 @@ #include "common.cuh" #include "vecdotq.cuh" #include "mma.cuh" +#include "iqk_cuda_common.h" #include <climits> #include <cstdint> @@ -2492,12 +2493,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin float * x_df = (float *) (x_qs + txs.qs); #endif // INT8_MMA_AVAILABLE - const int kqsx = threadIdx.x%16; + const int * all_values = (const int *)iq2k_table; - auto values = iq2nl_values; + const int kqsx = threadIdx.x%16; - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += 2*nwarps) { int i = i0 + 2*threadIdx.y + threadIdx.x/16; @@ -2511,26 +2510,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin uint16_t extra = bxi->extra >> 4*(kqsx/8); int q2 = get_int_b2(bxi->qs, kqsx); - aux32[0] = ((q2 >> 0) & 0x03030303) | (((extra << 2) & 4) * 0x01010101); - aux32[1] = ((q2 >> 2) & 0x03030303) | (((extra << 1) & 4) * 0x01010101); - aux32[2] = ((q2 >> 4) & 0x03030303) | (((extra >> 0) & 4) * 0x01010101); - aux32[3] = ((q2 >> 6) & 0x03030303) | (((extra >> 1) & 4) * 0x01010101); - - const char4 val0 = make_char4(values[aux8[ 0]], values[aux8[ 1]], values[aux8[ 2]], values[aux8[ 3]]); - const char4 val1 = make_char4(values[aux8[ 4]], values[aux8[ 5]], values[aux8[ 6]], values[aux8[ 7]]); - const char4 val2 = make_char4(values[aux8[ 8]], values[aux8[ 9]], values[aux8[10]], values[aux8[11]]); - const char4 val3 = make_char4(values[aux8[12]], values[aux8[13]], values[aux8[14]], values[aux8[15]]); - #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 0] = *(const int *)&val0; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 8] = *(const int *)&val1; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 16] = *(const int *)&val2; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 24] = *(const int *)&val3; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 0] = int_from_table_4((q2 >> 0) & 0x03030303, all_values + ((extra & 1) << 8)); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 8] = int_from_table_4((q2 >> 2) & 0x03030303, all_values + ((extra & 2) << 7)); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 16] = int_from_table_4((q2 >> 4) & 0x03030303, all_values + ((extra & 4) << 6)); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 24] = int_from_table_4((q2 >> 6) & 0x03030303, all_values + ((extra & 8) << 5)); #else - x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 0] = *(const int *)&val0; - x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 8] = *(const int *)&val1; - x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 16] = *(const int *)&val2; - x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = *(const int *)&val3; + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 0] = int_from_table_4((q2 >> 0) & 0x03030303, all_values + ((extra & 1) << 8)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 8] = int_from_table_4((q2 >> 2) & 0x03030303, all_values + ((extra & 2) << 7)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 16] = int_from_table_4((q2 >> 4) & 0x03030303, all_values + ((extra & 4) << 6)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = int_from_table_4((q2 >> 6) & 0x03030303, all_values + ((extra & 8) << 5)); #endif // INT8_MMA_AVAILABLE } @@ -2573,10 +2562,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin constexpr int qstep = 8; const int kqsx = threadIdx.x % qstep; - auto values = iq2nl_values; - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; @@ -2587,6 +2572,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin const block_iq2_k * bxi = (const block_iq2_k *)(x + i*stride) + kbx0; + auto all_values = (const int *)iq2k_table; + const float d = bxi->d; uint16_t extra = bxi->extra >> (kqsx/4); @@ -2595,28 +2582,20 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin for (int l = 0; l < qstep/4; ++l) { const int ql = get_int_b4(bxi->qs, kqsx + qstep*l); - aux32[0] = ((ql >> 0) & 0x03030303) | (((extra << 2) & 4) * 0x01010101); - aux32[1] = ((ql >> 2) & 0x03030303) | (((extra << 0) & 4) * 0x01010101); - aux32[2] = ((ql >> 4) & 0x03030303) | (((extra >> 2) & 4) * 0x01010101); - aux32[3] = ((ql >> 6) & 0x03030303) | (((extra >> 4) & 4) * 0x01010101); - extra >>= 8; - - const char4 val0 = make_char4(values[aux8[ 0]], values[aux8[ 1]], values[aux8[ 2]], values[aux8[ 3]]); - const char4 val1 = make_char4(values[aux8[ 4]], values[aux8[ 5]], values[aux8[ 6]], values[aux8[ 7]]); - const char4 val2 = make_char4(values[aux8[ 8]], values[aux8[ 9]], values[aux8[10]], values[aux8[11]]); - const char4 val3 = make_char4(values[aux8[12]], values[aux8[13]], values[aux8[14]], values[aux8[15]]); #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = *(const int *)&val0; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = *(const int *)&val1; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = *(const int *)&val2; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = *(const int *)&val3; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, all_values + ((extra & 0x01) << 8)); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = int_from_table_4((ql >> 2) & 0x03030303, all_values + ((extra & 0x04) << 6)); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = int_from_table_4((ql >> 4) & 0x03030303, all_values + ((extra & 0x10) << 4)); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = int_from_table_4((ql >> 6) & 0x03030303, all_values + ((extra & 0x40) << 2)); #else - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = *(const int *)&val0; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = *(const int *)&val1; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = *(const int *)&val2; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = *(const int *)&val3; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, all_values + ((extra & 0x01) << 8)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = int_from_table_4((ql >> 2) & 0x03030303, all_values + ((extra & 0x04) << 6)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = int_from_table_4((ql >> 4) & 0x03030303, all_values + ((extra & 0x10) << 4)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = int_from_table_4((ql >> 6) & 0x03030303, all_values + ((extra & 0x40) << 2)); #endif // INT8_MMA_AVAILABLE + + extra >>= 8; } #ifdef INT8_MMA_AVAILABLE diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_r4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_r4.cu index d7b5a18e..e40d55a0 100644 --- a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_r4.cu +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_r4.cu @@ -14,10 +14,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin float * x_df = (float *) (x_qs + txs.qs); #endif // INT8_MMA_AVAILABLE + const int * all_values = (const int *)iq2k_table; + const int kqsx = threadIdx.x/4; // 0...7 -> block of 32 - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { int i = i0 + 4*threadIdx.y + threadIdx.x%4; @@ -35,29 +35,20 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin #pragma unroll for (int l = 0; l < 2; ++l) { - auto values_l = iq2nl_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 2); + auto values_l = all_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8); const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l); - aux32[0] = (ql >> 0) & 0x03030303; - aux32[1] = (ql >> 2) & 0x03030303; - aux32[2] = (ql >> 4) & 0x03030303; - aux32[3] = (ql >> 6) & 0x03030303; - - const char4 val0 = make_char4(values_l[aux8[ 0]], values_l[aux8[ 1]], values_l[aux8[ 2]], values_l[aux8[ 3]]); - const char4 val1 = make_char4(values_l[aux8[ 4]], values_l[aux8[ 5]], values_l[aux8[ 6]], values_l[aux8[ 7]]); - const char4 val2 = make_char4(values_l[aux8[ 8]], values_l[aux8[ 9]], values_l[aux8[10]], values_l[aux8[11]]); - const char4 val3 = make_char4(values_l[aux8[12]], values_l[aux8[13]], values_l[aux8[14]], values_l[aux8[15]]); #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = *(const int *)&val0; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = *(const int *)&val1; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = *(const int *)&val2; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = *(const int *)&val3; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = int_from_table_4((ql >> 2) & 0x03030303, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = int_from_table_4((ql >> 4) & 0x03030303, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l); #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = *(const int *)&val0; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = *(const int *)&val1; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = *(const int *)&val2; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = *(const int *)&val3; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = int_from_table_4((ql >> 2) & 0x03030303, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = int_from_table_4((ql >> 4) & 0x03030303, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l); #endif // INT8_MMA_AVAILABLE } |