summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-06-05 08:31:20 +0300
committerGitHub <noreply@github.com>2025-06-05 08:31:20 +0300
commit8ffad187abbb93b74db8ef813b6fdceec80e02b0 (patch)
tree2c078dfcbe2dd36b46675651ae5d91abd41641f4
parent0b10f7418f7315ef90e35da49e0c053b395fd528 (diff)
MMQ implementation for IQ4_KS_R4 and IQ5_KS_R4 (#493)
* MMQ for iq4_ks_r4 * MMQ for iq5_ks_r4 * Add forgotten file * Another forgotten file --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-cuda/common.cuh14
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu15
-rw-r--r--ggml/src/ggml-cuda/mmq.cu8
-rw-r--r--ggml/src/ggml-cuda/mmq.cuh163
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_r4.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_r4.cu5
6 files changed, 167 insertions, 43 deletions
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index bc7fadb0..8f3d2a26 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -600,6 +600,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KS> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KS_R4> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_XS;
+ static constexpr int qi = QI4_XS;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KSS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
@@ -621,6 +628,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ5_KS> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ5_KS_R4> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR5_XS;
+ static constexpr int qi = QI5_XS;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ6_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR6_XS;
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index 2340b54a..747af5a7 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -36,21 +36,6 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K_R4> {
static constexpr int qi = QI5_XS;
};
-template<>
-struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KS_R4> {
- static constexpr int qk = QK_K;
- static constexpr int qr = QR4_XS;
- static constexpr int qi = QI4_XS;
-};
-
-template<>
-struct ggml_cuda_type_traits<GGML_TYPE_IQ5_KS_R4> {
- static constexpr int qk = QK_K;
- static constexpr int qr = QR5_XS;
- static constexpr int qi = QI5_XS;
-};
-
-
// Reminder:
// constexpr int qk = ggml_cuda_type_traits<type>::qk;
// constexpr int qi = ggml_cuda_type_traits<type>::qi;
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index 60c2037f..a13be11b 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -97,9 +97,15 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_IQ4_KS:
mul_mat_q_case<GGML_TYPE_IQ4_KS>(ctx, args, stream);
break;
+ case GGML_TYPE_IQ4_KS_R4:
+ mul_mat_q_case<GGML_TYPE_IQ4_KS_R4>(ctx, args, stream);
+ break;
case GGML_TYPE_IQ5_KS:
mul_mat_q_case<GGML_TYPE_IQ5_KS>(ctx, args, stream);
break;
+ case GGML_TYPE_IQ5_KS_R4:
+ mul_mat_q_case<GGML_TYPE_IQ5_KS_R4>(ctx, args, stream);
+ break;
case GGML_TYPE_IQ2_KS:
mul_mat_q_case<GGML_TYPE_IQ2_KS>(ctx, args, stream);
break;
@@ -157,7 +163,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ5_KS:
+ case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ3_K:
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index c98fa561..608de8f0 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -87,9 +87,11 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ5_KS:
+ case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ6_K:
return MMQ_Q8_1_DS_LAYOUT_D4;
default:
@@ -191,7 +193,9 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ4_KS_R4 : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ5_KS : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ5_KS_R4 : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ2_KS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ3_K : return MMQ_DP4A_TXS_Q8_0_16;
@@ -237,7 +241,9 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ4_KS_R4 : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ5_KS : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ5_KS_R4 : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ2_KS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ3_K : return MMQ_MMA_TILE_X_K_Q3_K;
@@ -2732,6 +2738,119 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
}
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks_r4(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_KS_R4, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x/4;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
+ int i = i0 + 4*threadIdx.y + threadIdx.x%4;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+ int i4 = i/4;
+ int ir = i%4;
+
+ const float * dptr = (const float *)(x + 4*i4*stride);
+ const block_iq4_ks_r4 * bxi = (const block_iq4_ks_r4 *)(dptr + 4) + kbx0;
+
+ const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127;
+ auto values = iq4k_values + ((bxi->scales[4*kqsx+ir] & 1) << 4);
+#pragma unroll
+ for (int j = 0; j < 4; ++j) {
+ const int q4 = get_int_b4(bxi->qs, 16*kqsx+4*j+ir);
+ const int2 v = get_int_from_table_16(q4, values);
+ const int k0 = 8*kqsx + 4*(j%2) + j/2;
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = v.y;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = v.y;
+#endif // INT8_MMA_AVAILABLE
+ }
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[ir] * ls;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[ir] * ls;
+#endif // INT8_MMA_AVAILABLE
+
+ }
+
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_ks_r4(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ5_KS_R4, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x/4;
+
+ uint32_t aux32[2];
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
+ int i = i0 + 4*threadIdx.y + threadIdx.x%4;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+ int i4 = i/4;
+ int ir = i%4;
+
+ const float * dptr = (const float *)(x + 4*i4*stride);
+ const block_iq5_ks_r4 * bxi = (const block_iq5_ks_r4 *)(dptr + 4) + kbx0;
+
+ const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127;
+ auto values = iq5nl_values + ((bxi->scales[4*kqsx+ir] & 1) << 5);
+
+ int qh = *((const int *)bxi->qh + 4*kqsx + ir);
+ const int * ql = (const int *)bxi->qs + 16*kqsx + ir;
+#pragma unroll
+ for (int j = 0; j < 4; ++j) {
+ aux32[0] = ((ql[4*j] >> 0) & 0x0f0f0f0f) | ((qh << 4) & 0x10101010);
+ aux32[1] = ((ql[4*j] >> 4) & 0x0f0f0f0f) | ((qh << 3) & 0x10101010);
+ qh >>= 2;
+ const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]);
+ const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]);
+ const int k0 = 8*kqsx + 4*(j%2) + j/2;
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = *(const int *)&val0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = *(const int *)&val1;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = *(const int *)&val0;
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = *(const int *)&val1;
+#endif // INT8_MMA_AVAILABLE
+ }
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[ir] * ls;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[ir] * ls;
+#endif // INT8_MMA_AVAILABLE
+
+ }
+
+}
+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_k(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
@@ -3069,7 +3188,6 @@ struct mmq_type_traits;
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
- static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3077,7 +3195,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
- static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3085,7 +3202,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
- static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3093,7 +3209,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
- static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3101,7 +3216,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_0> {
- static constexpr int vdr = VDR_Q6_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_0<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3109,7 +3223,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_0> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
- static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3117,7 +3230,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
- static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3125,7 +3237,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
- static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3133,7 +3244,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
- static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3141,7 +3251,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
- static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3149,7 +3258,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
- static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3157,7 +3265,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
- static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3165,7 +3272,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XS> {
- static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3173,7 +3279,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XS> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_S> {
- static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3181,7 +3286,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_S> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_XXS> {
- static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3189,7 +3293,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_XXS> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_S> {
- static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3197,7 +3300,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_S> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> {
- static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3205,7 +3307,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S_R4> {
- static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s_r4<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3213,7 +3314,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S_R4> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
- static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3221,7 +3321,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
- static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3229,7 +3328,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_K> {
- static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_k<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3237,7 +3335,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_K> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_K> {
- static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_k<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3245,7 +3342,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_K> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_K> {
- static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_k<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3253,7 +3349,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_K> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_K> {
- static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_k<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3261,7 +3356,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_K> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ6_K> {
- static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq6_k<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3269,7 +3363,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ6_K> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_KS> {
- static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_ks<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
@@ -3277,20 +3370,32 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_KS> {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KS> {
- static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_ks<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KS_R4> {
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_ks_r4<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_KS> {
- static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_KS_R4> {
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks_r4<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
static __device__ void mul_mat_q_process_tile(
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
@@ -3728,6 +3833,8 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_K);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_r4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_r4.cu
new file mode 100644
index 00000000..d0f08ce8
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_r4.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_r4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_r4.cu
new file mode 100644
index 00000000..0cc77dc0
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_r4.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4);