UI: avoid duplicated calculation of token metrics
This commit is contained in:
parent
dd5b20d74d
commit
7968216235
|
|
@ -96,7 +96,8 @@ data class ModelLoadingMetrics(
|
|||
*/
|
||||
data class GenerationUpdate(
|
||||
val text: String,
|
||||
val isComplete: Boolean
|
||||
val isComplete: Boolean,
|
||||
val metrics: TokenMetrics? = null
|
||||
)
|
||||
|
||||
/**
|
||||
|
|
@ -221,6 +222,8 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
|||
private var isFirstToken: Boolean = true
|
||||
|
||||
override fun generateResponse(prompt: String): Flow<GenerationUpdate> = flow {
|
||||
val response = StringBuilder()
|
||||
|
||||
try {
|
||||
// Reset metrics tracking
|
||||
generationStartTime = System.currentTimeMillis()
|
||||
|
|
@ -228,8 +231,6 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
|||
tokenCount = 0
|
||||
isFirstToken = true
|
||||
|
||||
val response = StringBuilder()
|
||||
|
||||
inferenceEngine.sendUserPrompt(prompt)
|
||||
.collect { token ->
|
||||
// Track first token time
|
||||
|
|
@ -253,10 +254,11 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
|||
val metrics = createTokenMetrics()
|
||||
|
||||
// Emit final response with completion flag
|
||||
emit(GenerationUpdate(response.toString(), true))
|
||||
emit(GenerationUpdate(response.toString(), true, metrics))
|
||||
} catch (e: Exception) {
|
||||
// Emit error
|
||||
val metrics = createTokenMetrics()
|
||||
emit(GenerationUpdate(response.toString(), true, metrics))
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -301,15 +301,15 @@ private fun MessageBubble(message: Message) {
|
|||
formattedTime = message.formattedTime,
|
||||
content = message.content,
|
||||
isThinking = message.content.isBlank(),
|
||||
isComplete = false,
|
||||
isGenerating = true,
|
||||
metrics = null
|
||||
)
|
||||
|
||||
is Message.Assistant.Completed -> AssistantMessageBubble(
|
||||
is Message.Assistant.Stopped -> AssistantMessageBubble(
|
||||
formattedTime = message.formattedTime,
|
||||
content = message.content,
|
||||
isThinking = false,
|
||||
isComplete = true,
|
||||
isGenerating = false,
|
||||
metrics = message.metrics.text
|
||||
)
|
||||
}
|
||||
|
|
@ -357,7 +357,7 @@ private fun AssistantMessageBubble(
|
|||
formattedTime: String,
|
||||
content: String,
|
||||
isThinking: Boolean,
|
||||
isComplete: Boolean,
|
||||
isGenerating: Boolean,
|
||||
metrics: String? = null
|
||||
) {
|
||||
Row(
|
||||
|
|
@ -417,7 +417,7 @@ private fun AssistantMessageBubble(
|
|||
.padding(top = 4.dp),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
if (!isComplete) {
|
||||
if (isGenerating) {
|
||||
PulsatingDots(small = true)
|
||||
|
||||
Spacer(modifier = Modifier.width(4.dp))
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import kotlinx.coroutines.flow.StateFlow
|
|||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.flow.onCompletion
|
||||
import kotlinx.coroutines.launch
|
||||
import okhttp3.internal.toImmutableList
|
||||
import java.text.SimpleDateFormat
|
||||
import java.util.Date
|
||||
import java.util.Locale
|
||||
|
|
@ -94,76 +95,57 @@ class ConversationViewModel @Inject constructor(
|
|||
/**
|
||||
* Stop ongoing generation
|
||||
*/
|
||||
fun stopGeneration() {
|
||||
fun stopGeneration() =
|
||||
tokenCollectionJob?.let { job ->
|
||||
// handled by the catch blocks
|
||||
if (job.isActive) { job.cancel() }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle the case when generation is explicitly cancelled
|
||||
* Handle the case when generation is explicitly cancelled by adding a stopping suffix
|
||||
*/
|
||||
private fun handleCancellation() {
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
val lastIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
|
||||
|
||||
if (currentAssistantMessage != null) {
|
||||
// Replace with completed message, adding note that it was interrupted
|
||||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||
content = currentAssistantMessage.content + " [Generation stopped]",
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = conversationService.createTokenMetrics()
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
private fun handleCancellation() =
|
||||
_messages.value.toMutableList().apply {
|
||||
(removeLastOrNull() as? Message.Assistant.Stopped)?.let {
|
||||
add(it.copy(content = it.content + SUFFIX_GENERATION_STOPPED))
|
||||
_messages.value = toImmutableList()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle response error
|
||||
* Handle response error by appending an error suffix
|
||||
*/
|
||||
private fun handleResponseError(e: Exception) {
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
val lastIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
|
||||
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||
content = currentAssistantMessage.content + " [Error: ${e.message}]",
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = conversationService.createTokenMetrics()
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
private fun handleResponseError(e: Exception) =
|
||||
_messages.value.toMutableList().apply {
|
||||
(removeLastOrNull() as? Message.Assistant.Stopped)?.let {
|
||||
add(it.copy(content = it.content + SUFFIX_GENERATION_ERROR.format(e.message)))
|
||||
_messages.value = toImmutableList()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle updating the assistant message
|
||||
*/
|
||||
private fun updateAssistantMessage(update: GenerationUpdate) {
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
val lastIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
|
||||
|
||||
if (currentAssistantMessage != null) {
|
||||
if (update.isComplete) {
|
||||
// Final message with metrics
|
||||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||
content = update.text,
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = conversationService.createTokenMetrics()
|
||||
)
|
||||
} else {
|
||||
// Ongoing message update
|
||||
currentMessages[lastIndex] = Message.Assistant.Ongoing(
|
||||
content = update.text,
|
||||
timestamp = currentAssistantMessage.timestamp
|
||||
)
|
||||
private fun updateAssistantMessage(update: GenerationUpdate) =
|
||||
_messages.value.toMutableList().apply {
|
||||
(removeLastOrNull() as? Message.Assistant.Ongoing)?.let {
|
||||
if (update.metrics != null) {
|
||||
// Finalized message (partial or complete) with metrics
|
||||
add(Message.Assistant.Stopped(
|
||||
content = update.text,
|
||||
timestamp = it.timestamp,
|
||||
metrics = update.metrics
|
||||
))
|
||||
} else if (!update.isComplete) {
|
||||
// Ongoing message update
|
||||
add(Message.Assistant.Ongoing(
|
||||
content = update.text,
|
||||
timestamp = it.timestamp
|
||||
))
|
||||
}
|
||||
_messages.value = toImmutableList()
|
||||
}
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun performCleanup() = clearConversation()
|
||||
|
||||
|
|
@ -179,6 +161,11 @@ class ConversationViewModel @Inject constructor(
|
|||
stopGeneration()
|
||||
super.onCleared()
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val SUFFIX_GENERATION_STOPPED = " [Generation stopped]"
|
||||
private const val SUFFIX_GENERATION_ERROR = " [Error: %s]"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -203,7 +190,7 @@ sealed class Message {
|
|||
override val content: String,
|
||||
) : Assistant()
|
||||
|
||||
data class Completed(
|
||||
data class Stopped(
|
||||
override val timestamp: Long,
|
||||
override val content: String,
|
||||
val metrics: TokenMetrics
|
||||
|
|
|
|||
Loading…
Reference in New Issue