summaryrefslogtreecommitdiff
path: root/examples/batched.swift/Sources
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-10-18 16:21:57 +0300
committerGitHub <noreply@github.com>2023-10-18 16:21:57 +0300
commit0e89203b517c95ec6675eda75d200a60d1e8921d (patch)
tree3aba40ef0362d061f240bd43c52e86a8f728f89d /examples/batched.swift/Sources
parentc67fe68e417f766970fb1feaf2e66458aa24116a (diff)
speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants
Diffstat (limited to 'examples/batched.swift/Sources')
-rw-r--r--examples/batched.swift/Sources/main.swift14
1 files changed, 11 insertions, 3 deletions
diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift
index 05d1bb9d..77273038 100644
--- a/examples/batched.swift/Sources/main.swift
+++ b/examples/batched.swift/Sources/main.swift
@@ -69,7 +69,7 @@ for id: llama_token in tokens {
print("\n")
-var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0)
+var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0, 1)
defer {
llama_batch_free(batch)
}
@@ -80,7 +80,12 @@ batch.n_tokens = Int32(tokens.count)
for (i, token) in tokens.enumerated() {
batch.token[i] = token
batch.pos[i] = Int32(i)
- batch.seq_id[i] = 0
+ batch.n_seq_id[i] = 1
+ // batch.seq_id[i][0] = 0
+ // TODO: is this the proper way to do this?
+ if let seq_id = batch.seq_id[i] {
+ seq_id[0] = 0
+ }
batch.logits[i] = 0
}
@@ -169,7 +174,10 @@ while n_cur <= n_len {
// push this new token for next evaluation
batch.token[Int(batch.n_tokens)] = new_token_id
batch.pos[Int(batch.n_tokens)] = n_cur
- batch.seq_id[Int(batch.n_tokens)] = Int32(i)
+ batch.n_seq_id[Int(batch.n_tokens)] = 1
+ if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
+ seq_id[0] = Int32(i)
+ }
batch.logits[Int(batch.n_tokens)] = 1
i_batch[i] = batch.n_tokens