summaryrefslogtreecommitdiff
path: root/examples/llama.swiftui
diff options
context:
space:
mode:
Diffstat (limited to 'examples/llama.swiftui')
-rw-r--r--examples/llama.swiftui/llama.cpp.swift/LibLlama.swift8
-rw-r--r--examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift2
2 files changed, 6 insertions, 4 deletions
diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
index 737f882f..58c32ca5 100644
--- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
+++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
@@ -26,11 +26,12 @@ actor LlamaContext {
private var context: OpaquePointer
private var batch: llama_batch
private var tokens_list: [llama_token]
+ var is_done: Bool = false
/// This variable is used to store temporarily invalid cchars
private var temporary_invalid_cchars: [CChar]
- var n_len: Int32 = 64
+ var n_len: Int32 = 1024
var n_cur: Int32 = 0
var n_decode: Int32 = 0
@@ -160,6 +161,7 @@ actor LlamaContext {
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
print("\n")
+ is_done = true
let new_token_str = String(cString: temporary_invalid_cchars + [0])
temporary_invalid_cchars.removeAll()
return new_token_str
@@ -322,7 +324,7 @@ actor LlamaContext {
defer {
result.deallocate()
}
- let nTokens = llama_token_to_piece(model, token, result, 8, false)
+ let nTokens = llama_token_to_piece(model, token, result, 8, 0, false)
if nTokens < 0 {
let newResult = UnsafeMutablePointer<Int8>.allocate(capacity: Int(-nTokens))
@@ -330,7 +332,7 @@ actor LlamaContext {
defer {
newResult.deallocate()
}
- let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens, false)
+ let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens, 0, false)
let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
return Array(bufferPointer)
} else {
diff --git a/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift b/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift
index 2c1e3f61..b8f6a31d 100644
--- a/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift
+++ b/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift
@@ -132,7 +132,7 @@ class LlamaState: ObservableObject {
messageLog += "\(text)"
Task.detached {
- while await llamaContext.n_cur < llamaContext.n_len {
+ while await !llamaContext.is_done {
let result = await llamaContext.completion_loop()
await MainActor.run {
self.messageLog += "\(result)"