diff options
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f3cfd9a0..4cd18a28 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -21786,15 +21786,36 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread #if GGML_USE_IQK_MULMAT + size_t qsize = 0; const struct ggml_tensor * q = node->src[0]; const struct ggml_tensor * k = node->src[1]; + if (k->type == GGML_TYPE_Q8_0) { + qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]); + } if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) { if (k->ne[2] > 1) { - int nk = MAX(1, 32 * (k->ne[2]*k->ne[1]/(32*n_tasks))); + int gcd = simple_gcd(k->ne[2], n_tasks); + int nth_k = n_tasks/gcd; + int nek2_k = k->ne[2]/gcd; + int nchunk = nek2_k*k->ne[1]/32; + int npt = (nchunk + nth_k - 1)/nth_k; + int nk; + if (npt*nth_k == nchunk) { + nk = 32 * (k->ne[1]*k->ne[2]/(32*n_tasks)); + } else { + //int nm = std::max(1, npt/8); + int nm = 1; + while (true) { + if (nm*4 >= npt) break; + nm *= 2; + } + nk = 32*nm; + } + //int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)); int nstep_k = k->ne[2]*k->ne[1]/nk; size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float); size_t size = nstep_k*result_size; - cur = MAX(cur, size); + cur = MAX(cur, size+qsize); } else { int nstep_k = k->ne[1]/32; int gcd_k = simple_gcd(nstep_k, n_tasks); @@ -21808,9 +21829,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); size += q->ne[2]*row_size; } - cur = MAX(cur, size); + cur = MAX(cur, size+qsize); } } + } else { + cur = MAX(cur, qsize); } #endif } break; |