bugfix: handle user quitting on model loading

This commit is contained in:
Han Yin 2025-04-12 22:21:50 -07:00
parent e8b84c6ebf
commit 6b341b0fbe
3 changed files with 45 additions and 21 deletions

View File

@ -22,7 +22,7 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalLifecycleOwner import androidx.lifecycle.compose.LocalLifecycleOwner
import androidx.lifecycle.viewmodel.compose.viewModel import androidx.lifecycle.viewmodel.compose.viewModel
import androidx.navigation.compose.NavHost import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable import androidx.navigation.compose.composable
@ -42,6 +42,7 @@ import com.example.llama.revamp.ui.screens.SettingsGeneralScreen
import com.example.llama.revamp.ui.theme.LlamaTheme import com.example.llama.revamp.ui.theme.LlamaTheme
import com.example.llama.revamp.util.ViewModelFactoryProvider import com.example.llama.revamp.util.ViewModelFactoryProvider
import com.example.llama.revamp.viewmodel.MainViewModel import com.example.llama.revamp.viewmodel.MainViewModel
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
class MainActivity : ComponentActivity() { class MainActivity : ComponentActivity() {
@ -122,7 +123,10 @@ fun AppContent() {
// Helper function to handle back press with model unloading check // Helper function to handle back press with model unloading check
val handleBackWithModelCheck = { val handleBackWithModelCheck = {
if (viewModel.isModelLoaded()) { if (viewModel.isModelLoading()) {
// If model is still loading, ignore the request
true // Mark as handled
} else if (viewModel.isModelLoaded()) {
showUnloadDialog = true showUnloadDialog = true
pendingNavigation = { navController.popBackStack() } pendingNavigation = { navController.popBackStack() }
true // Mark as handled true // Mark as handled
@ -211,8 +215,19 @@ fun AppContent() {
navigationActions.navigateToBenchmark() navigationActions.navigateToBenchmark()
}, },
onConversationSelected = { systemPrompt -> onConversationSelected = { systemPrompt ->
viewModel.prepareForConversation(systemPrompt) // Store a reference to the loading job
navigationActions.navigateToConversation() val loadingJob = coroutineScope.launch {
viewModel.prepareForConversation(systemPrompt)
// Check if the job wasn't cancelled before navigating
if (isActive) {
navigationActions.navigateToConversation()
}
}
// Update the pendingNavigation handler to cancel any ongoing loading
pendingNavigation = {
loadingJob.cancel()
navController.popBackStack()
}
}, },
onBackPressed = { onBackPressed = {
// Need to unload model before going back // Need to unload model before going back

View File

@ -57,7 +57,7 @@ class InferenceEngine {
_state.value = State.LoadingModel _state.value = State.LoadingModel
// Simulate model loading // Simulate model loading
delay(1000) delay(2000)
_state.value = State.ModelLoaded _state.value = State.ModelLoaded
@ -65,7 +65,7 @@ class InferenceEngine {
_state.value = State.ProcessingSystemPrompt _state.value = State.ProcessingSystemPrompt
// Simulate processing system prompt // Simulate processing system prompt
delay(500) delay(3000)
} }
_state.value = State.AwaitingUserPrompt _state.value = State.AwaitingUserPrompt
@ -127,7 +127,7 @@ class InferenceEngine {
try { try {
// Simulate benchmark running // Simulate benchmark running
delay(2000) delay(4000)
// Generate fake benchmark results // Generate fake benchmark results
val modelDesc = "LlamaModel" val modelDesc = "LlamaModel"
@ -170,7 +170,7 @@ class InferenceEngine {
*/ */
suspend fun unloadModel() { suspend fun unloadModel() {
// Simulate model unloading time // Simulate model unloading time
delay(300) delay(2000)
_state.value = State.LibraryLoaded _state.value = State.LibraryLoaded
_benchmarkResults = null _benchmarkResults = null
_benchmarkResultsFlow.value = null _benchmarkResultsFlow.value = null

View File

@ -87,12 +87,10 @@ class MainViewModel(
/** /**
* Prepares the engine for conversation mode. * Prepares the engine for conversation mode.
*/ */
fun prepareForConversation(systemPrompt: String? = null) { suspend fun prepareForConversation(systemPrompt: String? = null) {
_systemPrompt.value = systemPrompt _systemPrompt.value = systemPrompt
viewModelScope.launch { _selectedModel.value?.let { model ->
_selectedModel.value?.let { model -> inferenceEngine.loadModel(model.path, systemPrompt)
inferenceEngine.loadModel(model.path, systemPrompt)
}
} }
} }
@ -131,7 +129,8 @@ class MainViewModel(
val currentMessages = _messages.value.toMutableList() val currentMessages = _messages.value.toMutableList()
if (currentMessages.size >= 2) { if (currentMessages.size >= 2) {
val messageIndex = currentMessages.size - 1 val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant val currentAssistantMessage =
currentMessages[messageIndex] as? Message.Assistant
if (currentAssistantMessage != null) { if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy( currentMessages[messageIndex] = currentAssistantMessage.copy(
content = "${response}[Error: ${e.message}]", content = "${response}[Error: ${e.message}]",
@ -146,7 +145,8 @@ class MainViewModel(
val currentMessages = _messages.value.toMutableList() val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) { if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1 val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant val currentAssistantMessage =
currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) { if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy( currentMessages[messageIndex] = currentAssistantMessage.copy(
isComplete = true isComplete = true
@ -162,7 +162,8 @@ class MainViewModel(
val currentMessages = _messages.value.toMutableList() val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) { if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1 val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant val currentAssistantMessage =
currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) { if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy( currentMessages[messageIndex] = currentAssistantMessage.copy(
content = response.toString(), content = response.toString(),
@ -177,7 +178,8 @@ class MainViewModel(
val currentMessages = _messages.value.toMutableList() val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) { if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1 val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant val currentAssistantMessage =
currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) { if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy( currentMessages[messageIndex] = currentAssistantMessage.copy(
content = "${response}[Error: ${e.message}]", content = "${response}[Error: ${e.message}]",
@ -204,13 +206,20 @@ class MainViewModel(
inferenceEngine.unloadModel() inferenceEngine.unloadModel()
} }
fun isModelLoading() =
engineState.value.let {
it is InferenceEngine.State.LoadingModel
|| it is InferenceEngine.State.ProcessingSystemPrompt
}
/** /**
* Checks if a model is currently loaded. * Checks if a model is currently loaded.
*/ */
fun isModelLoaded(): Boolean { fun isModelLoaded() =
return engineState.value !is InferenceEngine.State.Uninitialized && engineState.value.let {
engineState.value !is InferenceEngine.State.LibraryLoaded it !is InferenceEngine.State.Uninitialized
} && it !is InferenceEngine.State.LibraryLoaded
}
/** /**
* Clean up resources when ViewModel is cleared. * Clean up resources when ViewModel is cleared.