summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
authorgoerch <jhr.walter@t-online.de>2023-10-10 18:59:52 +0200
committerGitHub <noreply@github.com>2023-10-10 18:59:52 +0200
commit233fc1c69f6f415f35363e18a755f9610e89161b (patch)
treed949e9cdaa21419b2a03e7eeb81852cd7a5e6240 /llama.cpp
parentc5b49360d0d9e49f32e05a9116e90bd0b39a282d (diff)
Minor improvements in GPT2 tokenizer (#3567)
* Fixing minor bugs in bpe_gpt2_preprocess * Don't add bos token in test
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp9
1 files changed, 4 insertions, 5 deletions
diff --git a/llama.cpp b/llama.cpp
index 4653c802..7ed87223 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -6342,7 +6342,6 @@ private:
for (int i = 0; i < (int)text_utf.size(); i++) {
const std::string & utf_char = text_utf[i];
bool split_condition = false;
- // const char* text_pos = raw_text_p + utf_char.seq_offset_bytes;
int bytes_remain = text_utf.size() - i;
// forward backward lookups
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
@@ -6368,9 +6367,9 @@ private:
if (!split_condition && bytes_remain >= 3) {
// 're|'ve|'ll
if (utf_char == "\'" && (
- (utf_char_next == "r" || utf_char_next_next == "e") ||
- (utf_char_next == "v" || utf_char_next_next == "e") ||
- (utf_char_next == "l" || utf_char_next_next == "l"))
+ (utf_char_next == "r" && utf_char_next_next == "e") ||
+ (utf_char_next == "v" && utf_char_next_next == "e") ||
+ (utf_char_next == "l" && utf_char_next_next == "l"))
) {
split_condition = true;
}
@@ -6421,7 +6420,7 @@ private:
else if (collecting_special && (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
split_condition = true;
}
- else if (collecting_whitespace_lookahead && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) {
+ else if (collecting_whitespace_lookahead && (codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
split_condition = true;
}
}