diff options
Diffstat (limited to 'examples')
-rw-r--r-- | examples/embd-input/embd-input-lib.cpp | 19 | ||||
-rw-r--r-- | examples/infill/infill.cpp | 18 | ||||
-rw-r--r-- | examples/main/main.cpp | 18 | ||||
-rw-r--r-- | examples/parallel/parallel.cpp | 6 | ||||
-rw-r--r-- | examples/save-load-state/save-load-state.cpp | 5 | ||||
-rw-r--r-- | examples/server/server.cpp | 100 | ||||
-rw-r--r-- | examples/speculative/speculative.cpp | 12 |
7 files changed, 101 insertions, 77 deletions
diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 99e6bdad..87a5a1c2 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -128,21 +128,22 @@ bool eval_string(struct MyModel * mymodel,const char* str){ llama_token sampling_id(struct MyModel* mymodel) { llama_context* ctx = mymodel->ctx; gpt_params params = mymodel->params; + llama_sampling_params & sparams = params.sampling_params; // int n_ctx = llama_n_ctx(ctx); // out of user input, sample next token - const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : params.top_k; - const float top_p = params.top_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; + const float temp = sparams.temp; + const int32_t top_k = sparams.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : sparams.top_k; + const float top_p = sparams.top_p; + const float tfs_z = sparams.tfs_z; + const float typical_p = sparams.typical_p; // const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; // const float repeat_penalty = params.repeat_penalty; // const float alpha_presence = params.presence_penalty; // const float alpha_frequency = params.frequency_penalty; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; + const int mirostat = sparams.mirostat; + const float mirostat_tau = sparams.mirostat_tau; + const float mirostat_eta = sparams.mirostat_eta; // const bool penalize_nl = params.penalize_nl; llama_token id = 0; @@ -151,7 +152,7 @@ llama_token sampling_id(struct MyModel* mymodel) { auto n_vocab = llama_n_vocab(llama_get_model(ctx)); // Apply params.logit_bias map - for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + for (auto it = sparams.logit_bias.begin(); it != sparams.logit_bias.end(); it++) { logits[it->first] += it->second; } diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index d994de5e..187623f5 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -104,6 +104,7 @@ static void sigint_handler(int signo) { int main(int argc, char ** argv) { gpt_params params; + llama_sampling_params & sparams = params.sampling_params; g_params = ¶ms; if (!gpt_params_parse(argc, argv, params)) { @@ -206,7 +207,7 @@ int main(int argc, char ** argv) { // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); std::tie(model, ctx) = llama_init_from_gpt_params(params); - if (params.cfg_scale > 1.f) { + if (sparams.cfg_scale > 1.f) { struct llama_context_params lparams = llama_context_params_from_gpt_params(params); ctx_guidance = llama_new_context_with_model(model, lparams); } @@ -269,9 +270,9 @@ int main(int argc, char ** argv) { int guidance_offset = 0; int original_prompt_len = 0; if (ctx_guidance) { - LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt)); + LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt)); - guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos); + guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos); LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp)); std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos); @@ -312,7 +313,7 @@ int main(int argc, char ** argv) { if (ctx_guidance) { LOG_TEE("\n"); - LOG_TEE("%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); + LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str()); LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); for (int i = 0; i < (int) guidance_inp.size(); i++) { LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str()); @@ -358,7 +359,7 @@ int main(int argc, char ** argv) { } } LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", - params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); + sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n"); @@ -376,8 +377,8 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); { - auto it = params.logit_bias.find(llama_token_eos(ctx)); - if (it != params.logit_bias.end() && it->second == -INFINITY) { + auto it = sparams.logit_bias.find(llama_token_eos(ctx)); + if (it != sparams.logit_bias.end() && it->second == -INFINITY) { LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); } } @@ -434,6 +435,7 @@ int main(int argc, char ** argv) { const int n_vocab = llama_n_vocab(model); + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar); std::vector<llama_token_data> candidates; candidates.reserve(n_vocab); @@ -552,7 +554,7 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates); + const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 775a5a20..b39a67d9 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -109,6 +109,7 @@ int main(int argc, char ** argv) { if (!gpt_params_parse(argc, argv, params)) { return 1; } + llama_sampling_params & sparams = params.sampling_params; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("main", "log")); @@ -179,7 +180,7 @@ int main(int argc, char ** argv) { // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); std::tie(model, ctx) = llama_init_from_gpt_params(params); - if (params.cfg_scale > 1.f) { + if (sparams.cfg_scale > 1.f) { struct llama_context_params lparams = llama_context_params_from_gpt_params(params); ctx_guidance = llama_new_context_with_model(model, lparams); } @@ -257,9 +258,9 @@ int main(int argc, char ** argv) { int guidance_offset = 0; int original_prompt_len = 0; if (ctx_guidance) { - LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt)); + LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt)); - guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos); + guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos); LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp)); std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos); @@ -343,7 +344,7 @@ int main(int argc, char ** argv) { if (ctx_guidance) { LOG_TEE("\n"); - LOG_TEE("%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); + LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str()); LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); for (int i = 0; i < (int) guidance_inp.size(); i++) { LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str()); @@ -395,7 +396,7 @@ int main(int argc, char ** argv) { } } LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", - params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); + sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n"); @@ -413,8 +414,8 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); { - auto it = params.logit_bias.find(llama_token_eos(ctx)); - if (it != params.logit_bias.end() && it->second == -INFINITY) { + auto it = sparams.logit_bias.find(llama_token_eos(ctx)); + if (it != sparams.logit_bias.end() && it->second == -INFINITY) { LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); } } @@ -469,6 +470,7 @@ int main(int argc, char ** argv) { const int n_vocab = llama_n_vocab(model); + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar); std::vector<llama_token_data> candidates; candidates.reserve(n_vocab); @@ -625,7 +627,7 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates); + const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 04f1e45b..63ddcd8e 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -125,6 +125,8 @@ int main(int argc, char ** argv) { params.logits_all = true; std::tie(model, ctx) = llama_init_from_gpt_params(params); + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL); + // load the prompts from an external file if there are any if (params.prompt.empty()) { printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n"); @@ -339,7 +341,7 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i); + const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -384,7 +386,7 @@ int main(int argc, char ** argv) { n_total_prompt += client.n_prompt; n_total_gen += client.n_decoded; - + llama_sampling_context_reset(ctx_sampling, client.seq_id); client.seq_id = -1; } diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index acc6dbdf..f9e3c98a 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -8,9 +8,10 @@ int main(int argc, char ** argv) { gpt_params params; + llama_sampling_params & sparams = params.sampling_params; params.seed = 42; params.n_threads = 4; - params.repeat_last_n = 64; + sparams.repeat_last_n = 64; params.prompt = "The quick brown fox"; if (!gpt_params_parse(argc, argv, params)) { @@ -24,7 +25,7 @@ int main(int argc, char ** argv) { } auto n_past = 0; - auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0); + auto last_n_tokens_data = std::vector<llama_token>(sparams.repeat_last_n, 0); // init llama_model * model; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8c5318c6..58af78de 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -200,6 +200,7 @@ struct llama_server_context llama_model *model = nullptr; llama_context *ctx = nullptr; gpt_params params; + llama_sampling_context ctx_sampling; int n_ctx; grammar_parser::parse_state parsed_grammar; @@ -254,6 +255,7 @@ struct llama_server_context if (grammar != nullptr) { llama_grammar_free(grammar); grammar = nullptr; + ctx_sampling = llama_sampling_context_init(params, NULL); } } @@ -329,8 +331,8 @@ struct llama_server_context grammar_parser::print_grammar(stderr, parsed_grammar); { - auto it = params.logit_bias.find(llama_token_eos(ctx)); - if (it != params.logit_bias.end() && it->second == -INFINITY) { + auto it = params.sampling_params.logit_bias.find(llama_token_eos(ctx)); + if (it != params.sampling_params.logit_bias.end() && it->second == -INFINITY) { LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); } } @@ -339,6 +341,7 @@ struct llama_server_context grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } + ctx_sampling = llama_sampling_context_init(params, grammar); return true; } @@ -550,12 +553,12 @@ struct llama_server_context std::vector<llama_token_data> candidates; candidates.reserve(llama_n_vocab(model)); - result.tok = llama_sample_token(ctx, NULL, grammar, params, last_n_tokens, candidates); + result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates); llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - const int32_t n_probs = params.n_probs; - if (params.temp <= 0 && n_probs > 0) + const int32_t n_probs = params.sampling_params.n_probs; + if (params.sampling_params.temp <= 0 && n_probs > 0) { // For llama_sample_token_greedy we need to sort candidates llama_sample_softmax(ctx, &candidates_p); @@ -630,7 +633,7 @@ struct llama_server_context const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok); generated_text += token_text; - if (params.n_probs > 0) + if (params.sampling_params.n_probs > 0) { generated_token_probs.push_back(token_with_probs); } @@ -1018,34 +1021,35 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, static json format_generation_settings(llama_server_context &llama) { - const auto eos_bias = llama.params.logit_bias.find(llama_token_eos(llama.ctx)); - const bool ignore_eos = eos_bias != llama.params.logit_bias.end() && + const auto & sparams = llama.params.sampling_params; + const auto eos_bias = sparams.logit_bias.find(llama_token_eos(llama.ctx)); + const bool ignore_eos = eos_bias != sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); return json{ {"n_ctx", llama.n_ctx}, {"model", llama.params.model_alias}, {"seed", llama.params.seed}, - {"temp", llama.params.temp}, - {"top_k", llama.params.top_k}, - {"top_p", llama.params.top_p}, - {"tfs_z", llama.params.tfs_z}, - {"typical_p", llama.params.typical_p}, - {"repeat_last_n", llama.params.repeat_last_n}, - {"repeat_penalty", llama.params.repeat_penalty}, - {"presence_penalty", llama.params.presence_penalty}, - {"frequency_penalty", llama.params.frequency_penalty}, - {"mirostat", llama.params.mirostat}, - {"mirostat_tau", llama.params.mirostat_tau}, - {"mirostat_eta", llama.params.mirostat_eta}, - {"penalize_nl", llama.params.penalize_nl}, + {"temp", sparams.temp}, + {"top_k", sparams.top_k}, + {"top_p", sparams.top_p}, + {"tfs_z", sparams.tfs_z}, + {"typical_p", sparams.typical_p}, + {"repeat_last_n", sparams.repeat_last_n}, + {"repeat_penalty", sparams.repeat_penalty}, + {"presence_penalty", sparams.presence_penalty}, + {"frequency_penalty", sparams.frequency_penalty}, + {"mirostat", sparams.mirostat}, + {"mirostat_tau", sparams.mirostat_tau}, + {"mirostat_eta", sparams.mirostat_eta}, + {"penalize_nl", sparams.penalize_nl}, {"stop", llama.params.antiprompt}, {"n_predict", llama.params.n_predict}, {"n_keep", llama.params.n_keep}, {"ignore_eos", ignore_eos}, {"stream", llama.stream}, - {"logit_bias", llama.params.logit_bias}, - {"n_probs", llama.params.n_probs}, + {"logit_bias", sparams.logit_bias}, + {"n_probs", sparams.n_probs}, {"grammar", llama.params.grammar}, }; } @@ -1094,7 +1098,7 @@ static json format_final_response(llama_server_context &llama, const std::string {"timings", format_timings(llama)}, }; - if (llama.params.n_probs > 0) + if (llama.params.sampling_params.n_probs > 0) { res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); } @@ -1110,7 +1114,7 @@ static json format_partial_response( {"stop", false}, }; - if (llama.params.n_probs > 0) + if (llama.params.sampling_params.n_probs > 0) { res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); } @@ -1142,26 +1146,28 @@ static T json_value(const json &body, const std::string &key, const T &default_v static void parse_options_completion(const json &body, llama_server_context &llama) { gpt_params default_params; + const auto & default_sparams = default_params.sampling_params; + auto & sparams = llama.params.sampling_params; llama.stream = json_value(body, "stream", false); llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict); - llama.params.top_k = json_value(body, "top_k", default_params.top_k); - llama.params.top_p = json_value(body, "top_p", default_params.top_p); - llama.params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z); - llama.params.typical_p = json_value(body, "typical_p", default_params.typical_p); - llama.params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n); - llama.params.temp = json_value(body, "temperature", default_params.temp); - llama.params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty); - llama.params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty); - llama.params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty); - llama.params.mirostat = json_value(body, "mirostat", default_params.mirostat); - llama.params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau); - llama.params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta); - llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl); + sparams.top_k = json_value(body, "top_k", default_sparams.top_k); + sparams.top_p = json_value(body, "top_p", default_sparams.top_p); + sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z); + sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p); + sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n); + sparams.temp = json_value(body, "temperature", default_sparams.temp); + sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty); + sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty); + sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty); + sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat); + sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); + sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); + sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl); llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep); llama.params.seed = json_value(body, "seed", default_params.seed); llama.params.grammar = json_value(body, "grammar", default_params.grammar); - llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs); + sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs); if (body.count("prompt") != 0) { @@ -1172,10 +1178,10 @@ static void parse_options_completion(const json &body, llama_server_context &lla llama.prompt = ""; } - llama.params.logit_bias.clear(); + sparams.logit_bias.clear(); if (json_value(body, "ignore_eos", false)) { - llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY; + sparams.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY; } const auto &logit_bias = body.find("logit_bias"); @@ -1191,11 +1197,11 @@ static void parse_options_completion(const json &body, llama_server_context &lla { if (el[1].is_number()) { - llama.params.logit_bias[tok] = el[1].get<float>(); + sparams.logit_bias[tok] = el[1].get<float>(); } else if (el[1].is_boolean() && !el[1].get<bool>()) { - llama.params.logit_bias[tok] = -INFINITY; + sparams.logit_bias[tok] = -INFINITY; } } } @@ -1215,6 +1221,8 @@ static void parse_options_completion(const json &body, llama_server_context &lla } } + llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar); + LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); } @@ -1423,7 +1431,7 @@ int main(int argc, char **argv) } auto probs = llama.generated_token_probs; - if (llama.params.n_probs > 0 && llama.stopped_word) { + if (llama.params.sampling_params.n_probs > 0 && llama.stopped_word) { const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false); probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size()); } @@ -1475,7 +1483,7 @@ int main(int argc, char **argv) std::vector<completion_token_output> probs_output = {}; - if (llama.params.n_probs > 0) { + if (llama.params.sampling_params.n_probs > 0) { const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false); size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); @@ -1596,7 +1604,7 @@ int main(int argc, char **argv) std::vector<completion_token_output> probs_output = {}; - if (llama.params.n_probs > 0) { + if (llama.params.sampling_params.n_probs > 0) { const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false); size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 75a2e5e2..018dbf9a 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -125,6 +125,8 @@ int main(int argc, char ** argv) { grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar_tgt); + const auto t_dec_start = ggml_time_us(); while (true) { @@ -134,7 +136,7 @@ int main(int argc, char ** argv) { while (true) { // sample from the target model - llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); + llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft); // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); @@ -211,7 +213,13 @@ int main(int argc, char ** argv) { if (grammar_dft) { llama_grammar_free(grammar_dft); } - grammar_dft = llama_grammar_copy(grammar_tgt); + // Note: Hardcoded to sequence id 0, if this ever supports parallel generation + // that will need to change. + auto it = ctx_sampling.sequence_contexts.find(0); + GGML_ASSERT(it != ctx_sampling.sequence_contexts.end()); + // This is necessary because each sequence id in sequence_contexts + // uses a copy of the original grammar. + grammar_dft = llama_grammar_copy(it->second.grammar); LOG("copied target grammar to draft grammar\n"); } |