bugfix: properly handle user's quitting conversation screen while tokens in generation
This commit is contained in:
parent
4848bf93d0
commit
75c986afc5
|
|
@ -1,9 +1,11 @@
|
||||||
package com.example.llama.revamp.engine
|
package com.example.llama.revamp.engine
|
||||||
|
|
||||||
|
import kotlinx.coroutines.CancellationException
|
||||||
import kotlinx.coroutines.delay
|
import kotlinx.coroutines.delay
|
||||||
import kotlinx.coroutines.flow.Flow
|
import kotlinx.coroutines.flow.Flow
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
|
import kotlinx.coroutines.flow.catch
|
||||||
import kotlinx.coroutines.flow.flow
|
import kotlinx.coroutines.flow.flow
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -39,7 +41,8 @@ class InferenceEngine {
|
||||||
|
|
||||||
// Keep track of current benchmark results
|
// Keep track of current benchmark results
|
||||||
private var _benchmarkResults: String? = null
|
private var _benchmarkResults: String? = null
|
||||||
val benchmarkResults: StateFlow<String?> = MutableStateFlow(_benchmarkResults)
|
private val _benchmarkResultsFlow = MutableStateFlow<String?>(null)
|
||||||
|
val benchmarkResults: StateFlow<String?> = _benchmarkResultsFlow
|
||||||
|
|
||||||
init {
|
init {
|
||||||
// Simulate library loading
|
// Simulate library loading
|
||||||
|
|
@ -66,6 +69,9 @@ class InferenceEngine {
|
||||||
}
|
}
|
||||||
|
|
||||||
_state.value = State.AwaitingUserPrompt
|
_state.value = State.AwaitingUserPrompt
|
||||||
|
} catch (e: CancellationException) {
|
||||||
|
// If coroutine is cancelled, propagate cancellation
|
||||||
|
throw e
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
_state.value = State.Error(e.message ?: "Unknown error during model loading")
|
_state.value = State.Error(e.message ?: "Unknown error during model loading")
|
||||||
}
|
}
|
||||||
|
|
@ -79,21 +85,35 @@ class InferenceEngine {
|
||||||
|
|
||||||
// This would be replaced with actual token generation logic
|
// This would be replaced with actual token generation logic
|
||||||
return flow {
|
return flow {
|
||||||
delay(500) // Simulate processing time
|
try {
|
||||||
|
delay(500) // Simulate processing time
|
||||||
|
|
||||||
_state.value = State.Generating
|
_state.value = State.Generating
|
||||||
|
|
||||||
// Simulate token generation
|
// Simulate token generation
|
||||||
val response =
|
val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
|
||||||
"This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
|
val words = response.split(" ")
|
||||||
val words = response.split(" ")
|
|
||||||
|
|
||||||
for (word in words) {
|
for (word in words) {
|
||||||
emit(word + " ")
|
emit(word + " ")
|
||||||
delay(50) // Simulate token generation delay
|
delay(50) // Simulate token generation delay
|
||||||
|
}
|
||||||
|
|
||||||
|
_state.value = State.AwaitingUserPrompt
|
||||||
|
} catch (e: CancellationException) {
|
||||||
|
// Handle cancellation gracefully
|
||||||
|
_state.value = State.AwaitingUserPrompt
|
||||||
|
throw e
|
||||||
|
} catch (e: Exception) {
|
||||||
|
_state.value = State.Error(e.message ?: "Unknown error during generation")
|
||||||
|
throw e
|
||||||
}
|
}
|
||||||
|
}.catch { e ->
|
||||||
_state.value = State.AwaitingUserPrompt
|
// If it's not a cancellation, update state to error
|
||||||
|
if (e !is CancellationException) {
|
||||||
|
_state.value = State.Error(e.message ?: "Unknown error during generation")
|
||||||
|
}
|
||||||
|
throw e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -128,11 +148,15 @@ class InferenceEngine {
|
||||||
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
|
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
|
||||||
|
|
||||||
_benchmarkResults = result.toString()
|
_benchmarkResults = result.toString()
|
||||||
(benchmarkResults as MutableStateFlow).value = _benchmarkResults
|
_benchmarkResultsFlow.value = _benchmarkResults
|
||||||
|
|
||||||
_state.value = State.AwaitingUserPrompt
|
_state.value = State.AwaitingUserPrompt
|
||||||
|
|
||||||
return _benchmarkResults ?: ""
|
return _benchmarkResults ?: ""
|
||||||
|
} catch (e: CancellationException) {
|
||||||
|
// If coroutine is cancelled, propagate cancellation
|
||||||
|
_state.value = State.AwaitingUserPrompt
|
||||||
|
throw e
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
|
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
|
||||||
return "Error: ${e.message}"
|
return "Error: ${e.message}"
|
||||||
|
|
@ -147,7 +171,7 @@ class InferenceEngine {
|
||||||
delay(300)
|
delay(300)
|
||||||
_state.value = State.LibraryLoaded
|
_state.value = State.LibraryLoaded
|
||||||
_benchmarkResults = null
|
_benchmarkResults = null
|
||||||
(benchmarkResults as MutableStateFlow).value = null
|
_benchmarkResultsFlow.value = null
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,12 @@ import androidx.lifecycle.ViewModelProvider
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.example.llama.revamp.data.model.ModelInfo
|
import com.example.llama.revamp.data.model.ModelInfo
|
||||||
import com.example.llama.revamp.engine.InferenceEngine
|
import com.example.llama.revamp.engine.InferenceEngine
|
||||||
|
import kotlinx.coroutines.Job
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
import kotlinx.coroutines.flow.asStateFlow
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
|
import kotlinx.coroutines.flow.catch
|
||||||
|
import kotlinx.coroutines.flow.onCompletion
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
import java.text.SimpleDateFormat
|
import java.text.SimpleDateFormat
|
||||||
import java.util.Date
|
import java.util.Date
|
||||||
|
|
@ -43,6 +46,9 @@ class MainViewModel(
|
||||||
private val _systemPrompt = MutableStateFlow<String?>(null)
|
private val _systemPrompt = MutableStateFlow<String?>(null)
|
||||||
val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
|
val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
|
||||||
|
|
||||||
|
// Flag to track if token collection is active
|
||||||
|
private var tokenCollectionJob: Job? = null
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Selects a model for use.
|
* Selects a model for use.
|
||||||
*/
|
*/
|
||||||
|
|
@ -96,6 +102,9 @@ class MainViewModel(
|
||||||
fun sendMessage(content: String) {
|
fun sendMessage(content: String) {
|
||||||
if (content.isBlank()) return
|
if (content.isBlank()) return
|
||||||
|
|
||||||
|
// Cancel any ongoing token collection
|
||||||
|
tokenCollectionJob?.cancel()
|
||||||
|
|
||||||
// Add user message
|
// Add user message
|
||||||
val userMessage = Message.User(
|
val userMessage = Message.User(
|
||||||
content = content,
|
content = content,
|
||||||
|
|
@ -112,31 +121,72 @@ class MainViewModel(
|
||||||
_messages.value = _messages.value + assistantMessage
|
_messages.value = _messages.value + assistantMessage
|
||||||
|
|
||||||
// Get response from engine
|
// Get response from engine
|
||||||
val messageIndex = _messages.value.size - 1
|
tokenCollectionJob = viewModelScope.launch {
|
||||||
|
|
||||||
viewModelScope.launch {
|
|
||||||
val response = StringBuilder()
|
val response = StringBuilder()
|
||||||
|
|
||||||
inferenceEngine.sendUserPrompt(content).collect { token ->
|
try {
|
||||||
response.append(token)
|
inferenceEngine.sendUserPrompt(content)
|
||||||
|
.catch { e ->
|
||||||
|
// Handle errors during token collection
|
||||||
|
val currentMessages = _messages.value.toMutableList()
|
||||||
|
if (currentMessages.size >= 2) {
|
||||||
|
val messageIndex = currentMessages.size - 1
|
||||||
|
val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant
|
||||||
|
if (currentAssistantMessage != null) {
|
||||||
|
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||||
|
content = "${response}[Error: ${e.message}]",
|
||||||
|
isComplete = true
|
||||||
|
)
|
||||||
|
_messages.value = currentMessages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.onCompletion { cause ->
|
||||||
|
// Handle completion (normal or cancelled)
|
||||||
|
val currentMessages = _messages.value.toMutableList()
|
||||||
|
if (currentMessages.isNotEmpty()) {
|
||||||
|
val messageIndex = currentMessages.size - 1
|
||||||
|
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||||
|
if (currentAssistantMessage != null) {
|
||||||
|
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||||
|
isComplete = true
|
||||||
|
)
|
||||||
|
_messages.value = currentMessages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.collect { token ->
|
||||||
|
response.append(token)
|
||||||
|
|
||||||
// Update the assistant message with the generated text
|
// Safely update the assistant message with the generated text
|
||||||
|
val currentMessages = _messages.value.toMutableList()
|
||||||
|
if (currentMessages.isNotEmpty()) {
|
||||||
|
val messageIndex = currentMessages.size - 1
|
||||||
|
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||||
|
if (currentAssistantMessage != null) {
|
||||||
|
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||||
|
content = response.toString(),
|
||||||
|
isComplete = false
|
||||||
|
)
|
||||||
|
_messages.value = currentMessages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
// Handle any unexpected exceptions
|
||||||
val currentMessages = _messages.value.toMutableList()
|
val currentMessages = _messages.value.toMutableList()
|
||||||
val currentAssistantMessage = currentMessages[messageIndex] as Message.Assistant
|
if (currentMessages.isNotEmpty()) {
|
||||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
val messageIndex = currentMessages.size - 1
|
||||||
content = response.toString(),
|
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||||
isComplete = false
|
if (currentAssistantMessage != null) {
|
||||||
)
|
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||||
_messages.value = currentMessages
|
content = "${response}[Error: ${e.message}]",
|
||||||
|
isComplete = true
|
||||||
|
)
|
||||||
|
_messages.value = currentMessages
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark message as complete when generation finishes
|
|
||||||
val finalMessages = _messages.value.toMutableList()
|
|
||||||
val finalAssistantMessage = finalMessages[messageIndex] as Message.Assistant
|
|
||||||
finalMessages[messageIndex] = finalAssistantMessage.copy(
|
|
||||||
isComplete = true
|
|
||||||
)
|
|
||||||
_messages.value = finalMessages
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -144,8 +194,14 @@ class MainViewModel(
|
||||||
* Unloads the currently loaded model.
|
* Unloads the currently loaded model.
|
||||||
*/
|
*/
|
||||||
suspend fun unloadModel() {
|
suspend fun unloadModel() {
|
||||||
inferenceEngine.unloadModel()
|
// Cancel any ongoing token collection
|
||||||
|
tokenCollectionJob?.cancel()
|
||||||
|
|
||||||
|
// Clear messages
|
||||||
_messages.value = emptyList()
|
_messages.value = emptyList()
|
||||||
|
|
||||||
|
// Unload model
|
||||||
|
inferenceEngine.unloadModel()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue