summaryrefslogtreecommitdiff
path: root/unicode.h
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-02-13 15:14:22 +0200
committerGitHub <noreply@github.com>2024-02-13 15:14:22 +0200
commitcf45252a7cfcb998bade46a886e20477cecc538a (patch)
tree89400db2102fb39809b61af13a7534e49ec5c27e /unicode.h
parent03bf161eb6dea6400ee49c6dc6b69bdcfa9fd3fc (diff)
tests : multi-thread the tokenizer tests (#5474)
* tests : multi-thread the tokenizer tests ggml-ci * unicode : fix data race for unidentified codepoints ggml-ci * unicode : minor style fixes ggml-ci
Diffstat (limited to 'unicode.h')
-rw-r--r--unicode.h72
1 files changed, 42 insertions, 30 deletions
diff --git a/unicode.h b/unicode.h
index 844eff3d..26326070 100644
--- a/unicode.h
+++ b/unicode.h
@@ -264,26 +264,29 @@ static uint32_t codepoint_from_utf8(const std::string & utf8, size_t & offset) {
offset += 1;
return result;
}
- else if (!(utf8[offset + 0] & 0x40)) {
+ if (!(utf8[offset + 0] & 0x40)) {
throw std::invalid_argument("invalid character");
}
- else if (!(utf8[offset + 0] & 0x20)) {
- if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80))
+ if (!(utf8[offset + 0] & 0x20)) {
+ if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) {
throw std::invalid_argument("invalid character");
+ }
auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f);
offset += 2;
return result;
}
- else if (!(utf8[offset + 0] & 0x10)) {
- if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80))
+ if (!(utf8[offset + 0] & 0x10)) {
+ if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) {
throw std::invalid_argument("invalid character");
+ }
auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f);
offset += 3;
return result;
}
- else if (!(utf8[offset + 0] & 0x08)) {
- if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80))
+ if (!(utf8[offset + 0] & 0x08)) {
+ if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
throw std::invalid_argument("invalid character");
+ }
auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f);
offset += 4;
return result;
@@ -331,21 +334,22 @@ static uint32_t codepoint_from_utf16(const std::vector<uint16_t> & utf16, size_t
offset += 1;
return result;
}
- else {
- if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00))
- throw std::invalid_argument("invalid character");
- auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
- offset += 2;
- return result;
+
+ if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
+ throw std::invalid_argument("invalid character");
}
- throw std::invalid_argument("invalid string");
+
+ auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
+ offset += 2;
+ return result;
}
static std::vector<uint32_t> codepoints_from_utf16(const std::vector<uint16_t> & utf16) {
std::vector<uint32_t> result;
size_t offset = 0;
- while (offset < utf16.size())
+ while (offset < utf16.size()) {
result.push_back(codepoint_from_utf16(utf16, offset));
+ }
return result;
}
@@ -361,44 +365,52 @@ static std::vector<uint32_t> codepoints_from_utf16(const std::vector<uint16_t> &
static std::unordered_map<uint32_t, int> codepoint_type_map() {
std::unordered_map<uint32_t, int> codepoint_types;
for (auto p : digit_ranges) {
- for(auto i = p.first; i <= p.second; ++ i)
+ for (auto i = p.first; i <= p.second; ++ i) {
codepoint_types[i] = CODEPOINT_TYPE_DIGIT;
+ }
}
- for(auto p : letter_ranges) {
- for(auto i = p.first; i <= p.second; ++ i)
+ for (auto p : letter_ranges) {
+ for (auto i = p.first; i <= p.second; ++ i) {
codepoint_types[i] = CODEPOINT_TYPE_LETTER;
+ }
}
- for(auto p : whitespace_ranges) {
- for(auto i = p.first; i <= p.second; ++ i)
+ for (auto p : whitespace_ranges) {
+ for (auto i = p.first; i <= p.second; ++ i) {
codepoint_types[i] = CODEPOINT_TYPE_WHITESPACE;
+ }
}
- for(auto p : accent_mark_ranges) {
- for(auto i = p.first; i <= p.second; ++ i)
+ for (auto p : accent_mark_ranges) {
+ for (auto i = p.first; i <= p.second; ++ i) {
codepoint_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
+ }
}
- for(auto p : punctuation_ranges) {
- for(auto i = p.first; i <= p.second; ++ i)
+ for (auto p : punctuation_ranges) {
+ for (auto i = p.first; i <= p.second; ++ i) {
codepoint_types[i] = CODEPOINT_TYPE_PUNCTUATION;
+ }
}
- for (auto p : symbol_ranges) {
- for (auto i = p.first; i <= p.second; ++i)
+ for (auto p : symbol_ranges) {
+ for (auto i = p.first; i <= p.second; ++i) {
codepoint_types[i] = CODEPOINT_TYPE_SYMBOL;
+ }
}
- for(auto p : control_ranges) {
- for(auto i = p.first; i <= p.second; ++ i)
+ for (auto p : control_ranges) {
+ for (auto i = p.first; i <= p.second; ++ i) {
codepoint_types[i] = CODEPOINT_TYPE_CONTROL;
+ }
}
return codepoint_types;
}
static int codepoint_type(uint32_t cp) {
static std::unordered_map<uint32_t, int> codepoint_types = codepoint_type_map();
- return codepoint_types[cp];
+ return codepoint_types.find(cp) == codepoint_types.end() ? CODEPOINT_TYPE_UNIDENTIFIED : codepoint_types.at(cp);
}
static int codepoint_type(const std::string & utf8) {
- if (utf8.length() == 0)
+ if (utf8.length() == 0) {
return CODEPOINT_TYPE_UNIDENTIFIED;
+ }
size_t offset = 0;
return codepoint_type(codepoint_from_utf8(utf8, offset));
}