diff options
Diffstat (limited to 'examples/server')
-rw-r--r-- | examples/server/server.cpp | 53 |
1 files changed, 42 insertions, 11 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 208edd57..8fe5e0b1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1210,7 +1210,7 @@ struct llama_server_context queue_results.send(res); } - void send_embedding(server_slot &slot) + void send_embedding(server_slot & slot, const llama_batch & batch) { task_result res; res.id = slot.task_id; @@ -1219,6 +1219,7 @@ struct llama_server_context res.stop = true; const int n_embd = llama_n_embd(model); + if (!params.embedding) { LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}}); @@ -1229,12 +1230,29 @@ struct llama_server_context } else { - const float *data = llama_get_embeddings(ctx); - std::vector<float> embedding(data, data + n_embd); - res.result_json = json - { - {"embedding", embedding}, - }; + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + if (embd == NULL) { + LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}}); + res.result_json = json + { + {"embedding", std::vector<float>(n_embd, 0.0f)}, + }; + continue; + } + } + + res.result_json = json + { + {"embedding", std::vector<float>(embd, embd + n_embd)}, + }; + } } queue_results.send(res); } @@ -1845,7 +1863,7 @@ struct llama_server_context ga_i += ga_w/ga_n; } } - llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false); slot_npast++; } @@ -1881,7 +1899,7 @@ struct llama_server_context for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); for (auto & slot : slots) { @@ -1954,7 +1972,7 @@ struct llama_server_context // prompt evaluated for embedding if (slot.embedding) { - send_embedding(slot); + send_embedding(slot, batch_view); slot.release(); slot.i_batch = -1; continue; @@ -2036,6 +2054,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); + printf(" --pooling {none,mean,cls}\n"); + printf(" pooling type for embeddings, use model default if unspecified\n"); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); @@ -2276,6 +2296,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.yarn_beta_slow = std::stof(argv[i]); } + else if (arg == "--pooling") + { + if (++i >= argc) { + invalid_param = true; + break; + } + std::string value(argv[i]); + /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } + else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } + else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else { invalid_param = true; break; } + } else if (arg == "--threads" || arg == "-t") { if (++i >= argc) @@ -2330,7 +2362,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } params.n_batch = std::stoi(argv[i]); - params.n_batch = std::min(512, params.n_batch); } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { |