summaryrefslogtreecommitdiff
path: root/examples/embedding/embedding.cpp
diff options
context:
space:
mode:
authorslaren <slarengh@gmail.com>2023-09-28 21:42:38 +0200
committerGitHub <noreply@github.com>2023-09-28 22:42:38 +0300
commit16bc66d9479edd5ee12ec734973554d4493c5dfa (patch)
tree4cca787ebd86dd55fd176d27112117c74e9b34c6 /examples/embedding/embedding.cpp
parent0512d66670de3f650c579519833c085014b0f200 (diff)
llama.cpp : split llama_context_params into model and context params (#3301)
* llama.cpp : split llama_context_params into model and context params ggml-ci * fix metal build * fix freq_base/scale default to model value * llama-bench : keep the same model between tests when possible * move n_threads to llama_context_params, add n_threads_batch * fix mpi build * remove kv_size(), cuda scratch fixes * remove low-vram option * add n_threads_batch to system info, refactor to get_system_info() * add documentation about --threads-batch to the READMEs * llama-bench fix * main : fix rope freq/scale warning * llama.cpp : add llama_get_model common : add llama_tokenize from model * remove duplicated ctx/model functions ggml-ci * cuda : print total VRAM used
Diffstat (limited to 'examples/embedding/embedding.cpp')
-rw-r--r--examples/embedding/embedding.cpp21
1 files changed, 11 insertions, 10 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index 18cefa23..14075609 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -42,17 +42,18 @@ int main(int argc, char ** argv) {
return 1;
}
- const int n_ctx_train = llama_n_ctx_train(ctx);
- if (params.n_ctx > n_ctx_train) {
+ const int n_ctx_train = llama_n_ctx_train(model);
+ const int n_ctx = llama_n_ctx(ctx);
+
+ if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
- __func__, n_ctx_train, params.n_ctx);
+ __func__, n_ctx_train, n_ctx);
}
// print system information
{
fprintf(stderr, "\n");
- fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
- params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
+ fprintf(stderr, "%s\n", get_system_info(params).c_str());
}
int n_past = 0;
@@ -70,15 +71,15 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n");
}
- if (embd_inp.size() > (size_t)params.n_ctx) {
+ if (embd_inp.size() > (size_t)n_ctx) {
fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
- __func__, embd_inp.size(), params.n_ctx);
+ __func__, embd_inp.size(), n_ctx);
return 1;
}
while (!embd_inp.empty()) {
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
- if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0), params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
@@ -86,8 +87,8 @@ int main(int argc, char ** argv) {
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
}
- const int n_embd = llama_n_embd(ctx);
- const auto embeddings = llama_get_embeddings(ctx);
+ const int n_embd = llama_n_embd(model);
+ const auto * embeddings = llama_get_embeddings(ctx);
for (int i = 0; i < n_embd; i++) {
printf("%f ", embeddings[i]);