diff options
Diffstat (limited to 'ggml/src/iqk/iqk_flash_attn.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_flash_attn.cpp | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index e302c944..3f1d6dc2 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -153,6 +153,67 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, } } + if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) { + int nk = 32 * (nek2*nek1/(32*nth)); + int nkk = (nek1 + nk - 1)/nk; + int nstep_k = nek2*nkk; + auto result_size = (Dv + 16)*rk2*sizeof(float); + //if (ith == 0) printf("rk2 = %d, nek1 = %d, nek2 = %d, nk = %d, nkk = %d, nstep_k = %d\n", (int)rk2, (int)nek1, (int)nek2, nk, nkk, nstep_k); + for (int istep_k = ith; istep_k < nstep_k; istep_k += nth) { + int ik02 = istep_k/nkk; + int ik01 = nk*(istep_k - ik02*nkk); + int this_nk = ik01 + nk <= nek1 ? nk : nek1 - ik01; + if (this_nk <= 0) break; + auto this_result = (float *)((char *)work_buffer + istep_k*result_size); + auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2); + auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2; + auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2; + auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv, + this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, + scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false; + } + + barrier(barrier_data); + + // We have nkk results for each head + for (int iq2 = ith; iq2 < neq2; iq2 += nth) { + // ik02*rk2 + il = iq2 (il = 0...rk2-1) => ik02 = iq2/rk2, il = iq2%rk2; + int ik02 = iq2/rk2; + int il = iq2 - ik02*rk2; + auto Racc = qkv + iq2*nb1/sizeof(float); + std::memset(Racc, 0, Dv*sizeof(float)); + float M = -INFINITY, S = 0; + for (int ikk = 0; ikk < nkk; ++ikk) { + int istep_k = ik02*nkk + ikk; + auto this_result = (float *)((char *)work_buffer + istep_k*result_size); + const float * R = this_result + il*Dv; + const float * Mj = this_result + Dv*rk2; + const float * Sj = Mj + rk2; + if (Mj[il] == -INFINITY) continue; + if (Mj[il] > M) { + if (M == -INFINITY) { + std::memcpy(Racc, R, Dv*sizeof(float)); + S = Sj[il]; + } else { + float c = exp(M - Mj[il]); + S = c*S + Sj[il]; + for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; + } + M = Mj[il]; + } else { + float c = exp(Mj[il] - M); + S += c*Sj[il]; + for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; + } + } + float norm = S > 0 ? 1/S : 1; + for (int i = 0; i < Dv; ++i) Racc[i] *= norm; + } + return true; + } + // I keep changing my mind what is the best strategy to split the threads when processing // multiple heads. This is my current thinking, the commented out code below was the previous. int ntg = nth/simple_gcd(neq2*neq3, nth); |