diff options
Diffstat (limited to 'examples/simple')
-rw-r--r-- | examples/simple/simple.cpp | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 1616a4a7..24fb16b7 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -33,18 +33,28 @@ int main(int argc, char ** argv) { llama_backend_init(params.numa); - llama_context_params ctx_params = llama_context_default_params(); + // initialize the model - ctx_params.seed = 1234; - ctx_params.n_ctx = 2048; + llama_model_params model_params = llama_model_default_params(); - llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); + // model_params.n_gpu_layers = 99; // offload all layers to the GPU + + llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); return 1; } + // initialize the context + + llama_context_params ctx_params = llama_context_default_params(); + + ctx_params.seed = 1234; + ctx_params.n_ctx = 2048; + ctx_params.n_threads = params.n_threads; + ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + llama_context * ctx = llama_new_context_with_model(model, ctx_params); if (ctx == NULL) { @@ -97,7 +107,7 @@ int main(int argc, char ** argv) { // llama_decode will output logits only for the last token of the prompt batch.logits[batch.n_tokens - 1] = true; - if (llama_decode(ctx, batch, params.n_threads) != 0) { + if (llama_decode(ctx, batch) != 0) { LOG_TEE("%s: llama_decode() failed\n", __func__); return 1; } @@ -112,7 +122,7 @@ int main(int argc, char ** argv) { while (n_cur <= n_len) { // sample the next token { - auto n_vocab = llama_n_vocab(ctx); + auto n_vocab = llama_n_vocab(model); auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); std::vector<llama_token_data> candidates; @@ -154,7 +164,7 @@ int main(int argc, char ** argv) { n_cur += 1; // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch, params.n_threads)) { + if (llama_decode(ctx, batch)) { fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); return 1; } |