vm: replace token metrics stubs with actual implementation
This commit is contained in:
parent
e47e3b77ee
commit
2a41c0e354
|
|
@ -65,6 +65,7 @@ import com.example.llama.revamp.navigation.NavigationActions
|
||||||
import com.example.llama.revamp.ui.components.AppScaffold
|
import com.example.llama.revamp.ui.components.AppScaffold
|
||||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
import com.example.llama.revamp.viewmodel.MainViewModel
|
||||||
import com.example.llama.revamp.viewmodel.Message
|
import com.example.llama.revamp.viewmodel.Message
|
||||||
|
import com.example.llama.revamp.viewmodel.TokenMetrics
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -244,23 +245,22 @@ fun ConversationMessageList(
|
||||||
fun MessageBubble(message: Message) {
|
fun MessageBubble(message: Message) {
|
||||||
when (message) {
|
when (message) {
|
||||||
is Message.User -> UserMessageBubble(
|
is Message.User -> UserMessageBubble(
|
||||||
content = message.content,
|
|
||||||
formattedTime = message.formattedTime
|
|
||||||
)
|
|
||||||
|
|
||||||
is Message.Assistant -> AssistantMessageBubble(
|
|
||||||
content = message.content,
|
|
||||||
formattedTime = message.formattedTime,
|
formattedTime = message.formattedTime,
|
||||||
isComplete = message.isComplete,
|
content = message.content
|
||||||
isThinking = !message.isComplete && message.content.isBlank(),
|
)
|
||||||
metrics = if (message.isComplete && message.content.isNotBlank()) {
|
is Message.Assistant.Ongoing -> AssistantMessageBubble(
|
||||||
// TODO-han.yin: Generate some example metrics for now
|
formattedTime = message.formattedTime,
|
||||||
// This would come from the actual LLM engine in a real implementation
|
content = message.content,
|
||||||
val tokenCount = message.content.split("\\s+".toRegex()).size
|
isThinking = message.content.isBlank(),
|
||||||
val ttft = (200 + (Math.random() * 80)).toInt()
|
isComplete = false,
|
||||||
val tps = 8.5 + (Math.random() * 1.5)
|
metrics = null
|
||||||
"Tokens: $tokenCount, TTFT: ${ttft}ms, TPS: ${"%.1f".format(tps)}"
|
)
|
||||||
} else null
|
is Message.Assistant.Completed -> AssistantMessageBubble(
|
||||||
|
formattedTime = message.formattedTime,
|
||||||
|
content = message.content,
|
||||||
|
isThinking = false,
|
||||||
|
isComplete = true,
|
||||||
|
metrics = message.metrics.text
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -304,10 +304,10 @@ fun UserMessageBubble(content: String, formattedTime: String) {
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
fun AssistantMessageBubble(
|
fun AssistantMessageBubble(
|
||||||
content: String,
|
|
||||||
formattedTime: String,
|
formattedTime: String,
|
||||||
isComplete: Boolean,
|
content: String,
|
||||||
isThinking: Boolean,
|
isThinking: Boolean,
|
||||||
|
isComplete: Boolean,
|
||||||
metrics: String? = null
|
metrics: String? = null
|
||||||
) {
|
) {
|
||||||
Row(
|
Row(
|
||||||
|
|
|
||||||
|
|
@ -33,11 +33,6 @@ class MainViewModel(
|
||||||
private val _selectedModel = MutableStateFlow<ModelInfo?>(null)
|
private val _selectedModel = MutableStateFlow<ModelInfo?>(null)
|
||||||
val selectedModel: StateFlow<ModelInfo?> = _selectedModel.asStateFlow()
|
val selectedModel: StateFlow<ModelInfo?> = _selectedModel.asStateFlow()
|
||||||
|
|
||||||
// Benchmark parameters
|
|
||||||
private var pp: Int = 32
|
|
||||||
private var tg: Int = 32
|
|
||||||
private var pl: Int = 512
|
|
||||||
|
|
||||||
// Messages in the conversation
|
// Messages in the conversation
|
||||||
private val _messages = MutableStateFlow<List<Message>>(emptyList())
|
private val _messages = MutableStateFlow<List<Message>>(emptyList())
|
||||||
val messages: StateFlow<List<Message>> = _messages.asStateFlow()
|
val messages: StateFlow<List<Message>> = _messages.asStateFlow()
|
||||||
|
|
@ -72,7 +67,7 @@ class MainViewModel(
|
||||||
* Runs the benchmark with current parameters.
|
* Runs the benchmark with current parameters.
|
||||||
*/
|
*/
|
||||||
private suspend fun runBenchmark() {
|
private suspend fun runBenchmark() {
|
||||||
inferenceEngine.bench(pp, tg, pl)
|
inferenceEngine.bench(512, 128, 1, 3)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -94,6 +89,14 @@ class MainViewModel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tracks token generation metrics
|
||||||
|
*/
|
||||||
|
private var generationStartTime: Long = 0L
|
||||||
|
private var firstTokenTime: Long = 0L
|
||||||
|
private var tokenCount: Int = 0
|
||||||
|
private var isFirstToken: Boolean = true
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sends a user message and collects the response.
|
* Sends a user message and collects the response.
|
||||||
*/
|
*/
|
||||||
|
|
@ -111,13 +114,18 @@ class MainViewModel(
|
||||||
_messages.value = _messages.value + userMessage
|
_messages.value = _messages.value + userMessage
|
||||||
|
|
||||||
// Create placeholder for assistant message
|
// Create placeholder for assistant message
|
||||||
val assistantMessage = Message.Assistant(
|
val assistantMessage = Message.Assistant.Ongoing(
|
||||||
content = "",
|
content = "",
|
||||||
timestamp = System.currentTimeMillis(),
|
timestamp = System.currentTimeMillis()
|
||||||
isComplete = false
|
|
||||||
)
|
)
|
||||||
_messages.value = _messages.value + assistantMessage
|
_messages.value = _messages.value + assistantMessage
|
||||||
|
|
||||||
|
// Reset metrics tracking
|
||||||
|
generationStartTime = System.currentTimeMillis()
|
||||||
|
firstTokenTime = 0L
|
||||||
|
tokenCount = 0
|
||||||
|
isFirstToken = true
|
||||||
|
|
||||||
// Get response from engine
|
// Get response from engine
|
||||||
tokenCollectionJob = viewModelScope.launch {
|
tokenCollectionJob = viewModelScope.launch {
|
||||||
val response = StringBuilder()
|
val response = StringBuilder()
|
||||||
|
|
@ -129,12 +137,19 @@ class MainViewModel(
|
||||||
val currentMessages = _messages.value.toMutableList()
|
val currentMessages = _messages.value.toMutableList()
|
||||||
if (currentMessages.size >= 2) {
|
if (currentMessages.size >= 2) {
|
||||||
val messageIndex = currentMessages.size - 1
|
val messageIndex = currentMessages.size - 1
|
||||||
val currentAssistantMessage =
|
val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant.Ongoing
|
||||||
currentMessages[messageIndex] as? Message.Assistant
|
|
||||||
if (currentAssistantMessage != null) {
|
if (currentAssistantMessage != null) {
|
||||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
// Create metrics with error indication
|
||||||
|
val errorMetrics = TokenMetrics(
|
||||||
|
tokensCount = tokenCount,
|
||||||
|
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
|
||||||
|
tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime)
|
||||||
|
)
|
||||||
|
|
||||||
|
currentMessages[messageIndex] = Message.Assistant.Completed(
|
||||||
content = "${response}[Error: ${e.message}]",
|
content = "${response}[Error: ${e.message}]",
|
||||||
isComplete = true
|
timestamp = currentAssistantMessage.timestamp,
|
||||||
|
metrics = errorMetrics
|
||||||
)
|
)
|
||||||
_messages.value = currentMessages
|
_messages.value = currentMessages
|
||||||
}
|
}
|
||||||
|
|
@ -145,29 +160,50 @@ class MainViewModel(
|
||||||
val currentMessages = _messages.value.toMutableList()
|
val currentMessages = _messages.value.toMutableList()
|
||||||
if (currentMessages.isNotEmpty()) {
|
if (currentMessages.isNotEmpty()) {
|
||||||
val messageIndex = currentMessages.size - 1
|
val messageIndex = currentMessages.size - 1
|
||||||
val currentAssistantMessage =
|
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
|
||||||
currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
|
||||||
if (currentAssistantMessage != null) {
|
if (currentAssistantMessage != null) {
|
||||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
// Calculate final metrics
|
||||||
isComplete = true
|
val endTime = System.currentTimeMillis()
|
||||||
|
val totalTimeMs = endTime - generationStartTime
|
||||||
|
|
||||||
|
val metrics = TokenMetrics(
|
||||||
|
tokensCount = tokenCount,
|
||||||
|
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
|
||||||
|
tpsMs = calculateTPS(tokenCount, totalTimeMs)
|
||||||
|
)
|
||||||
|
|
||||||
|
currentMessages[messageIndex] = Message.Assistant.Completed(
|
||||||
|
content = response.toString(),
|
||||||
|
timestamp = currentAssistantMessage.timestamp,
|
||||||
|
metrics = metrics
|
||||||
)
|
)
|
||||||
_messages.value = currentMessages
|
_messages.value = currentMessages
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.collect { token ->
|
.collect { token ->
|
||||||
|
// Track first token time
|
||||||
|
if (isFirstToken && token.isNotBlank()) {
|
||||||
|
firstTokenTime = System.currentTimeMillis()
|
||||||
|
isFirstToken = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count tokens - each non-empty emission is at least one token
|
||||||
|
if (token.isNotBlank()) {
|
||||||
|
tokenCount++
|
||||||
|
}
|
||||||
|
|
||||||
response.append(token)
|
response.append(token)
|
||||||
|
|
||||||
// Safely update the assistant message with the generated text
|
// Safely update the assistant message with the generated text
|
||||||
val currentMessages = _messages.value.toMutableList()
|
val currentMessages = _messages.value.toMutableList()
|
||||||
if (currentMessages.isNotEmpty()) {
|
if (currentMessages.isNotEmpty()) {
|
||||||
val messageIndex = currentMessages.size - 1
|
val messageIndex = currentMessages.size - 1
|
||||||
val currentAssistantMessage =
|
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
|
||||||
currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
|
||||||
if (currentAssistantMessage != null) {
|
if (currentAssistantMessage != null) {
|
||||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
currentMessages[messageIndex] = Message.Assistant.Ongoing(
|
||||||
content = response.toString(),
|
content = response.toString(),
|
||||||
isComplete = false
|
timestamp = currentAssistantMessage.timestamp
|
||||||
)
|
)
|
||||||
_messages.value = currentMessages
|
_messages.value = currentMessages
|
||||||
}
|
}
|
||||||
|
|
@ -178,12 +214,19 @@ class MainViewModel(
|
||||||
val currentMessages = _messages.value.toMutableList()
|
val currentMessages = _messages.value.toMutableList()
|
||||||
if (currentMessages.isNotEmpty()) {
|
if (currentMessages.isNotEmpty()) {
|
||||||
val messageIndex = currentMessages.size - 1
|
val messageIndex = currentMessages.size - 1
|
||||||
val currentAssistantMessage =
|
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
|
||||||
currentMessages.getOrNull(messageIndex) as? Message.Assistant
|
|
||||||
if (currentAssistantMessage != null) {
|
if (currentAssistantMessage != null) {
|
||||||
currentMessages[messageIndex] = currentAssistantMessage.copy(
|
// Create metrics with error indication
|
||||||
|
val errorMetrics = TokenMetrics(
|
||||||
|
tokensCount = tokenCount,
|
||||||
|
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
|
||||||
|
tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime)
|
||||||
|
)
|
||||||
|
|
||||||
|
currentMessages[messageIndex] = Message.Assistant.Completed(
|
||||||
content = "${response}[Error: ${e.message}]",
|
content = "${response}[Error: ${e.message}]",
|
||||||
isComplete = true
|
timestamp = currentAssistantMessage.timestamp,
|
||||||
|
metrics = errorMetrics
|
||||||
)
|
)
|
||||||
_messages.value = currentMessages
|
_messages.value = currentMessages
|
||||||
}
|
}
|
||||||
|
|
@ -192,6 +235,14 @@ class MainViewModel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate tokens per second.
|
||||||
|
*/
|
||||||
|
private fun calculateTPS(tokens: Int, timeMs: Long): Float {
|
||||||
|
if (tokens <= 0 || timeMs <= 0) return 0f
|
||||||
|
return (tokens.toFloat() * 1000f) / timeMs
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unloads the currently loaded model.
|
* Unloads the currently loaded model.
|
||||||
*/
|
*/
|
||||||
|
|
@ -206,6 +257,9 @@ class MainViewModel(
|
||||||
inferenceEngine.unloadModel()
|
inferenceEngine.unloadModel()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if a model is currently being loaded.
|
||||||
|
*/
|
||||||
fun isModelLoading() =
|
fun isModelLoading() =
|
||||||
engineState.value.let {
|
engineState.value.let {
|
||||||
it is InferenceEngine.State.LoadingModel
|
it is InferenceEngine.State.LoadingModel
|
||||||
|
|
@ -213,7 +267,7 @@ class MainViewModel(
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks if a model is currently loaded.
|
* Checks if a model has already been loaded.
|
||||||
*/
|
*/
|
||||||
fun isModelLoaded() =
|
fun isModelLoaded() =
|
||||||
engineState.value.let {
|
engineState.value.let {
|
||||||
|
|
@ -247,24 +301,40 @@ class MainViewModel(
|
||||||
* Sealed class representing messages in a conversation.
|
* Sealed class representing messages in a conversation.
|
||||||
*/
|
*/
|
||||||
sealed class Message {
|
sealed class Message {
|
||||||
abstract val content: String
|
|
||||||
abstract val timestamp: Long
|
abstract val timestamp: Long
|
||||||
|
abstract val content: String
|
||||||
|
|
||||||
val formattedTime: String
|
val formattedTime: String
|
||||||
get() {
|
get() = datetimeFormatter.format(Date(timestamp))
|
||||||
val formatter = SimpleDateFormat("h:mm a", Locale.getDefault())
|
|
||||||
return formatter.format(Date(timestamp))
|
|
||||||
}
|
|
||||||
|
|
||||||
data class User(
|
data class User(
|
||||||
override val content: String,
|
override val timestamp: Long,
|
||||||
override val timestamp: Long
|
override val content: String
|
||||||
) : Message()
|
) : Message()
|
||||||
|
|
||||||
// TODO-han.yin: break down into ongoing & completed message subtypes
|
sealed class Assistant : Message() {
|
||||||
data class Assistant(
|
data class Ongoing(
|
||||||
override val content: String,
|
override val timestamp: Long,
|
||||||
override val timestamp: Long,
|
override val content: String,
|
||||||
val isComplete: Boolean = true
|
) : Assistant()
|
||||||
) : Message()
|
|
||||||
|
data class Completed(
|
||||||
|
override val timestamp: Long,
|
||||||
|
override val content: String,
|
||||||
|
val metrics: TokenMetrics
|
||||||
|
) : Assistant()
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private val datetimeFormatter by lazy { SimpleDateFormat("h:mm a", Locale.getDefault()) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data class TokenMetrics(
|
||||||
|
val tokensCount: Int,
|
||||||
|
val ttftMs: Long,
|
||||||
|
val tpsMs: Float,
|
||||||
|
) {
|
||||||
|
val text: String
|
||||||
|
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue