vm: split mono MainViewModel into separate individual ViewModels
This commit is contained in:
parent
32d778bb8e
commit
23d411d86e
|
|
@ -170,7 +170,6 @@ fun AppContent(
|
|||
composable(AppDestinations.MODEL_SELECTION_ROUTE) {
|
||||
ModelSelectionScreen(
|
||||
onModelSelected = { modelInfo ->
|
||||
mainVewModel.selectModel(modelInfo)
|
||||
navigationActions.navigateToModelLoading()
|
||||
},
|
||||
onManageModelsClicked = {
|
||||
|
|
@ -184,34 +183,28 @@ fun AppContent(
|
|||
composable(AppDestinations.MODEL_LOADING_ROUTE) {
|
||||
ModelLoadingScreen(
|
||||
engineState = engineState,
|
||||
onBenchmarkSelected = {
|
||||
// Store a reference to the loading job
|
||||
onBenchmarkSelected = { prepareJob ->
|
||||
// Wait for preparation to complete, then navigate if still active
|
||||
val loadingJob = coroutineScope.launch {
|
||||
mainVewModel.prepareForBenchmark()
|
||||
// Check if the job wasn't cancelled before navigating
|
||||
if (isActive) {
|
||||
navigationActions.navigateToBenchmark()
|
||||
}
|
||||
prepareJob.join()
|
||||
if (isActive) { navigationActions.navigateToBenchmark() }
|
||||
}
|
||||
|
||||
|
||||
// Update the pendingNavigation handler to cancel any ongoing loading
|
||||
pendingNavigation = {
|
||||
prepareJob.cancel()
|
||||
loadingJob.cancel()
|
||||
navigationActions.navigateUp()
|
||||
}
|
||||
},
|
||||
onConversationSelected = { systemPrompt ->
|
||||
// Store a reference to the loading job
|
||||
onConversationSelected = { systemPrompt, prepareJob ->
|
||||
// Wait for preparation to complete, then navigate if still active
|
||||
val loadingJob = coroutineScope.launch {
|
||||
mainVewModel.prepareForConversation(systemPrompt)
|
||||
// Check if the job wasn't cancelled before navigating
|
||||
if (isActive) {
|
||||
navigationActions.navigateToConversation()
|
||||
}
|
||||
prepareJob.join()
|
||||
if (isActive) { navigationActions.navigateToConversation() }
|
||||
}
|
||||
// Update the pendingNavigation handler to cancel any ongoing loading
|
||||
|
||||
pendingNavigation = {
|
||||
prepareJob.cancel()
|
||||
loadingJob.cancel()
|
||||
navigationActions.navigateUp()
|
||||
}
|
||||
|
|
@ -229,8 +222,7 @@ fun AppContent(
|
|||
onBackPressed = {
|
||||
// Need to unload model before going back
|
||||
handleBackWithModelCheck()
|
||||
},
|
||||
viewModel = mainVewModel
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -240,14 +232,7 @@ fun AppContent(
|
|||
onBackPressed = {
|
||||
// Need to unload model before going back
|
||||
handleBackWithModelCheck()
|
||||
},
|
||||
onRerunPressed = {
|
||||
mainVewModel.rerunBenchmark()
|
||||
},
|
||||
onSharePressed = {
|
||||
// Stub for sharing functionality
|
||||
},
|
||||
viewModel = mainVewModel
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,14 +26,13 @@ import androidx.hilt.navigation.compose.hiltViewModel
|
|||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
||||
import com.example.llama.revamp.ui.theme.MonospacedTextStyle
|
||||
import com.example.llama.revamp.viewmodel.BenchmarkViewModel
|
||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
||||
|
||||
@Composable
|
||||
fun BenchmarkScreen(
|
||||
onBackPressed: () -> Unit,
|
||||
onRerunPressed: () -> Unit,
|
||||
onSharePressed: () -> Unit,
|
||||
viewModel: MainViewModel = hiltViewModel()
|
||||
viewModel: BenchmarkViewModel = hiltViewModel()
|
||||
) {
|
||||
val engineState by viewModel.engineState.collectAsState()
|
||||
val benchmarkResults by viewModel.benchmarkResults.collectAsState()
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ import androidx.lifecycle.LifecycleEventObserver
|
|||
import androidx.lifecycle.compose.LocalLifecycleOwner
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
||||
import com.example.llama.revamp.viewmodel.ConversationViewModel
|
||||
import com.example.llama.revamp.viewmodel.Message
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
|
|
@ -72,7 +72,7 @@ import kotlinx.coroutines.launch
|
|||
@Composable
|
||||
fun ConversationScreen(
|
||||
onBackPressed: () -> Unit,
|
||||
viewModel: MainViewModel = hiltViewModel()
|
||||
viewModel: ConversationViewModel = hiltViewModel()
|
||||
) {
|
||||
val engineState by viewModel.engineState.collectAsState()
|
||||
val messages by viewModel.messages.collectAsState()
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ import androidx.compose.runtime.collectAsState
|
|||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
|
|
@ -50,7 +51,10 @@ import androidx.hilt.navigation.compose.hiltViewModel
|
|||
import com.example.llama.revamp.data.model.SystemPrompt
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
||||
import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
|
||||
import com.example.llama.revamp.viewmodel.SystemPromptViewModel
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
enum class SystemPromptTab {
|
||||
PRESETS, CUSTOM, RECENTS
|
||||
|
|
@ -59,14 +63,17 @@ enum class SystemPromptTab {
|
|||
@OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
|
||||
@Composable
|
||||
fun ModelLoadingScreen(
|
||||
viewModel: SystemPromptViewModel = hiltViewModel(),
|
||||
engineState: InferenceEngine.State,
|
||||
onBenchmarkSelected: () -> Unit,
|
||||
onConversationSelected: (String?) -> Unit,
|
||||
onBenchmarkSelected: (prepareJob: Job) -> Unit,
|
||||
onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit,
|
||||
onBackPressed: () -> Unit,
|
||||
modelLoadingViewModel: ModelLoadingViewModel = hiltViewModel(),
|
||||
systemPromptViewModel: SystemPromptViewModel = hiltViewModel(),
|
||||
) {
|
||||
val presetPrompts by viewModel.presetPrompts.collectAsState()
|
||||
val recentPrompts by viewModel.recentPrompts.collectAsState()
|
||||
val coroutineScope = rememberCoroutineScope()
|
||||
|
||||
val presetPrompts by systemPromptViewModel.presetPrompts.collectAsState()
|
||||
val recentPrompts by systemPromptViewModel.recentPrompts.collectAsState()
|
||||
|
||||
var selectedMode by remember { mutableStateOf<Mode?>(null) }
|
||||
var useSystemPrompt by remember { mutableStateOf(false) }
|
||||
|
|
@ -96,6 +103,21 @@ fun ModelLoadingScreen(
|
|||
engineState !is InferenceEngine.State.LibraryLoaded &&
|
||||
engineState !is InferenceEngine.State.AwaitingUserPrompt
|
||||
|
||||
// Mode selection callbacks
|
||||
val handleBenchmarkSelected = {
|
||||
val prepareJob = coroutineScope.launch {
|
||||
modelLoadingViewModel.prepareForBenchmark()
|
||||
}
|
||||
onBenchmarkSelected(prepareJob)
|
||||
}
|
||||
|
||||
val handleConversationSelected = { systemPrompt: String? ->
|
||||
val prepareJob = coroutineScope.launch {
|
||||
modelLoadingViewModel.prepareForConversation(systemPrompt)
|
||||
}
|
||||
onConversationSelected(systemPrompt, prepareJob)
|
||||
}
|
||||
|
||||
PerformanceAppScaffold(
|
||||
title = "Load Model",
|
||||
onNavigateBack = onBackPressed,
|
||||
|
|
@ -143,7 +165,7 @@ fun ModelLoadingScreen(
|
|||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(bottom = 4.dp)
|
||||
// Only use weight if system prompt is active, otherwise wrap content
|
||||
// Only fill height if system prompt is active
|
||||
.then(if (useSystemPrompt) Modifier.weight(1f) else Modifier)
|
||||
) {
|
||||
Column(
|
||||
|
|
@ -355,14 +377,15 @@ fun ModelLoadingScreen(
|
|||
Button(
|
||||
onClick = {
|
||||
when (selectedMode) {
|
||||
Mode.BENCHMARK -> onBenchmarkSelected()
|
||||
Mode.BENCHMARK -> handleBenchmarkSelected()
|
||||
|
||||
Mode.CONVERSATION -> {
|
||||
val systemPrompt = if (useSystemPrompt) {
|
||||
when (selectedTab) {
|
||||
SystemPromptTab.PRESETS, SystemPromptTab.RECENTS ->
|
||||
selectedPrompt?.let { prompt ->
|
||||
// Save the prompt to recent prompts database
|
||||
viewModel.savePromptToRecents(prompt)
|
||||
systemPromptViewModel.savePromptToRecents(prompt)
|
||||
prompt.content
|
||||
}
|
||||
|
||||
|
|
@ -370,15 +393,15 @@ fun ModelLoadingScreen(
|
|||
customPromptText.takeIf { it.isNotBlank() }
|
||||
?.also { promptText ->
|
||||
// Save custom prompt to database
|
||||
viewModel.saveCustomPromptToRecents(promptText)
|
||||
systemPromptViewModel.saveCustomPromptToRecents(promptText)
|
||||
}
|
||||
}
|
||||
} else null
|
||||
onConversationSelected(systemPrompt)
|
||||
|
||||
handleConversationSelected(systemPrompt)
|
||||
}
|
||||
|
||||
null -> { /* No mode selected */
|
||||
}
|
||||
null -> { /* No mode selected */ }
|
||||
}
|
||||
},
|
||||
modifier = Modifier
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ import com.example.llama.revamp.data.model.ModelInfo
|
|||
import com.example.llama.revamp.ui.components.ModelCard
|
||||
import com.example.llama.revamp.ui.components.ModelCardActions
|
||||
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
||||
import com.example.llama.revamp.viewmodel.ModelSelectionViewModel
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
|
|
@ -38,10 +38,15 @@ fun ModelSelectionScreen(
|
|||
onModelSelected: (ModelInfo) -> Unit,
|
||||
onManageModelsClicked: () -> Unit,
|
||||
onMenuClicked: () -> Unit,
|
||||
viewModel: MainViewModel = hiltViewModel(),
|
||||
viewModel: ModelSelectionViewModel = hiltViewModel(),
|
||||
) {
|
||||
val models by viewModel.availableModels.collectAsState()
|
||||
|
||||
val handleModelSelection = { model: ModelInfo ->
|
||||
viewModel.selectModel(model)
|
||||
onModelSelected(model)
|
||||
}
|
||||
|
||||
PerformanceAppScaffold(
|
||||
title = "Models",
|
||||
onMenuOpen = onMenuClicked,
|
||||
|
|
@ -60,11 +65,13 @@ fun ModelSelectionScreen(
|
|||
items(models) { model ->
|
||||
ModelCard(
|
||||
model = model,
|
||||
onClick = { onModelSelected(model) },
|
||||
onClick = { handleModelSelection(model) },
|
||||
modifier = Modifier.padding(vertical = 4.dp),
|
||||
isSelected = null, // Not in selection mode
|
||||
actionButton = {
|
||||
ModelCardActions.PlayButton(onClick = { onModelSelected(model) })
|
||||
ModelCardActions.PlayButton {
|
||||
handleModelSelection(model)
|
||||
}
|
||||
}
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
package com.example.llama.revamp.viewmodel
|
||||
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.engine.InferenceManager
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.launch
|
||||
import javax.inject.Inject
|
||||
|
||||
@HiltViewModel
|
||||
class BenchmarkViewModel @Inject constructor(
|
||||
private val inferenceManager: InferenceManager
|
||||
) : ViewModel() {
|
||||
|
||||
val engineState: StateFlow<InferenceEngine.State> = inferenceManager.engineState
|
||||
val benchmarkResults: StateFlow<String?> = inferenceManager.benchmarkResults
|
||||
val selectedModel: StateFlow<ModelInfo?> = inferenceManager.currentModel
|
||||
|
||||
/**
|
||||
* Run benchmark with specified parameters
|
||||
*/
|
||||
fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3) =
|
||||
viewModelScope.launch {
|
||||
inferenceManager.benchmark(pp, tg, pl, nr)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,166 @@
|
|||
package com.example.llama.revamp.viewmodel
|
||||
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.example.llama.revamp.engine.InferenceManager
|
||||
import com.example.llama.revamp.engine.TokenMetrics
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.launch
|
||||
import java.text.SimpleDateFormat
|
||||
import java.util.Date
|
||||
import java.util.Locale
|
||||
import javax.inject.Inject
|
||||
import kotlin.getValue
|
||||
|
||||
|
||||
@HiltViewModel
|
||||
class ConversationViewModel @Inject constructor(
|
||||
private val inferenceManager: InferenceManager
|
||||
) : ViewModel() {
|
||||
|
||||
val engineState = inferenceManager.engineState
|
||||
val selectedModel = inferenceManager.currentModel
|
||||
val systemPrompt = inferenceManager.systemPrompt
|
||||
|
||||
// Messages in conversation
|
||||
private val _messages = MutableStateFlow<List<Message>>(emptyList())
|
||||
val messages: StateFlow<List<Message>> = _messages.asStateFlow()
|
||||
|
||||
// Token generation job
|
||||
private var tokenCollectionJob: Job? = null
|
||||
|
||||
/**
|
||||
* Send a message with the provided content.
|
||||
* Note: This matches the existing UI which manages input state outside the ViewModel.
|
||||
*/
|
||||
fun sendMessage(content: String) {
|
||||
if (content.isBlank()) return
|
||||
|
||||
// Cancel ongoing collection
|
||||
tokenCollectionJob?.cancel()
|
||||
|
||||
// Add user message
|
||||
val userMessage = Message.User(
|
||||
content = content,
|
||||
timestamp = System.currentTimeMillis()
|
||||
)
|
||||
_messages.value = _messages.value + userMessage
|
||||
|
||||
// Add placeholder for assistant response
|
||||
val assistantMessage = Message.Assistant.Ongoing(
|
||||
content = "",
|
||||
timestamp = System.currentTimeMillis()
|
||||
)
|
||||
_messages.value = _messages.value + assistantMessage
|
||||
|
||||
// Collect response
|
||||
tokenCollectionJob = viewModelScope.launch {
|
||||
try {
|
||||
inferenceManager.generateResponse(content)
|
||||
.collect { (text, isComplete) ->
|
||||
updateAssistantMessage(text, isComplete)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
// Handle error
|
||||
handleResponseError(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle updating the assistant message
|
||||
*/
|
||||
private fun updateAssistantMessage(text: String, isComplete: Boolean) {
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
val lastIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
|
||||
|
||||
if (currentAssistantMessage != null) {
|
||||
if (isComplete) {
|
||||
// Final message with metrics
|
||||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||
content = text,
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = inferenceManager.createTokenMetrics()
|
||||
)
|
||||
} else {
|
||||
// Ongoing message update
|
||||
currentMessages[lastIndex] = Message.Assistant.Ongoing(
|
||||
content = text,
|
||||
timestamp = currentAssistantMessage.timestamp
|
||||
)
|
||||
}
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle response error
|
||||
*/
|
||||
private fun handleResponseError(e: Exception) {
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
val lastIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
|
||||
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||
content = "${currentAssistantMessage.content}[Error: ${e.message}]",
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = inferenceManager.createTokenMetrics()
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear conversation
|
||||
*/
|
||||
fun clearConversation() {
|
||||
tokenCollectionJob?.cancel()
|
||||
_messages.value = emptyList()
|
||||
}
|
||||
|
||||
override fun onCleared() {
|
||||
tokenCollectionJob?.cancel()
|
||||
super.onCleared()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Sealed class representing messages in a conversation.
|
||||
*/
|
||||
sealed class Message {
|
||||
abstract val timestamp: Long
|
||||
abstract val content: String
|
||||
|
||||
val formattedTime: String
|
||||
get() = datetimeFormatter.format(Date(timestamp))
|
||||
|
||||
data class User(
|
||||
override val timestamp: Long,
|
||||
override val content: String
|
||||
) : Message()
|
||||
|
||||
sealed class Assistant : Message() {
|
||||
data class Ongoing(
|
||||
override val timestamp: Long,
|
||||
override val content: String,
|
||||
) : Assistant()
|
||||
|
||||
data class Completed(
|
||||
override val timestamp: Long,
|
||||
override val content: String,
|
||||
val metrics: TokenMetrics
|
||||
) : Assistant()
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val datetimeFormatter by lazy { SimpleDateFormat("h:mm a", Locale.getDefault()) }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1,276 +1,25 @@
|
|||
package com.example.llama.revamp.viewmodel
|
||||
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
import com.example.llama.revamp.data.repository.ModelRepository
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.engine.InferenceManager
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.SharingStarted
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.flow.catch
|
||||
import kotlinx.coroutines.flow.onCompletion
|
||||
import kotlinx.coroutines.flow.stateIn
|
||||
import kotlinx.coroutines.launch
|
||||
import java.text.SimpleDateFormat
|
||||
import java.util.Date
|
||||
import java.util.Locale
|
||||
import javax.inject.Inject
|
||||
|
||||
/**
|
||||
* Main ViewModel that handles the LLM engine state and operations.
|
||||
*/
|
||||
@HiltViewModel
|
||||
/**
|
||||
* Main ViewModel that expose the core states of [InferenceEngine]
|
||||
*/
|
||||
class MainViewModel @Inject constructor (
|
||||
private val inferenceEngine: InferenceEngine,
|
||||
private val modelRepository: ModelRepository,
|
||||
private val inferenceManager: InferenceManager,
|
||||
) : ViewModel() {
|
||||
|
||||
// Expose the engine state
|
||||
val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state
|
||||
|
||||
// Benchmark results
|
||||
val benchmarkResults: StateFlow<String?> = inferenceEngine.benchmarkResults
|
||||
|
||||
// Available models for selection
|
||||
val availableModels: StateFlow<List<ModelInfo>> = modelRepository.getModels()
|
||||
.stateIn(
|
||||
scope = viewModelScope,
|
||||
started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS),
|
||||
initialValue = emptyList()
|
||||
)
|
||||
|
||||
// Selected model information
|
||||
private val _selectedModel = MutableStateFlow<ModelInfo?>(null)
|
||||
val selectedModel: StateFlow<ModelInfo?> = _selectedModel.asStateFlow()
|
||||
|
||||
// Messages in the conversation
|
||||
private val _messages = MutableStateFlow<List<Message>>(emptyList())
|
||||
val messages: StateFlow<List<Message>> = _messages.asStateFlow()
|
||||
|
||||
// System prompt for the conversation
|
||||
private val _systemPrompt = MutableStateFlow<String?>(null)
|
||||
val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
|
||||
|
||||
// Flag to track if token collection is active
|
||||
private var tokenCollectionJob: Job? = null
|
||||
val engineState = inferenceManager.engineState
|
||||
|
||||
/**
|
||||
* Selects a model for use.
|
||||
* Unload the current model and release the resources
|
||||
*/
|
||||
fun selectModel(modelInfo: ModelInfo) {
|
||||
_selectedModel.value = modelInfo
|
||||
|
||||
viewModelScope.launch {
|
||||
modelRepository.updateModelLastUsed(modelInfo.id)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepares the engine for benchmark mode.
|
||||
*/
|
||||
suspend fun prepareForBenchmark() {
|
||||
_selectedModel.value?.let { model ->
|
||||
inferenceEngine.loadModel(model.path)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the benchmark with current parameters.
|
||||
*/
|
||||
suspend fun runBenchmark() = inferenceEngine.bench(512, 128, 1, 3)
|
||||
|
||||
/**
|
||||
* Reruns the benchmark.
|
||||
*/
|
||||
fun rerunBenchmark() = viewModelScope.launch { runBenchmark() }
|
||||
|
||||
/**
|
||||
* Prepares the engine for conversation mode.
|
||||
*/
|
||||
suspend fun prepareForConversation(systemPrompt: String? = null) {
|
||||
_systemPrompt.value = systemPrompt
|
||||
_selectedModel.value?.let { model ->
|
||||
inferenceEngine.loadModel(model.path, systemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tracks token generation metrics
|
||||
*/
|
||||
private var generationStartTime: Long = 0L
|
||||
private var firstTokenTime: Long = 0L
|
||||
private var tokenCount: Int = 0
|
||||
private var isFirstToken: Boolean = true
|
||||
|
||||
/**
|
||||
* Sends a user message and collects the response.
|
||||
*/
|
||||
fun sendMessage(content: String) {
|
||||
if (content.isBlank()) return
|
||||
|
||||
// Cancel any ongoing token collection
|
||||
tokenCollectionJob?.cancel()
|
||||
|
||||
// Add user message
|
||||
val userMessage = Message.User(
|
||||
content = content,
|
||||
timestamp = System.currentTimeMillis()
|
||||
)
|
||||
_messages.value = _messages.value + userMessage
|
||||
|
||||
// Create placeholder for assistant message
|
||||
val assistantMessage = Message.Assistant.Ongoing(
|
||||
content = "",
|
||||
timestamp = System.currentTimeMillis()
|
||||
)
|
||||
_messages.value = _messages.value + assistantMessage
|
||||
|
||||
// Reset metrics tracking
|
||||
generationStartTime = System.currentTimeMillis()
|
||||
firstTokenTime = 0L
|
||||
tokenCount = 0
|
||||
isFirstToken = true
|
||||
|
||||
// Get response from engine
|
||||
tokenCollectionJob = viewModelScope.launch {
|
||||
val response = StringBuilder()
|
||||
|
||||
try {
|
||||
inferenceEngine.sendUserPrompt(content)
|
||||
.catch { e ->
|
||||
// Handle errors during token collection
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.size >= 2) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant.Ongoing
|
||||
if (currentAssistantMessage != null) {
|
||||
// Create metrics with error indication
|
||||
val errorMetrics = TokenMetrics(
|
||||
tokensCount = tokenCount,
|
||||
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
|
||||
tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime)
|
||||
)
|
||||
|
||||
currentMessages[messageIndex] = Message.Assistant.Completed(
|
||||
content = "${response}[Error: ${e.message}]",
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = errorMetrics
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
}
|
||||
}
|
||||
.onCompletion { cause ->
|
||||
// Handle completion (normal or cancelled)
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.isNotEmpty()) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
|
||||
if (currentAssistantMessage != null) {
|
||||
// Calculate final metrics
|
||||
val endTime = System.currentTimeMillis()
|
||||
val totalTimeMs = endTime - generationStartTime
|
||||
|
||||
val metrics = TokenMetrics(
|
||||
tokensCount = tokenCount,
|
||||
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
|
||||
tpsMs = calculateTPS(tokenCount, totalTimeMs)
|
||||
)
|
||||
|
||||
currentMessages[messageIndex] = Message.Assistant.Completed(
|
||||
content = response.toString(),
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = metrics
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
}
|
||||
}
|
||||
.collect { token ->
|
||||
// Track first token time
|
||||
if (isFirstToken && token.isNotBlank()) {
|
||||
firstTokenTime = System.currentTimeMillis()
|
||||
isFirstToken = false
|
||||
}
|
||||
|
||||
// Count tokens - each non-empty emission is at least one token
|
||||
if (token.isNotBlank()) {
|
||||
tokenCount++
|
||||
}
|
||||
|
||||
response.append(token)
|
||||
|
||||
// Safely update the assistant message with the generated text
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.isNotEmpty()) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
|
||||
if (currentAssistantMessage != null) {
|
||||
currentMessages[messageIndex] = Message.Assistant.Ongoing(
|
||||
content = response.toString(),
|
||||
timestamp = currentAssistantMessage.timestamp
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
// Handle any unexpected exceptions
|
||||
val currentMessages = _messages.value.toMutableList()
|
||||
if (currentMessages.isNotEmpty()) {
|
||||
val messageIndex = currentMessages.size - 1
|
||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
|
||||
if (currentAssistantMessage != null) {
|
||||
// Create metrics with error indication
|
||||
val errorMetrics = TokenMetrics(
|
||||
tokensCount = tokenCount,
|
||||
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
|
||||
tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime)
|
||||
)
|
||||
|
||||
currentMessages[messageIndex] = Message.Assistant.Completed(
|
||||
content = "${response}[Error: ${e.message}]",
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = errorMetrics
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens per second.
|
||||
*/
|
||||
private fun calculateTPS(tokens: Int, timeMs: Long): Float {
|
||||
if (tokens <= 0 || timeMs <= 0) return 0f
|
||||
return (tokens.toFloat() * 1000f) / timeMs
|
||||
}
|
||||
|
||||
/**
|
||||
* Unloads the currently loaded model after cleanup chores:
|
||||
* - Cancel any ongoing token collection
|
||||
* - Clear messages
|
||||
*/
|
||||
suspend fun unloadModel() {
|
||||
tokenCollectionJob?.cancel()
|
||||
_messages.value = emptyList()
|
||||
|
||||
inferenceEngine.unloadModel()
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up resources when ViewModel is cleared.
|
||||
*/
|
||||
override fun onCleared() {
|
||||
inferenceEngine.destroy()
|
||||
super.onCleared()
|
||||
}
|
||||
suspend fun unloadModel() = inferenceManager.unloadModel()
|
||||
|
||||
companion object {
|
||||
private val TAG = MainViewModel::class.java.simpleName
|
||||
|
|
@ -279,44 +28,3 @@ class MainViewModel @Inject constructor (
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sealed class representing messages in a conversation.
|
||||
*/
|
||||
sealed class Message {
|
||||
abstract val timestamp: Long
|
||||
abstract val content: String
|
||||
|
||||
val formattedTime: String
|
||||
get() = datetimeFormatter.format(Date(timestamp))
|
||||
|
||||
data class User(
|
||||
override val timestamp: Long,
|
||||
override val content: String
|
||||
) : Message()
|
||||
|
||||
sealed class Assistant : Message() {
|
||||
data class Ongoing(
|
||||
override val timestamp: Long,
|
||||
override val content: String,
|
||||
) : Assistant()
|
||||
|
||||
data class Completed(
|
||||
override val timestamp: Long,
|
||||
override val content: String,
|
||||
val metrics: TokenMetrics
|
||||
) : Assistant()
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val datetimeFormatter by lazy { SimpleDateFormat("h:mm a", Locale.getDefault()) }
|
||||
}
|
||||
}
|
||||
|
||||
data class TokenMetrics(
|
||||
val tokensCount: Int,
|
||||
val ttftMs: Long,
|
||||
val tpsMs: Float,
|
||||
) {
|
||||
val text: String
|
||||
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
package com.example.llama.revamp.viewmodel
|
||||
|
||||
import androidx.lifecycle.ViewModel
|
||||
import com.example.llama.revamp.engine.InferenceManager
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import javax.inject.Inject
|
||||
|
||||
@HiltViewModel
|
||||
class ModelLoadingViewModel @Inject constructor(
|
||||
private val inferenceManager: InferenceManager
|
||||
) : ViewModel() {
|
||||
|
||||
val engineState = inferenceManager.engineState
|
||||
val selectedModel = inferenceManager.currentModel
|
||||
|
||||
/**
|
||||
* Prepares the engine for benchmark mode.
|
||||
*/
|
||||
suspend fun prepareForBenchmark() =
|
||||
inferenceManager.loadModelForBenchmark()
|
||||
|
||||
/**
|
||||
* Prepare for conversation
|
||||
*/
|
||||
suspend fun prepareForConversation(systemPrompt: String? = null) =
|
||||
inferenceManager.loadModelForConversation(systemPrompt)
|
||||
}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
package com.example.llama.revamp.viewmodel
|
||||
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
import com.example.llama.revamp.data.repository.ModelRepository
|
||||
import com.example.llama.revamp.engine.InferenceManager
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import kotlinx.coroutines.flow.SharingStarted
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.stateIn
|
||||
import kotlinx.coroutines.launch
|
||||
import javax.inject.Inject
|
||||
|
||||
|
||||
@HiltViewModel
|
||||
class ModelSelectionViewModel @Inject constructor(
|
||||
private val inferenceManager: InferenceManager,
|
||||
private val modelRepository: ModelRepository
|
||||
) : ViewModel() {
|
||||
|
||||
/**
|
||||
* Available models for selection
|
||||
*/
|
||||
val availableModels: StateFlow<List<ModelInfo>> = modelRepository.getModels()
|
||||
.stateIn(
|
||||
scope = viewModelScope,
|
||||
started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS),
|
||||
initialValue = emptyList()
|
||||
)
|
||||
|
||||
/**
|
||||
* Access to currently selected model
|
||||
*/
|
||||
val selectedModel = inferenceManager.currentModel
|
||||
|
||||
/**
|
||||
* Select a model and update its last used timestamp
|
||||
*/
|
||||
fun selectModel(modelInfo: ModelInfo) {
|
||||
inferenceManager.setCurrentModel(modelInfo)
|
||||
|
||||
viewModelScope.launch {
|
||||
modelRepository.updateModelLastUsed(modelInfo.id)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Unload model when navigating away
|
||||
*/
|
||||
suspend fun unloadModel() = inferenceManager.unloadModel()
|
||||
|
||||
companion object {
|
||||
private val TAG = ModelSelectionViewModel::class.java.simpleName
|
||||
|
||||
private const val SUBSCRIPTION_TIMEOUT_MS = 5000L
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue