summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-04-29 07:19:43 +0200
committerGitHub <noreply@github.com>2025-04-29 07:19:43 +0200
commitcda24b58cbef34154651d0083910fed860a506c1 (patch)
tree90cd3bd7f772c3b240a6553eca5e50edf95c53da /ggml/src/ggml.c
parentbaeefb4731fb24cdace168f6dbc74516d470efc0 (diff)
CPU FA improvements (#351)
* FA: provide work buffer for K repacking * Add header to avoid comp0iler warnings * WIP * WIP * WIP * WIP * Slightly better * WIP (Zen4) * WIP * Try to improve for unusual number of heads/number of threads * Use mul_mat_qX_0_q8_2_Tx for q6_0 in FA * Use mul_mat_qX_0_q8_2_Tx for q4_0 in FA * Use Sum4q4 for q4_0 * WIP * WIP * Much better FA TG with q8_0 KV cache Just repack it even for TG. But do the repacking for k_step rows, not the whole K tensor. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c29
1 files changed, 26 insertions, 3 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index f3cfd9a0..4cd18a28 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -21786,15 +21786,36 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread
#if GGML_USE_IQK_MULMAT
+ size_t qsize = 0;
const struct ggml_tensor * q = node->src[0];
const struct ggml_tensor * k = node->src[1];
+ if (k->type == GGML_TYPE_Q8_0) {
+ qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]);
+ }
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
if (k->ne[2] > 1) {
- int nk = MAX(1, 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)));
+ int gcd = simple_gcd(k->ne[2], n_tasks);
+ int nth_k = n_tasks/gcd;
+ int nek2_k = k->ne[2]/gcd;
+ int nchunk = nek2_k*k->ne[1]/32;
+ int npt = (nchunk + nth_k - 1)/nth_k;
+ int nk;
+ if (npt*nth_k == nchunk) {
+ nk = 32 * (k->ne[1]*k->ne[2]/(32*n_tasks));
+ } else {
+ //int nm = std::max(1, npt/8);
+ int nm = 1;
+ while (true) {
+ if (nm*4 >= npt) break;
+ nm *= 2;
+ }
+ nk = 32*nm;
+ }
+ //int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks));
int nstep_k = k->ne[2]*k->ne[1]/nk;
size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
size_t size = nstep_k*result_size;
- cur = MAX(cur, size);
+ cur = MAX(cur, size+qsize);
} else {
int nstep_k = k->ne[1]/32;
int gcd_k = simple_gcd(nstep_k, n_tasks);
@@ -21808,9 +21829,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
size += q->ne[2]*row_size;
}
- cur = MAX(cur, size);
+ cur = MAX(cur, size+qsize);
}
}
+ } else {
+ cur = MAX(cur, qsize);
}
#endif
} break;