summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-09-05 07:46:47 +0300
committerGitHub <noreply@github.com>2024-09-05 07:46:47 +0300
commit7b1b2b2c06c1729139135c9e47611af7161de6f7 (patch)
treeab79924dbb9f2ff780dd669fa65f826aae74d0b7
parentf17d0d72f565bf24d6eb8aa67d6618cdc143961d (diff)
Zen4 Flash Attention - bf16 support (#38)
* Zen4 Flash Attnetion: WIP bf16 * Zen4 Flash Attnetion: bf16 seems to be working * Zen4 Flash Attnetion: improving bf16 * Zen4 Flash Attnetion: improving bf16 It is better (slightly faster) to first convert Q to bf16 before processing each block of q_step rows. This requires D*q_step*sizeof(bf16) bytes, so at most 4 kb for the head sizes we support, so we can just allocate on the stack instead of reserving and passing a work buffer in ggml. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--common/common.cpp3
-rw-r--r--examples/llama-bench/llama-bench.cpp3
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp712
3 files changed, 538 insertions, 180 deletions
diff --git a/common/common.cpp b/common/common.cpp
index c86d364f..6c298d2d 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -2221,6 +2221,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "f16") {
return GGML_TYPE_F16;
}
+ if (s == "bf16") {
+ return GGML_TYPE_BF16;
+ }
if (s == "q8_0") {
return GGML_TYPE_Q8_0;
}
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 813d7bae..fc77be50 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -306,6 +306,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
if (s == "f16") {
return GGML_TYPE_F16;
}
+ if (s == "bf16") {
+ return GGML_TYPE_BF16;
+ }
if (s == "q8_0") {
return GGML_TYPE_Q8_0;
}
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 511eea01..84514ddc 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -6192,7 +6192,6 @@ struct HelperF16 final : public BaseHelper<step> {
load(l1+0, vk+0);
load(l1+1, vk+D/16);
}
-
};
template <int D, int step>
@@ -6356,29 +6355,9 @@ struct HelperQ41 final : public BaseHelper<step> {
const __m128i mask = _mm_set1_epi8(0xf);
};
-
-// Some of the methods in FlashAttn have two identical implementations that only differ by
-// one version using a loop over the template parameter q_step, while the other using a loop
-// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot,
-// but performance drops signficantly if I remove the version with fixed q_step iterations.
-// We only instantiate FlashAttn with q_step = 1 and q_step = 4 or 8 (depending on head size D),
-// so when we have to process Nq rows, we process q_step*(Nq/q_step) using fixed q_step loops,
-// and use the variable nq version (with lower performance) only for the remaining i1...q_step-1
-// rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to
-// process template parameter of such functions, but this would result in the compiler generating
-// q_step-1 versions of these functions for us, which I though was too much with q_step = 8.
-template <int D, int q_step, int k_step>
-struct FlashAttn {
- static_assert(D%16 == 0 && D <= 256);
- static_assert(k_step%16 == 0);
- static_assert(q_step <= 4 || q_step%4 == 0);
-
- constexpr static bool is_small_head = D <= 128;
-
- constexpr static int vk_size = is_small_head ? D/8 : D/16;
- static_assert(2*q_step <= vk_size);
-
- FlashAttn(float scale, float softcap) : vscale(_mm512_set1_ps(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
+template <int q_step, int k_step>
+struct FlashMS {
+ FlashMS(float scale, float softcap) : vscale(_mm512_set1_ps(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
inline void init_qstep() {
for (int j = 0; j < q_step; ++j) {
@@ -6386,47 +6365,7 @@ struct FlashAttn {
}
}
- template <bool small = is_small_head, class = std::enable_if<small>>
- inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, __m512 * qv) {
- // q index is q_step*i1 + m1
- // k index is k_step*k1 + l1
- const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
- cache[k_step*m1 + l1 + 0] = cache[k_step*m1 + l1 + 1] = -INFINITY;
- if (mp[l1+0] == h_inf && mp[l1+1] == h_inf) {
- return;
- }
- auto qr = q + m1*stride_q;
- for (int i = 0; i < D/16; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i);
- if (mp[l1+0] != h_inf) {
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum);
- cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
- }
- if (mp[l1+1] != h_inf) {
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum);
- cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
- }
- }
-
- template <bool small = is_small_head, class = std::enable_if<!small>>
- inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask) {
- // q index is q_step*i1 + m1
- // k index is k_step*k1 + l1
- const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
- if (mp[l1] == h_inf) {
- cache[k_step*m1 + l1] = -INFINITY;
- return;
- }
- auto qr = q + m1*stride_q;
- auto vsum = _mm512_setzero_ps();
- for (int i = 0; i < D/16; ++i) {
- vsum = _mm512_fmadd_ps(vk[i], _mm512_loadu_ps(qr + 16*i), vsum);
- }
- cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum);
- }
-
- inline void update_M_S(int j) {
+ inline void update_M_S(int j, __m512 * vk) {
if (softcap <= 0.0f) {
for (int l = 0; l < k_step/16; ++l) vk[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
} else {
@@ -6464,187 +6403,585 @@ struct FlashAttn {
S[j] += reduce_T<_mm512_reduce_add_ps, _mm512_add_ps>(vk);
}
- inline void normalize_and_store(int j, const float * R, float * qkv) const {
- GGML_ASSERT(S[j] > 0);
- auto norm = _mm512_set1_ps(1/S[j]);
+ float cache[q_step*k_step];
+ float S[q_step], M[q_step];
+ int need_scaling[q_step];
+ __m512 vms[q_step];
+ const __m512 vscale;
+ const float softcap;
+ const ggml_half h_inf;
+
+ typedef __m512 (*combine_t)(__m512, __m512);
+ typedef float (*reduce_t)(__m512);
+ template <reduce_t Op, combine_t Op_combine>
+ static inline float reduce_T(const __m512 * vals) {
+ float result;
+ if constexpr (k_step/16 == 1) {
+ result = Op(vals[0]);
+ }
+ else if constexpr (k_step/16 == 2) {
+ result = Op(Op_combine(vals[0], vals[1]));
+ }
+ else {
+ auto vmax = Op_combine(vals[0], vals[1]);
+ for (int l = 2; l < k_step/16; ++l) vmax = Op_combine(vmax, vals[l]);
+ result = Op(vmax);
+ }
+ return result;
+ }
+};
+
+template <int D, int q_step, int k_step>
+struct FlashQKV {
+
+ // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
+ // Hence, for now, we will not handle head sizes of 80 and 112
+ template <typename VHelper>
+ inline void accumulate_qkv(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
+ __m512 vk[2*q_step];
+ for (int i = 0; i < D/16; i += 2) {
+ for (int j = 0; j < q_step; ++j) {
+ if (fms.need_scaling[j] == 2) {
+ vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
+ } else {
+ auto R = qkv_cache + D*j;
+ vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
+ vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
+ if (fms.need_scaling[j] == 1) {
+ vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], fms.vms[j]);
+ vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], fms.vms[j]);
+ }
+ }
+ }
+ __m512 v1, v2;
+ for (int l1 = 0; l1 < k_step; ++l1) {
+ vh.load(l1, i, v1, v2);
+ for (int j = 0; j < q_step; ++j) {
+ auto vs = _mm512_set1_ps(fms.cache[k_step*j + l1]);
+ vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
+ vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
+ }
+ }
+ for (int j = 0; j < q_step; ++j) {
+ auto R = qkv_cache + D*j;
+ _mm512_storeu_ps(R + 16*i, vk[2*j+0]);
+ _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
+ }
+ }
+ }
+
+ template <typename VHelper, int Nq = q_step, class = std::enable_if<Nq >= 2>>
+ inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
+ __m512 vk[2*q_step];
+ for (int i = 0; i < D/16; i += 2) {
+ for (int j = 0; j < nq1; ++j) {
+ if (fms.need_scaling[j] == 2) {
+ vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
+ } else {
+ auto R = qkv_cache + D*j;
+ vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
+ vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
+ if (fms.need_scaling[j] == 1) {
+ vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], fms.vms[j]);
+ vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], fms.vms[j]);
+ }
+ }
+ }
+ __m512 v1, v2;
+ for (int l1 = 0; l1 < k_step; ++l1) {
+ vh.load(l1, i, v1, v2);
+ for (int j = 0; j < nq1; ++j) {
+ auto vs = _mm512_set1_ps(fms.cache[k_step*j + l1]);
+ vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
+ vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
+ }
+ }
+ for (int j = 0; j < nq1; ++j) {
+ auto R = qkv_cache + D*j;
+ _mm512_storeu_ps(R + 16*i, vk[2*j+0]);
+ _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
+ }
+ }
+ }
+
+ inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const float * R, float * qkv) const {
+ GGML_ASSERT(fms.S[j] > 0);
+ auto norm = _mm512_set1_ps(1/fms.S[j]);
for (int i = 0; i < D/16; ++i) {
auto r = _mm512_loadu_ps(R + 16*i);
_mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r));
}
}
- inline void normalize_and_store(int nq1, int stride_qkv, float * qkv) const {
+ inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int nq1, int stride_qkv, float * qkv) const {
auto R = qkv_cache;
for (int j = 0; j < nq1; ++j) {
- normalize_and_store(j, R, qkv);
+ normalize_and_store(fms, j, R, qkv);
qkv += stride_qkv;
R += D;
}
}
- inline void normalize_and_store(int stride_qkv, float * qkv) const {
+ inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int stride_qkv, float * qkv) const {
auto R = qkv_cache;
for (int j = 0; j < q_step; ++j) {
- normalize_and_store(j, R, qkv);
+ normalize_and_store(fms, j, R, qkv);
qkv += stride_qkv;
R += D;
}
}
+ float qkv_cache[D*q_step];
+};
+
+template <int D, int q_step, int k_step>
+struct FlashQKfp32 {
+ static_assert(D%16 == 0 && D <= 256);
+ static_assert(k_step%16 == 0);
+ static_assert(q_step <= 4 || q_step%4 == 0);
+
+ constexpr static bool is_small_head = D <= 128;
+
+ template <bool small = is_small_head, class = std::enable_if<small>>
+ static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
+ __m512 * qv, __m512 * vk, FlashMS<q_step, k_step>& fms) {
+ // q index is q_step*i1 + m1
+ // k index is k_step*k1 + l1
+ const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
+ fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY;
+ if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) {
+ return;
+ }
+ auto qr = q + m1*stride_q;
+ for (int i = 0; i < D/16; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i);
+ if (mp[l1+0] != fms.h_inf) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum);
+ fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
+ }
+ if (mp[l1+1] != fms.h_inf) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum);
+ fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
+ }
+ }
+
+ template <bool small = is_small_head, class = std::enable_if<!small>>
+ static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
+ __m512 * vk, FlashMS<q_step, k_step>& fms) {
+ // q index is q_step*i1 + m1
+ // k index is k_step*k1 + l1
+ const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
+ if (mp[l1] == fms.h_inf) {
+ fms.cache[k_step*m1 + l1] = -INFINITY;
+ return;
+ }
+ auto qr = q + m1*stride_q;
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/16; ++i) {
+ vsum = _mm512_fmadd_ps(vk[i], _mm512_loadu_ps(qr + 16*i), vsum);
+ }
+ fms.cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum);
+ }
+
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
- inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) {
+ static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
+ FlashMS<q_step, k_step>& fms) {
__m512 qv[D/16];
+ __m512 vk[D/8];
for (int l1 = 0; l1 < k_step; l1 += 2) {
kh.load_2(l1, vk);
for (int m1 = 0; m1 < q_step; ++m1) {
- mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv);
+ mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vk, fms);
}
}
}
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
- inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m,
- const float * q, const char * mask) {
+ static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m,
+ const float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
+ __m512 vk[D/16];
for (int l1 = 0; l1 < k_step; ++l1) {
kh.load(l1, vk);
for (int m1 = 0; m1 < q_step; ++m1) {
- mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask);
+ mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, vk, fms);
}
}
}
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
- inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) {
+ static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
+ FlashMS<q_step, k_step>& fms) {
__m512 qv[D/16];
+ __m512 vk[D/8];
for (int l1 = 0; l1 < k_step; l1 += 2) {
kh.load_2(l1, vk);
for (int m1 = 0; m1 < nq; ++m1) {
- mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv);
+ mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vk, fms);
}
}
}
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
- inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m,
- const float * q, const char * mask) {
+ static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m,
+ const float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
+ __m512 vk[D/16];
for (int l1 = 0; l1 < k_step; ++l1) {
kh.load(l1, vk);
for (int m1 = 0; m1 < nq; ++m1) {
- mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask);
+ mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, vk, fms);
}
}
}
template <typename KHelper>
- inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) {
+ static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
+ FlashMS<q_step, k_step>& fms) {
if constexpr (is_small_head) {
- mult_mask_kq(kh, stride_q, stride_m, q, mask);
+ mult_mask_kq(kh, stride_q, stride_m, q, mask, fms);
}
else {
- mult_mask_kq_l(kh, stride_q, stride_m, q, mask);
+ mult_mask_kq_l(kh, stride_q, stride_m, q, mask, fms);
}
+ __m512 vk[k_step/16];
for (int j = 0; j < q_step; ++j) {
- update_M_S(j);
+ fms.update_M_S(j, vk);
}
}
template <typename KHelper>
- inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) {
+ static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
+ FlashMS<q_step, k_step>& fms) {
if constexpr (is_small_head) {
- mult_mask_kq(nq, kh, stride_q, stride_m, q, mask);
+ mult_mask_kq(nq, kh, stride_q, stride_m, q, mask, fms);
}
else {
- mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask);
+ mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask, fms);
}
+ __m512 vk[k_step/16];
for (int j = 0; j < nq; ++j) {
- update_M_S(j);
+ fms.update_M_S(j, vk);
+ }
+ }
+};
+
+template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
+void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
+ FlashMS<q_step, k_step>& fms,
+ FlashQKV<D, q_step, k_step>& fqkv,
+ const float * q, const char * mask, float * qkv) {
+ for (int i1 = 0; i1 < nq1/q_step; ++i1) {
+ fms.init_qstep();
+ kh.reset_block();
+ vh.reset_block();
+ auto mr = mask;
+ for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+ KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms);
+ fqkv.accumulate_qkv(vh, fms);
+ kh.next_block();
+ vh.next_block();
+ mr += k_step*sizeof(ggml_half);
+ }
+ fqkv.normalize_and_store(fms, stride_qkv, qkv);
+
+ q += q_step*stride_q;
+ mask += q_step*stride_m;
+ qkv += q_step*stride_qkv;
+ }
+ int n_left = nq1 - q_step*(nq1/q_step);
+ if (n_left > 0) {
+ fms.init_qstep();
+ kh.reset_block();
+ vh.reset_block();
+ auto mr = mask;
+ for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+ KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms);
+ fqkv.accumulate_qkv(n_left, vh, fms);
+ kh.next_block();
+ vh.next_block();
+ mr += k_step*sizeof(ggml_half);
+ }
+ fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv);
+ }
+}
+
+// Some of the methods in FlashAttn have two identical implementations that only differ by
+// one version using a loop over the template parameter q_step, while the other using a loop
+// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot,
+// but performance drops signficantly if I remove the version with fixed q_step iterations.
+// We only instantiate FlashAttn with q_step = 1 and q_step = 4 or 8 (depending on head size D),
+// so when we have to process Nq rows, we process q_step*(Nq/q_step) using fixed q_step loops,
+// and use the variable nq version (with lower performance) only for the remaining i1...q_step-1
+// rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to
+// process template parameter of such functions, but this would result in the compiler generating
+// q_step-1 versions of these functions for us, which I though was too much with q_step = 8.
+template <int D, int q_step, int k_step>
+struct FlashAttn {
+ static_assert(D%16 == 0 && D <= 256);
+ static_assert(k_step%16 == 0);
+ static_assert(q_step <= 4 || q_step%4 == 0);
+
+ FlashAttn(float scale, float softcap) : fms(scale, softcap) {}
+
+ template <typename KHelper, typename VHelper>
+ void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
+ const float * q, const char * mask, float * qkv) {
+ compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ }
+
+ FlashMS<q_step, k_step> fms;
+ FlashQKV<D, q_step, k_step> fqkv;
+
+};
+
+#ifdef __AVX512BF16__
+
+template <int D, int step>
+struct HelperBF16 final : public BaseHelper<step> {
+ using Base = BaseHelper<step>;
+ HelperBF16(const char * data, int stride) : Base(data, stride) {}
+ inline void load(int l1, __m512bh * vk) const {
+ auto dr = Base::lblock(l1);
+ for (int i = 0; i < D/32; ++i) vk[i] = __m512bh(_mm512_loadu_si512((const __m512i*)dr + i));
+ }
+
+ inline void load(int l1, int i, __m512& v1, __m512& v2) const {
+ auto dr = Base::lblock(l1);
+ v1 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)dr + i + 0)), 16));
+ v2 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)dr + i + 1)), 16));
+ }
+
+ inline void load_2(int l1, __m512bh * vk) const {
+ load(l1+0, vk+0);
+ load(l1+1, vk+D/32);
+ }
+
+ inline void load_4(int l1, __m512bh * vk) const {
+ load(l1+0, vk+0);
+ load(l1+1, vk+1*D/32);
+ load(l1+2, vk+2*D/32);
+ load(l1+3, vk+3*D/32);
+ }
+};
+
+template <int D, int q_step, int k_step>
+struct FlashQKbf16 {
+ static_assert(D%32 == 0 && D <= 256);
+ static_assert(k_step%32 == 0);
+ static_assert(q_step <= 4 || q_step%4 == 0);
+
+ static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
+ __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
+ // q index is q_step*i1 + m1
+ // k index is k_step*k1 + l1
+ const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
+ fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY;
+ if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) {
+ return;
+ }
+ auto qr = q + m1*stride_q;
+ for (int i = 0; i < D/32; ++i) {
+ auto val1 = _mm512_loadu_ps(qr + 32*i);
+ auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
+ qv[i] = _mm512_cvtne2ps_pbh(val2, val1);
+ }
+ if (mp[l1+0] != fms.h_inf) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]);
+ fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
+ }
+ if (mp[l1+1] != fms.h_inf) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]);
+ fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
}
}
- // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
- // Hence, for now, we will not handle head sizes of 80 and 112
- template <typename VHelper>
- inline void accumulate_qkv(const VHelper& vh) {
- for (int i = 0; i < D/16; i += 2) {
- for (int j = 0; j < q_step; ++j) {
- if (need_scaling[j] == 2) {
- vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
- } else {
- auto R = qkv_cache + D*j;
- vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
- vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
- if (need_scaling[j] == 1) {
- vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
- vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
+ static inline void mult_mask_kq_one(int l1, int m1, int stride_m, const ggml_bf16_t * q, const char * mask,
+ __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
+ // q index is q_step*i1 + m1
+ // k index is k_step*k1 + l1
+ const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
+ fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY;
+ if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) {
+ return;
+ }
+ auto qr = q + m1*D;
+ for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i*)qr + i));
+ if (mp[l1+0] != fms.h_inf) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]);
+ fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
+ }
+ if (mp[l1+1] != fms.h_inf) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]);
+ fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
+ }
+ }
+
+ static inline void mult_mask_kq_4(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
+ __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
+ // q index is q_step*i1 + m1
+ // k index is k_step*k1 + l1
+ const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
+ fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] =
+ fms.cache[k_step*m1 + l1 + 2] = fms.cache[k_step*m1 + l1 + 3] = -INFINITY;
+ if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf && mp[l1+2] == fms.h_inf && mp[l1+3] == fms.h_inf) {
+ return;
+ }
+ auto qr = q + m1*stride_q;
+ for (int i = 0; i < D/32; ++i) {
+ auto val1 = _mm512_loadu_ps(qr + 32*i);
+ auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
+ qv[i] = _mm512_cvtne2ps_pbh(val2, val1);
+ }
+ for (int k = 0; k < 4; ++k) {
+ if (mp[l1+k] == fms.h_inf) continue;
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
+ fms.cache[k_step*m1 + l1 + k] = _mm512_reduce_add_ps(vsum);
+ }
+ }
+
+ static inline void mult_mask_kq_4(int l1, int m1, int stride_m, const ggml_bf16_t * q, const char * mask,
+ __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
+ // q index is q_step*i1 + m1
+ // k index is k_step*k1 + l1
+ const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
+ fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] =
+ fms.cache[k_step*m1 + l1 + 2] = fms.cache[k_step*m1 + l1 + 3] = -INFINITY;
+ if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf && mp[l1+2] == fms.h_inf && mp[l1+3] == fms.h_inf) {
+ return;
+ }
+ auto qr = q + m1*D;
+ for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i));
+ for (int k = 0; k < 4; ++k) {
+ if (mp[l1+k] == fms.h_inf) continue;
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
+ fms.cache[k_step*m1 + l1 + k] = _mm512_reduce_add_ps(vsum);
+ }
+ }
+
+ template <typename KHelper>
+ static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q,
+ const char * mask, FlashMS<q_step, k_step>& fms) {
+ {
+ __m512bh qv[D/32];
+ if constexpr (D <= 128) {
+ __m512bh vkh[D/8];
+ for (int l1 = 0; l1 < k_step; l1 += 4) {
+ kh.load_4(l1, vkh);
+ for (int j = 0; j < q_step; ++j) {
+ mult_mask_kq_4(l1, j, stride_q, stride_m, q, mask, qv, vkh, fms);
}
}
- }
- __m512 v1, v2;
- for (int l1 = 0; l1 < k_step; ++l1) {
- vh.load(l1, i, v1, v2);
- for (int j = 0; j < q_step; ++j) {
- auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
- vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
- vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
+ } else {
+ __m512bh vkh[D/16];
+ for (int l1 = 0; l1 < k_step; l1 += 2) {
+ kh.load_2(l1, vkh);
+ for (int j = 0; j < q_step; ++j) {
+ mult_mask_kq_one(l1, j, stride_q, stride_m, q, mask, qv, vkh, fms);
+ }
}
}
- for (int j = 0; j < q_step; ++j) {
- auto R = qkv_cache + D*j;
- _mm512_storeu_ps(R + 16*i, vk[2*j+0]);
- _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
- }
+ }
+ __m512 vk[k_step/16];
+ for (int j = 0; j < q_step; ++j) {
+ fms.update_M_S(j, vk);
}
}
- template <typename VHelper, int Nq = q_step, class = std::enable_if<Nq >= 2>>
- inline void accumulate_qkv(int nq1, const VHelper& vh) {
- for (int i = 0; i < D/16; i += 2) {
- for (int j = 0; j < nq1; ++j) {
- if (need_scaling[j] == 2) {
- vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
- } else {
- auto R = qkv_cache + D*j;
- vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
- vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
- if (need_scaling[j] == 1) {
- vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
- vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
+ template <typename KHelper>
+ static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
+ const char * mask, FlashMS<q_step, k_step>& fms) {
+ {
+ __m512bh qv[D/32];
+ if constexpr (D <= 128) {
+ __m512bh vkh[D/8];
+ for (int l1 = 0; l1 < k_step; l1 += 4) {
+ kh.load_4(l1, vkh);
+ for (int j = 0; j < q_step; ++j) {
+ mult_mask_kq_4(l1, j, stride_m, q, mask, qv, vkh, fms);
+ }
+ }
+ } else {
+ __m512bh vkh[D/16];
+ for (int l1 = 0; l1 < k_step; l1 += 2) {
+ kh.load_2(l1, vkh);
+ for (int j = 0; j < q_step; ++j) {
+ mult_mask_kq_one(l1, j, stride_m, q, mask, qv, vkh, fms);
}
}
}
- __m512 v1, v2;
- for (int l1 = 0; l1 < k_step; ++l1) {
- vh.load(l1, i, v1, v2);
- for (int j = 0; j < nq1; ++j) {
- auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
- vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
- vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
+ }
+ __m512 vk[k_step/16];
+ for (int j = 0; j < q_step; ++j) {
+ fms.update_M_S(j, vk);
+ }
+ }
+
+ template <typename KHelper>
+ static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q,
+ const char * mask, FlashMS<q_step, k_step>& fms) {
+ {
+ __m512bh qv[D/32];
+ __m512bh vkh[D/16];
+ for (int l1 = 0; l1 < k_step; l1 += 2) {
+ kh.load_2(l1, vkh);
+ for (int m1 = 0; m1 < nq; ++m1) {
+ mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vkh, fms);
}
}
- for (int j = 0; j < nq1; ++j) {
- auto R = qkv_cache + D*j;
- _mm512_storeu_ps(R + 16*i, vk[2*j+0]);
- _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
+ }
+ __m512 vk[k_step/16];
+ for (int j = 0; j < nq; ++j) {
+ fms.update_M_S(j, vk);
+ }
+ }
+
+ static inline void convert(int stride_q, const float * q, ggml_bf16_t * bf16) {
+ auto qr = q;
+ for (int j = 0; j < q_step; ++j) {
+ for (int i = 0; i < D/32; ++i) {
+ auto val1 = _mm512_loadu_ps(qr + 32*i);
+ auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
+ _mm512_storeu_si512((__m512i *)bf16 + i, (__m512i)_mm512_cvtne2ps_pbh(val2, val1));
}
+ qr += stride_q;
+ bf16 += D;
}
}
+};
+
+template <int D, int q_step, int k_step>
+struct FlashAttnBF16 {
+ static_assert(D%32 == 0 && D <= 256);
+ static_assert(k_step%32 == 0);
+ static_assert(q_step <= 4 || q_step%4 == 0);
+
+ FlashAttnBF16(float scale, float softcap) : fms(scale, softcap) {}
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
+ ggml_bf16_t q_bf16[q_step*D];
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
- init_qstep();
+ fms.init_qstep();
kh.reset_block();
vh.reset_block();
+ FlashQKbf16<D, q_step, k_step>::convert(stride_q, q, q_bf16);
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
- multiply_mask_kq(kh, stride_q, stride_m, q, mr);
- accumulate_qkv(vh);
+ FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms);
+ fqkv.accumulate_qkv(vh, fms);
kh.next_block();
vh.next_block();
mr += k_step*sizeof(ggml_half);
}
- normalize_and_store(stride_qkv, qkv);
+ fqkv.normalize_and_store(fms, stride_qkv, qkv);
q += q_step*stride_q;
mask += q_step*stride_m;
@@ -6652,50 +6989,25 @@ struct FlashAttn {
}
int n_left = nq1 - q_step*(nq1/q_step);
if (n_left > 0) {
- init_qstep();
+ fms.init_qstep();
kh.reset_block();
vh.reset_block();
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
- multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr);
- accumulate_qkv(n_left, vh);
+ FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms);
+ fqkv.accumulate_qkv(n_left, vh, fms);
kh.next_block();
vh.next_block();
mr += k_step*sizeof(ggml_half);
}
- normalize_and_store(n_left, stride_qkv, qkv);
+ fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv);
}
}
- float cache[q_step*k_step];
- float qkv_cache[D*q_step];
- float S[q_step], M[q_step];
- int need_scaling[q_step];
- __m512 vms[q_step];
- __m512 vk[vk_size];
- const __m512 vscale;
- const float softcap;
- const ggml_half h_inf;
-
- typedef __m512 (*combine_t)(__m512, __m512);
- typedef float (*reduce_t)(__m512);
- template <reduce_t Op, combine_t Op_combine>
- static inline float reduce_T(const __m512 * vals) {
- float result;
- if constexpr (k_step/16 == 1) {
- result = Op(vals[0]);
- }
- else if constexpr (k_step/16 == 2) {
- result = Op(Op_combine(vals[0], vals[1]));
- }
- else {
- auto vmax = Op_combine(vals[0], vals[1]);
- for (int l = 2; l < k_step/16; ++l) vmax = Op_combine(vmax, vals[l]);
- result = Op(vmax);
- }
- return result;
- }
+ FlashMS<q_step, k_step> fms;
+ FlashQKV<D, q_step, k_step> fqkv;
};
+#endif
template <int D, int q_step, int k_step, typename KHelper, typename VHelper>
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
@@ -6710,6 +7022,23 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
}
}
+#ifdef __AVX512BF16__
+template <int D, int q_step, int k_step>
+inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
+ const float * q, const char * k, const char * v, const char * mask,
+ float scale, float softcap, float * qkv) {
+ HelperBF16<D, k_step> kh(k, stride_k);
+ HelperBF16<D, k_step> vh(v, stride_v);
+ if (nq1 >= q_step) {
+ FlashAttnBF16<D, q_step, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ } else {
+ FlashAttnBF16<D, 1, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ }
+}
+#endif
+
template <int D, int q_step, int k_step, typename KHelper>
inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv,
@@ -6766,7 +7095,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
}
inline bool flash_attn_is_supported(ggml_type type) {
+#ifdef __AVX512BF16__
+ return type == GGML_TYPE_F16 || type == GGML_TYPE_BF16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1;
+#else
return type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1;
+#endif
}
}
@@ -6799,14 +7132,33 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
stride_q /= sizeof(float); // q stride as float
+#ifdef __AVX512BF16__
+ if (type_k == GGML_TYPE_BF16 && type_v == GGML_TYPE_BF16) {
+ switch (D) {
+ case 64:
+ iqk_flash_helper_T< 64, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ case 96:
+ iqk_flash_helper_T< 96, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ case 128:
+ iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ case 256:
+ iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ default:
+ return false;
+ }
+
+ return true;
+ }
+#endif
+
switch (D) {
case 64:
- iqk_flash_helper_T< 64, 4, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 64, 8, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 80:
// iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 96:
- iqk_flash_helper_T< 96, 4, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 96, 8, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 112:
// iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;