diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-02-13 15:14:22 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-13 15:14:22 +0200 |
commit | cf45252a7cfcb998bade46a886e20477cecc538a (patch) | |
tree | 89400db2102fb39809b61af13a7534e49ec5c27e /unicode.h | |
parent | 03bf161eb6dea6400ee49c6dc6b69bdcfa9fd3fc (diff) |
tests : multi-thread the tokenizer tests (#5474)
* tests : multi-thread the tokenizer tests
ggml-ci
* unicode : fix data race for unidentified codepoints
ggml-ci
* unicode : minor style fixes
ggml-ci
Diffstat (limited to 'unicode.h')
-rw-r--r-- | unicode.h | 72 |
1 files changed, 42 insertions, 30 deletions
@@ -264,26 +264,29 @@ static uint32_t codepoint_from_utf8(const std::string & utf8, size_t & offset) { offset += 1; return result; } - else if (!(utf8[offset + 0] & 0x40)) { + if (!(utf8[offset + 0] & 0x40)) { throw std::invalid_argument("invalid character"); } - else if (!(utf8[offset + 0] & 0x20)) { - if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) + if (!(utf8[offset + 0] & 0x20)) { + if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) { throw std::invalid_argument("invalid character"); + } auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f); offset += 2; return result; } - else if (!(utf8[offset + 0] & 0x10)) { - if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) + if (!(utf8[offset + 0] & 0x10)) { + if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) { throw std::invalid_argument("invalid character"); + } auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f); offset += 3; return result; } - else if (!(utf8[offset + 0] & 0x08)) { - if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) + if (!(utf8[offset + 0] & 0x08)) { + if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) { throw std::invalid_argument("invalid character"); + } auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f); offset += 4; return result; @@ -331,21 +334,22 @@ static uint32_t codepoint_from_utf16(const std::vector<uint16_t> & utf16, size_t offset += 1; return result; } - else { - if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) - throw std::invalid_argument("invalid character"); - auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff)); - offset += 2; - return result; + + if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) { + throw std::invalid_argument("invalid character"); } - throw std::invalid_argument("invalid string"); + + auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff)); + offset += 2; + return result; } static std::vector<uint32_t> codepoints_from_utf16(const std::vector<uint16_t> & utf16) { std::vector<uint32_t> result; size_t offset = 0; - while (offset < utf16.size()) + while (offset < utf16.size()) { result.push_back(codepoint_from_utf16(utf16, offset)); + } return result; } @@ -361,44 +365,52 @@ static std::vector<uint32_t> codepoints_from_utf16(const std::vector<uint16_t> & static std::unordered_map<uint32_t, int> codepoint_type_map() { std::unordered_map<uint32_t, int> codepoint_types; for (auto p : digit_ranges) { - for(auto i = p.first; i <= p.second; ++ i) + for (auto i = p.first; i <= p.second; ++ i) { codepoint_types[i] = CODEPOINT_TYPE_DIGIT; + } } - for(auto p : letter_ranges) { - for(auto i = p.first; i <= p.second; ++ i) + for (auto p : letter_ranges) { + for (auto i = p.first; i <= p.second; ++ i) { codepoint_types[i] = CODEPOINT_TYPE_LETTER; + } } - for(auto p : whitespace_ranges) { - for(auto i = p.first; i <= p.second; ++ i) + for (auto p : whitespace_ranges) { + for (auto i = p.first; i <= p.second; ++ i) { codepoint_types[i] = CODEPOINT_TYPE_WHITESPACE; + } } - for(auto p : accent_mark_ranges) { - for(auto i = p.first; i <= p.second; ++ i) + for (auto p : accent_mark_ranges) { + for (auto i = p.first; i <= p.second; ++ i) { codepoint_types[i] = CODEPOINT_TYPE_ACCENT_MARK; + } } - for(auto p : punctuation_ranges) { - for(auto i = p.first; i <= p.second; ++ i) + for (auto p : punctuation_ranges) { + for (auto i = p.first; i <= p.second; ++ i) { codepoint_types[i] = CODEPOINT_TYPE_PUNCTUATION; + } } - for (auto p : symbol_ranges) { - for (auto i = p.first; i <= p.second; ++i) + for (auto p : symbol_ranges) { + for (auto i = p.first; i <= p.second; ++i) { codepoint_types[i] = CODEPOINT_TYPE_SYMBOL; + } } - for(auto p : control_ranges) { - for(auto i = p.first; i <= p.second; ++ i) + for (auto p : control_ranges) { + for (auto i = p.first; i <= p.second; ++ i) { codepoint_types[i] = CODEPOINT_TYPE_CONTROL; + } } return codepoint_types; } static int codepoint_type(uint32_t cp) { static std::unordered_map<uint32_t, int> codepoint_types = codepoint_type_map(); - return codepoint_types[cp]; + return codepoint_types.find(cp) == codepoint_types.end() ? CODEPOINT_TYPE_UNIDENTIFIED : codepoint_types.at(cp); } static int codepoint_type(const std::string & utf8) { - if (utf8.length() == 0) + if (utf8.length() == 0) { return CODEPOINT_TYPE_UNIDENTIFIED; + } size_t offset = 0; return codepoint_type(codepoint_from_utf8(utf8, offset)); } |