summaryrefslogtreecommitdiff
path: root/common/common.cpp
diff options
context:
space:
mode:
authorDouglas Hanley <thesecretaryofwar@gmail.com>2024-03-03 04:40:27 -0600
committerGitHub <noreply@github.com>2024-03-03 12:40:27 +0200
commit475df1d6cf817060028d3ff763cb8097d4ec40d6 (patch)
tree5cad43f149f24b7b3f40604b78b7971e458aa309 /common/common.cpp
parent87c2e8b2797860a06af3d6c06b8488a8ff1a09ab (diff)
llama : allow for user specified embedding pooling type (#5849)
* allow for user specified pooling type * llama : use enum types over int --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'common/common.cpp')
-rw-r--r--common/common.cpp13
1 files changed, 13 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 1c0b7c40..dbe7e922 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -335,6 +335,16 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.yarn_beta_slow = std::stof(argv[i]);
+ } else if (arg == "--pooling") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::string value(argv[i]);
+ /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
+ else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
+ else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
+ else { invalid_param = true; break; }
} else if (arg == "--defrag-thold" || arg == "-dt") {
if (++i >= argc) {
invalid_param = true;
@@ -1014,6 +1024,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
+ printf(" --pooling {none,mean,cls}\n");
+ printf(" pooling type for embeddings, use model default if unspecified\n");
printf(" -dt N, --defrag-thold N\n");
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
@@ -1296,6 +1308,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
+ cparams.pooling_type = params.pooling_type;
cparams.defrag_thold = params.defrag_thold;
cparams.offload_kqv = !params.no_kv_offload;