summaryrefslogtreecommitdiff
path: root/examples/llava/llava-cli.cpp
diff options
context:
space:
mode:
authorXiao-Yong Jin <jinxiaoyong@gmail.com>2024-02-07 02:17:25 -0600
committerGitHub <noreply@github.com>2024-02-07 10:17:25 +0200
commit0ef46da632c32faa1a538e5dc180994e8bbb46e1 (patch)
treebdd3e10dc2f129b432094c4eb6a69ac8bd07e854 /examples/llava/llava-cli.cpp
parentee1628bdfea8b0079fed0140ac2f00ef1b465b57 (diff)
llava-cli : always tokenize special tokens (#5382)
* llava-cli: tokenize special tokens in prompt * llava-cli: use the escape CLI argument, remove incomplete separate escaping process
Diffstat (limited to 'examples/llava/llava-cli.cpp')
-rw-r--r--examples/llava/llava-cli.cpp14
1 files changed, 1 insertions, 13 deletions
diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp
index 6ac70ba6..031e9806 100644
--- a/examples/llava/llava-cli.cpp
+++ b/examples/llava/llava-cli.cpp
@@ -34,7 +34,7 @@ static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) {
static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){
std::string str2 = str;
- std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos);
+ std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos, true);
eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
return true;
}
@@ -152,20 +152,8 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
size_t image_pos = prompt.find("<image>");
if (image_pos != std::string::npos) {
// new templating mode: Provide the full prompt including system message and use <image> as a placeholder for the image
-
system_prompt = prompt.substr(0, image_pos);
user_prompt = prompt.substr(image_pos + std::string("<image>").length());
- // We replace \n with actual newlines in user_prompt, just in case -e was not used in templating string
- size_t pos = 0;
- while ((pos = user_prompt.find("\\n", pos)) != std::string::npos) {
- user_prompt.replace(pos, 2, "\n");
- pos += 1; // Advance past the replaced newline
- }
- while ((pos = system_prompt.find("\\n", pos)) != std::string::npos) {
- system_prompt.replace(pos, 2, "\n");
- pos += 1; // Advance past the replaced newline
- }
-
printf("system_prompt: %s\n", system_prompt.c_str());
printf("user_prompt: %s\n", user_prompt.c_str());
} else {