From 2a41c0e3541c8d5c8605595e61feecaa9fb1cc5a Mon Sep 17 00:00:00 2001 From: Han Yin Date: Sat, 12 Apr 2025 22:32:32 -0700 Subject: [PATCH] vm: replace token metrics stubs with actual implementation --- .../revamp/ui/screens/ConversationScreen.kt | 36 ++--- .../llama/revamp/viewmodel/MainViewModel.kt | 148 +++++++++++++----- 2 files changed, 127 insertions(+), 57 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt index 4b567e7f4d..d343d1c9a5 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt @@ -65,6 +65,7 @@ import com.example.llama.revamp.navigation.NavigationActions import com.example.llama.revamp.ui.components.AppScaffold import com.example.llama.revamp.viewmodel.MainViewModel import com.example.llama.revamp.viewmodel.Message +import com.example.llama.revamp.viewmodel.TokenMetrics import kotlinx.coroutines.launch /** @@ -244,23 +245,22 @@ fun ConversationMessageList( fun MessageBubble(message: Message) { when (message) { is Message.User -> UserMessageBubble( - content = message.content, - formattedTime = message.formattedTime - ) - - is Message.Assistant -> AssistantMessageBubble( - content = message.content, formattedTime = message.formattedTime, - isComplete = message.isComplete, - isThinking = !message.isComplete && message.content.isBlank(), - metrics = if (message.isComplete && message.content.isNotBlank()) { - // TODO-han.yin: Generate some example metrics for now - // This would come from the actual LLM engine in a real implementation - val tokenCount = message.content.split("\\s+".toRegex()).size - val ttft = (200 + (Math.random() * 80)).toInt() - val tps = 8.5 + (Math.random() * 1.5) - "Tokens: $tokenCount, TTFT: ${ttft}ms, TPS: ${"%.1f".format(tps)}" - } else null + content = message.content + ) + is Message.Assistant.Ongoing -> AssistantMessageBubble( + formattedTime = message.formattedTime, + content = message.content, + isThinking = message.content.isBlank(), + isComplete = false, + metrics = null + ) + is Message.Assistant.Completed -> AssistantMessageBubble( + formattedTime = message.formattedTime, + content = message.content, + isThinking = false, + isComplete = true, + metrics = message.metrics.text ) } } @@ -304,10 +304,10 @@ fun UserMessageBubble(content: String, formattedTime: String) { @Composable fun AssistantMessageBubble( - content: String, formattedTime: String, - isComplete: Boolean, + content: String, isThinking: Boolean, + isComplete: Boolean, metrics: String? = null ) { Row( diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt index 8aae216c28..178f0126f6 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt @@ -33,11 +33,6 @@ class MainViewModel( private val _selectedModel = MutableStateFlow(null) val selectedModel: StateFlow = _selectedModel.asStateFlow() - // Benchmark parameters - private var pp: Int = 32 - private var tg: Int = 32 - private var pl: Int = 512 - // Messages in the conversation private val _messages = MutableStateFlow>(emptyList()) val messages: StateFlow> = _messages.asStateFlow() @@ -72,7 +67,7 @@ class MainViewModel( * Runs the benchmark with current parameters. */ private suspend fun runBenchmark() { - inferenceEngine.bench(pp, tg, pl) + inferenceEngine.bench(512, 128, 1, 3) } /** @@ -94,6 +89,14 @@ class MainViewModel( } } + /** + * Tracks token generation metrics + */ + private var generationStartTime: Long = 0L + private var firstTokenTime: Long = 0L + private var tokenCount: Int = 0 + private var isFirstToken: Boolean = true + /** * Sends a user message and collects the response. */ @@ -111,13 +114,18 @@ class MainViewModel( _messages.value = _messages.value + userMessage // Create placeholder for assistant message - val assistantMessage = Message.Assistant( + val assistantMessage = Message.Assistant.Ongoing( content = "", - timestamp = System.currentTimeMillis(), - isComplete = false + timestamp = System.currentTimeMillis() ) _messages.value = _messages.value + assistantMessage + // Reset metrics tracking + generationStartTime = System.currentTimeMillis() + firstTokenTime = 0L + tokenCount = 0 + isFirstToken = true + // Get response from engine tokenCollectionJob = viewModelScope.launch { val response = StringBuilder() @@ -129,12 +137,19 @@ class MainViewModel( val currentMessages = _messages.value.toMutableList() if (currentMessages.size >= 2) { val messageIndex = currentMessages.size - 1 - val currentAssistantMessage = - currentMessages[messageIndex] as? Message.Assistant + val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant.Ongoing if (currentAssistantMessage != null) { - currentMessages[messageIndex] = currentAssistantMessage.copy( + // Create metrics with error indication + val errorMetrics = TokenMetrics( + tokensCount = tokenCount, + ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L, + tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime) + ) + + currentMessages[messageIndex] = Message.Assistant.Completed( content = "${response}[Error: ${e.message}]", - isComplete = true + timestamp = currentAssistantMessage.timestamp, + metrics = errorMetrics ) _messages.value = currentMessages } @@ -145,29 +160,50 @@ class MainViewModel( val currentMessages = _messages.value.toMutableList() if (currentMessages.isNotEmpty()) { val messageIndex = currentMessages.size - 1 - val currentAssistantMessage = - currentMessages.getOrNull(messageIndex) as? Message.Assistant + val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing if (currentAssistantMessage != null) { - currentMessages[messageIndex] = currentAssistantMessage.copy( - isComplete = true + // Calculate final metrics + val endTime = System.currentTimeMillis() + val totalTimeMs = endTime - generationStartTime + + val metrics = TokenMetrics( + tokensCount = tokenCount, + ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L, + tpsMs = calculateTPS(tokenCount, totalTimeMs) + ) + + currentMessages[messageIndex] = Message.Assistant.Completed( + content = response.toString(), + timestamp = currentAssistantMessage.timestamp, + metrics = metrics ) _messages.value = currentMessages } } } .collect { token -> + // Track first token time + if (isFirstToken && token.isNotBlank()) { + firstTokenTime = System.currentTimeMillis() + isFirstToken = false + } + + // Count tokens - each non-empty emission is at least one token + if (token.isNotBlank()) { + tokenCount++ + } + response.append(token) // Safely update the assistant message with the generated text val currentMessages = _messages.value.toMutableList() if (currentMessages.isNotEmpty()) { val messageIndex = currentMessages.size - 1 - val currentAssistantMessage = - currentMessages.getOrNull(messageIndex) as? Message.Assistant + val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing if (currentAssistantMessage != null) { - currentMessages[messageIndex] = currentAssistantMessage.copy( + currentMessages[messageIndex] = Message.Assistant.Ongoing( content = response.toString(), - isComplete = false + timestamp = currentAssistantMessage.timestamp ) _messages.value = currentMessages } @@ -178,12 +214,19 @@ class MainViewModel( val currentMessages = _messages.value.toMutableList() if (currentMessages.isNotEmpty()) { val messageIndex = currentMessages.size - 1 - val currentAssistantMessage = - currentMessages.getOrNull(messageIndex) as? Message.Assistant + val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing if (currentAssistantMessage != null) { - currentMessages[messageIndex] = currentAssistantMessage.copy( + // Create metrics with error indication + val errorMetrics = TokenMetrics( + tokensCount = tokenCount, + ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L, + tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime) + ) + + currentMessages[messageIndex] = Message.Assistant.Completed( content = "${response}[Error: ${e.message}]", - isComplete = true + timestamp = currentAssistantMessage.timestamp, + metrics = errorMetrics ) _messages.value = currentMessages } @@ -192,6 +235,14 @@ class MainViewModel( } } + /** + * Calculate tokens per second. + */ + private fun calculateTPS(tokens: Int, timeMs: Long): Float { + if (tokens <= 0 || timeMs <= 0) return 0f + return (tokens.toFloat() * 1000f) / timeMs + } + /** * Unloads the currently loaded model. */ @@ -206,6 +257,9 @@ class MainViewModel( inferenceEngine.unloadModel() } + /** + * Checks if a model is currently being loaded. + */ fun isModelLoading() = engineState.value.let { it is InferenceEngine.State.LoadingModel @@ -213,7 +267,7 @@ class MainViewModel( } /** - * Checks if a model is currently loaded. + * Checks if a model has already been loaded. */ fun isModelLoaded() = engineState.value.let { @@ -247,24 +301,40 @@ class MainViewModel( * Sealed class representing messages in a conversation. */ sealed class Message { - abstract val content: String abstract val timestamp: Long + abstract val content: String val formattedTime: String - get() { - val formatter = SimpleDateFormat("h:mm a", Locale.getDefault()) - return formatter.format(Date(timestamp)) - } + get() = datetimeFormatter.format(Date(timestamp)) data class User( - override val content: String, - override val timestamp: Long + override val timestamp: Long, + override val content: String ) : Message() - // TODO-han.yin: break down into ongoing & completed message subtypes - data class Assistant( - override val content: String, - override val timestamp: Long, - val isComplete: Boolean = true - ) : Message() + sealed class Assistant : Message() { + data class Ongoing( + override val timestamp: Long, + override val content: String, + ) : Assistant() + + data class Completed( + override val timestamp: Long, + override val content: String, + val metrics: TokenMetrics + ) : Assistant() + } + + companion object { + private val datetimeFormatter by lazy { SimpleDateFormat("h:mm a", Locale.getDefault()) } + } +} + +data class TokenMetrics( + val tokensCount: Int, + val ttftMs: Long, + val tpsMs: Float, +) { + val text: String + get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}" }