diff options
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 135 |
1 files changed, 42 insertions, 93 deletions
@@ -68,10 +68,12 @@ #include <cstdio> #include <cstring> #include <ctime> +#include <cwctype> #include <forward_list> #include <fstream> #include <functional> #include <initializer_list> +#include <locale> #include <map> #include <memory> #include <mutex> @@ -8941,37 +8943,46 @@ struct llm_tokenizer_wpm { } std::vector<std::string> preprocess(const std::string & text) { - std::string ori_str = normalize(text); - uint64_t ori_size = ori_str.size(); + // normalalization form D + std::vector<uint32_t> codepoints = codepoints_from_utf8(text); + std::vector<uint32_t> nfd_codepoints; + for (uint32_t code : codepoints) { + auto it = nfd_map.find(code); + if (it != nfd_map.end()) { + for (uint32_t c : it->second) { + nfd_codepoints.push_back(c); + } + } else { + nfd_codepoints.push_back(code); + } + } - // single punct / single symbol / single digit - // baseline: add whitespace on the left and right of punct and chinese characters - std::vector<std::string> words; + // strip accents, strip control, uniformize whitespace, + // to lowercase, pad chinese characters, pad punctuation std::string new_str = ""; - uint64_t i = 0; - while (i < ori_size) { - int utf_char_len = utf8_len(ori_str[i]); - if ((utf_char_len == 1) && ispunct(ori_str[i])) { - new_str += " "; - new_str += ori_str[i]; - new_str += " "; - i += 1; + for (uint32_t code : nfd_codepoints) { + int type = codepoint_type(code); + if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) { + continue; } - else if ((utf_char_len == 3) && is_chinese_char(ori_str.substr(i, 3))) { + code = to_lower(code); + if (type == CODEPOINT_TYPE_WHITESPACE) { + code = ' '; + } + std::string s = codepoint_to_utf8(code); + if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) { new_str += " "; - new_str += ori_str.substr(i, 3); + new_str += s; new_str += " "; - i += 3; - } - else { - new_str += ori_str[i]; - i += 1; + } else { + new_str += s; } } // split by whitespace uint64_t l = 0; uint64_t r = 0; + std::vector<std::string> words; while (r < new_str.size()) { // if is whitespace if (isspace(new_str[r])) { @@ -8989,47 +9000,20 @@ struct llm_tokenizer_wpm { return words; } - std::string normalize(const std::string & text) { - // TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98 - std::string text2 = strip_accents(text); - for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) { - char c = text2[i]; - if (c >= 'A' && c <= 'Z') { - text2[i] = c - 'A' + 'a'; - } + uint32_t to_lower(uint32_t code) { +#if defined(_WIN32) + if (code > 0xFFFF) { + return code; } - return text2; +#endif + return std::tolower(wchar_t(code), std::locale("en_US.UTF-8")); } - bool is_chinese_char(const std::string & str) { - int len = str.length(); - unsigned int codepoint = 0; - int num_bytes = 0; - int i = 0; - unsigned char ch = static_cast<unsigned char>(str[i]); - if (ch <= 0x7f) { - codepoint = ch; - num_bytes = 1; - } else if ((ch >> 5) == 0x06) { - codepoint = ch & 0x1f; - num_bytes = 2; - } else if ((ch >> 4) == 0x0e) { - codepoint = ch & 0x0f; - num_bytes = 3; - } else if ((ch >> 3) == 0x1e) { - codepoint = ch & 0x07; - num_bytes = 4; - } - for (int j = 1; j < num_bytes; ++j) { - if (i + j >= len) { - return false; // incomplete UTF-8 character - } - unsigned char next_ch = static_cast<unsigned char>(str[i + j]); - if ((next_ch >> 6) != 0x02) { - return false; // invalid trailing byte - } - codepoint = (codepoint << 6) | (next_ch & 0x3f); - } + bool is_ascii_punct(uint32_t code) { + return code < 256 && ispunct(code); + } + + bool is_chinese_char(uint32_t codepoint) { if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) || (codepoint >= 0x3400 && codepoint <= 0x4DBF) || (codepoint >= 0x20000 && codepoint <= 0x2A6DF) || @@ -9045,41 +9029,6 @@ struct llm_tokenizer_wpm { return false; } - std::string strip_accents(const std::string & input_string) { - std::string resultString; - std::map<std::string, char> accent_map = { - {"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'}, - {"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'}, - {"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'}, - {"ê", 'e'}, {"ë", 'e'}, {"Ì", 'I'}, {"Í", 'I'}, {"Î", 'I'}, {"Ï", 'I'}, - {"ì", 'i'}, {"í", 'i'}, {"î", 'i'}, {"ï", 'i'}, {"Ò", 'O'}, {"Ó", 'O'}, - {"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'}, - {"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'}, - {"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'}, - {"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'}, {"ñ", 'n'}, - }; - - for (size_t i = 0; i < input_string.length();) { - int len = utf8_len(input_string[i]); - std::string curChar = input_string.substr(i, len); - auto iter = accent_map.find(curChar); - if (iter != accent_map.end()) { - resultString += iter->second; - } else { - resultString += curChar; - } - i += len; - } - - return resultString; - } - - static size_t utf8_len(char src) { - const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4}; - uint8_t highbits = static_cast<uint8_t>(src) >> 4; - return lookup[highbits]; - } - const llama_vocab & vocab; }; |