From 0e89203b517c95ec6675eda75d200a60d1e8921d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 18 Oct 2023 16:21:57 +0300 Subject: speculative : add tree-based sampling example (#3624) * sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants --- examples/batched.swift/Sources/main.swift | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'examples/batched.swift/Sources') diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 05d1bb9d..77273038 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -69,7 +69,7 @@ for id: llama_token in tokens { print("\n") -var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0) +var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0, 1) defer { llama_batch_free(batch) } @@ -80,7 +80,12 @@ batch.n_tokens = Int32(tokens.count) for (i, token) in tokens.enumerated() { batch.token[i] = token batch.pos[i] = Int32(i) - batch.seq_id[i] = 0 + batch.n_seq_id[i] = 1 + // batch.seq_id[i][0] = 0 + // TODO: is this the proper way to do this? + if let seq_id = batch.seq_id[i] { + seq_id[0] = 0 + } batch.logits[i] = 0 } @@ -169,7 +174,10 @@ while n_cur <= n_len { // push this new token for next evaluation batch.token[Int(batch.n_tokens)] = new_token_id batch.pos[Int(batch.n_tokens)] = n_cur - batch.seq_id[Int(batch.n_tokens)] = Int32(i) + batch.n_seq_id[Int(batch.n_tokens)] = 1 + if let seq_id = batch.seq_id[Int(batch.n_tokens)] { + seq_id[0] = Int32(i) + } batch.logits[Int(batch.n_tokens)] = 1 i_batch[i] = batch.n_tokens -- cgit v1.2.3