diff options
Diffstat (limited to 'examples')
-rw-r--r-- | examples/main/main.cpp | 15 | ||||
-rw-r--r-- | examples/perplexity/perplexity.cpp | 12 |
2 files changed, 13 insertions, 14 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 922b9a98..9201b53b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -151,14 +151,6 @@ int main(int argc, char ** argv) { LOG_TEE("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale); } - if (params.n_ctx > 2048) { - // TODO: determine the actual max context of the model (e.g. 4096 for LLaMA v2) and use that instead of 2048 - LOG_TEE("%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified)\n", __func__, params.n_ctx); - } else if (params.n_ctx < 8) { - LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); - params.n_ctx = 8; - } - LOG_TEE("%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); if (params.seed == LLAMA_DEFAULT_SEED) { @@ -194,6 +186,13 @@ int main(int argc, char ** argv) { return 1; } + if (params.n_ctx > llama_n_ctx(ctx)) { + LOG_TEE("%s: warning: base model only supports context sizes no greater than %d tokens (%d specified)\n", __func__, llama_n_ctx(ctx), params.n_ctx); + } else if (params.n_ctx < 8) { + LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); + params.n_ctx = 8; + } + // print system information { LOG_TEE("\n"); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 7c02b6d4..843b2ae3 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -368,7 +368,7 @@ results_perplexity perplexity(llama_context * ctx, const gpt_params & params) { // Example, we have a context window of 512, we will compute perplexity for each of the // last 256 tokens. Then, we split the input up into context window size chunks to // process the entire prompt. - const int first = std::min(512, params.n_ctx/2); + const int first = params.n_ctx/2; process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first, workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); count += params.n_ctx - first - 1; @@ -668,11 +668,6 @@ int main(int argc, char ** argv) { params.n_ctx += params.ppl_stride/2; } - if (params.n_ctx > 2048) { - fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);" - "expect poor results\n", __func__, params.n_ctx); - } - fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); if (params.seed == LLAMA_DEFAULT_SEED) { @@ -698,6 +693,11 @@ int main(int argc, char ** argv) { return 1; } + if (params.n_ctx > llama_n_ctx(ctx)) { + fprintf(stderr, "%s: warning: model might not support context sizes greater than %d tokens (%d specified);" + "expect poor results\n", __func__, llama_n_ctx(ctx), params.n_ctx); + } + // print system information { fprintf(stderr, "\n"); |