UI: avoid duplicated calculation of token metrics

This commit is contained in:
Han Yin 2025-07-21 14:39:55 -07:00
parent dd5b20d74d
commit 7968216235
3 changed files with 51 additions and 62 deletions

View File

@ -96,7 +96,8 @@ data class ModelLoadingMetrics(
*/ */
data class GenerationUpdate( data class GenerationUpdate(
val text: String, val text: String,
val isComplete: Boolean val isComplete: Boolean,
val metrics: TokenMetrics? = null
) )
/** /**
@ -221,6 +222,8 @@ internal class InferenceServiceImpl @Inject internal constructor(
private var isFirstToken: Boolean = true private var isFirstToken: Boolean = true
override fun generateResponse(prompt: String): Flow<GenerationUpdate> = flow { override fun generateResponse(prompt: String): Flow<GenerationUpdate> = flow {
val response = StringBuilder()
try { try {
// Reset metrics tracking // Reset metrics tracking
generationStartTime = System.currentTimeMillis() generationStartTime = System.currentTimeMillis()
@ -228,8 +231,6 @@ internal class InferenceServiceImpl @Inject internal constructor(
tokenCount = 0 tokenCount = 0
isFirstToken = true isFirstToken = true
val response = StringBuilder()
inferenceEngine.sendUserPrompt(prompt) inferenceEngine.sendUserPrompt(prompt)
.collect { token -> .collect { token ->
// Track first token time // Track first token time
@ -253,10 +254,11 @@ internal class InferenceServiceImpl @Inject internal constructor(
val metrics = createTokenMetrics() val metrics = createTokenMetrics()
// Emit final response with completion flag // Emit final response with completion flag
emit(GenerationUpdate(response.toString(), true)) emit(GenerationUpdate(response.toString(), true, metrics))
} catch (e: Exception) { } catch (e: Exception) {
// Emit error // Emit error
val metrics = createTokenMetrics() val metrics = createTokenMetrics()
emit(GenerationUpdate(response.toString(), true, metrics))
throw e throw e
} }
} }

View File

@ -301,15 +301,15 @@ private fun MessageBubble(message: Message) {
formattedTime = message.formattedTime, formattedTime = message.formattedTime,
content = message.content, content = message.content,
isThinking = message.content.isBlank(), isThinking = message.content.isBlank(),
isComplete = false, isGenerating = true,
metrics = null metrics = null
) )
is Message.Assistant.Completed -> AssistantMessageBubble( is Message.Assistant.Stopped -> AssistantMessageBubble(
formattedTime = message.formattedTime, formattedTime = message.formattedTime,
content = message.content, content = message.content,
isThinking = false, isThinking = false,
isComplete = true, isGenerating = false,
metrics = message.metrics.text metrics = message.metrics.text
) )
} }
@ -357,7 +357,7 @@ private fun AssistantMessageBubble(
formattedTime: String, formattedTime: String,
content: String, content: String,
isThinking: Boolean, isThinking: Boolean,
isComplete: Boolean, isGenerating: Boolean,
metrics: String? = null metrics: String? = null
) { ) {
Row( Row(
@ -417,7 +417,7 @@ private fun AssistantMessageBubble(
.padding(top = 4.dp), .padding(top = 4.dp),
verticalAlignment = Alignment.CenterVertically verticalAlignment = Alignment.CenterVertically
) { ) {
if (!isComplete) { if (isGenerating) {
PulsatingDots(small = true) PulsatingDots(small = true)
Spacer(modifier = Modifier.width(4.dp)) Spacer(modifier = Modifier.width(4.dp))

View File

@ -14,6 +14,7 @@ import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import okhttp3.internal.toImmutableList
import java.text.SimpleDateFormat import java.text.SimpleDateFormat
import java.util.Date import java.util.Date
import java.util.Locale import java.util.Locale
@ -94,74 +95,55 @@ class ConversationViewModel @Inject constructor(
/** /**
* Stop ongoing generation * Stop ongoing generation
*/ */
fun stopGeneration() { fun stopGeneration() =
tokenCollectionJob?.let { job -> tokenCollectionJob?.let { job ->
// handled by the catch blocks // handled by the catch blocks
if (job.isActive) { job.cancel() } if (job.isActive) { job.cancel() }
} }
}
/** /**
* Handle the case when generation is explicitly cancelled * Handle the case when generation is explicitly cancelled by adding a stopping suffix
*/ */
private fun handleCancellation() { private fun handleCancellation() =
val currentMessages = _messages.value.toMutableList() _messages.value.toMutableList().apply {
val lastIndex = currentMessages.size - 1 (removeLastOrNull() as? Message.Assistant.Stopped)?.let {
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing add(it.copy(content = it.content + SUFFIX_GENERATION_STOPPED))
_messages.value = toImmutableList()
if (currentAssistantMessage != null) {
// Replace with completed message, adding note that it was interrupted
currentMessages[lastIndex] = Message.Assistant.Completed(
content = currentAssistantMessage.content + " [Generation stopped]",
timestamp = currentAssistantMessage.timestamp,
metrics = conversationService.createTokenMetrics()
)
_messages.value = currentMessages
} }
} }
/** /**
* Handle response error * Handle response error by appending an error suffix
*/ */
private fun handleResponseError(e: Exception) { private fun handleResponseError(e: Exception) =
val currentMessages = _messages.value.toMutableList() _messages.value.toMutableList().apply {
val lastIndex = currentMessages.size - 1 (removeLastOrNull() as? Message.Assistant.Stopped)?.let {
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing add(it.copy(content = it.content + SUFFIX_GENERATION_ERROR.format(e.message)))
_messages.value = toImmutableList()
if (currentAssistantMessage != null) {
currentMessages[lastIndex] = Message.Assistant.Completed(
content = currentAssistantMessage.content + " [Error: ${e.message}]",
timestamp = currentAssistantMessage.timestamp,
metrics = conversationService.createTokenMetrics()
)
_messages.value = currentMessages
} }
} }
/** /**
* Handle updating the assistant message * Handle updating the assistant message
*/ */
private fun updateAssistantMessage(update: GenerationUpdate) { private fun updateAssistantMessage(update: GenerationUpdate) =
val currentMessages = _messages.value.toMutableList() _messages.value.toMutableList().apply {
val lastIndex = currentMessages.size - 1 (removeLastOrNull() as? Message.Assistant.Ongoing)?.let {
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing if (update.metrics != null) {
// Finalized message (partial or complete) with metrics
if (currentAssistantMessage != null) { add(Message.Assistant.Stopped(
if (update.isComplete) {
// Final message with metrics
currentMessages[lastIndex] = Message.Assistant.Completed(
content = update.text, content = update.text,
timestamp = currentAssistantMessage.timestamp, timestamp = it.timestamp,
metrics = conversationService.createTokenMetrics() metrics = update.metrics
) ))
} else { } else if (!update.isComplete) {
// Ongoing message update // Ongoing message update
currentMessages[lastIndex] = Message.Assistant.Ongoing( add(Message.Assistant.Ongoing(
content = update.text, content = update.text,
timestamp = currentAssistantMessage.timestamp timestamp = it.timestamp
) ))
} }
_messages.value = currentMessages _messages.value = toImmutableList()
} }
} }
@ -179,6 +161,11 @@ class ConversationViewModel @Inject constructor(
stopGeneration() stopGeneration()
super.onCleared() super.onCleared()
} }
companion object {
private const val SUFFIX_GENERATION_STOPPED = " [Generation stopped]"
private const val SUFFIX_GENERATION_ERROR = " [Error: %s]"
}
} }
@ -203,7 +190,7 @@ sealed class Message {
override val content: String, override val content: String,
) : Assistant() ) : Assistant()
data class Completed( data class Stopped(
override val timestamp: Long, override val timestamp: Long,
override val content: String, override val content: String,
val metrics: TokenMetrics val metrics: TokenMetrics