From 29674ab4e847fcaba60cc6558f0d46d5f74ae279 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Fri, 25 Aug 2023 18:32:45 +0800 Subject: server : display token probabilities in the UI (#2489) * server : add n_probs param in chat UI * server : keep message data array & show in probabilites component * server : add simple popover component * server : fix completion_probabilities undefined if not set n_probs * server : implement Probabilites * server : handle bytes * server : make n_probs max to 10 for easy scroll * server : adjust for dark/light mode * server : Fix regenerated prompt * server : update index.html.hpp * server : convert prob to percentage + show original value as div title * server : fix Probabilites not used if included empty str * server : skip byte pair in display probabilites * server : remove array check of completion_probabilities in messages * skip empty array or byte pair (> 1) in Probabilites * generate index.html.hpp * fix incorrect prob convert if the str is already a known token * use final response to show probabilities on stop * revert unnecessary change * correct probabilites usage * remove unused function * always send partial response for get correct probs of last to_send * fix typo * fix content of format_final_response * refactor probs render & make pColor transparent if not found * send empty string when got stop_pos in partial * avoid unnecessary empty data event & send rest of partial tokens on stop * use
for new line * skip -1 tok in loop to avoid send '' on end * trim last new lines on stop * revert unnecessary change --- examples/server/server.cpp | 86 ++++++++++++++++++++++++++++++---------------- 1 file changed, 57 insertions(+), 29 deletions(-) (limited to 'examples/server/server.cpp') diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1e6d10c1..025b385c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -124,8 +124,9 @@ static void server_log(const char *level, const char *function, int line, static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) { std::string out = token == -1 ? "" : llama_token_to_str(ctx, token); - // if first bit is 1, meaning it's a partial character - if (out.size() > 0 && (out[0] & 0x80) == 0x80) + // if the size is 1 and first bit is 1, meaning it's a partial character + // (size > 1 meaning it's already a known token) + if (out.size() == 1 && (out[0] & 0x80) == 0x80) { std::stringstream ss; ss << std::hex << (out[0] & 0xff); @@ -1321,59 +1322,86 @@ int main(int argc, char **argv) while (llama.has_next_token) { const completion_token_output token_with_probs = llama.doCompletion(); - const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok); - if (llama.multibyte_pending > 0) { + if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) { continue; } + const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok); size_t pos = std::min(sent_count, llama.generated_text.size()); const std::string str_test = llama.generated_text.substr(pos); + bool is_stop_full = false; size_t stop_pos = llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); if (stop_pos != std::string::npos) { + is_stop_full = true; llama.generated_text.erase( llama.generated_text.begin() + pos + stop_pos, llama.generated_text.end()); pos = std::min(sent_count, llama.generated_text.size()); } else { + is_stop_full = false; stop_pos = llama.findStoppingStrings(str_test, token_text.size(), STOP_PARTIAL); } - const std::string to_send = llama.generated_text.substr(pos, stop_pos); - sent_count += to_send.size(); + if ( + stop_pos == std::string::npos || + // Send rest of the text if we are at the end of the generation + (!llama.has_next_token && !is_stop_full && stop_pos > 0) + ) { + const std::string to_send = llama.generated_text.substr(pos, std::string::npos); + + sent_count += to_send.size(); + + std::vector probs_output = {}; + + if (llama.params.n_probs > 0) { + const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); + size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); + size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); + } + sent_token_probs_index = probs_stop_pos; + } + + const json data = format_partial_response(llama, to_send, probs_output); + + const std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; - std::vector probs_output = {}; + LOG_VERBOSE("data stream", { + { "to_send", str } + }); - if (llama.params.n_probs > 0) { - const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); - size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); - size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); - if (probs_pos < probs_stop_pos) { - probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); + if (!sink.write(str.data(), str.size())) { + LOG_VERBOSE("stream closed", {}); + llama_print_timings(llama.ctx); + return false; } - sent_token_probs_index = probs_stop_pos; } - const json data = llama.has_next_token - ? format_partial_response(llama, to_send, probs_output) - // Generation is done, send extra information. - : format_final_response(llama, to_send, llama.generated_token_probs); + if (!llama.has_next_token) { + // Generation is done, send extra information. + const json data = format_final_response(llama, "", llama.generated_token_probs); - const std::string str = - "data: " + - data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; + const std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; - LOG_VERBOSE("data stream", { - { "to_send", str } - }); + LOG_VERBOSE("data stream", { + { "to_send", str } + }); - if (!sink.write(str.data(), str.size())) { - LOG_VERBOSE("stream closed", {}); - llama_print_timings(llama.ctx); - return false; + if (!sink.write(str.data(), str.size())) { + LOG_VERBOSE("stream closed", {}); + llama_print_timings(llama.ctx); + return false; + } } } -- cgit v1.2.3