vm: replace token metrics stubs with actual implementation
This commit is contained in:
parent
e47e3b77ee
commit
2a41c0e354
|
|
@ -65,6 +65,7 @@ import com.example.llama.revamp.navigation.NavigationActions
|
|||
import com.example.llama.revamp.ui.components.AppScaffold
|
||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
||||
import com.example.llama.revamp.viewmodel.Message
|
||||
import com.example.llama.revamp.viewmodel.TokenMetrics
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
/**
|
||||
|
|
@ -244,23 +245,22 @@ fun ConversationMessageList(
|
|||
fun MessageBubble(message: Message) {
|
||||
when (message) {
|
||||
is Message.User -> UserMessageBubble(
|
||||
content = message.content,
|
||||
formattedTime = message.formattedTime
|
||||
)
|
||||
|
||||
is Message.Assistant -> AssistantMessageBubble(
|
||||
content = message.content,
|
||||
formattedTime = message.formattedTime,
|
||||
isComplete = message.isComplete,
|
||||
isThinking = !message.isComplete && message.content.isBlank(),
|
||||
metrics = if (message.isComplete && message.content.isNotBlank()) {
|
||||
// TODO-han.yin: Generate some example metrics for now
|
||||
// This would come from the actual LLM engine in a real implementation
|
||||
val tokenCount = message.content.split("\\s+".toRegex()).size
|
||||
val ttft = (200 + (Math.random() * 80)).toInt()
|
||||
val tps = 8.5 + (Math.random() * 1.5)
|
||||
"Tokens: $tokenCount, TTFT: ${ttft}ms, TPS: ${"%.1f".format(tps)}"
|
||||
} else null
|
||||
content = message.content
|
||||
)
|
||||
is Message.Assistant.Ongoing -> AssistantMessageBubble(
|
||||
formattedTime = message.formattedTime,
|
||||
content = message.content,
|
||||
isThinking = message.content.isBlank(),
|
||||
isComplete = false,
|
||||
metrics = null
|
||||
)
|
||||
is Message.Assistant.Completed -> AssistantMessageBubble(
|
||||
formattedTime = message.formattedTime,
|
||||
content = message.content,
|
||||
isThinking = false,
|
||||
isComplete = true,
|
||||
metrics = message.metrics.text
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -304,10 +304,10 @@ fun UserMessageBubble(content: String, formattedTime: String) {
|
|||
|
||||
@Composable
|
||||
fun AssistantMessageBubble(
|
||||
content: String,
|
||||
formattedTime: String,
|
||||
isComplete: Boolean,
|
||||
content: String,
|
||||
isThinking: Boolean,
|
||||
isComplete: Boolean,
|
||||
metrics: String? = null
|
||||
) {
|
||||
Row(
|
||||
|
|
|
|||
|
|
@ -33,11 +33,6 @@ class MainViewModel(
|
|||
private val _selectedModel = MutableStateFlow<ModelInfo?>(null)
|
||||
val selectedModel: StateFlow<ModelInfo?> = _selectedModel.asStateFlow()
|
||||
|
||||
// Benchmark parameters
|
||||
private var pp: Int = 32
|
||||
private var tg: Int = 32
|
||||
private var pl: Int = 512
|
||||
|
||||
// Messages in the conversation
|
||||
private val _messages = MutableStateFlow<List<Message>>(emptyList())
|
||||
val messages: StateFlow<List<Message>> = _messages.asStateFlow()
|
||||
|
|
@ -72,7 +67,7 @@ class MainViewModel(
|
|||
* Runs the benchmark with current parameters.
|
||||
*/
|
||||
private suspend fun runBenchmark() {
|
||||
inferenceEngine.bench(pp, tg, pl)
|
||||
inferenceEngine.bench(512, 128, 1, 3)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -94,6 +89,14 @@ class MainViewModel(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tracks token generation metrics
|
||||
*/
|
||||
private var generationStartTime: Long = 0L
|
||||
private var firstTokenTime: Long = 0L
|
||||
private var tokenCount: Int = 0
|
||||
private var isFirstToken: Boolean = true
|
||||
|
||||
/**
|
||||
* Sends a user message and collects the response.
|
||||
*/
|
||||
|
|
@ -111,13 +114,18 @@ class MainViewModel(
|
|||
_messages.value = _messages.value + userMessage
|
||||
|
||||
// Create placeholder for assistant message
|
||||
val assistantMessage = Message.Assistant(
|
||||
val assistantMessage = Message.Assistant.Ongoing(
|
||||
content = "",
|
||||
timestamp = System.currentTimeMillis(),
|
||||
isComplete = false
|
||||
timestamp = System.currentTimeMillis()
|
||||
)
|
||||
_messages.value = _messages.value + assistantMessage
|
||||
|
||||
// Reset metrics tracking
|
||||
generationStartTime = System.currentTimeMillis()
|
||||
firstTokenTime = 0L
|
||||
tokenCount = 0
|
||||
isFirstToken = true
|
||||
|
||||
// Get response from engine
|
||||
tokenCollectionJob = viewModelScope.launch {
|
||||
val response = StringBuilder()
|
||||
|
|
@ -129,12 +137,19 @@ class MainViewModel(
|
|||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.size >= 2) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage =
|
||||
currentMessages[messageIndex] as? Message.Assistant
|
||||
val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant.Ongoing
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||
// Create metrics with error indication
|
||||
val errorMetrics = TokenMetrics(
|
||||
tokensCount = tokenCount,
|
||||
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
|
||||
tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime)
|
||||
)
|
||||
|
||||
currentMessages[messageIndex] = Message.Assistant.Completed(
|
||||
content = "${response}[Error: ${e.message}]",
|
||||
isComplete = true
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = errorMetrics
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
|
|
@ -145,29 +160,50 @@ class MainViewModel(
|
|||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.isNotEmpty()) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage =
|
||||
currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||
isComplete = true
|
||||
// Calculate final metrics
|
||||
val endTime = System.currentTimeMillis()
|
||||
val totalTimeMs = endTime - generationStartTime
|
||||
|
||||
val metrics = TokenMetrics(
|
||||
tokensCount = tokenCount,
|
||||
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
|
||||
tpsMs = calculateTPS(tokenCount, totalTimeMs)
|
||||
)
|
||||
|
||||
currentMessages[messageIndex] = Message.Assistant.Completed(
|
||||
content = response.toString(),
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = metrics
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
}
|
||||
}
|
||||
.collect { token ->
|
||||
// Track first token time
|
||||
if (isFirstToken && token.isNotBlank()) {
|
||||
firstTokenTime = System.currentTimeMillis()
|
||||
isFirstToken = false
|
||||
}
|
||||
|
||||
// Count tokens - each non-empty emission is at least one token
|
||||
if (token.isNotBlank()) {
|
||||
tokenCount++
|
||||
}
|
||||
|
||||
response.append(token)
|
||||
|
||||
// Safely update the assistant message with the generated text
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.isNotEmpty()) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage =
|
||||
currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||
currentMessages[messageIndex] = Message.Assistant.Ongoing(
|
||||
content = response.toString(),
|
||||
isComplete = false
|
||||
timestamp = currentAssistantMessage.timestamp
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
|
|
@ -178,12 +214,19 @@ class MainViewModel(
|
|||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.isNotEmpty()) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage =
|
||||
currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
||||
// Create metrics with error indication
|
||||
val errorMetrics = TokenMetrics(
|
||||
tokensCount = tokenCount,
|
||||
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
|
||||
tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime)
|
||||
)
|
||||
|
||||
currentMessages[messageIndex] = Message.Assistant.Completed(
|
||||
content = "${response}[Error: ${e.message}]",
|
||||
isComplete = true
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = errorMetrics
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
|
|
@ -192,6 +235,14 @@ class MainViewModel(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens per second.
|
||||
*/
|
||||
private fun calculateTPS(tokens: Int, timeMs: Long): Float {
|
||||
if (tokens <= 0 || timeMs <= 0) return 0f
|
||||
return (tokens.toFloat() * 1000f) / timeMs
|
||||
}
|
||||
|
||||
/**
|
||||
* Unloads the currently loaded model.
|
||||
*/
|
||||
|
|
@ -206,6 +257,9 @@ class MainViewModel(
|
|||
inferenceEngine.unloadModel()
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a model is currently being loaded.
|
||||
*/
|
||||
fun isModelLoading() =
|
||||
engineState.value.let {
|
||||
it is InferenceEngine.State.LoadingModel
|
||||
|
|
@ -213,7 +267,7 @@ class MainViewModel(
|
|||
}
|
||||
|
||||
/**
|
||||
* Checks if a model is currently loaded.
|
||||
* Checks if a model has already been loaded.
|
||||
*/
|
||||
fun isModelLoaded() =
|
||||
engineState.value.let {
|
||||
|
|
@ -247,24 +301,40 @@ class MainViewModel(
|
|||
* Sealed class representing messages in a conversation.
|
||||
*/
|
||||
sealed class Message {
|
||||
abstract val content: String
|
||||
abstract val timestamp: Long
|
||||
abstract val content: String
|
||||
|
||||
val formattedTime: String
|
||||
get() {
|
||||
val formatter = SimpleDateFormat("h:mm a", Locale.getDefault())
|
||||
return formatter.format(Date(timestamp))
|
||||
}
|
||||
get() = datetimeFormatter.format(Date(timestamp))
|
||||
|
||||
data class User(
|
||||
override val content: String,
|
||||
override val timestamp: Long
|
||||
override val timestamp: Long,
|
||||
override val content: String
|
||||
) : Message()
|
||||
|
||||
// TODO-han.yin: break down into ongoing & completed message subtypes
|
||||
data class Assistant(
|
||||
override val content: String,
|
||||
override val timestamp: Long,
|
||||
val isComplete: Boolean = true
|
||||
) : Message()
|
||||
sealed class Assistant : Message() {
|
||||
data class Ongoing(
|
||||
override val timestamp: Long,
|
||||
override val content: String,
|
||||
) : Assistant()
|
||||
|
||||
data class Completed(
|
||||
override val timestamp: Long,
|
||||
override val content: String,
|
||||
val metrics: TokenMetrics
|
||||
) : Assistant()
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val datetimeFormatter by lazy { SimpleDateFormat("h:mm a", Locale.getDefault()) }
|
||||
}
|
||||
}
|
||||
|
||||
data class TokenMetrics(
|
||||
val tokensCount: Int,
|
||||
val ttftMs: Long,
|
||||
val tpsMs: Float,
|
||||
) {
|
||||
val text: String
|
||||
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue