summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-07-14 18:55:08 +0200
committerGitHub <noreply@github.com>2025-07-14 18:55:08 +0200
commit45fae1a14444622478774f9a417e1d417af1ca46 (patch)
tree2609ef06be5640749834d4fc691446771ab29f42 /ggml/src/ggml-cuda
parentf5353047ef461e6fc9d527e09a06c9802c699929 (diff)
Adding IQ2_KL (#602)
* Experiments for 2.6875 bpw quants At least according to rmse, this is significantly better than q2_K, while using only 1/16 more bits per weight. * iq2_kl: basics * iq2_kl: CUDA dequantize * iq2_kl: small improvement in PPL Also check the two neighbouring values for the block scale and use the one that minimizes RMSE. * iq2_kl: MMQ Quite good: PP-512(L3-8B) = 8472 t/s. * iq2_kl: MMVQ We get PP-128(L3-8B) = 162 t/s. Which means that this is not quite as good as it should be as (almost) same bpq q2_K is at 170 t/s. * iq2_kl: Zen4 GEMM/GEMV Not particularly fast. I may need to think about rearranging the bits. * iq2_kl: better Zen4 * iq2_kl: convert/repack to q8_k_r8 (AVX2) * iq2_kl: AVX2 GEMM/GEMV * iq2_kl: WIP NEON The compiler started crashing!!! * iq2_kl: NEON Had to work around a compiler crash when using vzip2q_u8 using vqtbl2q_u8. * iq2_kl: convert/repack to q8_k_r8 (NEON) * iq2_kl: Metal dequantize * iq2_kl: Metal GEMV - pretty slow * iq2_kl: Metal GEMV - slightly better (40 t/s -> 44.5 t/s) * iq2_kl: Metal GEMV - slightly better (44.5 t/s -> 46.5 t/s) * iq2_kl: Metal GEMV - slightly better (46.5 t/s -> 47.2 t/s) * iq2_kl: slightly better Metal dequantize PP-512 goes to 476 t/s up from 466 t/s. * iq2_kl: slightly better Metal dequantize PP-512 goes to 492 t/s up from 476 t/s. * Add iq2_kl to constants.py --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r--ggml/src/ggml-cuda/common.cuh7
-rw-r--r--ggml/src/ggml-cuda/convert.cu56
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu54
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cuh5
-rw-r--r--ggml/src/ggml-cuda/mmq.cu4
-rw-r--r--ggml/src/ggml-cuda/mmq.cuh4
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu4
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kl.cu70
8 files changed, 204 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 973af2b8..38b52fd0 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -600,6 +600,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KL> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_XS;
+ static constexpr int qi = QI4_XS;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_KS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index 61c09481..c8e02a83 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -1334,6 +1334,48 @@ static __global__ void dequantize_block_iq3_k(const void * __restrict__ vx, dst_
}
template<typename dst_t>
+static __global__ void dequantize_block_iq2_kl(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
+
+ int64_t ii = blockIdx.x;
+ int64_t row = (QK_K * ii) / n_per_row;
+ const char * cx = (const char *)vx + row * row_size;
+ float scale = (float)*(const ggml_half *)cx;
+ const block_iq2_kl * x = (const block_iq2_kl *)(cx + sizeof(ggml_half));
+ const int64_t i = ii - (row*n_per_row)/QK_K;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t ib64 = tid/8;
+ const int64_t il = tid%8;
+ dst_t * y = yy + ii*QK_K + 64*ib64 + 4*il;
+ const uint8_t * qs = x[i].qs + 16*ib64 + 2*il;
+ const uint8_t * qh = x[i].qh + 2*il;
+ auto sh = x[i].scales_h >> 4*ib64;
+ const float d1 = scale * (int(((x[i].scales_l[(2*ib64+0)%4] >> 4*(ib64/2)) & 0xf) | ((sh << 4) & 0x30)) - 32);
+ const float d2 = scale * (int(((x[i].scales_l[(2*ib64+1)%4] >> 4*(ib64/2)) & 0xf) | ((sh << 2) & 0x30)) - 32);
+ if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
+ for (int j = 0; j < 2; ++j) {
+ uint8_t h = qh[j] >> 2*ib64;
+ auto val1 = (const int8_t *)(iq2kl_values + ((qs[j] & 0xf) | ((h & 1) << 4)));
+ auto val2 = (const int8_t *)(iq2kl_values + ((qs[j] >> 4) | ((h & 2) << 3)));
+ y[2*j+ 0] = __float2bfloat16(d1 * val1[0]);
+ y[2*j+ 1] = __float2bfloat16(d1 * val1[1]);
+ y[2*j+32] = __float2bfloat16(d2 * val2[0]);
+ y[2*j+33] = __float2bfloat16(d2 * val2[1]);
+ }
+ } else {
+ for (int j = 0; j < 2; ++j) {
+ uint8_t h = qh[j] >> 2*ib64;
+ auto val1 = (const int8_t *)(iq2kl_values + ((qs[j] & 0xf) | ((h & 1) << 4)));
+ auto val2 = (const int8_t *)(iq2kl_values + ((qs[j] >> 4) | ((h & 2) << 3)));
+ y[2*j+ 0] = d1 * val1[0];
+ y[2*j+ 1] = d1 * val1[1];
+ y[2*j+32] = d2 * val2[0];
+ y[2*j+33] = d2 * val2[1];
+ }
+ }
+}
+
+template<typename dst_t>
static __global__ void dequantize_block_iq3_ks(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
int64_t ii = blockIdx.x;
@@ -1619,6 +1661,14 @@ static void dequantize_row_iq3_k_cuda(const void * vx, dst_t * y, const int64_t
}
template<typename dst_t>
+static void dequantize_row_iq2_kl_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
+ const int64_t k = nrows * n_per_row;
+ const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_KL, n_per_row);
+ const int nb = (k + QK_K - 1) / QK_K;
+ dequantize_block_iq2_kl<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
+}
+
+template<typename dst_t>
static void dequantize_row_iq3_ks_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int64_t row_size = ggml_row_size(GGML_TYPE_IQ3_KS, n_per_row);
@@ -1772,6 +1822,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
return dequantize_row_iq2_k_cuda<nv_bfloat16>;
case GGML_TYPE_IQ3_K:
return dequantize_row_iq3_k_cuda<nv_bfloat16>;
+ case GGML_TYPE_IQ2_KL:
+ return dequantize_row_iq2_kl_cuda<nv_bfloat16>;
case GGML_TYPE_IQ3_KS:
return dequantize_row_iq3_ks_cuda<nv_bfloat16>;
case GGML_TYPE_IQ4_KSS:
@@ -1876,6 +1928,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq2_k_cuda;
case GGML_TYPE_IQ3_K:
return dequantize_row_iq3_k_cuda;
+ case GGML_TYPE_IQ2_KL:
+ return dequantize_row_iq2_kl_cuda;
case GGML_TYPE_IQ3_KS:
return dequantize_row_iq3_ks_cuda;
case GGML_TYPE_IQ4_K:
@@ -1973,6 +2027,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq2_k_cuda;
case GGML_TYPE_IQ3_K:
return dequantize_row_iq3_k_cuda;
+ case GGML_TYPE_IQ2_KL:
+ return dequantize_row_iq2_kl_cuda;
case GGML_TYPE_IQ3_KS:
return dequantize_row_iq3_ks_cuda;
case GGML_TYPE_IQ4_K:
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index d897063f..a669390d 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -1016,6 +1016,52 @@ __device__ __forceinline__ void vec_dot_iq3_k_q8_1(
}
+// TODO
+__device__ __forceinline__ void vec_dot_iq2_kl_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iiqs, float * result) {
+
+ float d = __half2float(*(const half *)vbq);
+ const block_iq2_kl * bq2 = (const block_iq2_kl *)((const char *)vbq + sizeof(half)) + kbx;
+
+ int iqs = iiqs/4;
+ const int ib64 = iqs/2; // 0...3. 0 works on quants 0...63, 1 on quants 64...127, etc.
+ // Each thread processes 16 quants in each of the 2 32-blocks
+ const int il16 = iqs%2; // 0...3. 0 works on quants 0...7, 1 on quants 8...15, 2 on 16...23, 3 on 24...31
+
+ const uint16_t * ql = (const uint16_t *)bq2->qs + 8*ib64 + 4*il16;
+ const uint16_t * qh = (const uint16_t *)bq2->qh + 4*il16;
+
+ int32_t aux32;
+ const uint8_t * aux8 = (const uint8_t *)&aux32;
+
+ const int * q8l = (const int *)bq8_1[2*ib64+0].qs + 4*il16;
+ const int * q8h = (const int *)bq8_1[2*ib64+1].qs + 4*il16;
+
+ int sumi1 = 0, sumi2 = 0;
+ int v1, v2;
+ for (int i = 0; i < 2; ++i) {
+ uint32_t vl = ql[2*i+0] | (ql[2*i+1] << 16);
+ uint32_t vh = (qh[2*i+0] | (qh[2*i+1] << 16)) >> 2*ib64;
+
+ aux32 = (vl & 0x0f0f0f0f) | ((vh << 4) & 0x10101010);
+ v1 = iq2kl_values[aux8[0]] | (iq2kl_values[aux8[1]] << 16);
+ v2 = iq2kl_values[aux8[2]] | (iq2kl_values[aux8[3]] << 16);
+ sumi1 = ggml_cuda_dp4a(v1, q8l[2*i+0], ggml_cuda_dp4a(v2, q8l[2*i+1], sumi1));
+
+ aux32 = ((vl >> 4) & 0x0f0f0f0f) | ((vh << 3) & 0x10101010);
+ v1 = iq2kl_values[aux8[0]] | (iq2kl_values[aux8[1]] << 16);
+ v2 = iq2kl_values[aux8[2]] | (iq2kl_values[aux8[3]] << 16);
+ sumi2 = ggml_cuda_dp4a(v1, q8h[2*i+0], ggml_cuda_dp4a(v2, q8h[2*i+1], sumi2));
+ }
+
+ auto sh = bq2->scales_h >> 4*ib64;
+ int ls1 = int(((bq2->scales_l[(2*ib64+0)%4] >> 4*(ib64/2)) & 0xf) | ((sh << 4) & 0x30)) - 32;
+ int ls2 = int(((bq2->scales_l[(2*ib64+1)%4] >> 4*(ib64/2)) & 0xf) | ((sh << 2) & 0x30)) - 32;
+
+ *result += d * (__low2float(bq8_1[2*ib64+0].ds) * ls1 * sumi1 + __low2float(bq8_1[2*ib64+1].ds) * ls2 * sumi2);
+
+}
+
__device__ __forceinline__ void vec_dot_iq3_ks_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iiqs, float * result) {
@@ -1280,6 +1326,14 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
+void mul_mat_vec_iq2_kl_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
+
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KL, VDR_IQ3_K_Q8_1_MMVQ, vec_dot_iq2_kl_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+}
+
void mul_mat_vec_iq3_ks_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh
index c2416b1e..d14c3541 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cuh
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh
@@ -16,6 +16,11 @@ void mul_mat_vec_iq3_k_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
+void mul_mat_vec_iq2_kl_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
+
void mul_mat_vec_iq3_ks_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index 231c4a41..cde5d044 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -94,6 +94,9 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_IQ4_NL:
mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
break;
+ case GGML_TYPE_IQ2_KL:
+ mul_mat_q_case<GGML_TYPE_IQ2_KL>(ctx, args, stream);
+ break;
case GGML_TYPE_IQ3_KS:
mul_mat_q_case<GGML_TYPE_IQ3_KS>(ctx, args, stream);
break;
@@ -201,6 +204,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ1_S_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ2_KL:
case GGML_TYPE_IQ3_KS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KS_R4:
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index ee34452a..21b50082 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -88,6 +88,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
case GGML_TYPE_IQ3_K:
+ case GGML_TYPE_IQ2_KL:
case GGML_TYPE_IQ3_KS:
case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_KS:
@@ -201,6 +202,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_IQ1_S_R4: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ2_KL : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ3_KS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_KS_R4 : return MMQ_DP4A_TXS_Q8_0;
@@ -257,6 +259,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_IQ1_S_R4: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ2_KL : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ3_KS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_KS_R4 : return MMQ_MMA_TILE_X_K_Q8_0;
@@ -4156,6 +4159,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KL);
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4);
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 2b619f67..d0746031 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -518,6 +518,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
case GGML_TYPE_IQ3_K:
mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
+ case GGML_TYPE_IQ2_KL:
+ mul_mat_vec_iq2_kl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+ break;
case GGML_TYPE_IQ3_KS:
mul_mat_vec_iq3_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
@@ -682,6 +685,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
+ case GGML_TYPE_IQ2_KL:
case GGML_TYPE_IQ3_KS:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kl.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kl.cu
new file mode 100644
index 00000000..a5c22879
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kl.cu
@@ -0,0 +1,70 @@
+#include "../mmq.cuh"
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_kl(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x/4;
+
+ uint32_t aux32[2];
+ const uint8_t * a8 = (const uint8_t *)aux32;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
+ int i = i0 + 4*threadIdx.y + threadIdx.x%4;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const half * dptr = (const half *)(x + i*stride);
+ const float d = *dptr;
+ const block_iq2_kl * bxi = (const block_iq2_kl *)(dptr + 1) + kbx0;
+
+ #pragma unroll
+ for (int j = 0; j < 2; ++j) {
+ auto ql = get_int_b2(bxi->qs, 4*(kqsx/2) + 2*(kqsx%2) + j);
+ auto qh = get_int_b2(bxi->qh, 2*(kqsx%2) + j) >> 2*(kqsx/2);
+ aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh << 4) & 0x10101010);
+ aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh << 3) & 0x10101010);
+ #pragma unroll
+ for (int l = 0; l < 2; ++l) {
+ int val1 = iq2kl_values[a8[2*l+0]] | (iq2kl_values[a8[2*l+1]] << 16);
+ int val2 = iq2kl_values[a8[2*l+4]] | (iq2kl_values[a8[2*l+5]] << 16);
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 0] = val1;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 8] = val2;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 0] = val1;
+ x_qs[i*(2*WARP_SIZE + 1) + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 8] = val2;
+#endif
+ }
+ }
+
+ int ls = int(((bxi->scales_l[kqsx%4] >> 4*(kqsx/4)) & 0xf) | (((bxi->scales_h >> 2*kqsx) & 3) << 4)) - 32;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = d * ls;
+#endif
+ }
+
+}
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_KL> {
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_kl<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_KL);