diff options
author | Jan Boon <jan.boon@kaetemi.be> | 2024-03-26 16:47:43 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-26 10:47:43 +0200 |
commit | 3d032ece8e6973441273601ca2981130608e287d (patch) | |
tree | c1ee27234308702f8dbfb9b25ce46427a9c8ebde /examples/server/server.cpp | |
parent | e190f1fca6f60d80944f9e8709d343a025c4d245 (diff) |
server : add `n_discard` parameter (#6300)
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c4c545c3..526de596 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -99,6 +99,7 @@ struct slot_params { uint32_t seed = -1; // RNG seed int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half int32_t n_predict = -1; // new tokens to predict std::vector<std::string> antiprompt; @@ -846,6 +847,7 @@ struct server_context { slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); + slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); slot.params.seed = json_value(data, "seed", default_params.seed); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); @@ -1253,6 +1255,7 @@ struct server_context { {"stop", slot.params.antiprompt}, {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict {"n_keep", slot.params.n_keep}, + {"n_discard", slot.params.n_discard}, {"ignore_eos", ignore_eos}, {"stream", slot.params.stream}, {"logit_bias", slot.sparams.logit_bias}, @@ -1696,7 +1699,7 @@ struct server_context { // Shift context const int n_keep = slot.params.n_keep + add_bos_token; const int n_left = (int) system_tokens.size() + slot.n_past - n_keep; - const int n_discard = n_left / 2; + const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); LOG_INFO("slot context shift", { {"id_slot", slot.id}, |