diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-06-04 21:23:39 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-04 21:23:39 +0300 |
commit | 1442677f92e45a475be7b4d056e3633d1d6f813b (patch) | |
tree | d9dbb111ccaedc44cba527dbddd90bedd1e04ea8 /examples/perplexity | |
parent | 554c247caffed64465f372661f2826640cb10430 (diff) |
common : refactor cli arg parsing (#7675)
* common : gpt_params_parse do not print usage
* common : rework usage print (wip)
* common : valign
* common : rework print_usage
* infill : remove cfg support
* common : reorder args
* server : deduplicate parameters
ggml-ci
* common : add missing header
ggml-ci
* common : remote --random-prompt usages
ggml-ci
* examples : migrate to gpt_params
ggml-ci
* batched-bench : migrate to gpt_params
* retrieval : migrate to gpt_params
* common : change defaults for escape and n_ctx
* common : remove chatml and instruct params
ggml-ci
* common : passkey use gpt_params
Diffstat (limited to 'examples/perplexity')
-rw-r--r-- | examples/perplexity/perplexity.cpp | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 30e5e282..0bd78c21 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1032,7 +1032,7 @@ struct winogrande_entry { std::vector<llama_token> seq_tokens[2]; }; -static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string& prompt) { +static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string & prompt) { std::vector<winogrande_entry> result; std::istringstream in(prompt); std::string line; @@ -1964,12 +1964,14 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { int main(int argc, char ** argv) { gpt_params params; + params.n_ctx = 512; + params.logits_all = true; + if (!gpt_params_parse(argc, argv, params)) { + gpt_params_print_usage(argc, argv, params); return 1; } - params.logits_all = true; - const int32_t n_ctx = params.n_ctx; if (n_ctx <= 0) { @@ -2006,9 +2008,6 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: seed = %u\n", __func__, params.seed); std::mt19937 rng(params.seed); - if (params.random_prompt) { - params.prompt = string_random_prompt(rng); - } llama_backend_init(); llama_numa_init(params.numa); @@ -2027,6 +2026,7 @@ int main(int argc, char ** argv) { } const int n_ctx_train = llama_n_ctx_train(model); + if (params.n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, params.n_ctx); |