summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c29
1 files changed, 26 insertions, 3 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index f3cfd9a0..4cd18a28 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -21786,15 +21786,36 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread
#if GGML_USE_IQK_MULMAT
+ size_t qsize = 0;
const struct ggml_tensor * q = node->src[0];
const struct ggml_tensor * k = node->src[1];
+ if (k->type == GGML_TYPE_Q8_0) {
+ qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]);
+ }
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
if (k->ne[2] > 1) {
- int nk = MAX(1, 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)));
+ int gcd = simple_gcd(k->ne[2], n_tasks);
+ int nth_k = n_tasks/gcd;
+ int nek2_k = k->ne[2]/gcd;
+ int nchunk = nek2_k*k->ne[1]/32;
+ int npt = (nchunk + nth_k - 1)/nth_k;
+ int nk;
+ if (npt*nth_k == nchunk) {
+ nk = 32 * (k->ne[1]*k->ne[2]/(32*n_tasks));
+ } else {
+ //int nm = std::max(1, npt/8);
+ int nm = 1;
+ while (true) {
+ if (nm*4 >= npt) break;
+ nm *= 2;
+ }
+ nk = 32*nm;
+ }
+ //int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks));
int nstep_k = k->ne[2]*k->ne[1]/nk;
size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
size_t size = nstep_k*result_size;
- cur = MAX(cur, size);
+ cur = MAX(cur, size+qsize);
} else {
int nstep_k = k->ne[1]/32;
int gcd_k = simple_gcd(nstep_k, n_tasks);
@@ -21808,9 +21829,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
size += q->ne[2]*row_size;
}
- cur = MAX(cur, size);
+ cur = MAX(cur, size+qsize);
}
}
+ } else {
+ cur = MAX(cur, qsize);
}
#endif
} break;