summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_flash_attn.cpp
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-03-07 09:46:58 +0200
committerGitHub <noreply@github.com>2025-03-07 09:46:58 +0200
commit3d85a1d66302989401f92a5ae347577b03cbdaa7 (patch)
tree7d9ea15568de65954ebddbf71792ad781841fd7f /ggml/src/iqk/iqk_flash_attn.cpp
parentc67a37b251fc22b0f8b8313ea5c76a73ff6ed49f (diff)
Better FlashMLA (#243)
* This is a better FA for TG It should benefit MLA and GQA. Tested to work with DeepSeek-Lite MLA, not yet for GQA. For tg64@pp8192 it is ~13% faster than MLA without FA, and 57% faster that the main branch FA. * WIP * Cleanup --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/iqk/iqk_flash_attn.cpp')
-rw-r--r--ggml/src/iqk/iqk_flash_attn.cpp194
1 files changed, 194 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp
new file mode 100644
index 00000000..fecd818b
--- /dev/null
+++ b/ggml/src/iqk/iqk_flash_attn.cpp
@@ -0,0 +1,194 @@
+#include "iqk_config.h"
+#include "iqk_mul_mat.h"
+#include "iqk_flash_impl.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include <algorithm>
+#include <cstdio>
+#include <vector>
+#include <cstdint>
+#include <cstring>
+#include <cmath>
+
+namespace {
+inline uint32_t simple_gcd(uint32_t a, uint32_t b) {
+ while (a != b) {
+ if (a > b) a -= b;
+ else b -= a;
+ }
+ return a;
+}
+}
+
+// TODO: get the ggml_type enum here without polution
+//
+bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
+ int neq3, int neq2, long nbq3, long nbq2,
+ int nek3, int nek2, long nbk3, long nbk2,
+ int nev3, int nev2, long nbv3, long nbv2,
+ int ne2, int ne1, long nb1,
+ int int_type_k, // type of k
+ int int_type_v, // type of v
+ int Dk, // K head size
+ int Dv, // V head size
+ int neq1, // number of columns in q
+ int nek1, // number of rows in k
+ int stride_q, // distance between q columns in bytes
+ int stride_k, // distance between k rows in bytes
+ int stride_v, // distance between v rows in bytes
+ int stride_m, // distance between mask rows (in bytes
+ const void * q, // q matrix.
+ const void * k, // k matrix. Assumed to be fp16, nq x nk elements
+ const void * v, // v matrix. Assumed to be fp16, nq x nk elements
+ const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
+ float scale, // scale applied before softmax
+ float softcap, // if > 0, a "soft-cap" operation is applied before softmax
+ float * qkv, // v*softmax(scale*(k*q))
+ [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data,
+ int ith, int nth) {
+
+ if (type_q != 0 || type_mask != 1 || max_bias > 0) return false;
+
+ int rk2 = neq2/nek2;
+ int rv2 = neq2/nev2;
+ int rk3 = neq3/nek3;
+ int rv3 = neq3/nev3;
+
+ // Getting confused all the time about where to load data from and store the results to
+ // (especially when combining the results from the threads).
+ // So, for now, making it work just for MLA (nek2 = 1).
+ // I think it would also speed up things for GQA, but I'm leaving this for another day.
+ if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nth >= 1 && nek1/32 > 1 && nek2 == 1) {
+ int nstep_k = nek1/32;
+ int gcd_k = simple_gcd(nstep_k, nth);
+ if (gcd_k >= 1) {
+ int nth_k = nth/gcd_k;
+ if (rk2%nth_k == 0) {
+ int ith_k = ith%gcd_k;
+ int ith_q = ith/gcd_k;
+ auto kth = (const char *)k + ith_k*(nek1/gcd_k)*stride_k;
+ auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v;
+ auto qth = (const char *)q + ith_q*(rk2/nth_k)*nbq2;
+ auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here
+ auto work = (char *)work_buffer;
+
+ // Each thread will produce a result of size Dv*(rk2/nth_k)*sizeof(float)
+ // In addition, we need M, S for the rk2/nth_k rows the thread is processing
+ // => (Dv + 2)*rk2/nth_k*sizeof(float). We use (Dv + 16) instead to make sure threads are not
+ // writing onto the same cache line.
+ auto size_thread = (Dv + 16)*rk2/nth_k*sizeof(float);
+ auto result_buffer = work;
+ auto work_this_thread = (float *)(result_buffer + ith*size_thread);
+ if (!iqk_flash_attn_impl(int_type_k, int_type_v,
+ Dk, Dv, rk2/nth_k, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
+ (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth,
+ scale, softcap,
+ work_this_thread, work_this_thread + (Dv+0)*rk2/nth_k, work_this_thread + (Dv+1)*rk2/nth_k)) return false;
+
+ barrier(barrier_data);
+
+ // TODO: simdify this
+ for (int j = ith; j < rk2; j += nth) {
+ auto Racc = qkv + j*nb1/sizeof(float);
+ float M = -INFINITY, S = 0;
+ int jth_q = j/(rk2/nth_k);
+ int jj = j%(rk2/nth_k);
+ for (int j1 = 0; j1 < rk2/nth_k; ++j1) {
+ auto R = (const float *)(result_buffer + (jth_q*(rk2/nth_k) + j1)*size_thread);
+ auto Mj = R + Dv*rk2/nth_k;
+ auto Sj = Mj + rk2/nth_k;
+ R += jj*Dv;
+ if (Mj[jj] == -INFINITY) continue;
+ if (Mj[jj] > M) {
+ if (M == -INFINITY) {
+ std::memcpy(Racc, R, Dv*sizeof(float));
+ S = Sj[jj];
+ } else {
+ float c = exp(M - Mj[jj]);
+ S = c*S + Sj[jj];
+ for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i];
+ }
+ M = Mj[jj];
+ } else {
+ float c = exp(Mj[jj] - M);
+ S += c*Sj[jj];
+ for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
+ }
+ }
+ float norm = S > 0 ? 1/S : 1;
+ for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
+ }
+ return true;
+
+ }
+ }
+ }
+
+ // I keep changing my mind what is the best strategy to split the threads when processing
+ // multiple heads. This is my current thinking, the commented out code below was the previous.
+ int ntg = nth/simple_gcd(neq2*neq3, nth);
+ int neq1g = (neq1 + ntg - 1)/ntg;
+ //int64_t work_per_slice = D*nek1*neq1;
+ //int ntg = 1;
+ //
+ // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix
+ // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of
+ // the number of threads processing the (iq2, iq3) matrix.
+ //
+ //if (neq1 >= 8*nth) {
+ // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
+ // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
+ // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
+ //}
+ int counter = 0;
+ for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
+ for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
+ if (counter++ % (nth/ntg) == ith/ntg) {
+ int iq1 = (ith%ntg)*neq1g;
+ int this_neq1 = std::min(neq1g, neq1-iq1);
+ if (!iqk_flash_attn_impl(int_type_k, int_type_v,
+ Dk, Dv, this_neq1, nek1, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float),
+ (const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
+ (const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3),
+ (const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3),
+ (const void *)((const char *)mask + iq1*stride_m),
+ scale, softcap,
+ (float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+#else
+
+bool iqk_flash_attn_noalibi([[maybe_unused]] int type_q, [[maybe_unused]] int type_mask, [[maybe_unused]] float max_bias,
+ [[maybe_unused]] int neq3, [[maybe_unused]] int neq2, [[maybe_unused]] long nbq3, [[maybe_unused]] long nbq2,
+ [[maybe_unused]] int nek3, [[maybe_unused]] int nek2, [[maybe_unused]] long nbk3, [[maybe_unused]] long nbk2,
+ [[maybe_unused]] int nev3, [[maybe_unused]] int nev2, [[maybe_unused]] long nbv3, [[maybe_unused]] long nbv2,
+ [[maybe_unused]] int ne2, [[maybe_unused]] int ne1, [[maybe_unused]] long nb1,
+ [[maybe_unused]] int int_type_k, // type of k
+ [[maybe_unused]] int int_type_v, // type of v
+ [[maybe_unused]] int D, // head size
+ [[maybe_unused]] int nq, // number of columns in q
+ [[maybe_unused]] int nk, // number of rows in k
+ [[maybe_unused]] int stride_q, // distance between q columns in bytes
+ [[maybe_unused]] int stride_k, // distance between k rows in bytes
+ [[maybe_unused]] int stride_v, // distance between v rows in bytes
+ [[maybe_unused]] int stride_m, // distance between mask rows (in bytes
+ [[maybe_unused]] const void * q, // q matrix.
+ [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
+ [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
+ [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
+ [[maybe_unused]] float scale, // scale applied before softmax
+ [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
+ [[maybe_unused]] float * qkv, // v*softmax(scale*(k*q))
+ [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data,
+ [[maybe_unused]] int ith, [[maybe_unused]] int nth) {
+ return false;
+}
+
+#endif
+