summaryrefslogtreecommitdiff
path: root/examples/embedding
diff options
context:
space:
mode:
Diffstat (limited to 'examples/embedding')
-rw-r--r--examples/embedding/embedding.cpp142
1 files changed, 106 insertions, 36 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index 27376c8f..b4688cf5 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -7,6 +7,51 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
+static std::vector<std::string> split_lines(const std::string & s) {
+ std::string line;
+ std::vector<std::string> lines;
+ std::stringstream ss(s);
+ while (std::getline(ss, line)) {
+ lines.push_back(line);
+ }
+ return lines;
+}
+
+static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
+ for (size_t i = 0; i < tokens.size(); i++) {
+ llama_batch_add(batch, tokens[i], i, { seq_id }, false);
+ }
+}
+
+static void normalize(float * vec, float * out, int n) {
+ float norm = 0;
+ for (int i = 0; i < n; i++) {
+ norm += vec[i] * vec[i];
+ }
+ norm = sqrt(norm);
+ for (int i = 0; i < n; i++) {
+ out[i] = vec[i] / norm;
+ }
+}
+
+static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
+ // clear previous kv_cache values (irrelevant for embeddings)
+ llama_kv_cache_clear(ctx);
+
+ // run model
+ fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
+ if (llama_decode(ctx, batch) < 0) {
+ fprintf(stderr, "%s : failed to decode\n", __func__);
+ }
+
+ // normalize on copy
+ for (int k = 0; k < n_seq; k++) {
+ float * emb = llama_get_embeddings_ith(ctx, k);
+ float * out = output + k * n_embd;
+ normalize(emb, out, n_embd);
+ }
+}
+
int main(int argc, char ** argv) {
gpt_params params;
@@ -55,59 +100,84 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s\n", get_system_info(params).c_str());
}
- int n_past = 0;
+ // split the prompt into lines
+ std::vector<std::string> prompts = split_lines(params.prompt);
- // tokenize the prompt
- auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
+ // max batch size
+ const uint64_t n_batch = params.n_batch;
+ GGML_ASSERT(params.n_batch == params.n_ctx);
- if (params.verbose_prompt) {
- fprintf(stderr, "\n");
- fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
- fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
- for (int i = 0; i < (int) embd_inp.size(); i++) {
- fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
+ // tokenize the prompts and trim
+ std::vector<std::vector<int32_t>> inputs;
+ for (const auto & prompt : prompts) {
+ auto inp = ::llama_tokenize(ctx, prompt, true);
+ if (inp.size() > n_batch) {
+ inp.resize(n_batch);
}
- fprintf(stderr, "\n");
+ inputs.push_back(inp);
}
- if (embd_inp.size() > (size_t)n_ctx) {
- fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
- __func__, embd_inp.size(), n_ctx);
- return 1;
- }
-
- while (!embd_inp.empty()) {
- int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
- if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) {
- fprintf(stderr, "%s : failed to eval\n", __func__);
- return 1;
+ // tokenization stats
+ if (params.verbose_prompt) {
+ for (int i = 0; i < (int) inputs.size(); i++) {
+ fprintf(stderr, "%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str());
+ fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size());
+ for (int j = 0; j < (int) inputs[i].size(); j++) {
+ fprintf(stderr, "%6d -> '%s'\n", inputs[i][j], llama_token_to_piece(ctx, inputs[i][j]).c_str());
+ }
+ fprintf(stderr, "\n\n");
}
- n_past += n_tokens;
- embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
}
+ // initialize batch
+ const int n_prompts = prompts.size();
+ struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts);
+
+ // allocate output
const int n_embd = llama_n_embd(model);
- auto * embeddings = llama_get_embeddings(ctx);
+ std::vector<float> embeddings(n_prompts * n_embd, 0);
+ float * emb = embeddings.data();
+
+ // break into batches
+ int p = 0; // number of prompts processed already
+ int s = 0; // number of prompts in current batch
+ for (int k = 0; k < n_prompts; k++) {
+ // clamp to n_batch tokens
+ auto & inp = inputs[k];
+ const uint64_t n_toks = inp.size();
+
+ // encode if at capacity
+ if (batch.n_tokens + n_toks > n_batch) {
+ float * out = emb + p * n_embd;
+ batch_decode(ctx, batch, out, s, n_embd);
+ llama_batch_clear(batch);
+ p += s;
+ s = 0;
+ }
- // l2-normalize embeddings
- float norm = 0;
- for (int i = 0; i < n_embd; i++) {
- norm += embeddings[i] * embeddings[i];
- }
- norm = sqrt(norm);
- for (int i = 0; i < n_embd; i++) {
- embeddings[i] /= norm;
+ // add to batch
+ batch_add_seq(batch, inp, s);
+ s += 1;
}
- for (int i = 0; i < n_embd; i++) {
- printf("%f ", embeddings[i]);
+ // final batch
+ float * out = emb + p * n_embd;
+ batch_decode(ctx, batch, out, s, n_embd);
+
+ // print first 3 embeddings
+ for (int j = 0; j < std::min(3, n_prompts); j++) {
+ fprintf(stderr, "embedding %d: ", j);
+ for (int i = 0; i < n_embd; i++) {
+ fprintf(stderr, "%f ", emb[j * n_embd + i]);
+ }
+ fprintf(stderr, "\n\n");
}
- printf("\n");
+ fprintf(stderr, "\n");
+ // clean up
llama_print_timings(ctx);
llama_free(ctx);
llama_free_model(model);
-
llama_backend_free();
return 0;