summaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp20
1 files changed, 20 insertions, 0 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 34e84d0d..47059e58 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -511,6 +511,14 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
+ // tokenized antiprompts
+ std::vector<std::vector<llama_token>> antiprompt_ids;
+
+ antiprompt_ids.reserve(params.antiprompt.size());
+ for (const std::string & antiprompt : params.antiprompt) {
+ antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
+ }
+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
@@ -769,6 +777,18 @@ int main(int argc, char ** argv) {
}
}
+ // check for reverse prompt using special tokens
+ llama_token last_token = llama_sampling_last(ctx_sampling);
+ for (std::vector<llama_token> ids : antiprompt_ids) {
+ if (ids.size() == 1 && last_token == ids[0]) {
+ if (params.interactive) {
+ is_interacting = true;
+ }
+ is_antiprompt = true;
+ break;
+ }
+ }
+
if (is_antiprompt) {
LOG("found antiprompt: %s\n", last_output.c_str());
}