|
|
@@ -6,6 +6,7 @@ import android.util.Log
|
|
|
import android.widget.EditText
|
|
|
import android.widget.TextView
|
|
|
import android.widget.Toast
|
|
|
+import androidx.activity.addCallback
|
|
|
import androidx.activity.enableEdgeToEdge
|
|
|
import androidx.activity.result.contract.ActivityResultContracts
|
|
|
import androidx.appcompat.app.AppCompatActivity
|
|
|
@@ -18,6 +19,7 @@ import com.arm.aichat.gguf.GgufMetadata
|
|
|
import com.arm.aichat.gguf.GgufMetadataReader
|
|
|
import com.google.android.material.floatingactionbutton.FloatingActionButton
|
|
|
import kotlinx.coroutines.Dispatchers
|
|
|
+import kotlinx.coroutines.Job
|
|
|
import kotlinx.coroutines.flow.onCompletion
|
|
|
import kotlinx.coroutines.launch
|
|
|
import kotlinx.coroutines.withContext
|
|
|
@@ -36,6 +38,7 @@ class MainActivity : AppCompatActivity() {
|
|
|
|
|
|
// Arm AI Chat inference engine
|
|
|
private lateinit var engine: InferenceEngine
|
|
|
+ private var generationJob: Job? = null
|
|
|
|
|
|
// Conversation states
|
|
|
private var isModelReady = false
|
|
|
@@ -47,11 +50,13 @@ class MainActivity : AppCompatActivity() {
|
|
|
super.onCreate(savedInstanceState)
|
|
|
enableEdgeToEdge()
|
|
|
setContentView(R.layout.activity_main)
|
|
|
+ // View model boilerplate and state management is out of this basic sample's scope
|
|
|
+ onBackPressedDispatcher.addCallback { Log.w(TAG, "Ignore back press for simplicity") }
|
|
|
|
|
|
// Find views
|
|
|
ggufTv = findViewById(R.id.gguf)
|
|
|
messagesRv = findViewById(R.id.messages)
|
|
|
- messagesRv.layoutManager = LinearLayoutManager(this)
|
|
|
+ messagesRv.layoutManager = LinearLayoutManager(this).apply { stackFromEnd = true }
|
|
|
messagesRv.adapter = messageAdapter
|
|
|
userInputEt = findViewById(R.id.user_input)
|
|
|
userActionFab = findViewById(R.id.fab)
|
|
|
@@ -157,33 +162,35 @@ class MainActivity : AppCompatActivity() {
|
|
|
* Validate and send the user message into [InferenceEngine]
|
|
|
*/
|
|
|
private fun handleUserInput() {
|
|
|
- userInputEt.text.toString().also { userSsg ->
|
|
|
- if (userSsg.isEmpty()) {
|
|
|
+ userInputEt.text.toString().also { userMsg ->
|
|
|
+ if (userMsg.isEmpty()) {
|
|
|
Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show()
|
|
|
} else {
|
|
|
userInputEt.text = null
|
|
|
+ userInputEt.isEnabled = false
|
|
|
userActionFab.isEnabled = false
|
|
|
|
|
|
// Update message states
|
|
|
- messages.add(Message(UUID.randomUUID().toString(), userSsg, true))
|
|
|
+ messages.add(Message(UUID.randomUUID().toString(), userMsg, true))
|
|
|
lastAssistantMsg.clear()
|
|
|
messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false))
|
|
|
|
|
|
- lifecycleScope.launch(Dispatchers.Default) {
|
|
|
- engine.sendUserPrompt(userSsg)
|
|
|
+ generationJob = lifecycleScope.launch(Dispatchers.Default) {
|
|
|
+ engine.sendUserPrompt(userMsg)
|
|
|
.onCompletion {
|
|
|
withContext(Dispatchers.Main) {
|
|
|
+ userInputEt.isEnabled = true
|
|
|
userActionFab.isEnabled = true
|
|
|
}
|
|
|
}.collect { token ->
|
|
|
- val messageCount = messages.size
|
|
|
- check(messageCount > 0 && !messages[messageCount - 1].isUser)
|
|
|
+ withContext(Dispatchers.Main) {
|
|
|
+ val messageCount = messages.size
|
|
|
+ check(messageCount > 0 && !messages[messageCount - 1].isUser)
|
|
|
|
|
|
- messages.removeAt(messageCount - 1).copy(
|
|
|
- content = lastAssistantMsg.append(token).toString()
|
|
|
- ).let { messages.add(it) }
|
|
|
+ messages.removeAt(messageCount - 1).copy(
|
|
|
+ content = lastAssistantMsg.append(token).toString()
|
|
|
+ ).let { messages.add(it) }
|
|
|
|
|
|
- withContext(Dispatchers.Main) {
|
|
|
messageAdapter.notifyItemChanged(messages.size - 1)
|
|
|
}
|
|
|
}
|
|
|
@@ -195,6 +202,7 @@ class MainActivity : AppCompatActivity() {
|
|
|
/**
|
|
|
* Run a benchmark with the model file
|
|
|
*/
|
|
|
+ @Deprecated("This benchmark doesn't accurately indicate GUI performance expected by app developers")
|
|
|
private suspend fun runBenchmark(modelName: String, modelFile: File) =
|
|
|
withContext(Dispatchers.Default) {
|
|
|
Log.i(TAG, "Starts benchmarking $modelName")
|
|
|
@@ -223,6 +231,16 @@ class MainActivity : AppCompatActivity() {
|
|
|
if (!it.exists()) { it.mkdir() }
|
|
|
}
|
|
|
|
|
|
+ override fun onStop() {
|
|
|
+ generationJob?.cancel()
|
|
|
+ super.onStop()
|
|
|
+ }
|
|
|
+
|
|
|
+ override fun onDestroy() {
|
|
|
+ engine.destroy()
|
|
|
+ super.onDestroy()
|
|
|
+ }
|
|
|
+
|
|
|
companion object {
|
|
|
private val TAG = MainActivity::class.java.simpleName
|
|
|
|