summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp10
-rw-r--r--common/common.h3
2 files changed, 13 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 2b0865ff..ce20360a 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -681,6 +681,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.hellaswag_tasks = std::stoi(argv[i]);
+ } else if (arg == "--winogrande") {
+ params.winogrande = true;
+ } else if (arg == "--winogrande-tasks") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.winogrande_tasks = std::stoi(argv[i]);
} else if (arg == "--ignore-eos") {
params.ignore_eos = true;
} else if (arg == "--no-penalize-nl") {
@@ -926,6 +934,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
+ printf(" --winogrande compute Winogrande score over random tasks from datafile supplied with -f\n");
+ printf(" --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: %zu)\n", params.winogrande_tasks);
printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
diff --git a/common/common.h b/common/common.h
index 1f43e628..0ae9c18b 100644
--- a/common/common.h
+++ b/common/common.h
@@ -105,6 +105,9 @@ struct gpt_params {
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
+ bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt
+ size_t winogrande_tasks= 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed
+
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs