summaryrefslogtreecommitdiff
path: root/examples/llama.android/app/src/main/java/com
diff options
context:
space:
mode:
authorElton Kola <eltonkola@gmail.com>2024-05-14 03:30:30 -0400
committerGitHub <noreply@github.com>2024-05-14 17:30:30 +1000
commitefc8f767c8c8c749a245dd96ad4e2f37c164b54c (patch)
tree4c2910c9f7b3fcd27fbc1d04dd5188be962a8a5d /examples/llama.android/app/src/main/java/com
parente0f556186b6e1f2b7032a1479edf5e89e2b1bd86 (diff)
move ndk code to a new library (#6951)
Diffstat (limited to 'examples/llama.android/app/src/main/java/com')
-rw-r--r--examples/llama.android/app/src/main/java/com/example/llama/Llm.kt172
-rw-r--r--examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt13
2 files changed, 7 insertions, 178 deletions
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt b/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt
deleted file mode 100644
index d86afee3..00000000
--- a/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt
+++ /dev/null
@@ -1,172 +0,0 @@
-package com.example.llama
-
-import android.util.Log
-import kotlinx.coroutines.CoroutineDispatcher
-import kotlinx.coroutines.asCoroutineDispatcher
-import kotlinx.coroutines.flow.Flow
-import kotlinx.coroutines.flow.flow
-import kotlinx.coroutines.flow.flowOn
-import kotlinx.coroutines.withContext
-import java.util.concurrent.Executors
-import kotlin.concurrent.thread
-
-class Llm {
- private val tag: String? = this::class.simpleName
-
- private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
-
- private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
- thread(start = false, name = "Llm-RunLoop") {
- Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
-
- // No-op if called more than once.
- System.loadLibrary("llama-android")
-
- // Set llama log handler to Android
- log_to_android()
- backend_init(false)
-
- Log.d(tag, system_info())
-
- it.run()
- }.apply {
- uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
- Log.e(tag, "Unhandled exception", exception)
- }
- }
- }.asCoroutineDispatcher()
-
- private val nlen: Int = 64
-
- private external fun log_to_android()
- private external fun load_model(filename: String): Long
- private external fun free_model(model: Long)
- private external fun new_context(model: Long): Long
- private external fun free_context(context: Long)
- private external fun backend_init(numa: Boolean)
- private external fun backend_free()
- private external fun free_batch(batch: Long)
- private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
- private external fun bench_model(
- context: Long,
- model: Long,
- batch: Long,
- pp: Int,
- tg: Int,
- pl: Int,
- nr: Int
- ): String
-
- private external fun system_info(): String
-
- private external fun completion_init(
- context: Long,
- batch: Long,
- text: String,
- nLen: Int
- ): Int
-
- private external fun completion_loop(
- context: Long,
- batch: Long,
- nLen: Int,
- ncur: IntVar
- ): String?
-
- private external fun kv_cache_clear(context: Long)
-
- suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
- return withContext(runLoop) {
- when (val state = threadLocalState.get()) {
- is State.Loaded -> {
- Log.d(tag, "bench(): $state")
- bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
- }
-
- else -> throw IllegalStateException("No model loaded")
- }
- }
- }
-
- suspend fun load(pathToModel: String) {
- withContext(runLoop) {
- when (threadLocalState.get()) {
- is State.Idle -> {
- val model = load_model(pathToModel)
- if (model == 0L) throw IllegalStateException("load_model() failed")
-
- val context = new_context(model)
- if (context == 0L) throw IllegalStateException("new_context() failed")
-
- val batch = new_batch(512, 0, 1)
- if (batch == 0L) throw IllegalStateException("new_batch() failed")
-
- Log.i(tag, "Loaded model $pathToModel")
- threadLocalState.set(State.Loaded(model, context, batch))
- }
- else -> throw IllegalStateException("Model already loaded")
- }
- }
- }
-
- fun send(message: String): Flow<String> = flow {
- when (val state = threadLocalState.get()) {
- is State.Loaded -> {
- val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
- while (ncur.value <= nlen) {
- val str = completion_loop(state.context, state.batch, nlen, ncur)
- if (str == null) {
- break
- }
- emit(str)
- }
- kv_cache_clear(state.context)
- }
- else -> {}
- }
- }.flowOn(runLoop)
-
- /**
- * Unloads the model and frees resources.
- *
- * This is a no-op if there's no model loaded.
- */
- suspend fun unload() {
- withContext(runLoop) {
- when (val state = threadLocalState.get()) {
- is State.Loaded -> {
- free_context(state.context)
- free_model(state.model)
- free_batch(state.batch)
-
- threadLocalState.set(State.Idle)
- }
- else -> {}
- }
- }
- }
-
- companion object {
- private class IntVar(value: Int) {
- @Volatile
- var value: Int = value
- private set
-
- fun inc() {
- synchronized(this) {
- value += 1
- }
- }
- }
-
- private sealed interface State {
- data object Idle: State
- data class Loaded(val model: Long, val context: Long, val batch: Long): State
- }
-
- // Enforce only one instance of Llm.
- private val _instance: Llm = Llm()
-
- fun instance(): Llm = _instance
- }
-}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt
index be95e222..45ac2993 100644
--- a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt
+++ b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt
@@ -1,5 +1,6 @@
package com.example.llama
+import android.llama.cpp.LLamaAndroid
import android.util.Log
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
@@ -9,7 +10,7 @@ import androidx.lifecycle.viewModelScope
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.launch
-class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
+class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() {
companion object {
@JvmStatic
private val NanosPerSecond = 1_000_000_000.0
@@ -28,7 +29,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
viewModelScope.launch {
try {
- llm.unload()
+ llamaAndroid.unload()
} catch (exc: IllegalStateException) {
messages += exc.message!!
}
@@ -44,7 +45,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
messages += ""
viewModelScope.launch {
- llm.send(text)
+ llamaAndroid.send(text)
.catch {
Log.e(tag, "send() failed", it)
messages += it.message!!
@@ -57,7 +58,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
viewModelScope.launch {
try {
val start = System.nanoTime()
- val warmupResult = llm.bench(pp, tg, pl, nr)
+ val warmupResult = llamaAndroid.bench(pp, tg, pl, nr)
val end = System.nanoTime()
messages += warmupResult
@@ -70,7 +71,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
return@launch
}
- messages += llm.bench(512, 128, 1, 3)
+ messages += llamaAndroid.bench(512, 128, 1, 3)
} catch (exc: IllegalStateException) {
Log.e(tag, "bench() failed", exc)
messages += exc.message!!
@@ -81,7 +82,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
fun load(pathToModel: String) {
viewModelScope.launch {
try {
- llm.load(pathToModel)
+ llamaAndroid.load(pathToModel)
messages += "Loaded $pathToModel"
} catch (exc: IllegalStateException) {
Log.e(tag, "load() failed", exc)