diff options
author | Didzis Gosko <didzis@users.noreply.github.com> | 2023-06-24 11:47:58 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-24 11:47:58 +0300 |
commit | 527b6fba1d237befb324fd846bda7418c0fa394d (patch) | |
tree | 360b44abac0c9a53739444b8ba9e4ccf903938cd /examples/common.cpp | |
parent | d7b7484f74d486f77feb4c0b7af7e1718ed91651 (diff) |
llama : make model stateless and context stateful (llama_state) (#1797)
* llama : make model stateless and context stateful
* llama : minor cleanup
* llama : update internal API declaration
* Apply suggestions from code review
fix style
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* Missing model memory release
* Fix style
* Add deprecated warning for public API function llama_init_from_file
* Update public API use cases: move away from deprecated llama_init_from_file
* Deprecate public API function llama_apply_lora_from_file
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/common.cpp')
-rw-r--r-- | examples/common.cpp | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/examples/common.cpp b/examples/common.cpp index fed24e02..6ac48455 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -536,7 +536,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s return res; } -struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { +std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) { auto lparams = llama_context_default_params(); lparams.n_ctx = params.n_ctx; @@ -552,25 +552,33 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { lparams.logits_all = params.perplexity; lparams.embedding = params.embedding; - llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams); + llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); + if (model == NULL) { + fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); + return std::make_tuple(nullptr, nullptr); + } + llama_context * lctx = llama_new_context_with_model(model, lparams); if (lctx == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return NULL; + fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); } if (!params.lora_adapter.empty()) { - int err = llama_apply_lora_from_file(lctx, + int err = llama_model_apply_lora_from_file(model, params.lora_adapter.c_str(), params.lora_base.empty() ? NULL : params.lora_base.c_str(), params.n_threads); if (err != 0) { fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); - return NULL; + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); } } - return lctx; + return std::make_tuple(model, lctx); } void console_init(console_state & con_st) { |