UI: avoid duplicated calculation of token metrics
This commit is contained in:
parent
dd5b20d74d
commit
7968216235
|
|
@ -96,7 +96,8 @@ data class ModelLoadingMetrics(
|
||||||
*/
|
*/
|
||||||
data class GenerationUpdate(
|
data class GenerationUpdate(
|
||||||
val text: String,
|
val text: String,
|
||||||
val isComplete: Boolean
|
val isComplete: Boolean,
|
||||||
|
val metrics: TokenMetrics? = null
|
||||||
)
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -221,6 +222,8 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
||||||
private var isFirstToken: Boolean = true
|
private var isFirstToken: Boolean = true
|
||||||
|
|
||||||
override fun generateResponse(prompt: String): Flow<GenerationUpdate> = flow {
|
override fun generateResponse(prompt: String): Flow<GenerationUpdate> = flow {
|
||||||
|
val response = StringBuilder()
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Reset metrics tracking
|
// Reset metrics tracking
|
||||||
generationStartTime = System.currentTimeMillis()
|
generationStartTime = System.currentTimeMillis()
|
||||||
|
|
@ -228,8 +231,6 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
||||||
tokenCount = 0
|
tokenCount = 0
|
||||||
isFirstToken = true
|
isFirstToken = true
|
||||||
|
|
||||||
val response = StringBuilder()
|
|
||||||
|
|
||||||
inferenceEngine.sendUserPrompt(prompt)
|
inferenceEngine.sendUserPrompt(prompt)
|
||||||
.collect { token ->
|
.collect { token ->
|
||||||
// Track first token time
|
// Track first token time
|
||||||
|
|
@ -253,10 +254,11 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
||||||
val metrics = createTokenMetrics()
|
val metrics = createTokenMetrics()
|
||||||
|
|
||||||
// Emit final response with completion flag
|
// Emit final response with completion flag
|
||||||
emit(GenerationUpdate(response.toString(), true))
|
emit(GenerationUpdate(response.toString(), true, metrics))
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
// Emit error
|
// Emit error
|
||||||
val metrics = createTokenMetrics()
|
val metrics = createTokenMetrics()
|
||||||
|
emit(GenerationUpdate(response.toString(), true, metrics))
|
||||||
throw e
|
throw e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -301,15 +301,15 @@ private fun MessageBubble(message: Message) {
|
||||||
formattedTime = message.formattedTime,
|
formattedTime = message.formattedTime,
|
||||||
content = message.content,
|
content = message.content,
|
||||||
isThinking = message.content.isBlank(),
|
isThinking = message.content.isBlank(),
|
||||||
isComplete = false,
|
isGenerating = true,
|
||||||
metrics = null
|
metrics = null
|
||||||
)
|
)
|
||||||
|
|
||||||
is Message.Assistant.Completed -> AssistantMessageBubble(
|
is Message.Assistant.Stopped -> AssistantMessageBubble(
|
||||||
formattedTime = message.formattedTime,
|
formattedTime = message.formattedTime,
|
||||||
content = message.content,
|
content = message.content,
|
||||||
isThinking = false,
|
isThinking = false,
|
||||||
isComplete = true,
|
isGenerating = false,
|
||||||
metrics = message.metrics.text
|
metrics = message.metrics.text
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
@ -357,7 +357,7 @@ private fun AssistantMessageBubble(
|
||||||
formattedTime: String,
|
formattedTime: String,
|
||||||
content: String,
|
content: String,
|
||||||
isThinking: Boolean,
|
isThinking: Boolean,
|
||||||
isComplete: Boolean,
|
isGenerating: Boolean,
|
||||||
metrics: String? = null
|
metrics: String? = null
|
||||||
) {
|
) {
|
||||||
Row(
|
Row(
|
||||||
|
|
@ -417,7 +417,7 @@ private fun AssistantMessageBubble(
|
||||||
.padding(top = 4.dp),
|
.padding(top = 4.dp),
|
||||||
verticalAlignment = Alignment.CenterVertically
|
verticalAlignment = Alignment.CenterVertically
|
||||||
) {
|
) {
|
||||||
if (!isComplete) {
|
if (isGenerating) {
|
||||||
PulsatingDots(small = true)
|
PulsatingDots(small = true)
|
||||||
|
|
||||||
Spacer(modifier = Modifier.width(4.dp))
|
Spacer(modifier = Modifier.width(4.dp))
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ import kotlinx.coroutines.flow.StateFlow
|
||||||
import kotlinx.coroutines.flow.asStateFlow
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
import kotlinx.coroutines.flow.onCompletion
|
import kotlinx.coroutines.flow.onCompletion
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
import okhttp3.internal.toImmutableList
|
||||||
import java.text.SimpleDateFormat
|
import java.text.SimpleDateFormat
|
||||||
import java.util.Date
|
import java.util.Date
|
||||||
import java.util.Locale
|
import java.util.Locale
|
||||||
|
|
@ -94,74 +95,55 @@ class ConversationViewModel @Inject constructor(
|
||||||
/**
|
/**
|
||||||
* Stop ongoing generation
|
* Stop ongoing generation
|
||||||
*/
|
*/
|
||||||
fun stopGeneration() {
|
fun stopGeneration() =
|
||||||
tokenCollectionJob?.let { job ->
|
tokenCollectionJob?.let { job ->
|
||||||
// handled by the catch blocks
|
// handled by the catch blocks
|
||||||
if (job.isActive) { job.cancel() }
|
if (job.isActive) { job.cancel() }
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Handle the case when generation is explicitly cancelled
|
* Handle the case when generation is explicitly cancelled by adding a stopping suffix
|
||||||
*/
|
*/
|
||||||
private fun handleCancellation() {
|
private fun handleCancellation() =
|
||||||
val currentMessages = _messages.value.toMutableList()
|
_messages.value.toMutableList().apply {
|
||||||
val lastIndex = currentMessages.size - 1
|
(removeLastOrNull() as? Message.Assistant.Stopped)?.let {
|
||||||
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
|
add(it.copy(content = it.content + SUFFIX_GENERATION_STOPPED))
|
||||||
|
_messages.value = toImmutableList()
|
||||||
if (currentAssistantMessage != null) {
|
|
||||||
// Replace with completed message, adding note that it was interrupted
|
|
||||||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
|
||||||
content = currentAssistantMessage.content + " [Generation stopped]",
|
|
||||||
timestamp = currentAssistantMessage.timestamp,
|
|
||||||
metrics = conversationService.createTokenMetrics()
|
|
||||||
)
|
|
||||||
_messages.value = currentMessages
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Handle response error
|
* Handle response error by appending an error suffix
|
||||||
*/
|
*/
|
||||||
private fun handleResponseError(e: Exception) {
|
private fun handleResponseError(e: Exception) =
|
||||||
val currentMessages = _messages.value.toMutableList()
|
_messages.value.toMutableList().apply {
|
||||||
val lastIndex = currentMessages.size - 1
|
(removeLastOrNull() as? Message.Assistant.Stopped)?.let {
|
||||||
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
|
add(it.copy(content = it.content + SUFFIX_GENERATION_ERROR.format(e.message)))
|
||||||
|
_messages.value = toImmutableList()
|
||||||
if (currentAssistantMessage != null) {
|
|
||||||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
|
||||||
content = currentAssistantMessage.content + " [Error: ${e.message}]",
|
|
||||||
timestamp = currentAssistantMessage.timestamp,
|
|
||||||
metrics = conversationService.createTokenMetrics()
|
|
||||||
)
|
|
||||||
_messages.value = currentMessages
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Handle updating the assistant message
|
* Handle updating the assistant message
|
||||||
*/
|
*/
|
||||||
private fun updateAssistantMessage(update: GenerationUpdate) {
|
private fun updateAssistantMessage(update: GenerationUpdate) =
|
||||||
val currentMessages = _messages.value.toMutableList()
|
_messages.value.toMutableList().apply {
|
||||||
val lastIndex = currentMessages.size - 1
|
(removeLastOrNull() as? Message.Assistant.Ongoing)?.let {
|
||||||
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
|
if (update.metrics != null) {
|
||||||
|
// Finalized message (partial or complete) with metrics
|
||||||
if (currentAssistantMessage != null) {
|
add(Message.Assistant.Stopped(
|
||||||
if (update.isComplete) {
|
|
||||||
// Final message with metrics
|
|
||||||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
|
||||||
content = update.text,
|
content = update.text,
|
||||||
timestamp = currentAssistantMessage.timestamp,
|
timestamp = it.timestamp,
|
||||||
metrics = conversationService.createTokenMetrics()
|
metrics = update.metrics
|
||||||
)
|
))
|
||||||
} else {
|
} else if (!update.isComplete) {
|
||||||
// Ongoing message update
|
// Ongoing message update
|
||||||
currentMessages[lastIndex] = Message.Assistant.Ongoing(
|
add(Message.Assistant.Ongoing(
|
||||||
content = update.text,
|
content = update.text,
|
||||||
timestamp = currentAssistantMessage.timestamp
|
timestamp = it.timestamp
|
||||||
)
|
))
|
||||||
}
|
}
|
||||||
_messages.value = currentMessages
|
_messages.value = toImmutableList()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -179,6 +161,11 @@ class ConversationViewModel @Inject constructor(
|
||||||
stopGeneration()
|
stopGeneration()
|
||||||
super.onCleared()
|
super.onCleared()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private const val SUFFIX_GENERATION_STOPPED = " [Generation stopped]"
|
||||||
|
private const val SUFFIX_GENERATION_ERROR = " [Error: %s]"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -203,7 +190,7 @@ sealed class Message {
|
||||||
override val content: String,
|
override val content: String,
|
||||||
) : Assistant()
|
) : Assistant()
|
||||||
|
|
||||||
data class Completed(
|
data class Stopped(
|
||||||
override val timestamp: Long,
|
override val timestamp: Long,
|
||||||
override val content: String,
|
override val content: String,
|
||||||
val metrics: TokenMetrics
|
val metrics: TokenMetrics
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue