From 527b6fba1d237befb324fd846bda7418c0fa394d Mon Sep 17 00:00:00 2001 From: Didzis Gosko Date: Sat, 24 Jun 2023 11:47:58 +0300 Subject: llama : make model stateless and context stateful (llama_state) (#1797) * llama : make model stateless and context stateful * llama : minor cleanup * llama : update internal API declaration * Apply suggestions from code review fix style Co-authored-by: Georgi Gerganov * Missing model memory release * Fix style * Add deprecated warning for public API function llama_init_from_file * Update public API use cases: move away from deprecated llama_init_from_file * Deprecate public API function llama_apply_lora_from_file --------- Co-authored-by: Georgi Gerganov --- examples/common.cpp | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) (limited to 'examples/common.cpp') diff --git a/examples/common.cpp b/examples/common.cpp index fed24e02..6ac48455 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -536,7 +536,7 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s return res; } -struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { +std::tuple llama_init_from_gpt_params(const gpt_params & params) { auto lparams = llama_context_default_params(); lparams.n_ctx = params.n_ctx; @@ -552,25 +552,33 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { lparams.logits_all = params.perplexity; lparams.embedding = params.embedding; - llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams); + llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); + if (model == NULL) { + fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); + return std::make_tuple(nullptr, nullptr); + } + llama_context * lctx = llama_new_context_with_model(model, lparams); if (lctx == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return NULL; + fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); } if (!params.lora_adapter.empty()) { - int err = llama_apply_lora_from_file(lctx, + int err = llama_model_apply_lora_from_file(model, params.lora_adapter.c_str(), params.lora_base.empty() ? NULL : params.lora_base.c_str(), params.n_threads); if (err != 0) { fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); - return NULL; + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); } } - return lctx; + return std::make_tuple(model, lctx); } void console_init(console_state & con_st) { -- cgit v1.2.3