summaryrefslogtreecommitdiff
path: root/examples/infill
diff options
context:
space:
mode:
Diffstat (limited to 'examples/infill')
-rw-r--r--examples/infill/CMakeLists.txt2
-rw-r--r--examples/infill/infill.cpp67
2 files changed, 18 insertions, 51 deletions
diff --git a/examples/infill/CMakeLists.txt b/examples/infill/CMakeLists.txt
index 046f9b1e..57d01cb0 100644
--- a/examples/infill/CMakeLists.txt
+++ b/examples/infill/CMakeLists.txt
@@ -4,5 +4,5 @@ install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
- add_dependencies(${TARGET} BUILD_INFO)
+ add_dependencies(${TARGET} BUILD_INFO)
endif()
diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp
index 128d6708..6331335e 100644
--- a/examples/infill/infill.cpp
+++ b/examples/infill/infill.cpp
@@ -39,8 +39,8 @@ static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
-static bool is_interacting = false;
+static bool is_interacting = false;
static void write_logfile(
const llama_context * ctx, const gpt_params & params, const llama_model * model,
@@ -104,7 +104,7 @@ static void sigint_handler(int signo) {
int main(int argc, char ** argv) {
gpt_params params;
- llama_sampling_params & sparams = params.sampling_params;
+ llama_sampling_params & sparams = params.sparams;
g_params = &params;
if (!gpt_params_parse(argc, argv, params)) {
@@ -358,36 +358,10 @@ int main(int argc, char ** argv) {
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
}
}
- LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
- sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
+ LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
- struct llama_grammar * grammar = NULL;
- grammar_parser::parse_state parsed_grammar;
-
- if (!params.grammar.empty()) {
- parsed_grammar = grammar_parser::parse(params.grammar.c_str());
- // will be empty (default) if there are parse errors
- if (parsed_grammar.rules.empty()) {
- return 1;
- }
- LOG_TEE("%s: grammar:\n", __func__);
- grammar_parser::print_grammar(stderr, parsed_grammar);
- LOG_TEE("\n");
-
- {
- auto it = sparams.logit_bias.find(llama_token_eos(ctx));
- if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
- LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
- }
- }
-
- std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
- grammar = llama_grammar_init(
- grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
- }
-
LOG_TEE("\n##### Infill mode #####\n\n");
if (params.infill) {
printf("\n************\n");
@@ -430,7 +404,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
- struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
+ struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
while (n_remain != 0 || params.interactive) {
// predict
@@ -549,7 +523,7 @@ int main(int argc, char ** argv) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
- llama_sampling_accept(ctx_sampling, ctx, id);
+ llama_sampling_accept(ctx_sampling, ctx, id, true);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
@@ -567,8 +541,11 @@ int main(int argc, char ** argv) {
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
- ctx_sampling->prev.erase(ctx_sampling->prev.begin());
- ctx_sampling->prev.push_back(embd_inp[n_consumed]);
+
+ // push the prompt in the sampling context in order to apply repetition penalties later
+ // for the prompt, we don't apply grammar rules
+ llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
+
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
@@ -600,7 +577,7 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode
- if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
+ if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){
if(is_interacting && !params.interactive_first) {
// print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
@@ -617,7 +594,7 @@ int main(int argc, char ** argv) {
buffer += line;
} while (another_line);
// check if we got an empty line, if so we use the old input
- if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
+ if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
params.input_prefix = buffer;
}
buffer.clear();
@@ -627,7 +604,7 @@ int main(int argc, char ** argv) {
buffer += line;
} while (another_line);
// check if we got an empty line
- if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
+ if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
params.input_suffix = buffer;
}
buffer.clear();
@@ -640,7 +617,7 @@ int main(int argc, char ** argv) {
process_escapes(params.input_suffix);
}
suff_rm_leading_spc = params.escape;
- if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
+ if (suff_rm_leading_spc && params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
@@ -667,7 +644,7 @@ int main(int argc, char ** argv) {
is_interacting = false;
}
// deal with end of text token in interactive mode
- else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
+ else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
LOG("found EOS token\n");
if (params.interactive) {
@@ -740,15 +717,7 @@ int main(int argc, char ** argv) {
if (n_past > 0) {
if (is_interacting) {
- // reset grammar state if we're restarting generation
- if (grammar != NULL) {
- llama_grammar_free(grammar);
-
- std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
- grammar = llama_grammar_init(
- grammar_rules.data(), grammar_rules.size(),
- parsed_grammar.symbol_ids.at("root"));
- }
+ llama_sampling_reset(ctx_sampling);
}
is_interacting = false;
}
@@ -778,9 +747,7 @@ int main(int argc, char ** argv) {
llama_free(ctx);
llama_free_model(model);
- if (grammar != NULL) {
- llama_grammar_free(grammar);
- }
+ llama_sampling_free(ctx_sampling);
llama_backend_free();
#ifndef LOG_DISABLE_LOGS