summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2023-08-23 12:56:42 +0300
committerGitHub <noreply@github.com>2023-08-23 12:56:42 +0300
commit62959e740e8759d246ac8d09036950efde09981c (patch)
treef868215d6def28423e1f161e84a45c64ad24019a /common
parent7f7ddd5002040804e33fcdbde44aa22f8635f57d (diff)
Strided perplexity (#2714)
* Implementing strided computation of perplexity * Alternative way to output PPL results --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp12
-rw-r--r--common/common.h4
2 files changed, 16 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 2a83b379..88a962ae 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -417,6 +417,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.antiprompt.push_back(argv[i]);
} else if (arg == "--perplexity") {
params.perplexity = true;
+ } else if (arg == "--ppl-stride") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.ppl_stride = std::stoi(argv[i]);
+ } else if (arg == "--ppl-output-type") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.ppl_output_type = std::stoi(argv[i]);
} else if (arg == "--hellaswag") {
params.hellaswag = true;
} else if (arg == "--hellaswag-tasks") {
diff --git a/common/common.h b/common/common.h
index 18fd951e..d68a8ef8 100644
--- a/common/common.h
+++ b/common/common.h
@@ -64,6 +64,10 @@ struct gpt_params {
std::string lora_adapter = ""; // lora adapter path
std::string lora_base = ""; // base model path for the lora adapter
+ int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
+ int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
+ // (which is more convenient to use for plotting)
+ //
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score