summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-03-04 22:31:20 +0200
committerGitHub <noreply@github.com>2024-03-04 22:31:20 +0200
commit29ae62d2ae163e2b68aa0ad3bf2ab4636de0c957 (patch)
treea65058dfddf1672f1d765e324dac9f66abf1a7c1 /examples
parente0843afe1b37890b631bc7d3d2da2ed36c862b91 (diff)
llama : fix embeddings (#5796)
* llama : fix embeddings ggml-ci * llama : do not use KV cache for non-causal models ggml-ci * embeddings : fix llama_batch_init arg * llama : add pooling switch * llama : distinguish token vs sequence embeddings ggml-ci * llama : assert pooling tensor * llama : simplify causal mask condition ggml-ci * llama : assert input batch with pooling enabled * readme : update API changes list
Diffstat (limited to 'examples')
-rw-r--r--examples/embedding/embedding.cpp28
-rw-r--r--examples/server-embd.py34
-rw-r--r--examples/server/server.cpp53
3 files changed, 97 insertions, 18 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index acff715e..ff5883da 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -19,11 +19,11 @@ static std::vector<std::string> split_lines(const std::string & s) {
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
for (size_t i = 0; i < tokens.size(); i++) {
- llama_batch_add(batch, tokens[i], i, { seq_id }, false);
+ llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
}
}
-static void normalize(float * vec, float * out, int n) {
+static void normalize(const float * vec, float * out, int n) {
float norm = 0;
for (int i = 0; i < n; i++) {
norm += vec[i] * vec[i];
@@ -45,10 +45,23 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
// normalize on copy
- for (int k = 0; k < n_seq; k++) {
- float * emb = llama_get_embeddings_ith(ctx, k);
- float * out = output + k * n_embd;
- normalize(emb, out, n_embd);
+ for (int i = 0; i < batch.n_tokens; i++) {
+ if (!batch.logits[i]) {
+ continue;
+ }
+
+ // try to get sequence embeddings - supported only when pooling_type is not NONE
+ const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
+ if (embd == NULL) {
+ embd = llama_get_embeddings_ith(ctx, i);
+ if (embd == NULL) {
+ fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
+ continue;
+ }
+ }
+
+ float * out = output + batch.seq_id[i][0] * n_embd;
+ normalize(embd, out, n_embd);
}
}
@@ -132,7 +145,7 @@ int main(int argc, char ** argv) {
// initialize batch
const int n_prompts = prompts.size();
- struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts);
+ struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
// allocate output
const int n_embd = llama_n_embd(model);
@@ -145,6 +158,7 @@ int main(int argc, char ** argv) {
for (int k = 0; k < n_prompts; k++) {
// clamp to n_batch tokens
auto & inp = inputs[k];
+
const uint64_t n_toks = inp.size();
// encode if at capacity
diff --git a/examples/server-embd.py b/examples/server-embd.py
new file mode 100644
index 00000000..c5c4ea87
--- /dev/null
+++ b/examples/server-embd.py
@@ -0,0 +1,34 @@
+import asyncio
+import requests
+import numpy as np
+
+n = 8
+
+result = []
+
+async def requests_post_async(*args, **kwargs):
+ return await asyncio.to_thread(requests.post, *args, **kwargs)
+
+async def main():
+ model_url = "http://127.0.0.1:6900"
+ responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
+ url= f"{model_url}/embedding",
+ json= {"content": str(i)*1024}
+ ) for i in range(n)])
+
+ for response in responses:
+ embedding = response.json()["embedding"]
+ print(embedding[-8:])
+ result.append(embedding)
+
+asyncio.run(main())
+
+# compute cosine similarity
+
+for i in range(n-1):
+ for j in range(i+1, n):
+ embedding1 = np.array(result[i])
+ embedding2 = np.array(result[j])
+ similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
+ print(f"Similarity between {i} and {j}: {similarity:.2f}")
+
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 208edd57..8fe5e0b1 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -1210,7 +1210,7 @@ struct llama_server_context
queue_results.send(res);
}
- void send_embedding(server_slot &slot)
+ void send_embedding(server_slot & slot, const llama_batch & batch)
{
task_result res;
res.id = slot.task_id;
@@ -1219,6 +1219,7 @@ struct llama_server_context
res.stop = true;
const int n_embd = llama_n_embd(model);
+
if (!params.embedding)
{
LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}});
@@ -1229,12 +1230,29 @@ struct llama_server_context
}
else
{
- const float *data = llama_get_embeddings(ctx);
- std::vector<float> embedding(data, data + n_embd);
- res.result_json = json
- {
- {"embedding", embedding},
- };
+ for (int i = 0; i < batch.n_tokens; ++i) {
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
+ continue;
+ }
+
+ const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
+ if (embd == NULL) {
+ embd = llama_get_embeddings_ith(ctx, i);
+ if (embd == NULL) {
+ LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}});
+ res.result_json = json
+ {
+ {"embedding", std::vector<float>(n_embd, 0.0f)},
+ };
+ continue;
+ }
+ }
+
+ res.result_json = json
+ {
+ {"embedding", std::vector<float>(embd, embd + n_embd)},
+ };
+ }
}
queue_results.send(res);
}
@@ -1845,7 +1863,7 @@ struct llama_server_context
ga_i += ga_w/ga_n;
}
}
- llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
+ llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
slot_npast++;
}
@@ -1881,7 +1899,7 @@ struct llama_server_context
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
{
- const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
+ const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
for (auto & slot : slots)
{
@@ -1954,7 +1972,7 @@ struct llama_server_context
// prompt evaluated for embedding
if (slot.embedding)
{
- send_embedding(slot);
+ send_embedding(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue;
@@ -2036,6 +2054,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
+ printf(" --pooling {none,mean,cls}\n");
+ printf(" pooling type for embeddings, use model default if unspecified\n");
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
@@ -2276,6 +2296,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.yarn_beta_slow = std::stof(argv[i]);
}
+ else if (arg == "--pooling")
+ {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::string value(argv[i]);
+ /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
+ else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
+ else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
+ else { invalid_param = true; break; }
+ }
else if (arg == "--threads" || arg == "-t")
{
if (++i >= argc)
@@ -2330,7 +2362,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
params.n_batch = std::stoi(argv[i]);
- params.n_batch = std::min(512, params.n_batch);
}
else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers")
{