summaryrefslogtreecommitdiff
path: root/ggml-backend.c
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-01-17 18:39:41 +0200
committerGitHub <noreply@github.com>2024-01-17 18:39:41 +0200
commit44a1a4a41a4c0b03afaa7d9e06bcbc7cf95aa1e6 (patch)
tree3c0973be05046780e14ca8048b7dbe1372aa5833 /ggml-backend.c
parentc918fe8dca8fa1c4602427e0a4b88e20046f6c34 (diff)
backend : add eval callback (#4935)
* backend : add eval callback ggml-ci * backend : group nodes in a single compute when user don't need them * backend : clean-up the implementation ggml-ci * simple : do not perform tensor data copy if not needed * simple : fix * simple : no need for ggml_is_contiguous + fix bool parse * llama : fix callback placement in llama_context_params * backend : avoid double-ask callback calls * simple : restore examples, imatrix will serve as a demo
Diffstat (limited to 'ggml-backend.c')
-rw-r--r--ggml-backend.c42
1 files changed, 40 insertions, 2 deletions
diff --git a/ggml-backend.c b/ggml-backend.c
index f5424fb9..4266250f 100644
--- a/ggml-backend.c
+++ b/ggml-backend.c
@@ -802,6 +802,9 @@ struct ggml_backend_sched {
__attribute__((aligned(GGML_MEM_ALIGN)))
#endif
char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
+
+ ggml_backend_sched_eval_callback callback_eval;
+ void * callback_eval_user_data;
};
#define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
@@ -1324,9 +1327,38 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
ggml_graph_dump_dot(split->graph, NULL, split_filename);
#endif
+
uint64_t compute_start_us = ggml_time_us();
- ggml_backend_graph_compute(split_backend, &split->graph);
- //ggml_backend_synchronize(split_backend); // necessary to measure compute time
+ if (!sched->callback_eval) {
+ ggml_backend_graph_compute(split_backend, &split->graph);
+ //ggml_backend_synchronize(split_backend); // necessary to measure compute time
+ } else {
+ // similar to ggml_backend_compare_graph_backend
+ for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
+ struct ggml_tensor * t = split->graph.nodes[j0];
+
+ // check if the user needs data from this node
+ bool need = sched->callback_eval(t, true, sched->callback_eval_user_data);
+
+ int j1 = j0;
+
+ // determine the range [j0, j1] of nodes that can be computed together
+ while (!need && j1 < split->graph.n_nodes - 1) {
+ t = split->graph.nodes[++j1];
+ need = sched->callback_eval(t, true, sched->callback_eval_user_data);
+ }
+
+ struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
+
+ ggml_backend_graph_compute(split_backend, &gv);
+
+ if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
+ break;
+ }
+
+ j0 = j1;
+ }
+ }
uint64_t compute_end_us = ggml_time_us();
compute_us[split_backend_id] += compute_end_us - compute_start_us;
}
@@ -1431,6 +1463,12 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
sched_reset(sched);
}
+
+void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
+ sched->callback_eval = callback;
+ sched->callback_eval_user_data = user_data;
+}
+
int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
return sched->n_splits;
}