diff options
Diffstat (limited to 'ggml/src')
-rw-r--r-- | ggml/src/ggml-backend.c | 45 |
1 files changed, 37 insertions, 8 deletions
diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c index e1651cc6..76d37f74 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.c @@ -1040,6 +1040,13 @@ struct ggml_backend_sched_split { struct ggml_cgraph graph; }; +// Object to facilitate GML graph caching +struct ggml_cached_graph { + bool is_active; + ggml_backend_t input_backend; + struct ggml_tensor * input_cpy[GGML_SCHED_MAX_SPLIT_INPUTS]; +}; + struct ggml_backend_sched { bool is_reset; // true if the scheduler has been reset since the last graph split bool is_alloc; @@ -1085,6 +1092,8 @@ struct ggml_backend_sched { size_t context_buffer_size; bool debug; + + struct ggml_cached_graph cached_graph; }; #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) @@ -1762,6 +1771,14 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s struct ggml_tensor * input = split->inputs[j]; struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy); + if (!sched->cached_graph.is_active) { + sched->cached_graph.input_backend = input_backend; + sched->cached_graph.input_cpy[j] = input_cpy; + } else { + input_backend = sched->cached_graph.input_backend; + input_cpy = sched->cached_graph.input_cpy[j]; + } + if (input->flags & GGML_TENSOR_FLAG_INPUT) { // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done if (sched->events[split_backend_id][sched->cur_copy] != NULL) { @@ -1893,6 +1910,8 @@ ggml_backend_sched_t ggml_backend_sched_new( ggml_backend_sched_reset(sched); + sched->cached_graph.is_active = false; + return sched; } @@ -1969,16 +1988,16 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st } enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { - if (!sched->is_reset && !sched->is_alloc) { - ggml_backend_sched_reset(sched); - } - - if (!sched->is_alloc) { - if (!ggml_backend_sched_alloc_graph(sched, graph)) { - return GGML_STATUS_ALLOC_FAILED; + if(!sched->cached_graph.is_active) { + if (!sched->is_reset && !sched->is_alloc) { + ggml_backend_sched_reset(sched); + } + if (!sched->is_alloc) { + if (!ggml_backend_sched_alloc_graph(sched, graph)) { + return GGML_STATUS_ALLOC_FAILED; + } } } - return ggml_backend_sched_compute_splits(sched); } @@ -2243,3 +2262,13 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t return true; } + +bool ggml_use_cached_graph(ggml_backend_sched_t sched) { + return sched->cached_graph.is_active; +} + +void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value) { + sched->cached_graph.is_active = set_value; +} + + |