diff options
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 23 |
1 files changed, 14 insertions, 9 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index bcb8bf41..b3c8a951 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -17471,25 +17471,30 @@ static void ggml_compute_forward_flash_attn_ext_f16( #if GGML_USE_IQK_MULMAT if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { - int64_t work_per_slice = D*nek1*neq1; - int ntg = 1; + // I keep changing my mind what is the best strategy to split the threads when processing + // multiple heads. This is my current thinking, the commented out code below was the previous. + int ntg = nth/simple_gcd(neq2*neq3, nth); + int64_t neq1g = (neq1 + ntg - 1)/ntg; + //int64_t work_per_slice = D*nek1*neq1; + //int ntg = 1; // // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of // the number of threads processing the (iq2, iq3) matrix. // - if (neq1 >= 8*nth) { - if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; - else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; - else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; - } + //if (neq1 >= 8*nth) { + // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; + // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; + // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; + //} int counter = 0; for (int64_t iq3 = 0; iq3 < neq3; iq3++) { for (int64_t iq2 = 0; iq2 < neq2; iq2++) { if (counter++ % (nth/ntg) == ith/ntg) { - int iq1 = (ith%ntg)*neq1/ntg; + int iq1 = (ith%ntg)*neq1g; + int this_neq1 = MIN(neq1g, neq1-iq1); if (!iqk_flash_attn_noalibi(k->type, v->type, - D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), + D, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), |