From 6b341b0fbe68bd6e64c29e4e6ebc440a45862129 Mon Sep 17 00:00:00 2001 From: Han Yin Date: Sat, 12 Apr 2025 22:21:50 -0700 Subject: [PATCH] bugfix: handle user quitting on model loading --- .../com/example/llama/revamp/MainActivity.kt | 23 +++++++++--- .../llama/revamp/engine/InferenceEngine.kt | 8 ++--- .../llama/revamp/viewmodel/MainViewModel.kt | 35 ++++++++++++------- 3 files changed, 45 insertions(+), 21 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt index a18a5d2efb..3142e5d977 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt @@ -22,7 +22,7 @@ import androidx.compose.runtime.remember import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.setValue import androidx.compose.ui.Modifier -import androidx.compose.ui.platform.LocalLifecycleOwner +import androidx.lifecycle.compose.LocalLifecycleOwner import androidx.lifecycle.viewmodel.compose.viewModel import androidx.navigation.compose.NavHost import androidx.navigation.compose.composable @@ -42,6 +42,7 @@ import com.example.llama.revamp.ui.screens.SettingsGeneralScreen import com.example.llama.revamp.ui.theme.LlamaTheme import com.example.llama.revamp.util.ViewModelFactoryProvider import com.example.llama.revamp.viewmodel.MainViewModel +import kotlinx.coroutines.isActive import kotlinx.coroutines.launch class MainActivity : ComponentActivity() { @@ -122,7 +123,10 @@ fun AppContent() { // Helper function to handle back press with model unloading check val handleBackWithModelCheck = { - if (viewModel.isModelLoaded()) { + if (viewModel.isModelLoading()) { + // If model is still loading, ignore the request + true // Mark as handled + } else if (viewModel.isModelLoaded()) { showUnloadDialog = true pendingNavigation = { navController.popBackStack() } true // Mark as handled @@ -211,8 +215,19 @@ fun AppContent() { navigationActions.navigateToBenchmark() }, onConversationSelected = { systemPrompt -> - viewModel.prepareForConversation(systemPrompt) - navigationActions.navigateToConversation() + // Store a reference to the loading job + val loadingJob = coroutineScope.launch { + viewModel.prepareForConversation(systemPrompt) + // Check if the job wasn't cancelled before navigating + if (isActive) { + navigationActions.navigateToConversation() + } + } + // Update the pendingNavigation handler to cancel any ongoing loading + pendingNavigation = { + loadingJob.cancel() + navController.popBackStack() + } }, onBackPressed = { // Need to unload model before going back diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt index 50ea4b7028..9c175d56d9 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt @@ -57,7 +57,7 @@ class InferenceEngine { _state.value = State.LoadingModel // Simulate model loading - delay(1000) + delay(2000) _state.value = State.ModelLoaded @@ -65,7 +65,7 @@ class InferenceEngine { _state.value = State.ProcessingSystemPrompt // Simulate processing system prompt - delay(500) + delay(3000) } _state.value = State.AwaitingUserPrompt @@ -127,7 +127,7 @@ class InferenceEngine { try { // Simulate benchmark running - delay(2000) + delay(4000) // Generate fake benchmark results val modelDesc = "LlamaModel" @@ -170,7 +170,7 @@ class InferenceEngine { */ suspend fun unloadModel() { // Simulate model unloading time - delay(300) + delay(2000) _state.value = State.LibraryLoaded _benchmarkResults = null _benchmarkResultsFlow.value = null diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt index ea3798328a..8aae216c28 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt @@ -87,12 +87,10 @@ class MainViewModel( /** * Prepares the engine for conversation mode. */ - fun prepareForConversation(systemPrompt: String? = null) { + suspend fun prepareForConversation(systemPrompt: String? = null) { _systemPrompt.value = systemPrompt - viewModelScope.launch { - _selectedModel.value?.let { model -> - inferenceEngine.loadModel(model.path, systemPrompt) - } + _selectedModel.value?.let { model -> + inferenceEngine.loadModel(model.path, systemPrompt) } } @@ -131,7 +129,8 @@ class MainViewModel( val currentMessages = _messages.value.toMutableList() if (currentMessages.size >= 2) { val messageIndex = currentMessages.size - 1 - val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant + val currentAssistantMessage = + currentMessages[messageIndex] as? Message.Assistant if (currentAssistantMessage != null) { currentMessages[messageIndex] = currentAssistantMessage.copy( content = "${response}[Error: ${e.message}]", @@ -146,7 +145,8 @@ class MainViewModel( val currentMessages = _messages.value.toMutableList() if (currentMessages.isNotEmpty()) { val messageIndex = currentMessages.size - 1 - val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant + val currentAssistantMessage = + currentMessages.getOrNull(messageIndex) as? Message.Assistant if (currentAssistantMessage != null) { currentMessages[messageIndex] = currentAssistantMessage.copy( isComplete = true @@ -162,7 +162,8 @@ class MainViewModel( val currentMessages = _messages.value.toMutableList() if (currentMessages.isNotEmpty()) { val messageIndex = currentMessages.size - 1 - val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant + val currentAssistantMessage = + currentMessages.getOrNull(messageIndex) as? Message.Assistant if (currentAssistantMessage != null) { currentMessages[messageIndex] = currentAssistantMessage.copy( content = response.toString(), @@ -177,7 +178,8 @@ class MainViewModel( val currentMessages = _messages.value.toMutableList() if (currentMessages.isNotEmpty()) { val messageIndex = currentMessages.size - 1 - val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant + val currentAssistantMessage = + currentMessages.getOrNull(messageIndex) as? Message.Assistant if (currentAssistantMessage != null) { currentMessages[messageIndex] = currentAssistantMessage.copy( content = "${response}[Error: ${e.message}]", @@ -204,13 +206,20 @@ class MainViewModel( inferenceEngine.unloadModel() } + fun isModelLoading() = + engineState.value.let { + it is InferenceEngine.State.LoadingModel + || it is InferenceEngine.State.ProcessingSystemPrompt + } + /** * Checks if a model is currently loaded. */ - fun isModelLoaded(): Boolean { - return engineState.value !is InferenceEngine.State.Uninitialized && - engineState.value !is InferenceEngine.State.LibraryLoaded - } + fun isModelLoaded() = + engineState.value.let { + it !is InferenceEngine.State.Uninitialized + && it !is InferenceEngine.State.LibraryLoaded + } /** * Clean up resources when ViewModel is cleared.