summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgoerch <jhr.walter@t-online.de>2023-08-22 23:10:42 +0200
committerGitHub <noreply@github.com>2023-08-23 00:10:42 +0300
commit46ef5b5fcf4c366e1fb27726b6394adbbf8fd0ea (patch)
tree96f771ef97596af6e59bdcfeea76d15a7c80153f
parentc63bb1d16a70c03440671b76954bb767513cead8 (diff)
llama : fix whitespace escaping in tokenizer (#2724)
-rw-r--r--llama.cpp13
-rw-r--r--tests/test-tokenizer-0.cpp11
-rw-r--r--tests/test-tokenizer-1.cpp13
3 files changed, 16 insertions, 21 deletions
diff --git a/llama.cpp b/llama.cpp
index 6abdc44f..6c5da130 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -2253,18 +2253,11 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
}
static std::string llama_escape_whitespace(const std::string& text) {
- std::string result;
- bool escaping = false;
- result += "\xe2\x96\x81";
+ std::string result = "\xe2\x96\x81";
for (size_t offs = 0; offs < text.length(); ++offs) {
if (text[offs] == ' ') {
- if (!escaping) {
- result += "\xe2\x96\x81";
- escaping = true;
- }
- }
- else {
- escaping = false;
+ result += "\xe2\x96\x81";
+ } else {
result += text[offs];
}
}
diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp
index 81764565..f3ee851a 100644
--- a/tests/test-tokenizer-0.cpp
+++ b/tests/test-tokenizer-0.cpp
@@ -17,6 +17,8 @@ static std::string unescape_whitespace(llama_context* ctx, const std::vector<lla
static const std::map<std::string, std::vector<llama_token>> & k_tests() {
static std::map<std::string, std::vector<llama_token>> _k_tests = {
{ " ", {1, 259, }, },
+ { " ", { 1, 1678, }, },
+ { " ", { 1, 268, }, },
{ "\t", { 1, 29871, 12, }, },
{ "\n", { 1, 29871, 13, }, },
{ "\t\n", { 1, 29871, 12, 13, }, },
@@ -38,6 +40,12 @@ static const std::map<std::string, std::vector<llama_token>> & k_tests() {
243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598,
313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681,
313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, },
+ { "Hello", { 1, 15043 }, },
+ { " Hello", { 1, 29871, 15043 }, },
+ { " Hello", { 1, 259, 15043 }, },
+ { " Hello", { 1, 1678, 15043 }, },
+ { " Hello", { 1, 268, 15043 }, },
+ { " Hello\n Hello", { 1, 268, 15043, 13, 1678, 15043 }, },
};
return _k_tests;
@@ -106,7 +114,8 @@ int main(int argc, char **argv) {
if (!correct) {
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
- fprintf(stderr, "%s : detokenized to: '%s'\n", __func__, unescape_whitespace(ctx, test_kv.second).c_str());
+ fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
+ unescape_whitespace(ctx, res).c_str(), unescape_whitespace(ctx, test_kv.second).c_str());
fprintf(stderr, "%s : expected tokens: ", __func__);
for (const auto & t : test_kv.second) {
fprintf(stderr, "%6d, ", t);
diff --git a/tests/test-tokenizer-1.cpp b/tests/test-tokenizer-1.cpp
index d8db7cd9..993d17f1 100644
--- a/tests/test-tokenizer-1.cpp
+++ b/tests/test-tokenizer-1.cpp
@@ -11,18 +11,11 @@
#include <locale>
static std::string escape_whitespace(const std::string& text) {
- std::string result;
- bool escaping = false;
- result += "\xe2\x96\x81";
+ std::string result = "\xe2\x96\x81";
for (size_t offs = 0; offs < text.length(); ++offs) {
if (text[offs] == ' ') {
- if (!escaping) {
- result += "\xe2\x96\x81";
- escaping = true;
- }
- }
- else {
- escaping = false;
+ result += "\xe2\x96\x81";
+ } else {
result += text[offs];
}
}