summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c32
1 files changed, 20 insertions, 12 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 25694fc7..83a48cb6 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -21781,19 +21781,27 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
const struct ggml_tensor * q = node->src[0];
const struct ggml_tensor * k = node->src[1];
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
- int nstep_k = k->ne[1]/32;
- int gcd_k = simple_gcd(nstep_k, n_tasks);
- if (gcd_k > 1) {
- int nth_k = n_tasks/gcd_k;
- int rk2 = q->ne[2]/k->ne[2];
- int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
- size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks;
- if (ggml_is_quantized(k->type)) {
- enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
- size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
- size += q->ne[2]*row_size;
- }
+ if (k->ne[2] > 1) {
+ int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks));
+ int nstep_k = k->ne[2]*k->ne[1]/nk;
+ size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
+ size_t size = nstep_k*result_size;
cur = MAX(cur, size);
+ } else {
+ int nstep_k = k->ne[1]/32;
+ int gcd_k = simple_gcd(nstep_k, n_tasks);
+ if (gcd_k > 1) {
+ int nth_k = n_tasks/gcd_k;
+ int rk2 = q->ne[2]/k->ne[2];
+ int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
+ size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks;
+ if (ggml_is_quantized(k->type)) {
+ enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
+ size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
+ size += q->ne[2]*row_size;
+ }
+ cur = MAX(cur, size);
+ }
}
}
#endif