summaryrefslogtreecommitdiff
path: root/examples/finetune
diff options
context:
space:
mode:
Diffstat (limited to 'examples/finetune')
-rw-r--r--examples/finetune/finetune.cpp23
1 files changed, 11 insertions, 12 deletions
diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp
index fa7dbe49..5a6cf22c 100644
--- a/examples/finetune/finetune.cpp
+++ b/examples/finetune/finetune.cpp
@@ -772,7 +772,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
if (enable_checkpointing) {
ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
} else {
- *gb = *gf;
+ ggml_graph_cpy(gf, gb);
ggml_build_backward_expand(ctx, gf, gb, true);
}
@@ -1615,6 +1615,7 @@ int main(int argc, char ** argv) {
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false;
+ opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
opt->params.n_threads = params.common.n_threads;
opt->params.past = params.common.opt_past;
opt->params.delta = params.common.opt_delta;
@@ -1741,11 +1742,9 @@ int main(int argc, char ** argv) {
ggml_allocr_free(alloc);
// context for compute tensors without their data
- size_t estimated_compute_size_wo_data = (
- ggml_tensor_overhead()*GGML_MAX_NODES*2
- + (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
- params.common.use_checkpointing ? 3 : 2
- )
+ const size_t estimated_compute_size_wo_data = (
+ 2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() +
+ (params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true))
);
struct ggml_init_params ctx_compute_params = {
estimated_compute_size_wo_data, // mem_size
@@ -1768,11 +1767,11 @@ int main(int argc, char ** argv) {
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new_measure(tensor_alignment);
- gf = ggml_new_graph(ctx_compute);
+ gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order;
- gb = ggml_new_graph(ctx_compute);
+ gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
- ? ggml_new_graph(ctx_compute)
+ ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL;
loss = llama_build_lora_finetune_graphs(
&model, &lora, alloc, ctx_compute,
@@ -1801,11 +1800,11 @@ int main(int argc, char ** argv) {
mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
- gf = ggml_new_graph(ctx_compute);
+ gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order;
- gb = ggml_new_graph(ctx_compute);
+ gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
- ? ggml_new_graph(ctx_compute)
+ ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL;
loss = llama_build_lora_finetune_graphs(
&model, &lora, alloc, ctx_compute,