Llm.kt 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package com.example.llama
  2. import android.util.Log
  3. import kotlinx.coroutines.CoroutineDispatcher
  4. import kotlinx.coroutines.asCoroutineDispatcher
  5. import kotlinx.coroutines.flow.Flow
  6. import kotlinx.coroutines.flow.flow
  7. import kotlinx.coroutines.flow.flowOn
  8. import kotlinx.coroutines.withContext
  9. import java.util.concurrent.Executors
  10. import kotlin.concurrent.thread
  11. class Llm {
  12. private val tag: String? = this::class.simpleName
  13. private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
  14. private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
  15. thread(start = false, name = "Llm-RunLoop") {
  16. Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
  17. // No-op if called more than once.
  18. System.loadLibrary("llama-android")
  19. // Set llama log handler to Android
  20. log_to_android()
  21. backend_init(false)
  22. Log.d(tag, system_info())
  23. it.run()
  24. }.apply {
  25. uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
  26. Log.e(tag, "Unhandled exception", exception)
  27. }
  28. }
  29. }.asCoroutineDispatcher()
  30. private val nlen: Int = 64
  31. private external fun log_to_android()
  32. private external fun load_model(filename: String): Long
  33. private external fun free_model(model: Long)
  34. private external fun new_context(model: Long): Long
  35. private external fun free_context(context: Long)
  36. private external fun backend_init(numa: Boolean)
  37. private external fun backend_free()
  38. private external fun free_batch(batch: Long)
  39. private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
  40. private external fun bench_model(
  41. context: Long,
  42. model: Long,
  43. batch: Long,
  44. pp: Int,
  45. tg: Int,
  46. pl: Int,
  47. nr: Int
  48. ): String
  49. private external fun system_info(): String
  50. private external fun completion_init(
  51. context: Long,
  52. batch: Long,
  53. text: String,
  54. nLen: Int
  55. ): Int
  56. private external fun completion_loop(
  57. context: Long,
  58. batch: Long,
  59. nLen: Int,
  60. ncur: IntVar
  61. ): String?
  62. private external fun kv_cache_clear(context: Long)
  63. suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
  64. return withContext(runLoop) {
  65. when (val state = threadLocalState.get()) {
  66. is State.Loaded -> {
  67. Log.d(tag, "bench(): $state")
  68. bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
  69. }
  70. else -> throw IllegalStateException("No model loaded")
  71. }
  72. }
  73. }
  74. suspend fun load(pathToModel: String) {
  75. withContext(runLoop) {
  76. when (threadLocalState.get()) {
  77. is State.Idle -> {
  78. val model = load_model(pathToModel)
  79. if (model == 0L) throw IllegalStateException("load_model() failed")
  80. val context = new_context(model)
  81. if (context == 0L) throw IllegalStateException("new_context() failed")
  82. val batch = new_batch(512, 0, 1)
  83. if (batch == 0L) throw IllegalStateException("new_batch() failed")
  84. Log.i(tag, "Loaded model $pathToModel")
  85. threadLocalState.set(State.Loaded(model, context, batch))
  86. }
  87. else -> throw IllegalStateException("Model already loaded")
  88. }
  89. }
  90. }
  91. fun send(message: String): Flow<String> = flow {
  92. when (val state = threadLocalState.get()) {
  93. is State.Loaded -> {
  94. val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
  95. while (ncur.value <= nlen) {
  96. val str = completion_loop(state.context, state.batch, nlen, ncur)
  97. if (str == null) {
  98. break
  99. }
  100. emit(str)
  101. }
  102. kv_cache_clear(state.context)
  103. }
  104. else -> {}
  105. }
  106. }.flowOn(runLoop)
  107. /**
  108. * Unloads the model and frees resources.
  109. *
  110. * This is a no-op if there's no model loaded.
  111. */
  112. suspend fun unload() {
  113. withContext(runLoop) {
  114. when (val state = threadLocalState.get()) {
  115. is State.Loaded -> {
  116. free_context(state.context)
  117. free_model(state.model)
  118. free_batch(state.batch)
  119. threadLocalState.set(State.Idle)
  120. }
  121. else -> {}
  122. }
  123. }
  124. }
  125. companion object {
  126. private class IntVar(value: Int) {
  127. @Volatile
  128. var value: Int = value
  129. private set
  130. fun inc() {
  131. synchronized(this) {
  132. value += 1
  133. }
  134. }
  135. }
  136. private sealed interface State {
  137. data object Idle: State
  138. data class Loaded(val model: Long, val context: Long, val batch: Long): State
  139. }
  140. // Enforce only one instance of Llm.
  141. private val _instance: Llm = Llm()
  142. fun instance(): Llm = _instance
  143. }
  144. }