diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-09-05 07:46:47 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-05 07:46:47 +0300 |
commit | 7b1b2b2c06c1729139135c9e47611af7161de6f7 (patch) | |
tree | ab79924dbb9f2ff780dd669fa65f826aae74d0b7 | |
parent | f17d0d72f565bf24d6eb8aa67d6618cdc143961d (diff) |
Zen4 Flash Attention - bf16 support (#38)
* Zen4 Flash Attnetion: WIP bf16
* Zen4 Flash Attnetion: bf16 seems to be working
* Zen4 Flash Attnetion: improving bf16
* Zen4 Flash Attnetion: improving bf16
It is better (slightly faster) to first convert Q
to bf16 before processing each block of q_step rows.
This requires D*q_step*sizeof(bf16) bytes, so at
most 4 kb for the head sizes we support, so we can
just allocate on the stack instead of reserving and
passing a work buffer in ggml.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | common/common.cpp | 3 | ||||
-rw-r--r-- | examples/llama-bench/llama-bench.cpp | 3 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 712 |
3 files changed, 538 insertions, 180 deletions
diff --git a/common/common.cpp b/common/common.cpp index c86d364f..6c298d2d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2221,6 +2221,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) { if (s == "f16") { return GGML_TYPE_F16; } + if (s == "bf16") { + return GGML_TYPE_BF16; + } if (s == "q8_0") { return GGML_TYPE_Q8_0; } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 813d7bae..fc77be50 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -306,6 +306,9 @@ static ggml_type ggml_type_from_name(const std::string & s) { if (s == "f16") { return GGML_TYPE_F16; } + if (s == "bf16") { + return GGML_TYPE_BF16; + } if (s == "q8_0") { return GGML_TYPE_Q8_0; } diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 511eea01..84514ddc 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6192,7 +6192,6 @@ struct HelperF16 final : public BaseHelper<step> { load(l1+0, vk+0); load(l1+1, vk+D/16); } - }; template <int D, int step> @@ -6356,29 +6355,9 @@ struct HelperQ41 final : public BaseHelper<step> { const __m128i mask = _mm_set1_epi8(0xf); }; - -// Some of the methods in FlashAttn have two identical implementations that only differ by -// one version using a loop over the template parameter q_step, while the other using a loop -// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot, -// but performance drops signficantly if I remove the version with fixed q_step iterations. -// We only instantiate FlashAttn with q_step = 1 and q_step = 4 or 8 (depending on head size D), -// so when we have to process Nq rows, we process q_step*(Nq/q_step) using fixed q_step loops, -// and use the variable nq version (with lower performance) only for the remaining i1...q_step-1 -// rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to -// process template parameter of such functions, but this would result in the compiler generating -// q_step-1 versions of these functions for us, which I though was too much with q_step = 8. -template <int D, int q_step, int k_step> -struct FlashAttn { - static_assert(D%16 == 0 && D <= 256); - static_assert(k_step%16 == 0); - static_assert(q_step <= 4 || q_step%4 == 0); - - constexpr static bool is_small_head = D <= 128; - - constexpr static int vk_size = is_small_head ? D/8 : D/16; - static_assert(2*q_step <= vk_size); - - FlashAttn(float scale, float softcap) : vscale(_mm512_set1_ps(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {} +template <int q_step, int k_step> +struct FlashMS { + FlashMS(float scale, float softcap) : vscale(_mm512_set1_ps(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {} inline void init_qstep() { for (int j = 0; j < q_step; ++j) { @@ -6386,47 +6365,7 @@ struct FlashAttn { } } - template <bool small = is_small_head, class = std::enable_if<small>> - inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, __m512 * qv) { - // q index is q_step*i1 + m1 - // k index is k_step*k1 + l1 - const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); - cache[k_step*m1 + l1 + 0] = cache[k_step*m1 + l1 + 1] = -INFINITY; - if (mp[l1+0] == h_inf && mp[l1+1] == h_inf) { - return; - } - auto qr = q + m1*stride_q; - for (int i = 0; i < D/16; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i); - if (mp[l1+0] != h_inf) { - auto vsum = _mm512_setzero_ps(); - for (int i = 0; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum); - cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum); - } - if (mp[l1+1] != h_inf) { - auto vsum = _mm512_setzero_ps(); - for (int i = 0; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum); - cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum); - } - } - - template <bool small = is_small_head, class = std::enable_if<!small>> - inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask) { - // q index is q_step*i1 + m1 - // k index is k_step*k1 + l1 - const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); - if (mp[l1] == h_inf) { - cache[k_step*m1 + l1] = -INFINITY; - return; - } - auto qr = q + m1*stride_q; - auto vsum = _mm512_setzero_ps(); - for (int i = 0; i < D/16; ++i) { - vsum = _mm512_fmadd_ps(vk[i], _mm512_loadu_ps(qr + 16*i), vsum); - } - cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum); - } - - inline void update_M_S(int j) { + inline void update_M_S(int j, __m512 * vk) { if (softcap <= 0.0f) { for (int l = 0; l < k_step/16; ++l) vk[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l)); } else { @@ -6464,187 +6403,585 @@ struct FlashAttn { S[j] += reduce_T<_mm512_reduce_add_ps, _mm512_add_ps>(vk); } - inline void normalize_and_store(int j, const float * R, float * qkv) const { - GGML_ASSERT(S[j] > 0); - auto norm = _mm512_set1_ps(1/S[j]); + float cache[q_step*k_step]; + float S[q_step], M[q_step]; + int need_scaling[q_step]; + __m512 vms[q_step]; + const __m512 vscale; + const float softcap; + const ggml_half h_inf; + + typedef __m512 (*combine_t)(__m512, __m512); + typedef float (*reduce_t)(__m512); + template <reduce_t Op, combine_t Op_combine> + static inline float reduce_T(const __m512 * vals) { + float result; + if constexpr (k_step/16 == 1) { + result = Op(vals[0]); + } + else if constexpr (k_step/16 == 2) { + result = Op(Op_combine(vals[0], vals[1])); + } + else { + auto vmax = Op_combine(vals[0], vals[1]); + for (int l = 2; l < k_step/16; ++l) vmax = Op_combine(vmax, vals[l]); + result = Op(vmax); + } + return result; + } +}; + +template <int D, int q_step, int k_step> +struct FlashQKV { + + // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2 + // Hence, for now, we will not handle head sizes of 80 and 112 + template <typename VHelper> + inline void accumulate_qkv(const VHelper& vh, const FlashMS<q_step, k_step>& fms) { + __m512 vk[2*q_step]; + for (int i = 0; i < D/16; i += 2) { + for (int j = 0; j < q_step; ++j) { + if (fms.need_scaling[j] == 2) { + vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps(); + } else { + auto R = qkv_cache + D*j; + vk[2*j+0] = _mm512_loadu_ps(R + 16*i); + vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16); + if (fms.need_scaling[j] == 1) { + vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], fms.vms[j]); + vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], fms.vms[j]); + } + } + } + __m512 v1, v2; + for (int l1 = 0; l1 < k_step; ++l1) { + vh.load(l1, i, v1, v2); + for (int j = 0; j < q_step; ++j) { + auto vs = _mm512_set1_ps(fms.cache[k_step*j + l1]); + vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]); + vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]); + } + } + for (int j = 0; j < q_step; ++j) { + auto R = qkv_cache + D*j; + _mm512_storeu_ps(R + 16*i, vk[2*j+0]); + _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]); + } + } + } + + template <typename VHelper, int Nq = q_step, class = std::enable_if<Nq >= 2>> + inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS<q_step, k_step>& fms) { + __m512 vk[2*q_step]; + for (int i = 0; i < D/16; i += 2) { + for (int j = 0; j < nq1; ++j) { + if (fms.need_scaling[j] == 2) { + vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps(); + } else { + auto R = qkv_cache + D*j; + vk[2*j+0] = _mm512_loadu_ps(R + 16*i); + vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16); + if (fms.need_scaling[j] == 1) { + vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], fms.vms[j]); + vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], fms.vms[j]); + } + } + } + __m512 v1, v2; + for (int l1 = 0; l1 < k_step; ++l1) { + vh.load(l1, i, v1, v2); + for (int j = 0; j < nq1; ++j) { + auto vs = _mm512_set1_ps(fms.cache[k_step*j + l1]); + vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]); + vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]); + } + } + for (int j = 0; j < nq1; ++j) { + auto R = qkv_cache + D*j; + _mm512_storeu_ps(R + 16*i, vk[2*j+0]); + _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]); + } + } + } + + inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const float * R, float * qkv) const { + GGML_ASSERT(fms.S[j] > 0); + auto norm = _mm512_set1_ps(1/fms.S[j]); for (int i = 0; i < D/16; ++i) { auto r = _mm512_loadu_ps(R + 16*i); _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r)); } } - inline void normalize_and_store(int nq1, int stride_qkv, float * qkv) const { + inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int nq1, int stride_qkv, float * qkv) const { auto R = qkv_cache; for (int j = 0; j < nq1; ++j) { - normalize_and_store(j, R, qkv); + normalize_and_store(fms, j, R, qkv); qkv += stride_qkv; R += D; } } - inline void normalize_and_store(int stride_qkv, float * qkv) const { + inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int stride_qkv, float * qkv) const { auto R = qkv_cache; for (int j = 0; j < q_step; ++j) { - normalize_and_store(j, R, qkv); + normalize_and_store(fms, j, R, qkv); qkv += stride_qkv; R += D; } } + float qkv_cache[D*q_step]; +}; + +template <int D, int q_step, int k_step> +struct FlashQKfp32 { + static_assert(D%16 == 0 && D <= 256); + static_assert(k_step%16 == 0); + static_assert(q_step <= 4 || q_step%4 == 0); + + constexpr static bool is_small_head = D <= 128; + + template <bool small = is_small_head, class = std::enable_if<small>> + static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, + __m512 * qv, __m512 * vk, FlashMS<q_step, k_step>& fms) { + // q index is q_step*i1 + m1 + // k index is k_step*k1 + l1 + const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); + fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY; + if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) { + return; + } + auto qr = q + m1*stride_q; + for (int i = 0; i < D/16; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i); + if (mp[l1+0] != fms.h_inf) { + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum); + fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum); + } + if (mp[l1+1] != fms.h_inf) { + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum); + fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum); + } + } + + template <bool small = is_small_head, class = std::enable_if<!small>> + static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, + __m512 * vk, FlashMS<q_step, k_step>& fms) { + // q index is q_step*i1 + m1 + // k index is k_step*k1 + l1 + const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); + if (mp[l1] == fms.h_inf) { + fms.cache[k_step*m1 + l1] = -INFINITY; + return; + } + auto qr = q + m1*stride_q; + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/16; ++i) { + vsum = _mm512_fmadd_ps(vk[i], _mm512_loadu_ps(qr + 16*i), vsum); + } + fms.cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum); + } + template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>> - inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) { + static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, + FlashMS<q_step, k_step>& fms) { __m512 qv[D/16]; + __m512 vk[D/8]; for (int l1 = 0; l1 < k_step; l1 += 2) { kh.load_2(l1, vk); for (int m1 = 0; m1 < q_step; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv); + mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vk, fms); } } } template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>> - inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m, - const float * q, const char * mask) { + static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m, + const float * q, const char * mask, FlashMS<q_step, k_step>& fms) { + __m512 vk[D/16]; for (int l1 = 0; l1 < k_step; ++l1) { kh.load(l1, vk); for (int m1 = 0; m1 < q_step; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask); + mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, vk, fms); } } } template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>> - inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) { + static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, + FlashMS<q_step, k_step>& fms) { __m512 qv[D/16]; + __m512 vk[D/8]; for (int l1 = 0; l1 < k_step; l1 += 2) { kh.load_2(l1, vk); for (int m1 = 0; m1 < nq; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv); + mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vk, fms); } } } template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>> - inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m, - const float * q, const char * mask) { + static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m, + const float * q, const char * mask, FlashMS<q_step, k_step>& fms) { + __m512 vk[D/16]; for (int l1 = 0; l1 < k_step; ++l1) { kh.load(l1, vk); for (int m1 = 0; m1 < nq; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask); + mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, vk, fms); } } } template <typename KHelper> - inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) { + static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, + FlashMS<q_step, k_step>& fms) { if constexpr (is_small_head) { - mult_mask_kq(kh, stride_q, stride_m, q, mask); + mult_mask_kq(kh, stride_q, stride_m, q, mask, fms); } else { - mult_mask_kq_l(kh, stride_q, stride_m, q, mask); + mult_mask_kq_l(kh, stride_q, stride_m, q, mask, fms); } + __m512 vk[k_step/16]; for (int j = 0; j < q_step; ++j) { - update_M_S(j); + fms.update_M_S(j, vk); } } template <typename KHelper> - inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) { + static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, + FlashMS<q_step, k_step>& fms) { if constexpr (is_small_head) { - mult_mask_kq(nq, kh, stride_q, stride_m, q, mask); + mult_mask_kq(nq, kh, stride_q, stride_m, q, mask, fms); } else { - mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask); + mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask, fms); } + __m512 vk[k_step/16]; for (int j = 0; j < nq; ++j) { - update_M_S(j); + fms.update_M_S(j, vk); + } + } +}; + +template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper> +void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, + FlashMS<q_step, k_step>& fms, + FlashQKV<D, q_step, k_step>& fqkv, + const float * q, const char * mask, float * qkv) { + for (int i1 = 0; i1 < nq1/q_step; ++i1) { + fms.init_qstep(); + kh.reset_block(); + vh.reset_block(); + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms); + fqkv.accumulate_qkv(vh, fms); + kh.next_block(); + vh.next_block(); + mr += k_step*sizeof(ggml_half); + } + fqkv.normalize_and_store(fms, stride_qkv, qkv); + + q += q_step*stride_q; + mask += q_step*stride_m; + qkv += q_step*stride_qkv; + } + int n_left = nq1 - q_step*(nq1/q_step); + if (n_left > 0) { + fms.init_qstep(); + kh.reset_block(); + vh.reset_block(); + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms); + fqkv.accumulate_qkv(n_left, vh, fms); + kh.next_block(); + vh.next_block(); + mr += k_step*sizeof(ggml_half); + } + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); + } +} + +// Some of the methods in FlashAttn have two identical implementations that only differ by +// one version using a loop over the template parameter q_step, while the other using a loop +// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot, +// but performance drops signficantly if I remove the version with fixed q_step iterations. +// We only instantiate FlashAttn with q_step = 1 and q_step = 4 or 8 (depending on head size D), +// so when we have to process Nq rows, we process q_step*(Nq/q_step) using fixed q_step loops, +// and use the variable nq version (with lower performance) only for the remaining i1...q_step-1 +// rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to +// process template parameter of such functions, but this would result in the compiler generating +// q_step-1 versions of these functions for us, which I though was too much with q_step = 8. +template <int D, int q_step, int k_step> +struct FlashAttn { + static_assert(D%16 == 0 && D <= 256); + static_assert(k_step%16 == 0); + static_assert(q_step <= 4 || q_step%4 == 0); + + FlashAttn(float scale, float softcap) : fms(scale, softcap) {} + + template <typename KHelper, typename VHelper> + void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, + const float * q, const char * mask, float * qkv) { + compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + } + + FlashMS<q_step, k_step> fms; + FlashQKV<D, q_step, k_step> fqkv; + +}; + +#ifdef __AVX512BF16__ + +template <int D, int step> +struct HelperBF16 final : public BaseHelper<step> { + using Base = BaseHelper<step>; + HelperBF16(const char * data, int stride) : Base(data, stride) {} + inline void load(int l1, __m512bh * vk) const { + auto dr = Base::lblock(l1); + for (int i = 0; i < D/32; ++i) vk[i] = __m512bh(_mm512_loadu_si512((const __m512i*)dr + i)); + } + + inline void load(int l1, int i, __m512& v1, __m512& v2) const { + auto dr = Base::lblock(l1); + v1 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)dr + i + 0)), 16)); + v2 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)dr + i + 1)), 16)); + } + + inline void load_2(int l1, __m512bh * vk) const { + load(l1+0, vk+0); + load(l1+1, vk+D/32); + } + + inline void load_4(int l1, __m512bh * vk) const { + load(l1+0, vk+0); + load(l1+1, vk+1*D/32); + load(l1+2, vk+2*D/32); + load(l1+3, vk+3*D/32); + } +}; + +template <int D, int q_step, int k_step> +struct FlashQKbf16 { + static_assert(D%32 == 0 && D <= 256); + static_assert(k_step%32 == 0); + static_assert(q_step <= 4 || q_step%4 == 0); + + static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, + __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) { + // q index is q_step*i1 + m1 + // k index is k_step*k1 + l1 + const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); + fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY; + if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) { + return; + } + auto qr = q + m1*stride_q; + for (int i = 0; i < D/32; ++i) { + auto val1 = _mm512_loadu_ps(qr + 32*i); + auto val2 = _mm512_loadu_ps(qr + 32*i + 16); + qv[i] = _mm512_cvtne2ps_pbh(val2, val1); + } + if (mp[l1+0] != fms.h_inf) { + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]); + fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum); + } + if (mp[l1+1] != fms.h_inf) { + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]); + fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum); } } - // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2 - // Hence, for now, we will not handle head sizes of 80 and 112 - template <typename VHelper> - inline void accumulate_qkv(const VHelper& vh) { - for (int i = 0; i < D/16; i += 2) { - for (int j = 0; j < q_step; ++j) { - if (need_scaling[j] == 2) { - vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps(); - } else { - auto R = qkv_cache + D*j; - vk[2*j+0] = _mm512_loadu_ps(R + 16*i); - vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16); - if (need_scaling[j] == 1) { - vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]); - vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]); + static inline void mult_mask_kq_one(int l1, int m1, int stride_m, const ggml_bf16_t * q, const char * mask, + __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) { + // q index is q_step*i1 + m1 + // k index is k_step*k1 + l1 + const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); + fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY; + if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) { + return; + } + auto qr = q + m1*D; + for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i*)qr + i)); + if (mp[l1+0] != fms.h_inf) { + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]); + fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum); + } + if (mp[l1+1] != fms.h_inf) { + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]); + fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum); + } + } + + static inline void mult_mask_kq_4(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, + __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) { + // q index is q_step*i1 + m1 + // k index is k_step*k1 + l1 + const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); + fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = + fms.cache[k_step*m1 + l1 + 2] = fms.cache[k_step*m1 + l1 + 3] = -INFINITY; + if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf && mp[l1+2] == fms.h_inf && mp[l1+3] == fms.h_inf) { + return; + } + auto qr = q + m1*stride_q; + for (int i = 0; i < D/32; ++i) { + auto val1 = _mm512_loadu_ps(qr + 32*i); + auto val2 = _mm512_loadu_ps(qr + 32*i + 16); + qv[i] = _mm512_cvtne2ps_pbh(val2, val1); + } + for (int k = 0; k < 4; ++k) { + if (mp[l1+k] == fms.h_inf) continue; + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]); + fms.cache[k_step*m1 + l1 + k] = _mm512_reduce_add_ps(vsum); + } + } + + static inline void mult_mask_kq_4(int l1, int m1, int stride_m, const ggml_bf16_t * q, const char * mask, + __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) { + // q index is q_step*i1 + m1 + // k index is k_step*k1 + l1 + const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); + fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = + fms.cache[k_step*m1 + l1 + 2] = fms.cache[k_step*m1 + l1 + 3] = -INFINITY; + if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf && mp[l1+2] == fms.h_inf && mp[l1+3] == fms.h_inf) { + return; + } + auto qr = q + m1*D; + for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i)); + for (int k = 0; k < 4; ++k) { + if (mp[l1+k] == fms.h_inf) continue; + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]); + fms.cache[k_step*m1 + l1 + k] = _mm512_reduce_add_ps(vsum); + } + } + + template <typename KHelper> + static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, + const char * mask, FlashMS<q_step, k_step>& fms) { + { + __m512bh qv[D/32]; + if constexpr (D <= 128) { + __m512bh vkh[D/8]; + for (int l1 = 0; l1 < k_step; l1 += 4) { + kh.load_4(l1, vkh); + for (int j = 0; j < q_step; ++j) { + mult_mask_kq_4(l1, j, stride_q, stride_m, q, mask, qv, vkh, fms); } } - } - __m512 v1, v2; - for (int l1 = 0; l1 < k_step; ++l1) { - vh.load(l1, i, v1, v2); - for (int j = 0; j < q_step; ++j) { - auto vs = _mm512_set1_ps(cache[k_step*j + l1]); - vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]); - vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]); + } else { + __m512bh vkh[D/16]; + for (int l1 = 0; l1 < k_step; l1 += 2) { + kh.load_2(l1, vkh); + for (int j = 0; j < q_step; ++j) { + mult_mask_kq_one(l1, j, stride_q, stride_m, q, mask, qv, vkh, fms); + } } } - for (int j = 0; j < q_step; ++j) { - auto R = qkv_cache + D*j; - _mm512_storeu_ps(R + 16*i, vk[2*j+0]); - _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]); - } + } + __m512 vk[k_step/16]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk); } } - template <typename VHelper, int Nq = q_step, class = std::enable_if<Nq >= 2>> - inline void accumulate_qkv(int nq1, const VHelper& vh) { - for (int i = 0; i < D/16; i += 2) { - for (int j = 0; j < nq1; ++j) { - if (need_scaling[j] == 2) { - vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps(); - } else { - auto R = qkv_cache + D*j; - vk[2*j+0] = _mm512_loadu_ps(R + 16*i); - vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16); - if (need_scaling[j] == 1) { - vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]); - vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]); + template <typename KHelper> + static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q, + const char * mask, FlashMS<q_step, k_step>& fms) { + { + __m512bh qv[D/32]; + if constexpr (D <= 128) { + __m512bh vkh[D/8]; + for (int l1 = 0; l1 < k_step; l1 += 4) { + kh.load_4(l1, vkh); + for (int j = 0; j < q_step; ++j) { + mult_mask_kq_4(l1, j, stride_m, q, mask, qv, vkh, fms); + } + } + } else { + __m512bh vkh[D/16]; + for (int l1 = 0; l1 < k_step; l1 += 2) { + kh.load_2(l1, vkh); + for (int j = 0; j < q_step; ++j) { + mult_mask_kq_one(l1, j, stride_m, q, mask, qv, vkh, fms); } } } - __m512 v1, v2; - for (int l1 = 0; l1 < k_step; ++l1) { - vh.load(l1, i, v1, v2); - for (int j = 0; j < nq1; ++j) { - auto vs = _mm512_set1_ps(cache[k_step*j + l1]); - vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]); - vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]); + } + __m512 vk[k_step/16]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk); + } + } + + template <typename KHelper> + static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, + const char * mask, FlashMS<q_step, k_step>& fms) { + { + __m512bh qv[D/32]; + __m512bh vkh[D/16]; + for (int l1 = 0; l1 < k_step; l1 += 2) { + kh.load_2(l1, vkh); + for (int m1 = 0; m1 < nq; ++m1) { + mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vkh, fms); } } - for (int j = 0; j < nq1; ++j) { - auto R = qkv_cache + D*j; - _mm512_storeu_ps(R + 16*i, vk[2*j+0]); - _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]); + } + __m512 vk[k_step/16]; + for (int j = 0; j < nq; ++j) { + fms.update_M_S(j, vk); + } + } + + static inline void convert(int stride_q, const float * q, ggml_bf16_t * bf16) { + auto qr = q; + for (int j = 0; j < q_step; ++j) { + for (int i = 0; i < D/32; ++i) { + auto val1 = _mm512_loadu_ps(qr + 32*i); + auto val2 = _mm512_loadu_ps(qr + 32*i + 16); + _mm512_storeu_si512((__m512i *)bf16 + i, (__m512i)_mm512_cvtne2ps_pbh(val2, val1)); } + qr += stride_q; + bf16 += D; } } +}; + +template <int D, int q_step, int k_step> +struct FlashAttnBF16 { + static_assert(D%32 == 0 && D <= 256); + static_assert(k_step%32 == 0); + static_assert(q_step <= 4 || q_step%4 == 0); + + FlashAttnBF16(float scale, float softcap) : fms(scale, softcap) {} template <typename KHelper, typename VHelper> void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { + ggml_bf16_t q_bf16[q_step*D]; for (int i1 = 0; i1 < nq1/q_step; ++i1) { - init_qstep(); + fms.init_qstep(); kh.reset_block(); vh.reset_block(); + FlashQKbf16<D, q_step, k_step>::convert(stride_q, q, q_bf16); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { - multiply_mask_kq(kh, stride_q, stride_m, q, mr); - accumulate_qkv(vh); + FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); + fqkv.accumulate_qkv(vh, fms); kh.next_block(); vh.next_block(); mr += k_step*sizeof(ggml_half); } - normalize_and_store(stride_qkv, qkv); + fqkv.normalize_and_store(fms, stride_qkv, qkv); q += q_step*stride_q; mask += q_step*stride_m; @@ -6652,50 +6989,25 @@ struct FlashAttn { } int n_left = nq1 - q_step*(nq1/q_step); if (n_left > 0) { - init_qstep(); + fms.init_qstep(); kh.reset_block(); vh.reset_block(); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { - multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr); - accumulate_qkv(n_left, vh); + FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms); + fqkv.accumulate_qkv(n_left, vh, fms); kh.next_block(); vh.next_block(); mr += k_step*sizeof(ggml_half); } - normalize_and_store(n_left, stride_qkv, qkv); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); } } - float cache[q_step*k_step]; - float qkv_cache[D*q_step]; - float S[q_step], M[q_step]; - int need_scaling[q_step]; - __m512 vms[q_step]; - __m512 vk[vk_size]; - const __m512 vscale; - const float softcap; - const ggml_half h_inf; - - typedef __m512 (*combine_t)(__m512, __m512); - typedef float (*reduce_t)(__m512); - template <reduce_t Op, combine_t Op_combine> - static inline float reduce_T(const __m512 * vals) { - float result; - if constexpr (k_step/16 == 1) { - result = Op(vals[0]); - } - else if constexpr (k_step/16 == 2) { - result = Op(Op_combine(vals[0], vals[1])); - } - else { - auto vmax = Op_combine(vals[0], vals[1]); - for (int l = 2; l < k_step/16; ++l) vmax = Op_combine(vmax, vals[l]); - result = Op(vmax); - } - return result; - } + FlashMS<q_step, k_step> fms; + FlashQKV<D, q_step, k_step> fqkv; }; +#endif template <int D, int q_step, int k_step, typename KHelper, typename VHelper> inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, @@ -6710,6 +7022,23 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str } } +#ifdef __AVX512BF16__ +template <int D, int q_step, int k_step> +inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, + const float * q, const char * k, const char * v, const char * mask, + float scale, float softcap, float * qkv) { + HelperBF16<D, k_step> kh(k, stride_k); + HelperBF16<D, k_step> vh(v, stride_v); + if (nq1 >= q_step) { + FlashAttnBF16<D, q_step, k_step> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { + FlashAttnBF16<D, 1, k_step> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } +} +#endif + template <int D, int q_step, int k_step, typename KHelper> inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv, @@ -6766,7 +7095,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, } inline bool flash_attn_is_supported(ggml_type type) { +#ifdef __AVX512BF16__ + return type == GGML_TYPE_F16 || type == GGML_TYPE_BF16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1; +#else return type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1; +#endif } } @@ -6799,14 +7132,33 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k stride_q /= sizeof(float); // q stride as float +#ifdef __AVX512BF16__ + if (type_k == GGML_TYPE_BF16 && type_v == GGML_TYPE_BF16) { + switch (D) { + case 64: + iqk_flash_helper_T< 64, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 96: + iqk_flash_helper_T< 96, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 128: + iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 256: + iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + default: + return false; + } + + return true; + } +#endif + switch (D) { case 64: - iqk_flash_helper_T< 64, 4, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 8, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: // iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 4, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 8, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: // iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; |