bugfix: handle user quitting on model loading
This commit is contained in:
parent
e8b84c6ebf
commit
6b341b0fbe
|
|
@ -22,7 +22,7 @@ import androidx.compose.runtime.remember
|
|||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalLifecycleOwner
|
||||
import androidx.lifecycle.compose.LocalLifecycleOwner
|
||||
import androidx.lifecycle.viewmodel.compose.viewModel
|
||||
import androidx.navigation.compose.NavHost
|
||||
import androidx.navigation.compose.composable
|
||||
|
|
@ -42,6 +42,7 @@ import com.example.llama.revamp.ui.screens.SettingsGeneralScreen
|
|||
import com.example.llama.revamp.ui.theme.LlamaTheme
|
||||
import com.example.llama.revamp.util.ViewModelFactoryProvider
|
||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
class MainActivity : ComponentActivity() {
|
||||
|
|
@ -122,7 +123,10 @@ fun AppContent() {
|
|||
|
||||
// Helper function to handle back press with model unloading check
|
||||
val handleBackWithModelCheck = {
|
||||
if (viewModel.isModelLoaded()) {
|
||||
if (viewModel.isModelLoading()) {
|
||||
// If model is still loading, ignore the request
|
||||
true // Mark as handled
|
||||
} else if (viewModel.isModelLoaded()) {
|
||||
showUnloadDialog = true
|
||||
pendingNavigation = { navController.popBackStack() }
|
||||
true // Mark as handled
|
||||
|
|
@ -211,8 +215,19 @@ fun AppContent() {
|
|||
navigationActions.navigateToBenchmark()
|
||||
},
|
||||
onConversationSelected = { systemPrompt ->
|
||||
viewModel.prepareForConversation(systemPrompt)
|
||||
navigationActions.navigateToConversation()
|
||||
// Store a reference to the loading job
|
||||
val loadingJob = coroutineScope.launch {
|
||||
viewModel.prepareForConversation(systemPrompt)
|
||||
// Check if the job wasn't cancelled before navigating
|
||||
if (isActive) {
|
||||
navigationActions.navigateToConversation()
|
||||
}
|
||||
}
|
||||
// Update the pendingNavigation handler to cancel any ongoing loading
|
||||
pendingNavigation = {
|
||||
loadingJob.cancel()
|
||||
navController.popBackStack()
|
||||
}
|
||||
},
|
||||
onBackPressed = {
|
||||
// Need to unload model before going back
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class InferenceEngine {
|
|||
_state.value = State.LoadingModel
|
||||
|
||||
// Simulate model loading
|
||||
delay(1000)
|
||||
delay(2000)
|
||||
|
||||
_state.value = State.ModelLoaded
|
||||
|
||||
|
|
@ -65,7 +65,7 @@ class InferenceEngine {
|
|||
_state.value = State.ProcessingSystemPrompt
|
||||
|
||||
// Simulate processing system prompt
|
||||
delay(500)
|
||||
delay(3000)
|
||||
}
|
||||
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
|
|
@ -127,7 +127,7 @@ class InferenceEngine {
|
|||
|
||||
try {
|
||||
// Simulate benchmark running
|
||||
delay(2000)
|
||||
delay(4000)
|
||||
|
||||
// Generate fake benchmark results
|
||||
val modelDesc = "LlamaModel"
|
||||
|
|
@ -170,7 +170,7 @@ class InferenceEngine {
|
|||
*/
|
||||
suspend fun unloadModel() {
|
||||
// Simulate model unloading time
|
||||
delay(300)
|
||||
delay(2000)
|
||||
_state.value = State.LibraryLoaded
|
||||
_benchmarkResults = null
|
||||
_benchmarkResultsFlow.value = null
|
||||
|
|
|
|||
|
|
@ -87,12 +87,10 @@ class MainViewModel(
|
|||
/**
|
||||
* Prepares the engine for conversation mode.
|
||||
*/
|
||||
fun prepareForConversation(systemPrompt: String? = null) {
|
||||
suspend fun prepareForConversation(systemPrompt: String? = null) {
|
||||
_systemPrompt.value = systemPrompt
|
||||
viewModelScope.launch {
|
||||
_selectedModel.value?.let { model ->
|
||||
inferenceEngine.loadModel(model.path, systemPrompt)
|
||||
}
|
||||
_selectedModel.value?.let { model ->
|
||||
inferenceEngine.loadModel(model.path, systemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -131,7 +129,8 @@ class MainViewModel(
|
|||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.size >= 2) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant
|
||||
val currentAssistantMessage =
|
||||
currentMessages[messageIndex] as? Message.Assistant
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||
content = "${response}[Error: ${e.message}]",
|
||||
|
|
@ -146,7 +145,8 @@ class MainViewModel(
|
|||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.isNotEmpty()) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||
val currentAssistantMessage =
|
||||
currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||
isComplete = true
|
||||
|
|
@ -162,7 +162,8 @@ class MainViewModel(
|
|||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.isNotEmpty()) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||
val currentAssistantMessage =
|
||||
currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||
content = response.toString(),
|
||||
|
|
@ -177,7 +178,8 @@ class MainViewModel(
|
|||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.isNotEmpty()) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||
val currentAssistantMessage =
|
||||
currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||
content = "${response}[Error: ${e.message}]",
|
||||
|
|
@ -204,13 +206,20 @@ class MainViewModel(
|
|||
inferenceEngine.unloadModel()
|
||||
}
|
||||
|
||||
fun isModelLoading() =
|
||||
engineState.value.let {
|
||||
it is InferenceEngine.State.LoadingModel
|
||||
|| it is InferenceEngine.State.ProcessingSystemPrompt
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a model is currently loaded.
|
||||
*/
|
||||
fun isModelLoaded(): Boolean {
|
||||
return engineState.value !is InferenceEngine.State.Uninitialized &&
|
||||
engineState.value !is InferenceEngine.State.LibraryLoaded
|
||||
}
|
||||
fun isModelLoaded() =
|
||||
engineState.value.let {
|
||||
it !is InferenceEngine.State.Uninitialized
|
||||
&& it !is InferenceEngine.State.LibraryLoaded
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up resources when ViewModel is cleared.
|
||||
|
|
|
|||
Loading…
Reference in New Issue