diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-01-22 16:10:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-22 16:10:14 +0200 |
commit | 6f9939d119b2d004c264952eb510bd106455531e (patch) | |
tree | bacfc729a033d1858019b53423d0ebc6afc26124 /common | |
parent | 780e24a22eb595b705cbe8284771e9ceff1c4dd2 (diff) |
KL-divergence (#5076)
* kl-divergence: be able to save all logits to a file
* Add ability to compute KL-divergence
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'common')
-rw-r--r-- | common/common.cpp | 9 | ||||
-rw-r--r-- | common/common.h | 3 |
2 files changed, 12 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp index 0e4b8bab..0a709617 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -672,6 +672,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { if (params.logdir.back() != DIRECTORY_SEPARATOR) { params.logdir += DIRECTORY_SEPARATOR; } + } else if (arg == "--save-all-logits" || arg == "--kl-divergence-base") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.logits_file = argv[i]; } else if (arg == "--perplexity" || arg == "--all-logits") { params.logits_all = true; } else if (arg == "--ppl-stride") { @@ -716,6 +722,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.multiple_choice_tasks = std::stoi(argv[i]); + } else if (arg == "--kl-divergence") { + params.kl_divergence = true; } else if (arg == "--ignore-eos") { params.ignore_eos = true; } else if (arg == "--no-penalize-nl") { @@ -967,6 +975,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: %zu)\n", params.winogrande_tasks); printf(" --multiple-choice compute multiple choice score over random tasks from datafile supplied with -f\n"); printf(" --multiple-choice-tasks N number of tasks to use when computing the multiple choice score (default: %zu)\n", params.winogrande_tasks); + printf(" --kl-divergence computes KL-divergence to logits provided via --kl-divergence-base"); printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft); printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); diff --git a/common/common.h b/common/common.h index c69ad7e9..214a379b 100644 --- a/common/common.h +++ b/common/common.h @@ -91,6 +91,7 @@ struct gpt_params { std::string input_suffix = ""; // string to suffix user inputs with std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted std::string logdir = ""; // directory in which to save YAML log files + std::string logits_file = ""; // file for saving *all* logits std::vector<llama_model_kv_override> kv_overrides; @@ -111,6 +112,8 @@ struct gpt_params { bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed + bool kl_divergence = false; // compute KL-divergence + bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs |