diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-06-05 08:31:20 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-05 08:31:20 +0300 |
commit | 8ffad187abbb93b74db8ef813b6fdceec80e02b0 (patch) | |
tree | 2c078dfcbe2dd36b46675651ae5d91abd41641f4 | |
parent | 0b10f7418f7315ef90e35da49e0c053b395fd528 (diff) |
MMQ implementation for IQ4_KS_R4 and IQ5_KS_R4 (#493)
* MMQ for iq4_ks_r4
* MMQ for iq5_ks_r4
* Add forgotten file
* Another forgotten file
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 14 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 15 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cu | 8 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cuh | 163 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_r4.cu | 5 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_r4.cu | 5 |
6 files changed, 167 insertions, 43 deletions
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index bc7fadb0..8f3d2a26 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -600,6 +600,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KS> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KS_R4> { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KSS> { static constexpr int qk = QK_K; static constexpr int qr = QR4_XS; @@ -621,6 +628,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ5_KS> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ5_KS_R4> { + static constexpr int qk = QK_K; + static constexpr int qr = QR5_XS; + static constexpr int qi = QI5_XS; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ6_K> { static constexpr int qk = QK_K; static constexpr int qr = QR6_XS; diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 2340b54a..747af5a7 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -36,21 +36,6 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K_R4> { static constexpr int qi = QI5_XS; }; -template<> -struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KS_R4> { - static constexpr int qk = QK_K; - static constexpr int qr = QR4_XS; - static constexpr int qi = QI4_XS; -}; - -template<> -struct ggml_cuda_type_traits<GGML_TYPE_IQ5_KS_R4> { - static constexpr int qk = QK_K; - static constexpr int qr = QR5_XS; - static constexpr int qi = QI5_XS; -}; - - // Reminder: // constexpr int qk = ggml_cuda_type_traits<type>::qk; // constexpr int qi = ggml_cuda_type_traits<type>::qi; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 60c2037f..a13be11b 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -97,9 +97,15 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ4_KS: mul_mat_q_case<GGML_TYPE_IQ4_KS>(ctx, args, stream); break; + case GGML_TYPE_IQ4_KS_R4: + mul_mat_q_case<GGML_TYPE_IQ4_KS_R4>(ctx, args, stream); + break; case GGML_TYPE_IQ5_KS: mul_mat_q_case<GGML_TYPE_IQ5_KS>(ctx, args, stream); break; + case GGML_TYPE_IQ5_KS_R4: + mul_mat_q_case<GGML_TYPE_IQ5_KS_R4>(ctx, args, stream); + break; case GGML_TYPE_IQ2_KS: mul_mat_q_case<GGML_TYPE_IQ2_KS>(ctx, args, stream); break; @@ -157,7 +163,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ5_KS: + case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_K: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index c98fa561..608de8f0 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -87,9 +87,11 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ5_KS: + case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ6_K: return MMQ_Q8_1_DS_LAYOUT_D4; default: @@ -191,7 +193,9 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_KS_R4 : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ5_KS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ5_KS_R4 : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_KS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ3_K : return MMQ_DP4A_TXS_Q8_0_16; @@ -237,7 +241,9 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_KS_R4 : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ5_KS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ5_KS_R4 : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_KS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ3_K : return MMQ_MMA_TILE_X_K_Q3_K; @@ -2732,6 +2738,119 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin } } +template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_KS_R4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const float * dptr = (const float *)(x + 4*i4*stride); + const block_iq4_ks_r4 * bxi = (const block_iq4_ks_r4 *)(dptr + 4) + kbx0; + + const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127; + auto values = iq4k_values + ((bxi->scales[4*kqsx+ir] & 1) << 4); +#pragma unroll + for (int j = 0; j < 4; ++j) { + const int q4 = get_int_b4(bxi->qs, 16*kqsx+4*j+ir); + const int2 v = get_int_from_table_16(q4, values); + const int k0 = 8*kqsx + 4*(j%2) + j/2; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = v.y; +#endif // INT8_MMA_AVAILABLE + } +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[ir] * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[ir] * ls; +#endif // INT8_MMA_AVAILABLE + + } + +} + +template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_ks_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ5_KS_R4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const float * dptr = (const float *)(x + 4*i4*stride); + const block_iq5_ks_r4 * bxi = (const block_iq5_ks_r4 *)(dptr + 4) + kbx0; + + const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127; + auto values = iq5nl_values + ((bxi->scales[4*kqsx+ir] & 1) << 5); + + int qh = *((const int *)bxi->qh + 4*kqsx + ir); + const int * ql = (const int *)bxi->qs + 16*kqsx + ir; +#pragma unroll + for (int j = 0; j < 4; ++j) { + aux32[0] = ((ql[4*j] >> 0) & 0x0f0f0f0f) | ((qh << 4) & 0x10101010); + aux32[1] = ((ql[4*j] >> 4) & 0x0f0f0f0f) | ((qh << 3) & 0x10101010); + qh >>= 2; + const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]); + const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]); + const int k0 = 8*kqsx + 4*(j%2) + j/2; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = *(const int *)&val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = *(const int *)&val1; +#else + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = *(const int *)&val0; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = *(const int *)&val1; +#endif // INT8_MMA_AVAILABLE + } +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[ir] * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[ir] * ls; +#endif // INT8_MMA_AVAILABLE + + } + +} + template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_k( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -3069,7 +3188,6 @@ struct mmq_type_traits; template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> { - static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3077,7 +3195,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> { - static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3085,7 +3202,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> { - static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3093,7 +3209,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> { - static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3101,7 +3216,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_0> { - static constexpr int vdr = VDR_Q6_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_0<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3109,7 +3223,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_0> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> { - static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3117,7 +3230,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> { - static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3125,7 +3237,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> { - static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3133,7 +3244,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> { - static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3141,7 +3251,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> { - static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3149,7 +3258,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> { - static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3157,7 +3265,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> { - static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3165,7 +3272,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XS> { - static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3173,7 +3279,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XS> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_S> { - static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3181,7 +3286,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_S> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_XXS> { - static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3189,7 +3293,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_XXS> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_S> { - static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3197,7 +3300,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_S> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> { - static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3205,7 +3307,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S_R4> { - static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s_r4<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3213,7 +3314,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S_R4> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> { - static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3221,7 +3321,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> { - static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3229,7 +3328,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_K> { - static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_k<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3237,7 +3335,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_K> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_K> { - static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_k<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3245,7 +3342,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_K> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_K> { - static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_k<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3253,7 +3349,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_K> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_K> { - static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_k<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3261,7 +3356,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_K> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ6_K> { - static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq6_k<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3269,7 +3363,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ6_K> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_KS> { - static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_ks<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; @@ -3277,20 +3370,32 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_KS> { template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KS> { - static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_ks<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; }; template <int mmq_x, int mmq_y, int nwarps, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KS_R4> { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_ks_r4<mmq_y, nwarps, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; +}; + +template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_KS> { - static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; }; +template <int mmq_x, int mmq_y, int nwarps, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_KS_R4> { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks_r4<mmq_y, nwarps, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; +}; + template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup> static __device__ void mul_mat_q_process_tile( const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, @@ -3728,6 +3833,8 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_r4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_r4.cu new file mode 100644 index 00000000..d0f08ce8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_r4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_r4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_r4.cu new file mode 100644 index 00000000..0cc77dc0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_r4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4); |