diff options
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 25 |
1 files changed, 13 insertions, 12 deletions
@@ -1600,12 +1600,12 @@ struct llama_mlock { }; using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>; -static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) { +static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { std::vector<char> result(8, 0); - const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); + const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); + int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); GGML_ASSERT(check == -n_tokens); } else { @@ -13312,7 +13312,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string piece = llama_token_to_piece(ctx, id); + const std::string piece = llama_token_to_piece(ctx, id, false); + if (llama_token_is_eog(&ctx->model, id)) { if (!allow_eog) { candidates->data[i].logit = -INFINITY; @@ -13512,7 +13513,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar GGML_ASSERT(false); } - const std::string piece = llama_token_to_piece(ctx, token); + const std::string piece = llama_token_to_piece(ctx, token, false); // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar->partial_utf8); @@ -16991,7 +16992,7 @@ static std::string llama_decode_text(const std::string & text) { } // does not write null-terminator to buf -int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length) { +int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) { if (0 <= token && token < llama_n_vocab(model)) { switch (llama_vocab_get_type(model->vocab)) { case LLAMA_VOCAB_TYPE_WPM: @@ -17006,7 +17007,9 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token } memcpy(buf, result.c_str(), result.length()); return result.length(); - } else if (llama_is_user_defined_token(model->vocab, token)) { + } else if ( + (llama_is_user_defined_token(model->vocab, token)) || + (llama_is_control_token (model->vocab, token) && special)) { std::string result = model->vocab.id_to_token[token].text; if (length < (int) result.length()) { return -(int) result.length(); @@ -17019,8 +17022,6 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token } memcpy(buf, "\xe2\x96\x85", 3); return 3; - } else if (llama_is_control_token(model->vocab, token)) { - ; } else if (llama_is_byte_token(model->vocab, token)) { if (length < 1) { return -1; @@ -17041,15 +17042,15 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token } memcpy(buf, result.c_str(), result.length()); return result.length(); - } else if (llama_is_user_defined_token(model->vocab, token)) { + } else if ( + (llama_is_user_defined_token(model->vocab, token)) || + (llama_is_control_token (model->vocab, token) && special)) { std::string result = model->vocab.id_to_token[token].text; if (length < (int) result.length()) { return -(int) result.length(); } memcpy(buf, result.c_str(), result.length()); return result.length(); - } else if (llama_is_control_token(model->vocab, token)) { - ; } break; } |