summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-04-24 17:37:12 +0200
committerGitHub <noreply@github.com>2025-04-24 17:37:12 +0200
commitc9eec1729fe95a5fcfd4ce47df440c2445abb17e (patch)
treeba01e4e34ba1bd42b885fb99a4b3af9528114f0b
parent222a1957430cf531a13c358f7ed54b7a4d96c26a (diff)
cuda: use switch in constexpr funcs (#343)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-cuda/mmq.cuh83
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu84
2 files changed, 87 insertions, 80 deletions
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 88e023a1..753848d7 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -160,26 +160,27 @@ static constexpr __device__ int get_mmq_y_device() {
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
- return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
- type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
- type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
- type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
- type == GGML_TYPE_Q6_0 ? MMQ_DP4A_TXS_Q8_0 :
- type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
- type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
- type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
- type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
- type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
- type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
- type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
- type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
- type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
- type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
- type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
- type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
- type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
- type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
- tile_x_sizes{0, 0, 0};
+ switch (type) {
+ case GGML_TYPE_Q4_1 : return MMQ_DP4A_TXS_Q4_1;
+ case GGML_TYPE_Q5_0 : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_Q5_1 : return MMQ_DP4A_TXS_Q8_1;
+ case GGML_TYPE_Q6_0 : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_Q8_0 : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_Q2_K : return MMQ_DP4A_TXS_Q2_K;
+ case GGML_TYPE_Q3_K : return MMQ_DP4A_TXS_Q3_K;
+ case GGML_TYPE_Q4_K : return MMQ_DP4A_TXS_Q4_K;
+ case GGML_TYPE_Q5_K : return MMQ_DP4A_TXS_Q5_K;
+ case GGML_TYPE_Q6_K : return MMQ_DP4A_TXS_Q6_K;
+ case GGML_TYPE_IQ2_XXS : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ2_XS : return MMQ_DP4A_TXS_Q8_0_16;
+ case GGML_TYPE_IQ2_S : return MMQ_DP4A_TXS_Q8_0_16;
+ case GGML_TYPE_IQ3_XXS : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ3_S : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ1_S : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0;
+ default : return tile_x_sizes{0, 0, 0};
+ }
}
#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
@@ -195,26 +196,28 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
- return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
- type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
- type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
- type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
- type == GGML_TYPE_Q6_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
- type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
- type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
- type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
- type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
- type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
- type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
- type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
- type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
- type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
- type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
- type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
- type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
- type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
- type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
- 0;
+ switch (type) {
+ case GGML_TYPE_Q4_0 : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_Q4_1 : return MMQ_MMA_TILE_X_K_Q8_1;
+ case GGML_TYPE_Q5_0 : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_Q5_1 : return MMQ_MMA_TILE_X_K_Q8_1;
+ case GGML_TYPE_Q6_0 : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_Q8_0 : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_Q2_K : return MMQ_MMA_TILE_X_K_Q2_K;
+ case GGML_TYPE_Q3_K : return MMQ_MMA_TILE_X_K_Q3_K;
+ case GGML_TYPE_Q4_K : return MMQ_MMA_TILE_X_K_Q8_1;
+ case GGML_TYPE_Q5_K : return MMQ_MMA_TILE_X_K_Q8_1;
+ case GGML_TYPE_Q6_K : return MMQ_MMA_TILE_X_K_Q6_K;
+ case GGML_TYPE_IQ2_XXS : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ2_XS : return MMQ_MMA_TILE_X_K_Q3_K;
+ case GGML_TYPE_IQ2_S : return MMQ_MMA_TILE_X_K_Q3_K;
+ case GGML_TYPE_IQ3_XXS : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ3_S : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ1_S : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0;
+ default : return 0;
+ }
}
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 69ccef1d..3d991b4d 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -12,49 +12,53 @@
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
- return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
- type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
- type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
- type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
- type == GGML_TYPE_Q6_0 ? vec_dot_q6_0_q8_1 :
- type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
- type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
- type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
- type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
- type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
- type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
- type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
- type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
- type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
- type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
- type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
- type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
- type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
- type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
- type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
- nullptr;
+ switch (type) {
+ case GGML_TYPE_Q4_0 : return vec_dot_q4_0_q8_1;
+ case GGML_TYPE_Q4_1 : return vec_dot_q4_1_q8_1;
+ case GGML_TYPE_Q5_0 : return vec_dot_q5_0_q8_1;
+ case GGML_TYPE_Q5_1 : return vec_dot_q5_1_q8_1;
+ case GGML_TYPE_Q6_0 : return vec_dot_q6_0_q8_1;
+ case GGML_TYPE_Q8_0 : return vec_dot_q8_0_q8_1;
+ case GGML_TYPE_Q2_K : return vec_dot_q2_K_q8_1;
+ case GGML_TYPE_Q3_K : return vec_dot_q3_K_q8_1;
+ case GGML_TYPE_Q4_K : return vec_dot_q4_K_q8_1;
+ case GGML_TYPE_Q5_K : return vec_dot_q5_K_q8_1;
+ case GGML_TYPE_Q6_K : return vec_dot_q6_K_q8_1;
+ case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1;
+ case GGML_TYPE_IQ2_XS : return vec_dot_iq2_xs_q8_1;
+ case GGML_TYPE_IQ2_S : return vec_dot_iq2_s_q8_1;
+ case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1;
+ case GGML_TYPE_IQ1_S : return vec_dot_iq1_s_q8_1;
+ case GGML_TYPE_IQ1_M : return vec_dot_iq1_m_q8_1;
+ case GGML_TYPE_IQ4_NL : return vec_dot_iq4_nl_q8_1;
+ case GGML_TYPE_IQ4_XS : return vec_dot_iq4_xs_q8_1;
+ case GGML_TYPE_IQ3_S : return vec_dot_iq3_s_q8_1;
+ default : return nullptr;
+ }
}
static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
- return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
- type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
- type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
- type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
- type == GGML_TYPE_Q6_0 ? VDR_Q6_0_Q8_1_MMVQ :
- type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
- type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
- type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
- type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
- type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
- type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
- type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
- type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ :
- type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ :
- type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ :
- type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ :
- type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ :
- type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ :
- 1;
+ switch (type) {
+ case GGML_TYPE_Q4_0 : return VDR_Q4_0_Q8_1_MMVQ;
+ case GGML_TYPE_Q4_1 : return VDR_Q4_1_Q8_1_MMVQ;
+ case GGML_TYPE_Q5_0 : return VDR_Q5_0_Q8_1_MMVQ;
+ case GGML_TYPE_Q5_1 : return VDR_Q5_1_Q8_1_MMVQ;
+ case GGML_TYPE_Q6_0 : return VDR_Q6_0_Q8_1_MMVQ;
+ case GGML_TYPE_Q8_0 : return VDR_Q8_0_Q8_1_MMVQ;
+ case GGML_TYPE_Q2_K : return VDR_Q2_K_Q8_1_MMVQ;
+ case GGML_TYPE_Q3_K : return VDR_Q3_K_Q8_1_MMVQ;
+ case GGML_TYPE_Q4_K : return VDR_Q4_K_Q8_1_MMVQ;
+ case GGML_TYPE_Q5_K : return VDR_Q5_K_Q8_1_MMVQ;
+ case GGML_TYPE_Q6_K : return VDR_Q6_K_Q8_1_MMVQ;
+ case GGML_TYPE_IQ2_XXS : return VDR_IQ2_XXS_Q8_1_MMVQ;
+ case GGML_TYPE_IQ2_XS : return VDR_IQ2_XS_Q8_1_MMVQ;
+ case GGML_TYPE_IQ2_S : return VDR_IQ2_S_Q8_1_MMVQ;
+ case GGML_TYPE_IQ3_XXS : return VDR_IQ3_XXS_Q8_1_MMVQ;
+ case GGML_TYPE_IQ3_S : return VDR_IQ3_S_Q8_1_MMVQ;
+ case GGML_TYPE_IQ4_NL : return VDR_IQ4_NL_Q8_1_MMVQ;
+ case GGML_TYPE_IQ4_XS : return VDR_IQ4_XS_Q8_1_MMVQ;
+ default : return 1;
+ }
}
template <ggml_type type, int ncols_y>