summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCebtenzzre <cebtenzzre@gmail.com>2023-09-08 11:43:35 -0400
committerGitHub <noreply@github.com>2023-09-08 11:43:35 -0400
commite64f5b55783e910d8287363895d652b4bea6527a (patch)
treedba460e07999ff2d9d0c4a7499e5cf3c591f9748
parent94f10b91ed69980f299441e49c8dbdb448f0ccc6 (diff)
examples : make n_ctx warning work again (#3066)
This was broken by commit e36ecdcc ("build : on Mac OS enable Metal by default (#2901)").
-rw-r--r--examples/embedding/embedding.cpp11
-rw-r--r--examples/main/main.cpp6
-rw-r--r--examples/perplexity/perplexity.cpp7
-rw-r--r--llama.cpp14
-rw-r--r--llama.h14
5 files changed, 33 insertions, 19 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index 49ab3e06..e4a0a38c 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -17,11 +17,6 @@ int main(int argc, char ** argv) {
params.embedding = true;
- if (params.n_ctx > 2048) {
- fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
- "expect poor results\n", __func__, params.n_ctx);
- }
-
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
if (params.seed == LLAMA_DEFAULT_SEED) {
@@ -47,6 +42,12 @@ int main(int argc, char ** argv) {
return 1;
}
+ const int n_ctx_train = llama_n_ctx_train(ctx);
+ if (params.n_ctx > n_ctx_train) {
+ fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
+ __func__, n_ctx_train, params.n_ctx);
+ }
+
// print system information
{
fprintf(stderr, "\n");
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index be030fff..baec6ba1 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -182,8 +182,10 @@ int main(int argc, char ** argv) {
return 1;
}
- if (params.n_ctx > llama_n_ctx(ctx)) {
- LOG_TEE("%s: warning: base model only supports context sizes no greater than %d tokens (%d specified)\n", __func__, llama_n_ctx(ctx), params.n_ctx);
+ const int n_ctx_train = llama_n_ctx_train(ctx);
+ if (params.n_ctx > n_ctx_train) {
+ LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
+ __func__, n_ctx_train, params.n_ctx);
} else if (params.n_ctx < 8) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index 1b760683..3a1c8c28 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -693,9 +693,10 @@ int main(int argc, char ** argv) {
return 1;
}
- if (params.n_ctx > llama_n_ctx(ctx)) {
- fprintf(stderr, "%s: warning: model might not support context sizes greater than %d tokens (%d specified);"
- "expect poor results\n", __func__, llama_n_ctx(ctx), params.n_ctx);
+ const int n_ctx_train = llama_n_ctx_train(ctx);
+ if (params.n_ctx > n_ctx_train) {
+ fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
+ __func__, n_ctx_train, params.n_ctx);
}
// print system information
diff --git a/llama.cpp b/llama.cpp
index 3f119022..2a2a0c9c 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -5633,15 +5633,19 @@ void llama_free(struct llama_context * ctx) {
}
int llama_n_vocab(const struct llama_context * ctx) {
- return ctx->model.vocab.id_to_token.size();
+ return llama_model_n_vocab(&ctx->model);
}
int llama_n_ctx(const struct llama_context * ctx) {
- return ctx->model.hparams.n_ctx;
+ return llama_model_n_ctx(&ctx->model);
+}
+
+int llama_n_ctx_train(const struct llama_context * ctx) {
+ return llama_model_n_ctx_train(&ctx->model);
}
int llama_n_embd(const struct llama_context * ctx) {
- return ctx->model.hparams.n_embd;
+ return llama_model_n_embd(&ctx->model);
}
enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) {
@@ -5656,6 +5660,10 @@ int llama_model_n_ctx(const struct llama_model * model) {
return model->hparams.n_ctx;
}
+int llama_model_n_ctx_train(const struct llama_model * model) {
+ return model->hparams.n_ctx_train;
+}
+
int llama_model_n_embd(const struct llama_model * model) {
return model->hparams.n_embd;
}
diff --git a/llama.h b/llama.h
index 5b95aaa8..37975beb 100644
--- a/llama.h
+++ b/llama.h
@@ -245,15 +245,17 @@ extern "C" {
LLAMA_API bool llama_mmap_supported (void);
LLAMA_API bool llama_mlock_supported(void);
- LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
- LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
- LLAMA_API int llama_n_embd (const struct llama_context * ctx);
+ LLAMA_API int llama_n_vocab (const struct llama_context * ctx);
+ LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
+ LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
+ LLAMA_API int llama_n_embd (const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
- LLAMA_API int llama_model_n_vocab(const struct llama_model * model);
- LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
- LLAMA_API int llama_model_n_embd (const struct llama_model * model);
+ LLAMA_API int llama_model_n_vocab (const struct llama_model * model);
+ LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
+ LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
+ LLAMA_API int llama_model_n_embd (const struct llama_model * model);
// Get a string describing the model type
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);