bugfix: properly handle user's quitting conversation screen while tokens in generation

This commit is contained in:
Han Yin 2025-04-12 12:36:11 -07:00
parent 4848bf93d0
commit 75c986afc5
2 changed files with 115 additions and 35 deletions

View File

@ -1,9 +1,11 @@
package com.example.llama.revamp.engine
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.flow
/**
@ -39,7 +41,8 @@ class InferenceEngine {
// Keep track of current benchmark results
private var _benchmarkResults: String? = null
val benchmarkResults: StateFlow<String?> = MutableStateFlow(_benchmarkResults)
private val _benchmarkResultsFlow = MutableStateFlow<String?>(null)
val benchmarkResults: StateFlow<String?> = _benchmarkResultsFlow
init {
// Simulate library loading
@ -66,6 +69,9 @@ class InferenceEngine {
}
_state.value = State.AwaitingUserPrompt
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during model loading")
}
@ -79,21 +85,35 @@ class InferenceEngine {
// This would be replaced with actual token generation logic
return flow {
delay(500) // Simulate processing time
try {
delay(500) // Simulate processing time
_state.value = State.Generating
_state.value = State.Generating
// Simulate token generation
val response =
"This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
val words = response.split(" ")
// Simulate token generation
val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
val words = response.split(" ")
for (word in words) {
emit(word + " ")
delay(50) // Simulate token generation delay
for (word in words) {
emit(word + " ")
delay(50) // Simulate token generation delay
}
_state.value = State.AwaitingUserPrompt
} catch (e: CancellationException) {
// Handle cancellation gracefully
_state.value = State.AwaitingUserPrompt
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during generation")
throw e
}
_state.value = State.AwaitingUserPrompt
}.catch { e ->
// If it's not a cancellation, update state to error
if (e !is CancellationException) {
_state.value = State.Error(e.message ?: "Unknown error during generation")
}
throw e
}
}
@ -128,11 +148,15 @@ class InferenceEngine {
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
_benchmarkResults = result.toString()
(benchmarkResults as MutableStateFlow).value = _benchmarkResults
_benchmarkResultsFlow.value = _benchmarkResults
_state.value = State.AwaitingUserPrompt
return _benchmarkResults ?: ""
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
_state.value = State.AwaitingUserPrompt
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
return "Error: ${e.message}"
@ -147,7 +171,7 @@ class InferenceEngine {
delay(300)
_state.value = State.LibraryLoaded
_benchmarkResults = null
(benchmarkResults as MutableStateFlow).value = null
_benchmarkResultsFlow.value = null
}
/**

View File

@ -5,9 +5,12 @@ import androidx.lifecycle.ViewModelProvider
import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.engine.InferenceEngine
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.launch
import java.text.SimpleDateFormat
import java.util.Date
@ -43,6 +46,9 @@ class MainViewModel(
private val _systemPrompt = MutableStateFlow<String?>(null)
val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
// Flag to track if token collection is active
private var tokenCollectionJob: Job? = null
/**
* Selects a model for use.
*/
@ -96,6 +102,9 @@ class MainViewModel(
fun sendMessage(content: String) {
if (content.isBlank()) return
// Cancel any ongoing token collection
tokenCollectionJob?.cancel()
// Add user message
val userMessage = Message.User(
content = content,
@ -112,31 +121,72 @@ class MainViewModel(
_messages.value = _messages.value + assistantMessage
// Get response from engine
val messageIndex = _messages.value.size - 1
viewModelScope.launch {
tokenCollectionJob = viewModelScope.launch {
val response = StringBuilder()
inferenceEngine.sendUserPrompt(content).collect { token ->
response.append(token)
try {
inferenceEngine.sendUserPrompt(content)
.catch { e ->
// Handle errors during token collection
val currentMessages = _messages.value.toMutableList()
if (currentMessages.size >= 2) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
content = "${response}[Error: ${e.message}]",
isComplete = true
)
_messages.value = currentMessages
}
}
}
.onCompletion { cause ->
// Handle completion (normal or cancelled)
val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
isComplete = true
)
_messages.value = currentMessages
}
}
}
.collect { token ->
response.append(token)
// Update the assistant message with the generated text
// Safely update the assistant message with the generated text
val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
content = response.toString(),
isComplete = false
)
_messages.value = currentMessages
}
}
}
} catch (e: Exception) {
// Handle any unexpected exceptions
val currentMessages = _messages.value.toMutableList()
val currentAssistantMessage = currentMessages[messageIndex] as Message.Assistant
currentMessages[messageIndex] = currentAssistantMessage.copy(
content = response.toString(),
isComplete = false
)
_messages.value = currentMessages
if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
content = "${response}[Error: ${e.message}]",
isComplete = true
)
_messages.value = currentMessages
}
}
}
// Mark message as complete when generation finishes
val finalMessages = _messages.value.toMutableList()
val finalAssistantMessage = finalMessages[messageIndex] as Message.Assistant
finalMessages[messageIndex] = finalAssistantMessage.copy(
isComplete = true
)
_messages.value = finalMessages
}
}
@ -144,8 +194,14 @@ class MainViewModel(
* Unloads the currently loaded model.
*/
suspend fun unloadModel() {
inferenceEngine.unloadModel()
// Cancel any ongoing token collection
tokenCollectionJob?.cancel()
// Clear messages
_messages.value = emptyList()
// Unload model
inferenceEngine.unloadModel()
}
/**