summaryrefslogtreecommitdiff
path: root/examples/common.cpp
diff options
context:
space:
mode:
authorDidzis Gosko <didzis@users.noreply.github.com>2023-06-24 11:47:58 +0300
committerGitHub <noreply@github.com>2023-06-24 11:47:58 +0300
commit527b6fba1d237befb324fd846bda7418c0fa394d (patch)
tree360b44abac0c9a53739444b8ba9e4ccf903938cd /examples/common.cpp
parentd7b7484f74d486f77feb4c0b7af7e1718ed91651 (diff)
llama : make model stateless and context stateful (llama_state) (#1797)
* llama : make model stateless and context stateful * llama : minor cleanup * llama : update internal API declaration * Apply suggestions from code review fix style Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Missing model memory release * Fix style * Add deprecated warning for public API function llama_init_from_file * Update public API use cases: move away from deprecated llama_init_from_file * Deprecate public API function llama_apply_lora_from_file --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/common.cpp')
-rw-r--r--examples/common.cpp22
1 files changed, 15 insertions, 7 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index fed24e02..6ac48455 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -536,7 +536,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res;
}
-struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
+std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
@@ -552,25 +552,33 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;
- llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
+ if (model == NULL) {
+ fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
+ return std::make_tuple(nullptr, nullptr);
+ }
+ llama_context * lctx = llama_new_context_with_model(model, lparams);
if (lctx == NULL) {
- fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
- return NULL;
+ fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
+ llama_free_model(model);
+ return std::make_tuple(nullptr, nullptr);
}
if (!params.lora_adapter.empty()) {
- int err = llama_apply_lora_from_file(lctx,
+ int err = llama_model_apply_lora_from_file(model,
params.lora_adapter.c_str(),
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
- return NULL;
+ llama_free(lctx);
+ llama_free_model(model);
+ return std::make_tuple(nullptr, nullptr);
}
}
- return lctx;
+ return std::make_tuple(model, lctx);
}
void console_init(console_state & con_st) {