summaryrefslogtreecommitdiff
path: root/examples/batched
diff options
context:
space:
mode:
authorslaren <slarengh@gmail.com>2023-09-28 21:42:38 +0200
committerGitHub <noreply@github.com>2023-09-28 22:42:38 +0300
commit16bc66d9479edd5ee12ec734973554d4493c5dfa (patch)
tree4cca787ebd86dd55fd176d27112117c74e9b34c6 /examples/batched
parent0512d66670de3f650c579519833c085014b0f200 (diff)
llama.cpp : split llama_context_params into model and context params (#3301)
* llama.cpp : split llama_context_params into model and context params ggml-ci * fix metal build * fix freq_base/scale default to model value * llama-bench : keep the same model between tests when possible * move n_threads to llama_context_params, add n_threads_batch * fix mpi build * remove kv_size(), cuda scratch fixes * remove low-vram option * add n_threads_batch to system info, refactor to get_system_info() * add documentation about --threads-batch to the READMEs * llama-bench fix * main : fix rope freq/scale warning * llama.cpp : add llama_get_model common : add llama_tokenize from model * remove duplicated ctx/model functions ggml-ci * cuda : print total VRAM used
Diffstat (limited to 'examples/batched')
-rw-r--r--examples/batched/batched.cpp39
1 files changed, 24 insertions, 15 deletions
diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp
index 4dd1d553..688ef221 100644
--- a/examples/batched/batched.cpp
+++ b/examples/batched/batched.cpp
@@ -40,20 +40,35 @@ int main(int argc, char ** argv) {
llama_backend_init(params.numa);
- llama_context_params ctx_params = llama_context_default_params();
+ // initialize the model
- ctx_params.seed = 1234;
- ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301)
- ctx_params.n_batch = std::max(n_len, n_parallel);
- // ctx_params.n_gpu_layers = 99; // offload all layers to the GPU
+ llama_model_params model_params = llama_model_default_params();
- llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
+ // model_params.n_gpu_layers = 99; // offload all layers to the GPU
+
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
+ // tokenize the prompt
+
+ std::vector<llama_token> tokens_list;
+ tokens_list = ::llama_tokenize(model, params.prompt, true);
+ const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
+
+ // initialize the context
+
+ llama_context_params ctx_params = llama_context_default_params();
+
+ ctx_params.seed = 1234;
+ ctx_params.n_ctx = n_kv_req;
+ ctx_params.n_batch = std::max(n_len, n_parallel);
+ ctx_params.n_threads = params.n_threads;
+ ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
@@ -61,13 +76,7 @@ int main(int argc, char ** argv) {
return 1;
}
- // tokenize the prompt
-
- std::vector<llama_token> tokens_list;
- tokens_list = ::llama_tokenize(ctx, params.prompt, true);
-
const int n_ctx = llama_n_ctx(ctx);
- const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
@@ -106,7 +115,7 @@ int main(int argc, char ** argv) {
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
- if (llama_decode(ctx, batch, params.n_threads) != 0) {
+ if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
@@ -146,7 +155,7 @@ int main(int argc, char ** argv) {
continue;
}
- auto n_vocab = llama_n_vocab(ctx);
+ auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
std::vector<llama_token_data> candidates;
@@ -210,7 +219,7 @@ int main(int argc, char ** argv) {
n_cur += 1;
// evaluate the current batch with the transformer model
- if (llama_decode(ctx, batch, params.n_threads)) {
+ if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}