VM: handle the cancellation of ongoing token generation
This commit is contained in:
parent
46859c10f0
commit
9f77155535
|
|
@ -40,6 +40,7 @@ import com.example.llama.revamp.ui.screens.ModelSelectionScreen
|
|||
import com.example.llama.revamp.ui.screens.ModelsManagementScreen
|
||||
import com.example.llama.revamp.ui.screens.SettingsGeneralScreen
|
||||
import com.example.llama.revamp.ui.theme.LlamaTheme
|
||||
import com.example.llama.revamp.viewmodel.ConversationViewModel
|
||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
||||
import dagger.hilt.android.AndroidEntryPoint
|
||||
import kotlinx.coroutines.isActive
|
||||
|
|
@ -64,18 +65,27 @@ class MainActivity : ComponentActivity() {
|
|||
|
||||
@Composable
|
||||
fun AppContent(
|
||||
mainVewModel: MainViewModel = hiltViewModel()
|
||||
mainViewModel: MainViewModel = hiltViewModel(),
|
||||
conversationViewModel: ConversationViewModel = hiltViewModel(),
|
||||
) {
|
||||
// Lifecycle and Coroutine scope
|
||||
val lifecycleOwner = LocalLifecycleOwner.current
|
||||
val coroutineScope = rememberCoroutineScope()
|
||||
|
||||
// LLM Inference engine status
|
||||
val engineState by mainVewModel.engineState.collectAsState()
|
||||
val isModelLoading = engineState is State.LoadingModel
|
||||
|| engineState is State.ProcessingSystemPrompt
|
||||
val isModelLoaded = engineState !is State.Uninitialized
|
||||
&& engineState !is State.LibraryLoaded
|
||||
// Inference engine state
|
||||
val engineState by mainViewModel.engineState.collectAsState()
|
||||
val isModelUninterruptible by remember(engineState) {
|
||||
derivedStateOf {
|
||||
engineState is State.LoadingModel
|
||||
|| engineState is State.Benchmarking
|
||||
|| engineState is State.ProcessingUserPrompt
|
||||
|| engineState is State.ProcessingSystemPrompt
|
||||
}
|
||||
}
|
||||
val isModelLoaded by remember(engineState) {
|
||||
derivedStateOf {
|
||||
engineState !is State.Uninitialized && engineState !is State.LibraryLoaded
|
||||
}
|
||||
}
|
||||
|
||||
// Navigation
|
||||
val navController = rememberNavController()
|
||||
|
|
@ -103,16 +113,20 @@ fun AppContent(
|
|||
// Model unloading confirmation
|
||||
var showUnloadDialog by remember { mutableStateOf(false) }
|
||||
val handleBackWithModelCheck = {
|
||||
if (isModelLoading) {
|
||||
// If model is still loading, ignore the request
|
||||
true // Mark as handled
|
||||
} else if (isModelLoaded) {
|
||||
showUnloadDialog = true
|
||||
pendingNavigation = { navigationActions.navigateUp() }
|
||||
true // Mark as handled
|
||||
} else {
|
||||
navigationActions.navigateUp()
|
||||
true // Mark as handled
|
||||
when {
|
||||
isModelUninterruptible -> {
|
||||
// If model is non-interruptible at all, ignore the request
|
||||
true // Mark as handled
|
||||
}
|
||||
isModelLoaded -> {
|
||||
showUnloadDialog = true
|
||||
pendingNavigation = { navigationActions.navigateUp() }
|
||||
true // Mark as handled
|
||||
}
|
||||
else -> {
|
||||
navigationActions.navigateUp()
|
||||
true // Mark as handled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -216,16 +230,6 @@ fun AppContent(
|
|||
)
|
||||
}
|
||||
|
||||
// Conversation Screen
|
||||
composable(AppDestinations.CONVERSATION_ROUTE) {
|
||||
ConversationScreen(
|
||||
onBackPressed = {
|
||||
// Need to unload model before going back
|
||||
handleBackWithModelCheck()
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// Benchmark Screen
|
||||
composable(AppDestinations.BENCHMARK_ROUTE) {
|
||||
BenchmarkScreen(
|
||||
|
|
@ -236,6 +240,17 @@ fun AppContent(
|
|||
)
|
||||
}
|
||||
|
||||
// Conversation Screen
|
||||
composable(AppDestinations.CONVERSATION_ROUTE) {
|
||||
ConversationScreen(
|
||||
onBackPressed = {
|
||||
// Need to unload model before going back
|
||||
handleBackWithModelCheck()
|
||||
},
|
||||
viewModel = conversationViewModel
|
||||
)
|
||||
}
|
||||
|
||||
// Settings General Screen
|
||||
composable(AppDestinations.SETTINGS_GENERAL_ROUTE) {
|
||||
SettingsGeneralScreen(
|
||||
|
|
@ -261,7 +276,15 @@ fun AppContent(
|
|||
onConfirm = {
|
||||
isUnloading = true
|
||||
coroutineScope.launch {
|
||||
mainVewModel.unloadModel()
|
||||
// Handle screen specific cleanups
|
||||
when(engineState) {
|
||||
is State.Benchmarking -> {}
|
||||
is State.Generating -> conversationViewModel.clearConversation()
|
||||
else -> {}
|
||||
}
|
||||
|
||||
// Unload model
|
||||
mainViewModel.unloadModel()
|
||||
isUnloading = false
|
||||
showUnloadDialog = false
|
||||
pendingNavigation?.invoke()
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ import kotlinx.coroutines.launch
|
|||
@Composable
|
||||
fun ConversationScreen(
|
||||
onBackPressed: () -> Unit,
|
||||
viewModel: ConversationViewModel = hiltViewModel()
|
||||
viewModel: ConversationViewModel
|
||||
) {
|
||||
val engineState by viewModel.engineState.collectAsState()
|
||||
val messages by viewModel.messages.collectAsState()
|
||||
|
|
|
|||
|
|
@ -3,12 +3,15 @@ package com.example.llama.revamp.viewmodel
|
|||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.example.llama.revamp.engine.ConversationService
|
||||
import com.example.llama.revamp.engine.GenerationUpdate
|
||||
import com.example.llama.revamp.engine.TokenMetrics
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.flow.onCompletion
|
||||
import kotlinx.coroutines.launch
|
||||
import java.text.SimpleDateFormat
|
||||
import java.util.Date
|
||||
|
|
@ -29,18 +32,17 @@ class ConversationViewModel @Inject constructor(
|
|||
private val _messages = MutableStateFlow<List<Message>>(emptyList())
|
||||
val messages: StateFlow<List<Message>> = _messages.asStateFlow()
|
||||
|
||||
// Token generation job
|
||||
// Keep track of token generation job
|
||||
private var tokenCollectionJob: Job? = null
|
||||
|
||||
/**
|
||||
* Send a message with the provided content.
|
||||
* Note: This matches the existing UI which manages input state outside the ViewModel.
|
||||
* Send a message with the provided content
|
||||
*/
|
||||
fun sendMessage(content: String) {
|
||||
if (content.isBlank()) return
|
||||
|
||||
// Cancel ongoing collection
|
||||
tokenCollectionJob?.cancel()
|
||||
stopGeneration()
|
||||
|
||||
// Add user message
|
||||
val userMessage = Message.User(
|
||||
|
|
@ -60,39 +62,45 @@ class ConversationViewModel @Inject constructor(
|
|||
tokenCollectionJob = viewModelScope.launch {
|
||||
try {
|
||||
conversationService.generateResponse(content)
|
||||
.collect { (text, isComplete) ->
|
||||
updateAssistantMessage(text, isComplete)
|
||||
}
|
||||
.onCompletion { tokenCollectionJob = null }
|
||||
.collect(::updateAssistantMessage)
|
||||
|
||||
} catch (_: CancellationException) {
|
||||
handleCancellation()
|
||||
tokenCollectionJob = null
|
||||
|
||||
} catch (e: Exception) {
|
||||
// Handle error
|
||||
handleResponseError(e)
|
||||
tokenCollectionJob = null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle updating the assistant message
|
||||
* Stop ongoing generation
|
||||
*/
|
||||
private fun updateAssistantMessage(text: String, isComplete: Boolean) {
|
||||
fun stopGeneration() {
|
||||
tokenCollectionJob?.let { job ->
|
||||
// handled by the catch blocks
|
||||
if (job.isActive) { job.cancel() }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle the case when generation is explicitly cancelled
|
||||
*/
|
||||
private fun handleCancellation() {
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
val lastIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
|
||||
|
||||
if (currentAssistantMessage != null) {
|
||||
if (isComplete) {
|
||||
// Final message with metrics
|
||||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||
content = text,
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = conversationService.createTokenMetrics()
|
||||
)
|
||||
} else {
|
||||
// Ongoing message update
|
||||
currentMessages[lastIndex] = Message.Assistant.Ongoing(
|
||||
content = text,
|
||||
timestamp = currentAssistantMessage.timestamp
|
||||
)
|
||||
}
|
||||
// Replace with completed message, adding note that it was interrupted
|
||||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||
content = currentAssistantMessage.content + " [Generation stopped]",
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = conversationService.createTokenMetrics()
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
}
|
||||
|
|
@ -107,7 +115,7 @@ class ConversationViewModel @Inject constructor(
|
|||
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||
content = "${currentAssistantMessage.content}[Error: ${e.message}]",
|
||||
content = currentAssistantMessage.content + " [Error: ${e.message}]",
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = conversationService.createTokenMetrics()
|
||||
)
|
||||
|
|
@ -115,16 +123,43 @@ class ConversationViewModel @Inject constructor(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle updating the assistant message
|
||||
*/
|
||||
private fun updateAssistantMessage(update: GenerationUpdate) {
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
val lastIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
|
||||
|
||||
if (currentAssistantMessage != null) {
|
||||
if (update.isComplete) {
|
||||
// Final message with metrics
|
||||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||
content = update.text,
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = conversationService.createTokenMetrics()
|
||||
)
|
||||
} else {
|
||||
// Ongoing message update
|
||||
currentMessages[lastIndex] = Message.Assistant.Ongoing(
|
||||
content = update.text,
|
||||
timestamp = currentAssistantMessage.timestamp
|
||||
)
|
||||
}
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear conversation
|
||||
*/
|
||||
fun clearConversation() {
|
||||
tokenCollectionJob?.cancel()
|
||||
stopGeneration()
|
||||
_messages.value = emptyList()
|
||||
}
|
||||
|
||||
override fun onCleared() {
|
||||
tokenCollectionJob?.cancel()
|
||||
stopGeneration()
|
||||
super.onCleared()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue