From 7968216235583e3ea8318c01202d7f0cfef18367 Mon Sep 17 00:00:00 2001 From: Han Yin Date: Mon, 21 Jul 2025 14:39:55 -0700 Subject: [PATCH] UI: avoid duplicated calculation of token metrics --- .../example/llama/engine/InferenceServices.kt | 10 +- .../llama/ui/screens/ConversationScreen.kt | 10 +- .../llama/viewmodel/ConversationViewModel.kt | 93 ++++++++----------- 3 files changed, 51 insertions(+), 62 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/engine/InferenceServices.kt b/examples/llama.android/app/src/main/java/com/example/llama/engine/InferenceServices.kt index c9ffec68c8..00b29506d2 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/engine/InferenceServices.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/engine/InferenceServices.kt @@ -96,7 +96,8 @@ data class ModelLoadingMetrics( */ data class GenerationUpdate( val text: String, - val isComplete: Boolean + val isComplete: Boolean, + val metrics: TokenMetrics? = null ) /** @@ -221,6 +222,8 @@ internal class InferenceServiceImpl @Inject internal constructor( private var isFirstToken: Boolean = true override fun generateResponse(prompt: String): Flow = flow { + val response = StringBuilder() + try { // Reset metrics tracking generationStartTime = System.currentTimeMillis() @@ -228,8 +231,6 @@ internal class InferenceServiceImpl @Inject internal constructor( tokenCount = 0 isFirstToken = true - val response = StringBuilder() - inferenceEngine.sendUserPrompt(prompt) .collect { token -> // Track first token time @@ -253,10 +254,11 @@ internal class InferenceServiceImpl @Inject internal constructor( val metrics = createTokenMetrics() // Emit final response with completion flag - emit(GenerationUpdate(response.toString(), true)) + emit(GenerationUpdate(response.toString(), true, metrics)) } catch (e: Exception) { // Emit error val metrics = createTokenMetrics() + emit(GenerationUpdate(response.toString(), true, metrics)) throw e } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ConversationScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ConversationScreen.kt index 7cf2295c6b..1c680e67b8 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ConversationScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ConversationScreen.kt @@ -301,15 +301,15 @@ private fun MessageBubble(message: Message) { formattedTime = message.formattedTime, content = message.content, isThinking = message.content.isBlank(), - isComplete = false, + isGenerating = true, metrics = null ) - is Message.Assistant.Completed -> AssistantMessageBubble( + is Message.Assistant.Stopped -> AssistantMessageBubble( formattedTime = message.formattedTime, content = message.content, isThinking = false, - isComplete = true, + isGenerating = false, metrics = message.metrics.text ) } @@ -357,7 +357,7 @@ private fun AssistantMessageBubble( formattedTime: String, content: String, isThinking: Boolean, - isComplete: Boolean, + isGenerating: Boolean, metrics: String? = null ) { Row( @@ -417,7 +417,7 @@ private fun AssistantMessageBubble( .padding(top = 4.dp), verticalAlignment = Alignment.CenterVertically ) { - if (!isComplete) { + if (isGenerating) { PulsatingDots(small = true) Spacer(modifier = Modifier.width(4.dp)) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ConversationViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ConversationViewModel.kt index e19c55ddba..c650caef4c 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ConversationViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ConversationViewModel.kt @@ -14,6 +14,7 @@ import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.launch +import okhttp3.internal.toImmutableList import java.text.SimpleDateFormat import java.util.Date import java.util.Locale @@ -94,76 +95,57 @@ class ConversationViewModel @Inject constructor( /** * Stop ongoing generation */ - fun stopGeneration() { + fun stopGeneration() = tokenCollectionJob?.let { job -> // handled by the catch blocks if (job.isActive) { job.cancel() } } - } /** - * Handle the case when generation is explicitly cancelled + * Handle the case when generation is explicitly cancelled by adding a stopping suffix */ - private fun handleCancellation() { - val currentMessages = _messages.value.toMutableList() - val lastIndex = currentMessages.size - 1 - val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing - - if (currentAssistantMessage != null) { - // Replace with completed message, adding note that it was interrupted - currentMessages[lastIndex] = Message.Assistant.Completed( - content = currentAssistantMessage.content + " [Generation stopped]", - timestamp = currentAssistantMessage.timestamp, - metrics = conversationService.createTokenMetrics() - ) - _messages.value = currentMessages + private fun handleCancellation() = + _messages.value.toMutableList().apply { + (removeLastOrNull() as? Message.Assistant.Stopped)?.let { + add(it.copy(content = it.content + SUFFIX_GENERATION_STOPPED)) + _messages.value = toImmutableList() + } } - } /** - * Handle response error + * Handle response error by appending an error suffix */ - private fun handleResponseError(e: Exception) { - val currentMessages = _messages.value.toMutableList() - val lastIndex = currentMessages.size - 1 - val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing - - if (currentAssistantMessage != null) { - currentMessages[lastIndex] = Message.Assistant.Completed( - content = currentAssistantMessage.content + " [Error: ${e.message}]", - timestamp = currentAssistantMessage.timestamp, - metrics = conversationService.createTokenMetrics() - ) - _messages.value = currentMessages + private fun handleResponseError(e: Exception) = + _messages.value.toMutableList().apply { + (removeLastOrNull() as? Message.Assistant.Stopped)?.let { + add(it.copy(content = it.content + SUFFIX_GENERATION_ERROR.format(e.message))) + _messages.value = toImmutableList() + } } - } /** * Handle updating the assistant message */ - private fun updateAssistantMessage(update: GenerationUpdate) { - val currentMessages = _messages.value.toMutableList() - val lastIndex = currentMessages.size - 1 - val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing - - if (currentAssistantMessage != null) { - if (update.isComplete) { - // Final message with metrics - currentMessages[lastIndex] = Message.Assistant.Completed( - content = update.text, - timestamp = currentAssistantMessage.timestamp, - metrics = conversationService.createTokenMetrics() - ) - } else { - // Ongoing message update - currentMessages[lastIndex] = Message.Assistant.Ongoing( - content = update.text, - timestamp = currentAssistantMessage.timestamp - ) + private fun updateAssistantMessage(update: GenerationUpdate) = + _messages.value.toMutableList().apply { + (removeLastOrNull() as? Message.Assistant.Ongoing)?.let { + if (update.metrics != null) { + // Finalized message (partial or complete) with metrics + add(Message.Assistant.Stopped( + content = update.text, + timestamp = it.timestamp, + metrics = update.metrics + )) + } else if (!update.isComplete) { + // Ongoing message update + add(Message.Assistant.Ongoing( + content = update.text, + timestamp = it.timestamp + )) + } + _messages.value = toImmutableList() } - _messages.value = currentMessages } - } override suspend fun performCleanup() = clearConversation() @@ -179,6 +161,11 @@ class ConversationViewModel @Inject constructor( stopGeneration() super.onCleared() } + + companion object { + private const val SUFFIX_GENERATION_STOPPED = " [Generation stopped]" + private const val SUFFIX_GENERATION_ERROR = " [Error: %s]" + } } @@ -203,7 +190,7 @@ sealed class Message { override val content: String, ) : Assistant() - data class Completed( + data class Stopped( override val timestamp: Long, override val content: String, val metrics: TokenMetrics