vm: replace token metrics stubs with actual implementation

This commit is contained in:
Han Yin 2025-04-12 22:32:32 -07:00
parent e47e3b77ee
commit 2a41c0e354
2 changed files with 127 additions and 57 deletions

View File

@ -65,6 +65,7 @@ import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.AppScaffold
import com.example.llama.revamp.viewmodel.MainViewModel
import com.example.llama.revamp.viewmodel.Message
import com.example.llama.revamp.viewmodel.TokenMetrics
import kotlinx.coroutines.launch
/**
@ -244,23 +245,22 @@ fun ConversationMessageList(
fun MessageBubble(message: Message) {
when (message) {
is Message.User -> UserMessageBubble(
content = message.content,
formattedTime = message.formattedTime
)
is Message.Assistant -> AssistantMessageBubble(
content = message.content,
formattedTime = message.formattedTime,
isComplete = message.isComplete,
isThinking = !message.isComplete && message.content.isBlank(),
metrics = if (message.isComplete && message.content.isNotBlank()) {
// TODO-han.yin: Generate some example metrics for now
// This would come from the actual LLM engine in a real implementation
val tokenCount = message.content.split("\\s+".toRegex()).size
val ttft = (200 + (Math.random() * 80)).toInt()
val tps = 8.5 + (Math.random() * 1.5)
"Tokens: $tokenCount, TTFT: ${ttft}ms, TPS: ${"%.1f".format(tps)}"
} else null
content = message.content
)
is Message.Assistant.Ongoing -> AssistantMessageBubble(
formattedTime = message.formattedTime,
content = message.content,
isThinking = message.content.isBlank(),
isComplete = false,
metrics = null
)
is Message.Assistant.Completed -> AssistantMessageBubble(
formattedTime = message.formattedTime,
content = message.content,
isThinking = false,
isComplete = true,
metrics = message.metrics.text
)
}
}
@ -304,10 +304,10 @@ fun UserMessageBubble(content: String, formattedTime: String) {
@Composable
fun AssistantMessageBubble(
content: String,
formattedTime: String,
isComplete: Boolean,
content: String,
isThinking: Boolean,
isComplete: Boolean,
metrics: String? = null
) {
Row(

View File

@ -33,11 +33,6 @@ class MainViewModel(
private val _selectedModel = MutableStateFlow<ModelInfo?>(null)
val selectedModel: StateFlow<ModelInfo?> = _selectedModel.asStateFlow()
// Benchmark parameters
private var pp: Int = 32
private var tg: Int = 32
private var pl: Int = 512
// Messages in the conversation
private val _messages = MutableStateFlow<List<Message>>(emptyList())
val messages: StateFlow<List<Message>> = _messages.asStateFlow()
@ -72,7 +67,7 @@ class MainViewModel(
* Runs the benchmark with current parameters.
*/
private suspend fun runBenchmark() {
inferenceEngine.bench(pp, tg, pl)
inferenceEngine.bench(512, 128, 1, 3)
}
/**
@ -94,6 +89,14 @@ class MainViewModel(
}
}
/**
* Tracks token generation metrics
*/
private var generationStartTime: Long = 0L
private var firstTokenTime: Long = 0L
private var tokenCount: Int = 0
private var isFirstToken: Boolean = true
/**
* Sends a user message and collects the response.
*/
@ -111,13 +114,18 @@ class MainViewModel(
_messages.value = _messages.value + userMessage
// Create placeholder for assistant message
val assistantMessage = Message.Assistant(
val assistantMessage = Message.Assistant.Ongoing(
content = "",
timestamp = System.currentTimeMillis(),
isComplete = false
timestamp = System.currentTimeMillis()
)
_messages.value = _messages.value + assistantMessage
// Reset metrics tracking
generationStartTime = System.currentTimeMillis()
firstTokenTime = 0L
tokenCount = 0
isFirstToken = true
// Get response from engine
tokenCollectionJob = viewModelScope.launch {
val response = StringBuilder()
@ -129,12 +137,19 @@ class MainViewModel(
val currentMessages = _messages.value.toMutableList()
if (currentMessages.size >= 2) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage =
currentMessages[messageIndex] as? Message.Assistant
val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant.Ongoing
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
// Create metrics with error indication
val errorMetrics = TokenMetrics(
tokensCount = tokenCount,
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime)
)
currentMessages[messageIndex] = Message.Assistant.Completed(
content = "${response}[Error: ${e.message}]",
isComplete = true
timestamp = currentAssistantMessage.timestamp,
metrics = errorMetrics
)
_messages.value = currentMessages
}
@ -145,29 +160,50 @@ class MainViewModel(
val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage =
currentMessages.getOrNull(messageIndex) as? Message.Assistant
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
isComplete = true
// Calculate final metrics
val endTime = System.currentTimeMillis()
val totalTimeMs = endTime - generationStartTime
val metrics = TokenMetrics(
tokensCount = tokenCount,
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
tpsMs = calculateTPS(tokenCount, totalTimeMs)
)
currentMessages[messageIndex] = Message.Assistant.Completed(
content = response.toString(),
timestamp = currentAssistantMessage.timestamp,
metrics = metrics
)
_messages.value = currentMessages
}
}
}
.collect { token ->
// Track first token time
if (isFirstToken && token.isNotBlank()) {
firstTokenTime = System.currentTimeMillis()
isFirstToken = false
}
// Count tokens - each non-empty emission is at least one token
if (token.isNotBlank()) {
tokenCount++
}
response.append(token)
// Safely update the assistant message with the generated text
val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage =
currentMessages.getOrNull(messageIndex) as? Message.Assistant
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
currentMessages[messageIndex] = Message.Assistant.Ongoing(
content = response.toString(),
isComplete = false
timestamp = currentAssistantMessage.timestamp
)
_messages.value = currentMessages
}
@ -178,12 +214,19 @@ class MainViewModel(
val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage =
currentMessages.getOrNull(messageIndex) as? Message.Assistant
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = currentAssistantMessage.copy(
// Create metrics with error indication
val errorMetrics = TokenMetrics(
tokensCount = tokenCount,
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime)
)
currentMessages[messageIndex] = Message.Assistant.Completed(
content = "${response}[Error: ${e.message}]",
isComplete = true
timestamp = currentAssistantMessage.timestamp,
metrics = errorMetrics
)
_messages.value = currentMessages
}
@ -192,6 +235,14 @@ class MainViewModel(
}
}
/**
* Calculate tokens per second.
*/
private fun calculateTPS(tokens: Int, timeMs: Long): Float {
if (tokens <= 0 || timeMs <= 0) return 0f
return (tokens.toFloat() * 1000f) / timeMs
}
/**
* Unloads the currently loaded model.
*/
@ -206,6 +257,9 @@ class MainViewModel(
inferenceEngine.unloadModel()
}
/**
* Checks if a model is currently being loaded.
*/
fun isModelLoading() =
engineState.value.let {
it is InferenceEngine.State.LoadingModel
@ -213,7 +267,7 @@ class MainViewModel(
}
/**
* Checks if a model is currently loaded.
* Checks if a model has already been loaded.
*/
fun isModelLoaded() =
engineState.value.let {
@ -247,24 +301,40 @@ class MainViewModel(
* Sealed class representing messages in a conversation.
*/
sealed class Message {
abstract val content: String
abstract val timestamp: Long
abstract val content: String
val formattedTime: String
get() {
val formatter = SimpleDateFormat("h:mm a", Locale.getDefault())
return formatter.format(Date(timestamp))
}
get() = datetimeFormatter.format(Date(timestamp))
data class User(
override val content: String,
override val timestamp: Long
override val timestamp: Long,
override val content: String
) : Message()
// TODO-han.yin: break down into ongoing & completed message subtypes
data class Assistant(
override val content: String,
override val timestamp: Long,
val isComplete: Boolean = true
) : Message()
sealed class Assistant : Message() {
data class Ongoing(
override val timestamp: Long,
override val content: String,
) : Assistant()
data class Completed(
override val timestamp: Long,
override val content: String,
val metrics: TokenMetrics
) : Assistant()
}
companion object {
private val datetimeFormatter by lazy { SimpleDateFormat("h:mm a", Locale.getDefault()) }
}
}
data class TokenMetrics(
val tokensCount: Int,
val ttftMs: Long,
val tpsMs: Float,
) {
val text: String
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
}