bugfix: properly handle user's quitting conversation screen while tokens in generation
This commit is contained in:
parent
4848bf93d0
commit
75c986afc5
|
|
@ -1,9 +1,11 @@
|
|||
package com.example.llama.revamp.engine
|
||||
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.catch
|
||||
import kotlinx.coroutines.flow.flow
|
||||
|
||||
/**
|
||||
|
|
@ -39,7 +41,8 @@ class InferenceEngine {
|
|||
|
||||
// Keep track of current benchmark results
|
||||
private var _benchmarkResults: String? = null
|
||||
val benchmarkResults: StateFlow<String?> = MutableStateFlow(_benchmarkResults)
|
||||
private val _benchmarkResultsFlow = MutableStateFlow<String?>(null)
|
||||
val benchmarkResults: StateFlow<String?> = _benchmarkResultsFlow
|
||||
|
||||
init {
|
||||
// Simulate library loading
|
||||
|
|
@ -66,6 +69,9 @@ class InferenceEngine {
|
|||
}
|
||||
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
} catch (e: CancellationException) {
|
||||
// If coroutine is cancelled, propagate cancellation
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during model loading")
|
||||
}
|
||||
|
|
@ -79,21 +85,35 @@ class InferenceEngine {
|
|||
|
||||
// This would be replaced with actual token generation logic
|
||||
return flow {
|
||||
delay(500) // Simulate processing time
|
||||
try {
|
||||
delay(500) // Simulate processing time
|
||||
|
||||
_state.value = State.Generating
|
||||
_state.value = State.Generating
|
||||
|
||||
// Simulate token generation
|
||||
val response =
|
||||
"This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
|
||||
val words = response.split(" ")
|
||||
// Simulate token generation
|
||||
val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
|
||||
val words = response.split(" ")
|
||||
|
||||
for (word in words) {
|
||||
emit(word + " ")
|
||||
delay(50) // Simulate token generation delay
|
||||
for (word in words) {
|
||||
emit(word + " ")
|
||||
delay(50) // Simulate token generation delay
|
||||
}
|
||||
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
} catch (e: CancellationException) {
|
||||
// Handle cancellation gracefully
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during generation")
|
||||
throw e
|
||||
}
|
||||
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
}.catch { e ->
|
||||
// If it's not a cancellation, update state to error
|
||||
if (e !is CancellationException) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during generation")
|
||||
}
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -128,11 +148,15 @@ class InferenceEngine {
|
|||
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
|
||||
|
||||
_benchmarkResults = result.toString()
|
||||
(benchmarkResults as MutableStateFlow).value = _benchmarkResults
|
||||
_benchmarkResultsFlow.value = _benchmarkResults
|
||||
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
|
||||
return _benchmarkResults ?: ""
|
||||
} catch (e: CancellationException) {
|
||||
// If coroutine is cancelled, propagate cancellation
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
|
||||
return "Error: ${e.message}"
|
||||
|
|
@ -147,7 +171,7 @@ class InferenceEngine {
|
|||
delay(300)
|
||||
_state.value = State.LibraryLoaded
|
||||
_benchmarkResults = null
|
||||
(benchmarkResults as MutableStateFlow).value = null
|
||||
_benchmarkResultsFlow.value = null
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -5,9 +5,12 @@ import androidx.lifecycle.ViewModelProvider
|
|||
import androidx.lifecycle.viewModelScope
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.flow.catch
|
||||
import kotlinx.coroutines.flow.onCompletion
|
||||
import kotlinx.coroutines.launch
|
||||
import java.text.SimpleDateFormat
|
||||
import java.util.Date
|
||||
|
|
@ -43,6 +46,9 @@ class MainViewModel(
|
|||
private val _systemPrompt = MutableStateFlow<String?>(null)
|
||||
val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
|
||||
|
||||
// Flag to track if token collection is active
|
||||
private var tokenCollectionJob: Job? = null
|
||||
|
||||
/**
|
||||
* Selects a model for use.
|
||||
*/
|
||||
|
|
@ -96,6 +102,9 @@ class MainViewModel(
|
|||
fun sendMessage(content: String) {
|
||||
if (content.isBlank()) return
|
||||
|
||||
// Cancel any ongoing token collection
|
||||
tokenCollectionJob?.cancel()
|
||||
|
||||
// Add user message
|
||||
val userMessage = Message.User(
|
||||
content = content,
|
||||
|
|
@ -112,31 +121,72 @@ class MainViewModel(
|
|||
_messages.value = _messages.value + assistantMessage
|
||||
|
||||
// Get response from engine
|
||||
val messageIndex = _messages.value.size - 1
|
||||
|
||||
viewModelScope.launch {
|
||||
tokenCollectionJob = viewModelScope.launch {
|
||||
val response = StringBuilder()
|
||||
|
||||
inferenceEngine.sendUserPrompt(content).collect { token ->
|
||||
response.append(token)
|
||||
try {
|
||||
inferenceEngine.sendUserPrompt(content)
|
||||
.catch { e ->
|
||||
// Handle errors during token collection
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.size >= 2) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||
content = "${response}[Error: ${e.message}]",
|
||||
isComplete = true
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
}
|
||||
}
|
||||
.onCompletion { cause ->
|
||||
// Handle completion (normal or cancelled)
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.isNotEmpty()) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||
isComplete = true
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
}
|
||||
}
|
||||
.collect { token ->
|
||||
response.append(token)
|
||||
|
||||
// Update the assistant message with the generated text
|
||||
// Safely update the assistant message with the generated text
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.isNotEmpty()) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||
content = response.toString(),
|
||||
isComplete = false
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
// Handle any unexpected exceptions
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
val currentAssistantMessage = currentMessages[messageIndex] as Message.Assistant
|
||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||
content = response.toString(),
|
||||
isComplete = false
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
if (currentMessages.isNotEmpty()) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||
content = "${response}[Error: ${e.message}]",
|
||||
isComplete = true
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mark message as complete when generation finishes
|
||||
val finalMessages = _messages.value.toMutableList()
|
||||
val finalAssistantMessage = finalMessages[messageIndex] as Message.Assistant
|
||||
finalMessages[messageIndex] = finalAssistantMessage.copy(
|
||||
isComplete = true
|
||||
)
|
||||
_messages.value = finalMessages
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -144,8 +194,14 @@ class MainViewModel(
|
|||
* Unloads the currently loaded model.
|
||||
*/
|
||||
suspend fun unloadModel() {
|
||||
inferenceEngine.unloadModel()
|
||||
// Cancel any ongoing token collection
|
||||
tokenCollectionJob?.cancel()
|
||||
|
||||
// Clear messages
|
||||
_messages.value = emptyList()
|
||||
|
||||
// Unload model
|
||||
inferenceEngine.unloadModel()
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Reference in New Issue