From 475df1d6cf817060028d3ff763cb8097d4ec40d6 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Sun, 3 Mar 2024 04:40:27 -0600 Subject: llama : allow for user specified embedding pooling type (#5849) * allow for user specified pooling type * llama : use enum types over int --------- Co-authored-by: Georgi Gerganov --- common/common.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'common/common.cpp') diff --git a/common/common.cpp b/common/common.cpp index 1c0b7c40..dbe7e922 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -335,6 +335,16 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.yarn_beta_slow = std::stof(argv[i]); + } else if (arg == "--pooling") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::string value(argv[i]); + /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } + else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } + else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else { invalid_param = true; break; } } else if (arg == "--defrag-thold" || arg == "-dt") { if (++i >= argc) { invalid_param = true; @@ -1014,6 +1024,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); + printf(" --pooling {none,mean,cls}\n"); + printf(" pooling type for embeddings, use model default if unspecified\n"); printf(" -dt N, --defrag-thold N\n"); printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold); printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); @@ -1296,6 +1308,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_beta_fast = params.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_orig_ctx = params.yarn_orig_ctx; + cparams.pooling_type = params.pooling_type; cparams.defrag_thold = params.defrag_thold; cparams.offload_kqv = !params.no_kv_offload; -- cgit v1.2.3