summaryrefslogtreecommitdiff
path: root/examples/server
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server')
-rw-r--r--examples/server/CMakeLists.txt8
-rw-r--r--examples/server/server.cpp54
2 files changed, 60 insertions, 2 deletions
diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt
index 33fd25f0..20ddc5c5 100644
--- a/examples/server/CMakeLists.txt
+++ b/examples/server/CMakeLists.txt
@@ -37,7 +37,13 @@ install(TARGETS ${TARGET} RUNTIME)
target_compile_definitions(${TARGET} PRIVATE
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
)
-
+if (MSVC)
+ target_link_options(${TARGET} PRIVATE
+ $<$<CONFIG:DEBUG>:/STACK:20971520,1048576 >
+ $<$<CONFIG:RELEASE>:/STACK:20971520,1048576>
+ )
+endif()
+# target_link_libraries(${TARGET} PRIVATE "/STACK:104857600")
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 31f7383e..563570ad 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -977,6 +977,10 @@ struct server_context {
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
+ slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
+ slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
+ slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
+ slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
@@ -987,6 +991,42 @@ struct server_context {
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
+ if (slot.sparams.penalty_last_n < -1) {
+ throw std::runtime_error("Error: repeat_last_n must be >= -1");
+ }
+
+ if (slot.sparams.dry_penalty_last_n < -1) {
+ throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
+ }
+
+ if (slot.sparams.penalty_last_n == -1) {
+ // note: should be the slot's context and not the full context, but it's ok
+ slot.sparams.penalty_last_n = llama_n_ctx(ctx);
+ }
+
+ if (slot.sparams.dry_penalty_last_n == -1) {
+ slot.sparams.dry_penalty_last_n = llama_n_ctx(ctx);
+
+ }
+ if (slot.sparams.dry_base < 1.0f)
+ {
+ slot.sparams.dry_base = default_sparams.dry_base;
+ }
+
+ // sequence breakers for DRY
+ {
+ // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
+ // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
+
+ if (data.contains("dry_sequence_breakers")) {
+ slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
+ if (slot.sparams.dry_sequence_breakers.empty()) {
+ send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
+ return false;
+ }
+ }
+ }
+
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
@@ -1156,7 +1196,7 @@ struct server_context {
if (slot.ctx_sampling != nullptr) {
llama_sampling_free(slot.ctx_sampling);
}
- slot.ctx_sampling = llama_sampling_init(slot.sparams);
+ slot.ctx_sampling = llama_sampling_init(llama_get_model_vocab(model),slot.sparams);
if (slot.ctx_sampling == nullptr) {
// for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
@@ -1405,6 +1445,11 @@ struct server_context {
{"frequency_penalty", slot.sparams.penalty_freq},
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
+ {"dry_multiplier", slot.sparams.dry_multiplier},
+ {"dry_base", slot.sparams.dry_base},
+ {"dry_allowed_length", slot.sparams.dry_allowed_length},
+ {"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
+ {"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
{"mirostat", slot.sparams.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta},
@@ -2337,6 +2382,13 @@ struct server_context {
slot.command = SLOT_COMMAND_NONE;
GGML_ASSERT(batch.n_tokens > 0);
+ llama_sampling_reset(slot.ctx_sampling);
+ for (int i = 0; i < slot.n_prompt_tokens; ++i) {
+ llama_token id = slot.prompt_tokens[i];
+ if (id != LLAMA_TOKEN_NULL) {
+ llama_sampling_accept(slot.ctx_sampling, ctx, id, false);
+ }
+ }
// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;