VM: handle the cancellation of ongoing token generation

This commit is contained in:
Han Yin 2025-04-16 15:22:30 -07:00
parent 46859c10f0
commit 9f77155535
3 changed files with 115 additions and 57 deletions

View File

@ -40,6 +40,7 @@ import com.example.llama.revamp.ui.screens.ModelSelectionScreen
import com.example.llama.revamp.ui.screens.ModelsManagementScreen
import com.example.llama.revamp.ui.screens.SettingsGeneralScreen
import com.example.llama.revamp.ui.theme.LlamaTheme
import com.example.llama.revamp.viewmodel.ConversationViewModel
import com.example.llama.revamp.viewmodel.MainViewModel
import dagger.hilt.android.AndroidEntryPoint
import kotlinx.coroutines.isActive
@ -64,18 +65,27 @@ class MainActivity : ComponentActivity() {
@Composable
fun AppContent(
mainVewModel: MainViewModel = hiltViewModel()
mainViewModel: MainViewModel = hiltViewModel(),
conversationViewModel: ConversationViewModel = hiltViewModel(),
) {
// Lifecycle and Coroutine scope
val lifecycleOwner = LocalLifecycleOwner.current
val coroutineScope = rememberCoroutineScope()
// LLM Inference engine status
val engineState by mainVewModel.engineState.collectAsState()
val isModelLoading = engineState is State.LoadingModel
|| engineState is State.ProcessingSystemPrompt
val isModelLoaded = engineState !is State.Uninitialized
&& engineState !is State.LibraryLoaded
// Inference engine state
val engineState by mainViewModel.engineState.collectAsState()
val isModelUninterruptible by remember(engineState) {
derivedStateOf {
engineState is State.LoadingModel
|| engineState is State.Benchmarking
|| engineState is State.ProcessingUserPrompt
|| engineState is State.ProcessingSystemPrompt
}
}
val isModelLoaded by remember(engineState) {
derivedStateOf {
engineState !is State.Uninitialized && engineState !is State.LibraryLoaded
}
}
// Navigation
val navController = rememberNavController()
@ -103,16 +113,20 @@ fun AppContent(
// Model unloading confirmation
var showUnloadDialog by remember { mutableStateOf(false) }
val handleBackWithModelCheck = {
if (isModelLoading) {
// If model is still loading, ignore the request
true // Mark as handled
} else if (isModelLoaded) {
showUnloadDialog = true
pendingNavigation = { navigationActions.navigateUp() }
true // Mark as handled
} else {
navigationActions.navigateUp()
true // Mark as handled
when {
isModelUninterruptible -> {
// If model is non-interruptible at all, ignore the request
true // Mark as handled
}
isModelLoaded -> {
showUnloadDialog = true
pendingNavigation = { navigationActions.navigateUp() }
true // Mark as handled
}
else -> {
navigationActions.navigateUp()
true // Mark as handled
}
}
}
@ -216,16 +230,6 @@ fun AppContent(
)
}
// Conversation Screen
composable(AppDestinations.CONVERSATION_ROUTE) {
ConversationScreen(
onBackPressed = {
// Need to unload model before going back
handleBackWithModelCheck()
}
)
}
// Benchmark Screen
composable(AppDestinations.BENCHMARK_ROUTE) {
BenchmarkScreen(
@ -236,6 +240,17 @@ fun AppContent(
)
}
// Conversation Screen
composable(AppDestinations.CONVERSATION_ROUTE) {
ConversationScreen(
onBackPressed = {
// Need to unload model before going back
handleBackWithModelCheck()
},
viewModel = conversationViewModel
)
}
// Settings General Screen
composable(AppDestinations.SETTINGS_GENERAL_ROUTE) {
SettingsGeneralScreen(
@ -261,7 +276,15 @@ fun AppContent(
onConfirm = {
isUnloading = true
coroutineScope.launch {
mainVewModel.unloadModel()
// Handle screen specific cleanups
when(engineState) {
is State.Benchmarking -> {}
is State.Generating -> conversationViewModel.clearConversation()
else -> {}
}
// Unload model
mainViewModel.unloadModel()
isUnloading = false
showUnloadDialog = false
pendingNavigation?.invoke()

View File

@ -68,7 +68,7 @@ import kotlinx.coroutines.launch
@Composable
fun ConversationScreen(
onBackPressed: () -> Unit,
viewModel: ConversationViewModel = hiltViewModel()
viewModel: ConversationViewModel
) {
val engineState by viewModel.engineState.collectAsState()
val messages by viewModel.messages.collectAsState()

View File

@ -3,12 +3,15 @@ package com.example.llama.revamp.viewmodel
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.engine.ConversationService
import com.example.llama.revamp.engine.GenerationUpdate
import com.example.llama.revamp.engine.TokenMetrics
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.launch
import java.text.SimpleDateFormat
import java.util.Date
@ -29,18 +32,17 @@ class ConversationViewModel @Inject constructor(
private val _messages = MutableStateFlow<List<Message>>(emptyList())
val messages: StateFlow<List<Message>> = _messages.asStateFlow()
// Token generation job
// Keep track of token generation job
private var tokenCollectionJob: Job? = null
/**
* Send a message with the provided content.
* Note: This matches the existing UI which manages input state outside the ViewModel.
* Send a message with the provided content
*/
fun sendMessage(content: String) {
if (content.isBlank()) return
// Cancel ongoing collection
tokenCollectionJob?.cancel()
stopGeneration()
// Add user message
val userMessage = Message.User(
@ -60,39 +62,45 @@ class ConversationViewModel @Inject constructor(
tokenCollectionJob = viewModelScope.launch {
try {
conversationService.generateResponse(content)
.collect { (text, isComplete) ->
updateAssistantMessage(text, isComplete)
}
.onCompletion { tokenCollectionJob = null }
.collect(::updateAssistantMessage)
} catch (_: CancellationException) {
handleCancellation()
tokenCollectionJob = null
} catch (e: Exception) {
// Handle error
handleResponseError(e)
tokenCollectionJob = null
}
}
}
/**
* Handle updating the assistant message
* Stop ongoing generation
*/
private fun updateAssistantMessage(text: String, isComplete: Boolean) {
fun stopGeneration() {
tokenCollectionJob?.let { job ->
// handled by the catch blocks
if (job.isActive) { job.cancel() }
}
}
/**
* Handle the case when generation is explicitly cancelled
*/
private fun handleCancellation() {
val currentMessages = _messages.value.toMutableList()
val lastIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
if (currentAssistantMessage != null) {
if (isComplete) {
// Final message with metrics
currentMessages[lastIndex] = Message.Assistant.Completed(
content = text,
timestamp = currentAssistantMessage.timestamp,
metrics = conversationService.createTokenMetrics()
)
} else {
// Ongoing message update
currentMessages[lastIndex] = Message.Assistant.Ongoing(
content = text,
timestamp = currentAssistantMessage.timestamp
)
}
// Replace with completed message, adding note that it was interrupted
currentMessages[lastIndex] = Message.Assistant.Completed(
content = currentAssistantMessage.content + " [Generation stopped]",
timestamp = currentAssistantMessage.timestamp,
metrics = conversationService.createTokenMetrics()
)
_messages.value = currentMessages
}
}
@ -107,7 +115,7 @@ class ConversationViewModel @Inject constructor(
if (currentAssistantMessage != null) {
currentMessages[lastIndex] = Message.Assistant.Completed(
content = "${currentAssistantMessage.content}[Error: ${e.message}]",
content = currentAssistantMessage.content + " [Error: ${e.message}]",
timestamp = currentAssistantMessage.timestamp,
metrics = conversationService.createTokenMetrics()
)
@ -115,16 +123,43 @@ class ConversationViewModel @Inject constructor(
}
}
/**
* Handle updating the assistant message
*/
private fun updateAssistantMessage(update: GenerationUpdate) {
val currentMessages = _messages.value.toMutableList()
val lastIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
if (currentAssistantMessage != null) {
if (update.isComplete) {
// Final message with metrics
currentMessages[lastIndex] = Message.Assistant.Completed(
content = update.text,
timestamp = currentAssistantMessage.timestamp,
metrics = conversationService.createTokenMetrics()
)
} else {
// Ongoing message update
currentMessages[lastIndex] = Message.Assistant.Ongoing(
content = update.text,
timestamp = currentAssistantMessage.timestamp
)
}
_messages.value = currentMessages
}
}
/**
* Clear conversation
*/
fun clearConversation() {
tokenCollectionJob?.cancel()
stopGeneration()
_messages.value = emptyList()
}
override fun onCleared() {
tokenCollectionJob?.cancel()
stopGeneration()
super.onCleared()
}
}