diff options
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 229 |
1 files changed, 110 insertions, 119 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0471528a..b5ad3cc9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -195,10 +195,12 @@ struct llama_server_context json prompt; std::vector<llama_token> embd; + gpt_params params; + llama_model *model = nullptr; llama_context *ctx = nullptr; - gpt_params params; llama_sampling_context *ctx_sampling = nullptr; + int n_ctx; bool truncated = false; @@ -232,7 +234,7 @@ struct llama_server_context void rewind() { params.antiprompt.clear(); - params.grammar.clear(); + params.sparams.grammar.clear(); num_prompt_tokens = 0; num_tokens_predicted = 0; generated_text = ""; @@ -246,11 +248,14 @@ struct llama_server_context multibyte_pending = 0; n_remain = 0; n_past = 0; + params.sparams.n_prev = n_ctx; + } + void initSampling() { if (ctx_sampling != nullptr) { llama_sampling_free(ctx_sampling); } - ctx_sampling = llama_sampling_init(params); + ctx_sampling = llama_sampling_init(params.sparams); } bool loadModel(const gpt_params ¶ms_) @@ -311,16 +316,32 @@ struct llama_server_context return prompt_tokens; } - bool loadGrammar() - { - ctx_sampling = llama_sampling_init(params); - return true; + void truncatePrompt(std::vector<llama_token> &prompt_tokens) { + const int n_left = n_ctx - params.n_keep; + const int n_block_size = n_left / 2; + const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_block_size) / n_block_size; + + // Keep n_keep tokens at start of prompt (at most n_ctx - 4) + std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + + new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_block_size, prompt_tokens.end()); + + LOG_VERBOSE("input truncated", { + {"n_ctx", n_ctx}, + {"n_keep", params.n_keep}, + {"n_left", n_left}, + {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + {"num_prompt_tokens", new_tokens.size()} + }); + + truncated = true; + prompt_tokens = new_tokens; } void loadInfill() { bool suff_rm_leading_spc = true; - if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { params.input_suffix.erase(0, 1); suff_rm_leading_spc = false; } @@ -336,6 +357,7 @@ struct llama_server_context prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); prefix_tokens.push_back(llama_token_middle(ctx)); + auto prompt_tokens = prefix_tokens; num_prompt_tokens = prompt_tokens.size(); @@ -347,31 +369,18 @@ struct llama_server_context params.n_keep = std::min(params.n_ctx - 4, params.n_keep); // if input prompt is too big, truncate like normal - if (num_prompt_tokens >= (size_t)params.n_ctx) + if (num_prompt_tokens >= (size_t) n_ctx) { - printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens); - // todo we probably want to cut from both sides - const int n_left = (params.n_ctx - params.n_keep) / 2; - std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin()); - - LOG_VERBOSE("input truncated", { - {"n_ctx", params.n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - }); + truncatePrompt(prompt_tokens); + num_prompt_tokens = prompt_tokens.size(); - truncated = true; - prompt_tokens = new_tokens; + GGML_ASSERT(num_prompt_tokens < (size_t)n_ctx); } - else + + // push the prompt into the sampling context (do not apply grammar) + for (auto & token : prompt_tokens) { - const size_t ps = num_prompt_tokens; - std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps); + llama_sampling_accept(ctx_sampling, ctx, token, false); } // compare the evaluated prompt with the new prompt @@ -409,29 +418,18 @@ struct llama_server_context params.n_keep = std::min(n_ctx - 4, params.n_keep); // if input prompt is too big, truncate like normal - if (num_prompt_tokens >= (size_t)n_ctx) + if (num_prompt_tokens >= (size_t) n_ctx) { - const int n_left = (n_ctx - params.n_keep) / 2; - std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin()); + truncatePrompt(prompt_tokens); + num_prompt_tokens = prompt_tokens.size(); - LOG_VERBOSE("input truncated", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - }); - - truncated = true; - prompt_tokens = new_tokens; + GGML_ASSERT(num_prompt_tokens < (size_t)n_ctx); } - else + + // push the prompt into the sampling context (do not apply grammar) + for (auto & token : prompt_tokens) { - const size_t ps = num_prompt_tokens; - std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps); + llama_sampling_accept(ctx_sampling, ctx, token, false); } // compare the evaluated prompt with the new prompt @@ -530,8 +528,8 @@ struct llama_server_context llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false }; - const int32_t n_probs = params.sampling_params.n_probs; - if (params.sampling_params.temp <= 0 && n_probs > 0) + const int32_t n_probs = params.sparams.n_probs; + if (params.sparams.temp <= 0 && n_probs > 0) { // For llama_sample_token_greedy we need to sort candidates llama_sample_softmax(ctx, &cur_p); @@ -542,7 +540,7 @@ struct llama_server_context result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); } - llama_sampling_accept(ctx_sampling, ctx, result.tok); + llama_sampling_accept(ctx_sampling, ctx, result.tok, true); if (tg) { num_tokens_predicted++; @@ -606,7 +604,7 @@ struct llama_server_context const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok); generated_text += token_text; - if (params.sampling_params.n_probs > 0) + if (params.sparams.n_probs > 0) { generated_token_probs.push_back(token_with_probs); } @@ -1004,36 +1002,36 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, static json format_generation_settings(llama_server_context &llama) { - const auto & sparams = llama.params.sampling_params; + const auto & sparams = llama.params.sparams; const auto eos_bias = sparams.logit_bias.find(llama_token_eos(llama.ctx)); const bool ignore_eos = eos_bias != sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); return json{ - {"n_ctx", llama.n_ctx}, - {"model", llama.params.model_alias}, - {"seed", llama.params.seed}, - {"temp", sparams.temp}, - {"top_k", sparams.top_k}, - {"top_p", sparams.top_p}, - {"tfs_z", sparams.tfs_z}, - {"typical_p", sparams.typical_p}, - {"repeat_last_n", sparams.repeat_last_n}, - {"repeat_penalty", sparams.repeat_penalty}, - {"presence_penalty", sparams.presence_penalty}, - {"frequency_penalty", sparams.frequency_penalty}, - {"mirostat", sparams.mirostat}, - {"mirostat_tau", sparams.mirostat_tau}, - {"mirostat_eta", sparams.mirostat_eta}, - {"penalize_nl", sparams.penalize_nl}, - {"stop", llama.params.antiprompt}, - {"n_predict", llama.params.n_predict}, - {"n_keep", llama.params.n_keep}, - {"ignore_eos", ignore_eos}, - {"stream", llama.stream}, - {"logit_bias", sparams.logit_bias}, - {"n_probs", sparams.n_probs}, - {"grammar", llama.params.grammar}, + {"n_ctx", llama.n_ctx}, + {"model", llama.params.model_alias}, + {"seed", llama.params.seed}, + {"temp", sparams.temp}, + {"top_k", sparams.top_k}, + {"top_p", sparams.top_p}, + {"tfs_z", sparams.tfs_z}, + {"typical_p", sparams.typical_p}, + {"repeat_last_n", sparams.penalty_last_n}, + {"repeat_penalty", sparams.penalty_repeat}, + {"frequency_penalty", sparams.penalty_freq}, + {"presence_penalty", sparams.penalty_present}, + {"mirostat", sparams.mirostat}, + {"mirostat_tau", sparams.mirostat_tau}, + {"mirostat_eta", sparams.mirostat_eta}, + {"penalize_nl", sparams.penalize_nl}, + {"stop", llama.params.antiprompt}, + {"n_predict", llama.params.n_predict}, + {"n_keep", llama.params.n_keep}, + {"ignore_eos", ignore_eos}, + {"stream", llama.stream}, + {"logit_bias", sparams.logit_bias}, + {"n_probs", sparams.n_probs}, + {"grammar", llama.params.sparams.grammar}, }; } @@ -1081,7 +1079,7 @@ static json format_final_response(llama_server_context &llama, const std::string {"timings", format_timings(llama)}, }; - if (llama.params.sampling_params.n_probs > 0) + if (llama.params.sparams.n_probs > 0) { res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); } @@ -1097,7 +1095,7 @@ static json format_partial_response( {"stop", false}, }; - if (llama.params.sampling_params.n_probs > 0) + if (llama.params.sparams.n_probs > 0) { res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); } @@ -1129,28 +1127,30 @@ static T json_value(const json &body, const std::string &key, const T &default_v static void parse_options_completion(const json &body, llama_server_context &llama) { gpt_params default_params; - const auto & default_sparams = default_params.sampling_params; - auto & sparams = llama.params.sampling_params; - - llama.stream = json_value(body, "stream", false); - llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict); - sparams.top_k = json_value(body, "top_k", default_sparams.top_k); - sparams.top_p = json_value(body, "top_p", default_sparams.top_p); - sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z); - sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p); - sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n); - sparams.temp = json_value(body, "temperature", default_sparams.temp); - sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty); - sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty); - sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty); - sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat); - sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); - sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); - sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl); - llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep); - llama.params.seed = json_value(body, "seed", default_params.seed); - llama.params.grammar = json_value(body, "grammar", default_params.grammar); - sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs); + const auto & default_sparams = default_params.sparams; + + auto & params = llama.params; + auto & sparams = llama.params.sparams; + + llama.stream = json_value(body, "stream", false); + params.n_predict = json_value(body, "n_predict", default_params.n_predict); + sparams.top_k = json_value(body, "top_k", default_sparams.top_k); + sparams.top_p = json_value(body, "top_p", default_sparams.top_p); + sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z); + sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p); + sparams.temp = json_value(body, "temperature", default_sparams.temp); + sparams.penalty_last_n = json_value(body, "repeat_last_n", default_sparams.penalty_last_n); + sparams.penalty_repeat = json_value(body, "repeat_penalty", default_sparams.penalty_repeat); + sparams.penalty_freq = json_value(body, "frequency_penalty", default_sparams.penalty_freq); + sparams.penalty_present = json_value(body, "presence_penalty", default_sparams.penalty_present); + sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat); + sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); + sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); + sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl); + params.n_keep = json_value(body, "n_keep", default_params.n_keep); + params.seed = json_value(body, "seed", default_params.seed); + sparams.grammar = json_value(body, "grammar", default_sparams.grammar); + sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs); if (body.count("prompt") != 0) { @@ -1204,8 +1204,6 @@ static void parse_options_completion(const json &body, llama_server_context &lla } } - llama.ctx_sampling = llama_sampling_init(llama.params); - LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); } @@ -1374,15 +1372,9 @@ int main(int argc, char **argv) llama.rewind(); llama_reset_timings(llama.ctx); - parse_options_completion(json::parse(req.body), llama); - if (!llama.loadGrammar()) - { - res.status = 400; - return; - } - + llama.initSampling(); llama.loadPrompt(); llama.beginCompletion(); @@ -1414,7 +1406,7 @@ int main(int argc, char **argv) } auto probs = llama.generated_token_probs; - if (llama.params.sampling_params.n_probs > 0 && llama.stopped_word) { + if (llama.params.sparams.n_probs > 0 && llama.stopped_word) { const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false); probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size()); } @@ -1466,7 +1458,7 @@ int main(int argc, char **argv) std::vector<completion_token_output> probs_output = {}; - if (llama.params.sampling_params.n_probs > 0) { + if (llama.params.sparams.n_probs > 0) { const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false); size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); @@ -1537,14 +1529,9 @@ int main(int argc, char **argv) llama.rewind(); llama_reset_timings(llama.ctx); - parse_options_infill(json::parse(req.body), llama); - if (!llama.loadGrammar()) - { - res.status = 400; - return; - } + llama.initSampling(); llama.loadInfill(); llama.beginCompletion(); const auto chunked_content_provider = [&](size_t, DataSink & sink) { @@ -1587,7 +1574,7 @@ int main(int argc, char **argv) std::vector<completion_token_output> probs_output = {}; - if (llama.params.sampling_params.n_probs > 0) { + if (llama.params.sparams.n_probs > 0) { const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false); size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); @@ -1694,7 +1681,9 @@ int main(int argc, char **argv) const json body = json::parse(req.body); llama.rewind(); + llama_reset_timings(llama.ctx); + if (body.count("content") != 0) { llama.prompt = body["content"]; @@ -1704,6 +1693,8 @@ int main(int argc, char **argv) llama.prompt = ""; } llama.params.n_predict = 0; + + llama.initSampling(); llama.loadPrompt(); llama.beginCompletion(); llama.doCompletion(); |