summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
authorJhen-Jie Hong <iainst0409@gmail.com>2023-09-02 08:31:46 +0800
committerGitHub <noreply@github.com>2023-09-02 08:31:46 +0800
commit571083f508266c4eb5cb5457d836df5dd3c173ce (patch)
tree91f1a6b16ba767062bf472e07b0507dac49e51d0 /examples/server/server.cpp
parentf04d0028444bc9b3d4225fba47e19d4c3aeb3741 (diff)
server : avoid aniprompt in probabilities of final response (#2849)
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp14
1 files changed, 12 insertions, 2 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 09eac2ec..94def943 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -1379,7 +1379,13 @@ int main(int argc, char **argv)
}
}
- const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);
+ auto probs = llama.generated_token_probs;
+ if (llama.params.n_probs > 0 && llama.stopped_word) {
+ const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
+ probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
+ }
+
+ const json data = format_final_response(llama, llama.generated_text, probs);
llama_print_timings(llama.ctx);
@@ -1456,7 +1462,11 @@ int main(int argc, char **argv)
if (!llama.has_next_token) {
// Generation is done, send extra information.
- const json data = format_final_response(llama, "", llama.generated_token_probs);
+ const json data = format_final_response(
+ llama,
+ "",
+ std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index)
+ );
const std::string str =
"data: " +