diff options
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 32 |
1 files changed, 24 insertions, 8 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3172d96d..895d608f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -147,7 +147,7 @@ struct server_slot { int32_t n_decoded = 0; int32_t n_remaining = -1; int32_t i_batch = -1; - int32_t n_predict = -1; + int32_t n_predict = -1; // TODO: disambiguate from params.n_predict int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_processed = 0; @@ -739,7 +739,13 @@ struct server_context { default_generation_settings_for_props = get_formated_generation(slots.front()); default_generation_settings_for_props["seed"] = -1; - batch = llama_batch_init(n_ctx, 0, params.n_parallel); + // the update_slots() logic will always submit a maximum of n_batch tokens + // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) + { + const int32_t n_batch = llama_n_batch(ctx); + + batch = llama_batch_init(n_batch, 0, params.n_parallel); + } metrics.init(); } @@ -1036,8 +1042,10 @@ struct server_context { llama_batch_add(batch, system_tokens[i], i, { 0 }, false); } - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch) { - const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i)); + const int32_t n_batch = llama_n_batch(ctx); + + for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i); llama_batch batch_view = { n_tokens, batch.token + i, @@ -1226,7 +1234,7 @@ struct server_context { {"mirostat_eta", slot.sparams.mirostat_eta}, {"penalize_nl", slot.sparams.penalize_nl}, {"stop", slot.params.antiprompt}, - {"n_predict", slot.params.n_predict}, + {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict {"n_keep", params.n_keep}, {"ignore_eos", ignore_eos}, {"stream", slot.params.stream}, @@ -1738,7 +1746,8 @@ struct server_context { } // process in chunks of params.n_batch - int32_t n_batch = params.n_batch; + int32_t n_batch = llama_n_batch(ctx); + int32_t n_ubatch = llama_n_ubatch(ctx); // next, batch any pending prompts without exceeding n_batch if (params.cont_batching || batch.n_tokens == 0) { @@ -1811,7 +1820,7 @@ struct server_context { if (slot.embedding) { // this prompt is too large to process - discard it - if (slot.n_prompt_tokens > n_batch) { + if (slot.n_prompt_tokens > n_ubatch) { slot.state = SLOT_STATE_PROCESSING; slot.command = SLOT_COMMAND_NONE; slot.release(); @@ -2157,7 +2166,8 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --pooling {none,mean,cls} pooling type for embeddings, use model default if unspecified\n"); printf(" -dt N, --defrag-thold N\n"); printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold); - printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); + printf(" -b N, --batch-size N logical maximum batch size (default: %d)\n", params.n_batch); + printf(" -ub N, --ubatch-size N physical maximum batch size (default: %d)\n", params.n_ubatch); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); if (llama_supports_mlock()) { @@ -2424,6 +2434,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, break; } params.n_batch = std::stoi(argv[i]); + } else if (arg == "-ub" || arg == "--ubatch-size") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_ubatch = std::stoi(argv[i]); } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { if (++i >= argc) { invalid_param = true; |