summaryrefslogtreecommitdiff
path: root/llama.h
diff options
context:
space:
mode:
authorDouglas Hanley <thesecretaryofwar@gmail.com>2024-03-03 04:40:27 -0600
committerGitHub <noreply@github.com>2024-03-03 12:40:27 +0200
commit475df1d6cf817060028d3ff763cb8097d4ec40d6 (patch)
tree5cad43f149f24b7b3f40604b78b7971e458aa309 /llama.h
parent87c2e8b2797860a06af3d6c06b8488a8ff1a09ab (diff)
llama : allow for user specified embedding pooling type (#5849)
* allow for user specified pooling type * llama : use enum types over int --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'llama.h')
-rw-r--r--llama.h7
1 files changed, 5 insertions, 2 deletions
diff --git a/llama.h b/llama.h
index 6406b527..70da4cb3 100644
--- a/llama.h
+++ b/llama.h
@@ -129,6 +129,7 @@ extern "C" {
};
enum llama_pooling_type {
+ LLAMA_POOLING_TYPE_UNSPECIFIED = -1,
LLAMA_POOLING_TYPE_NONE = 0,
LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2,
@@ -236,7 +237,10 @@ extern "C" {
uint32_t n_batch; // prompt processing maximum batch size
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
- int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
+
+ enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
+ enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
+ // (ignored if no pooling layer)
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model
@@ -258,7 +262,6 @@ extern "C" {
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embedding; // embedding mode only
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
- bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
// Abort callback
// if it returns true, execution of llama_decode() will be aborted