From cae2b81155fdad75b7beab3a835c438120412969 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 9 Feb 2025 18:59:33 +0200 Subject: FA: Add option to build all FA kernels (#197) Similar to the CUDA situation. It is OFF by default. If OFF, only F16, Q8_0, Q6_0, and, if the CPU provides native BF16 support, BF16 FA kernels will be included. To enable all, cmake -DGGML_IQK_FA_ALL_QUANTS=1 ... This cuts compilation time for iqk_mul_mat.cpp by almost half (45 seconds vs 81 seconds on my Ryzen-7950X). Co-authored-by: Iwan Kawrakow --- ggml/src/CMakeLists.txt | 4 +++ ggml/src/iqk/iqk_mul_mat.cpp | 66 ++++++++++++++++++++++---------------------- 2 files changed, 37 insertions(+), 33 deletions(-) (limited to 'ggml/src') diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index da1746c8..3d1a2970 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -259,6 +259,10 @@ if (GGML_IQK_MUL_MAT) add_compile_definitions(GGML_USE_IQK_MULMAT) set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp) set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h) + if (GGML_IQK_FA_ALL_QUANTS) + message(STATUS "Including all IQK FA kernels") + add_compile_definitions(GGML_IQK_FA_ALL_QUANTS) + endif() endif() if (GGML_LLAMAFILE) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index aeba2c59..ee0af7e9 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -15239,14 +15239,7 @@ struct FlashQKfp32 { case 7: return std::make_pair(mul_mat<7>, 7);\ }\ } - if constexpr (std::is_same_v>) { -#ifdef __aarch64__ - MAKE_FUNCS(mul_mat_qX_0_q8_0>) { + if constexpr (std::is_same_v>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0>) { +#ifdef __aarch64__ + MAKE_FUNCS(mul_mat_qX_0_q8_0>) { +#ifdef __aarch64__ + MAKE_FUNCS(mul_mat_qX_0_q8_0>) { @@ -15278,13 +15286,7 @@ struct FlashQKfp32 { MAKE_FUNCS(mul_mat_qX_1_q8_1_T>) { -#ifdef __aarch64__ - MAKE_FUNCS(mul_mat_qX_0_q8_0 void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { -// if constexpr (std::is_same_v> || std::is_same_v> || -// std::is_same_v> || -// std::is_same_v> || -// std::is_same_v> || -// std::is_same_v>) { -// compute_helper_q>( -// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); -// } else { -// compute_helper>( -// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); -// } if constexpr (std::is_same_v> || std::is_same_v> || std::is_same_v> || std::is_same_v>) { @@ -16027,6 +16018,11 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperQ80 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; + case GGML_TYPE_Q6_0: { + HelperQ60 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + } break; +#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); @@ -16039,10 +16035,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperIQ4nl vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; - case GGML_TYPE_Q6_0: { - HelperQ60 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); - } break; +#endif default: break; } } @@ -16062,6 +16055,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ80 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; + case GGML_TYPE_Q6_0: { + HelperQ60 kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + } break; +#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); @@ -16074,10 +16072,7 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperIQ4nl kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; - case GGML_TYPE_Q6_0: { - HelperQ60 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); - } break; +#endif default: break; } @@ -16087,8 +16082,12 @@ inline bool flash_attn_is_supported(ggml_type type) { #ifdef __AVX512BF16__ if (type == GGML_TYPE_BF16) return true; #endif +#if GGML_IQK_FA_ALL_QUANTS if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true; +#else + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true; +#endif return false; } } @@ -16115,6 +16114,7 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k auto type_v = ggml_type(int_type_v); if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false; if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 + if (D != 64 && D != 96 && D != 128 && D != 256) return false; auto ck = (const char *)k; auto cv = (const char *)v; -- cgit v1.2.3