summaryrefslogtreecommitdiff
path: root/examples/llama.android/app/src/main/java/com
diff options
context:
space:
mode:
authorBrian <mofosyne@gmail.com>2024-05-14 23:10:39 +1000
committerGitHub <noreply@github.com>2024-05-14 16:10:39 +0300
commit1265c670fd8e41e1947352c96c5179adda97fb2c (patch)
tree8ef6858e5701e5ff6e7a4e656e26589af5067963 /examples/llama.android/app/src/main/java/com
parent5e31828d3e35c76ecfee665bc23771a4bec1d130 (diff)
Revert "move ndk code to a new library (#6951)" (#7282)
This reverts commit efc8f767c8c8c749a245dd96ad4e2f37c164b54c.
Diffstat (limited to 'examples/llama.android/app/src/main/java/com')
-rw-r--r--examples/llama.android/app/src/main/java/com/example/llama/Llm.kt172
-rw-r--r--examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt13
2 files changed, 178 insertions, 7 deletions
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt b/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt
new file mode 100644
index 00000000..d86afee3
--- /dev/null
+++ b/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt
@@ -0,0 +1,172 @@
+package com.example.llama
+
+import android.util.Log
+import kotlinx.coroutines.CoroutineDispatcher
+import kotlinx.coroutines.asCoroutineDispatcher
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.flow
+import kotlinx.coroutines.flow.flowOn
+import kotlinx.coroutines.withContext
+import java.util.concurrent.Executors
+import kotlin.concurrent.thread
+
+class Llm {
+ private val tag: String? = this::class.simpleName
+
+ private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
+
+ private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
+ thread(start = false, name = "Llm-RunLoop") {
+ Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
+
+ // No-op if called more than once.
+ System.loadLibrary("llama-android")
+
+ // Set llama log handler to Android
+ log_to_android()
+ backend_init(false)
+
+ Log.d(tag, system_info())
+
+ it.run()
+ }.apply {
+ uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
+ Log.e(tag, "Unhandled exception", exception)
+ }
+ }
+ }.asCoroutineDispatcher()
+
+ private val nlen: Int = 64
+
+ private external fun log_to_android()
+ private external fun load_model(filename: String): Long
+ private external fun free_model(model: Long)
+ private external fun new_context(model: Long): Long
+ private external fun free_context(context: Long)
+ private external fun backend_init(numa: Boolean)
+ private external fun backend_free()
+ private external fun free_batch(batch: Long)
+ private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
+ private external fun bench_model(
+ context: Long,
+ model: Long,
+ batch: Long,
+ pp: Int,
+ tg: Int,
+ pl: Int,
+ nr: Int
+ ): String
+
+ private external fun system_info(): String
+
+ private external fun completion_init(
+ context: Long,
+ batch: Long,
+ text: String,
+ nLen: Int
+ ): Int
+
+ private external fun completion_loop(
+ context: Long,
+ batch: Long,
+ nLen: Int,
+ ncur: IntVar
+ ): String?
+
+ private external fun kv_cache_clear(context: Long)
+
+ suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
+ return withContext(runLoop) {
+ when (val state = threadLocalState.get()) {
+ is State.Loaded -> {
+ Log.d(tag, "bench(): $state")
+ bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
+ }
+
+ else -> throw IllegalStateException("No model loaded")
+ }
+ }
+ }
+
+ suspend fun load(pathToModel: String) {
+ withContext(runLoop) {
+ when (threadLocalState.get()) {
+ is State.Idle -> {
+ val model = load_model(pathToModel)
+ if (model == 0L) throw IllegalStateException("load_model() failed")
+
+ val context = new_context(model)
+ if (context == 0L) throw IllegalStateException("new_context() failed")
+
+ val batch = new_batch(512, 0, 1)
+ if (batch == 0L) throw IllegalStateException("new_batch() failed")
+
+ Log.i(tag, "Loaded model $pathToModel")
+ threadLocalState.set(State.Loaded(model, context, batch))
+ }
+ else -> throw IllegalStateException("Model already loaded")
+ }
+ }
+ }
+
+ fun send(message: String): Flow<String> = flow {
+ when (val state = threadLocalState.get()) {
+ is State.Loaded -> {
+ val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
+ while (ncur.value <= nlen) {
+ val str = completion_loop(state.context, state.batch, nlen, ncur)
+ if (str == null) {
+ break
+ }
+ emit(str)
+ }
+ kv_cache_clear(state.context)
+ }
+ else -> {}
+ }
+ }.flowOn(runLoop)
+
+ /**
+ * Unloads the model and frees resources.
+ *
+ * This is a no-op if there's no model loaded.
+ */
+ suspend fun unload() {
+ withContext(runLoop) {
+ when (val state = threadLocalState.get()) {
+ is State.Loaded -> {
+ free_context(state.context)
+ free_model(state.model)
+ free_batch(state.batch)
+
+ threadLocalState.set(State.Idle)
+ }
+ else -> {}
+ }
+ }
+ }
+
+ companion object {
+ private class IntVar(value: Int) {
+ @Volatile
+ var value: Int = value
+ private set
+
+ fun inc() {
+ synchronized(this) {
+ value += 1
+ }
+ }
+ }
+
+ private sealed interface State {
+ data object Idle: State
+ data class Loaded(val model: Long, val context: Long, val batch: Long): State
+ }
+
+ // Enforce only one instance of Llm.
+ private val _instance: Llm = Llm()
+
+ fun instance(): Llm = _instance
+ }
+}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt
index 45ac2993..be95e222 100644
--- a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt
+++ b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt
@@ -1,6 +1,5 @@
package com.example.llama
-import android.llama.cpp.LLamaAndroid
import android.util.Log
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
@@ -10,7 +9,7 @@ import androidx.lifecycle.viewModelScope
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.launch
-class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() {
+class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
companion object {
@JvmStatic
private val NanosPerSecond = 1_000_000_000.0
@@ -29,7 +28,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
viewModelScope.launch {
try {
- llamaAndroid.unload()
+ llm.unload()
} catch (exc: IllegalStateException) {
messages += exc.message!!
}
@@ -45,7 +44,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
messages += ""
viewModelScope.launch {
- llamaAndroid.send(text)
+ llm.send(text)
.catch {
Log.e(tag, "send() failed", it)
messages += it.message!!
@@ -58,7 +57,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
viewModelScope.launch {
try {
val start = System.nanoTime()
- val warmupResult = llamaAndroid.bench(pp, tg, pl, nr)
+ val warmupResult = llm.bench(pp, tg, pl, nr)
val end = System.nanoTime()
messages += warmupResult
@@ -71,7 +70,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
return@launch
}
- messages += llamaAndroid.bench(512, 128, 1, 3)
+ messages += llm.bench(512, 128, 1, 3)
} catch (exc: IllegalStateException) {
Log.e(tag, "bench() failed", exc)
messages += exc.message!!
@@ -82,7 +81,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
fun load(pathToModel: String) {
viewModelScope.launch {
try {
- llamaAndroid.load(pathToModel)
+ llm.load(pathToModel)
messages += "Loaded $pathToModel"
} catch (exc: IllegalStateException) {
Log.e(tag, "load() failed", exc)