summaryrefslogtreecommitdiff
path: root/tests/test-grad0.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test-grad0.cpp')
-rw-r--r--tests/test-grad0.cpp7
1 files changed, 4 insertions, 3 deletions
diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp
index 0a559b27..7fe9154d 100644
--- a/tests/test-grad0.cpp
+++ b/tests/test-grad0.cpp
@@ -231,9 +231,10 @@ static bool check_gradient(
printf("GGML_N_THREADS = %d\n", n_threads);
}
- struct ggml_cgraph * gf = ggml_build_forward_ctx(ctx0, f);
- struct ggml_cgraph * gb = ggml_new_graph(ctx0);
- *gb = *gf;
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+ struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+ ggml_build_forward_expand(gf, f);
+ ggml_graph_cpy(gf, gb);
ggml_build_backward_expand(ctx0, gf, gb, false);
ggml_graph_compute_with_ctx(ctx0, gf, n_threads);