summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-05-04 12:45:00 +0300
committerGitHub <noreply@github.com>2025-05-04 12:45:00 +0300
commitf7c9a0f036951fecab32e056df954ebc54f8688f (patch)
tree277a7c5ee63fda3841488e38a1dda9d2a43e0094
parent13281282986fb6783d0d7d64b3610bfb7085e749 (diff)
CUDA: MMQ for IQ4_KS (#374)
* WIP * WIP: still getting illegal memory access * CUDA: MMQ for iq4_ks now works ~25% faster than dequantize+cuBLAS, ~10% slower than Q4_0 MMQ. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-cuda/mmq.cu8
-rw-r--r--ggml/src/ggml-cuda/mmq.cuh148
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu5
-rw-r--r--ggml/src/ggml-cuda/vecdotq.cuh12
4 files changed, 133 insertions, 40 deletions
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index 3b959182..67897a83 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -14,6 +14,7 @@ void ggml_cuda_op_mul_mat_q(
const int64_t src1_padded_row_size, cudaStream_t stream) {
const int64_t ne00 = src0->ne[0];
+ const int64_t nb01 = src0->nb[1];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
@@ -22,7 +23,6 @@ void ggml_cuda_op_mul_mat_q(
const int64_t ne0 = dst->ne[0];
const int64_t row_diff = row_high - row_low;
- const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
@@ -31,7 +31,7 @@ void ggml_cuda_op_mul_mat_q(
// nrows_dst == nrows of the matrix that the kernel writes into
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
- const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst};
+ const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, nb01, src1_padded_row_size, src1_ncols, ne11, nrows_dst};
switch (src0->type) {
case GGML_TYPE_Q4_0:
@@ -91,6 +91,9 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_IQ4_NL:
mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
break;
+ case GGML_TYPE_IQ4_KS:
+ mul_mat_q_case<GGML_TYPE_IQ4_KS>(ctx, args, stream);
+ break;
default:
GGML_ABORT("fatal error");
break;
@@ -128,6 +131,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_KS:
mmq_supported = true;
break;
default:
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 753848d7..148697e2 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -82,6 +82,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
return MMQ_Q8_1_DS_LAYOUT_DS4;
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_KS:
return MMQ_Q8_1_DS_LAYOUT_D4;
default:
GGML_ABORT("fatal error");
@@ -179,6 +180,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_IQ1_S : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0;
default : return tile_x_sizes{0, 0, 0};
}
}
@@ -216,6 +218,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_IQ1_S : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0;
default : return 0;
}
}
@@ -261,7 +264,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
+ const block_q4_0 * bxi = (const block_q4_0 *)(x + i*stride) + kbx0 + kbx;
const int qs0 = get_int_b2(bxi->qs, kqsx);
#ifdef INT8_MMA_AVAILABLE
@@ -283,7 +286,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
+ const block_q4_0 * bxi = (const block_q4_0 *)(x + i*stride) + kbx0 + kbxd;
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
@@ -356,7 +359,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
+ const block_q4_1 * bxi = (const block_q4_1 *)(x + i*stride) + kbx0 + kbx;
const int qs0 = get_int_b4(bxi->qs, kqsx);
#ifdef INT8_MMA_AVAILABLE
@@ -378,7 +381,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
+ const block_q4_1 * bxi = (const block_q4_1 *)(x + i*stride) + kbx0 + kbxd;
#ifdef INT8_MMA_AVAILABLE
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
@@ -451,7 +454,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
+ const block_q5_0 * bxi = (const block_q5_0 *)(x + i*stride) + kbx0 + kbx;
const int ql = get_int_b2(bxi->qs, kqsx);
const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
@@ -490,7 +493,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
+ const block_q5_0 * bxi = (const block_q5_0 *)(x + i*stride) + kbx0 + kbxd;
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
@@ -523,7 +526,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
+ const block_q5_1 * bxi = (const block_q5_1 *)(x + i*stride) + kbx0 + kbx;
const int ql = get_int_b4(bxi->qs, kqsx);
const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
@@ -560,7 +563,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
+ const block_q5_1 * bxi = (const block_q5_1 *)(x + i*stride) + kbx0 + kbxd;
#ifdef INT8_MMA_AVAILABLE
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
@@ -593,7 +596,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbx;
+ const block_q6_0 * bxi = (const block_q6_0 *)(x + i*stride) + kbx0 + kbx;
const int ql = get_int_b2(bxi->qs, kqsx);
const int qh = get_int_b2(bxi->qh, kqsx%2) >> 4*(kqsx/2);
@@ -623,7 +626,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbxd;
+ const block_q6_0 * bxi = (const block_q6_0 *)(x + i*stride) + kbx0 + kbxd;
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
@@ -656,7 +659,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
+ const block_q8_0 * bxi = (const block_q8_0 *)(x + i*stride) + kbx0 + kbx;
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
@@ -678,7 +681,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
+ const block_q8_0 * bxi = (const block_q8_0 *)(x + i*stride) + kbx0 + kbxd;
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
@@ -1044,7 +1047,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride;
+ const block_q2_K * bxi = (const block_q2_K *)(x + i*stride) + kbx0;
const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
@@ -1275,7 +1278,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+ const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0;
const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
@@ -1305,7 +1308,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+ const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0;
const int ksc = threadIdx.x % (WARP_SIZE/8);
@@ -1341,7 +1344,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+ const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0;
x_df[i] = bxi->d;
}
@@ -1412,7 +1415,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+ const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0;
const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
#ifdef INT8_MMA_AVAILABLE
@@ -1433,7 +1436,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+ const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0;
const int * scales = (const int *) bxi->scales;
const int ksc = threadIdx.x % (WARP_SIZE/16);
@@ -1462,7 +1465,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+ const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0;
x_dm[i] = bxi->dm;
}
@@ -1475,7 +1478,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
+ const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0 + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
const int * scales = (const int *) bxi->scales;
@@ -1541,7 +1544,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+ const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0;
const int ky = QR5_K*threadIdx.x;
const int ql = get_int_b4(bxi->qs, threadIdx.x);
@@ -1574,7 +1577,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+ const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0;
const int * scales = (const int *) bxi->scales;
const int ksc = threadIdx.x % (WARP_SIZE/16);
@@ -1603,7 +1606,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+ const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0;
x_dm[i] = bxi->dm;
}
@@ -1616,7 +1619,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+ const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0;
const int * scales = (const int *) bxi->scales;
@@ -1683,7 +1686,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
+ const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0;
const int ql = get_int_b2(bxi->ql, threadIdx.x);
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
@@ -1716,7 +1719,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
+ const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0 + kbxd;
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d;
@@ -1733,7 +1736,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
+ const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0 + (threadIdx.x % (WARP_SIZE/8)) / 4;
#ifdef INT8_MMA_AVAILABLE
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
@@ -1908,7 +1911,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
+ const block_iq4_nl * bxi = (const block_iq4_nl *)(x + i*stride) + kbx0 + kbx;
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
const int2 v = get_int_from_table_16(aux_q4);
@@ -1933,7 +1936,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
+ const block_iq4_nl * bxi = (const block_iq4_nl *)(x + i*stride) + kbx0 + kbxd;
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
@@ -1965,7 +1968,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride;
+ const block_iq2_xxs * bxi = (const block_iq2_xxs *)(x + i*stride) + kbx0;
const int q2 = get_int_b2(bxi->qs, 2*kqsx+0);
const uint8_t * aux8 = (const uint8_t *) &q2;
@@ -2023,7 +2026,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride;
+ const block_iq2_xs * bxi = (const block_iq2_xs *)(x + i*stride) + kbx0;
const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
const uint16_t * q2 = (const uint16_t *) &q2_packed;
@@ -2079,7 +2082,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride;
+ const block_iq2_s * bxi = (const block_iq2_s *)(x + i*stride) + kbx0;
const int qs_packed = get_int_b2(bxi->qs, kqsx);
const uint8_t * qs = (const uint8_t *) &qs_packed;
@@ -2142,7 +2145,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride;
+ const block_iq3_xxs * bxi = (const block_iq3_xxs *)(x + i*stride) + kbx0;
const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
const uint8_t * q3 = (const uint8_t *) &q3_packed;
@@ -2198,7 +2201,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride;
+ const block_iq3_s * bxi = (const block_iq3_s *)(x + i*stride) + kbx0;
const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
const uint8_t * qs = (const uint8_t *) &qs_packed;
@@ -2261,7 +2264,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride;
+ const block_iq1_s * bxi = (const block_iq1_s *)(x + i*stride) + kbx0;
const int qs_packed = get_int_b2(bxi->qs, kqsx);
const uint8_t * qs = (const uint8_t *) &qs_packed;
@@ -2318,7 +2321,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
+ const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0 + kbx;
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
const int2 v = get_int_from_table_16(aux_q4);
@@ -2340,7 +2343,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i = min(i, i_max);
}
- const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
+ const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0;
const float d = __half2float(bxi->d);
@@ -2355,6 +2358,64 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
}
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kbx = 0; // threadIdx.x / QI4_XS
+ const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx;
+
+ auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4);
+ const int aux_q4 = get_int_b4(bxi->qs, kqsx);
+ const int2 v = get_int_from_table_16(aux_q4, values);
+ const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+ int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const float * dptr = (const float *)(x + i*stride);
+ const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
+ const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_dp4a(
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
@@ -2576,6 +2637,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KS> {
+ static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_ks<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
static __device__ void mul_mat_q_process_tile(
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
@@ -2608,7 +2677,7 @@ static __device__ void mul_mat_q_process_tile(
const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
- load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
+ load_tiles(x + stride01*it*mmq_y, tile_x, kb0, tile_x_max_i, stride01);
{
const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
@@ -2889,6 +2958,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
(args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
+
} else {
constexpr bool need_check = true;
@@ -2897,6 +2967,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
(args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
+
}
}
@@ -3010,6 +3081,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS);
// -------------------------------------------------------------------------------------------------------------------------
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu
new file mode 100644
index 00000000..940c2da8
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ4_KS);
diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh
index e9af29b9..cae5e04f 100644
--- a/ggml/src/ggml-cuda/vecdotq.cuh
+++ b/ggml/src/ggml-cuda/vecdotq.cuh
@@ -1131,6 +1131,18 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
}
+static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * values) {
+ const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
+ const int8_t * q0_8 = (const int8_t *) &q0_32;
+ const char4 val0_8 = make_char4(values[q0_8[0]], values[q0_8[1]], values[q0_8[2]], values[q0_8[3]]);
+
+ const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
+ const int8_t * q1_8 = (const int8_t *) &q1_32;
+ const char4 val1_8 = make_char4(values[q1_8[0]], values[q1_8[1]], values[q1_8[2]], values[q1_8[3]]);
+
+ return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
+}
+
#define VDR_IQ4_NL_Q8_1_MMVQ 2
#define VDR_IQ4_NL_Q8_1_MMQ 4