summaryrefslogtreecommitdiff
path: root/examples/simple/simple.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/simple/simple.cpp')
-rw-r--r--examples/simple/simple.cpp24
1 files changed, 17 insertions, 7 deletions
diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp
index 1616a4a7..24fb16b7 100644
--- a/examples/simple/simple.cpp
+++ b/examples/simple/simple.cpp
@@ -33,18 +33,28 @@ int main(int argc, char ** argv) {
llama_backend_init(params.numa);
- llama_context_params ctx_params = llama_context_default_params();
+ // initialize the model
- ctx_params.seed = 1234;
- ctx_params.n_ctx = 2048;
+ llama_model_params model_params = llama_model_default_params();
- llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
+ // model_params.n_gpu_layers = 99; // offload all layers to the GPU
+
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
+ // initialize the context
+
+ llama_context_params ctx_params = llama_context_default_params();
+
+ ctx_params.seed = 1234;
+ ctx_params.n_ctx = 2048;
+ ctx_params.n_threads = params.n_threads;
+ ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
@@ -97,7 +107,7 @@ int main(int argc, char ** argv) {
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
- if (llama_decode(ctx, batch, params.n_threads) != 0) {
+ if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
@@ -112,7 +122,7 @@ int main(int argc, char ** argv) {
while (n_cur <= n_len) {
// sample the next token
{
- auto n_vocab = llama_n_vocab(ctx);
+ auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
std::vector<llama_token_data> candidates;
@@ -154,7 +164,7 @@ int main(int argc, char ** argv) {
n_cur += 1;
// evaluate the current batch with the transformer model
- if (llama_decode(ctx, batch, params.n_threads)) {
+ if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}