summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml.c15
-rw-r--r--ggml/src/iqk/iqk_flash_attn.cpp60
2 files changed, 48 insertions, 27 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index a2bdc156..036bd8a8 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -21771,15 +21771,14 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
if (gcd_k > 1) {
int nth_k = n_tasks/gcd_k;
int rk2 = q->ne[2]/k->ne[2];
- if (rk2%nth_k == 0) {
- size_t size = (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks;
- if (ggml_is_quantized(k->type)) {
- enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
- size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
- size += q->ne[2]*row_size;
- }
- cur = MAX(cur, size);
+ int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
+ size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks;
+ if (ggml_is_quantized(k->type)) {
+ enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
+ size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
+ size += q->ne[2]*row_size;
}
+ cur = MAX(cur, size);
}
}
#endif
diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp
index fecd818b..2c6da2b9 100644
--- a/ggml/src/iqk/iqk_flash_attn.cpp
+++ b/ggml/src/iqk/iqk_flash_attn.cpp
@@ -64,40 +64,63 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
int gcd_k = simple_gcd(nstep_k, nth);
if (gcd_k >= 1) {
int nth_k = nth/gcd_k;
- if (rk2%nth_k == 0) {
- int ith_k = ith%gcd_k;
- int ith_q = ith/gcd_k;
+ int ith_k = ith%gcd_k;
+ int ith_q = ith/gcd_k;
+ int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
+ if (nq_per_thread > 1) {
+ int ith_mid = nth_k;
+ int nq_this_thread = nq_per_thread;
+ if (nq_per_thread*nth_k > rk2) {
+ ith_mid = rk2 - nth_k*(nq_per_thread - 1);
+ if (ith_q >= ith_mid) --nq_this_thread;
+ }
+ int j_mid = ith_mid*nq_per_thread;
+ auto work = (char *)work_buffer;
+ auto size_thread = (Dv + 16)*nq_per_thread*sizeof(float);
+ auto result_buffer = work;
+
auto kth = (const char *)k + ith_k*(nek1/gcd_k)*stride_k;
auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v;
- auto qth = (const char *)q + ith_q*(rk2/nth_k)*nbq2;
+ auto q_offset = ith_q < ith_mid ? ith_q*nq_per_thread*nbq2 : (ith_mid*nq_per_thread + (ith_q - ith_mid)*nq_this_thread)*nbq2;
+ auto qth = (const char *)q + q_offset;
auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here
- auto work = (char *)work_buffer;
- // Each thread will produce a result of size Dv*(rk2/nth_k)*sizeof(float)
- // In addition, we need M, S for the rk2/nth_k rows the thread is processing
- // => (Dv + 2)*rk2/nth_k*sizeof(float). We use (Dv + 16) instead to make sure threads are not
+ // Each thread will produce a result of size Dv*nq_this_thread*sizeof(float)
+ // In addition, we need M, S for the nq_this_thread rows the thread is processing
+ // => (Dv + 2)*nq_per_thread*sizeof(float). We use (Dv + 16) instead to make sure threads are not
// writing onto the same cache line.
- auto size_thread = (Dv + 16)*rk2/nth_k*sizeof(float);
- auto result_buffer = work;
auto work_this_thread = (float *)(result_buffer + ith*size_thread);
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
- Dk, Dv, rk2/nth_k, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
+ Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
(const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth,
scale, softcap,
- work_this_thread, work_this_thread + (Dv+0)*rk2/nth_k, work_this_thread + (Dv+1)*rk2/nth_k)) return false;
+ work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false;
barrier(barrier_data);
+ // There are nek1/gcd_k contributions for each j that we need to sum up
+ // Thread i computed k/v (i%gcd_k)*(nek1/gcd_k) for j (i/gcd_k)*(rk2/nth_k)...((i/gcd_k)+1)*(rk2/nth_k) and results at offset i*size_thread
+
// TODO: simdify this
+ // TODO: if nth > rk2, have threads process portions of the rows instead of entire rows as it is now
for (int j = ith; j < rk2; j += nth) {
auto Racc = qkv + j*nb1/sizeof(float);
float M = -INFINITY, S = 0;
- int jth_q = j/(rk2/nth_k);
- int jj = j%(rk2/nth_k);
- for (int j1 = 0; j1 < rk2/nth_k; ++j1) {
- auto R = (const float *)(result_buffer + (jth_q*(rk2/nth_k) + j1)*size_thread);
- auto Mj = R + Dv*rk2/nth_k;
- auto Sj = Mj + rk2/nth_k;
+ int jth_first, jj, nq_this_j;
+ if (j < j_mid) {
+ jth_first = j/nq_per_thread;
+ jj = j%nq_per_thread;
+ nq_this_j = nq_per_thread;
+ } else {
+ jth_first = ith_mid + (j - j_mid)/(nq_per_thread-1);
+ jj = (j - j_mid)%(nq_per_thread-1);
+ nq_this_j = nq_per_thread - 1;
+ }
+ jth_first *= gcd_k;
+ for (int jth = jth_first; jth < jth_first + gcd_k; ++jth) {
+ auto R = (const float *)(result_buffer + jth*size_thread);
+ auto Mj = R + Dv*nq_this_j;
+ auto Sj = Mj + nq_this_j;
R += jj*Dv;
if (Mj[jj] == -INFINITY) continue;
if (Mj[jj] > M) {
@@ -120,7 +143,6 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
}
return true;
-
}
}
}