diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-04-29 07:19:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-04-29 07:19:43 +0200 |
commit | cda24b58cbef34154651d0083910fed860a506c1 (patch) | |
tree | 90cd3bd7f772c3b240a6553eca5e50edf95c53da /ggml/src/ggml.c | |
parent | baeefb4731fb24cdace168f6dbc74516d470efc0 (diff) |
CPU FA improvements (#351)
* FA: provide work buffer for K repacking
* Add header to avoid comp0iler warnings
* WIP
* WIP
* WIP
* WIP
* Slightly better
* WIP (Zen4)
* WIP
* Try to improve for unusual number of heads/number of threads
* Use mul_mat_qX_0_q8_2_Tx for q6_0 in FA
* Use mul_mat_qX_0_q8_2_Tx for q4_0 in FA
* Use Sum4q4 for q4_0
* WIP
* WIP
* Much better FA TG with q8_0 KV cache
Just repack it even for TG. But do the repacking for k_step rows,
not the whole K tensor.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f3cfd9a0..4cd18a28 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -21786,15 +21786,36 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread #if GGML_USE_IQK_MULMAT + size_t qsize = 0; const struct ggml_tensor * q = node->src[0]; const struct ggml_tensor * k = node->src[1]; + if (k->type == GGML_TYPE_Q8_0) { + qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]); + } if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) { if (k->ne[2] > 1) { - int nk = MAX(1, 32 * (k->ne[2]*k->ne[1]/(32*n_tasks))); + int gcd = simple_gcd(k->ne[2], n_tasks); + int nth_k = n_tasks/gcd; + int nek2_k = k->ne[2]/gcd; + int nchunk = nek2_k*k->ne[1]/32; + int npt = (nchunk + nth_k - 1)/nth_k; + int nk; + if (npt*nth_k == nchunk) { + nk = 32 * (k->ne[1]*k->ne[2]/(32*n_tasks)); + } else { + //int nm = std::max(1, npt/8); + int nm = 1; + while (true) { + if (nm*4 >= npt) break; + nm *= 2; + } + nk = 32*nm; + } + //int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)); int nstep_k = k->ne[2]*k->ne[1]/nk; size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float); size_t size = nstep_k*result_size; - cur = MAX(cur, size); + cur = MAX(cur, size+qsize); } else { int nstep_k = k->ne[1]/32; int gcd_k = simple_gcd(nstep_k, n_tasks); @@ -21808,9 +21829,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); size += q->ne[2]*row_size; } - cur = MAX(cur, size); + cur = MAX(cur, size+qsize); } } + } else { + cur = MAX(cur, qsize); } #endif } break; |