summaryrefslogtreecommitdiff
path: root/llama.h
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-10-18 16:21:57 +0300
committerGitHub <noreply@github.com>2023-10-18 16:21:57 +0300
commit0e89203b517c95ec6675eda75d200a60d1e8921d (patch)
tree3aba40ef0362d061f240bd43c52e86a8f728f89d /llama.h
parentc67fe68e417f766970fb1feaf2e66458aa24116a (diff)
speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants
Diffstat (limited to 'llama.h')
-rw-r--r--llama.h17
1 files changed, 10 insertions, 7 deletions
diff --git a/llama.h b/llama.h
index b13f2312..51010e03 100644
--- a/llama.h
+++ b/llama.h
@@ -133,11 +133,12 @@ extern "C" {
typedef struct llama_batch {
int32_t n_tokens;
- llama_token * token;
- float * embd;
- llama_pos * pos;
- llama_seq_id * seq_id;
- int8_t * logits;
+ llama_token * token;
+ float * embd;
+ llama_pos * pos;
+ int32_t * n_seq_id;
+ llama_seq_id ** seq_id;
+ int8_t * logits;
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below
@@ -446,7 +447,8 @@ extern "C" {
llama_pos pos_0,
llama_seq_id seq_id);
- // Allocates a batch of tokens on the heap
+ // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
+ // Each token can be assigned up to n_seq_max sequence ids
// The batch has to be freed with llama_batch_free()
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
@@ -454,7 +456,8 @@ extern "C" {
// All members are left uninitialized
LLAMA_API struct llama_batch llama_batch_init(
int32_t n_tokens,
- int32_t embd);
+ int32_t embd,
+ int32_t n_seq_max);
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);