vm: split mono MainViewModel into separate individual ViewModels
This commit is contained in:
parent
32d778bb8e
commit
23d411d86e
|
|
@ -170,7 +170,6 @@ fun AppContent(
|
||||||
composable(AppDestinations.MODEL_SELECTION_ROUTE) {
|
composable(AppDestinations.MODEL_SELECTION_ROUTE) {
|
||||||
ModelSelectionScreen(
|
ModelSelectionScreen(
|
||||||
onModelSelected = { modelInfo ->
|
onModelSelected = { modelInfo ->
|
||||||
mainVewModel.selectModel(modelInfo)
|
|
||||||
navigationActions.navigateToModelLoading()
|
navigationActions.navigateToModelLoading()
|
||||||
},
|
},
|
||||||
onManageModelsClicked = {
|
onManageModelsClicked = {
|
||||||
|
|
@ -184,34 +183,28 @@ fun AppContent(
|
||||||
composable(AppDestinations.MODEL_LOADING_ROUTE) {
|
composable(AppDestinations.MODEL_LOADING_ROUTE) {
|
||||||
ModelLoadingScreen(
|
ModelLoadingScreen(
|
||||||
engineState = engineState,
|
engineState = engineState,
|
||||||
onBenchmarkSelected = {
|
onBenchmarkSelected = { prepareJob ->
|
||||||
// Store a reference to the loading job
|
// Wait for preparation to complete, then navigate if still active
|
||||||
val loadingJob = coroutineScope.launch {
|
val loadingJob = coroutineScope.launch {
|
||||||
mainVewModel.prepareForBenchmark()
|
prepareJob.join()
|
||||||
// Check if the job wasn't cancelled before navigating
|
if (isActive) { navigationActions.navigateToBenchmark() }
|
||||||
if (isActive) {
|
|
||||||
navigationActions.navigateToBenchmark()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Update the pendingNavigation handler to cancel any ongoing loading
|
|
||||||
pendingNavigation = {
|
pendingNavigation = {
|
||||||
|
prepareJob.cancel()
|
||||||
loadingJob.cancel()
|
loadingJob.cancel()
|
||||||
navigationActions.navigateUp()
|
navigationActions.navigateUp()
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
onConversationSelected = { systemPrompt ->
|
onConversationSelected = { systemPrompt, prepareJob ->
|
||||||
// Store a reference to the loading job
|
// Wait for preparation to complete, then navigate if still active
|
||||||
val loadingJob = coroutineScope.launch {
|
val loadingJob = coroutineScope.launch {
|
||||||
mainVewModel.prepareForConversation(systemPrompt)
|
prepareJob.join()
|
||||||
// Check if the job wasn't cancelled before navigating
|
if (isActive) { navigationActions.navigateToConversation() }
|
||||||
if (isActive) {
|
|
||||||
navigationActions.navigateToConversation()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// Update the pendingNavigation handler to cancel any ongoing loading
|
|
||||||
pendingNavigation = {
|
pendingNavigation = {
|
||||||
|
prepareJob.cancel()
|
||||||
loadingJob.cancel()
|
loadingJob.cancel()
|
||||||
navigationActions.navigateUp()
|
navigationActions.navigateUp()
|
||||||
}
|
}
|
||||||
|
|
@ -229,8 +222,7 @@ fun AppContent(
|
||||||
onBackPressed = {
|
onBackPressed = {
|
||||||
// Need to unload model before going back
|
// Need to unload model before going back
|
||||||
handleBackWithModelCheck()
|
handleBackWithModelCheck()
|
||||||
},
|
}
|
||||||
viewModel = mainVewModel
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -240,14 +232,7 @@ fun AppContent(
|
||||||
onBackPressed = {
|
onBackPressed = {
|
||||||
// Need to unload model before going back
|
// Need to unload model before going back
|
||||||
handleBackWithModelCheck()
|
handleBackWithModelCheck()
|
||||||
},
|
}
|
||||||
onRerunPressed = {
|
|
||||||
mainVewModel.rerunBenchmark()
|
|
||||||
},
|
|
||||||
onSharePressed = {
|
|
||||||
// Stub for sharing functionality
|
|
||||||
},
|
|
||||||
viewModel = mainVewModel
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,14 +26,13 @@ import androidx.hilt.navigation.compose.hiltViewModel
|
||||||
import com.example.llama.revamp.engine.InferenceEngine
|
import com.example.llama.revamp.engine.InferenceEngine
|
||||||
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
||||||
import com.example.llama.revamp.ui.theme.MonospacedTextStyle
|
import com.example.llama.revamp.ui.theme.MonospacedTextStyle
|
||||||
|
import com.example.llama.revamp.viewmodel.BenchmarkViewModel
|
||||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
import com.example.llama.revamp.viewmodel.MainViewModel
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
fun BenchmarkScreen(
|
fun BenchmarkScreen(
|
||||||
onBackPressed: () -> Unit,
|
onBackPressed: () -> Unit,
|
||||||
onRerunPressed: () -> Unit,
|
viewModel: BenchmarkViewModel = hiltViewModel()
|
||||||
onSharePressed: () -> Unit,
|
|
||||||
viewModel: MainViewModel = hiltViewModel()
|
|
||||||
) {
|
) {
|
||||||
val engineState by viewModel.engineState.collectAsState()
|
val engineState by viewModel.engineState.collectAsState()
|
||||||
val benchmarkResults by viewModel.benchmarkResults.collectAsState()
|
val benchmarkResults by viewModel.benchmarkResults.collectAsState()
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ import androidx.lifecycle.LifecycleEventObserver
|
||||||
import androidx.lifecycle.compose.LocalLifecycleOwner
|
import androidx.lifecycle.compose.LocalLifecycleOwner
|
||||||
import com.example.llama.revamp.engine.InferenceEngine
|
import com.example.llama.revamp.engine.InferenceEngine
|
||||||
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
||||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
import com.example.llama.revamp.viewmodel.ConversationViewModel
|
||||||
import com.example.llama.revamp.viewmodel.Message
|
import com.example.llama.revamp.viewmodel.Message
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
|
|
@ -72,7 +72,7 @@ import kotlinx.coroutines.launch
|
||||||
@Composable
|
@Composable
|
||||||
fun ConversationScreen(
|
fun ConversationScreen(
|
||||||
onBackPressed: () -> Unit,
|
onBackPressed: () -> Unit,
|
||||||
viewModel: MainViewModel = hiltViewModel()
|
viewModel: ConversationViewModel = hiltViewModel()
|
||||||
) {
|
) {
|
||||||
val engineState by viewModel.engineState.collectAsState()
|
val engineState by viewModel.engineState.collectAsState()
|
||||||
val messages by viewModel.messages.collectAsState()
|
val messages by viewModel.messages.collectAsState()
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ import androidx.compose.runtime.collectAsState
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
import androidx.compose.runtime.mutableStateOf
|
import androidx.compose.runtime.mutableStateOf
|
||||||
import androidx.compose.runtime.remember
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
import androidx.compose.runtime.setValue
|
import androidx.compose.runtime.setValue
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
|
|
@ -50,7 +51,10 @@ import androidx.hilt.navigation.compose.hiltViewModel
|
||||||
import com.example.llama.revamp.data.model.SystemPrompt
|
import com.example.llama.revamp.data.model.SystemPrompt
|
||||||
import com.example.llama.revamp.engine.InferenceEngine
|
import com.example.llama.revamp.engine.InferenceEngine
|
||||||
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
||||||
|
import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
|
||||||
import com.example.llama.revamp.viewmodel.SystemPromptViewModel
|
import com.example.llama.revamp.viewmodel.SystemPromptViewModel
|
||||||
|
import kotlinx.coroutines.Job
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
enum class SystemPromptTab {
|
enum class SystemPromptTab {
|
||||||
PRESETS, CUSTOM, RECENTS
|
PRESETS, CUSTOM, RECENTS
|
||||||
|
|
@ -59,14 +63,17 @@ enum class SystemPromptTab {
|
||||||
@OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
|
@OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelLoadingScreen(
|
fun ModelLoadingScreen(
|
||||||
viewModel: SystemPromptViewModel = hiltViewModel(),
|
|
||||||
engineState: InferenceEngine.State,
|
engineState: InferenceEngine.State,
|
||||||
onBenchmarkSelected: () -> Unit,
|
onBenchmarkSelected: (prepareJob: Job) -> Unit,
|
||||||
onConversationSelected: (String?) -> Unit,
|
onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit,
|
||||||
onBackPressed: () -> Unit,
|
onBackPressed: () -> Unit,
|
||||||
|
modelLoadingViewModel: ModelLoadingViewModel = hiltViewModel(),
|
||||||
|
systemPromptViewModel: SystemPromptViewModel = hiltViewModel(),
|
||||||
) {
|
) {
|
||||||
val presetPrompts by viewModel.presetPrompts.collectAsState()
|
val coroutineScope = rememberCoroutineScope()
|
||||||
val recentPrompts by viewModel.recentPrompts.collectAsState()
|
|
||||||
|
val presetPrompts by systemPromptViewModel.presetPrompts.collectAsState()
|
||||||
|
val recentPrompts by systemPromptViewModel.recentPrompts.collectAsState()
|
||||||
|
|
||||||
var selectedMode by remember { mutableStateOf<Mode?>(null) }
|
var selectedMode by remember { mutableStateOf<Mode?>(null) }
|
||||||
var useSystemPrompt by remember { mutableStateOf(false) }
|
var useSystemPrompt by remember { mutableStateOf(false) }
|
||||||
|
|
@ -96,6 +103,21 @@ fun ModelLoadingScreen(
|
||||||
engineState !is InferenceEngine.State.LibraryLoaded &&
|
engineState !is InferenceEngine.State.LibraryLoaded &&
|
||||||
engineState !is InferenceEngine.State.AwaitingUserPrompt
|
engineState !is InferenceEngine.State.AwaitingUserPrompt
|
||||||
|
|
||||||
|
// Mode selection callbacks
|
||||||
|
val handleBenchmarkSelected = {
|
||||||
|
val prepareJob = coroutineScope.launch {
|
||||||
|
modelLoadingViewModel.prepareForBenchmark()
|
||||||
|
}
|
||||||
|
onBenchmarkSelected(prepareJob)
|
||||||
|
}
|
||||||
|
|
||||||
|
val handleConversationSelected = { systemPrompt: String? ->
|
||||||
|
val prepareJob = coroutineScope.launch {
|
||||||
|
modelLoadingViewModel.prepareForConversation(systemPrompt)
|
||||||
|
}
|
||||||
|
onConversationSelected(systemPrompt, prepareJob)
|
||||||
|
}
|
||||||
|
|
||||||
PerformanceAppScaffold(
|
PerformanceAppScaffold(
|
||||||
title = "Load Model",
|
title = "Load Model",
|
||||||
onNavigateBack = onBackPressed,
|
onNavigateBack = onBackPressed,
|
||||||
|
|
@ -143,7 +165,7 @@ fun ModelLoadingScreen(
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxWidth()
|
.fillMaxWidth()
|
||||||
.padding(bottom = 4.dp)
|
.padding(bottom = 4.dp)
|
||||||
// Only use weight if system prompt is active, otherwise wrap content
|
// Only fill height if system prompt is active
|
||||||
.then(if (useSystemPrompt) Modifier.weight(1f) else Modifier)
|
.then(if (useSystemPrompt) Modifier.weight(1f) else Modifier)
|
||||||
) {
|
) {
|
||||||
Column(
|
Column(
|
||||||
|
|
@ -355,14 +377,15 @@ fun ModelLoadingScreen(
|
||||||
Button(
|
Button(
|
||||||
onClick = {
|
onClick = {
|
||||||
when (selectedMode) {
|
when (selectedMode) {
|
||||||
Mode.BENCHMARK -> onBenchmarkSelected()
|
Mode.BENCHMARK -> handleBenchmarkSelected()
|
||||||
|
|
||||||
Mode.CONVERSATION -> {
|
Mode.CONVERSATION -> {
|
||||||
val systemPrompt = if (useSystemPrompt) {
|
val systemPrompt = if (useSystemPrompt) {
|
||||||
when (selectedTab) {
|
when (selectedTab) {
|
||||||
SystemPromptTab.PRESETS, SystemPromptTab.RECENTS ->
|
SystemPromptTab.PRESETS, SystemPromptTab.RECENTS ->
|
||||||
selectedPrompt?.let { prompt ->
|
selectedPrompt?.let { prompt ->
|
||||||
// Save the prompt to recent prompts database
|
// Save the prompt to recent prompts database
|
||||||
viewModel.savePromptToRecents(prompt)
|
systemPromptViewModel.savePromptToRecents(prompt)
|
||||||
prompt.content
|
prompt.content
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -370,15 +393,15 @@ fun ModelLoadingScreen(
|
||||||
customPromptText.takeIf { it.isNotBlank() }
|
customPromptText.takeIf { it.isNotBlank() }
|
||||||
?.also { promptText ->
|
?.also { promptText ->
|
||||||
// Save custom prompt to database
|
// Save custom prompt to database
|
||||||
viewModel.saveCustomPromptToRecents(promptText)
|
systemPromptViewModel.saveCustomPromptToRecents(promptText)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else null
|
} else null
|
||||||
onConversationSelected(systemPrompt)
|
|
||||||
|
handleConversationSelected(systemPrompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
null -> { /* No mode selected */
|
null -> { /* No mode selected */ }
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ import com.example.llama.revamp.data.model.ModelInfo
|
||||||
import com.example.llama.revamp.ui.components.ModelCard
|
import com.example.llama.revamp.ui.components.ModelCard
|
||||||
import com.example.llama.revamp.ui.components.ModelCardActions
|
import com.example.llama.revamp.ui.components.ModelCardActions
|
||||||
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
||||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
import com.example.llama.revamp.viewmodel.ModelSelectionViewModel
|
||||||
|
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
@Composable
|
@Composable
|
||||||
|
|
@ -38,10 +38,15 @@ fun ModelSelectionScreen(
|
||||||
onModelSelected: (ModelInfo) -> Unit,
|
onModelSelected: (ModelInfo) -> Unit,
|
||||||
onManageModelsClicked: () -> Unit,
|
onManageModelsClicked: () -> Unit,
|
||||||
onMenuClicked: () -> Unit,
|
onMenuClicked: () -> Unit,
|
||||||
viewModel: MainViewModel = hiltViewModel(),
|
viewModel: ModelSelectionViewModel = hiltViewModel(),
|
||||||
) {
|
) {
|
||||||
val models by viewModel.availableModels.collectAsState()
|
val models by viewModel.availableModels.collectAsState()
|
||||||
|
|
||||||
|
val handleModelSelection = { model: ModelInfo ->
|
||||||
|
viewModel.selectModel(model)
|
||||||
|
onModelSelected(model)
|
||||||
|
}
|
||||||
|
|
||||||
PerformanceAppScaffold(
|
PerformanceAppScaffold(
|
||||||
title = "Models",
|
title = "Models",
|
||||||
onMenuOpen = onMenuClicked,
|
onMenuOpen = onMenuClicked,
|
||||||
|
|
@ -60,11 +65,13 @@ fun ModelSelectionScreen(
|
||||||
items(models) { model ->
|
items(models) { model ->
|
||||||
ModelCard(
|
ModelCard(
|
||||||
model = model,
|
model = model,
|
||||||
onClick = { onModelSelected(model) },
|
onClick = { handleModelSelection(model) },
|
||||||
modifier = Modifier.padding(vertical = 4.dp),
|
modifier = Modifier.padding(vertical = 4.dp),
|
||||||
isSelected = null, // Not in selection mode
|
isSelected = null, // Not in selection mode
|
||||||
actionButton = {
|
actionButton = {
|
||||||
ModelCardActions.PlayButton(onClick = { onModelSelected(model) })
|
ModelCardActions.PlayButton {
|
||||||
|
handleModelSelection(model)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
Spacer(modifier = Modifier.height(8.dp))
|
Spacer(modifier = Modifier.height(8.dp))
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
package com.example.llama.revamp.viewmodel
|
||||||
|
|
||||||
|
import androidx.lifecycle.ViewModel
|
||||||
|
import androidx.lifecycle.viewModelScope
|
||||||
|
import com.example.llama.revamp.data.model.ModelInfo
|
||||||
|
import com.example.llama.revamp.engine.InferenceEngine
|
||||||
|
import com.example.llama.revamp.engine.InferenceManager
|
||||||
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
import javax.inject.Inject
|
||||||
|
|
||||||
|
@HiltViewModel
|
||||||
|
class BenchmarkViewModel @Inject constructor(
|
||||||
|
private val inferenceManager: InferenceManager
|
||||||
|
) : ViewModel() {
|
||||||
|
|
||||||
|
val engineState: StateFlow<InferenceEngine.State> = inferenceManager.engineState
|
||||||
|
val benchmarkResults: StateFlow<String?> = inferenceManager.benchmarkResults
|
||||||
|
val selectedModel: StateFlow<ModelInfo?> = inferenceManager.currentModel
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Run benchmark with specified parameters
|
||||||
|
*/
|
||||||
|
fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3) =
|
||||||
|
viewModelScope.launch {
|
||||||
|
inferenceManager.benchmark(pp, tg, pl, nr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,166 @@
|
||||||
|
package com.example.llama.revamp.viewmodel
|
||||||
|
|
||||||
|
import androidx.lifecycle.ViewModel
|
||||||
|
import androidx.lifecycle.viewModelScope
|
||||||
|
import com.example.llama.revamp.engine.InferenceManager
|
||||||
|
import com.example.llama.revamp.engine.TokenMetrics
|
||||||
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
|
import kotlinx.coroutines.Job
|
||||||
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
import java.text.SimpleDateFormat
|
||||||
|
import java.util.Date
|
||||||
|
import java.util.Locale
|
||||||
|
import javax.inject.Inject
|
||||||
|
import kotlin.getValue
|
||||||
|
|
||||||
|
|
||||||
|
@HiltViewModel
|
||||||
|
class ConversationViewModel @Inject constructor(
|
||||||
|
private val inferenceManager: InferenceManager
|
||||||
|
) : ViewModel() {
|
||||||
|
|
||||||
|
val engineState = inferenceManager.engineState
|
||||||
|
val selectedModel = inferenceManager.currentModel
|
||||||
|
val systemPrompt = inferenceManager.systemPrompt
|
||||||
|
|
||||||
|
// Messages in conversation
|
||||||
|
private val _messages = MutableStateFlow<List<Message>>(emptyList())
|
||||||
|
val messages: StateFlow<List<Message>> = _messages.asStateFlow()
|
||||||
|
|
||||||
|
// Token generation job
|
||||||
|
private var tokenCollectionJob: Job? = null
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send a message with the provided content.
|
||||||
|
* Note: This matches the existing UI which manages input state outside the ViewModel.
|
||||||
|
*/
|
||||||
|
fun sendMessage(content: String) {
|
||||||
|
if (content.isBlank()) return
|
||||||
|
|
||||||
|
// Cancel ongoing collection
|
||||||
|
tokenCollectionJob?.cancel()
|
||||||
|
|
||||||
|
// Add user message
|
||||||
|
val userMessage = Message.User(
|
||||||
|
content = content,
|
||||||
|
timestamp = System.currentTimeMillis()
|
||||||
|
)
|
||||||
|
_messages.value = _messages.value + userMessage
|
||||||
|
|
||||||
|
// Add placeholder for assistant response
|
||||||
|
val assistantMessage = Message.Assistant.Ongoing(
|
||||||
|
content = "",
|
||||||
|
timestamp = System.currentTimeMillis()
|
||||||
|
)
|
||||||
|
_messages.value = _messages.value + assistantMessage
|
||||||
|
|
||||||
|
// Collect response
|
||||||
|
tokenCollectionJob = viewModelScope.launch {
|
||||||
|
try {
|
||||||
|
inferenceManager.generateResponse(content)
|
||||||
|
.collect { (text, isComplete) ->
|
||||||
|
updateAssistantMessage(text, isComplete)
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
// Handle error
|
||||||
|
handleResponseError(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle updating the assistant message
|
||||||
|
*/
|
||||||
|
private fun updateAssistantMessage(text: String, isComplete: Boolean) {
|
||||||
|
val currentMessages = _messages.value.toMutableList()
|
||||||
|
val lastIndex = currentMessages.size - 1
|
||||||
|
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
|
||||||
|
|
||||||
|
if (currentAssistantMessage != null) {
|
||||||
|
if (isComplete) {
|
||||||
|
// Final message with metrics
|
||||||
|
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||||
|
content = text,
|
||||||
|
timestamp = currentAssistantMessage.timestamp,
|
||||||
|
metrics = inferenceManager.createTokenMetrics()
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// Ongoing message update
|
||||||
|
currentMessages[lastIndex] = Message.Assistant.Ongoing(
|
||||||
|
content = text,
|
||||||
|
timestamp = currentAssistantMessage.timestamp
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_messages.value = currentMessages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle response error
|
||||||
|
*/
|
||||||
|
private fun handleResponseError(e: Exception) {
|
||||||
|
val currentMessages = _messages.value.toMutableList()
|
||||||
|
val lastIndex = currentMessages.size - 1
|
||||||
|
val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing
|
||||||
|
|
||||||
|
if (currentAssistantMessage != null) {
|
||||||
|
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||||
|
content = "${currentAssistantMessage.content}[Error: ${e.message}]",
|
||||||
|
timestamp = currentAssistantMessage.timestamp,
|
||||||
|
metrics = inferenceManager.createTokenMetrics()
|
||||||
|
)
|
||||||
|
_messages.value = currentMessages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear conversation
|
||||||
|
*/
|
||||||
|
fun clearConversation() {
|
||||||
|
tokenCollectionJob?.cancel()
|
||||||
|
_messages.value = emptyList()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun onCleared() {
|
||||||
|
tokenCollectionJob?.cancel()
|
||||||
|
super.onCleared()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sealed class representing messages in a conversation.
|
||||||
|
*/
|
||||||
|
sealed class Message {
|
||||||
|
abstract val timestamp: Long
|
||||||
|
abstract val content: String
|
||||||
|
|
||||||
|
val formattedTime: String
|
||||||
|
get() = datetimeFormatter.format(Date(timestamp))
|
||||||
|
|
||||||
|
data class User(
|
||||||
|
override val timestamp: Long,
|
||||||
|
override val content: String
|
||||||
|
) : Message()
|
||||||
|
|
||||||
|
sealed class Assistant : Message() {
|
||||||
|
data class Ongoing(
|
||||||
|
override val timestamp: Long,
|
||||||
|
override val content: String,
|
||||||
|
) : Assistant()
|
||||||
|
|
||||||
|
data class Completed(
|
||||||
|
override val timestamp: Long,
|
||||||
|
override val content: String,
|
||||||
|
val metrics: TokenMetrics
|
||||||
|
) : Assistant()
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private val datetimeFormatter by lazy { SimpleDateFormat("h:mm a", Locale.getDefault()) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -1,276 +1,25 @@
|
||||||
package com.example.llama.revamp.viewmodel
|
package com.example.llama.revamp.viewmodel
|
||||||
|
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
|
||||||
import com.example.llama.revamp.data.model.ModelInfo
|
|
||||||
import com.example.llama.revamp.data.repository.ModelRepository
|
|
||||||
import com.example.llama.revamp.engine.InferenceEngine
|
import com.example.llama.revamp.engine.InferenceEngine
|
||||||
|
import com.example.llama.revamp.engine.InferenceManager
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
import kotlinx.coroutines.Job
|
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
|
||||||
import kotlinx.coroutines.flow.SharingStarted
|
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
|
||||||
import kotlinx.coroutines.flow.asStateFlow
|
|
||||||
import kotlinx.coroutines.flow.catch
|
|
||||||
import kotlinx.coroutines.flow.onCompletion
|
|
||||||
import kotlinx.coroutines.flow.stateIn
|
|
||||||
import kotlinx.coroutines.launch
|
|
||||||
import java.text.SimpleDateFormat
|
|
||||||
import java.util.Date
|
|
||||||
import java.util.Locale
|
|
||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
|
|
||||||
/**
|
|
||||||
* Main ViewModel that handles the LLM engine state and operations.
|
|
||||||
*/
|
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
|
/**
|
||||||
|
* Main ViewModel that expose the core states of [InferenceEngine]
|
||||||
|
*/
|
||||||
class MainViewModel @Inject constructor (
|
class MainViewModel @Inject constructor (
|
||||||
private val inferenceEngine: InferenceEngine,
|
private val inferenceManager: InferenceManager,
|
||||||
private val modelRepository: ModelRepository,
|
|
||||||
) : ViewModel() {
|
) : ViewModel() {
|
||||||
|
|
||||||
// Expose the engine state
|
val engineState = inferenceManager.engineState
|
||||||
val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state
|
|
||||||
|
|
||||||
// Benchmark results
|
|
||||||
val benchmarkResults: StateFlow<String?> = inferenceEngine.benchmarkResults
|
|
||||||
|
|
||||||
// Available models for selection
|
|
||||||
val availableModels: StateFlow<List<ModelInfo>> = modelRepository.getModels()
|
|
||||||
.stateIn(
|
|
||||||
scope = viewModelScope,
|
|
||||||
started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS),
|
|
||||||
initialValue = emptyList()
|
|
||||||
)
|
|
||||||
|
|
||||||
// Selected model information
|
|
||||||
private val _selectedModel = MutableStateFlow<ModelInfo?>(null)
|
|
||||||
val selectedModel: StateFlow<ModelInfo?> = _selectedModel.asStateFlow()
|
|
||||||
|
|
||||||
// Messages in the conversation
|
|
||||||
private val _messages = MutableStateFlow<List<Message>>(emptyList())
|
|
||||||
val messages: StateFlow<List<Message>> = _messages.asStateFlow()
|
|
||||||
|
|
||||||
// System prompt for the conversation
|
|
||||||
private val _systemPrompt = MutableStateFlow<String?>(null)
|
|
||||||
val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
|
|
||||||
|
|
||||||
// Flag to track if token collection is active
|
|
||||||
private var tokenCollectionJob: Job? = null
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Selects a model for use.
|
* Unload the current model and release the resources
|
||||||
*/
|
*/
|
||||||
fun selectModel(modelInfo: ModelInfo) {
|
suspend fun unloadModel() = inferenceManager.unloadModel()
|
||||||
_selectedModel.value = modelInfo
|
|
||||||
|
|
||||||
viewModelScope.launch {
|
|
||||||
modelRepository.updateModelLastUsed(modelInfo.id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Prepares the engine for benchmark mode.
|
|
||||||
*/
|
|
||||||
suspend fun prepareForBenchmark() {
|
|
||||||
_selectedModel.value?.let { model ->
|
|
||||||
inferenceEngine.loadModel(model.path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Runs the benchmark with current parameters.
|
|
||||||
*/
|
|
||||||
suspend fun runBenchmark() = inferenceEngine.bench(512, 128, 1, 3)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Reruns the benchmark.
|
|
||||||
*/
|
|
||||||
fun rerunBenchmark() = viewModelScope.launch { runBenchmark() }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Prepares the engine for conversation mode.
|
|
||||||
*/
|
|
||||||
suspend fun prepareForConversation(systemPrompt: String? = null) {
|
|
||||||
_systemPrompt.value = systemPrompt
|
|
||||||
_selectedModel.value?.let { model ->
|
|
||||||
inferenceEngine.loadModel(model.path, systemPrompt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Tracks token generation metrics
|
|
||||||
*/
|
|
||||||
private var generationStartTime: Long = 0L
|
|
||||||
private var firstTokenTime: Long = 0L
|
|
||||||
private var tokenCount: Int = 0
|
|
||||||
private var isFirstToken: Boolean = true
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sends a user message and collects the response.
|
|
||||||
*/
|
|
||||||
fun sendMessage(content: String) {
|
|
||||||
if (content.isBlank()) return
|
|
||||||
|
|
||||||
// Cancel any ongoing token collection
|
|
||||||
tokenCollectionJob?.cancel()
|
|
||||||
|
|
||||||
// Add user message
|
|
||||||
val userMessage = Message.User(
|
|
||||||
content = content,
|
|
||||||
timestamp = System.currentTimeMillis()
|
|
||||||
)
|
|
||||||
_messages.value = _messages.value + userMessage
|
|
||||||
|
|
||||||
// Create placeholder for assistant message
|
|
||||||
val assistantMessage = Message.Assistant.Ongoing(
|
|
||||||
content = "",
|
|
||||||
timestamp = System.currentTimeMillis()
|
|
||||||
)
|
|
||||||
_messages.value = _messages.value + assistantMessage
|
|
||||||
|
|
||||||
// Reset metrics tracking
|
|
||||||
generationStartTime = System.currentTimeMillis()
|
|
||||||
firstTokenTime = 0L
|
|
||||||
tokenCount = 0
|
|
||||||
isFirstToken = true
|
|
||||||
|
|
||||||
// Get response from engine
|
|
||||||
tokenCollectionJob = viewModelScope.launch {
|
|
||||||
val response = StringBuilder()
|
|
||||||
|
|
||||||
try {
|
|
||||||
inferenceEngine.sendUserPrompt(content)
|
|
||||||
.catch { e ->
|
|
||||||
// Handle errors during token collection
|
|
||||||
val currentMessages = _messages.value.toMutableList()
|
|
||||||
if (currentMessages.size >= 2) {
|
|
||||||
val messageIndex = currentMessages.size - 1
|
|
||||||
val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant.Ongoing
|
|
||||||
if (currentAssistantMessage != null) {
|
|
||||||
// Create metrics with error indication
|
|
||||||
val errorMetrics = TokenMetrics(
|
|
||||||
tokensCount = tokenCount,
|
|
||||||
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
|
|
||||||
tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime)
|
|
||||||
)
|
|
||||||
|
|
||||||
currentMessages[messageIndex] = Message.Assistant.Completed(
|
|
||||||
content = "${response}[Error: ${e.message}]",
|
|
||||||
timestamp = currentAssistantMessage.timestamp,
|
|
||||||
metrics = errorMetrics
|
|
||||||
)
|
|
||||||
_messages.value = currentMessages
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.onCompletion { cause ->
|
|
||||||
// Handle completion (normal or cancelled)
|
|
||||||
val currentMessages = _messages.value.toMutableList()
|
|
||||||
if (currentMessages.isNotEmpty()) {
|
|
||||||
val messageIndex = currentMessages.size - 1
|
|
||||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
|
|
||||||
if (currentAssistantMessage != null) {
|
|
||||||
// Calculate final metrics
|
|
||||||
val endTime = System.currentTimeMillis()
|
|
||||||
val totalTimeMs = endTime - generationStartTime
|
|
||||||
|
|
||||||
val metrics = TokenMetrics(
|
|
||||||
tokensCount = tokenCount,
|
|
||||||
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
|
|
||||||
tpsMs = calculateTPS(tokenCount, totalTimeMs)
|
|
||||||
)
|
|
||||||
|
|
||||||
currentMessages[messageIndex] = Message.Assistant.Completed(
|
|
||||||
content = response.toString(),
|
|
||||||
timestamp = currentAssistantMessage.timestamp,
|
|
||||||
metrics = metrics
|
|
||||||
)
|
|
||||||
_messages.value = currentMessages
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.collect { token ->
|
|
||||||
// Track first token time
|
|
||||||
if (isFirstToken && token.isNotBlank()) {
|
|
||||||
firstTokenTime = System.currentTimeMillis()
|
|
||||||
isFirstToken = false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count tokens - each non-empty emission is at least one token
|
|
||||||
if (token.isNotBlank()) {
|
|
||||||
tokenCount++
|
|
||||||
}
|
|
||||||
|
|
||||||
response.append(token)
|
|
||||||
|
|
||||||
// Safely update the assistant message with the generated text
|
|
||||||
val currentMessages = _messages.value.toMutableList()
|
|
||||||
if (currentMessages.isNotEmpty()) {
|
|
||||||
val messageIndex = currentMessages.size - 1
|
|
||||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
|
|
||||||
if (currentAssistantMessage != null) {
|
|
||||||
currentMessages[messageIndex] = Message.Assistant.Ongoing(
|
|
||||||
content = response.toString(),
|
|
||||||
timestamp = currentAssistantMessage.timestamp
|
|
||||||
)
|
|
||||||
_messages.value = currentMessages
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (e: Exception) {
|
|
||||||
// Handle any unexpected exceptions
|
|
||||||
val currentMessages = _messages.value.toMutableList()
|
|
||||||
if (currentMessages.isNotEmpty()) {
|
|
||||||
val messageIndex = currentMessages.size - 1
|
|
||||||
val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing
|
|
||||||
if (currentAssistantMessage != null) {
|
|
||||||
// Create metrics with error indication
|
|
||||||
val errorMetrics = TokenMetrics(
|
|
||||||
tokensCount = tokenCount,
|
|
||||||
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
|
|
||||||
tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime)
|
|
||||||
)
|
|
||||||
|
|
||||||
currentMessages[messageIndex] = Message.Assistant.Completed(
|
|
||||||
content = "${response}[Error: ${e.message}]",
|
|
||||||
timestamp = currentAssistantMessage.timestamp,
|
|
||||||
metrics = errorMetrics
|
|
||||||
)
|
|
||||||
_messages.value = currentMessages
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate tokens per second.
|
|
||||||
*/
|
|
||||||
private fun calculateTPS(tokens: Int, timeMs: Long): Float {
|
|
||||||
if (tokens <= 0 || timeMs <= 0) return 0f
|
|
||||||
return (tokens.toFloat() * 1000f) / timeMs
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Unloads the currently loaded model after cleanup chores:
|
|
||||||
* - Cancel any ongoing token collection
|
|
||||||
* - Clear messages
|
|
||||||
*/
|
|
||||||
suspend fun unloadModel() {
|
|
||||||
tokenCollectionJob?.cancel()
|
|
||||||
_messages.value = emptyList()
|
|
||||||
|
|
||||||
inferenceEngine.unloadModel()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up resources when ViewModel is cleared.
|
|
||||||
*/
|
|
||||||
override fun onCleared() {
|
|
||||||
inferenceEngine.destroy()
|
|
||||||
super.onCleared()
|
|
||||||
}
|
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private val TAG = MainViewModel::class.java.simpleName
|
private val TAG = MainViewModel::class.java.simpleName
|
||||||
|
|
@ -279,44 +28,3 @@ class MainViewModel @Inject constructor (
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Sealed class representing messages in a conversation.
|
|
||||||
*/
|
|
||||||
sealed class Message {
|
|
||||||
abstract val timestamp: Long
|
|
||||||
abstract val content: String
|
|
||||||
|
|
||||||
val formattedTime: String
|
|
||||||
get() = datetimeFormatter.format(Date(timestamp))
|
|
||||||
|
|
||||||
data class User(
|
|
||||||
override val timestamp: Long,
|
|
||||||
override val content: String
|
|
||||||
) : Message()
|
|
||||||
|
|
||||||
sealed class Assistant : Message() {
|
|
||||||
data class Ongoing(
|
|
||||||
override val timestamp: Long,
|
|
||||||
override val content: String,
|
|
||||||
) : Assistant()
|
|
||||||
|
|
||||||
data class Completed(
|
|
||||||
override val timestamp: Long,
|
|
||||||
override val content: String,
|
|
||||||
val metrics: TokenMetrics
|
|
||||||
) : Assistant()
|
|
||||||
}
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
private val datetimeFormatter by lazy { SimpleDateFormat("h:mm a", Locale.getDefault()) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data class TokenMetrics(
|
|
||||||
val tokensCount: Int,
|
|
||||||
val ttftMs: Long,
|
|
||||||
val tpsMs: Float,
|
|
||||||
) {
|
|
||||||
val text: String
|
|
||||||
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
package com.example.llama.revamp.viewmodel
|
||||||
|
|
||||||
|
import androidx.lifecycle.ViewModel
|
||||||
|
import com.example.llama.revamp.engine.InferenceManager
|
||||||
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
|
import javax.inject.Inject
|
||||||
|
|
||||||
|
@HiltViewModel
|
||||||
|
class ModelLoadingViewModel @Inject constructor(
|
||||||
|
private val inferenceManager: InferenceManager
|
||||||
|
) : ViewModel() {
|
||||||
|
|
||||||
|
val engineState = inferenceManager.engineState
|
||||||
|
val selectedModel = inferenceManager.currentModel
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Prepares the engine for benchmark mode.
|
||||||
|
*/
|
||||||
|
suspend fun prepareForBenchmark() =
|
||||||
|
inferenceManager.loadModelForBenchmark()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Prepare for conversation
|
||||||
|
*/
|
||||||
|
suspend fun prepareForConversation(systemPrompt: String? = null) =
|
||||||
|
inferenceManager.loadModelForConversation(systemPrompt)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
package com.example.llama.revamp.viewmodel
|
||||||
|
|
||||||
|
import androidx.lifecycle.ViewModel
|
||||||
|
import androidx.lifecycle.viewModelScope
|
||||||
|
import com.example.llama.revamp.data.model.ModelInfo
|
||||||
|
import com.example.llama.revamp.data.repository.ModelRepository
|
||||||
|
import com.example.llama.revamp.engine.InferenceManager
|
||||||
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
|
import kotlinx.coroutines.flow.SharingStarted
|
||||||
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
|
import kotlinx.coroutines.flow.stateIn
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
import javax.inject.Inject
|
||||||
|
|
||||||
|
|
||||||
|
@HiltViewModel
|
||||||
|
class ModelSelectionViewModel @Inject constructor(
|
||||||
|
private val inferenceManager: InferenceManager,
|
||||||
|
private val modelRepository: ModelRepository
|
||||||
|
) : ViewModel() {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Available models for selection
|
||||||
|
*/
|
||||||
|
val availableModels: StateFlow<List<ModelInfo>> = modelRepository.getModels()
|
||||||
|
.stateIn(
|
||||||
|
scope = viewModelScope,
|
||||||
|
started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS),
|
||||||
|
initialValue = emptyList()
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Access to currently selected model
|
||||||
|
*/
|
||||||
|
val selectedModel = inferenceManager.currentModel
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Select a model and update its last used timestamp
|
||||||
|
*/
|
||||||
|
fun selectModel(modelInfo: ModelInfo) {
|
||||||
|
inferenceManager.setCurrentModel(modelInfo)
|
||||||
|
|
||||||
|
viewModelScope.launch {
|
||||||
|
modelRepository.updateModelLastUsed(modelInfo.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Unload model when navigating away
|
||||||
|
*/
|
||||||
|
suspend fun unloadModel() = inferenceManager.unloadModel()
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private val TAG = ModelSelectionViewModel::class.java.simpleName
|
||||||
|
|
||||||
|
private const val SUBSCRIPTION_TIMEOUT_MS = 5000L
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue