summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-03-23 07:28:21 +0100
committerGitHub <noreply@github.com>2025-03-23 07:28:21 +0100
commit5a4855e61c05b0c54ecad3f4155074d8f344b6f6 (patch)
tree0059972ba5a476b96e528022596d2cf5507b9bc9 /ggml/src/ggml.c
parentdd5ebd0e3d42e871ac398deaaa3f60edaa78a7eb (diff)
Attempt to improve FlashMLA on the CPU (#277)
* Fix it for nth > rk2 * Handle rk2%nth_k != 0 * Cleanup --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c15
1 files changed, 7 insertions, 8 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index a2bdc156..036bd8a8 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -21771,15 +21771,14 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
if (gcd_k > 1) {
int nth_k = n_tasks/gcd_k;
int rk2 = q->ne[2]/k->ne[2];
- if (rk2%nth_k == 0) {
- size_t size = (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks;
- if (ggml_is_quantized(k->type)) {
- enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
- size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
- size += q->ne[2]*row_size;
- }
- cur = MAX(cur, size);
+ int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
+ size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks;
+ if (ggml_is_quantized(k->type)) {
+ enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
+ size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
+ size += q->ne[2]*row_size;
}
+ cur = MAX(cur, size);
}
}
#endif