navigation: sink model loading state management from AppContent down into ModelLoadingScreen; pass ModelLoadingMetrics to Benchmark and Conversation screens
This commit is contained in:
parent
8a682ff85d
commit
a9466c0370
|
|
@ -21,9 +21,12 @@ import androidx.compose.runtime.remember
|
|||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import androidx.navigation.NavType
|
||||
import androidx.navigation.compose.composable
|
||||
import androidx.navigation.compose.currentBackStackEntryAsState
|
||||
import androidx.navigation.compose.rememberNavController
|
||||
import androidx.navigation.navArgument
|
||||
import com.example.llama.revamp.engine.ModelLoadingMetrics
|
||||
import com.example.llama.revamp.navigation.AppDestinations
|
||||
import com.example.llama.revamp.navigation.NavigationActions
|
||||
import com.example.llama.revamp.ui.components.AnimatedNavHost
|
||||
|
|
@ -48,7 +51,6 @@ import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
|
|||
import com.example.llama.revamp.viewmodel.ModelsManagementViewModel
|
||||
import com.example.llama.revamp.viewmodel.PerformanceViewModel
|
||||
import dagger.hilt.android.AndroidEntryPoint
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
@AndroidEntryPoint
|
||||
|
|
@ -110,9 +112,9 @@ fun AppContent(
|
|||
val openDrawer: () -> Unit = { coroutineScope.launch { drawerState.open() } }
|
||||
|
||||
// Create scaffold's top & bottom bar configs based on current route
|
||||
val scaffoldConfig = when (currentRoute) {
|
||||
val scaffoldConfig = when {
|
||||
// Model selection screen
|
||||
AppDestinations.MODEL_SELECTION_ROUTE ->
|
||||
currentRoute == AppDestinations.MODEL_SELECTION_ROUTE ->
|
||||
ScaffoldConfig(
|
||||
topBarConfig = TopBarConfig.Default(
|
||||
title = "Models",
|
||||
|
|
@ -121,7 +123,7 @@ fun AppContent(
|
|||
)
|
||||
|
||||
// Model loading screen
|
||||
AppDestinations.MODEL_LOADING_ROUTE ->
|
||||
currentRoute == AppDestinations.MODEL_LOADING_ROUTE ->
|
||||
ScaffoldConfig(
|
||||
topBarConfig = TopBarConfig.Performance(
|
||||
title = "Load Model",
|
||||
|
|
@ -134,7 +136,7 @@ fun AppContent(
|
|||
)
|
||||
|
||||
// Benchmark screen
|
||||
AppDestinations.BENCHMARK_ROUTE ->
|
||||
currentRoute.startsWith(AppDestinations.BENCHMARK_ROUTE) ->
|
||||
ScaffoldConfig(
|
||||
topBarConfig = TopBarConfig.Performance(
|
||||
title = "Benchmark",
|
||||
|
|
@ -147,7 +149,7 @@ fun AppContent(
|
|||
)
|
||||
|
||||
// Conversation screen
|
||||
AppDestinations.CONVERSATION_ROUTE ->
|
||||
currentRoute.startsWith(AppDestinations.CONVERSATION_ROUTE) ->
|
||||
ScaffoldConfig(
|
||||
topBarConfig = TopBarConfig.Performance(
|
||||
title = "Chat",
|
||||
|
|
@ -160,7 +162,7 @@ fun AppContent(
|
|||
)
|
||||
|
||||
// Settings screen
|
||||
AppDestinations.SETTINGS_GENERAL_ROUTE ->
|
||||
currentRoute == AppDestinations.SETTINGS_GENERAL_ROUTE ->
|
||||
ScaffoldConfig(
|
||||
topBarConfig = TopBarConfig.Default(
|
||||
title = "Settings",
|
||||
|
|
@ -169,7 +171,7 @@ fun AppContent(
|
|||
)
|
||||
|
||||
// Storage management screen
|
||||
AppDestinations.MODELS_MANAGEMENT_ROUTE -> {
|
||||
currentRoute == AppDestinations.MODELS_MANAGEMENT_ROUTE -> {
|
||||
// Collect the needed states
|
||||
val sortOrder by modelsManagementViewModel.sortOrder.collectAsState()
|
||||
val isMultiSelectionMode by modelsManagementViewModel.isMultiSelectionMode.collectAsState()
|
||||
|
|
@ -301,35 +303,59 @@ fun AppContent(
|
|||
composable(AppDestinations.MODEL_LOADING_ROUTE) {
|
||||
ModelLoadingScreen(
|
||||
onNavigateBack = { navigationActions.navigateUp() },
|
||||
onBenchmarkSelected = { prepareJob ->
|
||||
// Wait for preparation to complete, then navigate if still active
|
||||
coroutineScope.launch {
|
||||
prepareJob.join()
|
||||
if (isActive) { navigationActions.navigateToBenchmark() }
|
||||
}
|
||||
},
|
||||
onConversationSelected = { systemPrompt, prepareJob ->
|
||||
// Wait for preparation to complete, then navigate if still active
|
||||
coroutineScope.launch {
|
||||
prepareJob.join()
|
||||
if (isActive) { navigationActions.navigateToConversation() }
|
||||
}
|
||||
},
|
||||
onNavigateToBenchmark = { navigationActions.navigateToBenchmark(it) },
|
||||
onNavigateToConversation = { navigationActions.navigateToConversation(it) },
|
||||
viewModel = modelLoadingViewModel
|
||||
)
|
||||
}
|
||||
|
||||
// Benchmark Screen
|
||||
composable(AppDestinations.BENCHMARK_ROUTE) {
|
||||
composable(
|
||||
route = AppDestinations.BENCHMARK_ROUTE_WITH_PARAMS,
|
||||
arguments = listOf(
|
||||
navArgument("modelLoadTimeMs") {
|
||||
type = NavType.LongType
|
||||
defaultValue = 0L
|
||||
}
|
||||
)
|
||||
) { backStackEntry ->
|
||||
val modelLoadTimeMs = backStackEntry.arguments?.getLong("modelLoadTimeMs") ?: 0L
|
||||
val metrics = if (modelLoadTimeMs > 0) {
|
||||
ModelLoadingMetrics(modelLoadTimeMs)
|
||||
} else throw IllegalArgumentException("Expecting a valid ModelLoadingMetrics!")
|
||||
|
||||
BenchmarkScreen(
|
||||
loadingMetrics = metrics,
|
||||
onNavigateBack = { navigationActions.navigateUp() },
|
||||
viewModel = benchmarkViewModel
|
||||
)
|
||||
}
|
||||
|
||||
// Conversation Screen
|
||||
composable(AppDestinations.CONVERSATION_ROUTE) {
|
||||
composable(
|
||||
route = AppDestinations.CONVERSATION_ROUTE_WITH_PARAMS,
|
||||
arguments = listOf(
|
||||
navArgument("modelLoadTimeMs") {
|
||||
type = NavType.LongType
|
||||
defaultValue = 0L
|
||||
},
|
||||
navArgument("promptTimeMs") {
|
||||
type = NavType.LongType
|
||||
defaultValue = 0L
|
||||
}
|
||||
)
|
||||
) { backStackEntry ->
|
||||
val modelLoadTimeMs = backStackEntry.arguments?.getLong("modelLoadTimeMs") ?: 0L
|
||||
val promptTimeMs = backStackEntry.arguments?.getLong("promptTimeMs") ?: 0L
|
||||
val metrics = if (modelLoadTimeMs > 0) {
|
||||
ModelLoadingMetrics(
|
||||
modelLoadingTimeMs = modelLoadTimeMs,
|
||||
systemPromptProcessingTimeMs = if (promptTimeMs > 0) promptTimeMs else null
|
||||
)
|
||||
} else throw IllegalArgumentException("Expecting a valid ModelLoadingMetrics!")
|
||||
|
||||
ConversationScreen(
|
||||
loadingMetrics = metrics,
|
||||
onNavigateBack = { navigationActions.navigateUp() },
|
||||
viewModel = conversationViewModel
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package com.example.llama.revamp.navigation
|
||||
|
||||
import androidx.navigation.NavController
|
||||
import com.example.llama.revamp.engine.ModelLoadingMetrics
|
||||
|
||||
/**
|
||||
* Navigation destinations for the app
|
||||
|
|
@ -9,10 +10,14 @@ object AppDestinations {
|
|||
// Primary navigation destinations
|
||||
const val MODEL_SELECTION_ROUTE = "model_selection"
|
||||
const val MODEL_LOADING_ROUTE = "model_loading"
|
||||
const val CONVERSATION_ROUTE = "conversation"
|
||||
const val BENCHMARK_ROUTE = "benchmark"
|
||||
|
||||
// Settings destinations (moved from tabs to separate routes)
|
||||
const val CONVERSATION_ROUTE = "conversation"
|
||||
const val CONVERSATION_ROUTE_WITH_PARAMS = "conversation/{modelLoadTimeMs}/{promptTimeMs}"
|
||||
|
||||
const val BENCHMARK_ROUTE = "benchmark"
|
||||
const val BENCHMARK_ROUTE_WITH_PARAMS = "benchmark/{modelLoadTimeMs}"
|
||||
|
||||
// Settings destinations
|
||||
const val SETTINGS_GENERAL_ROUTE = "settings_general"
|
||||
const val MODELS_MANAGEMENT_ROUTE = "models_management"
|
||||
}
|
||||
|
|
@ -33,12 +38,17 @@ class NavigationActions(private val navController: NavController) {
|
|||
navController.navigate(AppDestinations.MODEL_LOADING_ROUTE)
|
||||
}
|
||||
|
||||
fun navigateToConversation() {
|
||||
navController.navigate(AppDestinations.CONVERSATION_ROUTE)
|
||||
fun navigateToConversation(metrics: ModelLoadingMetrics) {
|
||||
val route = AppDestinations.CONVERSATION_ROUTE
|
||||
val modelLoadTimeMs = metrics.modelLoadingTimeMs
|
||||
val promptTimeMs = metrics.systemPromptProcessingTimeMs ?: 0
|
||||
navController.navigate("$route/$modelLoadTimeMs/$promptTimeMs")
|
||||
}
|
||||
|
||||
fun navigateToBenchmark() {
|
||||
navController.navigate(AppDestinations.BENCHMARK_ROUTE)
|
||||
fun navigateToBenchmark(metrics: ModelLoadingMetrics) {
|
||||
val route = AppDestinations.BENCHMARK_ROUTE
|
||||
val modelLoadTimeMs = metrics.modelLoadingTimeMs
|
||||
navController.navigate("$route/$modelLoadTimeMs")
|
||||
}
|
||||
|
||||
fun navigateToSettingsGeneral() {
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import androidx.compose.runtime.getValue
|
|||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.example.llama.revamp.engine.ModelLoadingMetrics
|
||||
import com.example.llama.revamp.ui.components.ModelCard
|
||||
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
|
||||
import com.example.llama.revamp.ui.theme.MonospacedTextStyle
|
||||
|
|
@ -31,6 +32,7 @@ import com.example.llama.revamp.viewmodel.BenchmarkViewModel
|
|||
|
||||
@Composable
|
||||
fun BenchmarkScreen(
|
||||
loadingMetrics: ModelLoadingMetrics,
|
||||
onNavigateBack: () -> Unit,
|
||||
viewModel: BenchmarkViewModel
|
||||
) {
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ import androidx.lifecycle.Lifecycle
|
|||
import androidx.lifecycle.LifecycleEventObserver
|
||||
import androidx.lifecycle.compose.LocalLifecycleOwner
|
||||
import com.example.llama.revamp.APP_NAME
|
||||
import com.example.llama.revamp.engine.ModelLoadingMetrics
|
||||
import com.example.llama.revamp.ui.components.ModelCardWithSystemPrompt
|
||||
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
|
||||
import com.example.llama.revamp.viewmodel.ConversationViewModel
|
||||
|
|
@ -68,6 +69,7 @@ import kotlinx.coroutines.launch
|
|||
*/
|
||||
@Composable
|
||||
fun ConversationScreen(
|
||||
loadingMetrics: ModelLoadingMetrics,
|
||||
onNavigateBack: () -> Unit,
|
||||
viewModel: ConversationViewModel
|
||||
) {
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ import androidx.compose.runtime.collectAsState
|
|||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
|
|
@ -50,11 +49,10 @@ import androidx.compose.ui.semantics.Role
|
|||
import androidx.compose.ui.text.style.TextOverflow
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.example.llama.revamp.data.model.SystemPrompt
|
||||
import com.example.llama.revamp.engine.ModelLoadingMetrics
|
||||
import com.example.llama.revamp.ui.components.ModelCard
|
||||
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
|
||||
import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
|
||||
enum class Mode {
|
||||
|
|
@ -70,12 +68,10 @@ enum class SystemPromptTab(val label: String) {
|
|||
@Composable
|
||||
fun ModelLoadingScreen(
|
||||
onNavigateBack: () -> Unit,
|
||||
onBenchmarkSelected: (prepareJob: Job) -> Unit,
|
||||
onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit,
|
||||
onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit,
|
||||
onNavigateToConversation: (ModelLoadingMetrics) -> Unit,
|
||||
viewModel: ModelLoadingViewModel,
|
||||
) {
|
||||
val coroutineScope = rememberCoroutineScope()
|
||||
|
||||
val engineState by viewModel.engineState.collectAsState()
|
||||
val selectedModel by viewModel.selectedModel.collectAsState()
|
||||
val presetPrompts by viewModel.presetPrompts.collectAsState()
|
||||
|
|
@ -108,22 +104,6 @@ fun ModelLoadingScreen(
|
|||
// Check if we're in a loading state
|
||||
val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady
|
||||
|
||||
// Mode selection callbacks
|
||||
val handleBenchmarkSelected = {
|
||||
val prepareJob = coroutineScope.launch {
|
||||
viewModel.prepareForBenchmark()
|
||||
}
|
||||
onBenchmarkSelected(prepareJob)
|
||||
}
|
||||
|
||||
// TODO-han.yin: refactor this into ViewModel too
|
||||
val handleConversationSelected = { systemPrompt: String? ->
|
||||
val prepareJob = coroutineScope.launch {
|
||||
viewModel.prepareForConversation(systemPrompt)
|
||||
}
|
||||
onConversationSelected(systemPrompt, prepareJob)
|
||||
}
|
||||
|
||||
// Handle back navigation requests
|
||||
BackHandler {
|
||||
viewModel.onBackPressed(onNavigateBack)
|
||||
|
|
@ -301,7 +281,7 @@ fun ModelLoadingScreen(
|
|||
Button(
|
||||
onClick = {
|
||||
when (selectedMode) {
|
||||
Mode.BENCHMARK -> handleBenchmarkSelected()
|
||||
Mode.BENCHMARK -> viewModel.onBenchmarkSelected(onNavigateToBenchmark)
|
||||
|
||||
Mode.CONVERSATION -> {
|
||||
val systemPrompt = if (useSystemPrompt) {
|
||||
|
|
@ -324,7 +304,7 @@ fun ModelLoadingScreen(
|
|||
}
|
||||
} else null
|
||||
|
||||
handleConversationSelected(systemPrompt)
|
||||
viewModel.onConversationSelected(systemPrompt, onNavigateToConversation)
|
||||
}
|
||||
|
||||
null -> { /* No mode selected */
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package com.example.llama.revamp.viewmodel
|
|||
import androidx.lifecycle.viewModelScope
|
||||
import com.example.llama.revamp.data.model.SystemPrompt
|
||||
import com.example.llama.revamp.data.repository.SystemPromptRepository
|
||||
import com.example.llama.revamp.engine.ModelLoadingMetrics
|
||||
import com.example.llama.revamp.engine.ModelLoadingService
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import kotlinx.coroutines.flow.SharingStarted
|
||||
|
|
@ -79,17 +80,23 @@ class ModelLoadingViewModel @Inject constructor(
|
|||
}
|
||||
|
||||
/**
|
||||
* Prepares the engine for benchmark mode.
|
||||
* Loads the model, then navigate to [BenchmarkScreen] with [ModelLoadingMetrics]
|
||||
*/
|
||||
suspend fun prepareForBenchmark() =
|
||||
modelLoadingService.loadModelForBenchmark()
|
||||
fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) =
|
||||
viewModelScope.launch {
|
||||
onNavigateToBenchmark(modelLoadingService.loadModelForBenchmark())
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare for conversation
|
||||
* Loads the model, process system prompt if any,
|
||||
* then navigate to [ConversationScreen] with [ModelLoadingMetrics]
|
||||
*/
|
||||
suspend fun prepareForConversation(systemPrompt: String? = null) =
|
||||
modelLoadingService.loadModelForConversation(systemPrompt)
|
||||
|
||||
fun onConversationSelected(
|
||||
systemPrompt: String? = null,
|
||||
onNavigateToConversation: (ModelLoadingMetrics) -> Unit
|
||||
) = viewModelScope.launch {
|
||||
onNavigateToConversation(modelLoadingService.loadModelForConversation(systemPrompt))
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val TAG = ModelLoadingViewModel::class.java.simpleName
|
||||
|
|
|
|||
Loading…
Reference in New Issue