bugfix: properly handle user's quitting conversation screen while tokens in generation

This commit is contained in:
Han Yin 2025-04-12 12:36:11 -07:00
parent 4848bf93d0
commit 75c986afc5
2 changed files with 115 additions and 35 deletions

View File

@ -1,9 +1,11 @@
package com.example.llama.revamp.engine package com.example.llama.revamp.engine
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flow
/** /**
@ -39,7 +41,8 @@ class InferenceEngine {
// Keep track of current benchmark results // Keep track of current benchmark results
private var _benchmarkResults: String? = null private var _benchmarkResults: String? = null
val benchmarkResults: StateFlow<String?> = MutableStateFlow(_benchmarkResults) private val _benchmarkResultsFlow = MutableStateFlow<String?>(null)
val benchmarkResults: StateFlow<String?> = _benchmarkResultsFlow
init { init {
// Simulate library loading // Simulate library loading
@ -66,6 +69,9 @@ class InferenceEngine {
} }
_state.value = State.AwaitingUserPrompt _state.value = State.AwaitingUserPrompt
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
throw e
} catch (e: Exception) { } catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during model loading") _state.value = State.Error(e.message ?: "Unknown error during model loading")
} }
@ -79,13 +85,13 @@ class InferenceEngine {
// This would be replaced with actual token generation logic // This would be replaced with actual token generation logic
return flow { return flow {
try {
delay(500) // Simulate processing time delay(500) // Simulate processing time
_state.value = State.Generating _state.value = State.Generating
// Simulate token generation // Simulate token generation
val response = val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
"This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
val words = response.split(" ") val words = response.split(" ")
for (word in words) { for (word in words) {
@ -94,6 +100,20 @@ class InferenceEngine {
} }
_state.value = State.AwaitingUserPrompt _state.value = State.AwaitingUserPrompt
} catch (e: CancellationException) {
// Handle cancellation gracefully
_state.value = State.AwaitingUserPrompt
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during generation")
throw e
}
}.catch { e ->
// If it's not a cancellation, update state to error
if (e !is CancellationException) {
_state.value = State.Error(e.message ?: "Unknown error during generation")
}
throw e
} }
} }
@ -128,11 +148,15 @@ class InferenceEngine {
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n") result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
_benchmarkResults = result.toString() _benchmarkResults = result.toString()
(benchmarkResults as MutableStateFlow).value = _benchmarkResults _benchmarkResultsFlow.value = _benchmarkResults
_state.value = State.AwaitingUserPrompt _state.value = State.AwaitingUserPrompt
return _benchmarkResults ?: "" return _benchmarkResults ?: ""
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
_state.value = State.AwaitingUserPrompt
throw e
} catch (e: Exception) { } catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during benchmarking") _state.value = State.Error(e.message ?: "Unknown error during benchmarking")
return "Error: ${e.message}" return "Error: ${e.message}"
@ -147,7 +171,7 @@ class InferenceEngine {
delay(300) delay(300)
_state.value = State.LibraryLoaded _state.value = State.LibraryLoaded
_benchmarkResults = null _benchmarkResults = null
(benchmarkResults as MutableStateFlow).value = null _benchmarkResultsFlow.value = null
} }
/** /**

View File

@ -5,9 +5,12 @@ import androidx.lifecycle.ViewModelProvider
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceEngine
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import java.text.SimpleDateFormat import java.text.SimpleDateFormat
import java.util.Date import java.util.Date
@ -43,6 +46,9 @@ class MainViewModel(
private val _systemPrompt = MutableStateFlow<String?>(null) private val _systemPrompt = MutableStateFlow<String?>(null)
val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow() val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
// Flag to track if token collection is active
private var tokenCollectionJob: Job? = null
/** /**
* Selects a model for use. * Selects a model for use.
*/ */
@ -96,6 +102,9 @@ class MainViewModel(
fun sendMessage(content: String) { fun sendMessage(content: String) {
if (content.isBlank()) return if (content.isBlank()) return
// Cancel any ongoing token collection
tokenCollectionJob?.cancel()
// Add user message // Add user message
val userMessage = Message.User( val userMessage = Message.User(
content = content, content = content,
@ -112,31 +121,72 @@ class MainViewModel(
_messages.value = _messages.value + assistantMessage _messages.value = _messages.value + assistantMessage
// Get response from engine // Get response from engine
val messageIndex = _messages.value.size - 1 tokenCollectionJob = viewModelScope.launch {
viewModelScope.launch {
val response = StringBuilder() val response = StringBuilder()
inferenceEngine.sendUserPrompt(content).collect { token -> try {
inferenceEngine.sendUserPrompt(content)
.catch { e ->
// Handle errors during token collection
val currentMessages = _messages.value.toMutableList()
if (currentMessages.size >= 2) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
content = "${response}[Error: ${e.message}]",
isComplete = true
)
_messages.value = currentMessages
}
}
}
.onCompletion { cause ->
// Handle completion (normal or cancelled)
val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
isComplete = true
)
_messages.value = currentMessages
}
}
}
.collect { token ->
response.append(token) response.append(token)
// Update the assistant message with the generated text // Safely update the assistant message with the generated text
val currentMessages = _messages.value.toMutableList() val currentMessages = _messages.value.toMutableList()
val currentAssistantMessage = currentMessages[messageIndex] as Message.Assistant if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy( currentMessages[messageIndex] = currentAssistantMessage.copy(
content = response.toString(), content = response.toString(),
isComplete = false isComplete = false
) )
_messages.value = currentMessages _messages.value = currentMessages
} }
}
// Mark message as complete when generation finishes }
val finalMessages = _messages.value.toMutableList() } catch (e: Exception) {
val finalAssistantMessage = finalMessages[messageIndex] as Message.Assistant // Handle any unexpected exceptions
finalMessages[messageIndex] = finalAssistantMessage.copy( val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
content = "${response}[Error: ${e.message}]",
isComplete = true isComplete = true
) )
_messages.value = finalMessages _messages.value = currentMessages
}
}
}
} }
} }
@ -144,8 +194,14 @@ class MainViewModel(
* Unloads the currently loaded model. * Unloads the currently loaded model.
*/ */
suspend fun unloadModel() { suspend fun unloadModel() {
inferenceEngine.unloadModel() // Cancel any ongoing token collection
tokenCollectionJob?.cancel()
// Clear messages
_messages.value = emptyList() _messages.value = emptyList()
// Unload model
inferenceEngine.unloadModel()
} }
/** /**