diff options
author | Brian <mofosyne@gmail.com> | 2024-05-14 23:10:39 +1000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-14 16:10:39 +0300 |
commit | 1265c670fd8e41e1947352c96c5179adda97fb2c (patch) | |
tree | 8ef6858e5701e5ff6e7a4e656e26589af5067963 /examples/llama.android/app/src/main/java/com | |
parent | 5e31828d3e35c76ecfee665bc23771a4bec1d130 (diff) |
Revert "move ndk code to a new library (#6951)" (#7282)
This reverts commit efc8f767c8c8c749a245dd96ad4e2f37c164b54c.
Diffstat (limited to 'examples/llama.android/app/src/main/java/com')
-rw-r--r-- | examples/llama.android/app/src/main/java/com/example/llama/Llm.kt | 172 | ||||
-rw-r--r-- | examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt | 13 |
2 files changed, 178 insertions, 7 deletions
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt b/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt new file mode 100644 index 00000000..d86afee3 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt @@ -0,0 +1,172 @@ +package com.example.llama + +import android.util.Log +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.asCoroutineDispatcher +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOn +import kotlinx.coroutines.withContext +import java.util.concurrent.Executors +import kotlin.concurrent.thread + +class Llm { + private val tag: String? = this::class.simpleName + + private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle } + + private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor { + thread(start = false, name = "Llm-RunLoop") { + Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}") + + // No-op if called more than once. + System.loadLibrary("llama-android") + + // Set llama log handler to Android + log_to_android() + backend_init(false) + + Log.d(tag, system_info()) + + it.run() + }.apply { + uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable -> + Log.e(tag, "Unhandled exception", exception) + } + } + }.asCoroutineDispatcher() + + private val nlen: Int = 64 + + private external fun log_to_android() + private external fun load_model(filename: String): Long + private external fun free_model(model: Long) + private external fun new_context(model: Long): Long + private external fun free_context(context: Long) + private external fun backend_init(numa: Boolean) + private external fun backend_free() + private external fun free_batch(batch: Long) + private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long + private external fun bench_model( + context: Long, + model: Long, + batch: Long, + pp: Int, + tg: Int, + pl: Int, + nr: Int + ): String + + private external fun system_info(): String + + private external fun completion_init( + context: Long, + batch: Long, + text: String, + nLen: Int + ): Int + + private external fun completion_loop( + context: Long, + batch: Long, + nLen: Int, + ncur: IntVar + ): String? + + private external fun kv_cache_clear(context: Long) + + suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { + return withContext(runLoop) { + when (val state = threadLocalState.get()) { + is State.Loaded -> { + Log.d(tag, "bench(): $state") + bench_model(state.context, state.model, state.batch, pp, tg, pl, nr) + } + + else -> throw IllegalStateException("No model loaded") + } + } + } + + suspend fun load(pathToModel: String) { + withContext(runLoop) { + when (threadLocalState.get()) { + is State.Idle -> { + val model = load_model(pathToModel) + if (model == 0L) throw IllegalStateException("load_model() failed") + + val context = new_context(model) + if (context == 0L) throw IllegalStateException("new_context() failed") + + val batch = new_batch(512, 0, 1) + if (batch == 0L) throw IllegalStateException("new_batch() failed") + + Log.i(tag, "Loaded model $pathToModel") + threadLocalState.set(State.Loaded(model, context, batch)) + } + else -> throw IllegalStateException("Model already loaded") + } + } + } + + fun send(message: String): Flow<String> = flow { + when (val state = threadLocalState.get()) { + is State.Loaded -> { + val ncur = IntVar(completion_init(state.context, state.batch, message, nlen)) + while (ncur.value <= nlen) { + val str = completion_loop(state.context, state.batch, nlen, ncur) + if (str == null) { + break + } + emit(str) + } + kv_cache_clear(state.context) + } + else -> {} + } + }.flowOn(runLoop) + + /** + * Unloads the model and frees resources. + * + * This is a no-op if there's no model loaded. + */ + suspend fun unload() { + withContext(runLoop) { + when (val state = threadLocalState.get()) { + is State.Loaded -> { + free_context(state.context) + free_model(state.model) + free_batch(state.batch) + + threadLocalState.set(State.Idle) + } + else -> {} + } + } + } + + companion object { + private class IntVar(value: Int) { + @Volatile + var value: Int = value + private set + + fun inc() { + synchronized(this) { + value += 1 + } + } + } + + private sealed interface State { + data object Idle: State + data class Loaded(val model: Long, val context: Long, val batch: Long): State + } + + // Enforce only one instance of Llm. + private val _instance: Llm = Llm() + + fun instance(): Llm = _instance + } +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt index 45ac2993..be95e222 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt @@ -1,6 +1,5 @@ package com.example.llama -import android.llama.cpp.LLamaAndroid import android.util.Log import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf @@ -10,7 +9,7 @@ import androidx.lifecycle.viewModelScope import kotlinx.coroutines.flow.catch import kotlinx.coroutines.launch -class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() { +class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() { companion object { @JvmStatic private val NanosPerSecond = 1_000_000_000.0 @@ -29,7 +28,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan viewModelScope.launch { try { - llamaAndroid.unload() + llm.unload() } catch (exc: IllegalStateException) { messages += exc.message!! } @@ -45,7 +44,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan messages += "" viewModelScope.launch { - llamaAndroid.send(text) + llm.send(text) .catch { Log.e(tag, "send() failed", it) messages += it.message!! @@ -58,7 +57,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan viewModelScope.launch { try { val start = System.nanoTime() - val warmupResult = llamaAndroid.bench(pp, tg, pl, nr) + val warmupResult = llm.bench(pp, tg, pl, nr) val end = System.nanoTime() messages += warmupResult @@ -71,7 +70,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan return@launch } - messages += llamaAndroid.bench(512, 128, 1, 3) + messages += llm.bench(512, 128, 1, 3) } catch (exc: IllegalStateException) { Log.e(tag, "bench() failed", exc) messages += exc.message!! @@ -82,7 +81,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan fun load(pathToModel: String) { viewModelScope.launch { try { - llamaAndroid.load(pathToModel) + llm.load(pathToModel) messages += "Loaded $pathToModel" } catch (exc: IllegalStateException) { Log.e(tag, "load() failed", exc) |