summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-02-09 18:59:33 +0200
committerGitHub <noreply@github.com>2025-02-09 18:59:33 +0200
commitcae2b81155fdad75b7beab3a835c438120412969 (patch)
treee5b84d2744af15e1218db1ac935b4bfc1c499cb0
parent33390c4b74fa52875d6028c5c9aaf84f17288c25 (diff)
FA: Add option to build all FA kernels (#197)
Similar to the CUDA situation. It is OFF by default. If OFF, only F16, Q8_0, Q6_0, and, if the CPU provides native BF16 support, BF16 FA kernels will be included. To enable all, cmake -DGGML_IQK_FA_ALL_QUANTS=1 ... This cuts compilation time for iqk_mul_mat.cpp by almost half (45 seconds vs 81 seconds on my Ryzen-7950X). Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/CMakeLists.txt2
-rw-r--r--ggml/src/CMakeLists.txt4
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp66
3 files changed, 39 insertions, 33 deletions
diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
index 90b37d5b..6775fdcb 100644
--- a/ggml/CMakeLists.txt
+++ b/ggml/CMakeLists.txt
@@ -130,6 +130,8 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM"
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF)
+option(GGML_IQK_FA_ALL_QUANTS "ggml: compile all quants for IQK FlashAttention" OFF)
+
option(GGML_CURL "ggml: use libcurl to download model from an URL" OFF)
option(GGML_HIPBLAS "ggml: use hipBLAS" OFF)
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index da1746c8..3d1a2970 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -259,6 +259,10 @@ if (GGML_IQK_MUL_MAT)
add_compile_definitions(GGML_USE_IQK_MULMAT)
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp)
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h)
+ if (GGML_IQK_FA_ALL_QUANTS)
+ message(STATUS "Including all IQK FA kernels")
+ add_compile_definitions(GGML_IQK_FA_ALL_QUANTS)
+ endif()
endif()
if (GGML_LLAMAFILE)
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index aeba2c59..ee0af7e9 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -15239,14 +15239,7 @@ struct FlashQKfp32 {
case 7: return std::make_pair(mul_mat<7>, 7);\
}\
}
- if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
-#ifdef __aarch64__
- MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
-#else
- MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, nq);
-#endif
- }
- else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
+ if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
#else
@@ -15264,6 +15257,21 @@ struct FlashQKfp32 {
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_1, nq);
#endif
}
+ else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
+#ifdef __aarch64__
+ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
+#else
+ MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, nq);
+#endif
+ }
+#if GGML_IQK_FA_ALL_QUANTS
+ else if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
+#ifdef __aarch64__
+ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
+#else
+ MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, nq);
+#endif
+ }
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq);
@@ -15278,13 +15286,7 @@ struct FlashQKfp32 {
MAKE_FUNCS(mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, nq);
#endif
}
- else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
-#ifdef __aarch64__
- MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
-#else
- MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, nq);
#endif
- }
else {
GGML_ASSERT(false);
}
@@ -15493,17 +15495,6 @@ struct FlashAttn {
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
-// if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
-// std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
-// std::is_same_v<KHelper, HelperQ80<D, k_step>> ||
-// std::is_same_v<KHelper, HelperQ80R4<D, k_step>> ||
-// std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
-// compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
-// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
-// } else {
-// compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
-// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
-// }
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
@@ -16027,6 +16018,11 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperQ80<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
+ case GGML_TYPE_Q6_0: {
+ HelperQ60<D, k_step> vh(v, stride_v);
+ iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ } break;
+#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: {
HelperQ40<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
@@ -16039,10 +16035,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperIQ4nl<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
- case GGML_TYPE_Q6_0: {
- HelperQ60<D, k_step> vh(v, stride_v);
- iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
- } break;
+#endif
default: break;
}
}
@@ -16062,6 +16055,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperQ80<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
+ case GGML_TYPE_Q6_0: {
+ HelperQ60<D, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ } break;
+#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: {
HelperQ40<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
@@ -16074,10 +16072,7 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperIQ4nl<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
- case GGML_TYPE_Q6_0: {
- HelperQ60<D, k_step> kh(k, stride_k);
- iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
- } break;
+#endif
default: break;
}
@@ -16087,8 +16082,12 @@ inline bool flash_attn_is_supported(ggml_type type) {
#ifdef __AVX512BF16__
if (type == GGML_TYPE_BF16) return true;
#endif
+#if GGML_IQK_FA_ALL_QUANTS
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true;
+#else
+ if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true;
+#endif
return false;
}
}
@@ -16115,6 +16114,7 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
auto type_v = ggml_type(int_type_v);
if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;
if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
+ if (D != 64 && D != 96 && D != 128 && D != 256) return false;
auto ck = (const char *)k;
auto cv = (const char *)v;