summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r--ggml/src/ggml-cuda/common.cuh14
-rw-r--r--ggml/src/ggml-cuda/convert.cu23
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu137
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cuh15
-rw-r--r--ggml/src/ggml-cuda/mmq.cu12
-rw-r--r--ggml/src/ggml-cuda/mmq.cuh253
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu12
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu5
10 files changed, 472 insertions, 9 deletions
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 8f3d2a26..a0cdab28 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -579,6 +579,20 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KT> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ3_KT> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_XS;
+ static constexpr int qi = QI4_XS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KT> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_XS;
+ static constexpr int qi = QI4_XS;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index 01b7250e..b40079a3 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -340,6 +340,12 @@ inline __device__ int nearest_int(float fval) {
return (i & 0x007fffff) - 0x00400000;
}
+int __device__ __forceinline__ trellis_next_int(uint32_t& val) {
+ constexpr uint32_t ka = 0xCBAC1FED;
+ val = ka*val;
+ return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, -126);
+}
+
float __device__ __forceinline__ trellis_next(uint32_t& val) {
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
@@ -367,9 +373,9 @@ static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst
dst_t * y = yy + ii*QK_K + 8*ib;
const uint16_t * ql = (const uint16_t *)x[i].ql;
uint32_t idx = ql[ib] + 4096;
- const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f;
+ const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 1.05f;
for (int j = 0; j < 8; ++j) {
- y[j] = dl * trellis_next(idx);
+ y[j] = dl * trellis_next_int(idx);
}
}
@@ -388,10 +394,10 @@ static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst
dst_t * y = yy + ii*QK_K + 8*ib;
const uint16_t * ql = (const uint16_t *)x[i].ql;
uint32_t idx = ql[ib] + 4096;
- const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f;
+ const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 1.01f; //1.015f;
uint8_t mask = 1 << (ib/4);
for (int j = 0; j < 8; ++j) {
- y[j] = dl * std::abs(trellis_next(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
+ y[j] = dl * std::abs(trellis_next_int(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
}
}
@@ -401,9 +407,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
int64_t ii = blockIdx.x;
int64_t row = (QK_K * ii) / n_per_row;
const float * dptr = (const float *)((const char *)vx + row * row_size);
- float scale = dptr[0] * 31.75f * 1.01f;
- float row_av = dptr[1];
- const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
+ float scale = dptr[0] * 1.00f;
+ const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1);
const int64_t i = ii - (row*n_per_row)/QK_K;
constexpr int kNumGroups = 64;
@@ -423,8 +428,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
int ls = ((shb[ib32] & 0xff) >> 1) - 64;
const float dl = scale * ls;
for (int j = 0; j < 4; ++j) {
- y[j+0] = dl * trellis_next(idx1) + row_av;
- y[j+4] = dl * trellis_next(idx2) + row_av;
+ y[j+0] = dl * trellis_next_int(idx1);
+ y[j+4] = dl * trellis_next_int(idx2);
}
}
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index e5a224b4..c19215de 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -433,6 +433,119 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1(
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
}
+__device__ __forceinline__ void vec_dot_iq4_kt_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
+
+ constexpr uint32_t ka = 0xCBAC1FED;
+ constexpr uint32_t km = 0x3f3f3f3f;
+
+ float scale = *(const float *)vbq;
+ const block_iq4_kt * bq4 = (const block_iq4_kt *)((const char *)vbq + sizeof(float)) + kbx;
+
+ // iqs is 0...28
+ const int ib32 = iqs/4; // Why iqs/4 ?
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
+ //const int8_t * q8 = bq8_1[ib32].qs;
+ const int ls = (bq4->qs[ib32] & 0xff) >> 1;
+ const float dl = scale * (ls - 64);
+ const uint32_t idx0 = ((bq4->qs[ib32] & 1) << 15) + 4096;
+ auto ql = (const uint8_t *)(bq4->qs + 8);
+ auto qh = ql + 64;
+ ql += 8*ib32;
+ qh += 8*(ib32%4);
+ const int shift1 = 8 - 4*(ib32/4);
+ int sumi = 0;
+ for (int j = 0; j < 8; ++j) {
+ const uint32_t sh = bq4->qs[ib32] >> (8 + 3*j);
+ uint32_t val = ql[j] + ((qh[j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0;
+ int v4 = 0;
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ //int s = val & km;
+ //sumi += q8[4*j+k] * ggml_cuda_dp4a(s, 0x01010101, -126);
+ v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
+ }
+ sumi = ggml_cuda_dp4a(v4, q8[j], sumi);
+ }
+ *result += dl * __low2float(bq8_1[ib32].ds) * sumi;
+}
+
+__device__ __forceinline__ void vec_dot_iq2_kt_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
+
+ constexpr uint32_t ka = 0xCBAC1FED;
+ constexpr uint32_t km = 0x3f3f3f3f;
+
+ float scale = *(const float *)vbq;
+ const block_iq2_kt * bq2 = (const block_iq2_kt *)((const char *)vbq + sizeof(float)) + kbx;
+
+ // iqs is 0...28
+ const int ib32 = iqs/4;
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
+ const int ls = iq4k_values[(bq2->scales[ib32%4] >> 4*(ib32/4)) & 0xf];
+ const float dl = scale * ls * 1.05f;
+ auto ql = (const uint16_t *)bq2->ql;
+ int sumi = 0;
+ for (int j = 0; j < 4; ++j) {
+ uint32_t val = ql[4*ib32+j] + 4096;
+ int v4 = 0;
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
+ }
+ sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi);
+ v4 = 0;
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
+ }
+ sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi);
+ }
+ *result += dl * __low2float(bq8_1[ib32].ds) * sumi;
+}
+
+__device__ __forceinline__ void vec_dot_iq3_kt_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
+
+ constexpr uint32_t ka = 0xCBAC1FED;
+ constexpr uint32_t km = 0x3f3f3f3f;
+
+ float scale = *(const float *)vbq;
+ const block_iq3_kt * bq3 = (const block_iq3_kt *)((const char *)vbq + sizeof(float)) + kbx;
+
+ // iqs is 0...28
+ const int ib32 = iqs/4;
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
+ const int ls = (bq3->scales[ib32%4] >> 4*(ib32/4)) & 0xf;
+ const float dl = scale * ls * 1.015f;
+ auto ql = (const uint16_t *)bq3->ql;
+ uint32_t mask = 0x01010101 << ib32;
+ const uint32_t * qh = (const uint32_t *)bq3->qh;
+ int sumi = 0;
+ for (int j = 0; j < 4; ++j) {
+ uint32_t val = ql[4*ib32+j] + 4096;
+ int v4 = 0;
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ int8_t q = std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126));
+ v4 |= q << 8*k;
+ }
+ uint32_t signs = __vcmpne4(qh[2*j+0] & mask, 0);
+ v4 = __vsub4(v4 ^ signs, signs);
+ sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi);
+ v4 = 0;
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ int8_t q = std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126));
+ v4 |= q << 8*k;
+ }
+ signs = __vcmpne4(qh[2*j+1] & mask, 0);
+ v4 = __vsub4(v4 ^ signs, signs);
+ sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi);
+ }
+ *result += dl * __low2float(bq8_1[ib32].ds) * sumi;
+}
+
#define VDR_IQ4_KSS_Q8_1_MMVQ 4
#define VDR_IQ4_KSS_Q8_1_MMQ 4
@@ -1217,6 +1330,30 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
+void mul_mat_vec_iq4_kt_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
+
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+}
+
+void mul_mat_vec_iq2_kt_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
+
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq2_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+}
+
+void mul_mat_vec_iq3_kt_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
+
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq3_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+}
+
void mul_mat_vec_iq4_kss_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh
index 17bf5ad2..e7c6e1d2 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cuh
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh
@@ -100,3 +100,18 @@ void mul_mat_vec_iq1_m_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
+
+void mul_mat_vec_iq2_kt_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
+
+void mul_mat_vec_iq3_kt_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
+
+void mul_mat_vec_iq4_kt_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index a13be11b..9103e3f1 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -100,6 +100,15 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_IQ4_KS_R4:
mul_mat_q_case<GGML_TYPE_IQ4_KS_R4>(ctx, args, stream);
break;
+ case GGML_TYPE_IQ4_KT:
+ mul_mat_q_case<GGML_TYPE_IQ4_KT>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ2_KT:
+ mul_mat_q_case<GGML_TYPE_IQ2_KT>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ3_KT:
+ mul_mat_q_case<GGML_TYPE_IQ3_KT>(ctx, args, stream);
+ break;
case GGML_TYPE_IQ5_KS:
mul_mat_q_case<GGML_TYPE_IQ5_KS>(ctx, args, stream);
break;
@@ -172,6 +181,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
mmq_supported = true;
break;
default:
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 608de8f0..c6e6a365 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -93,6 +93,9 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
case GGML_TYPE_IQ5_KS:
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ6_K:
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
return MMQ_Q8_1_DS_LAYOUT_D4;
default:
GGML_ABORT("fatal error");
@@ -202,6 +205,9 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16;
+ case GGML_TYPE_IQ2_KT : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ3_KT : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ4_KT : return MMQ_DP4A_TXS_Q8_0;
default : return tile_x_sizes{0, 0, 0};
}
}
@@ -250,6 +256,9 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K;
+ case GGML_TYPE_IQ2_KT : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ3_KT : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ4_KT : return MMQ_MMA_TILE_X_K_Q8_0;
default : return 0;
}
}
@@ -2790,6 +2799,226 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_kt(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+ constexpr uint32_t ka = 0xCBAC1FED;
+ constexpr uint32_t km = 0x3f3f3f3f;
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_kt * bxi = (const block_iq4_kt *)(x + i*stride + sizeof(float)) + kbx0;
+
+ int ib32 = kqsx/4;
+ int j = kqsx%4;
+ const auto shb = bxi->qs;
+ const auto ql = (const uint8_t *)(shb + 8);
+ const auto qh = ql + 64;
+ const uint32_t sh = shb[ib32] >> (8 + 6*j);
+ uint32_t offset = 4096 + ((shb[ib32] & 1) << 15);
+ uint32_t val1 = offset + ql[8*ib32+2*j+0] + ((qh[8*(ib32%4)+2*j+0] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 7) << 12);
+ uint32_t val2 = offset + ql[8*ib32+2*j+1] + ((qh[8*(ib32%4)+2*j+1] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 56) << 9);
+ int2 v = {0, 0};
+ for (int k = 0; k < 4; ++k) {
+ val1 *= ka;
+ val2 *= ka;
+ v.x |= (ggml_cuda_dp4a(val1 & km, 0x01010101, -126) & 0xff) << 8*k;
+ v.y |= (ggml_cuda_dp4a(val2 & km, 0x01010101, -126) & 0xff) << 8*k;
+ }
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+ int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const float * dptr = (const float *)(x + i*stride);
+ const block_iq4_kt * bxi = (const block_iq4_kt *)(dptr + 1) + kbx0;
+ const int ls = (bxi->qs[threadIdx.x % 8] & 0xff) >> 1;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * (ls - 64);
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * (ls - 64);
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_kt(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+ constexpr uint32_t ka = 0xCBAC1FED;
+ constexpr uint32_t km = 0x3f3f3f3f;
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq2_kt * bxi = (const block_iq2_kt *)(x + i*stride + sizeof(float)) + kbx0;
+
+ int ib32 = kqsx/4;
+ int j = kqsx%4;
+ const auto ql = (const uint16_t *)bxi->ql;
+ uint32_t val = ql[4*ib32+j] + 4096;
+ int2 v = {0, 0};
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ v.x |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
+ }
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ v.y |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
+ }
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+ int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const float * dptr = (const float *)(x + i*stride);
+ const float d = dptr[0] * 1.05f;
+ const block_iq2_kt * bxi = (const block_iq2_kt *)(dptr + 1) + kbx0;
+ int ib32 = threadIdx.x % 8;
+ const int ls = iq4k_values[(bxi->scales[ib32%4] >> 4*(ib32/4)) & 0xf];
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_kt(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+ constexpr uint32_t ka = 0xCBAC1FED;
+ constexpr uint32_t km = 0x3f3f3f3f;
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq3_kt * bxi = (const block_iq3_kt *)(x + i*stride + sizeof(float)) + kbx0;
+
+ int ib32 = kqsx/4;
+ int j = kqsx%4;
+ const auto ql = (const uint16_t *)bxi->ql;
+ const auto qh = (const uint32_t *)bxi->qh;
+ uint32_t mask = 0x01010101 << ib32;
+ uint32_t val = ql[4*ib32+j] + 4096;
+ int2 v = {0, 0};
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ v.x |= std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)) << 8*k;
+ }
+ auto signs = __vcmpne4(qh[2*j+0] & mask, 0);
+ v.x = __vsub4(v.x ^ signs, signs);
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ v.y |= std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)) << 8*k;
+ }
+ signs = __vcmpne4(qh[2*j+1] & mask, 0);
+ v.y = __vsub4(v.y ^ signs, signs);
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+ int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const float * dptr = (const float *)(x + i*stride);
+ const float d = dptr[0] * 1.01f;
+ const block_iq3_kt * bxi = (const block_iq3_kt *)(dptr + 1) + kbx0;
+ int ib32 = threadIdx.x % 8;
+ const int ls = (bxi->scales[ib32%4] >> 4*(ib32/4)) & 0xf;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_ks_r4(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
@@ -3383,6 +3612,27 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KS_R4> {
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KT> {
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_kt<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_KT> {
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_kt<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_KT> {
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_kt<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_KS> {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
@@ -3843,6 +4093,9 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KT);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KT);
// -------------------------------------------------------------------------------------------------------------------------
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 73caabab..6412be30 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -527,6 +527,15 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
case GGML_TYPE_IQ4_KSS:
mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
+ case GGML_TYPE_IQ2_KT:
+ mul_mat_vec_iq2_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+ break;
+ case GGML_TYPE_IQ3_KT:
+ mul_mat_vec_iq3_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+ break;
+ case GGML_TYPE_IQ4_KT:
+ mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+ break;
case GGML_TYPE_IQ2_KS:
mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
@@ -687,6 +696,9 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ1_S_R4:
case GGML_TYPE_IQ1_M_R4:
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
return true;
default:
return false;
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu
new file mode 100644
index 00000000..2d48f077
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_KT);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu
new file mode 100644
index 00000000..978bc6ca
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ3_KT);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu
new file mode 100644
index 00000000..4590f919
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ4_KT);