summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp73
1 files changed, 20 insertions, 53 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index ee0ababb..28b3f3f5 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -1,7 +1,6 @@
#include "common.h"
#include "llama.h"
#include "build-info.h"
-#include "grammar-parser.h"
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
@@ -195,17 +194,13 @@ struct llama_server_context
json prompt;
std::vector<llama_token> embd;
- std::vector<llama_token> last_n_tokens;
llama_model *model = nullptr;
llama_context *ctx = nullptr;
gpt_params params;
- llama_sampling_context ctx_sampling;
+ llama_sampling_context *ctx_sampling;
int n_ctx;
- grammar_parser::parse_state parsed_grammar;
- llama_grammar *grammar = nullptr;
-
bool truncated = false;
bool stopped_eos = false;
bool stopped_word = false;
@@ -252,11 +247,10 @@ struct llama_server_context
n_remain = 0;
n_past = 0;
- if (grammar != nullptr) {
- llama_grammar_free(grammar);
- grammar = nullptr;
- ctx_sampling = llama_sampling_context_init(params, NULL);
+ if (ctx_sampling != nullptr) {
+ llama_sampling_free(ctx_sampling);
}
+ ctx_sampling = llama_sampling_init(params);
}
bool loadModel(const gpt_params &params_)
@@ -269,8 +263,6 @@ struct llama_server_context
return false;
}
n_ctx = llama_n_ctx(ctx);
- last_n_tokens.resize(n_ctx);
- std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
return true;
}
@@ -321,27 +313,7 @@ struct llama_server_context
bool loadGrammar()
{
- if (!params.grammar.empty()) {
- parsed_grammar = grammar_parser::parse(params.grammar.c_str());
- // will be empty (default) if there are parse errors
- if (parsed_grammar.rules.empty()) {
- LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
- return false;
- }
- grammar_parser::print_grammar(stderr, parsed_grammar);
-
- {
- auto it = params.sampling_params.logit_bias.find(llama_token_eos(ctx));
- if (it != params.sampling_params.logit_bias.end() && it->second == -INFINITY) {
- LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
- }
- }
-
- std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
- grammar = llama_grammar_init(
- grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
- }
- ctx_sampling = llama_sampling_context_init(params, grammar);
+ ctx_sampling = llama_sampling_init(params);
return true;
}
@@ -383,7 +355,7 @@ struct llama_server_context
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
- std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
+ std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
LOG_VERBOSE("input truncated", {
{"n_ctx", params.n_ctx},
@@ -398,8 +370,8 @@ struct llama_server_context
else
{
const size_t ps = num_prompt_tokens;
- std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
- std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
+ std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
+ std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
}
// compare the evaluated prompt with the new prompt
@@ -443,7 +415,7 @@ struct llama_server_context
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
- std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
+ std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
@@ -458,8 +430,8 @@ struct llama_server_context
else
{
const size_t ps = num_prompt_tokens;
- std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
- std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
+ std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
+ std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
}
// compare the evaluated prompt with the new prompt
@@ -554,27 +526,24 @@ struct llama_server_context
{
// out of user input, sample next token
- std::vector<llama_token_data> candidates;
- candidates.reserve(llama_n_vocab(model));
-
- result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates);
+ result.tok = llama_sampling_sample(ctx_sampling, ctx, NULL);
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+ llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false };
const int32_t n_probs = params.sampling_params.n_probs;
if (params.sampling_params.temp <= 0 && n_probs > 0)
{
// For llama_sample_token_greedy we need to sort candidates
- llama_sample_softmax(ctx, &candidates_p);
+ llama_sample_softmax(ctx, &cur_p);
}
- for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
+ for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i)
{
- result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
+ result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
}
- last_n_tokens.erase(last_n_tokens.begin());
- last_n_tokens.push_back(result.tok);
+ llama_sampling_accept(ctx_sampling, ctx, result.tok);
+
if (tg) {
num_tokens_predicted++;
}
@@ -1235,7 +1204,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
}
}
- llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar);
+ llama.ctx_sampling = llama_sampling_init(llama.params);
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
}
@@ -1793,9 +1762,7 @@ int main(int argc, char **argv)
return 1;
}
- if (llama.grammar != nullptr) {
- llama_grammar_free(llama.grammar);
- }
+ llama_sampling_free(llama.ctx_sampling);
llama_backend_free();
return 0;