summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
authorjaime-m-p <167997752+jaime-m-p@users.noreply.github.com>2024-06-18 18:40:52 +0200
committerGitHub <noreply@github.com>2024-06-18 18:40:52 +0200
commit37bef8943312d91183ff06d8f1214082a17344a5 (patch)
tree7713dc5aceb3b181568db3d21b1383762de41c4a /llama.cpp
parent91c188d6c296bd3384f2a02a83b71187aa3d18b3 (diff)
tokenizer : BPE fixes (#7530)
* Random test: add_bos_token, add_eos_token * Random test: add BPE models for testing * Custom regex split fails with codepoint 0 * Fix falcon punctuation regex * Refactor llm_tokenizer_bpe: move code to constructor * Move 'add_special_bos/eos' logic to llm_tokenizer_bpe * Move tokenizer flags to vocab structure. * Default values for special_add_bos/eos * Build vocab.special_tokens_cache using vocab token types * Generalize 'jina-v2' per token attributes * Fix unicode whitespaces (deepseek-coder, deepseek-llm) * Skip missing byte tokens (falcon) * Better unicode data generation * Replace char32_t with uint32_t
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp309
1 files changed, 172 insertions, 137 deletions
diff --git a/llama.cpp b/llama.cpp
index e06c851a..8818c692 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -2310,16 +2310,17 @@ struct llama_vocab {
id special_cls_id = -1;
id special_mask_id = -1;
- int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
- int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
-
id linefeed_id = 13;
id special_prefix_id = -1;
id special_suffix_id = -1;
id special_middle_id = -1;
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
- bool add_space_prefix = true;
+ // tokenizer flags
+ bool tokenizer_add_space_prefix = true;
+ bool tokenizer_add_bos = false;
+ bool tokenizer_add_eos = false;
+ bool tokenizer_ignore_merges = false;
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
GGML_ASSERT(token_left.find(' ') == std::string::npos);
@@ -4770,7 +4771,7 @@ static void llm_load_vocab(
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
if (add_space_prefix_keyidx != -1) {
- vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
+ vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
} // The default value of add_space_prefix is true.
} else if (tokenizer_model == "bert") {
vocab.type = LLAMA_VOCAB_TYPE_WPM;
@@ -4783,13 +4784,13 @@ static void llm_load_vocab(
vocab.special_pad_id = 0;
vocab.special_cls_id = 101;
vocab.special_mask_id = 103;
- vocab.add_space_prefix = false;
+ vocab.tokenizer_add_space_prefix = false;
} else if (tokenizer_model == "gpt2") {
vocab.type = LLAMA_VOCAB_TYPE_BPE;
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
if (add_space_prefix_keyidx != -1) {
- vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
+ vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
}
// read bpe merges and populate bpe ranks
@@ -4847,6 +4848,8 @@ static void llm_load_vocab(
tokenizer_pre == "llama-v3" ||
tokenizer_pre == "llama-bpe") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
+ vocab.tokenizer_ignore_merges = true;
+ vocab.tokenizer_add_bos = true;
} else if (
tokenizer_pre == "deepseek-llm") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
@@ -4897,6 +4900,14 @@ static void llm_load_vocab(
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}
+ } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+ vocab.tokenizer_add_bos = true;
+ vocab.tokenizer_add_eos = false;
+ } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+ vocab.tokenizer_add_bos = true;
+ vocab.tokenizer_add_eos = false;
} else {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
}
@@ -5041,10 +5052,10 @@ static void llm_load_vocab(
bool temp = true;
if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
- vocab.special_add_bos = int(temp);
+ vocab.tokenizer_add_bos = temp;
}
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
- vocab.special_add_eos = int(temp);
+ vocab.tokenizer_add_eos = temp;
}
}
@@ -5144,7 +5155,7 @@ static void llm_load_vocab(
);
// set attributes by model/tokenizer name
- if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) {
+ if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
for (auto id : vocab.cache_special_tokens) {
@@ -13158,112 +13169,142 @@ struct llm_bigram_bpe {
};
struct llm_tokenizer_bpe {
- llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {}
-
- void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
- int final_prev_index = -1;
- bool ignore_merges = false;
-
- std::vector<std::string> word_collection;
- switch (vocab.type) {
- case LLAMA_VOCAB_TYPE_BPE:
- switch (vocab.type_pre) {
- case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
- ignore_merges = true;
- word_collection = unicode_regex_split(text, {
- // original regex from tokenizer.json
- //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-
- // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
- "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
- });
- break;
- case LLAMA_VOCAB_PRE_TYPE_DBRX:
- case LLAMA_VOCAB_PRE_TYPE_SMAUG:
- word_collection = unicode_regex_split(text, {
- // same as llama3
- "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
- });
- break;
- case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
- word_collection = unicode_regex_split(text, {
- "[\r\n]",
- "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
- "\\s?[!-/:-~!-/:-~‘-‟ -。]+",
- "\\s+$",
- "[一-龥ࠀ-一가-퟿]+",
- "\\p{N}+",
- });
- break;
- case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
- word_collection = unicode_regex_split(text, {
- "[\r\n]",
- "\\s?\\p{L}+",
- "\\s?\\p{P}+",
- "[一-龥ࠀ-一가-퟿]+",
- "\\p{N}",
- });
- break;
- case LLAMA_VOCAB_PRE_TYPE_FALCON:
- word_collection = unicode_regex_split(text, {
- "[\\p{P}\\$\\+<=>\\^~\\|]+",
- "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
- "[0-9][0-9][0-9]",
- });
- break;
- case LLAMA_VOCAB_PRE_TYPE_MPT:
- // TODO: MPT pre-tokenization regexes are unknown
- // the following are close, but not exact. run the following:
- // ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
- GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
- word_collection = unicode_regex_split(text, {
- "\\s?\\p{L}+",
- "\\s?\\p{P}+",
- "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
- });
- break;
- case LLAMA_VOCAB_PRE_TYPE_STARCODER:
- case LLAMA_VOCAB_PRE_TYPE_REFACT:
- case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
- word_collection = unicode_regex_split(text, {
- "\\p{N}",
- "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
- });
- break;
- case LLAMA_VOCAB_PRE_TYPE_GPT2:
- case LLAMA_VOCAB_PRE_TYPE_OLMO:
- word_collection = unicode_regex_split(text, {
- "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
- });
- break;
- case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
- case LLAMA_VOCAB_PRE_TYPE_QWEN2:
- word_collection = unicode_regex_split(text, {
- // original regex from tokenizer.json
- // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
- "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
- });
- break;
- case LLAMA_VOCAB_PRE_TYPE_PORO:
- word_collection = unicode_regex_split(text, {
- " ?[^(\\s|.,!?…。,、।۔،)]+",
- });
- break;
- default:
- // default regex for BPE tokenization pre-processing
- word_collection = unicode_regex_split(text, {
- "[\\p{P}\\$\\+<=>\\^~\\|]+",
- "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
- "\\p{N}+",
- "[0-9][0-9][0-9]",
- });
- break;
- }
+ llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
+ GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
+ switch (vocab.type_pre) {
+ case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
+ regex_exprs = {
+ // original regex from tokenizer.json
+ //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+
+ // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+ };
+ break;
+ case LLAMA_VOCAB_PRE_TYPE_DBRX:
+ case LLAMA_VOCAB_PRE_TYPE_SMAUG:
+ regex_exprs = {
+ // same as llama3
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+ };
+ break;
+ case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
+ regex_exprs = {
+ "[\r\n]",
+ "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
+ "\\s?[!-/:-~!-/:-~‘-‟ -。]+",
+ "\\s+$",
+ "[一-龥ࠀ-一가-퟿]+",
+ "\\p{N}+",
+ };
+ break;
+ case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
+ regex_exprs = {
+ "[\r\n]",
+ "\\s?\\p{L}+",
+ "\\s?\\p{P}+",
+ "[一-龥ࠀ-一가-퟿]+",
+ "\\p{N}",
+ };
+ break;
+ case LLAMA_VOCAB_PRE_TYPE_FALCON:
+ regex_exprs = {
+ "[\\p{P}\\$\\+<=>\\^~\\|`]+",
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+ "[0-9][0-9][0-9]",
+ };
+ break;
+ case LLAMA_VOCAB_PRE_TYPE_MPT:
+ // TODO: MPT pre-tokenization regexes are unknown
+ // the following are close, but not exact. run the following:
+ // ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
+ GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
+ regex_exprs = {
+ "\\s?\\p{L}+",
+ "\\s?\\p{P}+",
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+ };
+ break;
+ case LLAMA_VOCAB_PRE_TYPE_STARCODER:
+ case LLAMA_VOCAB_PRE_TYPE_REFACT:
+ case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
+ regex_exprs = {
+ "\\p{N}",
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+ };
+ break;
+ case LLAMA_VOCAB_PRE_TYPE_GPT2:
+ case LLAMA_VOCAB_PRE_TYPE_OLMO:
+ regex_exprs = {
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+ };
+ break;
+ case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
+ case LLAMA_VOCAB_PRE_TYPE_QWEN2:
+ regex_exprs = {
+ // original regex from tokenizer.json
+ // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+ };
+ break;
+ case LLAMA_VOCAB_PRE_TYPE_PORO:
+ regex_exprs = {
+ " ?[^(\\s|.,!?…。,、।۔،)]+",
+ };
break;
default:
- GGML_ASSERT(false);
+ // default regex for BPE tokenization pre-processing
+ regex_exprs = {
+ "[\\p{P}\\$\\+<=>\\^~\\|]+",
+ "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+ "\\p{N}+",
+ "[0-9][0-9][0-9]",
+ };
break;
}
+ }
+
+ void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
+ output.push_back(token_id);
+ }
+
+ bool append_bos(std::vector<llama_vocab::id> & output) const {
+ if (vocab.tokenizer_add_bos) {
+ GGML_ASSERT(vocab.special_bos_id != -1);
+ output.push_back(vocab.special_bos_id);
+ return true;
+ }
+ return false;
+ }
+
+ bool append_eos(std::vector<llama_vocab::id> & output) const {
+ if (vocab.tokenizer_add_eos) {
+ GGML_ASSERT(vocab.special_eos_id != -1);
+ output.push_back(vocab.special_eos_id);
+ return true;
+ }
+ return false;
+ }
+
+ void check_double_bos_eos(const std::vector<llama_vocab::id> & output) const {
+ if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+ LLAMA_LOG_WARN(
+ "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+ "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+ "Are you sure this is what you want?\n", __FUNCTION__);
+ }
+ if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
+ LLAMA_LOG_WARN(
+ "%s: Added a EOS token to the prompt as specified by the model but the prompt "
+ "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
+ "Are you sure this is what you want?\n", __FUNCTION__);
+ }
+ }
+
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+ int final_prev_index = -1;
+
+ const auto word_collection = unicode_regex_split(text, regex_exprs);
symbols_final.clear();
@@ -13274,7 +13315,7 @@ struct llm_tokenizer_bpe {
int index = 0;
size_t offset = 0;
- if (ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
+ if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
offset = word.size();
}
@@ -13355,10 +13396,9 @@ struct llm_tokenizer_bpe {
for (auto j = str.begin(); j != str.end(); ++j) {
std::string byte_str(1, *j);
auto token_multibyte = vocab.token_to_id.find(byte_str);
- if (token_multibyte == vocab.token_to_id.end()) {
- throw std::runtime_error("ERROR: byte not found in vocab");
+ if (token_multibyte != vocab.token_to_id.end()) {
+ output.push_back(token_multibyte->second);
}
- output.push_back((*token_multibyte).second);
}
} else {
output.push_back((*token).second);
@@ -13397,6 +13437,8 @@ private:
const llama_vocab & vocab;
+ std::vector<std::string> regex_exprs;
+
std::vector<llm_symbol> symbols;
std::vector<llm_symbol> symbols_final;
@@ -13677,7 +13719,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
bool is_prev_special = false;
- if (add_special && vocab.special_add_bos != 0) {
+ if (add_special && vocab.tokenizer_add_bos) {
GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id);
is_prev_special = true;
@@ -13687,7 +13729,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
- if (vocab.add_space_prefix) {
+ if (vocab.tokenizer_add_space_prefix) {
if (!output.size() || is_prev_special) { // prefix with space if first token
raw_text = " " + raw_text;
}
@@ -13705,23 +13747,24 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
}
}
- if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+ if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
LLAMA_LOG_WARN(
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
- if (add_special && vocab.special_add_eos == 1) {
+ if (add_special && vocab.tokenizer_add_eos) {
GGML_ASSERT(vocab.special_eos_id != -1);
output.push_back(vocab.special_eos_id);
}
} break;
case LLAMA_VOCAB_TYPE_BPE:
{
- if (add_special && vocab.special_add_bos != 0) {
- GGML_ASSERT(vocab.special_bos_id != -1);
- output.push_back(vocab.special_bos_id);
+ llm_tokenizer_bpe tokenizer(vocab);
+
+ if (add_special) {
+ tokenizer.append_bos(output);
}
for (const auto & fragment : fragment_buffer) {
@@ -13731,23 +13774,15 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
- llm_tokenizer_bpe tokenizer(vocab);
tokenizer.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
- output.push_back(fragment.token);
+ tokenizer.append(fragment.token, output);
}
}
- if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
- LLAMA_LOG_WARN(
- "%s: Added a BOS token to the prompt as specified by the model but the prompt "
- "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
- "Are you sure this is what you want?\n", __FUNCTION__);
- }
-
- if (add_special && vocab.special_add_eos == 1) {
- GGML_ASSERT(vocab.special_add_eos != -1);
- output.push_back(vocab.special_eos_id);
+ if (add_special) {
+ tokenizer.append_eos(output);
+ tokenizer.check_double_bos_eos(output);
}
} break;
case LLAMA_VOCAB_TYPE_WPM:
@@ -18320,11 +18355,11 @@ llama_token llama_token_nl(const struct llama_model * model) {
}
int32_t llama_add_bos_token(const struct llama_model * model) {
- return model->vocab.special_add_bos;
+ return model->vocab.tokenizer_add_bos;
}
int32_t llama_add_eos_token(const struct llama_model * model) {
- return model->vocab.special_add_eos;
+ return model->vocab.tokenizer_add_eos;
}
llama_token llama_token_prefix(const struct llama_model * model) {