summaryrefslogtreecommitdiff
path: root/examples/batched.swift/Sources
diff options
context:
space:
mode:
Diffstat (limited to 'examples/batched.swift/Sources')
-rw-r--r--examples/batched.swift/Sources/main.swift14
1 files changed, 11 insertions, 3 deletions
diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift
index 05d1bb9d..77273038 100644
--- a/examples/batched.swift/Sources/main.swift
+++ b/examples/batched.swift/Sources/main.swift
@@ -69,7 +69,7 @@ for id: llama_token in tokens {
print("\n")
-var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0)
+var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0, 1)
defer {
llama_batch_free(batch)
}
@@ -80,7 +80,12 @@ batch.n_tokens = Int32(tokens.count)
for (i, token) in tokens.enumerated() {
batch.token[i] = token
batch.pos[i] = Int32(i)
- batch.seq_id[i] = 0
+ batch.n_seq_id[i] = 1
+ // batch.seq_id[i][0] = 0
+ // TODO: is this the proper way to do this?
+ if let seq_id = batch.seq_id[i] {
+ seq_id[0] = 0
+ }
batch.logits[i] = 0
}
@@ -169,7 +174,10 @@ while n_cur <= n_len {
// push this new token for next evaluation
batch.token[Int(batch.n_tokens)] = new_token_id
batch.pos[Int(batch.n_tokens)] = n_cur
- batch.seq_id[Int(batch.n_tokens)] = Int32(i)
+ batch.n_seq_id[Int(batch.n_tokens)] = 1
+ if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
+ seq_id[0] = Int32(i)
+ }
batch.logits[Int(batch.n_tokens)] = 1
i_batch[i] = batch.n_tokens