vm: replace token metrics stubs with actual implementation

This commit is contained in:
Han Yin 2025-04-12 22:32:32 -07:00
parent e47e3b77ee
commit 2a41c0e354
2 changed files with 127 additions and 57 deletions

View File

@ -65,6 +65,7 @@ import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.AppScaffold import com.example.llama.revamp.ui.components.AppScaffold
import com.example.llama.revamp.viewmodel.MainViewModel import com.example.llama.revamp.viewmodel.MainViewModel
import com.example.llama.revamp.viewmodel.Message import com.example.llama.revamp.viewmodel.Message
import com.example.llama.revamp.viewmodel.TokenMetrics
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
/** /**
@ -244,23 +245,22 @@ fun ConversationMessageList(
fun MessageBubble(message: Message) { fun MessageBubble(message: Message) {
when (message) { when (message) {
is Message.User -> UserMessageBubble( is Message.User -> UserMessageBubble(
content = message.content,
formattedTime = message.formattedTime
)
is Message.Assistant -> AssistantMessageBubble(
content = message.content,
formattedTime = message.formattedTime, formattedTime = message.formattedTime,
isComplete = message.isComplete, content = message.content
isThinking = !message.isComplete && message.content.isBlank(), )
metrics = if (message.isComplete && message.content.isNotBlank()) { is Message.Assistant.Ongoing -> AssistantMessageBubble(
// TODO-han.yin: Generate some example metrics for now formattedTime = message.formattedTime,
// This would come from the actual LLM engine in a real implementation content = message.content,
val tokenCount = message.content.split("\\s+".toRegex()).size isThinking = message.content.isBlank(),
val ttft = (200 + (Math.random() * 80)).toInt() isComplete = false,
val tps = 8.5 + (Math.random() * 1.5) metrics = null
"Tokens: $tokenCount, TTFT: ${ttft}ms, TPS: ${"%.1f".format(tps)}" )
} else null is Message.Assistant.Completed -> AssistantMessageBubble(
formattedTime = message.formattedTime,
content = message.content,
isThinking = false,
isComplete = true,
metrics = message.metrics.text
) )
} }
} }
@ -304,10 +304,10 @@ fun UserMessageBubble(content: String, formattedTime: String) {
@Composable @Composable
fun AssistantMessageBubble( fun AssistantMessageBubble(
content: String,
formattedTime: String, formattedTime: String,
isComplete: Boolean, content: String,
isThinking: Boolean, isThinking: Boolean,
isComplete: Boolean,
metrics: String? = null metrics: String? = null
) { ) {
Row( Row(

View File

@ -33,11 +33,6 @@ class MainViewModel(
private val _selectedModel = MutableStateFlow<ModelInfo?>(null) private val _selectedModel = MutableStateFlow<ModelInfo?>(null)
val selectedModel: StateFlow<ModelInfo?> = _selectedModel.asStateFlow() val selectedModel: StateFlow<ModelInfo?> = _selectedModel.asStateFlow()
// Benchmark parameters
private var pp: Int = 32
private var tg: Int = 32
private var pl: Int = 512
// Messages in the conversation // Messages in the conversation
private val _messages = MutableStateFlow<List<Message>>(emptyList()) private val _messages = MutableStateFlow<List<Message>>(emptyList())
val messages: StateFlow<List<Message>> = _messages.asStateFlow() val messages: StateFlow<List<Message>> = _messages.asStateFlow()
@ -72,7 +67,7 @@ class MainViewModel(
* Runs the benchmark with current parameters. * Runs the benchmark with current parameters.
*/ */
private suspend fun runBenchmark() { private suspend fun runBenchmark() {
inferenceEngine.bench(pp, tg, pl) inferenceEngine.bench(512, 128, 1, 3)
} }
/** /**
@ -94,6 +89,14 @@ class MainViewModel(
} }
} }
/**
* Tracks token generation metrics
*/
private var generationStartTime: Long = 0L
private var firstTokenTime: Long = 0L
private var tokenCount: Int = 0
private var isFirstToken: Boolean = true
/** /**
* Sends a user message and collects the response. * Sends a user message and collects the response.
*/ */
@ -111,13 +114,18 @@ class MainViewModel(
_messages.value = _messages.value + userMessage _messages.value = _messages.value + userMessage
// Create placeholder for assistant message // Create placeholder for assistant message
val assistantMessage = Message.Assistant( val assistantMessage = Message.Assistant.Ongoing(
content = "", content = "",
timestamp = System.currentTimeMillis(), timestamp = System.currentTimeMillis()
isComplete = false
) )
_messages.value = _messages.value + assistantMessage _messages.value = _messages.value + assistantMessage
// Reset metrics tracking
generationStartTime = System.currentTimeMillis()
firstTokenTime = 0L
tokenCount = 0
isFirstToken = true
// Get response from engine // Get response from engine
tokenCollectionJob = viewModelScope.launch { tokenCollectionJob = viewModelScope.launch {
val response = StringBuilder() val response = StringBuilder()
@ -129,12 +137,19 @@ class MainViewModel(
val currentMessages = _messages.value.toMutableList() val currentMessages = _messages.value.toMutableList()
if (currentMessages.size >= 2) { if (currentMessages.size >= 2) {
val messageIndex = currentMessages.size - 1 val messageIndex = currentMessages.size - 1
val currentAssistantMessage = val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant.Ongoing
currentMessages[messageIndex] as? Message.Assistant
if (currentAssistantMessage != null) { if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy( // Create metrics with error indication
val errorMetrics = TokenMetrics(
tokensCount = tokenCount,
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime)
)
currentMessages[messageIndex] = Message.Assistant.Completed(
content = "${response}[Error: ${e.message}]", content = "${response}[Error: ${e.message}]",
isComplete = true timestamp = currentAssistantMessage.timestamp,
metrics = errorMetrics
) )
_messages.value = currentMessages _messages.value = currentMessages
} }
@ -145,29 +160,50 @@ class MainViewModel(
val currentMessages = _messages.value.toMutableList() val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) { if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1 val messageIndex = currentMessages.size - 1
val currentAssistantMessage = val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) { if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy( // Calculate final metrics
isComplete = true val endTime = System.currentTimeMillis()
val totalTimeMs = endTime - generationStartTime
val metrics = TokenMetrics(
tokensCount = tokenCount,
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
tpsMs = calculateTPS(tokenCount, totalTimeMs)
)
currentMessages[messageIndex] = Message.Assistant.Completed(
content = response.toString(),
timestamp = currentAssistantMessage.timestamp,
metrics = metrics
) )
_messages.value = currentMessages _messages.value = currentMessages
} }
} }
} }
.collect { token -> .collect { token ->
// Track first token time
if (isFirstToken && token.isNotBlank()) {
firstTokenTime = System.currentTimeMillis()
isFirstToken = false
}
// Count tokens - each non-empty emission is at least one token
if (token.isNotBlank()) {
tokenCount++
}
response.append(token) response.append(token)
// Safely update the assistant message with the generated text // Safely update the assistant message with the generated text
val currentMessages = _messages.value.toMutableList() val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) { if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1 val messageIndex = currentMessages.size - 1
val currentAssistantMessage = val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) { if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy( currentMessages[messageIndex] = Message.Assistant.Ongoing(
content = response.toString(), content = response.toString(),
isComplete = false timestamp = currentAssistantMessage.timestamp
) )
_messages.value = currentMessages _messages.value = currentMessages
} }
@ -178,12 +214,19 @@ class MainViewModel(
val currentMessages = _messages.value.toMutableList() val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) { if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1 val messageIndex = currentMessages.size - 1
val currentAssistantMessage = val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
currentMessages.getOrNull(messageIndex) as? Message.Assistant
if (currentAssistantMessage != null) { if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy( // Create metrics with error indication
val errorMetrics = TokenMetrics(
tokensCount = tokenCount,
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime)
)
currentMessages[messageIndex] = Message.Assistant.Completed(
content = "${response}[Error: ${e.message}]", content = "${response}[Error: ${e.message}]",
isComplete = true timestamp = currentAssistantMessage.timestamp,
metrics = errorMetrics
) )
_messages.value = currentMessages _messages.value = currentMessages
} }
@ -192,6 +235,14 @@ class MainViewModel(
} }
} }
/**
* Calculate tokens per second.
*/
private fun calculateTPS(tokens: Int, timeMs: Long): Float {
if (tokens <= 0 || timeMs <= 0) return 0f
return (tokens.toFloat() * 1000f) / timeMs
}
/** /**
* Unloads the currently loaded model. * Unloads the currently loaded model.
*/ */
@ -206,6 +257,9 @@ class MainViewModel(
inferenceEngine.unloadModel() inferenceEngine.unloadModel()
} }
/**
* Checks if a model is currently being loaded.
*/
fun isModelLoading() = fun isModelLoading() =
engineState.value.let { engineState.value.let {
it is InferenceEngine.State.LoadingModel it is InferenceEngine.State.LoadingModel
@ -213,7 +267,7 @@ class MainViewModel(
} }
/** /**
* Checks if a model is currently loaded. * Checks if a model has already been loaded.
*/ */
fun isModelLoaded() = fun isModelLoaded() =
engineState.value.let { engineState.value.let {
@ -247,24 +301,40 @@ class MainViewModel(
* Sealed class representing messages in a conversation. * Sealed class representing messages in a conversation.
*/ */
sealed class Message { sealed class Message {
abstract val content: String
abstract val timestamp: Long abstract val timestamp: Long
abstract val content: String
val formattedTime: String val formattedTime: String
get() { get() = datetimeFormatter.format(Date(timestamp))
val formatter = SimpleDateFormat("h:mm a", Locale.getDefault())
return formatter.format(Date(timestamp))
}
data class User( data class User(
override val content: String, override val timestamp: Long,
override val timestamp: Long override val content: String
) : Message() ) : Message()
// TODO-han.yin: break down into ongoing & completed message subtypes sealed class Assistant : Message() {
data class Assistant( data class Ongoing(
override val content: String, override val timestamp: Long,
override val timestamp: Long, override val content: String,
val isComplete: Boolean = true ) : Assistant()
) : Message()
data class Completed(
override val timestamp: Long,
override val content: String,
val metrics: TokenMetrics
) : Assistant()
}
companion object {
private val datetimeFormatter by lazy { SimpleDateFormat("h:mm a", Locale.getDefault()) }
}
}
data class TokenMetrics(
val tokensCount: Int,
val ttftMs: Long,
val tpsMs: Float,
) {
val text: String
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
} }