bugfix: handle user quitting on model loading
This commit is contained in:
parent
e8b84c6ebf
commit
6b341b0fbe
|
|
@ -22,7 +22,7 @@ import androidx.compose.runtime.remember
|
||||||
import androidx.compose.runtime.rememberCoroutineScope
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
import androidx.compose.runtime.setValue
|
import androidx.compose.runtime.setValue
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.platform.LocalLifecycleOwner
|
import androidx.lifecycle.compose.LocalLifecycleOwner
|
||||||
import androidx.lifecycle.viewmodel.compose.viewModel
|
import androidx.lifecycle.viewmodel.compose.viewModel
|
||||||
import androidx.navigation.compose.NavHost
|
import androidx.navigation.compose.NavHost
|
||||||
import androidx.navigation.compose.composable
|
import androidx.navigation.compose.composable
|
||||||
|
|
@ -42,6 +42,7 @@ import com.example.llama.revamp.ui.screens.SettingsGeneralScreen
|
||||||
import com.example.llama.revamp.ui.theme.LlamaTheme
|
import com.example.llama.revamp.ui.theme.LlamaTheme
|
||||||
import com.example.llama.revamp.util.ViewModelFactoryProvider
|
import com.example.llama.revamp.util.ViewModelFactoryProvider
|
||||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
import com.example.llama.revamp.viewmodel.MainViewModel
|
||||||
|
import kotlinx.coroutines.isActive
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
class MainActivity : ComponentActivity() {
|
class MainActivity : ComponentActivity() {
|
||||||
|
|
@ -122,7 +123,10 @@ fun AppContent() {
|
||||||
|
|
||||||
// Helper function to handle back press with model unloading check
|
// Helper function to handle back press with model unloading check
|
||||||
val handleBackWithModelCheck = {
|
val handleBackWithModelCheck = {
|
||||||
if (viewModel.isModelLoaded()) {
|
if (viewModel.isModelLoading()) {
|
||||||
|
// If model is still loading, ignore the request
|
||||||
|
true // Mark as handled
|
||||||
|
} else if (viewModel.isModelLoaded()) {
|
||||||
showUnloadDialog = true
|
showUnloadDialog = true
|
||||||
pendingNavigation = { navController.popBackStack() }
|
pendingNavigation = { navController.popBackStack() }
|
||||||
true // Mark as handled
|
true // Mark as handled
|
||||||
|
|
@ -211,8 +215,19 @@ fun AppContent() {
|
||||||
navigationActions.navigateToBenchmark()
|
navigationActions.navigateToBenchmark()
|
||||||
},
|
},
|
||||||
onConversationSelected = { systemPrompt ->
|
onConversationSelected = { systemPrompt ->
|
||||||
viewModel.prepareForConversation(systemPrompt)
|
// Store a reference to the loading job
|
||||||
navigationActions.navigateToConversation()
|
val loadingJob = coroutineScope.launch {
|
||||||
|
viewModel.prepareForConversation(systemPrompt)
|
||||||
|
// Check if the job wasn't cancelled before navigating
|
||||||
|
if (isActive) {
|
||||||
|
navigationActions.navigateToConversation()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Update the pendingNavigation handler to cancel any ongoing loading
|
||||||
|
pendingNavigation = {
|
||||||
|
loadingJob.cancel()
|
||||||
|
navController.popBackStack()
|
||||||
|
}
|
||||||
},
|
},
|
||||||
onBackPressed = {
|
onBackPressed = {
|
||||||
// Need to unload model before going back
|
// Need to unload model before going back
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ class InferenceEngine {
|
||||||
_state.value = State.LoadingModel
|
_state.value = State.LoadingModel
|
||||||
|
|
||||||
// Simulate model loading
|
// Simulate model loading
|
||||||
delay(1000)
|
delay(2000)
|
||||||
|
|
||||||
_state.value = State.ModelLoaded
|
_state.value = State.ModelLoaded
|
||||||
|
|
||||||
|
|
@ -65,7 +65,7 @@ class InferenceEngine {
|
||||||
_state.value = State.ProcessingSystemPrompt
|
_state.value = State.ProcessingSystemPrompt
|
||||||
|
|
||||||
// Simulate processing system prompt
|
// Simulate processing system prompt
|
||||||
delay(500)
|
delay(3000)
|
||||||
}
|
}
|
||||||
|
|
||||||
_state.value = State.AwaitingUserPrompt
|
_state.value = State.AwaitingUserPrompt
|
||||||
|
|
@ -127,7 +127,7 @@ class InferenceEngine {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Simulate benchmark running
|
// Simulate benchmark running
|
||||||
delay(2000)
|
delay(4000)
|
||||||
|
|
||||||
// Generate fake benchmark results
|
// Generate fake benchmark results
|
||||||
val modelDesc = "LlamaModel"
|
val modelDesc = "LlamaModel"
|
||||||
|
|
@ -170,7 +170,7 @@ class InferenceEngine {
|
||||||
*/
|
*/
|
||||||
suspend fun unloadModel() {
|
suspend fun unloadModel() {
|
||||||
// Simulate model unloading time
|
// Simulate model unloading time
|
||||||
delay(300)
|
delay(2000)
|
||||||
_state.value = State.LibraryLoaded
|
_state.value = State.LibraryLoaded
|
||||||
_benchmarkResults = null
|
_benchmarkResults = null
|
||||||
_benchmarkResultsFlow.value = null
|
_benchmarkResultsFlow.value = null
|
||||||
|
|
|
||||||
|
|
@ -87,12 +87,10 @@ class MainViewModel(
|
||||||
/**
|
/**
|
||||||
* Prepares the engine for conversation mode.
|
* Prepares the engine for conversation mode.
|
||||||
*/
|
*/
|
||||||
fun prepareForConversation(systemPrompt: String? = null) {
|
suspend fun prepareForConversation(systemPrompt: String? = null) {
|
||||||
_systemPrompt.value = systemPrompt
|
_systemPrompt.value = systemPrompt
|
||||||
viewModelScope.launch {
|
_selectedModel.value?.let { model ->
|
||||||
_selectedModel.value?.let { model ->
|
inferenceEngine.loadModel(model.path, systemPrompt)
|
||||||
inferenceEngine.loadModel(model.path, systemPrompt)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -131,7 +129,8 @@ class MainViewModel(
|
||||||
val currentMessages = _messages.value.toMutableList()
|
val currentMessages = _messages.value.toMutableList()
|
||||||
if (currentMessages.size >= 2) {
|
if (currentMessages.size >= 2) {
|
||||||
val messageIndex = currentMessages.size - 1
|
val messageIndex = currentMessages.size - 1
|
||||||
val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant
|
val currentAssistantMessage =
|
||||||
|
currentMessages[messageIndex] as? Message.Assistant
|
||||||
if (currentAssistantMessage != null) {
|
if (currentAssistantMessage != null) {
|
||||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||||
content = "${response}[Error: ${e.message}]",
|
content = "${response}[Error: ${e.message}]",
|
||||||
|
|
@ -146,7 +145,8 @@ class MainViewModel(
|
||||||
val currentMessages = _messages.value.toMutableList()
|
val currentMessages = _messages.value.toMutableList()
|
||||||
if (currentMessages.isNotEmpty()) {
|
if (currentMessages.isNotEmpty()) {
|
||||||
val messageIndex = currentMessages.size - 1
|
val messageIndex = currentMessages.size - 1
|
||||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
val currentAssistantMessage =
|
||||||
|
currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||||
if (currentAssistantMessage != null) {
|
if (currentAssistantMessage != null) {
|
||||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||||
isComplete = true
|
isComplete = true
|
||||||
|
|
@ -162,7 +162,8 @@ class MainViewModel(
|
||||||
val currentMessages = _messages.value.toMutableList()
|
val currentMessages = _messages.value.toMutableList()
|
||||||
if (currentMessages.isNotEmpty()) {
|
if (currentMessages.isNotEmpty()) {
|
||||||
val messageIndex = currentMessages.size - 1
|
val messageIndex = currentMessages.size - 1
|
||||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
val currentAssistantMessage =
|
||||||
|
currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||||
if (currentAssistantMessage != null) {
|
if (currentAssistantMessage != null) {
|
||||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||||
content = response.toString(),
|
content = response.toString(),
|
||||||
|
|
@ -177,7 +178,8 @@ class MainViewModel(
|
||||||
val currentMessages = _messages.value.toMutableList()
|
val currentMessages = _messages.value.toMutableList()
|
||||||
if (currentMessages.isNotEmpty()) {
|
if (currentMessages.isNotEmpty()) {
|
||||||
val messageIndex = currentMessages.size - 1
|
val messageIndex = currentMessages.size - 1
|
||||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
val currentAssistantMessage =
|
||||||
|
currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||||
if (currentAssistantMessage != null) {
|
if (currentAssistantMessage != null) {
|
||||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||||
content = "${response}[Error: ${e.message}]",
|
content = "${response}[Error: ${e.message}]",
|
||||||
|
|
@ -204,13 +206,20 @@ class MainViewModel(
|
||||||
inferenceEngine.unloadModel()
|
inferenceEngine.unloadModel()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun isModelLoading() =
|
||||||
|
engineState.value.let {
|
||||||
|
it is InferenceEngine.State.LoadingModel
|
||||||
|
|| it is InferenceEngine.State.ProcessingSystemPrompt
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks if a model is currently loaded.
|
* Checks if a model is currently loaded.
|
||||||
*/
|
*/
|
||||||
fun isModelLoaded(): Boolean {
|
fun isModelLoaded() =
|
||||||
return engineState.value !is InferenceEngine.State.Uninitialized &&
|
engineState.value.let {
|
||||||
engineState.value !is InferenceEngine.State.LibraryLoaded
|
it !is InferenceEngine.State.Uninitialized
|
||||||
}
|
&& it !is InferenceEngine.State.LibraryLoaded
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Clean up resources when ViewModel is cleared.
|
* Clean up resources when ViewModel is cleared.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue