summaryrefslogtreecommitdiff
path: root/common/common.h
diff options
context:
space:
mode:
Diffstat (limited to 'common/common.h')
-rw-r--r--common/common.h31
1 files changed, 30 insertions, 1 deletions
diff --git a/common/common.h b/common/common.h
index d250eef8..687f3425 100644
--- a/common/common.h
+++ b/common/common.h
@@ -37,10 +37,13 @@ extern char const *LLAMA_COMMIT;
extern char const *LLAMA_COMPILER;
extern char const *LLAMA_BUILD_TARGET;
+struct llama_control_vector_load_info;
+
+int32_t get_num_physical_cores();
+
//
// CLI argument parsing
//
-int32_t get_num_physical_cores();
struct gpt_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
@@ -103,6 +106,11 @@ struct gpt_params {
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
std::string lora_base = ""; // base model path for the lora adapter
+ std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale
+
+ int32_t control_vector_layer_start = -1; // layer range for control vector
+ int32_t control_vector_layer_end = -1; // layer range for control vector
+
int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
// (which is more convenient to use for plotting)
@@ -269,3 +277,24 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40
void llama_embd_normalize(const float * inp, float * out, int n);
float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n);
+
+//
+// Control vector utils
+//
+
+struct llama_control_vector_data {
+ int n_embd;
+
+ // stores data for layers [1, n_layer] where n_layer = data.size() / n_embd
+ std::vector<float> data;
+};
+
+struct llama_control_vector_load_info {
+ float strength;
+
+ std::string fname;
+};
+
+// Load control vectors, scale each by strength, and add them together.
+// On error, returns {-1, empty}
+llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos);