summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorChristian Demsar <crasm@git.vczf.us>2023-08-10 10:28:27 -0400
committerGitHub <noreply@github.com>2023-08-10 16:28:27 +0200
commite59fcb2bc129881f4a269fee748fb38bce0a64de (patch)
treef96cb28cdf28e315cd4bea28dbc10b77afbc7fde /examples
parent1638757767072a4957f52b9e3594f0b67610631b (diff)
Add --n-predict -2 for stopping generation on full context (#2565)
Diffstat (limited to 'examples')
-rw-r--r--examples/common.cpp2
-rw-r--r--examples/main/README.md8
-rw-r--r--examples/main/main.cpp6
3 files changed, 12 insertions, 4 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index 4d3ba9bb..9f8aab9a 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -543,7 +543,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n");
fprintf(stdout, " -f FNAME, --file FNAME\n");
fprintf(stdout, " prompt file to start generation.\n");
- fprintf(stdout, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
+ fprintf(stdout, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
diff --git a/examples/main/README.md b/examples/main/README.md
index 55c16096..60e3907d 100644
--- a/examples/main/README.md
+++ b/examples/main/README.md
@@ -160,9 +160,13 @@ The following options allow you to control the text generation process and fine-
### Number of Tokens to Predict
-- `-n N, --n-predict N`: Set the number of tokens to predict when generating text (default: 128, -1 = infinity).
+- `-n N, --n-predict N`: Set the number of tokens to predict when generating text (default: 128, -1 = infinity, -2 = until context filled)
-The `--n-predict` option controls the number of tokens the model generates in response to the input prompt. By adjusting this value, you can influence the length of the generated text. A higher value will result in longer text, while a lower value will produce shorter text. A value of -1 will cause text to be generated without limit.
+The `--n-predict` option controls the number of tokens the model generates in response to the input prompt. By adjusting this value, you can influence the length of the generated text. A higher value will result in longer text, while a lower value will produce shorter text.
+
+A value of -1 will enable infinite text generation, even though we have a finite context window. When the context window is full, some of the earlier tokens (half of the tokens after `--n-keep`) will be discarded. The context must then be re-evaluated before generation can resume. On large models and/or large context windows, this will result in significant pause in output.
+
+If the pause is undesirable, a value of -2 will stop generation immediately when the context is filled.
It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `n-predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter.
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 56ada7e6..a632bea1 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -431,8 +431,12 @@ int main(int argc, char ** argv) {
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
- const int n_left = n_past - params.n_keep;
+ if (params.n_predict == -2) {
+ fprintf(stderr, "\n\n%s: context full, stopping generation\n", __func__);
+ break;
+ }
+ const int n_left = n_past - params.n_keep;
// always keep the first token - BOS
n_past = std::max(1, params.n_keep);
n_past_guidance = std::max(1, params.n_keep + guidance_offset);