diff options
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 112 |
1 files changed, 72 insertions, 40 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 46e1a548..e5ad15f2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -17870,46 +17870,57 @@ static void ggml_compute_forward_flash_attn_ext_f16( } #if GGML_USE_IQK_MULMAT - if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { - //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n", - // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]); - // I keep changing my mind what is the best strategy to split the threads when processing - // multiple heads. This is my current thinking, the commented out code below was the previous. - int ntg = nth/simple_gcd(neq2*neq3, nth); - int64_t neq1g = (neq1 + ntg - 1)/ntg; - //int64_t work_per_slice = D*nek1*neq1; - //int ntg = 1; - // - // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix - // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of - // the number of threads processing the (iq2, iq3) matrix. - // - //if (neq1 >= 8*nth) { - // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; - // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; - // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; - //} - int counter = 0; - for (int64_t iq3 = 0; iq3 < neq3; iq3++) { - for (int64_t iq2 = 0; iq2 < neq2; iq2++) { - if (counter++ % (nth/ntg) == ith/ntg) { - int iq1 = (ith%ntg)*neq1g; - int this_neq1 = MIN(neq1g, neq1-iq1); - if (!iqk_flash_attn_noalibi(k->type, v->type, - Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), - (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), - (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), - (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), - (const void *)((const char *)mask->data + iq1*mask->nb[1]), - scale, softcap, - (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable; - } - } - } - return; -IQK_Flash_Attn_NotAvailable:; - printf("iqk_flash was rejected\n"); - } + if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias, + q->ne[3], q->ne[2], q->nb[3], q->nb[2], + k->ne[3], k->ne[2], k->nb[3], k->nb[2], + v->ne[3], v->ne[2], v->nb[3], v->nb[2], + dst->ne[2], dst->ne[1], dst->nb[1], + k->type, v->type, + Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], + q->data, k->data, v->data, mask->data, + scale, softcap, (float *)dst->data, + params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return; + +// if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { +// //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n", +// // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]); +// // I keep changing my mind what is the best strategy to split the threads when processing +// // multiple heads. This is my current thinking, the commented out code below was the previous. +// int ntg = nth/simple_gcd(neq2*neq3, nth); +// int64_t neq1g = (neq1 + ntg - 1)/ntg; +// //int64_t work_per_slice = D*nek1*neq1; +// //int ntg = 1; +// // +// // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix +// // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of +// // the number of threads processing the (iq2, iq3) matrix. +// // +// //if (neq1 >= 8*nth) { +// // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; +// // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; +// // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; +// //} +// int counter = 0; +// for (int64_t iq3 = 0; iq3 < neq3; iq3++) { +// for (int64_t iq2 = 0; iq2 < neq2; iq2++) { +// if (counter++ % (nth/ntg) == ith/ntg) { +// int iq1 = (ith%ntg)*neq1g; +// int this_neq1 = MIN(neq1g, neq1-iq1); +// if (!iqk_flash_attn_noalibi(k->type, v->type, +// Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), +// (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), +// (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), +// (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), +// (const void *)((const char *)mask->data + iq1*mask->nb[1]), +// scale, softcap, +// (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable; +// } +// } +// } +// return; +//IQK_Flash_Attn_NotAvailable:; +// printf("iqk_flash was rejected\n"); +// } #endif const uint32_t n_head = neq2; @@ -21534,6 +21545,27 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const int64_t D = MAX(Dk, Dv); cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread +#if GGML_USE_IQK_MULMAT + const struct ggml_tensor * q = node->src[0]; + const struct ggml_tensor * k = node->src[1]; + if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) { + int nstep_k = k->ne[1]/32; + int gcd_k = simple_gcd(nstep_k, n_tasks); + if (gcd_k > 1) { + int nth_k = n_tasks/gcd_k; + int rk2 = q->ne[2]/k->ne[2]; + if (rk2%nth_k == 0) { + size_t size = (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks; + if (ggml_is_quantized(k->type)) { + enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; + size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); + size += q->ne[2]*row_size; + } + cur = MAX(cur, size); + } + } + } +#endif } break; case GGML_OP_FLASH_ATTN_BACK: { |