diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-02-09 18:59:33 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-02-09 18:59:33 +0200 |
commit | cae2b81155fdad75b7beab3a835c438120412969 (patch) | |
tree | e5b84d2744af15e1218db1ac935b4bfc1c499cb0 | |
parent | 33390c4b74fa52875d6028c5c9aaf84f17288c25 (diff) |
FA: Add option to build all FA kernels (#197)
Similar to the CUDA situation.
It is OFF by default.
If OFF, only F16, Q8_0, Q6_0, and, if the CPU provides native
BF16 support, BF16 FA kernels will be included.
To enable all, cmake -DGGML_IQK_FA_ALL_QUANTS=1 ...
This cuts compilation time for iqk_mul_mat.cpp by almost half
(45 seconds vs 81 seconds on my Ryzen-7950X).
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/CMakeLists.txt | 2 | ||||
-rw-r--r-- | ggml/src/CMakeLists.txt | 4 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 66 |
3 files changed, 39 insertions, 33 deletions
diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 90b37d5b..6775fdcb 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -130,6 +130,8 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF) +option(GGML_IQK_FA_ALL_QUANTS "ggml: compile all quants for IQK FlashAttention" OFF) + option(GGML_CURL "ggml: use libcurl to download model from an URL" OFF) option(GGML_HIPBLAS "ggml: use hipBLAS" OFF) option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index da1746c8..3d1a2970 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -259,6 +259,10 @@ if (GGML_IQK_MUL_MAT) add_compile_definitions(GGML_USE_IQK_MULMAT) set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp) set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h) + if (GGML_IQK_FA_ALL_QUANTS) + message(STATUS "Including all IQK FA kernels") + add_compile_definitions(GGML_IQK_FA_ALL_QUANTS) + endif() endif() if (GGML_LLAMAFILE) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index aeba2c59..ee0af7e9 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -15239,14 +15239,7 @@ struct FlashQKfp32 { case 7: return std::make_pair(mul_mat<7>, 7);\ }\ } - if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) { -#ifdef __aarch64__ - MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq); -#else - MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, nq); -#endif - } - else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) { + if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq); #else @@ -15264,6 +15257,21 @@ struct FlashQKfp32 { MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_1, nq); #endif } + else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) { +#ifdef __aarch64__ + MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq); +#else + MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, nq); +#endif + } +#if GGML_IQK_FA_ALL_QUANTS + else if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) { +#ifdef __aarch64__ + MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq); +#else + MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, nq); +#endif + } else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq); @@ -15278,13 +15286,7 @@ struct FlashQKfp32 { MAKE_FUNCS(mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, nq); #endif } - else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) { -#ifdef __aarch64__ - MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq); -#else - MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, nq); #endif - } else { GGML_ASSERT(false); } @@ -15493,17 +15495,6 @@ struct FlashAttn { template <typename KHelper, typename VHelper> void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { -// if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> || -// std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> || -// std::is_same_v<KHelper, HelperQ80<D, k_step>> || -// std::is_same_v<KHelper, HelperQ80R4<D, k_step>> || -// std::is_same_v<KHelper, HelperQ60<D, k_step>>) { -// compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( -// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); -// } else { -// compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( -// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); -// } if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> || std::is_same_v<KHelper, HelperQ60<D, k_step>>) { @@ -16027,6 +16018,11 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperQ80<D, k_step> vh(v, stride_v); iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; + case GGML_TYPE_Q6_0: { + HelperQ60<D, k_step> vh(v, stride_v); + iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + } break; +#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40<D, k_step> vh(v, stride_v); iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); @@ -16039,10 +16035,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperIQ4nl<D, k_step> vh(v, stride_v); iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; - case GGML_TYPE_Q6_0: { - HelperQ60<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); - } break; +#endif default: break; } } @@ -16062,6 +16055,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ80<D, k_step> kh(k, stride_k); iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; + case GGML_TYPE_Q6_0: { + HelperQ60<D, k_step> kh(k, stride_k); + iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + } break; +#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40<D, k_step> kh(k, stride_k); iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); @@ -16074,10 +16072,7 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperIQ4nl<D, k_step> kh(k, stride_k); iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; - case GGML_TYPE_Q6_0: { - HelperQ60<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); - } break; +#endif default: break; } @@ -16087,8 +16082,12 @@ inline bool flash_attn_is_supported(ggml_type type) { #ifdef __AVX512BF16__ if (type == GGML_TYPE_BF16) return true; #endif +#if GGML_IQK_FA_ALL_QUANTS if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true; +#else + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true; +#endif return false; } } @@ -16115,6 +16114,7 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k auto type_v = ggml_type(int_type_v); if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false; if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 + if (D != 64 && D != 96 && D != 128 && D != 256) return false; auto ck = (const char *)k; auto cv = (const char *)v; |