navigation: sink model loading state management from AppContent down into ModelLoadingScreen; pass ModelLoadingMetrics to Benchmark and Conversation screens

This commit is contained in:
Han Yin 2025-04-18 16:46:25 -07:00
parent 8a682ff85d
commit a9466c0370
6 changed files with 90 additions and 63 deletions

View File

@ -21,9 +21,12 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.ui.Modifier
import androidx.hilt.navigation.compose.hiltViewModel
import androidx.navigation.NavType
import androidx.navigation.compose.composable
import androidx.navigation.compose.currentBackStackEntryAsState
import androidx.navigation.compose.rememberNavController
import androidx.navigation.navArgument
import com.example.llama.revamp.engine.ModelLoadingMetrics
import com.example.llama.revamp.navigation.AppDestinations
import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.AnimatedNavHost
@ -48,7 +51,6 @@ import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
import com.example.llama.revamp.viewmodel.ModelsManagementViewModel
import com.example.llama.revamp.viewmodel.PerformanceViewModel
import dagger.hilt.android.AndroidEntryPoint
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
@AndroidEntryPoint
@ -110,9 +112,9 @@ fun AppContent(
val openDrawer: () -> Unit = { coroutineScope.launch { drawerState.open() } }
// Create scaffold's top & bottom bar configs based on current route
val scaffoldConfig = when (currentRoute) {
val scaffoldConfig = when {
// Model selection screen
AppDestinations.MODEL_SELECTION_ROUTE ->
currentRoute == AppDestinations.MODEL_SELECTION_ROUTE ->
ScaffoldConfig(
topBarConfig = TopBarConfig.Default(
title = "Models",
@ -121,7 +123,7 @@ fun AppContent(
)
// Model loading screen
AppDestinations.MODEL_LOADING_ROUTE ->
currentRoute == AppDestinations.MODEL_LOADING_ROUTE ->
ScaffoldConfig(
topBarConfig = TopBarConfig.Performance(
title = "Load Model",
@ -134,7 +136,7 @@ fun AppContent(
)
// Benchmark screen
AppDestinations.BENCHMARK_ROUTE ->
currentRoute.startsWith(AppDestinations.BENCHMARK_ROUTE) ->
ScaffoldConfig(
topBarConfig = TopBarConfig.Performance(
title = "Benchmark",
@ -147,7 +149,7 @@ fun AppContent(
)
// Conversation screen
AppDestinations.CONVERSATION_ROUTE ->
currentRoute.startsWith(AppDestinations.CONVERSATION_ROUTE) ->
ScaffoldConfig(
topBarConfig = TopBarConfig.Performance(
title = "Chat",
@ -160,7 +162,7 @@ fun AppContent(
)
// Settings screen
AppDestinations.SETTINGS_GENERAL_ROUTE ->
currentRoute == AppDestinations.SETTINGS_GENERAL_ROUTE ->
ScaffoldConfig(
topBarConfig = TopBarConfig.Default(
title = "Settings",
@ -169,7 +171,7 @@ fun AppContent(
)
// Storage management screen
AppDestinations.MODELS_MANAGEMENT_ROUTE -> {
currentRoute == AppDestinations.MODELS_MANAGEMENT_ROUTE -> {
// Collect the needed states
val sortOrder by modelsManagementViewModel.sortOrder.collectAsState()
val isMultiSelectionMode by modelsManagementViewModel.isMultiSelectionMode.collectAsState()
@ -301,35 +303,59 @@ fun AppContent(
composable(AppDestinations.MODEL_LOADING_ROUTE) {
ModelLoadingScreen(
onNavigateBack = { navigationActions.navigateUp() },
onBenchmarkSelected = { prepareJob ->
// Wait for preparation to complete, then navigate if still active
coroutineScope.launch {
prepareJob.join()
if (isActive) { navigationActions.navigateToBenchmark() }
}
},
onConversationSelected = { systemPrompt, prepareJob ->
// Wait for preparation to complete, then navigate if still active
coroutineScope.launch {
prepareJob.join()
if (isActive) { navigationActions.navigateToConversation() }
}
},
onNavigateToBenchmark = { navigationActions.navigateToBenchmark(it) },
onNavigateToConversation = { navigationActions.navigateToConversation(it) },
viewModel = modelLoadingViewModel
)
}
// Benchmark Screen
composable(AppDestinations.BENCHMARK_ROUTE) {
composable(
route = AppDestinations.BENCHMARK_ROUTE_WITH_PARAMS,
arguments = listOf(
navArgument("modelLoadTimeMs") {
type = NavType.LongType
defaultValue = 0L
}
)
) { backStackEntry ->
val modelLoadTimeMs = backStackEntry.arguments?.getLong("modelLoadTimeMs") ?: 0L
val metrics = if (modelLoadTimeMs > 0) {
ModelLoadingMetrics(modelLoadTimeMs)
} else throw IllegalArgumentException("Expecting a valid ModelLoadingMetrics!")
BenchmarkScreen(
loadingMetrics = metrics,
onNavigateBack = { navigationActions.navigateUp() },
viewModel = benchmarkViewModel
)
}
// Conversation Screen
composable(AppDestinations.CONVERSATION_ROUTE) {
composable(
route = AppDestinations.CONVERSATION_ROUTE_WITH_PARAMS,
arguments = listOf(
navArgument("modelLoadTimeMs") {
type = NavType.LongType
defaultValue = 0L
},
navArgument("promptTimeMs") {
type = NavType.LongType
defaultValue = 0L
}
)
) { backStackEntry ->
val modelLoadTimeMs = backStackEntry.arguments?.getLong("modelLoadTimeMs") ?: 0L
val promptTimeMs = backStackEntry.arguments?.getLong("promptTimeMs") ?: 0L
val metrics = if (modelLoadTimeMs > 0) {
ModelLoadingMetrics(
modelLoadingTimeMs = modelLoadTimeMs,
systemPromptProcessingTimeMs = if (promptTimeMs > 0) promptTimeMs else null
)
} else throw IllegalArgumentException("Expecting a valid ModelLoadingMetrics!")
ConversationScreen(
loadingMetrics = metrics,
onNavigateBack = { navigationActions.navigateUp() },
viewModel = conversationViewModel
)

View File

@ -1,6 +1,7 @@
package com.example.llama.revamp.navigation
import androidx.navigation.NavController
import com.example.llama.revamp.engine.ModelLoadingMetrics
/**
* Navigation destinations for the app
@ -9,10 +10,14 @@ object AppDestinations {
// Primary navigation destinations
const val MODEL_SELECTION_ROUTE = "model_selection"
const val MODEL_LOADING_ROUTE = "model_loading"
const val CONVERSATION_ROUTE = "conversation"
const val BENCHMARK_ROUTE = "benchmark"
// Settings destinations (moved from tabs to separate routes)
const val CONVERSATION_ROUTE = "conversation"
const val CONVERSATION_ROUTE_WITH_PARAMS = "conversation/{modelLoadTimeMs}/{promptTimeMs}"
const val BENCHMARK_ROUTE = "benchmark"
const val BENCHMARK_ROUTE_WITH_PARAMS = "benchmark/{modelLoadTimeMs}"
// Settings destinations
const val SETTINGS_GENERAL_ROUTE = "settings_general"
const val MODELS_MANAGEMENT_ROUTE = "models_management"
}
@ -33,12 +38,17 @@ class NavigationActions(private val navController: NavController) {
navController.navigate(AppDestinations.MODEL_LOADING_ROUTE)
}
fun navigateToConversation() {
navController.navigate(AppDestinations.CONVERSATION_ROUTE)
fun navigateToConversation(metrics: ModelLoadingMetrics) {
val route = AppDestinations.CONVERSATION_ROUTE
val modelLoadTimeMs = metrics.modelLoadingTimeMs
val promptTimeMs = metrics.systemPromptProcessingTimeMs ?: 0
navController.navigate("$route/$modelLoadTimeMs/$promptTimeMs")
}
fun navigateToBenchmark() {
navController.navigate(AppDestinations.BENCHMARK_ROUTE)
fun navigateToBenchmark(metrics: ModelLoadingMetrics) {
val route = AppDestinations.BENCHMARK_ROUTE
val modelLoadTimeMs = metrics.modelLoadingTimeMs
navController.navigate("$route/$modelLoadTimeMs")
}
fun navigateToSettingsGeneral() {

View File

@ -24,6 +24,7 @@ import androidx.compose.runtime.getValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp
import com.example.llama.revamp.engine.ModelLoadingMetrics
import com.example.llama.revamp.ui.components.ModelCard
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
import com.example.llama.revamp.ui.theme.MonospacedTextStyle
@ -31,6 +32,7 @@ import com.example.llama.revamp.viewmodel.BenchmarkViewModel
@Composable
fun BenchmarkScreen(
loadingMetrics: ModelLoadingMetrics,
onNavigateBack: () -> Unit,
viewModel: BenchmarkViewModel
) {

View File

@ -57,6 +57,7 @@ import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleEventObserver
import androidx.lifecycle.compose.LocalLifecycleOwner
import com.example.llama.revamp.APP_NAME
import com.example.llama.revamp.engine.ModelLoadingMetrics
import com.example.llama.revamp.ui.components.ModelCardWithSystemPrompt
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
import com.example.llama.revamp.viewmodel.ConversationViewModel
@ -68,6 +69,7 @@ import kotlinx.coroutines.launch
*/
@Composable
fun ConversationScreen(
loadingMetrics: ModelLoadingMetrics,
onNavigateBack: () -> Unit,
viewModel: ConversationViewModel
) {

View File

@ -42,7 +42,6 @@ import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
@ -50,11 +49,10 @@ import androidx.compose.ui.semantics.Role
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp
import com.example.llama.revamp.data.model.SystemPrompt
import com.example.llama.revamp.engine.ModelLoadingMetrics
import com.example.llama.revamp.ui.components.ModelCard
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
enum class Mode {
@ -70,12 +68,10 @@ enum class SystemPromptTab(val label: String) {
@Composable
fun ModelLoadingScreen(
onNavigateBack: () -> Unit,
onBenchmarkSelected: (prepareJob: Job) -> Unit,
onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit,
onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit,
onNavigateToConversation: (ModelLoadingMetrics) -> Unit,
viewModel: ModelLoadingViewModel,
) {
val coroutineScope = rememberCoroutineScope()
val engineState by viewModel.engineState.collectAsState()
val selectedModel by viewModel.selectedModel.collectAsState()
val presetPrompts by viewModel.presetPrompts.collectAsState()
@ -108,22 +104,6 @@ fun ModelLoadingScreen(
// Check if we're in a loading state
val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady
// Mode selection callbacks
val handleBenchmarkSelected = {
val prepareJob = coroutineScope.launch {
viewModel.prepareForBenchmark()
}
onBenchmarkSelected(prepareJob)
}
// TODO-han.yin: refactor this into ViewModel too
val handleConversationSelected = { systemPrompt: String? ->
val prepareJob = coroutineScope.launch {
viewModel.prepareForConversation(systemPrompt)
}
onConversationSelected(systemPrompt, prepareJob)
}
// Handle back navigation requests
BackHandler {
viewModel.onBackPressed(onNavigateBack)
@ -301,7 +281,7 @@ fun ModelLoadingScreen(
Button(
onClick = {
when (selectedMode) {
Mode.BENCHMARK -> handleBenchmarkSelected()
Mode.BENCHMARK -> viewModel.onBenchmarkSelected(onNavigateToBenchmark)
Mode.CONVERSATION -> {
val systemPrompt = if (useSystemPrompt) {
@ -324,7 +304,7 @@ fun ModelLoadingScreen(
}
} else null
handleConversationSelected(systemPrompt)
viewModel.onConversationSelected(systemPrompt, onNavigateToConversation)
}
null -> { /* No mode selected */

View File

@ -3,6 +3,7 @@ package com.example.llama.revamp.viewmodel
import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.SystemPrompt
import com.example.llama.revamp.data.repository.SystemPromptRepository
import com.example.llama.revamp.engine.ModelLoadingMetrics
import com.example.llama.revamp.engine.ModelLoadingService
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.SharingStarted
@ -79,17 +80,23 @@ class ModelLoadingViewModel @Inject constructor(
}
/**
* Prepares the engine for benchmark mode.
* Loads the model, then navigate to [BenchmarkScreen] with [ModelLoadingMetrics]
*/
suspend fun prepareForBenchmark() =
modelLoadingService.loadModelForBenchmark()
fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) =
viewModelScope.launch {
onNavigateToBenchmark(modelLoadingService.loadModelForBenchmark())
}
/**
* Prepare for conversation
* Loads the model, process system prompt if any,
* then navigate to [ConversationScreen] with [ModelLoadingMetrics]
*/
suspend fun prepareForConversation(systemPrompt: String? = null) =
modelLoadingService.loadModelForConversation(systemPrompt)
fun onConversationSelected(
systemPrompt: String? = null,
onNavigateToConversation: (ModelLoadingMetrics) -> Unit
) = viewModelScope.launch {
onNavigateToConversation(modelLoadingService.loadModelForConversation(systemPrompt))
}
companion object {
private val TAG = ModelLoadingViewModel::class.java.simpleName