From 8ffad187abbb93b74db8ef813b6fdceec80e02b0 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 5 Jun 2025 08:31:20 +0300 Subject: MMQ implementation for IQ4_KS_R4 and IQ5_KS_R4 (#493) * MMQ for iq4_ks_r4 * MMQ for iq5_ks_r4 * Add forgotten file * Another forgotten file --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/common.cuh | 14 ++ ggml/src/ggml-cuda/iqk_mmvq.cu | 15 -- ggml/src/ggml-cuda/mmq.cu | 8 + ggml/src/ggml-cuda/mmq.cuh | 163 +++++++++++++++++---- .../template-instances/mmq-instance-iq4_ks_r4.cu | 5 + .../template-instances/mmq-instance-iq5_ks_r4.cu | 5 + 6 files changed, 167 insertions(+), 43 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_r4.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_r4.cu diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index bc7fadb0..8f3d2a26 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -599,6 +599,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI4_XS; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; @@ -620,6 +627,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI5_XS; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR5_XS; + static constexpr int qi = QI5_XS; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 2340b54a..747af5a7 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -36,21 +36,6 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI5_XS; }; -template<> -struct ggml_cuda_type_traits { - static constexpr int qk = QK_K; - static constexpr int qr = QR4_XS; - static constexpr int qi = QI4_XS; -}; - -template<> -struct ggml_cuda_type_traits { - static constexpr int qk = QK_K; - static constexpr int qr = QR5_XS; - static constexpr int qi = QI5_XS; -}; - - // Reminder: // constexpr int qk = ggml_cuda_type_traits::qk; // constexpr int qi = ggml_cuda_type_traits::qi; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 60c2037f..a13be11b 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -97,9 +97,15 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ4_KS: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_IQ4_KS_R4: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_IQ5_KS: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_IQ5_KS_R4: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_IQ2_KS: mul_mat_q_case(ctx, args, stream); break; @@ -157,7 +163,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ5_KS: + case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_K: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index c98fa561..608de8f0 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -87,9 +87,11 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ5_KS: + case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ6_K: return MMQ_Q8_1_DS_LAYOUT_D4; default: @@ -191,7 +193,9 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_KS_R4 : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ5_KS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ5_KS_R4 : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_KS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ3_K : return MMQ_DP4A_TXS_Q8_0_16; @@ -237,7 +241,9 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_KS_R4 : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ5_KS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ5_KS_R4 : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_KS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ3_K : return MMQ_MMA_TILE_X_K_Q3_K; @@ -2732,6 +2738,119 @@ template static __device__ __forceinlin } } +template static __device__ __forceinline__ void load_tiles_iq4_ks_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_KS_R4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const float * dptr = (const float *)(x + 4*i4*stride); + const block_iq4_ks_r4 * bxi = (const block_iq4_ks_r4 *)(dptr + 4) + kbx0; + + const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127; + auto values = iq4k_values + ((bxi->scales[4*kqsx+ir] & 1) << 4); +#pragma unroll + for (int j = 0; j < 4; ++j) { + const int q4 = get_int_b4(bxi->qs, 16*kqsx+4*j+ir); + const int2 v = get_int_from_table_16(q4, values); + const int k0 = 8*kqsx + 4*(j%2) + j/2; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = v.y; +#endif // INT8_MMA_AVAILABLE + } +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[ir] * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[ir] * ls; +#endif // INT8_MMA_AVAILABLE + + } + +} + +template static __device__ __forceinline__ void load_tiles_iq5_ks_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ5_KS_R4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const float * dptr = (const float *)(x + 4*i4*stride); + const block_iq5_ks_r4 * bxi = (const block_iq5_ks_r4 *)(dptr + 4) + kbx0; + + const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127; + auto values = iq5nl_values + ((bxi->scales[4*kqsx+ir] & 1) << 5); + + int qh = *((const int *)bxi->qh + 4*kqsx + ir); + const int * ql = (const int *)bxi->qs + 16*kqsx + ir; +#pragma unroll + for (int j = 0; j < 4; ++j) { + aux32[0] = ((ql[4*j] >> 0) & 0x0f0f0f0f) | ((qh << 4) & 0x10101010); + aux32[1] = ((ql[4*j] >> 4) & 0x0f0f0f0f) | ((qh << 3) & 0x10101010); + qh >>= 2; + const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]); + const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]); + const int k0 = 8*kqsx + 4*(j%2) + j/2; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = *(const int *)&val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = *(const int *)&val1; +#else + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = *(const int *)&val0; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = *(const int *)&val1; +#endif // INT8_MMA_AVAILABLE + } +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[ir] * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[ir] * ls; +#endif // INT8_MMA_AVAILABLE + + } + +} + template static __device__ __forceinline__ void load_tiles_iq4_k( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -3069,7 +3188,6 @@ struct mmq_type_traits; template struct mmq_type_traits { - static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; @@ -3077,7 +3195,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; @@ -3085,7 +3202,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3093,7 +3209,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; @@ -3101,7 +3216,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_Q6_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3109,7 +3223,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3117,7 +3230,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; @@ -3125,7 +3237,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; @@ -3133,7 +3244,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; @@ -3141,7 +3251,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; @@ -3149,7 +3258,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; @@ -3157,7 +3265,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3165,7 +3272,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; @@ -3173,7 +3279,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; @@ -3181,7 +3286,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3189,7 +3293,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3197,7 +3300,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; @@ -3205,7 +3307,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s_r4; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; @@ -3213,7 +3314,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3221,7 +3321,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3229,7 +3328,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_k; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; @@ -3237,7 +3335,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_k; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; @@ -3245,7 +3342,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_k; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; @@ -3253,7 +3349,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_k; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; @@ -3261,7 +3356,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq6_k; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; @@ -3269,7 +3363,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_ks; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; @@ -3277,20 +3370,32 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_ks; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_ks_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template struct mmq_type_traits { - static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template static __device__ void mul_mat_q_process_tile( const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, @@ -3728,6 +3833,8 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_r4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_r4.cu new file mode 100644 index 00000000..d0f08ce8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_r4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_r4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_r4.cu new file mode 100644 index 00000000..0cc77dc0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_r4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4); -- cgit v1.2.3