summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp278
1 files changed, 226 insertions, 52 deletions
diff --git a/llama.cpp b/llama.cpp
index 05b570bd..4a61eecd 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1,6 +1,8 @@
#define LLAMA_API_INTERNAL
#include "llama.h"
+#include "unicode.h"
+
#include "ggml.h"
#include "ggml-alloc.h"
@@ -1980,6 +1982,7 @@ static void llm_load_vocab(
for (int i = 0; i < n_merges; i++) {
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
+ GGML_ASSERT(codepoints_from_utf8(word).size() > 0);
std::string first;
std::string second;
@@ -2014,6 +2017,7 @@ static void llm_load_vocab(
for (uint32_t i = 0; i < n_vocab; i++) {
std::string word = gguf_get_arr_str(ctx, token_idx, i);
+ GGML_ASSERT(codepoints_from_utf8(word).size() > 0);
vocab.token_to_id[word] = i;
@@ -2022,12 +2026,13 @@ static void llm_load_vocab(
token_data.score = scores ? scores[i] : 0.0f;
token_data.type = toktypes ? (llama_token_type) toktypes[i] : LLAMA_TOKEN_TYPE_NORMAL;
}
+ GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
} else {
- vocab.linefeed_id = llama_tokenize_internal(vocab, "\n", false)[0];
+ vocab.linefeed_id = llama_tokenize_internal(vocab, "\u010A", false)[0];
}
// special tokens
@@ -4236,18 +4241,41 @@ static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE;
}
-static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) {
+static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) {
+ return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED;
+}
+
+static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
GGML_ASSERT(llama_is_byte_token(vocab, id));
const auto& token_data = vocab.id_to_token.at(id);
- auto buf = token_data.text.substr(3, 2);
- return strtol(buf.c_str(), NULL, 16);
+ switch (llama_vocab_get_type(vocab)) {
+ case LLAMA_VOCAB_TYPE_SPM: {
+ auto buf = token_data.text.substr(3, 2);
+ return strtol(buf.c_str(), NULL, 16);
+ }
+ case LLAMA_VOCAB_TYPE_BPE: {
+ GGML_ASSERT(false);
+ return unicode_to_bytes_bpe(token_data.text);
+ }
+ default:
+ GGML_ASSERT(false);
+ }
}
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
- char buf[7];
- int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch);
- GGML_ASSERT(0 <= result && result < 7);
- return vocab.token_to_id.at(buf);
+ switch (llama_vocab_get_type(vocab)) {
+ case LLAMA_VOCAB_TYPE_SPM: {
+ char buf[7];
+ int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch);
+ GGML_ASSERT(0 <= result && result < 7);
+ return vocab.token_to_id.at(buf);
+ }
+ case LLAMA_VOCAB_TYPE_BPE: {
+ return vocab.token_to_id.at(bytes_to_unicode_bpe(ch));
+ }
+ default:
+ GGML_ASSERT(false);
+ }
}
static void llama_escape_whitespace(std::string & text) {
@@ -4527,15 +4555,9 @@ struct llm_tokenizer_bpe {
std::string byte_str(1, *j);
auto token_multibyte = vocab.token_to_id.find(byte_str);
if (token_multibyte == vocab.token_to_id.end()) {
- try {
- llama_token token_byte = llama_byte_to_token(vocab, *j);
- output.push_back(token_byte);
- } catch (const std::out_of_range & err) {
- fprintf(stderr,"ERROR: byte not found in vocab: '%s'\n", byte_str.c_str());
- }
- } else {
- output.push_back((*token_multibyte).second);
+ throw std::runtime_error("ERROR: byte not found in vocab");
}
+ output.push_back((*token_multibyte).second);
}
} else {
output.push_back((*token).second);
@@ -4572,23 +4594,144 @@ private:
work_queue.push(bigram);
}
- // probably not 100% correct
- static std::vector<std::string> bpe_gpt2_preprocess(const std::string & text) {
- std::vector<std::string> words;
+ std::vector<std::string> bpe_gpt2_preprocess(const std::string & text) {
+ std::vector<std::string> bpe_words;
+ std::vector<std::string> bpe_encoded_words;
+
+ std::string token = "";
+ // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
+ bool collecting_numeric = false;
+ bool collecting_letter = false;
+ bool collecting_special = false;
+ bool collecting_whitespace_lookahead = false;
+ bool collecting = false;
+
+ std::vector<std::string> text_utf;
+ text_utf.reserve(text.size());
+ bpe_words.reserve(text.size());
+ bpe_encoded_words.reserve(text.size());
+
+ auto cps = codepoints_from_utf8(text);
+ for (size_t i = 0; i < cps.size(); ++i)
+ text_utf.emplace_back(codepoint_to_utf8(cps[i]));
+
+ for (int i = 0; i < (int)text_utf.size(); i++) {
+ const std::string & utf_char = text_utf[i];
+ bool split_condition = false;
+ // const char* text_pos = raw_text_p + utf_char.seq_offset_bytes;
+ int bytes_remain = text_utf.size() - i;
+ // forward backward lookups
+ const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
+ const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
+
+ // handling contractions
+ if (!split_condition && bytes_remain >= 2) {
+ // 's|'t|'m|'d
+ if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) {
+ split_condition = true;
+ }
+ if (split_condition) {
+ if (token.size()) {
+ bpe_words.emplace_back(token); // push previous content as token
+ }
+ token = utf_char + utf_char_next;
+ bpe_words.emplace_back(token);
+ token = "";
+ i++;
+ continue;
+ }
+ }
+ if (!split_condition && bytes_remain >= 3) {
+ // 're|'ve|'ll
+ if (utf_char == "\'" && (
+ (utf_char_next == "r" || utf_char_next_next == "e") ||
+ (utf_char_next == "v" || utf_char_next_next == "e") ||
+ (utf_char_next == "l" || utf_char_next_next == "l"))
+ ) {
+ split_condition = true;
+ }
+ if (split_condition) {
+ // current token + next token can be defined
+ if (token.size()) {
+ bpe_words.emplace_back(token); // push previous content as token
+ }
+ token = utf_char + utf_char_next + utf_char_next_next;
+ bpe_words.emplace_back(token); // the contraction
+ token = "";
+ i += 2;
+ continue;
+ }
+ }
+
+ if (!split_condition && !collecting) {
+ if (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
+ collecting_letter = true;
+ collecting = true;
+ }
+ else if (codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
+ collecting_numeric = true;
+ collecting = true;
+ }
+ else if (
+ ((codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (codepoint_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
+ (!token.size() && utf_char == " " && codepoint_type(utf_char_next) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
+ ) {
+ collecting_special = true;
+ collecting = true;
+ }
+ else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && codepoint_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
+ collecting_whitespace_lookahead = true;
+ collecting = true;
+ }
+ else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
+ split_condition = true;
+ }
+ }
+ else if (!split_condition && collecting) {
+ if (collecting_letter && codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER) {
+ split_condition = true;
+ }
+ else if (collecting_numeric && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
+ split_condition = true;
+ }
+ else if (collecting_special && (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
+ split_condition = true;
+ }
+ else if (collecting_whitespace_lookahead && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) {
+ split_condition = true;
+ }
+ }
+
+ if (utf_char_next == "") {
+ split_condition = true; // final
+ token += utf_char;
+ }
- // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
- const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
- const std::regex re(pattern);
+ if (split_condition) {
+ if (token.size()) {
+ bpe_words.emplace_back(token);
+ }
+ token = utf_char;
+ collecting = false;
+ collecting_letter = false;
+ collecting_numeric = false;
+ collecting_special = false;
+ collecting_whitespace_lookahead = false;
+ }
+ else {
+ token += utf_char;
+ }
+ }
- auto words_begin = std::sregex_iterator(text.begin(), text.end(), re);
- auto words_end = std::sregex_iterator();
- auto n_words = std::distance(words_begin, words_end);
- words.reserve(n_words);
- for (auto it = words_begin; it != words_end; ++it) {
- words.push_back(it->str());
+ for (std::string & word : bpe_words) {
+ std::string encoded_token = "";
+ for (char & c : word) {
+ encoded_token += bytes_to_unicode_bpe(c);
+ }
+ bpe_encoded_words.emplace_back(encoded_token);
}
- return words;
+ return bpe_encoded_words;
}
const llama_vocab & vocab;
@@ -7532,35 +7675,66 @@ int llama_tokenize(
return res.size();
}
+static std::string llama_decode_text(const std::string & text) {
+ std::string decoded_text;
+ auto unicode_sequences = codepoints_from_utf8(text);
+ for (auto& unicode_sequence : unicode_sequences) {
+ decoded_text += unicode_to_bytes_bpe(codepoint_to_utf8(unicode_sequence));
+ }
+
+ return decoded_text;
+}
+
// does not write null-terminator to buf
int llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int length) {
if (0 <= token && token < llama_n_vocab(model)) {
- if (llama_is_normal_token(model->vocab, token)) {
- std::string result = model->vocab.id_to_token[token].text;
- if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) {
+ switch (llama_vocab_get_type(model->vocab)) {
+ case LLAMA_VOCAB_TYPE_SPM: {
+ if (llama_is_normal_token(model->vocab, token)) {
+ std::string result = model->vocab.id_to_token[token].text;
llama_unescape_whitespace(result);
+ if (length < (int) result.length()) {
+ return -result.length();
+ }
+ memcpy(buf, result.c_str(), result.length());
+ return result.length();
+ } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
+ if (length < 3) {
+ return -3;
+ }
+ memcpy(buf, "\xe2\x96\x85", 3);
+ return 3;
+ } else if (llama_is_control_token(model->vocab, token)) {
+ ;
+ } else if (llama_is_byte_token(model->vocab, token)) {
+ if (length < 1) {
+ return -1;
+ }
+ buf[0] = llama_token_to_byte(model->vocab, token);
+ return 1;
+ } else {
+ GGML_ASSERT(false);
}
- if (length < (int) result.length()) {
- return -result.length();
- }
- memcpy(buf, result.c_str(), result.length());
- return result.length();
- } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
- if (length < 3) {
- return -3;
- }
- buf[0] = '\xe2';
- buf[1] = '\x96';
- buf[2] = '\x85';
- return 3;
- } else if (llama_is_control_token(model->vocab, token)) {
- // do nothing
- } else if (llama_is_byte_token(model->vocab, token)) {
- if (length < 1) {
- return -1;
+ break;
+ }
+ case LLAMA_VOCAB_TYPE_BPE: {
+ if (llama_is_normal_token(model->vocab, token)) {
+ std::string result = model->vocab.id_to_token[token].text;
+ result = llama_decode_text(result);
+ if (length < (int) result.length()) {
+ return -result.length();
+ }
+ memcpy(buf, result.c_str(), result.length());
+ return result.length();
+ } else if (llama_is_control_token(model->vocab, token)) {
+ ;
+ } else {
+ GGML_ASSERT(false);
}
- buf[0] = llama_token_to_byte(model->vocab, token);
- return 1;
+ break;
+ }
+ default:
+ GGML_ASSERT(false);
}
}
return 0;