From dc023bc3be1a7ac42d1030f86c4d77563a019286 Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Sun, 1 Sep 2024 16:08:21 +0300 Subject: Zen4 Flash Attention (#32) * Zen4 flash attention: moving useful parts from the kq_fused_softmax branch * Add flash attention with soft-cap and fix D = 256 case * Flash attention refinements * Update FlashAttn comment --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) (limited to 'ggml/src/ggml.c') diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index cebac584..4546eac3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -16149,6 +16149,38 @@ static void ggml_compute_forward_flash_attn_ext_f16( scale /= softcap; } +#if GGML_USE_IQK_MULMAT + if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16 && + mask && mask->type == GGML_TYPE_F16) { + int64_t work_per_slice = D*nek1*neq1; + int ntg = 1; + if (nth%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; + else if (nth%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; + else if (nth%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; + if ((neq2*neq3)%(nth/ntg) == 0) { + //if (ith == 0) printf("%s: D = %d, neq2 = %d, neq1 = %d, nek1 = %d\n", __func__, (int)D, (int)neq2, (int)neq1, (int)nek1); + int counter = 0; + for (int64_t iq3 = 0; iq3 < neq3; iq3++) { + for (int64_t iq2 = 0; iq2 < neq2; iq2++) { + if (counter++ % (nth/ntg) == ith/ntg) { + int iq1 = (ith%ntg)*neq1/ntg; + if (!iqk_flash_attn_noalibi(D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), + (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), + (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), + (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), + (const void *)((const char *)mask->data + iq1*mask->nb[1]), + scale, softcap, + (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable; + } + } + } + return; + } +IQK_Flash_Attn_NotAvailable:; + } + +#endif + const uint32_t n_head = neq2; const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); -- cgit v1.2.3