| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- package com.example.llama
- import android.util.Log
- import kotlinx.coroutines.CoroutineDispatcher
- import kotlinx.coroutines.asCoroutineDispatcher
- import kotlinx.coroutines.flow.Flow
- import kotlinx.coroutines.flow.flow
- import kotlinx.coroutines.flow.flowOn
- import kotlinx.coroutines.withContext
- import java.util.concurrent.Executors
- import kotlin.concurrent.thread
- class Llm {
- private val tag: String? = this::class.simpleName
- private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
- private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
- thread(start = false, name = "Llm-RunLoop") {
- Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
- // No-op if called more than once.
- System.loadLibrary("llama-android")
- // Set llama log handler to Android
- log_to_android()
- backend_init(false)
- Log.d(tag, system_info())
- it.run()
- }.apply {
- uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
- Log.e(tag, "Unhandled exception", exception)
- }
- }
- }.asCoroutineDispatcher()
- private val nlen: Int = 64
- private external fun log_to_android()
- private external fun load_model(filename: String): Long
- private external fun free_model(model: Long)
- private external fun new_context(model: Long): Long
- private external fun free_context(context: Long)
- private external fun backend_init(numa: Boolean)
- private external fun backend_free()
- private external fun free_batch(batch: Long)
- private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
- private external fun bench_model(
- context: Long,
- model: Long,
- batch: Long,
- pp: Int,
- tg: Int,
- pl: Int,
- nr: Int
- ): String
- private external fun system_info(): String
- private external fun completion_init(
- context: Long,
- batch: Long,
- text: String,
- nLen: Int
- ): Int
- private external fun completion_loop(
- context: Long,
- batch: Long,
- nLen: Int,
- ncur: IntVar
- ): String?
- private external fun kv_cache_clear(context: Long)
- suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
- return withContext(runLoop) {
- when (val state = threadLocalState.get()) {
- is State.Loaded -> {
- Log.d(tag, "bench(): $state")
- bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
- }
- else -> throw IllegalStateException("No model loaded")
- }
- }
- }
- suspend fun load(pathToModel: String) {
- withContext(runLoop) {
- when (threadLocalState.get()) {
- is State.Idle -> {
- val model = load_model(pathToModel)
- if (model == 0L) throw IllegalStateException("load_model() failed")
- val context = new_context(model)
- if (context == 0L) throw IllegalStateException("new_context() failed")
- val batch = new_batch(512, 0, 1)
- if (batch == 0L) throw IllegalStateException("new_batch() failed")
- Log.i(tag, "Loaded model $pathToModel")
- threadLocalState.set(State.Loaded(model, context, batch))
- }
- else -> throw IllegalStateException("Model already loaded")
- }
- }
- }
- fun send(message: String): Flow<String> = flow {
- when (val state = threadLocalState.get()) {
- is State.Loaded -> {
- val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
- while (ncur.value <= nlen) {
- val str = completion_loop(state.context, state.batch, nlen, ncur)
- if (str == null) {
- break
- }
- emit(str)
- }
- kv_cache_clear(state.context)
- }
- else -> {}
- }
- }.flowOn(runLoop)
- /**
- * Unloads the model and frees resources.
- *
- * This is a no-op if there's no model loaded.
- */
- suspend fun unload() {
- withContext(runLoop) {
- when (val state = threadLocalState.get()) {
- is State.Loaded -> {
- free_context(state.context)
- free_model(state.model)
- free_batch(state.batch)
- threadLocalState.set(State.Idle)
- }
- else -> {}
- }
- }
- }
- companion object {
- private class IntVar(value: Int) {
- @Volatile
- var value: Int = value
- private set
- fun inc() {
- synchronized(this) {
- value += 1
- }
- }
- }
- private sealed interface State {
- data object Idle: State
- data class Loaded(val model: Long, val context: Long, val batch: Long): State
- }
- // Enforce only one instance of Llm.
- private val _instance: Llm = Llm()
- fun instance(): Llm = _instance
- }
- }
|