summaryrefslogtreecommitdiff
path: root/examples/llama.android/llama/src/main/java/android
diff options
context:
space:
mode:
Diffstat (limited to 'examples/llama.android/llama/src/main/java/android')
-rw-r--r--examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt172
1 files changed, 172 insertions, 0 deletions
diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt
new file mode 100644
index 00000000..6c63e54e
--- /dev/null
+++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt
@@ -0,0 +1,172 @@
+package android.llama.cpp
+
+import android.util.Log
+import kotlinx.coroutines.CoroutineDispatcher
+import kotlinx.coroutines.asCoroutineDispatcher
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.flow
+import kotlinx.coroutines.flow.flowOn
+import kotlinx.coroutines.withContext
+import java.util.concurrent.Executors
+import kotlin.concurrent.thread
+
+class LLamaAndroid {
+ private val tag: String? = this::class.simpleName
+
+ private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
+
+ private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
+ thread(start = false, name = "Llm-RunLoop") {
+ Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
+
+ // No-op if called more than once.
+ System.loadLibrary("llama-android")
+
+ // Set llama log handler to Android
+ log_to_android()
+ backend_init(false)
+
+ Log.d(tag, system_info())
+
+ it.run()
+ }.apply {
+ uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
+ Log.e(tag, "Unhandled exception", exception)
+ }
+ }
+ }.asCoroutineDispatcher()
+
+ private val nlen: Int = 64
+
+ private external fun log_to_android()
+ private external fun load_model(filename: String): Long
+ private external fun free_model(model: Long)
+ private external fun new_context(model: Long): Long
+ private external fun free_context(context: Long)
+ private external fun backend_init(numa: Boolean)
+ private external fun backend_free()
+ private external fun free_batch(batch: Long)
+ private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
+ private external fun bench_model(
+ context: Long,
+ model: Long,
+ batch: Long,
+ pp: Int,
+ tg: Int,
+ pl: Int,
+ nr: Int
+ ): String
+
+ private external fun system_info(): String
+
+ private external fun completion_init(
+ context: Long,
+ batch: Long,
+ text: String,
+ nLen: Int
+ ): Int
+
+ private external fun completion_loop(
+ context: Long,
+ batch: Long,
+ nLen: Int,
+ ncur: IntVar
+ ): String?
+
+ private external fun kv_cache_clear(context: Long)
+
+ suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
+ return withContext(runLoop) {
+ when (val state = threadLocalState.get()) {
+ is State.Loaded -> {
+ Log.d(tag, "bench(): $state")
+ bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
+ }
+
+ else -> throw IllegalStateException("No model loaded")
+ }
+ }
+ }
+
+ suspend fun load(pathToModel: String) {
+ withContext(runLoop) {
+ when (threadLocalState.get()) {
+ is State.Idle -> {
+ val model = load_model(pathToModel)
+ if (model == 0L) throw IllegalStateException("load_model() failed")
+
+ val context = new_context(model)
+ if (context == 0L) throw IllegalStateException("new_context() failed")
+
+ val batch = new_batch(512, 0, 1)
+ if (batch == 0L) throw IllegalStateException("new_batch() failed")
+
+ Log.i(tag, "Loaded model $pathToModel")
+ threadLocalState.set(State.Loaded(model, context, batch))
+ }
+ else -> throw IllegalStateException("Model already loaded")
+ }
+ }
+ }
+
+ fun send(message: String): Flow<String> = flow {
+ when (val state = threadLocalState.get()) {
+ is State.Loaded -> {
+ val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
+ while (ncur.value <= nlen) {
+ val str = completion_loop(state.context, state.batch, nlen, ncur)
+ if (str == null) {
+ break
+ }
+ emit(str)
+ }
+ kv_cache_clear(state.context)
+ }
+ else -> {}
+ }
+ }.flowOn(runLoop)
+
+ /**
+ * Unloads the model and frees resources.
+ *
+ * This is a no-op if there's no model loaded.
+ */
+ suspend fun unload() {
+ withContext(runLoop) {
+ when (val state = threadLocalState.get()) {
+ is State.Loaded -> {
+ free_context(state.context)
+ free_model(state.model)
+ free_batch(state.batch)
+
+ threadLocalState.set(State.Idle)
+ }
+ else -> {}
+ }
+ }
+ }
+
+ companion object {
+ private class IntVar(value: Int) {
+ @Volatile
+ var value: Int = value
+ private set
+
+ fun inc() {
+ synchronized(this) {
+ value += 1
+ }
+ }
+ }
+
+ private sealed interface State {
+ data object Idle: State
+ data class Loaded(val model: Long, val context: Long, val batch: Long): State
+ }
+
+ // Enforce only one instance of Llm.
+ private val _instance: LLamaAndroid = LLamaAndroid()
+
+ fun instance(): LLamaAndroid = _instance
+ }
+}