diff options
Diffstat (limited to 'common')
-rw-r--r-- | common/common.cpp | 12 | ||||
-rw-r--r-- | common/common.h | 4 |
2 files changed, 16 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp index 2a83b379..88a962ae 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -417,6 +417,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.antiprompt.push_back(argv[i]); } else if (arg == "--perplexity") { params.perplexity = true; + } else if (arg == "--ppl-stride") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.ppl_stride = std::stoi(argv[i]); + } else if (arg == "--ppl-output-type") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.ppl_output_type = std::stoi(argv[i]); } else if (arg == "--hellaswag") { params.hellaswag = true; } else if (arg == "--hellaswag-tasks") { diff --git a/common/common.h b/common/common.h index 18fd951e..d68a8ef8 100644 --- a/common/common.h +++ b/common/common.h @@ -64,6 +64,10 @@ struct gpt_params { std::string lora_adapter = ""; // lora adapter path std::string lora_base = ""; // base model path for the lora adapter + int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. + int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line + // (which is more convenient to use for plotting) + // bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score |