diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt index 8b552bc1a7..846fbea9ad 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt @@ -44,7 +44,6 @@ import com.example.llama.revamp.ui.components.NavigationIcon import com.example.llama.revamp.ui.components.ScaffoldConfig import com.example.llama.revamp.ui.components.ScaffoldEvent import com.example.llama.revamp.ui.components.TopBarConfig -import com.example.llama.revamp.ui.components.UnloadModelConfirmationDialog import com.example.llama.revamp.ui.screens.BenchmarkScreen import com.example.llama.revamp.ui.screens.ConversationScreen import com.example.llama.revamp.ui.screens.ModelLoadingScreen @@ -52,8 +51,10 @@ import com.example.llama.revamp.ui.screens.ModelSelectionScreen import com.example.llama.revamp.ui.screens.ModelsManagementScreen import com.example.llama.revamp.ui.screens.SettingsGeneralScreen import com.example.llama.revamp.ui.theme.LlamaTheme +import com.example.llama.revamp.viewmodel.BenchmarkViewModel import com.example.llama.revamp.viewmodel.ConversationViewModel import com.example.llama.revamp.viewmodel.MainViewModel +import com.example.llama.revamp.viewmodel.ModelLoadingViewModel import com.example.llama.revamp.viewmodel.ModelsManagementViewModel import com.example.llama.revamp.viewmodel.PerformanceViewModel import dagger.hilt.android.AndroidEntryPoint @@ -81,28 +82,16 @@ class MainActivity : ComponentActivity() { fun AppContent( mainViewModel: MainViewModel = hiltViewModel(), performanceViewModel: PerformanceViewModel = hiltViewModel(), - modelsManagementViewModel: ModelsManagementViewModel = hiltViewModel(), + modelLoadingViewModel: ModelLoadingViewModel = hiltViewModel(), + benchmarkViewModel: BenchmarkViewModel = hiltViewModel(), conversationViewModel: ConversationViewModel = hiltViewModel(), + modelsManagementViewModel: ModelsManagementViewModel = hiltViewModel(), ) { - val lifecycleOwner = LocalLifecycleOwner.current val coroutineScope = rememberCoroutineScope() val snackbarHostState = remember { SnackbarHostState() } // Inference engine state val engineState by mainViewModel.engineState.collectAsState() - val isModelUninterruptible by remember(engineState) { - derivedStateOf { - engineState is State.LoadingModel - || engineState is State.Benchmarking - || engineState is State.ProcessingUserPrompt - || engineState is State.ProcessingSystemPrompt - } - } - val isModelLoaded by remember(engineState) { - derivedStateOf { - engineState !is State.Uninitialized && engineState !is State.LibraryLoaded - } - } // Metric states for scaffolds val memoryUsage by performanceViewModel.memoryUsage.collectAsState() @@ -117,24 +106,6 @@ fun AppContent( val currentRoute by remember(navBackStackEntry) { derivedStateOf { navBackStackEntry?.destination?.route ?: "" } } - var pendingNavigation by remember { mutableStateOf<(() -> Unit)?>(null) } - - // Model unloading confirmation - var showUnloadDialog by remember { mutableStateOf(false) } - val handleBackWithModelCheck = { - when { - isModelUninterruptible -> { - // If model is non-interruptible at all, ignore the request - } - isModelLoaded -> { - showUnloadDialog = true - pendingNavigation = { navigationActions.navigateUp() } - } - else -> { - navigationActions.navigateUp() - } - } - } // Determine if drawer gestures should be enabled based on route val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed) @@ -164,22 +135,35 @@ fun AppContent( ScaffoldConfig( topBarConfig = TopBarConfig.Performance( title = "Load Model", - navigationIcon = NavigationIcon.Back(handleBackWithModelCheck), + navigationIcon = NavigationIcon.Back { navigationActions.navigateUp() }, memoryMetrics = memoryUsage, temperatureInfo = null ) ) - // Benchmark and Conversation screens - AppDestinations.BENCHMARK_ROUTE, AppDestinations.CONVERSATION_ROUTE -> + // Benchmark screen + AppDestinations.BENCHMARK_ROUTE -> ScaffoldConfig( topBarConfig = TopBarConfig.Performance( - title = when(currentRoute) { - AppDestinations.CONVERSATION_ROUTE -> "Chat" - AppDestinations.BENCHMARK_ROUTE -> "Benchmark" - else -> "LlamaAndroid" - }, - navigationIcon = NavigationIcon.Back(handleBackWithModelCheck), + title = "Benchmark", + navigationIcon = NavigationIcon.Back { + android.util.Log.w("JOJO", "Benchmark navigation icon tapped") + benchmarkViewModel.onBackPressed() + }, + memoryMetrics = memoryUsage, + temperatureInfo = Pair(temperatureInfo, useFahrenheit) + ) + ) + + // Conversation screen + AppDestinations.CONVERSATION_ROUTE -> + ScaffoldConfig( + topBarConfig = TopBarConfig.Performance( + title = "Chat", + navigationIcon = NavigationIcon.Back { + // TODO-han.yin: uncomment after [ConversationViewModel] done + // conversationViewModel.onBackPressed() + }, memoryMetrics = memoryUsage, temperatureInfo = Pair(temperatureInfo, useFahrenheit) ) @@ -292,15 +276,6 @@ fun AppContent( } } - // Register system back handler - BackHandlerSetup( - lifecycleOwner = lifecycleOwner, - backDispatcher = LocalOnBackPressedDispatcherOwner.current?.onBackPressedDispatcher, - currentRoute = currentRoute, - isModelLoaded = isModelLoaded, - handleBackWithModelCheck = handleBackWithModelCheck - ) - // Main UI hierarchy AppNavigationDrawer( drawerState = drawerState, @@ -338,69 +313,46 @@ fun AppContent( engineState = engineState, onBenchmarkSelected = { prepareJob -> // Wait for preparation to complete, then navigate if still active - val loadingJob = coroutineScope.launch { + coroutineScope.launch { prepareJob.join() if (isActive) { navigationActions.navigateToBenchmark() } } - - pendingNavigation = { - prepareJob.cancel() - loadingJob.cancel() - navigationActions.navigateUp() - } }, onConversationSelected = { systemPrompt, prepareJob -> // Wait for preparation to complete, then navigate if still active - val loadingJob = coroutineScope.launch { + coroutineScope.launch { prepareJob.join() if (isActive) { navigationActions.navigateToConversation() } } - - pendingNavigation = { - prepareJob.cancel() - loadingJob.cancel() - navigationActions.navigateUp() - } - }, - onBackPressed = { - // Need to unload model before going back - handleBackWithModelCheck() }, + viewModel = modelLoadingViewModel ) } // Benchmark Screen composable(AppDestinations.BENCHMARK_ROUTE) { BenchmarkScreen( - onBackPressed = { - // Need to unload model before going back - handleBackWithModelCheck() - } + onNavigateBack = { navigationActions.navigateUp() }, + viewModel = benchmarkViewModel ) } // Conversation Screen composable(AppDestinations.CONVERSATION_ROUTE) { ConversationScreen( - onBackPressed = { - // Need to unload model before going back - handleBackWithModelCheck() - }, + onNavigateBack = { navigationActions.navigateUp() }, viewModel = conversationViewModel ) } // Settings General Screen composable(AppDestinations.SETTINGS_GENERAL_ROUTE) { - SettingsGeneralScreen( - onBackPressed = { navigationActions.navigateUp() }, - ) + SettingsGeneralScreen() } // Models Management Screen composable(AppDestinations.MODELS_MANAGEMENT_ROUTE) { ModelsManagementScreen( - onBackPressed = { navigationActions.navigateUp() }, onScaffoldEvent = handleScaffoldEvent, viewModel = modelsManagementViewModel ) @@ -408,74 +360,4 @@ fun AppContent( } } } - - // Model unload confirmation dialog - var isUnloading by remember { mutableStateOf(false) } - - if (showUnloadDialog) { - UnloadModelConfirmationDialog( - onConfirm = { - isUnloading = true - coroutineScope.launch { - // TODO-han.yin: Clear conversation upon normal exiting - // Handle screen specific cleanups - when(engineState) { - is State.Benchmarking -> {} - is State.Generating -> conversationViewModel.clearConversation() - else -> {} - } - - // Unload model - mainViewModel.unloadModel() - isUnloading = false - showUnloadDialog = false - pendingNavigation?.invoke() - pendingNavigation = null - } - }, - onDismiss = { - if (!isUnloading) { - showUnloadDialog = false - pendingNavigation = null - } - }, - isUnloading = isUnloading - ) - } -} - -@Composable -private fun BackHandlerSetup( - lifecycleOwner: LifecycleOwner, - backDispatcher: OnBackPressedDispatcher?, - currentRoute: String, - isModelLoaded: Boolean, - handleBackWithModelCheck: () -> Unit -) { - val routeNeedsModelUnloading = currentRoute in listOf( - AppDestinations.CONVERSATION_ROUTE, - AppDestinations.BENCHMARK_ROUTE, - AppDestinations.MODEL_LOADING_ROUTE - ) - - DisposableEffect(lifecycleOwner, backDispatcher, currentRoute, isModelLoaded) { - android.util.Log.w("JOJO", "BackHandlerSetup: currentRoute: $currentRoute") - - val callback = object : OnBackPressedCallback( - routeNeedsModelUnloading && isModelLoaded - ) { - override fun handleOnBackPressed() { - handleBackWithModelCheck() - } - } - - backDispatcher?.addCallback(lifecycleOwner, callback) - onDispose { callback.remove() } - } - - BackHandler( - enabled = routeNeedsModelUnloading && isModelLoaded - ) { - handleBackWithModelCheck() - } }