vm: split mono MainViewModel into separate individual ViewModels

This commit is contained in:
Han Yin 2025-04-15 13:48:01 -07:00
parent 32d778bb8e
commit 23d411d86e
10 changed files with 351 additions and 349 deletions

View File

@ -170,7 +170,6 @@ fun AppContent(
composable(AppDestinations.MODEL_SELECTION_ROUTE) { composable(AppDestinations.MODEL_SELECTION_ROUTE) {
ModelSelectionScreen( ModelSelectionScreen(
onModelSelected = { modelInfo -> onModelSelected = { modelInfo ->
mainVewModel.selectModel(modelInfo)
navigationActions.navigateToModelLoading() navigationActions.navigateToModelLoading()
}, },
onManageModelsClicked = { onManageModelsClicked = {
@ -184,34 +183,28 @@ fun AppContent(
composable(AppDestinations.MODEL_LOADING_ROUTE) { composable(AppDestinations.MODEL_LOADING_ROUTE) {
ModelLoadingScreen( ModelLoadingScreen(
engineState = engineState, engineState = engineState,
onBenchmarkSelected = { onBenchmarkSelected = { prepareJob ->
// Store a reference to the loading job // Wait for preparation to complete, then navigate if still active
val loadingJob = coroutineScope.launch { val loadingJob = coroutineScope.launch {
mainVewModel.prepareForBenchmark() prepareJob.join()
// Check if the job wasn't cancelled before navigating if (isActive) { navigationActions.navigateToBenchmark() }
if (isActive) {
navigationActions.navigateToBenchmark()
}
} }
// Update the pendingNavigation handler to cancel any ongoing loading
pendingNavigation = { pendingNavigation = {
prepareJob.cancel()
loadingJob.cancel() loadingJob.cancel()
navigationActions.navigateUp() navigationActions.navigateUp()
} }
}, },
onConversationSelected = { systemPrompt -> onConversationSelected = { systemPrompt, prepareJob ->
// Store a reference to the loading job // Wait for preparation to complete, then navigate if still active
val loadingJob = coroutineScope.launch { val loadingJob = coroutineScope.launch {
mainVewModel.prepareForConversation(systemPrompt) prepareJob.join()
// Check if the job wasn't cancelled before navigating if (isActive) { navigationActions.navigateToConversation() }
if (isActive) {
navigationActions.navigateToConversation()
} }
}
// Update the pendingNavigation handler to cancel any ongoing loading
pendingNavigation = { pendingNavigation = {
prepareJob.cancel()
loadingJob.cancel() loadingJob.cancel()
navigationActions.navigateUp() navigationActions.navigateUp()
} }
@ -229,8 +222,7 @@ fun AppContent(
onBackPressed = { onBackPressed = {
// Need to unload model before going back // Need to unload model before going back
handleBackWithModelCheck() handleBackWithModelCheck()
}, }
viewModel = mainVewModel
) )
} }
@ -240,14 +232,7 @@ fun AppContent(
onBackPressed = { onBackPressed = {
// Need to unload model before going back // Need to unload model before going back
handleBackWithModelCheck() handleBackWithModelCheck()
}, }
onRerunPressed = {
mainVewModel.rerunBenchmark()
},
onSharePressed = {
// Stub for sharing functionality
},
viewModel = mainVewModel
) )
} }

View File

@ -26,14 +26,13 @@ import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.components.PerformanceAppScaffold
import com.example.llama.revamp.ui.theme.MonospacedTextStyle import com.example.llama.revamp.ui.theme.MonospacedTextStyle
import com.example.llama.revamp.viewmodel.BenchmarkViewModel
import com.example.llama.revamp.viewmodel.MainViewModel import com.example.llama.revamp.viewmodel.MainViewModel
@Composable @Composable
fun BenchmarkScreen( fun BenchmarkScreen(
onBackPressed: () -> Unit, onBackPressed: () -> Unit,
onRerunPressed: () -> Unit, viewModel: BenchmarkViewModel = hiltViewModel()
onSharePressed: () -> Unit,
viewModel: MainViewModel = hiltViewModel()
) { ) {
val engineState by viewModel.engineState.collectAsState() val engineState by viewModel.engineState.collectAsState()
val benchmarkResults by viewModel.benchmarkResults.collectAsState() val benchmarkResults by viewModel.benchmarkResults.collectAsState()

View File

@ -62,7 +62,7 @@ import androidx.lifecycle.LifecycleEventObserver
import androidx.lifecycle.compose.LocalLifecycleOwner import androidx.lifecycle.compose.LocalLifecycleOwner
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.components.PerformanceAppScaffold
import com.example.llama.revamp.viewmodel.MainViewModel import com.example.llama.revamp.viewmodel.ConversationViewModel
import com.example.llama.revamp.viewmodel.Message import com.example.llama.revamp.viewmodel.Message
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@ -72,7 +72,7 @@ import kotlinx.coroutines.launch
@Composable @Composable
fun ConversationScreen( fun ConversationScreen(
onBackPressed: () -> Unit, onBackPressed: () -> Unit,
viewModel: MainViewModel = hiltViewModel() viewModel: ConversationViewModel = hiltViewModel()
) { ) {
val engineState by viewModel.engineState.collectAsState() val engineState by viewModel.engineState.collectAsState()
val messages by viewModel.messages.collectAsState() val messages by viewModel.messages.collectAsState()

View File

@ -40,6 +40,7 @@ import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
@ -50,7 +51,10 @@ import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.data.model.SystemPrompt
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.components.PerformanceAppScaffold
import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
import com.example.llama.revamp.viewmodel.SystemPromptViewModel import com.example.llama.revamp.viewmodel.SystemPromptViewModel
import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
enum class SystemPromptTab { enum class SystemPromptTab {
PRESETS, CUSTOM, RECENTS PRESETS, CUSTOM, RECENTS
@ -59,14 +63,17 @@ enum class SystemPromptTab {
@OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class) @OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
@Composable @Composable
fun ModelLoadingScreen( fun ModelLoadingScreen(
viewModel: SystemPromptViewModel = hiltViewModel(),
engineState: InferenceEngine.State, engineState: InferenceEngine.State,
onBenchmarkSelected: () -> Unit, onBenchmarkSelected: (prepareJob: Job) -> Unit,
onConversationSelected: (String?) -> Unit, onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit,
onBackPressed: () -> Unit, onBackPressed: () -> Unit,
modelLoadingViewModel: ModelLoadingViewModel = hiltViewModel(),
systemPromptViewModel: SystemPromptViewModel = hiltViewModel(),
) { ) {
val presetPrompts by viewModel.presetPrompts.collectAsState() val coroutineScope = rememberCoroutineScope()
val recentPrompts by viewModel.recentPrompts.collectAsState()
val presetPrompts by systemPromptViewModel.presetPrompts.collectAsState()
val recentPrompts by systemPromptViewModel.recentPrompts.collectAsState()
var selectedMode by remember { mutableStateOf<Mode?>(null) } var selectedMode by remember { mutableStateOf<Mode?>(null) }
var useSystemPrompt by remember { mutableStateOf(false) } var useSystemPrompt by remember { mutableStateOf(false) }
@ -96,6 +103,21 @@ fun ModelLoadingScreen(
engineState !is InferenceEngine.State.LibraryLoaded && engineState !is InferenceEngine.State.LibraryLoaded &&
engineState !is InferenceEngine.State.AwaitingUserPrompt engineState !is InferenceEngine.State.AwaitingUserPrompt
// Mode selection callbacks
val handleBenchmarkSelected = {
val prepareJob = coroutineScope.launch {
modelLoadingViewModel.prepareForBenchmark()
}
onBenchmarkSelected(prepareJob)
}
val handleConversationSelected = { systemPrompt: String? ->
val prepareJob = coroutineScope.launch {
modelLoadingViewModel.prepareForConversation(systemPrompt)
}
onConversationSelected(systemPrompt, prepareJob)
}
PerformanceAppScaffold( PerformanceAppScaffold(
title = "Load Model", title = "Load Model",
onNavigateBack = onBackPressed, onNavigateBack = onBackPressed,
@ -143,7 +165,7 @@ fun ModelLoadingScreen(
modifier = Modifier modifier = Modifier
.fillMaxWidth() .fillMaxWidth()
.padding(bottom = 4.dp) .padding(bottom = 4.dp)
// Only use weight if system prompt is active, otherwise wrap content // Only fill height if system prompt is active
.then(if (useSystemPrompt) Modifier.weight(1f) else Modifier) .then(if (useSystemPrompt) Modifier.weight(1f) else Modifier)
) { ) {
Column( Column(
@ -355,14 +377,15 @@ fun ModelLoadingScreen(
Button( Button(
onClick = { onClick = {
when (selectedMode) { when (selectedMode) {
Mode.BENCHMARK -> onBenchmarkSelected() Mode.BENCHMARK -> handleBenchmarkSelected()
Mode.CONVERSATION -> { Mode.CONVERSATION -> {
val systemPrompt = if (useSystemPrompt) { val systemPrompt = if (useSystemPrompt) {
when (selectedTab) { when (selectedTab) {
SystemPromptTab.PRESETS, SystemPromptTab.RECENTS -> SystemPromptTab.PRESETS, SystemPromptTab.RECENTS ->
selectedPrompt?.let { prompt -> selectedPrompt?.let { prompt ->
// Save the prompt to recent prompts database // Save the prompt to recent prompts database
viewModel.savePromptToRecents(prompt) systemPromptViewModel.savePromptToRecents(prompt)
prompt.content prompt.content
} }
@ -370,15 +393,15 @@ fun ModelLoadingScreen(
customPromptText.takeIf { it.isNotBlank() } customPromptText.takeIf { it.isNotBlank() }
?.also { promptText -> ?.also { promptText ->
// Save custom prompt to database // Save custom prompt to database
viewModel.saveCustomPromptToRecents(promptText) systemPromptViewModel.saveCustomPromptToRecents(promptText)
} }
} }
} else null } else null
onConversationSelected(systemPrompt)
handleConversationSelected(systemPrompt)
} }
null -> { /* No mode selected */ null -> { /* No mode selected */ }
}
} }
}, },
modifier = Modifier modifier = Modifier

View File

@ -30,7 +30,7 @@ import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.ui.components.ModelCard import com.example.llama.revamp.ui.components.ModelCard
import com.example.llama.revamp.ui.components.ModelCardActions import com.example.llama.revamp.ui.components.ModelCardActions
import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.components.PerformanceAppScaffold
import com.example.llama.revamp.viewmodel.MainViewModel import com.example.llama.revamp.viewmodel.ModelSelectionViewModel
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
@ -38,10 +38,15 @@ fun ModelSelectionScreen(
onModelSelected: (ModelInfo) -> Unit, onModelSelected: (ModelInfo) -> Unit,
onManageModelsClicked: () -> Unit, onManageModelsClicked: () -> Unit,
onMenuClicked: () -> Unit, onMenuClicked: () -> Unit,
viewModel: MainViewModel = hiltViewModel(), viewModel: ModelSelectionViewModel = hiltViewModel(),
) { ) {
val models by viewModel.availableModels.collectAsState() val models by viewModel.availableModels.collectAsState()
val handleModelSelection = { model: ModelInfo ->
viewModel.selectModel(model)
onModelSelected(model)
}
PerformanceAppScaffold( PerformanceAppScaffold(
title = "Models", title = "Models",
onMenuOpen = onMenuClicked, onMenuOpen = onMenuClicked,
@ -60,11 +65,13 @@ fun ModelSelectionScreen(
items(models) { model -> items(models) { model ->
ModelCard( ModelCard(
model = model, model = model,
onClick = { onModelSelected(model) }, onClick = { handleModelSelection(model) },
modifier = Modifier.padding(vertical = 4.dp), modifier = Modifier.padding(vertical = 4.dp),
isSelected = null, // Not in selection mode isSelected = null, // Not in selection mode
actionButton = { actionButton = {
ModelCardActions.PlayButton(onClick = { onModelSelected(model) }) ModelCardActions.PlayButton {
handleModelSelection(model)
}
} }
) )
Spacer(modifier = Modifier.height(8.dp)) Spacer(modifier = Modifier.height(8.dp))

View File

@ -0,0 +1,29 @@
package com.example.llama.revamp.viewmodel
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.engine.InferenceManager
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.launch
import javax.inject.Inject
@HiltViewModel
class BenchmarkViewModel @Inject constructor(
private val inferenceManager: InferenceManager
) : ViewModel() {
val engineState: StateFlow<InferenceEngine.State> = inferenceManager.engineState
val benchmarkResults: StateFlow<String?> = inferenceManager.benchmarkResults
val selectedModel: StateFlow<ModelInfo?> = inferenceManager.currentModel
/**
* Run benchmark with specified parameters
*/
fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3) =
viewModelScope.launch {
inferenceManager.benchmark(pp, tg, pl, nr)
}
}

View File

@ -0,0 +1,166 @@
package com.example.llama.revamp.viewmodel
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.engine.InferenceManager
import com.example.llama.revamp.engine.TokenMetrics
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch
import java.text.SimpleDateFormat
import java.util.Date
import java.util.Locale
import javax.inject.Inject
import kotlin.getValue
@HiltViewModel
class ConversationViewModel @Inject constructor(
private val inferenceManager: InferenceManager
) : ViewModel() {
val engineState = inferenceManager.engineState
val selectedModel = inferenceManager.currentModel
val systemPrompt = inferenceManager.systemPrompt
// Messages in conversation
private val _messages = MutableStateFlow<List<Message>>(emptyList())
val messages: StateFlow<List<Message>> = _messages.asStateFlow()
// Token generation job
private var tokenCollectionJob: Job? = null
/**
* Send a message with the provided content.
* Note: This matches the existing UI which manages input state outside the ViewModel.
*/
fun sendMessage(content: String) {
if (content.isBlank()) return
// Cancel ongoing collection
tokenCollectionJob?.cancel()
// Add user message
val userMessage = Message.User(
content = content,
timestamp = System.currentTimeMillis()
)
_messages.value = _messages.value + userMessage
// Add placeholder for assistant response
val assistantMessage = Message.Assistant.Ongoing(
content = "",
timestamp = System.currentTimeMillis()
)
_messages.value = _messages.value + assistantMessage
// Collect response
tokenCollectionJob = viewModelScope.launch {
try {
inferenceManager.generateResponse(content)
.collect { (text, isComplete) ->
updateAssistantMessage(text, isComplete)
}
} catch (e: Exception) {
// Handle error
handleResponseError(e)
}
}
}
/**
* Handle updating the assistant message
*/
private fun updateAssistantMessage(text: String, isComplete: Boolean) {
val currentMessages = _messages.value.toMutableList()
val lastIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
if (currentAssistantMessage != null) {
if (isComplete) {
// Final message with metrics
currentMessages[lastIndex] = Message.Assistant.Completed(
content = text,
timestamp = currentAssistantMessage.timestamp,
metrics = inferenceManager.createTokenMetrics()
)
} else {
// Ongoing message update
currentMessages[lastIndex] = Message.Assistant.Ongoing(
content = text,
timestamp = currentAssistantMessage.timestamp
)
}
_messages.value = currentMessages
}
}
/**
* Handle response error
*/
private fun handleResponseError(e: Exception) {
val currentMessages = _messages.value.toMutableList()
val lastIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
if (currentAssistantMessage != null) {
currentMessages[lastIndex] = Message.Assistant.Completed(
content = "${currentAssistantMessage.content}[Error: ${e.message}]",
timestamp = currentAssistantMessage.timestamp,
metrics = inferenceManager.createTokenMetrics()
)
_messages.value = currentMessages
}
}
/**
* Clear conversation
*/
fun clearConversation() {
tokenCollectionJob?.cancel()
_messages.value = emptyList()
}
override fun onCleared() {
tokenCollectionJob?.cancel()
super.onCleared()
}
}
/**
* Sealed class representing messages in a conversation.
*/
sealed class Message {
abstract val timestamp: Long
abstract val content: String
val formattedTime: String
get() = datetimeFormatter.format(Date(timestamp))
data class User(
override val timestamp: Long,
override val content: String
) : Message()
sealed class Assistant : Message() {
data class Ongoing(
override val timestamp: Long,
override val content: String,
) : Assistant()
data class Completed(
override val timestamp: Long,
override val content: String,
val metrics: TokenMetrics
) : Assistant()
}
companion object {
private val datetimeFormatter by lazy { SimpleDateFormat("h:mm a", Locale.getDefault()) }
}
}

View File

@ -1,276 +1,25 @@
package com.example.llama.revamp.viewmodel package com.example.llama.revamp.viewmodel
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.data.repository.ModelRepository
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.engine.InferenceManager
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.launch
import java.text.SimpleDateFormat
import java.util.Date
import java.util.Locale
import javax.inject.Inject import javax.inject.Inject
/**
* Main ViewModel that handles the LLM engine state and operations.
*/
@HiltViewModel @HiltViewModel
/**
* Main ViewModel that expose the core states of [InferenceEngine]
*/
class MainViewModel @Inject constructor ( class MainViewModel @Inject constructor (
private val inferenceEngine: InferenceEngine, private val inferenceManager: InferenceManager,
private val modelRepository: ModelRepository,
) : ViewModel() { ) : ViewModel() {
// Expose the engine state val engineState = inferenceManager.engineState
val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state
// Benchmark results
val benchmarkResults: StateFlow<String?> = inferenceEngine.benchmarkResults
// Available models for selection
val availableModels: StateFlow<List<ModelInfo>> = modelRepository.getModels()
.stateIn(
scope = viewModelScope,
started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS),
initialValue = emptyList()
)
// Selected model information
private val _selectedModel = MutableStateFlow<ModelInfo?>(null)
val selectedModel: StateFlow<ModelInfo?> = _selectedModel.asStateFlow()
// Messages in the conversation
private val _messages = MutableStateFlow<List<Message>>(emptyList())
val messages: StateFlow<List<Message>> = _messages.asStateFlow()
// System prompt for the conversation
private val _systemPrompt = MutableStateFlow<String?>(null)
val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
// Flag to track if token collection is active
private var tokenCollectionJob: Job? = null
/** /**
* Selects a model for use. * Unload the current model and release the resources
*/ */
fun selectModel(modelInfo: ModelInfo) { suspend fun unloadModel() = inferenceManager.unloadModel()
_selectedModel.value = modelInfo
viewModelScope.launch {
modelRepository.updateModelLastUsed(modelInfo.id)
}
}
/**
* Prepares the engine for benchmark mode.
*/
suspend fun prepareForBenchmark() {
_selectedModel.value?.let { model ->
inferenceEngine.loadModel(model.path)
}
}
/**
* Runs the benchmark with current parameters.
*/
suspend fun runBenchmark() = inferenceEngine.bench(512, 128, 1, 3)
/**
* Reruns the benchmark.
*/
fun rerunBenchmark() = viewModelScope.launch { runBenchmark() }
/**
* Prepares the engine for conversation mode.
*/
suspend fun prepareForConversation(systemPrompt: String? = null) {
_systemPrompt.value = systemPrompt
_selectedModel.value?.let { model ->
inferenceEngine.loadModel(model.path, systemPrompt)
}
}
/**
* Tracks token generation metrics
*/
private var generationStartTime: Long = 0L
private var firstTokenTime: Long = 0L
private var tokenCount: Int = 0
private var isFirstToken: Boolean = true
/**
* Sends a user message and collects the response.
*/
fun sendMessage(content: String) {
if (content.isBlank()) return
// Cancel any ongoing token collection
tokenCollectionJob?.cancel()
// Add user message
val userMessage = Message.User(
content = content,
timestamp = System.currentTimeMillis()
)
_messages.value = _messages.value + userMessage
// Create placeholder for assistant message
val assistantMessage = Message.Assistant.Ongoing(
content = "",
timestamp = System.currentTimeMillis()
)
_messages.value = _messages.value + assistantMessage
// Reset metrics tracking
generationStartTime = System.currentTimeMillis()
firstTokenTime = 0L
tokenCount = 0
isFirstToken = true
// Get response from engine
tokenCollectionJob = viewModelScope.launch {
val response = StringBuilder()
try {
inferenceEngine.sendUserPrompt(content)
.catch { e ->
// Handle errors during token collection
val currentMessages = _messages.value.toMutableList()
if (currentMessages.size >= 2) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant.Ongoing
if (currentAssistantMessage != null) {
// Create metrics with error indication
val errorMetrics = TokenMetrics(
tokensCount = tokenCount,
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime)
)
currentMessages[messageIndex] = Message.Assistant.Completed(
content = "${response}[Error: ${e.message}]",
timestamp = currentAssistantMessage.timestamp,
metrics = errorMetrics
)
_messages.value = currentMessages
}
}
}
.onCompletion { cause ->
// Handle completion (normal or cancelled)
val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
if (currentAssistantMessage != null) {
// Calculate final metrics
val endTime = System.currentTimeMillis()
val totalTimeMs = endTime - generationStartTime
val metrics = TokenMetrics(
tokensCount = tokenCount,
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
tpsMs = calculateTPS(tokenCount, totalTimeMs)
)
currentMessages[messageIndex] = Message.Assistant.Completed(
content = response.toString(),
timestamp = currentAssistantMessage.timestamp,
metrics = metrics
)
_messages.value = currentMessages
}
}
}
.collect { token ->
// Track first token time
if (isFirstToken && token.isNotBlank()) {
firstTokenTime = System.currentTimeMillis()
isFirstToken = false
}
// Count tokens - each non-empty emission is at least one token
if (token.isNotBlank()) {
tokenCount++
}
response.append(token)
// Safely update the assistant message with the generated text
val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
if (currentAssistantMessage != null) {
currentMessages[messageIndex] = Message.Assistant.Ongoing(
content = response.toString(),
timestamp = currentAssistantMessage.timestamp
)
_messages.value = currentMessages
}
}
}
} catch (e: Exception) {
// Handle any unexpected exceptions
val currentMessages = _messages.value.toMutableList()
if (currentMessages.isNotEmpty()) {
val messageIndex = currentMessages.size - 1
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
if (currentAssistantMessage != null) {
// Create metrics with error indication
val errorMetrics = TokenMetrics(
tokensCount = tokenCount,
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime)
)
currentMessages[messageIndex] = Message.Assistant.Completed(
content = "${response}[Error: ${e.message}]",
timestamp = currentAssistantMessage.timestamp,
metrics = errorMetrics
)
_messages.value = currentMessages
}
}
}
}
}
/**
* Calculate tokens per second.
*/
private fun calculateTPS(tokens: Int, timeMs: Long): Float {
if (tokens <= 0 || timeMs <= 0) return 0f
return (tokens.toFloat() * 1000f) / timeMs
}
/**
* Unloads the currently loaded model after cleanup chores:
* - Cancel any ongoing token collection
* - Clear messages
*/
suspend fun unloadModel() {
tokenCollectionJob?.cancel()
_messages.value = emptyList()
inferenceEngine.unloadModel()
}
/**
* Clean up resources when ViewModel is cleared.
*/
override fun onCleared() {
inferenceEngine.destroy()
super.onCleared()
}
companion object { companion object {
private val TAG = MainViewModel::class.java.simpleName private val TAG = MainViewModel::class.java.simpleName
@ -279,44 +28,3 @@ class MainViewModel @Inject constructor (
} }
} }
/**
* Sealed class representing messages in a conversation.
*/
sealed class Message {
abstract val timestamp: Long
abstract val content: String
val formattedTime: String
get() = datetimeFormatter.format(Date(timestamp))
data class User(
override val timestamp: Long,
override val content: String
) : Message()
sealed class Assistant : Message() {
data class Ongoing(
override val timestamp: Long,
override val content: String,
) : Assistant()
data class Completed(
override val timestamp: Long,
override val content: String,
val metrics: TokenMetrics
) : Assistant()
}
companion object {
private val datetimeFormatter by lazy { SimpleDateFormat("h:mm a", Locale.getDefault()) }
}
}
data class TokenMetrics(
val tokensCount: Int,
val ttftMs: Long,
val tpsMs: Float,
) {
val text: String
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
}

View File

@ -0,0 +1,27 @@
package com.example.llama.revamp.viewmodel
import androidx.lifecycle.ViewModel
import com.example.llama.revamp.engine.InferenceManager
import dagger.hilt.android.lifecycle.HiltViewModel
import javax.inject.Inject
@HiltViewModel
class ModelLoadingViewModel @Inject constructor(
private val inferenceManager: InferenceManager
) : ViewModel() {
val engineState = inferenceManager.engineState
val selectedModel = inferenceManager.currentModel
/**
* Prepares the engine for benchmark mode.
*/
suspend fun prepareForBenchmark() =
inferenceManager.loadModelForBenchmark()
/**
* Prepare for conversation
*/
suspend fun prepareForConversation(systemPrompt: String? = null) =
inferenceManager.loadModelForConversation(systemPrompt)
}

View File

@ -0,0 +1,58 @@
package com.example.llama.revamp.viewmodel
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.data.repository.ModelRepository
import com.example.llama.revamp.engine.InferenceManager
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.launch
import javax.inject.Inject
@HiltViewModel
class ModelSelectionViewModel @Inject constructor(
private val inferenceManager: InferenceManager,
private val modelRepository: ModelRepository
) : ViewModel() {
/**
* Available models for selection
*/
val availableModels: StateFlow<List<ModelInfo>> = modelRepository.getModels()
.stateIn(
scope = viewModelScope,
started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS),
initialValue = emptyList()
)
/**
* Access to currently selected model
*/
val selectedModel = inferenceManager.currentModel
/**
* Select a model and update its last used timestamp
*/
fun selectModel(modelInfo: ModelInfo) {
inferenceManager.setCurrentModel(modelInfo)
viewModelScope.launch {
modelRepository.updateModelLastUsed(modelInfo.id)
}
}
/**
* Unload model when navigating away
*/
suspend fun unloadModel() = inferenceManager.unloadModel()
companion object {
private val TAG = ModelSelectionViewModel::class.java.simpleName
private const val SUBSCRIPTION_TIMEOUT_MS = 5000L
}
}