diff options
Diffstat (limited to 'ggml/src/iqk/iqk_flash_attn.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_flash_attn.cpp | 147 |
1 files changed, 110 insertions, 37 deletions
diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index 0de68b94..fd0d5dd0 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -25,6 +25,24 @@ inline uint32_t simple_gcd(uint32_t a, uint32_t b) { } return a; } +inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float * Racc, const float * R) { + if (Mj == -INFINITY) return; + if (Mj > M) { + if (M == -INFINITY) { + std::memcpy(Racc, R, Dv*sizeof(float)); + S = Sj; + } else { + float c = exp(M - Mj); + S = c*S + Sj; + for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; + } + M = Mj; + } else { + float c = exp(Mj - M); + S += c*Sj; + for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; + } +} } // TODO: get the ggml_type enum here without polution @@ -34,7 +52,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int nek3, int nek2, long nbk3, long nbk2, int nev3, int nev2, long nbv3, long nbv2, int ne2, int ne1, long nb1, - int int_type_k, // type of k + int int_type_k_in, // type of k int int_type_v, // type of v int Dk, // K head size int Dv, // V head size @@ -51,7 +69,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) - [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, + [[maybe_unused]] void * work_buffer_in, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, int ith, int nth) { if (type_q != 0 || type_mask != 1 || max_bias > 0) return false; @@ -61,6 +79,29 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int rk3 = neq3/nek3; int rv3 = neq3/nev3; + int int_type_k = int_type_k_in; + auto work_buffer = work_buffer_in; + if (neq1 >= 8 || rk2 >= 8) { + uint64_t row_size = 0; + work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size); + if (int_type_k != int_type_k_in) { + stride_k = row_size; + nbk2 = stride_k*nek1; + nbk3 = nbk2*nek2; + k = work_buffer_in; + barrier(barrier_data); + } + } + //uint64_t row_size = 0; + //auto work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size); + //if (int_type_k != int_type_k_in) { + // stride_k = row_size; + // nbk2 = stride_k*nek1; + // nbk3 = nbk2*nek2; + // k = work_buffer_in; + // barrier(barrier_data); + //} + // Getting confused all the time about where to load data from and store the results to // (especially when combining the results from the threads). // So, for now, making it work just for MLA (nek2 = 1). @@ -128,22 +169,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float auto Mj = R + Dv*nq_this_j; auto Sj = Mj + nq_this_j; R += jj*Dv; - if (Mj[jj] == -INFINITY) continue; - if (Mj[jj] > M) { - if (M == -INFINITY) { - std::memcpy(Racc, R, Dv*sizeof(float)); - S = Sj[jj]; - } else { - float c = exp(M - Mj[jj]); - S = c*S + Sj[jj]; - for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; - } - M = Mj[jj]; - } else { - float c = exp(Mj[jj] - M); - S += c*Sj[jj]; - for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; - } + accumulate_qkv(Dv, M, S, Mj[jj], Sj[jj], Racc, R); } float norm = S > 0 ? 1/S : 1; for (int i = 0; i < Dv; ++i) Racc[i] *= norm; @@ -154,10 +180,72 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float } if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) { - int nk = std::max(1, 32 * (nek2*nek1/(32*nth))); + auto result_size = (Dv + 16)*rk2*sizeof(float); + int gcd = simple_gcd(nek2, nth); + if (false && gcd > 1) { + int nth_g = nth/gcd; + int ith_g = ith%nth_g; + int nek1_32 = nek1/32; + int nek1_pt = (nek1_32 + nth_g - 1)/nth_g; + int ith_mid = nth_g; + if (nek1_pt*nth_g > nek1_32) { + ith_mid = nek1_32 - nth_g*(nek1_pt - 1); + } + nek1_pt *= 32; + int nek1_mid = ith_mid*nek1_pt; + int nek1_thread = ith_g < ith_mid ? nek1_pt : nek1_pt - 32; + for (int ik02 = ith/nth_g; ik02 < nek2; ik02 += gcd) { + int ik01 = ith_g < ith_mid ? ith_g*nek1_pt : nek1_mid + (ith_g - ith_mid)*nek1_thread; + auto this_result = (float *)((char *)work_buffer + (ik02*nth_g + ith_g)*result_size); + auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2); + auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2; + auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2; + auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, rk2, nek1_thread, nbq2, stride_k, stride_v, 0, Dv, + this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, + scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false; + } + + barrier(barrier_data); + + for (int iq2 = ith; iq2 < neq2; iq2 += nth) { + int ik02 = iq2/rk2; + int il = iq2 - ik02*rk2; + auto Racc = qkv + iq2*nb1/sizeof(float); + float M = -INFINITY, S = 0; + for (int ig = 0; ig < nth_g; ++ig) { + int istep_k = ik02*nth_g + ig; + auto this_result = (float *)((char *)work_buffer + istep_k*result_size); + const float * R = this_result + il*Dv; + const float * Mj = this_result + Dv*rk2; + const float * Sj = Mj + rk2; + accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R); + } + float norm = S > 0 ? 1/S : 1; + for (int i = 0; i < Dv; ++i) Racc[i] *= norm; + } + return true; + } + int nth_k = nth/gcd; + int nek2_k = nek2/gcd; + int nchunk = nek2_k*nek1/32; + int npt = (nchunk + nth_k - 1)/nth_k; + int nk; + if (npt*nth_k == nchunk) { + nk = 32 * (nek2*nek1/(32*nth)); + } else { + //int nm = std::max(1, npt/8); + int nm = 1; + while (true) { + if (nm*4 >= npt) break; + nm *= 2; + } + nk = 32*nm; + } + //int nk = 32 * (nek2*nek1/(32*nth)); int nkk = (nek1 + nk - 1)/nk; int nstep_k = nek2*nkk; - auto result_size = (Dv + 16)*rk2*sizeof(float); //if (ith == 0) printf("rk2 = %d, nek1 = %d, nek2 = %d, nk = %d, nkk = %d, nstep_k = %d\n", (int)rk2, (int)nek1, (int)nek2, nk, nkk, nstep_k); for (int istep_k = ith; istep_k < nstep_k; istep_k += nth) { int ik02 = istep_k/nkk; @@ -183,7 +271,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int ik02 = iq2/rk2; int il = iq2 - ik02*rk2; auto Racc = qkv + iq2*nb1/sizeof(float); - std::memset(Racc, 0, Dv*sizeof(float)); + //std::memset(Racc, 0, Dv*sizeof(float)); float M = -INFINITY, S = 0; for (int ikk = 0; ikk < nkk; ++ikk) { int istep_k = ik02*nkk + ikk; @@ -191,22 +279,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float const float * R = this_result + il*Dv; const float * Mj = this_result + Dv*rk2; const float * Sj = Mj + rk2; - if (Mj[il] == -INFINITY) continue; - if (Mj[il] > M) { - if (M == -INFINITY) { - std::memcpy(Racc, R, Dv*sizeof(float)); - S = Sj[il]; - } else { - float c = exp(M - Mj[il]); - S = c*S + Sj[il]; - for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; - } - M = Mj[il]; - } else { - float c = exp(Mj[il] - M); - S += c*Sj[il]; - for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; - } + accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R); } float norm = S > 0 ? 1/S : 1; for (int i = 0; i < Dv; ++i) Racc[i] *= norm; |