MainViewModel.kt 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. package com.example.llama
  2. import android.util.Log
  3. import androidx.compose.runtime.getValue
  4. import androidx.compose.runtime.mutableStateOf
  5. import androidx.compose.runtime.setValue
  6. import androidx.lifecycle.ViewModel
  7. import androidx.lifecycle.viewModelScope
  8. import kotlinx.coroutines.flow.catch
  9. import kotlinx.coroutines.launch
  10. class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
  11. companion object {
  12. @JvmStatic
  13. private val NanosPerSecond = 1_000_000_000.0
  14. }
  15. private val tag: String? = this::class.simpleName
  16. var messages by mutableStateOf(listOf("Initializing..."))
  17. private set
  18. var message by mutableStateOf("")
  19. private set
  20. override fun onCleared() {
  21. super.onCleared()
  22. viewModelScope.launch {
  23. try {
  24. llm.unload()
  25. } catch (exc: IllegalStateException) {
  26. messages += exc.message!!
  27. }
  28. }
  29. }
  30. fun send() {
  31. val text = message
  32. message = ""
  33. // Add to messages console.
  34. messages += text
  35. messages += ""
  36. viewModelScope.launch {
  37. llm.send(text)
  38. .catch {
  39. Log.e(tag, "send() failed", it)
  40. messages += it.message!!
  41. }
  42. .collect { messages = messages.dropLast(1) + (messages.last() + it) }
  43. }
  44. }
  45. fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) {
  46. viewModelScope.launch {
  47. try {
  48. val start = System.nanoTime()
  49. val warmupResult = llm.bench(pp, tg, pl, nr)
  50. val end = System.nanoTime()
  51. messages += warmupResult
  52. val warmup = (end - start).toDouble() / NanosPerSecond
  53. messages += "Warm up time: $warmup seconds, please wait..."
  54. if (warmup > 5.0) {
  55. messages += "Warm up took too long, aborting benchmark"
  56. return@launch
  57. }
  58. messages += llm.bench(512, 128, 1, 3)
  59. } catch (exc: IllegalStateException) {
  60. Log.e(tag, "bench() failed", exc)
  61. messages += exc.message!!
  62. }
  63. }
  64. }
  65. fun load(pathToModel: String) {
  66. viewModelScope.launch {
  67. try {
  68. llm.load(pathToModel)
  69. messages += "Loaded $pathToModel"
  70. } catch (exc: IllegalStateException) {
  71. Log.e(tag, "load() failed", exc)
  72. messages += exc.message!!
  73. }
  74. }
  75. }
  76. fun updateMessage(newMessage: String) {
  77. message = newMessage
  78. }
  79. fun clear() {
  80. messages = listOf()
  81. }
  82. fun log(message: String) {
  83. messages += message
  84. }
  85. }