summaryrefslogtreecommitdiff
path: root/llama.h
diff options
context:
space:
mode:
authorslaren <slarengh@gmail.com>2024-03-13 18:54:21 +0100
committerGitHub <noreply@github.com>2024-03-13 18:54:21 +0100
commitf30ea47a87ed4446ad55adb265755dc9102956a2 (patch)
treefc885962ca3d537cfdfbd6b4a2820b7c864b1ee0 /llama.h
parentd8fd0ccf6ac8b07791ffd1575eed436930854ae3 (diff)
llama : add pipeline parallelism support (#6017)
* llama : add pipeline parallelism support for batch processing with multiple CUDA GPUs ggml-ci * server : add -ub, --ubatch-size parameter * fix server embedding test * llama : fix Mamba inference for pipeline parallelism Tested to work correctly with both `main` and `parallel` examples. * llama : limit max batch size to n_batch * add LLAMA_SCHED_MAX_COPIES to configure the number of input copies for pipeline parallelism default increase to 4 (from 2) changing this value may improve performance for some systems, but increases memory usage * fix hip build * fix sycl build (disable cpy_tensor_async) * fix hip build * llama : limit n_batch and n_ubatch to n_ctx during context creation * llama : fix norm backend * batched-bench : sync after decode * swiftui : sync after decode * ggml : allow ggml_get_rows to use multiple threads if they are available * check n_ubatch >= n_tokens with non-casual attention * llama : do not limit n_batch to n_ctx with non-casual attn * server : construct batch with size of llama_n_batch * ggml_backend_cpu_graph_compute : fix return value when alloc fails * llama : better n_batch and n_ubatch comment * fix merge * small fix * reduce default n_batch to 2048 --------- Co-authored-by: Francis Couture-Harpin <git@compilade.net> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'llama.h')
-rw-r--r--llama.h9
1 files changed, 8 insertions, 1 deletions
diff --git a/llama.h b/llama.h
index 446899da..2d16cc9b 100644
--- a/llama.h
+++ b/llama.h
@@ -234,7 +234,8 @@ extern "C" {
struct llama_context_params {
uint32_t seed; // RNG seed, -1 for random
uint32_t n_ctx; // text context, 0 = from model
- uint32_t n_batch; // prompt processing maximum batch size
+ uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
+ uint32_t n_ubatch; // physical maximum batch size
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
@@ -377,6 +378,7 @@ extern "C" {
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
+ LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
@@ -650,6 +652,11 @@ extern "C" {
// Set abort callback
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
+ // Wait until all computations are finished
+ // This is automatically done when using one of the functions below to obtain the computation results
+ // and is not necessary to call it explicitly in most cases
+ LLAMA_API void llama_synchronize(struct llama_context * ctx);
+
// Token logits obtained from the last call to llama_decode()
// The logits for the last token are stored in the last row
// Logits for which llama_batch.logits[i] == 0 are undefined