diff options
Diffstat (limited to 'common/common.h')
-rw-r--r-- | common/common.h | 105 |
1 files changed, 82 insertions, 23 deletions
diff --git a/common/common.h b/common/common.h index 26450483..e0a08a61 100644 --- a/common/common.h +++ b/common/common.h @@ -60,7 +60,7 @@ struct gpt_params { int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) int32_t n_threads_batch_draft = -1; int32_t n_predict = -1; // new tokens to predict - int32_t n_ctx = 512; // context size + int32_t n_ctx = 0; // context size int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt @@ -99,23 +99,23 @@ struct gpt_params { // // sampling parameters struct llama_sampling_params sparams; - std::string model = ""; // model path - std::string model_draft = ""; // draft model for speculative decoding + std::string model = ""; // model path + std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias - std::string model_url = ""; // model url to download - std::string hf_repo = ""; // HF repo - std::string hf_file = ""; // HF file + std::string model_url = ""; // model url to download + std::string hf_repo = ""; // HF repo + std::string hf_file = ""; // HF file std::string prompt = ""; - std::string prompt_file = ""; // store the external prompt file name - std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state - std::string input_prefix = ""; // string to prefix user inputs with - std::string input_suffix = ""; // string to suffix user inputs with - std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted - std::string logdir = ""; // directory in which to save YAML log files + std::string prompt_file = ""; // store the external prompt file name + std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state + std::string input_prefix = ""; // string to prefix user inputs with + std::string input_suffix = ""; // string to suffix user inputs with + std::string logdir = ""; // directory in which to save YAML log files std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding - std::string logits_file = ""; // file for saving *all* logits + std::string logits_file = ""; // file for saving *all* logits + std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector<llama_model_kv_override> kv_overrides; // TODO: avoid tuple, use struct @@ -127,8 +127,8 @@ struct gpt_params { int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector - int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. - int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line + int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. + int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line // (which is more convenient to use for plotting) // bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt @@ -142,19 +142,17 @@ struct gpt_params { bool kl_divergence = false; // compute KL divergence - bool random_prompt = false; // do not randomize prompt if none provided + bool usage = false; // print usage bool use_color = false; // use color to distinguish generations and inputs - bool interactive = false; // interactive mode - bool interactive_specials = false; // whether to allow special tokens from user, during interactive mode bool special = false; // enable special token output + bool interactive = false; // interactive mode + bool interactive_first = false; // wait for user input immediately bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) - bool chatml = false; // chatml mode (used for models trained on chatml syntax) bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it bool embedding = false; // get only sentence embedding - bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\" - bool interactive_first = false; // wait for user input immediately + bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\" bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly @@ -162,10 +160,10 @@ struct gpt_params { bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens - bool instruct = false; // instruction mode (used for Alpaca models) bool logits_all = false; // return logits for all tokens in the batch bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory + bool verbose = false; bool verbose_prompt = false; // print prompt tokens before generation bool display_prompt = true; // print prompt before generation bool infill = false; // use infill mode @@ -180,6 +178,47 @@ struct gpt_params { // multimodal models (see examples/llava) std::string mmproj = ""; // path to multimodal projector std::vector<std::string> image; // path to image file(s) + + // server params + int32_t port = 8080; + int32_t timeout_read = 600; + int32_t timeout_write = timeout_read; + int32_t n_threads_http = -1; + + std::string hostname = "127.0.0.1"; + std::string public_path = ""; + std::string chat_template = ""; + std::string system_prompt = ""; + + std::vector<std::string> api_keys; + + std::string ssl_file_key = ""; + std::string ssl_file_cert = ""; + + bool endpoint_slots = true; + bool endpoint_metrics = false; + + bool log_json = false; + + std::string slot_save_path; + + // batched-bench params + bool is_pp_shared = false; + + std::vector<int32_t> n_pp; + std::vector<int32_t> n_tg; + std::vector<int32_t> n_pl; + + // retrieval params + std::vector<std::string> context_files; // context files to embed + + int32_t chunk_size = 64; // chunk size for context embedding + + std::string chunk_separator = "\n"; // chunk separator for context embedding + + // passkey params + int32_t n_junk = 250; // number of times to repeat the junk text + int32_t i_pos = -1; // position of the passkey in the junk text }; void gpt_params_handle_model_default(gpt_params & params); @@ -199,7 +238,20 @@ std::vector<std::string> string_split(std::string input, char separator); std::string string_strip(const std::string & str); std::string string_get_sortable_timestamp(); -std::string string_random_prompt(std::mt19937 & rng); + +template<class T> +static std::vector<T> string_split(const std::string & str, char delim) { + std::vector<T> values; + std::istringstream str_stream(str); + std::string token; + while (std::getline(str_stream, token, delim)) { + T value; + std::istringstream token_stream(token); + token_stream >> value; + values.push_back(value); + } + return values; +} bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides); void string_process_escapes(std::string & input); @@ -283,6 +335,13 @@ std::string llama_detokenize_bpe( bool llama_should_add_bos_token(const llama_model * model); // +// Chat template utils +// + +// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid +bool llama_chat_verify_template(const std::string & tmpl); + +// // KV cache utils // |