summaryrefslogtreecommitdiff
path: root/examples/perplexity
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-06-04 21:23:39 +0300
committerGitHub <noreply@github.com>2024-06-04 21:23:39 +0300
commit1442677f92e45a475be7b4d056e3633d1d6f813b (patch)
treed9dbb111ccaedc44cba527dbddd90bedd1e04ea8 /examples/perplexity
parent554c247caffed64465f372661f2826640cb10430 (diff)
common : refactor cli arg parsing (#7675)
* common : gpt_params_parse do not print usage * common : rework usage print (wip) * common : valign * common : rework print_usage * infill : remove cfg support * common : reorder args * server : deduplicate parameters ggml-ci * common : add missing header ggml-ci * common : remote --random-prompt usages ggml-ci * examples : migrate to gpt_params ggml-ci * batched-bench : migrate to gpt_params * retrieval : migrate to gpt_params * common : change defaults for escape and n_ctx * common : remove chatml and instruct params ggml-ci * common : passkey use gpt_params
Diffstat (limited to 'examples/perplexity')
-rw-r--r--examples/perplexity/perplexity.cpp12
1 files changed, 6 insertions, 6 deletions
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index 30e5e282..0bd78c21 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -1032,7 +1032,7 @@ struct winogrande_entry {
std::vector<llama_token> seq_tokens[2];
};
-static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string& prompt) {
+static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string & prompt) {
std::vector<winogrande_entry> result;
std::istringstream in(prompt);
std::string line;
@@ -1964,12 +1964,14 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
int main(int argc, char ** argv) {
gpt_params params;
+ params.n_ctx = 512;
+ params.logits_all = true;
+
if (!gpt_params_parse(argc, argv, params)) {
+ gpt_params_print_usage(argc, argv, params);
return 1;
}
- params.logits_all = true;
-
const int32_t n_ctx = params.n_ctx;
if (n_ctx <= 0) {
@@ -2006,9 +2008,6 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
std::mt19937 rng(params.seed);
- if (params.random_prompt) {
- params.prompt = string_random_prompt(rng);
- }
llama_backend_init();
llama_numa_init(params.numa);
@@ -2027,6 +2026,7 @@ int main(int argc, char ** argv) {
}
const int n_ctx_train = llama_n_ctx_train(model);
+
if (params.n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);