navigation: sink model loading state management from AppContent down into ModelLoadingScreen; pass ModelLoadingMetrics to Benchmark and Conversation screens

This commit is contained in:
Han Yin 2025-04-18 16:46:25 -07:00
parent 8a682ff85d
commit a9466c0370
6 changed files with 90 additions and 63 deletions

View File

@ -21,9 +21,12 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.hilt.navigation.compose.hiltViewModel import androidx.hilt.navigation.compose.hiltViewModel
import androidx.navigation.NavType
import androidx.navigation.compose.composable import androidx.navigation.compose.composable
import androidx.navigation.compose.currentBackStackEntryAsState import androidx.navigation.compose.currentBackStackEntryAsState
import androidx.navigation.compose.rememberNavController import androidx.navigation.compose.rememberNavController
import androidx.navigation.navArgument
import com.example.llama.revamp.engine.ModelLoadingMetrics
import com.example.llama.revamp.navigation.AppDestinations import com.example.llama.revamp.navigation.AppDestinations
import com.example.llama.revamp.navigation.NavigationActions import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.AnimatedNavHost import com.example.llama.revamp.ui.components.AnimatedNavHost
@ -48,7 +51,6 @@ import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
import com.example.llama.revamp.viewmodel.ModelsManagementViewModel import com.example.llama.revamp.viewmodel.ModelsManagementViewModel
import com.example.llama.revamp.viewmodel.PerformanceViewModel import com.example.llama.revamp.viewmodel.PerformanceViewModel
import dagger.hilt.android.AndroidEntryPoint import dagger.hilt.android.AndroidEntryPoint
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@AndroidEntryPoint @AndroidEntryPoint
@ -110,9 +112,9 @@ fun AppContent(
val openDrawer: () -> Unit = { coroutineScope.launch { drawerState.open() } } val openDrawer: () -> Unit = { coroutineScope.launch { drawerState.open() } }
// Create scaffold's top & bottom bar configs based on current route // Create scaffold's top & bottom bar configs based on current route
val scaffoldConfig = when (currentRoute) { val scaffoldConfig = when {
// Model selection screen // Model selection screen
AppDestinations.MODEL_SELECTION_ROUTE -> currentRoute == AppDestinations.MODEL_SELECTION_ROUTE ->
ScaffoldConfig( ScaffoldConfig(
topBarConfig = TopBarConfig.Default( topBarConfig = TopBarConfig.Default(
title = "Models", title = "Models",
@ -121,7 +123,7 @@ fun AppContent(
) )
// Model loading screen // Model loading screen
AppDestinations.MODEL_LOADING_ROUTE -> currentRoute == AppDestinations.MODEL_LOADING_ROUTE ->
ScaffoldConfig( ScaffoldConfig(
topBarConfig = TopBarConfig.Performance( topBarConfig = TopBarConfig.Performance(
title = "Load Model", title = "Load Model",
@ -134,7 +136,7 @@ fun AppContent(
) )
// Benchmark screen // Benchmark screen
AppDestinations.BENCHMARK_ROUTE -> currentRoute.startsWith(AppDestinations.BENCHMARK_ROUTE) ->
ScaffoldConfig( ScaffoldConfig(
topBarConfig = TopBarConfig.Performance( topBarConfig = TopBarConfig.Performance(
title = "Benchmark", title = "Benchmark",
@ -147,7 +149,7 @@ fun AppContent(
) )
// Conversation screen // Conversation screen
AppDestinations.CONVERSATION_ROUTE -> currentRoute.startsWith(AppDestinations.CONVERSATION_ROUTE) ->
ScaffoldConfig( ScaffoldConfig(
topBarConfig = TopBarConfig.Performance( topBarConfig = TopBarConfig.Performance(
title = "Chat", title = "Chat",
@ -160,7 +162,7 @@ fun AppContent(
) )
// Settings screen // Settings screen
AppDestinations.SETTINGS_GENERAL_ROUTE -> currentRoute == AppDestinations.SETTINGS_GENERAL_ROUTE ->
ScaffoldConfig( ScaffoldConfig(
topBarConfig = TopBarConfig.Default( topBarConfig = TopBarConfig.Default(
title = "Settings", title = "Settings",
@ -169,7 +171,7 @@ fun AppContent(
) )
// Storage management screen // Storage management screen
AppDestinations.MODELS_MANAGEMENT_ROUTE -> { currentRoute == AppDestinations.MODELS_MANAGEMENT_ROUTE -> {
// Collect the needed states // Collect the needed states
val sortOrder by modelsManagementViewModel.sortOrder.collectAsState() val sortOrder by modelsManagementViewModel.sortOrder.collectAsState()
val isMultiSelectionMode by modelsManagementViewModel.isMultiSelectionMode.collectAsState() val isMultiSelectionMode by modelsManagementViewModel.isMultiSelectionMode.collectAsState()
@ -301,35 +303,59 @@ fun AppContent(
composable(AppDestinations.MODEL_LOADING_ROUTE) { composable(AppDestinations.MODEL_LOADING_ROUTE) {
ModelLoadingScreen( ModelLoadingScreen(
onNavigateBack = { navigationActions.navigateUp() }, onNavigateBack = { navigationActions.navigateUp() },
onBenchmarkSelected = { prepareJob -> onNavigateToBenchmark = { navigationActions.navigateToBenchmark(it) },
// Wait for preparation to complete, then navigate if still active onNavigateToConversation = { navigationActions.navigateToConversation(it) },
coroutineScope.launch {
prepareJob.join()
if (isActive) { navigationActions.navigateToBenchmark() }
}
},
onConversationSelected = { systemPrompt, prepareJob ->
// Wait for preparation to complete, then navigate if still active
coroutineScope.launch {
prepareJob.join()
if (isActive) { navigationActions.navigateToConversation() }
}
},
viewModel = modelLoadingViewModel viewModel = modelLoadingViewModel
) )
} }
// Benchmark Screen // Benchmark Screen
composable(AppDestinations.BENCHMARK_ROUTE) { composable(
route = AppDestinations.BENCHMARK_ROUTE_WITH_PARAMS,
arguments = listOf(
navArgument("modelLoadTimeMs") {
type = NavType.LongType
defaultValue = 0L
}
)
) { backStackEntry ->
val modelLoadTimeMs = backStackEntry.arguments?.getLong("modelLoadTimeMs") ?: 0L
val metrics = if (modelLoadTimeMs > 0) {
ModelLoadingMetrics(modelLoadTimeMs)
} else throw IllegalArgumentException("Expecting a valid ModelLoadingMetrics!")
BenchmarkScreen( BenchmarkScreen(
loadingMetrics = metrics,
onNavigateBack = { navigationActions.navigateUp() }, onNavigateBack = { navigationActions.navigateUp() },
viewModel = benchmarkViewModel viewModel = benchmarkViewModel
) )
} }
// Conversation Screen // Conversation Screen
composable(AppDestinations.CONVERSATION_ROUTE) { composable(
route = AppDestinations.CONVERSATION_ROUTE_WITH_PARAMS,
arguments = listOf(
navArgument("modelLoadTimeMs") {
type = NavType.LongType
defaultValue = 0L
},
navArgument("promptTimeMs") {
type = NavType.LongType
defaultValue = 0L
}
)
) { backStackEntry ->
val modelLoadTimeMs = backStackEntry.arguments?.getLong("modelLoadTimeMs") ?: 0L
val promptTimeMs = backStackEntry.arguments?.getLong("promptTimeMs") ?: 0L
val metrics = if (modelLoadTimeMs > 0) {
ModelLoadingMetrics(
modelLoadingTimeMs = modelLoadTimeMs,
systemPromptProcessingTimeMs = if (promptTimeMs > 0) promptTimeMs else null
)
} else throw IllegalArgumentException("Expecting a valid ModelLoadingMetrics!")
ConversationScreen( ConversationScreen(
loadingMetrics = metrics,
onNavigateBack = { navigationActions.navigateUp() }, onNavigateBack = { navigationActions.navigateUp() },
viewModel = conversationViewModel viewModel = conversationViewModel
) )

View File

@ -1,6 +1,7 @@
package com.example.llama.revamp.navigation package com.example.llama.revamp.navigation
import androidx.navigation.NavController import androidx.navigation.NavController
import com.example.llama.revamp.engine.ModelLoadingMetrics
/** /**
* Navigation destinations for the app * Navigation destinations for the app
@ -9,10 +10,14 @@ object AppDestinations {
// Primary navigation destinations // Primary navigation destinations
const val MODEL_SELECTION_ROUTE = "model_selection" const val MODEL_SELECTION_ROUTE = "model_selection"
const val MODEL_LOADING_ROUTE = "model_loading" const val MODEL_LOADING_ROUTE = "model_loading"
const val CONVERSATION_ROUTE = "conversation"
const val BENCHMARK_ROUTE = "benchmark"
// Settings destinations (moved from tabs to separate routes) const val CONVERSATION_ROUTE = "conversation"
const val CONVERSATION_ROUTE_WITH_PARAMS = "conversation/{modelLoadTimeMs}/{promptTimeMs}"
const val BENCHMARK_ROUTE = "benchmark"
const val BENCHMARK_ROUTE_WITH_PARAMS = "benchmark/{modelLoadTimeMs}"
// Settings destinations
const val SETTINGS_GENERAL_ROUTE = "settings_general" const val SETTINGS_GENERAL_ROUTE = "settings_general"
const val MODELS_MANAGEMENT_ROUTE = "models_management" const val MODELS_MANAGEMENT_ROUTE = "models_management"
} }
@ -33,12 +38,17 @@ class NavigationActions(private val navController: NavController) {
navController.navigate(AppDestinations.MODEL_LOADING_ROUTE) navController.navigate(AppDestinations.MODEL_LOADING_ROUTE)
} }
fun navigateToConversation() { fun navigateToConversation(metrics: ModelLoadingMetrics) {
navController.navigate(AppDestinations.CONVERSATION_ROUTE) val route = AppDestinations.CONVERSATION_ROUTE
val modelLoadTimeMs = metrics.modelLoadingTimeMs
val promptTimeMs = metrics.systemPromptProcessingTimeMs ?: 0
navController.navigate("$route/$modelLoadTimeMs/$promptTimeMs")
} }
fun navigateToBenchmark() { fun navigateToBenchmark(metrics: ModelLoadingMetrics) {
navController.navigate(AppDestinations.BENCHMARK_ROUTE) val route = AppDestinations.BENCHMARK_ROUTE
val modelLoadTimeMs = metrics.modelLoadingTimeMs
navController.navigate("$route/$modelLoadTimeMs")
} }
fun navigateToSettingsGeneral() { fun navigateToSettingsGeneral() {

View File

@ -24,6 +24,7 @@ import androidx.compose.runtime.getValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.example.llama.revamp.engine.ModelLoadingMetrics
import com.example.llama.revamp.ui.components.ModelCard import com.example.llama.revamp.ui.components.ModelCard
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
import com.example.llama.revamp.ui.theme.MonospacedTextStyle import com.example.llama.revamp.ui.theme.MonospacedTextStyle
@ -31,6 +32,7 @@ import com.example.llama.revamp.viewmodel.BenchmarkViewModel
@Composable @Composable
fun BenchmarkScreen( fun BenchmarkScreen(
loadingMetrics: ModelLoadingMetrics,
onNavigateBack: () -> Unit, onNavigateBack: () -> Unit,
viewModel: BenchmarkViewModel viewModel: BenchmarkViewModel
) { ) {

View File

@ -57,6 +57,7 @@ import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleEventObserver import androidx.lifecycle.LifecycleEventObserver
import androidx.lifecycle.compose.LocalLifecycleOwner import androidx.lifecycle.compose.LocalLifecycleOwner
import com.example.llama.revamp.APP_NAME import com.example.llama.revamp.APP_NAME
import com.example.llama.revamp.engine.ModelLoadingMetrics
import com.example.llama.revamp.ui.components.ModelCardWithSystemPrompt import com.example.llama.revamp.ui.components.ModelCardWithSystemPrompt
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
import com.example.llama.revamp.viewmodel.ConversationViewModel import com.example.llama.revamp.viewmodel.ConversationViewModel
@ -68,6 +69,7 @@ import kotlinx.coroutines.launch
*/ */
@Composable @Composable
fun ConversationScreen( fun ConversationScreen(
loadingMetrics: ModelLoadingMetrics,
onNavigateBack: () -> Unit, onNavigateBack: () -> Unit,
viewModel: ConversationViewModel viewModel: ConversationViewModel
) { ) {

View File

@ -42,7 +42,6 @@ import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
@ -50,11 +49,10 @@ import androidx.compose.ui.semantics.Role
import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.data.model.SystemPrompt
import com.example.llama.revamp.engine.ModelLoadingMetrics
import com.example.llama.revamp.ui.components.ModelCard import com.example.llama.revamp.ui.components.ModelCard
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
import com.example.llama.revamp.viewmodel.ModelLoadingViewModel import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
enum class Mode { enum class Mode {
@ -70,12 +68,10 @@ enum class SystemPromptTab(val label: String) {
@Composable @Composable
fun ModelLoadingScreen( fun ModelLoadingScreen(
onNavigateBack: () -> Unit, onNavigateBack: () -> Unit,
onBenchmarkSelected: (prepareJob: Job) -> Unit, onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit,
onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit, onNavigateToConversation: (ModelLoadingMetrics) -> Unit,
viewModel: ModelLoadingViewModel, viewModel: ModelLoadingViewModel,
) { ) {
val coroutineScope = rememberCoroutineScope()
val engineState by viewModel.engineState.collectAsState() val engineState by viewModel.engineState.collectAsState()
val selectedModel by viewModel.selectedModel.collectAsState() val selectedModel by viewModel.selectedModel.collectAsState()
val presetPrompts by viewModel.presetPrompts.collectAsState() val presetPrompts by viewModel.presetPrompts.collectAsState()
@ -108,22 +104,6 @@ fun ModelLoadingScreen(
// Check if we're in a loading state // Check if we're in a loading state
val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady
// Mode selection callbacks
val handleBenchmarkSelected = {
val prepareJob = coroutineScope.launch {
viewModel.prepareForBenchmark()
}
onBenchmarkSelected(prepareJob)
}
// TODO-han.yin: refactor this into ViewModel too
val handleConversationSelected = { systemPrompt: String? ->
val prepareJob = coroutineScope.launch {
viewModel.prepareForConversation(systemPrompt)
}
onConversationSelected(systemPrompt, prepareJob)
}
// Handle back navigation requests // Handle back navigation requests
BackHandler { BackHandler {
viewModel.onBackPressed(onNavigateBack) viewModel.onBackPressed(onNavigateBack)
@ -301,7 +281,7 @@ fun ModelLoadingScreen(
Button( Button(
onClick = { onClick = {
when (selectedMode) { when (selectedMode) {
Mode.BENCHMARK -> handleBenchmarkSelected() Mode.BENCHMARK -> viewModel.onBenchmarkSelected(onNavigateToBenchmark)
Mode.CONVERSATION -> { Mode.CONVERSATION -> {
val systemPrompt = if (useSystemPrompt) { val systemPrompt = if (useSystemPrompt) {
@ -324,7 +304,7 @@ fun ModelLoadingScreen(
} }
} else null } else null
handleConversationSelected(systemPrompt) viewModel.onConversationSelected(systemPrompt, onNavigateToConversation)
} }
null -> { /* No mode selected */ null -> { /* No mode selected */

View File

@ -3,6 +3,7 @@ package com.example.llama.revamp.viewmodel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.data.model.SystemPrompt
import com.example.llama.revamp.data.repository.SystemPromptRepository import com.example.llama.revamp.data.repository.SystemPromptRepository
import com.example.llama.revamp.engine.ModelLoadingMetrics
import com.example.llama.revamp.engine.ModelLoadingService import com.example.llama.revamp.engine.ModelLoadingService
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.SharingStarted
@ -79,17 +80,23 @@ class ModelLoadingViewModel @Inject constructor(
} }
/** /**
* Prepares the engine for benchmark mode. * Loads the model, then navigate to [BenchmarkScreen] with [ModelLoadingMetrics]
*/ */
suspend fun prepareForBenchmark() = fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) =
modelLoadingService.loadModelForBenchmark() viewModelScope.launch {
onNavigateToBenchmark(modelLoadingService.loadModelForBenchmark())
}
/** /**
* Prepare for conversation * Loads the model, process system prompt if any,
* then navigate to [ConversationScreen] with [ModelLoadingMetrics]
*/ */
suspend fun prepareForConversation(systemPrompt: String? = null) = fun onConversationSelected(
modelLoadingService.loadModelForConversation(systemPrompt) systemPrompt: String? = null,
onNavigateToConversation: (ModelLoadingMetrics) -> Unit
) = viewModelScope.launch {
onNavigateToConversation(modelLoadingService.loadModelForConversation(systemPrompt))
}
companion object { companion object {
private val TAG = ModelLoadingViewModel::class.java.simpleName private val TAG = ModelLoadingViewModel::class.java.simpleName