From 7d1a378b8fb266782d9248538a661405aad80768 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 5 Jun 2024 16:53:00 +0200 Subject: CUDA: refactor mmq, dmmv, mmvq (#7716) * CUDA: refactor mmq, dmmv, mmvq * fix out-of-bounds write * struct for qk, qr, qi * fix cmake build * mmq_type_traits --- ggml-cuda/common.cuh | 157 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 156 insertions(+), 1 deletion(-) (limited to 'ggml-cuda/common.cuh') diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 22872ca5..90a0a81e 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -160,7 +160,7 @@ #endif #define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels -#define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available +#define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -484,6 +484,161 @@ static __device__ __forceinline__ float get_alibi_slope( return powf(base, exph); } +template +struct ggml_cuda_type_traits; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = 1; + static constexpr int qr = 1; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK4_0; + static constexpr int qr = QR4_0; + static constexpr int qi = QI4_0; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK4_1; + static constexpr int qr = QR4_1; + static constexpr int qi = QI4_1; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK5_0; + static constexpr int qr = QR5_0; + static constexpr int qi = QI5_0; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK5_1; + static constexpr int qr = QR5_1; + static constexpr int qi = QI5_1; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK8_0; + static constexpr int qr = QR8_0; + static constexpr int qi = QI8_0; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_K; + static constexpr int qi = QI2_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_K; + static constexpr int qi = QI3_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_K; + static constexpr int qi = QI4_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR5_K; + static constexpr int qi = QI5_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR6_K; + static constexpr int qi = QI6_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_XXS; + static constexpr int qi = QI2_XXS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_XS; + static constexpr int qi = QI2_XS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_S; + static constexpr int qi = QI2_S; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_XXS; + static constexpr int qi = QI3_XXS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR1_S; + static constexpr int qi = QI1_S; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR1_M; + static constexpr int qi = QI1_M; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK4_NL; + static constexpr int qr = QR4_NL; + static constexpr int qi = QI4_NL; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_S; + static constexpr int qi = QI3_S; +}; + +static int get_mmq_x_max_host(const int cc) { +#ifdef CUDA_USE_TENSOR_CORES + return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64; +#else + return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64; +#endif // CUDA_USE_TENSOR_CORES +} + +// Round rows to this value for --split-mode row: +static int get_mmq_y_host(const int cc, const int mmq_x) { + return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64; +} + ////////////////////// struct ggml_cuda_device_info { -- cgit v1.2.3