summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--llama.cpp26
1 files changed, 18 insertions, 8 deletions
diff --git a/llama.cpp b/llama.cpp
index 27545608..2190ea7a 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -10305,6 +10305,8 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
if (0 <= token && token < llama_n_vocab(model)) {
switch (llama_vocab_get_type(model->vocab)) {
case LLAMA_VOCAB_TYPE_SPM: {
+ // NOTE: we accept all unsupported token types,
+ // suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
llama_unescape_whitespace(result);
@@ -10313,6 +10315,13 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
}
memcpy(buf, result.c_str(), result.length());
return result.length();
+ } else if (llama_is_user_defined_token(model->vocab, token)) {
+ std::string result = model->vocab.id_to_token[token].text;
+ if (length < (int) result.length()) {
+ return -result.length();
+ }
+ memcpy(buf, result.c_str(), result.length());
+ return result.length();
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
if (length < 3) {
return -3;
@@ -10327,14 +10336,12 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
}
buf[0] = llama_token_to_byte(model->vocab, token);
return 1;
- } else {
- // TODO: for now we accept all unsupported token types,
- // suppressing them like CONTROL tokens.
- // GGML_ASSERT(false);
}
break;
}
case LLAMA_VOCAB_TYPE_BPE: {
+ // NOTE: we accept all unsupported token types,
+ // suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
result = llama_decode_text(result);
@@ -10343,12 +10350,15 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
}
memcpy(buf, result.c_str(), result.length());
return result.length();
+ } else if (llama_is_user_defined_token(model->vocab, token)) {
+ std::string result = model->vocab.id_to_token[token].text;
+ if (length < (int) result.length()) {
+ return -result.length();
+ }
+ memcpy(buf, result.c_str(), result.length());
+ return result.length();
} else if (llama_is_control_token(model->vocab, token)) {
;
- } else {
- // TODO: for now we accept all unsupported token types,
- // suppressing them like CONTROL tokens.
- // GGML_ASSERT(false);
}
break;
}