diff options
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 32 |
1 files changed, 20 insertions, 12 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 25694fc7..83a48cb6 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -21781,19 +21781,27 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const struct ggml_tensor * q = node->src[0]; const struct ggml_tensor * k = node->src[1]; if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) { - int nstep_k = k->ne[1]/32; - int gcd_k = simple_gcd(nstep_k, n_tasks); - if (gcd_k > 1) { - int nth_k = n_tasks/gcd_k; - int rk2 = q->ne[2]/k->ne[2]; - int nq_per_thread = (rk2 + nth_k - 1)/nth_k; - size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks; - if (ggml_is_quantized(k->type)) { - enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; - size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); - size += q->ne[2]*row_size; - } + if (k->ne[2] > 1) { + int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)); + int nstep_k = k->ne[2]*k->ne[1]/nk; + size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float); + size_t size = nstep_k*result_size; cur = MAX(cur, size); + } else { + int nstep_k = k->ne[1]/32; + int gcd_k = simple_gcd(nstep_k, n_tasks); + if (gcd_k > 1) { + int nth_k = n_tasks/gcd_k; + int rk2 = q->ne[2]/k->ne[2]; + int nq_per_thread = (rk2 + nth_k - 1)/nth_k; + size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks; + if (ggml_is_quantized(k->type)) { + enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; + size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); + size += q->ne[2]*row_size; + } + cur = MAX(cur, size); + } } } #endif |