diff options
Diffstat (limited to 'common')
-rw-r--r-- | common/train.cpp | 1 | ||||
-rw-r--r-- | common/train.h | 2 |
2 files changed, 3 insertions, 0 deletions
diff --git a/common/train.cpp b/common/train.cpp index bc15b7a0..964b156b 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -32,6 +32,7 @@ struct train_state * init_train_state() { state->opt = new struct ggml_opt_context; state->opt->ctx = NULL; state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM); + state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES; state->opt->loss_after = 0.0f; return state; diff --git a/common/train.h b/common/train.h index d86c93cc..263d940c 100644 --- a/common/train.h +++ b/common/train.h @@ -9,6 +9,8 @@ #include "ggml.h" #include "llama.h" +#define LLAMA_TRAIN_MAX_NODES 16384 + typedef std::string mt19937_state; struct train_state { |