summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
Diffstat (limited to 'common')
-rw-r--r--common/train.cpp1
-rw-r--r--common/train.h2
2 files changed, 3 insertions, 0 deletions
diff --git a/common/train.cpp b/common/train.cpp
index bc15b7a0..964b156b 100644
--- a/common/train.cpp
+++ b/common/train.cpp
@@ -32,6 +32,7 @@ struct train_state * init_train_state() {
state->opt = new struct ggml_opt_context;
state->opt->ctx = NULL;
state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
+ state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
state->opt->loss_after = 0.0f;
return state;
diff --git a/common/train.h b/common/train.h
index d86c93cc..263d940c 100644
--- a/common/train.h
+++ b/common/train.h
@@ -9,6 +9,8 @@
#include "ggml.h"
#include "llama.h"
+#define LLAMA_TRAIN_MAX_NODES 16384
+
typedef std::string mt19937_state;
struct train_state {