diff options
author | Elton Kola <eltonkola@gmail.com> | 2024-05-14 03:30:30 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-14 17:30:30 +1000 |
commit | efc8f767c8c8c749a245dd96ad4e2f37c164b54c (patch) | |
tree | 4c2910c9f7b3fcd27fbc1d04dd5188be962a8a5d /examples/llama.android/app/src/main/java/com | |
parent | e0f556186b6e1f2b7032a1479edf5e89e2b1bd86 (diff) |
move ndk code to a new library (#6951)
Diffstat (limited to 'examples/llama.android/app/src/main/java/com')
-rw-r--r-- | examples/llama.android/app/src/main/java/com/example/llama/Llm.kt | 172 | ||||
-rw-r--r-- | examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt | 13 |
2 files changed, 7 insertions, 178 deletions
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt b/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt deleted file mode 100644 index d86afee3..00000000 --- a/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt +++ /dev/null @@ -1,172 +0,0 @@ -package com.example.llama - -import android.util.Log -import kotlinx.coroutines.CoroutineDispatcher -import kotlinx.coroutines.asCoroutineDispatcher -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.flow.flowOn -import kotlinx.coroutines.withContext -import java.util.concurrent.Executors -import kotlin.concurrent.thread - -class Llm { - private val tag: String? = this::class.simpleName - - private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle } - - private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor { - thread(start = false, name = "Llm-RunLoop") { - Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}") - - // No-op if called more than once. - System.loadLibrary("llama-android") - - // Set llama log handler to Android - log_to_android() - backend_init(false) - - Log.d(tag, system_info()) - - it.run() - }.apply { - uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable -> - Log.e(tag, "Unhandled exception", exception) - } - } - }.asCoroutineDispatcher() - - private val nlen: Int = 64 - - private external fun log_to_android() - private external fun load_model(filename: String): Long - private external fun free_model(model: Long) - private external fun new_context(model: Long): Long - private external fun free_context(context: Long) - private external fun backend_init(numa: Boolean) - private external fun backend_free() - private external fun free_batch(batch: Long) - private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long - private external fun bench_model( - context: Long, - model: Long, - batch: Long, - pp: Int, - tg: Int, - pl: Int, - nr: Int - ): String - - private external fun system_info(): String - - private external fun completion_init( - context: Long, - batch: Long, - text: String, - nLen: Int - ): Int - - private external fun completion_loop( - context: Long, - batch: Long, - nLen: Int, - ncur: IntVar - ): String? - - private external fun kv_cache_clear(context: Long) - - suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { - return withContext(runLoop) { - when (val state = threadLocalState.get()) { - is State.Loaded -> { - Log.d(tag, "bench(): $state") - bench_model(state.context, state.model, state.batch, pp, tg, pl, nr) - } - - else -> throw IllegalStateException("No model loaded") - } - } - } - - suspend fun load(pathToModel: String) { - withContext(runLoop) { - when (threadLocalState.get()) { - is State.Idle -> { - val model = load_model(pathToModel) - if (model == 0L) throw IllegalStateException("load_model() failed") - - val context = new_context(model) - if (context == 0L) throw IllegalStateException("new_context() failed") - - val batch = new_batch(512, 0, 1) - if (batch == 0L) throw IllegalStateException("new_batch() failed") - - Log.i(tag, "Loaded model $pathToModel") - threadLocalState.set(State.Loaded(model, context, batch)) - } - else -> throw IllegalStateException("Model already loaded") - } - } - } - - fun send(message: String): Flow<String> = flow { - when (val state = threadLocalState.get()) { - is State.Loaded -> { - val ncur = IntVar(completion_init(state.context, state.batch, message, nlen)) - while (ncur.value <= nlen) { - val str = completion_loop(state.context, state.batch, nlen, ncur) - if (str == null) { - break - } - emit(str) - } - kv_cache_clear(state.context) - } - else -> {} - } - }.flowOn(runLoop) - - /** - * Unloads the model and frees resources. - * - * This is a no-op if there's no model loaded. - */ - suspend fun unload() { - withContext(runLoop) { - when (val state = threadLocalState.get()) { - is State.Loaded -> { - free_context(state.context) - free_model(state.model) - free_batch(state.batch) - - threadLocalState.set(State.Idle) - } - else -> {} - } - } - } - - companion object { - private class IntVar(value: Int) { - @Volatile - var value: Int = value - private set - - fun inc() { - synchronized(this) { - value += 1 - } - } - } - - private sealed interface State { - data object Idle: State - data class Loaded(val model: Long, val context: Long, val batch: Long): State - } - - // Enforce only one instance of Llm. - private val _instance: Llm = Llm() - - fun instance(): Llm = _instance - } -} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt index be95e222..45ac2993 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt @@ -1,5 +1,6 @@ package com.example.llama +import android.llama.cpp.LLamaAndroid import android.util.Log import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf @@ -9,7 +10,7 @@ import androidx.lifecycle.viewModelScope import kotlinx.coroutines.flow.catch import kotlinx.coroutines.launch -class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() { +class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() { companion object { @JvmStatic private val NanosPerSecond = 1_000_000_000.0 @@ -28,7 +29,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() { viewModelScope.launch { try { - llm.unload() + llamaAndroid.unload() } catch (exc: IllegalStateException) { messages += exc.message!! } @@ -44,7 +45,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() { messages += "" viewModelScope.launch { - llm.send(text) + llamaAndroid.send(text) .catch { Log.e(tag, "send() failed", it) messages += it.message!! @@ -57,7 +58,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() { viewModelScope.launch { try { val start = System.nanoTime() - val warmupResult = llm.bench(pp, tg, pl, nr) + val warmupResult = llamaAndroid.bench(pp, tg, pl, nr) val end = System.nanoTime() messages += warmupResult @@ -70,7 +71,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() { return@launch } - messages += llm.bench(512, 128, 1, 3) + messages += llamaAndroid.bench(512, 128, 1, 3) } catch (exc: IllegalStateException) { Log.e(tag, "bench() failed", exc) messages += exc.message!! @@ -81,7 +82,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() { fun load(pathToModel: String) { viewModelScope.launch { try { - llm.load(pathToModel) + llamaAndroid.load(pathToModel) messages += "Loaded $pathToModel" } catch (exc: IllegalStateException) { Log.e(tag, "load() failed", exc) |