diff options
author | Douglas Hanley <thesecretaryofwar@gmail.com> | 2024-02-13 06:06:58 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-13 14:06:58 +0200 |
commit | 03bf161eb6dea6400ee49c6dc6b69bdcfa9fd3fc (patch) | |
tree | 49320ac8aca35d2ba8162c2a280924bacbd7e06b /examples/embedding/embedding.cpp | |
parent | ad014bba97ef6ef6c3e2f78b2fc463e91ae94579 (diff) |
llama : support batched embeddings (#5466)
* batched embedding: pool outputs by sequence id. updated embedding example
* bring back non-causal attention
* embd : minor improvements
* llama : minor
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/embedding/embedding.cpp')
-rw-r--r-- | examples/embedding/embedding.cpp | 142 |
1 files changed, 106 insertions, 36 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 27376c8f..b4688cf5 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -7,6 +7,51 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +static std::vector<std::string> split_lines(const std::string & s) { + std::string line; + std::vector<std::string> lines; + std::stringstream ss(s); + while (std::getline(ss, line)) { + lines.push_back(line); + } + return lines; +} + +static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) { + for (size_t i = 0; i < tokens.size(); i++) { + llama_batch_add(batch, tokens[i], i, { seq_id }, false); + } +} + +static void normalize(float * vec, float * out, int n) { + float norm = 0; + for (int i = 0; i < n; i++) { + norm += vec[i] * vec[i]; + } + norm = sqrt(norm); + for (int i = 0; i < n; i++) { + out[i] = vec[i] / norm; + } +} + +static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { + // clear previous kv_cache values (irrelevant for embeddings) + llama_kv_cache_clear(ctx); + + // run model + fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); + if (llama_decode(ctx, batch) < 0) { + fprintf(stderr, "%s : failed to decode\n", __func__); + } + + // normalize on copy + for (int k = 0; k < n_seq; k++) { + float * emb = llama_get_embeddings_ith(ctx, k); + float * out = output + k * n_embd; + normalize(emb, out, n_embd); + } +} + int main(int argc, char ** argv) { gpt_params params; @@ -55,59 +100,84 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s\n", get_system_info(params).c_str()); } - int n_past = 0; + // split the prompt into lines + std::vector<std::string> prompts = split_lines(params.prompt); - // tokenize the prompt - auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); + // max batch size + const uint64_t n_batch = params.n_batch; + GGML_ASSERT(params.n_batch == params.n_ctx); - if (params.verbose_prompt) { - fprintf(stderr, "\n"); - fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); - fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); - for (int i = 0; i < (int) embd_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str()); + // tokenize the prompts and trim + std::vector<std::vector<int32_t>> inputs; + for (const auto & prompt : prompts) { + auto inp = ::llama_tokenize(ctx, prompt, true); + if (inp.size() > n_batch) { + inp.resize(n_batch); } - fprintf(stderr, "\n"); + inputs.push_back(inp); } - if (embd_inp.size() > (size_t)n_ctx) { - fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n", - __func__, embd_inp.size(), n_ctx); - return 1; - } - - while (!embd_inp.empty()) { - int n_tokens = std::min(params.n_batch, (int) embd_inp.size()); - if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; + // tokenization stats + if (params.verbose_prompt) { + for (int i = 0; i < (int) inputs.size(); i++) { + fprintf(stderr, "%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str()); + fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size()); + for (int j = 0; j < (int) inputs[i].size(); j++) { + fprintf(stderr, "%6d -> '%s'\n", inputs[i][j], llama_token_to_piece(ctx, inputs[i][j]).c_str()); + } + fprintf(stderr, "\n\n"); } - n_past += n_tokens; - embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens); } + // initialize batch + const int n_prompts = prompts.size(); + struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts); + + // allocate output const int n_embd = llama_n_embd(model); - auto * embeddings = llama_get_embeddings(ctx); + std::vector<float> embeddings(n_prompts * n_embd, 0); + float * emb = embeddings.data(); + + // break into batches + int p = 0; // number of prompts processed already + int s = 0; // number of prompts in current batch + for (int k = 0; k < n_prompts; k++) { + // clamp to n_batch tokens + auto & inp = inputs[k]; + const uint64_t n_toks = inp.size(); + + // encode if at capacity + if (batch.n_tokens + n_toks > n_batch) { + float * out = emb + p * n_embd; + batch_decode(ctx, batch, out, s, n_embd); + llama_batch_clear(batch); + p += s; + s = 0; + } - // l2-normalize embeddings - float norm = 0; - for (int i = 0; i < n_embd; i++) { - norm += embeddings[i] * embeddings[i]; - } - norm = sqrt(norm); - for (int i = 0; i < n_embd; i++) { - embeddings[i] /= norm; + // add to batch + batch_add_seq(batch, inp, s); + s += 1; } - for (int i = 0; i < n_embd; i++) { - printf("%f ", embeddings[i]); + // final batch + float * out = emb + p * n_embd; + batch_decode(ctx, batch, out, s, n_embd); + + // print first 3 embeddings + for (int j = 0; j < std::min(3, n_prompts); j++) { + fprintf(stderr, "embedding %d: ", j); + for (int i = 0; i < n_embd; i++) { + fprintf(stderr, "%f ", emb[j * n_embd + i]); + } + fprintf(stderr, "\n\n"); } - printf("\n"); + fprintf(stderr, "\n"); + // clean up llama_print_timings(ctx); llama_free(ctx); llama_free_model(model); - llama_backend_free(); return 0; |