diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-09-01 16:08:21 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-01 16:08:21 +0300 |
commit | dc023bc3be1a7ac42d1030f86c4d77563a019286 (patch) | |
tree | 565cc8a7be7d54ac164c0e7efc23b9dadf06cd92 | |
parent | dbb1db989991025881679a60b0a81a92d2fa471b (diff) |
Zen4 Flash Attention (#32)
* Zen4 flash attention: moving useful parts from the kq_fused_softmax branch
* Add flash attention with soft-cap and fix D = 256 case
* Flash attention refinements
* Update FlashAttn comment
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml.c | 32 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 587 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.h | 15 |
3 files changed, 634 insertions, 0 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index cebac584..4546eac3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -16149,6 +16149,38 @@ static void ggml_compute_forward_flash_attn_ext_f16( scale /= softcap; } +#if GGML_USE_IQK_MULMAT + if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16 && + mask && mask->type == GGML_TYPE_F16) { + int64_t work_per_slice = D*nek1*neq1; + int ntg = 1; + if (nth%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; + else if (nth%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; + else if (nth%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; + if ((neq2*neq3)%(nth/ntg) == 0) { + //if (ith == 0) printf("%s: D = %d, neq2 = %d, neq1 = %d, nek1 = %d\n", __func__, (int)D, (int)neq2, (int)neq1, (int)nek1); + int counter = 0; + for (int64_t iq3 = 0; iq3 < neq3; iq3++) { + for (int64_t iq2 = 0; iq2 < neq2; iq2++) { + if (counter++ % (nth/ntg) == ith/ntg) { + int iq1 = (ith%ntg)*neq1/ntg; + if (!iqk_flash_attn_noalibi(D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), + (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), + (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), + (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), + (const void *)((const char *)mask->data + iq1*mask->nb[1]), + scale, softcap, + (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable; + } + } + } + return; + } +IQK_Flash_Attn_NotAvailable:; + } + +#endif + const uint32_t n_head = neq2; const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 32ddb3ff..55dd016c 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -5915,6 +5915,575 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { #endif // __aarch64__ +namespace { + +#if defined(__ARM_NEON) && defined(__aarch64__) +// copy-pasted from Justine Tunney's contribution to llama.cpp +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline float32x4_t v_expf(float32x4_t x) { + const float32x4_t r = vdupq_n_f32(0x1.8p23f); + const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f)); + const float32x4_t n = vsubq_f32(z, r); + const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n, + vdupq_n_f32(0x1.7f7d1cp-20f)); + const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23); + const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1)))); + const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126)); + const float32x4_t u = vmulq_f32(b, b); + const float32x4_t j = vfmaq_f32( + vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b), + vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b), + vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u); + if (!vpaddd_u64(vreinterpretq_u64_u32(c))) + return vfmaq_f32(k, j, k); + const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000)); + const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000))); + const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d)); + return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1), + vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); +} +inline float32x4_t v_tanh(float32x4_t x) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f)); + const float32x4_t exp_two_x = v_expf(two_x); + const uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f)); + const float32x4_t res = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); + return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask))); + //return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); +} +#endif + +#if defined(__AVX512F__) && defined(__AVX512DQ__) + +// copy-pasted from Justine Tunney's contribution to llama.cpp +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline __m512 v_expf(__m512 x) { + const __m512 r = _mm512_set1_ps(0x1.8p23f); + const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); + const __m512 n = _mm512_sub_ps(z, r); + const __m512 b = + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); + const __mmask16 d = + _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); + const __m512 u = _mm512_mul_ps(b, b); + const __m512 j = _mm512_fmadd_ps( + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, + _mm512_set1_ps(0x1.573e2ep-5f)), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, + _mm512_set1_ps(0x1.fffdb6p-2f))), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); + const __m512 res = _mm512_scalef_ps(j, n); + if (_mm512_kortestz(d, d)) + return res; + const __m512 zero = _mm512_setzero_ps(); + const __m512 alt = _mm512_mask_blend_ps( + _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); + return _mm512_mask_blend_ps(d, res, alt); +} +inline __m512 v_tanh(__m512 x) { + const __m512 one = _mm512_set1_ps(1.0f); + const __m512 exp_two_x = v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f))); + const __mmask16 mask = _mm512_cmp_ps_mask(x, _mm512_set1_ps(10.f), _CMP_GT_OQ); + const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); + return _mm512_mask_blend_ps(mask, res, one); +} +#endif + +#if defined(__AVX2__) && defined(__FMA__) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline __m256 v_expf(__m256 x) { + const __m256 r = _mm256_set1_ps(0x1.8p23f); + const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r); + const __m256 n = _mm256_sub_ps(z, r); + const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f), + _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x)); + const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23); + const __m256 k = _mm256_castsi256_ps( + _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1)))); + const __m256i c = _mm256_castps_si256( + _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), + _mm256_set1_ps(126), _CMP_GT_OQ)); + const __m256 u = _mm256_mul_ps(b, b); + const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b, + _mm256_set1_ps(0x1.573e2ep-5f)), u, + _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b, + _mm256_set1_ps(0x1.fffdb6p-2f))), + u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b)); + if (!_mm256_movemask_ps(_mm256_castsi256_ps(c))) + return _mm256_fmadd_ps(j, k, k); + const __m256i g = _mm256_and_si256( + _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)), + _mm256_set1_epi32(0x82000000u)); + const __m256 s1 = + _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u))); + const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g)); + const __m256i d = _mm256_castps_si256( + _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), + _mm256_set1_ps(192), _CMP_GT_OQ)); + return _mm256_or_ps( + _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)), + _mm256_andnot_ps( + _mm256_castsi256_ps(d), + _mm256_or_ps( + _mm256_and_ps(_mm256_castsi256_ps(c), + _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)), + _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k))))); +} +inline __m256 v_tanh(__m256 x) { + const __m256 one = _mm256_set1_ps(1.0f); + const __m256 exp_two_x = v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f))); + const __m256 res = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one)); + const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ); + return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res)); +} + +#endif +} // namespace + +#ifdef HAVE_FANCY_SIMD + +namespace { + +// Some of the methods in FlashAttn have two identical implementations that only differ by +// one version using a loop over the template parameter q_step, while the other using a loop +// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot, +// but performance drops signficantly if I remove the version with fixed q_step iterations. +// We only instantiate FlashAttn with q_step = 1 and q_step = 4 or 8 (depending on head size D), +// so when we have to process Nq rows, we process q_step*(Nq/q_step) using fixed q_step loops, +// and use the variable nq version (with lower performance) only for the remaining i1...q_step-1 +// rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to +// process template parameter of such functions, but this would result in the compiler generating +// q_step-1 versions of these functions for us, which I though was too much with q_step = 8. +template <int D, int q_step, int k_step> +struct FlashAttn { + static_assert(D%16 == 0 && D <= 256); + static_assert(k_step%16 == 0); + static_assert(q_step <= 4 || q_step%4 == 0); + + constexpr static bool is_small_head = D <= 128; + + constexpr static int vk_size = is_small_head ? D/8 : D/16; + static_assert(2*q_step <= vk_size); + + FlashAttn(float scale, float softcap) : vscale(_mm512_set1_ps(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {} + + inline void init_qstep() { + for (int j = 0; j < q_step; ++j) { + S[j] = 0; M[j] = -INFINITY; + } + } + + template <bool small = is_small_head, class = std::enable_if<small>> + inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, __m512 * qv) { + // q index is q_step*i1 + m1 + // k index is k_step*k1 + l1 + const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); + cache[k_step*m1 + l1 + 0] = cache[k_step*m1 + l1 + 1] = -INFINITY; + if (mp[l1+0] == h_inf && mp[l1+1] == h_inf) { + return; + } + auto qr = q + m1*stride_q; + for (int i = 0; i < D/16; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i); + if (mp[l1+0] != h_inf) { + auto vsum = _mm512_mul_ps(vk[0], qv[0]); + for (int i = 1; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum); + cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum); + } + if (mp[l1+1] != h_inf) { + auto vsum = _mm512_mul_ps(vk[D/16], qv[0]); + for (int i = 1; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum); + cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum); + } + } + + template <bool small = is_small_head, class = std::enable_if<!small>> + inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask) { + // q index is q_step*i1 + m1 + // k index is k_step*k1 + l1 + const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); + if (mp[l1] == h_inf) { + cache[k_step*m1 + l1] = -INFINITY; + return; + } + auto qr = q + m1*stride_q; + auto vsum = _mm512_mul_ps(vk[0], _mm512_loadu_ps(qr)); + for (int i = 0; i < D/16; ++i) { + vsum = _mm512_fmadd_ps(vk[i], _mm512_loadu_ps(qr + 16*i), vsum); + } + cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum); + } + + inline void update_M_S(int j) { + if (softcap <= 0.0f) { + for (int l = 0; l < k_step/16; ++l) vk[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l)); + } else { + auto v_softcap = _mm512_set1_ps(softcap); + for (int l = 0; l < k_step/16; ++l) { + auto val = _mm512_loadu_ps(cache + k_step*j + 16*l); + vk[l] = _mm512_mul_ps(v_softcap, v_tanh(_mm512_mul_ps(vscale, val))); + } + } + + float smax = reduce_T<_mm512_reduce_max_ps, _mm512_max_ps>(vk); + need_scaling[j] = 0; + if (smax > M[j]) { + if (M[j] > -INFINITY) { + float m = expf(M[j] - smax); + vms[j] = _mm512_set1_ps(m); + need_scaling[j] = 1; + S[j] *= m; + } else { + need_scaling[j] = 2; + S[j] = 0; + } + M[j] = smax; + } + auto vm = _mm512_set1_ps(M[j]); + for (int l = 0; l < k_step/16; ++l) { + vk[l] = v_expf(_mm512_sub_ps(vk[l], vm)); + _mm512_storeu_ps(cache + k_step*j + 16*l, vk[l]); + } + S[j] += reduce_T<_mm512_reduce_add_ps, _mm512_add_ps>(vk); + } + + inline void normalize_and_store(int j, const float * R, float * qkv) const { + GGML_ASSERT(S[j] > 0); + auto norm = _mm512_set1_ps(1/S[j]); + for (int i = 0; i < D/16; ++i) { + auto r = _mm512_loadu_ps(R + 16*i); + _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r)); + } + } + + inline void normalize_and_store(int nq1, int stride_qkv, float * qkv) const { + auto R = qkv_cache; + for (int j = 0; j < nq1; ++j) { + normalize_and_store(j, R, qkv); + qkv += stride_qkv; + R += D; + } + } + + inline void normalize_and_store(int stride_qkv, float * qkv) const { + auto R = qkv_cache; + for (int j = 0; j < q_step; ++j) { + normalize_and_store(j, R, qkv); + qkv += stride_qkv; + R += D; + } + } + + template <bool small = is_small_head, class = std::enable_if<small>> + inline void mult_mask_kq(int stride_k, int stride_q, int stride_m, + const char * k, const float * q, const char * mask) { + __m512 qv[D/16]; + for (int l1 = 0; l1 < k_step; l1 += 2) { + auto kr1 = (const ggml_half *)(k + (l1 + 0)*stride_k); + auto kr2 = (const ggml_half *)(k + (l1 + 1)*stride_k); + for (int i = 0; i < D/16; ++i) vk[i+ 0] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr1 + i)); + for (int i = 0; i < D/16; ++i) vk[i+D/16] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr2 + i)); + for (int m1 = 0; m1 < q_step; ++m1) { + mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv); + } + } + } + + template <bool small = is_small_head, class = std::enable_if<!small>> + inline void mult_mask_kq_l(int stride_k, int stride_q, int stride_m, + const char * k, const float * q, const char * mask) { + for (int l1 = 0; l1 < k_step; ++l1) { + auto kr = (const ggml_half *)(k + l1*stride_k); + for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)); + for (int m1 = 0; m1 < q_step; ++m1) { + mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask); + } + } + } + + template <bool small = is_small_head, class = std::enable_if<small>> + inline void mult_mask_kq(int nq, int stride_k, int stride_q, int stride_m, + const char * k, const float * q, const char * mask) { + __m512 qv[D/16]; + for (int l1 = 0; l1 < k_step; l1 += 2) { + auto kr1 = (const ggml_half *)(k + (l1 + 0)*stride_k); + auto kr2 = (const ggml_half *)(k + (l1 + 1)*stride_k); + for (int i = 0; i < D/16; ++i) vk[i+ 0] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr1 + i)); + for (int i = 0; i < D/16; ++i) vk[i+D/16] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr2 + i)); + for (int m1 = 0; m1 < nq; ++m1) { + mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv); + } + } + } + + template <bool small = is_small_head, class = std::enable_if<!small>> + inline void mult_mask_kq_l(int nq, int stride_k, int stride_q, int stride_m, + const char * k, const float * q, const char * mask) { + for (int l1 = 0; l1 < k_step; ++l1) { + auto kr = (const ggml_half *)(k + l1*stride_k); + for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)); + for (int m1 = 0; m1 < nq; ++m1) { + mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask); + } + } + } + + inline void multiply_mask_kq(int stride_k, int stride_q, int stride_m, const char * k, const float * q, const char * mask) { + if constexpr (is_small_head) { + mult_mask_kq(stride_k, stride_q, stride_m, k, q, mask); + } + else { + mult_mask_kq_l(stride_k, stride_q, stride_m, k, q, mask); + } + for (int j = 0; j < q_step; ++j) { + update_M_S(j); + } + } + + inline void multiply_mask_kq(int nq, int stride_k, int stride_q, int stride_m, const char * k, const float * q, const char * mask) { + if constexpr (is_small_head) { + mult_mask_kq(nq, stride_k, stride_q, stride_m, k, q, mask); + } + else { + mult_mask_kq_l(nq, stride_k, stride_q, stride_m, k, q, mask); + } + for (int j = 0; j < nq; ++j) { + update_M_S(j); + } + } + + // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2 + // Hence, for now, we will not handle head sizes of 80 and 112 + inline void accumulate_qkv(int stride_v, const char * v) { + for (int i = 0; i < D/16; i += 2) { + for (int j = 0; j < q_step; ++j) { + if (need_scaling[j] == 2) { + vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps(); + } else { + auto R = qkv_cache + D*j; + vk[2*j+0] = _mm512_loadu_ps(R + 16*i); + vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16); + if (need_scaling[j] == 1) { + vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]); + vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]); + } + } + } + for (int l1 = 0; l1 < k_step; ++l1) { + auto vr = (const ggml_half *)(v + l1*stride_v); + auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0)); + auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1)); + for (int j = 0; j < q_step; ++j) { + auto vs = _mm512_set1_ps(cache[k_step*j + l1]); + vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]); + vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]); + } + } + for (int j = 0; j < q_step; ++j) { + auto R = qkv_cache + D*j; + _mm512_storeu_ps(R + 16*i, vk[2*j+0]); + _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]); + } + } + } + + template <int Nq = q_step, class = std::enable_if<Nq >= 2>> + inline void accumulate_qkv(int nq1, int stride_v, const char * v) { + for (int i = 0; i < D/16; i += 2) { + for (int j = 0; j < nq1; ++j) { + if (need_scaling[j] == 2) { + vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps(); + } else { + auto R = qkv_cache + D*j; + vk[2*j+0] = _mm512_loadu_ps(R + 16*i); + vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16); + if (need_scaling[j] == 1) { + vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]); + vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]); + } + } + } + for (int l1 = 0; l1 < k_step; ++l1) { + auto vr = (const ggml_half *)(v + l1*stride_v); + auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0)); + auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1)); + for (int j = 0; j < nq1; ++j) { + auto vs = _mm512_set1_ps(cache[k_step*j + l1]); + vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]); + vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]); + } + } + for (int j = 0; j < nq1; ++j) { + auto R = qkv_cache + D*j; + _mm512_storeu_ps(R + 16*i, vk[2*j+0]); + _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]); + } + } + } + + void compute(int nq1, int nk1, int stride_k, int stride_q, int stride_m, int stride_v, int stride_qkv, + const char * k, const float * q, const char * mask, const char * v, float * qkv) { + for (int i1 = 0; i1 < nq1/q_step; ++i1) { + init_qstep(); + auto kr = k; + auto vr = v; + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + multiply_mask_kq(stride_k, stride_q, stride_m, kr, q, mr); + accumulate_qkv(stride_v, vr); + kr += k_step*stride_k; + vr += k_step*stride_v; + mr += k_step*sizeof(ggml_half); + } + normalize_and_store(stride_qkv, qkv); + + q += q_step*stride_q; + mask += q_step*stride_m; + qkv += q_step*stride_qkv; + } + int n_left = nq1 - q_step*(nq1/q_step); + if (n_left > 0) { + init_qstep(); + auto kr = k; + auto vr = v; + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + multiply_mask_kq(n_left, stride_k, stride_q, stride_m, kr, q, mr); + accumulate_qkv(n_left, stride_v, vr); + kr += k_step*stride_k; + vr += k_step*stride_v; + mr += k_step*sizeof(ggml_half); + } + normalize_and_store(n_left, stride_qkv, qkv); + } + } + + float cache[q_step*k_step]; + float qkv_cache[D*q_step]; + float S[q_step], M[q_step]; + int need_scaling[q_step]; + __m512 vms[q_step]; + __m512 vk[vk_size]; + const __m512 vscale; + const float softcap; + const ggml_half h_inf; + + typedef __m512 (*combine_t)(__m512, __m512); + typedef float (*reduce_t)(__m512); + template <reduce_t Op, combine_t Op_combine> + static inline float reduce_T(const __m512 * vals) { + float result; + if constexpr (k_step/16 == 1) { + result = Op(vals[0]); + } + else if constexpr (k_step/16 == 2) { + result = Op(Op_combine(vals[0], vals[1])); + } + else { + auto vmax = Op_combine(vals[0], vals[1]); + for (int l = 2; l < k_step/16; ++l) vmax = Op_combine(vmax, vals[l]); + result = Op(vmax); + } + return result; + } +}; + +template <int D, int q_step, int k_step> +inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, + const float * q, const char * k, const char * v, const char * mask, + float scale, float softcap, float * qkv) { + if (nq1 >= q_step) { + FlashAttn<D, q_step, k_step> fa(scale, softcap); + fa.compute(nq1, nk1, stride_k, stride_q, stride_m, stride_v, stride_qkv, + (const char *)k, q, (const char *)mask, (const char *)v, qkv); + } else { + FlashAttn<D, 1, k_step> fa(scale, softcap); + fa.compute(nq1, nk1, stride_k, stride_q, stride_m, stride_v, stride_qkv, + (const char *)k, q, (const char *)mask, (const char *)v, qkv); + } +} +} + +bool iqk_flash_attn_noalibi(int D, // head size + int nq1, // number of columns in q + int nk1, // number of rows in k + int stride_q, // distance between q columns in bytes + int stride_k, // distance between k rows in bytes + int stride_v, // distance between v rows in bytes + int stride_m, // distance between mask rows (in bytes + int stride_qkv, // distance between rows in mask (in bytes) + const float * q, // q matrix. + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, // v matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + float scale, // scale applied before softmax + float softcap, // if > 0, a "soft-cap" operation is applied before softmax + float * qkv) { // v*softmax(scale*(k*q)) + + if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 + + auto ck = (const char *)k; + auto cv = (const char *)v; + auto cm = (const char *)mask; + + stride_q /= sizeof(float); // q stride as float + + switch (D) { + case 64: + iqk_flash_helper_T< 64, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + // Disable until we fix accumulate_qkv for odd D/16 + //case 80: + // iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 96: + iqk_flash_helper_T< 96, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + // Disable until we fix accumulate_qkv for odd D/16 + //case 112: + // iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 128: + iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 256: + iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + default: + return false; + } + + return true; +} + +#else +// TODO +bool iqk_flash_attn_noalibi([[maybe_unused]] int D, // head size + [[maybe_unused]] int nq, // number of columns in q + [[maybe_unused]] int nk, // number of rows in k + [[maybe_unused]] int stride_q, // distance between q columns in bytes + [[maybe_unused]] int stride_k, // distance between k rows in bytes + [[maybe_unused]] int stride_v, // distance between v rows in bytes + [[maybe_unused]] int stride_m, // distance between mask rows (in bytes + [[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes) + [[maybe_unused]] const float * q, // q matrix. + [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements + [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements + [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + [[maybe_unused]] float scale, // scale applied before softmax + [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax + [[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q)) + return false; +} + +#endif + #else // IQK_IMPLEMENT bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) { @@ -5926,4 +6495,22 @@ bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const return false; } +bool iqk_flash_attn_noalibi([[maybe_unused]] int D, // head size + [[maybe_unused]] int nq, // number of columns in q + [[maybe_unused]] int nk, // number of rows in k + [[maybe_unused]] int stride_q, // distance between q columns in bytes + [[maybe_unused]] int stride_k, // distance between k rows in bytes + [[maybe_unused]] int stride_v, // distance between v rows in bytes + [[maybe_unused]] int stride_m, // distance between mask rows (in bytes + [[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes) + [[maybe_unused]] const float * q, // q matrix. + [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements + [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements + [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + [[maybe_unused]] float scale, // scale applied before softmax + [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax + [[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q)) + return false; +} + #endif diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 6bed5f5a..de50bc6b 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -21,6 +21,21 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); +bool iqk_flash_attn_noalibi(int D, // head size + int nq, // number of columns in q + int nk, // number of rows in k + int stride_q, // distance between q columns in bytes + int stride_k, // distance between k rows in bytes + int stride_v, // distance between v rows in bytes + int stride_m, // distance between mask rows (in bytes + int stride_qkv, // distance between rows in mask (in bytes) + const float * q, // q matrix. + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, // v matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + float scale, // scale applied before softmax + float softcap, // if > 0, a "soft-cap" operation is applied before softmax + float * qkv); // v*softmax(scale*(k*q)) #ifdef __cplusplus } |