diff options
-rw-r--r-- | ggml/src/ggml.c | 15 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_flash_attn.cpp | 60 |
2 files changed, 48 insertions, 27 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a2bdc156..036bd8a8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -21771,15 +21771,14 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa if (gcd_k > 1) { int nth_k = n_tasks/gcd_k; int rk2 = q->ne[2]/k->ne[2]; - if (rk2%nth_k == 0) { - size_t size = (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks; - if (ggml_is_quantized(k->type)) { - enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; - size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); - size += q->ne[2]*row_size; - } - cur = MAX(cur, size); + int nq_per_thread = (rk2 + nth_k - 1)/nth_k; + size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks; + if (ggml_is_quantized(k->type)) { + enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; + size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); + size += q->ne[2]*row_size; } + cur = MAX(cur, size); } } #endif diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index fecd818b..2c6da2b9 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -64,40 +64,63 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, int gcd_k = simple_gcd(nstep_k, nth); if (gcd_k >= 1) { int nth_k = nth/gcd_k; - if (rk2%nth_k == 0) { - int ith_k = ith%gcd_k; - int ith_q = ith/gcd_k; + int ith_k = ith%gcd_k; + int ith_q = ith/gcd_k; + int nq_per_thread = (rk2 + nth_k - 1)/nth_k; + if (nq_per_thread > 1) { + int ith_mid = nth_k; + int nq_this_thread = nq_per_thread; + if (nq_per_thread*nth_k > rk2) { + ith_mid = rk2 - nth_k*(nq_per_thread - 1); + if (ith_q >= ith_mid) --nq_this_thread; + } + int j_mid = ith_mid*nq_per_thread; + auto work = (char *)work_buffer; + auto size_thread = (Dv + 16)*nq_per_thread*sizeof(float); + auto result_buffer = work; + auto kth = (const char *)k + ith_k*(nek1/gcd_k)*stride_k; auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v; - auto qth = (const char *)q + ith_q*(rk2/nth_k)*nbq2; + auto q_offset = ith_q < ith_mid ? ith_q*nq_per_thread*nbq2 : (ith_mid*nq_per_thread + (ith_q - ith_mid)*nq_this_thread)*nbq2; + auto qth = (const char *)q + q_offset; auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here - auto work = (char *)work_buffer; - // Each thread will produce a result of size Dv*(rk2/nth_k)*sizeof(float) - // In addition, we need M, S for the rk2/nth_k rows the thread is processing - // => (Dv + 2)*rk2/nth_k*sizeof(float). We use (Dv + 16) instead to make sure threads are not + // Each thread will produce a result of size Dv*nq_this_thread*sizeof(float) + // In addition, we need M, S for the nq_this_thread rows the thread is processing + // => (Dv + 2)*nq_per_thread*sizeof(float). We use (Dv + 16) instead to make sure threads are not // writing onto the same cache line. - auto size_thread = (Dv + 16)*rk2/nth_k*sizeof(float); - auto result_buffer = work; auto work_this_thread = (float *)(result_buffer + ith*size_thread); if (!iqk_flash_attn_impl(int_type_k, int_type_v, - Dk, Dv, rk2/nth_k, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, + Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, scale, softcap, - work_this_thread, work_this_thread + (Dv+0)*rk2/nth_k, work_this_thread + (Dv+1)*rk2/nth_k)) return false; + work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false; barrier(barrier_data); + // There are nek1/gcd_k contributions for each j that we need to sum up + // Thread i computed k/v (i%gcd_k)*(nek1/gcd_k) for j (i/gcd_k)*(rk2/nth_k)...((i/gcd_k)+1)*(rk2/nth_k) and results at offset i*size_thread + // TODO: simdify this + // TODO: if nth > rk2, have threads process portions of the rows instead of entire rows as it is now for (int j = ith; j < rk2; j += nth) { auto Racc = qkv + j*nb1/sizeof(float); float M = -INFINITY, S = 0; - int jth_q = j/(rk2/nth_k); - int jj = j%(rk2/nth_k); - for (int j1 = 0; j1 < rk2/nth_k; ++j1) { - auto R = (const float *)(result_buffer + (jth_q*(rk2/nth_k) + j1)*size_thread); - auto Mj = R + Dv*rk2/nth_k; - auto Sj = Mj + rk2/nth_k; + int jth_first, jj, nq_this_j; + if (j < j_mid) { + jth_first = j/nq_per_thread; + jj = j%nq_per_thread; + nq_this_j = nq_per_thread; + } else { + jth_first = ith_mid + (j - j_mid)/(nq_per_thread-1); + jj = (j - j_mid)%(nq_per_thread-1); + nq_this_j = nq_per_thread - 1; + } + jth_first *= gcd_k; + for (int jth = jth_first; jth < jth_first + gcd_k; ++jth) { + auto R = (const float *)(result_buffer + jth*size_thread); + auto Mj = R + Dv*nq_this_j; + auto Sj = Mj + nq_this_j; R += jj*Dv; if (Mj[jj] == -INFINITY) continue; if (Mj[jj] > M) { @@ -120,7 +143,6 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, for (int i = 0; i < Dv; ++i) Racc[i] *= norm; } return true; - } } } |