diff --git a/examples/llama.android/app/build.gradle.kts b/examples/llama.android/app/build.gradle.kts index 3524fe39c4..2edfe98845 100644 --- a/examples/llama.android/app/build.gradle.kts +++ b/examples/llama.android/app/build.gradle.kts @@ -41,11 +41,8 @@ android { } } compileOptions { - sourceCompatibility = JavaVersion.VERSION_1_8 - targetCompatibility = JavaVersion.VERSION_1_8 - } - kotlinOptions { - jvmTarget = "1.8" + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt index 52c5dc2154..872ec2b98a 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt @@ -6,6 +6,7 @@ import android.util.Log import android.widget.EditText import android.widget.TextView import android.widget.Toast +import androidx.activity.addCallback import androidx.activity.enableEdgeToEdge import androidx.activity.result.contract.ActivityResultContracts import androidx.appcompat.app.AppCompatActivity @@ -18,6 +19,7 @@ import com.arm.aichat.gguf.GgufMetadata import com.arm.aichat.gguf.GgufMetadataReader import com.google.android.material.floatingactionbutton.FloatingActionButton import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.launch import kotlinx.coroutines.withContext @@ -36,6 +38,7 @@ class MainActivity : AppCompatActivity() { // Arm AI Chat inference engine private lateinit var engine: InferenceEngine + private var generationJob: Job? = null // Conversation states private var isModelReady = false @@ -47,11 +50,13 @@ class MainActivity : AppCompatActivity() { super.onCreate(savedInstanceState) enableEdgeToEdge() setContentView(R.layout.activity_main) + // View model boilerplate and state management is out of this basic sample's scope + onBackPressedDispatcher.addCallback { Log.w(TAG, "Ignore back press for simplicity") } // Find views ggufTv = findViewById(R.id.gguf) messagesRv = findViewById(R.id.messages) - messagesRv.layoutManager = LinearLayoutManager(this) + messagesRv.layoutManager = LinearLayoutManager(this).apply { stackFromEnd = true } messagesRv.adapter = messageAdapter userInputEt = findViewById(R.id.user_input) userActionFab = findViewById(R.id.fab) @@ -157,33 +162,35 @@ class MainActivity : AppCompatActivity() { * Validate and send the user message into [InferenceEngine] */ private fun handleUserInput() { - userInputEt.text.toString().also { userSsg -> - if (userSsg.isEmpty()) { + userInputEt.text.toString().also { userMsg -> + if (userMsg.isEmpty()) { Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show() } else { userInputEt.text = null + userInputEt.isEnabled = false userActionFab.isEnabled = false // Update message states - messages.add(Message(UUID.randomUUID().toString(), userSsg, true)) + messages.add(Message(UUID.randomUUID().toString(), userMsg, true)) lastAssistantMsg.clear() messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false)) - lifecycleScope.launch(Dispatchers.Default) { - engine.sendUserPrompt(userSsg) + generationJob = lifecycleScope.launch(Dispatchers.Default) { + engine.sendUserPrompt(userMsg) .onCompletion { withContext(Dispatchers.Main) { + userInputEt.isEnabled = true userActionFab.isEnabled = true } }.collect { token -> - val messageCount = messages.size - check(messageCount > 0 && !messages[messageCount - 1].isUser) - - messages.removeAt(messageCount - 1).copy( - content = lastAssistantMsg.append(token).toString() - ).let { messages.add(it) } - withContext(Dispatchers.Main) { + val messageCount = messages.size + check(messageCount > 0 && !messages[messageCount - 1].isUser) + + messages.removeAt(messageCount - 1).copy( + content = lastAssistantMsg.append(token).toString() + ).let { messages.add(it) } + messageAdapter.notifyItemChanged(messages.size - 1) } } @@ -195,6 +202,7 @@ class MainActivity : AppCompatActivity() { /** * Run a benchmark with the model file */ + @Deprecated("This benchmark doesn't accurately indicate GUI performance expected by app developers") private suspend fun runBenchmark(modelName: String, modelFile: File) = withContext(Dispatchers.Default) { Log.i(TAG, "Starts benchmarking $modelName") @@ -223,6 +231,16 @@ class MainActivity : AppCompatActivity() { if (!it.exists()) { it.mkdir() } } + override fun onStop() { + generationJob?.cancel() + super.onStop() + } + + override fun onDestroy() { + engine.destroy() + super.onDestroy() + } + companion object { private val TAG = MainActivity::class.java.simpleName diff --git a/examples/llama.android/app/src/main/res/layout/activity_main.xml b/examples/llama.android/app/src/main/res/layout/activity_main.xml index ad805a674e..d15772bd37 100644 --- a/examples/llama.android/app/src/main/res/layout/activity_main.xml +++ b/examples/llama.android/app/src/main/res/layout/activity_main.xml @@ -24,7 +24,7 @@ android:id="@+id/gguf" android:layout_width="match_parent" android:layout_height="wrap_content" - android:layout_margin="16dp" + android:padding="16dp" android:text="Selected GGUF model's metadata will show here." style="@style/TextAppearance.MaterialComponents.Body2" /> @@ -33,8 +33,7 @@ + android:layout_marginHorizontal="16dp" /> (InferenceEngine.State.Uninitialized) - override val state: StateFlow = _state + override val state: StateFlow = _state.asStateFlow() private var _readyForSystemPrompt = false + @Volatile + private var _cancelGeneration = false /** * Single-threaded coroutine dispatcher & scope for LLama asynchronous operations @@ -169,6 +173,8 @@ internal class InferenceEngineImpl private constructor( } Log.i(TAG, "Model loaded!") _readyForSystemPrompt = true + + _cancelGeneration = false _state.value = InferenceEngine.State.ModelReady } catch (e: Exception) { Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e) @@ -231,15 +237,19 @@ internal class InferenceEngineImpl private constructor( Log.i(TAG, "User prompt processed. Generating assistant prompt...") _state.value = InferenceEngine.State.Generating - while (true) { + while (!_cancelGeneration) { generateNextToken()?.let { utf8token -> if (utf8token.isNotEmpty()) emit(utf8token) } ?: break } - Log.i(TAG, "Assistant generation complete. Awaiting user prompt...") + if (_cancelGeneration) { + Log.i(TAG, "Assistant generation aborted per requested.") + } else { + Log.i(TAG, "Assistant generation complete. Awaiting user prompt...") + } _state.value = InferenceEngine.State.ModelReady } catch (e: CancellationException) { - Log.i(TAG, "Generation cancelled by user.") + Log.i(TAG, "Assistant generation's flow collection cancelled.") _state.value = InferenceEngine.State.ModelReady throw e } catch (e: Exception) { @@ -268,8 +278,9 @@ internal class InferenceEngineImpl private constructor( /** * Unloads the model and frees resources, or reset error states */ - override suspend fun cleanUp() = - withContext(llamaDispatcher) { + override fun cleanUp() { + _cancelGeneration = true + runBlocking(llamaDispatcher) { when (val state = _state.value) { is InferenceEngine.State.ModelReady -> { Log.i(TAG, "Unloading model and free resources...") @@ -293,17 +304,21 @@ internal class InferenceEngineImpl private constructor( else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}") } } + } /** * Cancel all ongoing coroutines and free GGML backends */ override fun destroy() { - _readyForSystemPrompt = false - llamaScope.cancel() - when(_state.value) { - is InferenceEngine.State.Uninitialized -> {} - is InferenceEngine.State.Initialized -> shutdown() - else -> { unload(); shutdown() } + _cancelGeneration = true + runBlocking(llamaDispatcher) { + _readyForSystemPrompt = false + when(_state.value) { + is InferenceEngine.State.Uninitialized -> {} + is InferenceEngine.State.Initialized -> shutdown() + else -> { unload(); shutdown() } + } } + llamaScope.cancel() } }