diff options
author | slaren <slarengh@gmail.com> | 2023-08-22 16:03:12 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-22 16:03:12 +0200 |
commit | 519c981f8b65ee6c87c2965539685ced0a17223b (patch) | |
tree | b8eafbffa237d87761b13dce52e245332d4d2275 /examples | |
parent | 1123f7fbdfb8012e46f05e903e6f675922916378 (diff) |
embedding : evaluate prompt in batches (#2713)
Diffstat (limited to 'examples')
-rw-r--r-- | examples/embedding/embedding.cpp | 31 |
1 files changed, 19 insertions, 12 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 8788571c..38395c75 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -72,22 +72,29 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); } - if (params.embedding){ - if (embd_inp.size() > 0) { - if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; - } + if (embd_inp.size() > (size_t)params.n_ctx) { + fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n", + __func__, embd_inp.size(), params.n_ctx); + return 1; + } + + while (!embd_inp.empty()) { + int n_tokens = std::min(params.n_batch, (int) embd_inp.size()); + if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; } + n_past += n_tokens; + embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens); + } - const int n_embd = llama_n_embd(ctx); - const auto embeddings = llama_get_embeddings(ctx); + const int n_embd = llama_n_embd(ctx); + const auto embeddings = llama_get_embeddings(ctx); - for (int i = 0; i < n_embd; i++) { - printf("%f ", embeddings[i]); - } - printf("\n"); + for (int i = 0; i < n_embd; i++) { + printf("%f ", embeddings[i]); } + printf("\n"); llama_print_timings(ctx); llama_free(ctx); |