summaryrefslogtreecommitdiff
path: root/examples/perplexity/perplexity.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/perplexity/perplexity.cpp')
-rw-r--r--examples/perplexity/perplexity.cpp12
1 files changed, 6 insertions, 6 deletions
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index 30e5e282..0bd78c21 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -1032,7 +1032,7 @@ struct winogrande_entry {
std::vector<llama_token> seq_tokens[2];
};
-static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string& prompt) {
+static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string & prompt) {
std::vector<winogrande_entry> result;
std::istringstream in(prompt);
std::string line;
@@ -1964,12 +1964,14 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
int main(int argc, char ** argv) {
gpt_params params;
+ params.n_ctx = 512;
+ params.logits_all = true;
+
if (!gpt_params_parse(argc, argv, params)) {
+ gpt_params_print_usage(argc, argv, params);
return 1;
}
- params.logits_all = true;
-
const int32_t n_ctx = params.n_ctx;
if (n_ctx <= 0) {
@@ -2006,9 +2008,6 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
std::mt19937 rng(params.seed);
- if (params.random_prompt) {
- params.prompt = string_random_prompt(rng);
- }
llama_backend_init();
llama_numa_init(params.numa);
@@ -2027,6 +2026,7 @@ int main(int argc, char ** argv) {
}
const int n_ctx_train = llama_n_ctx_train(model);
+
if (params.n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);