diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-11-13 14:16:23 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-13 14:16:23 +0200 |
commit | 4760e7cc0b68570d58f55e8dda469805d1759d0d (patch) | |
tree | cd983b1f2833f0094c0539f7943703c6787bf12b /examples/finetune | |
parent | bb50a792ec2a49944470c82694fa364345e95170 (diff) |
sync : ggml (backend v2) (#3912)
* sync : ggml (backend v2) (wip)
* sync : migrate examples and llama.cpp to dynamic graphs (wip)
* sync : update tests + fix max op params to 64
ggml-ci
* sync : ggml-cuda
ggml-ci
* llama : fix save/load state context size
ggml-ci
* sync : try to fix build on tvOS
* sync : pass custom graph sizes in training examples
* sync : update graph copies to new ggml API
* sync : update sync-ggml.sh with new files
* scripts : fix header in sync script
* train : fix context size calculations
* llama : increase inference graph size up to 4096 nodes
* train : allocate grads for backward graphs
* train : allocate grads for gb_tmp
Diffstat (limited to 'examples/finetune')
-rw-r--r-- | examples/finetune/finetune.cpp | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index fa7dbe49..5a6cf22c 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -772,7 +772,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( if (enable_checkpointing) { ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size()); } else { - *gb = *gf; + ggml_graph_cpy(gf, gb); ggml_build_backward_expand(ctx, gf, gb, true); } @@ -1615,6 +1615,7 @@ int main(int argc, char ** argv) { opt->params = ggml_opt_default_params(GGML_OPT_ADAM); opt->params.print_forward_graph = false; opt->params.print_backward_graph = false; + opt->params.graph_size = LLAMA_TRAIN_MAX_NODES; opt->params.n_threads = params.common.n_threads; opt->params.past = params.common.opt_past; opt->params.delta = params.common.opt_delta; @@ -1741,11 +1742,9 @@ int main(int argc, char ** argv) { ggml_allocr_free(alloc); // context for compute tensors without their data - size_t estimated_compute_size_wo_data = ( - ggml_tensor_overhead()*GGML_MAX_NODES*2 - + (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*( - params.common.use_checkpointing ? 3 : 2 - ) + const size_t estimated_compute_size_wo_data = ( + 2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() + + (params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true)) ); struct ggml_init_params ctx_compute_params = { estimated_compute_size_wo_data, // mem_size @@ -1768,11 +1767,11 @@ int main(int argc, char ** argv) { for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) { ctx_compute = ggml_init(ctx_compute_params); alloc = ggml_allocr_new_measure(tensor_alignment); - gf = ggml_new_graph(ctx_compute); + gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gf->order = (enum ggml_cgraph_eval_order) order; - gb = ggml_new_graph(ctx_compute); + gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb_tmp = params.common.use_checkpointing - ? ggml_new_graph(ctx_compute) + ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true) : NULL; loss = llama_build_lora_finetune_graphs( &model, &lora, alloc, ctx_compute, @@ -1801,11 +1800,11 @@ int main(int argc, char ** argv) { mem_compute_data.resize(max_compute_size); ctx_compute = ggml_init(ctx_compute_params); alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment); - gf = ggml_new_graph(ctx_compute); + gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gf->order = best_order; - gb = ggml_new_graph(ctx_compute); + gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb_tmp = params.common.use_checkpointing - ? ggml_new_graph(ctx_compute) + ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true) : NULL; loss = llama_build_lora_finetune_graphs( &model, &lora, alloc, ctx_compute, |