bugfix: handle user quitting on model loading

This commit is contained in:
Han Yin 2025-04-12 22:21:50 -07:00
parent e8b84c6ebf
commit 6b341b0fbe
3 changed files with 45 additions and 21 deletions

View File

@ -22,7 +22,7 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalLifecycleOwner
import androidx.lifecycle.compose.LocalLifecycleOwner
import androidx.lifecycle.viewmodel.compose.viewModel
import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable
@ -42,6 +42,7 @@ import com.example.llama.revamp.ui.screens.SettingsGeneralScreen
import com.example.llama.revamp.ui.theme.LlamaTheme
import com.example.llama.revamp.util.ViewModelFactoryProvider
import com.example.llama.revamp.viewmodel.MainViewModel
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
class MainActivity : ComponentActivity() {
@ -122,7 +123,10 @@ fun AppContent() {
// Helper function to handle back press with model unloading check
val handleBackWithModelCheck = {
if (viewModel.isModelLoaded()) {
if (viewModel.isModelLoading()) {
// If model is still loading, ignore the request
true // Mark as handled
} else if (viewModel.isModelLoaded()) {
showUnloadDialog = true
pendingNavigation = { navController.popBackStack() }
true // Mark as handled
@ -211,8 +215,19 @@ fun AppContent() {
navigationActions.navigateToBenchmark()
},
onConversationSelected = { systemPrompt ->
viewModel.prepareForConversation(systemPrompt)
navigationActions.navigateToConversation()
// Store a reference to the loading job
val loadingJob = coroutineScope.launch {
viewModel.prepareForConversation(systemPrompt)
// Check if the job wasn't cancelled before navigating
if (isActive) {
navigationActions.navigateToConversation()
}
}
// Update the pendingNavigation handler to cancel any ongoing loading
pendingNavigation = {
loadingJob.cancel()
navController.popBackStack()
}
},
onBackPressed = {
// Need to unload model before going back

View File

@ -57,7 +57,7 @@ class InferenceEngine {
_state.value = State.LoadingModel
// Simulate model loading
delay(1000)
delay(2000)
_state.value = State.ModelLoaded
@ -65,7 +65,7 @@ class InferenceEngine {
_state.value = State.ProcessingSystemPrompt
// Simulate processing system prompt
delay(500)
delay(3000)
}
_state.value = State.AwaitingUserPrompt
@ -127,7 +127,7 @@ class InferenceEngine {
try {
// Simulate benchmark running
delay(2000)
delay(4000)
// Generate fake benchmark results
val modelDesc = "LlamaModel"
@ -170,7 +170,7 @@ class InferenceEngine {
*/
suspend fun unloadModel() {
// Simulate model unloading time
delay(300)
delay(2000)
_state.value = State.LibraryLoaded
_benchmarkResults = null
_benchmarkResultsFlow.value = null

View File

@ -87,12 +87,10 @@ class MainViewModel(
/**
* Prepares the engine for conversation mode.
*/
fun prepareForConversation(systemPrompt: String? = null) {
suspend fun prepareForConversation(systemPrompt: String? = null) {
_systemPrompt.value = systemPrompt
viewModelScope.launch {
_selectedModel.value?.let { model ->
inferenceEngine.loadModel(model.path, systemPrompt)
}
_selectedModel.value?.let { model ->
inferenceEngine.loadModel(model.path, systemPrompt)
}
}
@ -131,7 +129,8 @@ class MainViewModel(
val currentMessages = _messages.value.toMutableList()
if (currentMessages.size >= 2) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant
val currentAssistantMessage =
currentMessages[messageIndex] as? Message.Assistant
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
content = "${response}[Error: ${e.message}]",
@ -146,7 +145,8 @@ class MainViewModel(
val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
val currentAssistantMessage =
currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
isComplete = true
@ -162,7 +162,8 @@ class MainViewModel(
val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
val currentAssistantMessage =
currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
content = response.toString(),
@ -177,7 +178,8 @@ class MainViewModel(
val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
val currentAssistantMessage =
currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
content = "${response}[Error: ${e.message}]",
@ -204,13 +206,20 @@ class MainViewModel(
inferenceEngine.unloadModel()
}
fun isModelLoading() =
engineState.value.let {
it is InferenceEngine.State.LoadingModel
|| it is InferenceEngine.State.ProcessingSystemPrompt
}
/**
* Checks if a model is currently loaded.
*/
fun isModelLoaded(): Boolean {
return engineState.value !is InferenceEngine.State.Uninitialized &&
engineState.value !is InferenceEngine.State.LibraryLoaded
}
fun isModelLoaded() =
engineState.value.let {
it !is InferenceEngine.State.Uninitialized
&& it !is InferenceEngine.State.LibraryLoaded
}
/**
* Clean up resources when ViewModel is cleared.