UI: refactor back handling by removing centralized BackHandlerSetup and UnloadModelConfirmationDialog from AppContent

This commit is contained in:
Han Yin 2025-04-18 12:10:15 -07:00
parent c08d02d233
commit 8203ddb97a
1 changed files with 34 additions and 152 deletions

View File

@ -44,7 +44,6 @@ import com.example.llama.revamp.ui.components.NavigationIcon
import com.example.llama.revamp.ui.components.ScaffoldConfig
import com.example.llama.revamp.ui.components.ScaffoldEvent
import com.example.llama.revamp.ui.components.TopBarConfig
import com.example.llama.revamp.ui.components.UnloadModelConfirmationDialog
import com.example.llama.revamp.ui.screens.BenchmarkScreen
import com.example.llama.revamp.ui.screens.ConversationScreen
import com.example.llama.revamp.ui.screens.ModelLoadingScreen
@ -52,8 +51,10 @@ import com.example.llama.revamp.ui.screens.ModelSelectionScreen
import com.example.llama.revamp.ui.screens.ModelsManagementScreen
import com.example.llama.revamp.ui.screens.SettingsGeneralScreen
import com.example.llama.revamp.ui.theme.LlamaTheme
import com.example.llama.revamp.viewmodel.BenchmarkViewModel
import com.example.llama.revamp.viewmodel.ConversationViewModel
import com.example.llama.revamp.viewmodel.MainViewModel
import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
import com.example.llama.revamp.viewmodel.ModelsManagementViewModel
import com.example.llama.revamp.viewmodel.PerformanceViewModel
import dagger.hilt.android.AndroidEntryPoint
@ -81,28 +82,16 @@ class MainActivity : ComponentActivity() {
fun AppContent(
mainViewModel: MainViewModel = hiltViewModel(),
performanceViewModel: PerformanceViewModel = hiltViewModel(),
modelsManagementViewModel: ModelsManagementViewModel = hiltViewModel(),
modelLoadingViewModel: ModelLoadingViewModel = hiltViewModel(),
benchmarkViewModel: BenchmarkViewModel = hiltViewModel(),
conversationViewModel: ConversationViewModel = hiltViewModel(),
modelsManagementViewModel: ModelsManagementViewModel = hiltViewModel(),
) {
val lifecycleOwner = LocalLifecycleOwner.current
val coroutineScope = rememberCoroutineScope()
val snackbarHostState = remember { SnackbarHostState() }
// Inference engine state
val engineState by mainViewModel.engineState.collectAsState()
val isModelUninterruptible by remember(engineState) {
derivedStateOf {
engineState is State.LoadingModel
|| engineState is State.Benchmarking
|| engineState is State.ProcessingUserPrompt
|| engineState is State.ProcessingSystemPrompt
}
}
val isModelLoaded by remember(engineState) {
derivedStateOf {
engineState !is State.Uninitialized && engineState !is State.LibraryLoaded
}
}
// Metric states for scaffolds
val memoryUsage by performanceViewModel.memoryUsage.collectAsState()
@ -117,24 +106,6 @@ fun AppContent(
val currentRoute by remember(navBackStackEntry) {
derivedStateOf { navBackStackEntry?.destination?.route ?: "" }
}
var pendingNavigation by remember { mutableStateOf<(() -> Unit)?>(null) }
// Model unloading confirmation
var showUnloadDialog by remember { mutableStateOf(false) }
val handleBackWithModelCheck = {
when {
isModelUninterruptible -> {
// If model is non-interruptible at all, ignore the request
}
isModelLoaded -> {
showUnloadDialog = true
pendingNavigation = { navigationActions.navigateUp() }
}
else -> {
navigationActions.navigateUp()
}
}
}
// Determine if drawer gestures should be enabled based on route
val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed)
@ -164,22 +135,35 @@ fun AppContent(
ScaffoldConfig(
topBarConfig = TopBarConfig.Performance(
title = "Load Model",
navigationIcon = NavigationIcon.Back(handleBackWithModelCheck),
navigationIcon = NavigationIcon.Back { navigationActions.navigateUp() },
memoryMetrics = memoryUsage,
temperatureInfo = null
)
)
// Benchmark and Conversation screens
AppDestinations.BENCHMARK_ROUTE, AppDestinations.CONVERSATION_ROUTE ->
// Benchmark screen
AppDestinations.BENCHMARK_ROUTE ->
ScaffoldConfig(
topBarConfig = TopBarConfig.Performance(
title = when(currentRoute) {
AppDestinations.CONVERSATION_ROUTE -> "Chat"
AppDestinations.BENCHMARK_ROUTE -> "Benchmark"
else -> "LlamaAndroid"
},
navigationIcon = NavigationIcon.Back(handleBackWithModelCheck),
title = "Benchmark",
navigationIcon = NavigationIcon.Back {
android.util.Log.w("JOJO", "Benchmark navigation icon tapped")
benchmarkViewModel.onBackPressed()
},
memoryMetrics = memoryUsage,
temperatureInfo = Pair(temperatureInfo, useFahrenheit)
)
)
// Conversation screen
AppDestinations.CONVERSATION_ROUTE ->
ScaffoldConfig(
topBarConfig = TopBarConfig.Performance(
title = "Chat",
navigationIcon = NavigationIcon.Back {
// TODO-han.yin: uncomment after [ConversationViewModel] done
// conversationViewModel.onBackPressed()
},
memoryMetrics = memoryUsage,
temperatureInfo = Pair(temperatureInfo, useFahrenheit)
)
@ -292,15 +276,6 @@ fun AppContent(
}
}
// Register system back handler
BackHandlerSetup(
lifecycleOwner = lifecycleOwner,
backDispatcher = LocalOnBackPressedDispatcherOwner.current?.onBackPressedDispatcher,
currentRoute = currentRoute,
isModelLoaded = isModelLoaded,
handleBackWithModelCheck = handleBackWithModelCheck
)
// Main UI hierarchy
AppNavigationDrawer(
drawerState = drawerState,
@ -338,69 +313,46 @@ fun AppContent(
engineState = engineState,
onBenchmarkSelected = { prepareJob ->
// Wait for preparation to complete, then navigate if still active
val loadingJob = coroutineScope.launch {
coroutineScope.launch {
prepareJob.join()
if (isActive) { navigationActions.navigateToBenchmark() }
}
pendingNavigation = {
prepareJob.cancel()
loadingJob.cancel()
navigationActions.navigateUp()
}
},
onConversationSelected = { systemPrompt, prepareJob ->
// Wait for preparation to complete, then navigate if still active
val loadingJob = coroutineScope.launch {
coroutineScope.launch {
prepareJob.join()
if (isActive) { navigationActions.navigateToConversation() }
}
pendingNavigation = {
prepareJob.cancel()
loadingJob.cancel()
navigationActions.navigateUp()
}
},
onBackPressed = {
// Need to unload model before going back
handleBackWithModelCheck()
},
viewModel = modelLoadingViewModel
)
}
// Benchmark Screen
composable(AppDestinations.BENCHMARK_ROUTE) {
BenchmarkScreen(
onBackPressed = {
// Need to unload model before going back
handleBackWithModelCheck()
}
onNavigateBack = { navigationActions.navigateUp() },
viewModel = benchmarkViewModel
)
}
// Conversation Screen
composable(AppDestinations.CONVERSATION_ROUTE) {
ConversationScreen(
onBackPressed = {
// Need to unload model before going back
handleBackWithModelCheck()
},
onNavigateBack = { navigationActions.navigateUp() },
viewModel = conversationViewModel
)
}
// Settings General Screen
composable(AppDestinations.SETTINGS_GENERAL_ROUTE) {
SettingsGeneralScreen(
onBackPressed = { navigationActions.navigateUp() },
)
SettingsGeneralScreen()
}
// Models Management Screen
composable(AppDestinations.MODELS_MANAGEMENT_ROUTE) {
ModelsManagementScreen(
onBackPressed = { navigationActions.navigateUp() },
onScaffoldEvent = handleScaffoldEvent,
viewModel = modelsManagementViewModel
)
@ -408,74 +360,4 @@ fun AppContent(
}
}
}
// Model unload confirmation dialog
var isUnloading by remember { mutableStateOf(false) }
if (showUnloadDialog) {
UnloadModelConfirmationDialog(
onConfirm = {
isUnloading = true
coroutineScope.launch {
// TODO-han.yin: Clear conversation upon normal exiting
// Handle screen specific cleanups
when(engineState) {
is State.Benchmarking -> {}
is State.Generating -> conversationViewModel.clearConversation()
else -> {}
}
// Unload model
mainViewModel.unloadModel()
isUnloading = false
showUnloadDialog = false
pendingNavigation?.invoke()
pendingNavigation = null
}
},
onDismiss = {
if (!isUnloading) {
showUnloadDialog = false
pendingNavigation = null
}
},
isUnloading = isUnloading
)
}
}
@Composable
private fun BackHandlerSetup(
lifecycleOwner: LifecycleOwner,
backDispatcher: OnBackPressedDispatcher?,
currentRoute: String,
isModelLoaded: Boolean,
handleBackWithModelCheck: () -> Unit
) {
val routeNeedsModelUnloading = currentRoute in listOf(
AppDestinations.CONVERSATION_ROUTE,
AppDestinations.BENCHMARK_ROUTE,
AppDestinations.MODEL_LOADING_ROUTE
)
DisposableEffect(lifecycleOwner, backDispatcher, currentRoute, isModelLoaded) {
android.util.Log.w("JOJO", "BackHandlerSetup: currentRoute: $currentRoute")
val callback = object : OnBackPressedCallback(
routeNeedsModelUnloading && isModelLoaded
) {
override fun handleOnBackPressed() {
handleBackWithModelCheck()
}
}
backDispatcher?.addCallback(lifecycleOwner, callback)
onDispose { callback.remove() }
}
BackHandler(
enabled = routeNeedsModelUnloading && isModelLoaded
) {
handleBackWithModelCheck()
}
}