package com.example.llama import android.util.Log import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.asCoroutineDispatcher import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flowOn import kotlinx.coroutines.withContext import java.util.concurrent.Executors import kotlin.concurrent.thread class Llm { private val tag: String? = this::class.simpleName private val threadLocalState: ThreadLocal = ThreadLocal.withInitial { State.Idle } private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor { thread(start = false, name = "Llm-RunLoop") { Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}") // No-op if called more than once. System.loadLibrary("llama-android") // Set llama log handler to Android log_to_android() backend_init(false) Log.d(tag, system_info()) it.run() }.apply { uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable -> Log.e(tag, "Unhandled exception", exception) } } }.asCoroutineDispatcher() private val nlen: Int = 64 private external fun log_to_android() private external fun load_model(filename: String): Long private external fun free_model(model: Long) private external fun new_context(model: Long): Long private external fun free_context(context: Long) private external fun backend_init(numa: Boolean) private external fun backend_free() private external fun free_batch(batch: Long) private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long private external fun bench_model( context: Long, model: Long, batch: Long, pp: Int, tg: Int, pl: Int, nr: Int ): String private external fun system_info(): String private external fun completion_init( context: Long, batch: Long, text: String, nLen: Int ): Int private external fun completion_loop( context: Long, batch: Long, nLen: Int, ncur: IntVar ): String? private external fun kv_cache_clear(context: Long) suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { return withContext(runLoop) { when (val state = threadLocalState.get()) { is State.Loaded -> { Log.d(tag, "bench(): $state") bench_model(state.context, state.model, state.batch, pp, tg, pl, nr) } else -> throw IllegalStateException("No model loaded") } } } suspend fun load(pathToModel: String) { withContext(runLoop) { when (threadLocalState.get()) { is State.Idle -> { val model = load_model(pathToModel) if (model == 0L) throw IllegalStateException("load_model() failed") val context = new_context(model) if (context == 0L) throw IllegalStateException("new_context() failed") val batch = new_batch(512, 0, 1) if (batch == 0L) throw IllegalStateException("new_batch() failed") Log.i(tag, "Loaded model $pathToModel") threadLocalState.set(State.Loaded(model, context, batch)) } else -> throw IllegalStateException("Model already loaded") } } } fun send(message: String): Flow = flow { when (val state = threadLocalState.get()) { is State.Loaded -> { val ncur = IntVar(completion_init(state.context, state.batch, message, nlen)) while (ncur.value <= nlen) { val str = completion_loop(state.context, state.batch, nlen, ncur) if (str == null) { break } emit(str) } kv_cache_clear(state.context) } else -> {} } }.flowOn(runLoop) /** * Unloads the model and frees resources. * * This is a no-op if there's no model loaded. */ suspend fun unload() { withContext(runLoop) { when (val state = threadLocalState.get()) { is State.Loaded -> { free_context(state.context) free_model(state.model) free_batch(state.batch) threadLocalState.set(State.Idle) } else -> {} } } } companion object { private class IntVar(value: Int) { @Volatile var value: Int = value private set fun inc() { synchronized(this) { value += 1 } } } private sealed interface State { data object Idle: State data class Loaded(val model: Long, val context: Long, val batch: Long): State } // Enforce only one instance of Llm. private val _instance: Llm = Llm() fun instance(): Llm = _instance } }