diff --git a/examples/llama.android/app/build.gradle.kts b/examples/llama.android/app/build.gradle.kts
index 3524fe39c4..2edfe98845 100644
--- a/examples/llama.android/app/build.gradle.kts
+++ b/examples/llama.android/app/build.gradle.kts
@@ -41,11 +41,8 @@ android {
}
}
compileOptions {
- sourceCompatibility = JavaVersion.VERSION_1_8
- targetCompatibility = JavaVersion.VERSION_1_8
- }
- kotlinOptions {
- jvmTarget = "1.8"
+ sourceCompatibility = JavaVersion.VERSION_17
+ targetCompatibility = JavaVersion.VERSION_17
}
}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt
index 52c5dc2154..872ec2b98a 100644
--- a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt
+++ b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt
@@ -6,6 +6,7 @@ import android.util.Log
import android.widget.EditText
import android.widget.TextView
import android.widget.Toast
+import androidx.activity.addCallback
import androidx.activity.enableEdgeToEdge
import androidx.activity.result.contract.ActivityResultContracts
import androidx.appcompat.app.AppCompatActivity
@@ -18,6 +19,7 @@ import com.arm.aichat.gguf.GgufMetadata
import com.arm.aichat.gguf.GgufMetadataReader
import com.google.android.material.floatingactionbutton.FloatingActionButton
import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
@@ -36,6 +38,7 @@ class MainActivity : AppCompatActivity() {
// Arm AI Chat inference engine
private lateinit var engine: InferenceEngine
+ private var generationJob: Job? = null
// Conversation states
private var isModelReady = false
@@ -47,11 +50,13 @@ class MainActivity : AppCompatActivity() {
super.onCreate(savedInstanceState)
enableEdgeToEdge()
setContentView(R.layout.activity_main)
+ // View model boilerplate and state management is out of this basic sample's scope
+ onBackPressedDispatcher.addCallback { Log.w(TAG, "Ignore back press for simplicity") }
// Find views
ggufTv = findViewById(R.id.gguf)
messagesRv = findViewById(R.id.messages)
- messagesRv.layoutManager = LinearLayoutManager(this)
+ messagesRv.layoutManager = LinearLayoutManager(this).apply { stackFromEnd = true }
messagesRv.adapter = messageAdapter
userInputEt = findViewById(R.id.user_input)
userActionFab = findViewById(R.id.fab)
@@ -157,33 +162,35 @@ class MainActivity : AppCompatActivity() {
* Validate and send the user message into [InferenceEngine]
*/
private fun handleUserInput() {
- userInputEt.text.toString().also { userSsg ->
- if (userSsg.isEmpty()) {
+ userInputEt.text.toString().also { userMsg ->
+ if (userMsg.isEmpty()) {
Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show()
} else {
userInputEt.text = null
+ userInputEt.isEnabled = false
userActionFab.isEnabled = false
// Update message states
- messages.add(Message(UUID.randomUUID().toString(), userSsg, true))
+ messages.add(Message(UUID.randomUUID().toString(), userMsg, true))
lastAssistantMsg.clear()
messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false))
- lifecycleScope.launch(Dispatchers.Default) {
- engine.sendUserPrompt(userSsg)
+ generationJob = lifecycleScope.launch(Dispatchers.Default) {
+ engine.sendUserPrompt(userMsg)
.onCompletion {
withContext(Dispatchers.Main) {
+ userInputEt.isEnabled = true
userActionFab.isEnabled = true
}
}.collect { token ->
- val messageCount = messages.size
- check(messageCount > 0 && !messages[messageCount - 1].isUser)
-
- messages.removeAt(messageCount - 1).copy(
- content = lastAssistantMsg.append(token).toString()
- ).let { messages.add(it) }
-
withContext(Dispatchers.Main) {
+ val messageCount = messages.size
+ check(messageCount > 0 && !messages[messageCount - 1].isUser)
+
+ messages.removeAt(messageCount - 1).copy(
+ content = lastAssistantMsg.append(token).toString()
+ ).let { messages.add(it) }
+
messageAdapter.notifyItemChanged(messages.size - 1)
}
}
@@ -195,6 +202,7 @@ class MainActivity : AppCompatActivity() {
/**
* Run a benchmark with the model file
*/
+ @Deprecated("This benchmark doesn't accurately indicate GUI performance expected by app developers")
private suspend fun runBenchmark(modelName: String, modelFile: File) =
withContext(Dispatchers.Default) {
Log.i(TAG, "Starts benchmarking $modelName")
@@ -223,6 +231,16 @@ class MainActivity : AppCompatActivity() {
if (!it.exists()) { it.mkdir() }
}
+ override fun onStop() {
+ generationJob?.cancel()
+ super.onStop()
+ }
+
+ override fun onDestroy() {
+ engine.destroy()
+ super.onDestroy()
+ }
+
companion object {
private val TAG = MainActivity::class.java.simpleName
diff --git a/examples/llama.android/app/src/main/res/layout/activity_main.xml b/examples/llama.android/app/src/main/res/layout/activity_main.xml
index ad805a674e..d15772bd37 100644
--- a/examples/llama.android/app/src/main/res/layout/activity_main.xml
+++ b/examples/llama.android/app/src/main/res/layout/activity_main.xml
@@ -24,7 +24,7 @@
android:id="@+id/gguf"
android:layout_width="match_parent"
android:layout_height="wrap_content"
- android:layout_margin="16dp"
+ android:padding="16dp"
android:text="Selected GGUF model's metadata will show here."
style="@style/TextAppearance.MaterialComponents.Body2" />
@@ -33,8 +33,7 @@
+ android:layout_marginHorizontal="16dp" />
(InferenceEngine.State.Uninitialized)
- override val state: StateFlow = _state
+ override val state: StateFlow = _state.asStateFlow()
private var _readyForSystemPrompt = false
+ @Volatile
+ private var _cancelGeneration = false
/**
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
@@ -169,6 +173,8 @@ internal class InferenceEngineImpl private constructor(
}
Log.i(TAG, "Model loaded!")
_readyForSystemPrompt = true
+
+ _cancelGeneration = false
_state.value = InferenceEngine.State.ModelReady
} catch (e: Exception) {
Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e)
@@ -231,15 +237,19 @@ internal class InferenceEngineImpl private constructor(
Log.i(TAG, "User prompt processed. Generating assistant prompt...")
_state.value = InferenceEngine.State.Generating
- while (true) {
+ while (!_cancelGeneration) {
generateNextToken()?.let { utf8token ->
if (utf8token.isNotEmpty()) emit(utf8token)
} ?: break
}
- Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
+ if (_cancelGeneration) {
+ Log.i(TAG, "Assistant generation aborted per requested.")
+ } else {
+ Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
+ }
_state.value = InferenceEngine.State.ModelReady
} catch (e: CancellationException) {
- Log.i(TAG, "Generation cancelled by user.")
+ Log.i(TAG, "Assistant generation's flow collection cancelled.")
_state.value = InferenceEngine.State.ModelReady
throw e
} catch (e: Exception) {
@@ -268,8 +278,9 @@ internal class InferenceEngineImpl private constructor(
/**
* Unloads the model and frees resources, or reset error states
*/
- override suspend fun cleanUp() =
- withContext(llamaDispatcher) {
+ override fun cleanUp() {
+ _cancelGeneration = true
+ runBlocking(llamaDispatcher) {
when (val state = _state.value) {
is InferenceEngine.State.ModelReady -> {
Log.i(TAG, "Unloading model and free resources...")
@@ -293,17 +304,21 @@ internal class InferenceEngineImpl private constructor(
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
}
}
+ }
/**
* Cancel all ongoing coroutines and free GGML backends
*/
override fun destroy() {
- _readyForSystemPrompt = false
- llamaScope.cancel()
- when(_state.value) {
- is InferenceEngine.State.Uninitialized -> {}
- is InferenceEngine.State.Initialized -> shutdown()
- else -> { unload(); shutdown() }
+ _cancelGeneration = true
+ runBlocking(llamaDispatcher) {
+ _readyForSystemPrompt = false
+ when(_state.value) {
+ is InferenceEngine.State.Uninitialized -> {}
+ is InferenceEngine.State.Initialized -> shutdown()
+ else -> { unload(); shutdown() }
+ }
}
+ llamaScope.cancel()
}
}