diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt index 10004641ce..89aa7d2408 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt @@ -21,9 +21,12 @@ import androidx.compose.runtime.remember import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.ui.Modifier import androidx.hilt.navigation.compose.hiltViewModel +import androidx.navigation.NavType import androidx.navigation.compose.composable import androidx.navigation.compose.currentBackStackEntryAsState import androidx.navigation.compose.rememberNavController +import androidx.navigation.navArgument +import com.example.llama.revamp.engine.ModelLoadingMetrics import com.example.llama.revamp.navigation.AppDestinations import com.example.llama.revamp.navigation.NavigationActions import com.example.llama.revamp.ui.components.AnimatedNavHost @@ -48,7 +51,6 @@ import com.example.llama.revamp.viewmodel.ModelLoadingViewModel import com.example.llama.revamp.viewmodel.ModelsManagementViewModel import com.example.llama.revamp.viewmodel.PerformanceViewModel import dagger.hilt.android.AndroidEntryPoint -import kotlinx.coroutines.isActive import kotlinx.coroutines.launch @AndroidEntryPoint @@ -110,9 +112,9 @@ fun AppContent( val openDrawer: () -> Unit = { coroutineScope.launch { drawerState.open() } } // Create scaffold's top & bottom bar configs based on current route - val scaffoldConfig = when (currentRoute) { + val scaffoldConfig = when { // Model selection screen - AppDestinations.MODEL_SELECTION_ROUTE -> + currentRoute == AppDestinations.MODEL_SELECTION_ROUTE -> ScaffoldConfig( topBarConfig = TopBarConfig.Default( title = "Models", @@ -121,7 +123,7 @@ fun AppContent( ) // Model loading screen - AppDestinations.MODEL_LOADING_ROUTE -> + currentRoute == AppDestinations.MODEL_LOADING_ROUTE -> ScaffoldConfig( topBarConfig = TopBarConfig.Performance( title = "Load Model", @@ -134,7 +136,7 @@ fun AppContent( ) // Benchmark screen - AppDestinations.BENCHMARK_ROUTE -> + currentRoute.startsWith(AppDestinations.BENCHMARK_ROUTE) -> ScaffoldConfig( topBarConfig = TopBarConfig.Performance( title = "Benchmark", @@ -147,7 +149,7 @@ fun AppContent( ) // Conversation screen - AppDestinations.CONVERSATION_ROUTE -> + currentRoute.startsWith(AppDestinations.CONVERSATION_ROUTE) -> ScaffoldConfig( topBarConfig = TopBarConfig.Performance( title = "Chat", @@ -160,7 +162,7 @@ fun AppContent( ) // Settings screen - AppDestinations.SETTINGS_GENERAL_ROUTE -> + currentRoute == AppDestinations.SETTINGS_GENERAL_ROUTE -> ScaffoldConfig( topBarConfig = TopBarConfig.Default( title = "Settings", @@ -169,7 +171,7 @@ fun AppContent( ) // Storage management screen - AppDestinations.MODELS_MANAGEMENT_ROUTE -> { + currentRoute == AppDestinations.MODELS_MANAGEMENT_ROUTE -> { // Collect the needed states val sortOrder by modelsManagementViewModel.sortOrder.collectAsState() val isMultiSelectionMode by modelsManagementViewModel.isMultiSelectionMode.collectAsState() @@ -301,35 +303,59 @@ fun AppContent( composable(AppDestinations.MODEL_LOADING_ROUTE) { ModelLoadingScreen( onNavigateBack = { navigationActions.navigateUp() }, - onBenchmarkSelected = { prepareJob -> - // Wait for preparation to complete, then navigate if still active - coroutineScope.launch { - prepareJob.join() - if (isActive) { navigationActions.navigateToBenchmark() } - } - }, - onConversationSelected = { systemPrompt, prepareJob -> - // Wait for preparation to complete, then navigate if still active - coroutineScope.launch { - prepareJob.join() - if (isActive) { navigationActions.navigateToConversation() } - } - }, + onNavigateToBenchmark = { navigationActions.navigateToBenchmark(it) }, + onNavigateToConversation = { navigationActions.navigateToConversation(it) }, viewModel = modelLoadingViewModel ) } // Benchmark Screen - composable(AppDestinations.BENCHMARK_ROUTE) { + composable( + route = AppDestinations.BENCHMARK_ROUTE_WITH_PARAMS, + arguments = listOf( + navArgument("modelLoadTimeMs") { + type = NavType.LongType + defaultValue = 0L + } + ) + ) { backStackEntry -> + val modelLoadTimeMs = backStackEntry.arguments?.getLong("modelLoadTimeMs") ?: 0L + val metrics = if (modelLoadTimeMs > 0) { + ModelLoadingMetrics(modelLoadTimeMs) + } else throw IllegalArgumentException("Expecting a valid ModelLoadingMetrics!") + BenchmarkScreen( + loadingMetrics = metrics, onNavigateBack = { navigationActions.navigateUp() }, viewModel = benchmarkViewModel ) } // Conversation Screen - composable(AppDestinations.CONVERSATION_ROUTE) { + composable( + route = AppDestinations.CONVERSATION_ROUTE_WITH_PARAMS, + arguments = listOf( + navArgument("modelLoadTimeMs") { + type = NavType.LongType + defaultValue = 0L + }, + navArgument("promptTimeMs") { + type = NavType.LongType + defaultValue = 0L + } + ) + ) { backStackEntry -> + val modelLoadTimeMs = backStackEntry.arguments?.getLong("modelLoadTimeMs") ?: 0L + val promptTimeMs = backStackEntry.arguments?.getLong("promptTimeMs") ?: 0L + val metrics = if (modelLoadTimeMs > 0) { + ModelLoadingMetrics( + modelLoadingTimeMs = modelLoadTimeMs, + systemPromptProcessingTimeMs = if (promptTimeMs > 0) promptTimeMs else null + ) + } else throw IllegalArgumentException("Expecting a valid ModelLoadingMetrics!") + ConversationScreen( + loadingMetrics = metrics, onNavigateBack = { navigationActions.navigateUp() }, viewModel = conversationViewModel ) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/navigation/AppDestinations.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/navigation/AppDestinations.kt index e10b65f2a4..9392dab2d2 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/navigation/AppDestinations.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/navigation/AppDestinations.kt @@ -1,6 +1,7 @@ package com.example.llama.revamp.navigation import androidx.navigation.NavController +import com.example.llama.revamp.engine.ModelLoadingMetrics /** * Navigation destinations for the app @@ -9,10 +10,14 @@ object AppDestinations { // Primary navigation destinations const val MODEL_SELECTION_ROUTE = "model_selection" const val MODEL_LOADING_ROUTE = "model_loading" - const val CONVERSATION_ROUTE = "conversation" - const val BENCHMARK_ROUTE = "benchmark" - // Settings destinations (moved from tabs to separate routes) + const val CONVERSATION_ROUTE = "conversation" + const val CONVERSATION_ROUTE_WITH_PARAMS = "conversation/{modelLoadTimeMs}/{promptTimeMs}" + + const val BENCHMARK_ROUTE = "benchmark" + const val BENCHMARK_ROUTE_WITH_PARAMS = "benchmark/{modelLoadTimeMs}" + + // Settings destinations const val SETTINGS_GENERAL_ROUTE = "settings_general" const val MODELS_MANAGEMENT_ROUTE = "models_management" } @@ -33,12 +38,17 @@ class NavigationActions(private val navController: NavController) { navController.navigate(AppDestinations.MODEL_LOADING_ROUTE) } - fun navigateToConversation() { - navController.navigate(AppDestinations.CONVERSATION_ROUTE) + fun navigateToConversation(metrics: ModelLoadingMetrics) { + val route = AppDestinations.CONVERSATION_ROUTE + val modelLoadTimeMs = metrics.modelLoadingTimeMs + val promptTimeMs = metrics.systemPromptProcessingTimeMs ?: 0 + navController.navigate("$route/$modelLoadTimeMs/$promptTimeMs") } - fun navigateToBenchmark() { - navController.navigate(AppDestinations.BENCHMARK_ROUTE) + fun navigateToBenchmark(metrics: ModelLoadingMetrics) { + val route = AppDestinations.BENCHMARK_ROUTE + val modelLoadTimeMs = metrics.modelLoadingTimeMs + navController.navigate("$route/$modelLoadTimeMs") } fun navigateToSettingsGeneral() { diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt index 768f60ef23..36297f2d89 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt @@ -24,6 +24,7 @@ import androidx.compose.runtime.getValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.unit.dp +import com.example.llama.revamp.engine.ModelLoadingMetrics import com.example.llama.revamp.ui.components.ModelCard import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler import com.example.llama.revamp.ui.theme.MonospacedTextStyle @@ -31,6 +32,7 @@ import com.example.llama.revamp.viewmodel.BenchmarkViewModel @Composable fun BenchmarkScreen( + loadingMetrics: ModelLoadingMetrics, onNavigateBack: () -> Unit, viewModel: BenchmarkViewModel ) { diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt index ff95861ee8..19584f2dd6 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt @@ -57,6 +57,7 @@ import androidx.lifecycle.Lifecycle import androidx.lifecycle.LifecycleEventObserver import androidx.lifecycle.compose.LocalLifecycleOwner import com.example.llama.revamp.APP_NAME +import com.example.llama.revamp.engine.ModelLoadingMetrics import com.example.llama.revamp.ui.components.ModelCardWithSystemPrompt import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler import com.example.llama.revamp.viewmodel.ConversationViewModel @@ -68,6 +69,7 @@ import kotlinx.coroutines.launch */ @Composable fun ConversationScreen( + loadingMetrics: ModelLoadingMetrics, onNavigateBack: () -> Unit, viewModel: ConversationViewModel ) { diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt index 3162f886d9..132a80afd1 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt @@ -42,7 +42,6 @@ import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember -import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier @@ -50,11 +49,10 @@ import androidx.compose.ui.semantics.Role import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp import com.example.llama.revamp.data.model.SystemPrompt +import com.example.llama.revamp.engine.ModelLoadingMetrics import com.example.llama.revamp.ui.components.ModelCard import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler import com.example.llama.revamp.viewmodel.ModelLoadingViewModel -import kotlinx.coroutines.Job -import kotlinx.coroutines.launch enum class Mode { @@ -70,12 +68,10 @@ enum class SystemPromptTab(val label: String) { @Composable fun ModelLoadingScreen( onNavigateBack: () -> Unit, - onBenchmarkSelected: (prepareJob: Job) -> Unit, - onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit, + onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit, + onNavigateToConversation: (ModelLoadingMetrics) -> Unit, viewModel: ModelLoadingViewModel, ) { - val coroutineScope = rememberCoroutineScope() - val engineState by viewModel.engineState.collectAsState() val selectedModel by viewModel.selectedModel.collectAsState() val presetPrompts by viewModel.presetPrompts.collectAsState() @@ -108,22 +104,6 @@ fun ModelLoadingScreen( // Check if we're in a loading state val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady - // Mode selection callbacks - val handleBenchmarkSelected = { - val prepareJob = coroutineScope.launch { - viewModel.prepareForBenchmark() - } - onBenchmarkSelected(prepareJob) - } - - // TODO-han.yin: refactor this into ViewModel too - val handleConversationSelected = { systemPrompt: String? -> - val prepareJob = coroutineScope.launch { - viewModel.prepareForConversation(systemPrompt) - } - onConversationSelected(systemPrompt, prepareJob) - } - // Handle back navigation requests BackHandler { viewModel.onBackPressed(onNavigateBack) @@ -301,7 +281,7 @@ fun ModelLoadingScreen( Button( onClick = { when (selectedMode) { - Mode.BENCHMARK -> handleBenchmarkSelected() + Mode.BENCHMARK -> viewModel.onBenchmarkSelected(onNavigateToBenchmark) Mode.CONVERSATION -> { val systemPrompt = if (useSystemPrompt) { @@ -324,7 +304,7 @@ fun ModelLoadingScreen( } } else null - handleConversationSelected(systemPrompt) + viewModel.onConversationSelected(systemPrompt, onNavigateToConversation) } null -> { /* No mode selected */ diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt index d12cfcb36a..6e15a6e359 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt @@ -3,6 +3,7 @@ package com.example.llama.revamp.viewmodel import androidx.lifecycle.viewModelScope import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.data.repository.SystemPromptRepository +import com.example.llama.revamp.engine.ModelLoadingMetrics import com.example.llama.revamp.engine.ModelLoadingService import dagger.hilt.android.lifecycle.HiltViewModel import kotlinx.coroutines.flow.SharingStarted @@ -79,17 +80,23 @@ class ModelLoadingViewModel @Inject constructor( } /** - * Prepares the engine for benchmark mode. + * Loads the model, then navigate to [BenchmarkScreen] with [ModelLoadingMetrics] */ - suspend fun prepareForBenchmark() = - modelLoadingService.loadModelForBenchmark() + fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) = + viewModelScope.launch { + onNavigateToBenchmark(modelLoadingService.loadModelForBenchmark()) + } /** - * Prepare for conversation + * Loads the model, process system prompt if any, + * then navigate to [ConversationScreen] with [ModelLoadingMetrics] */ - suspend fun prepareForConversation(systemPrompt: String? = null) = - modelLoadingService.loadModelForConversation(systemPrompt) - + fun onConversationSelected( + systemPrompt: String? = null, + onNavigateToConversation: (ModelLoadingMetrics) -> Unit + ) = viewModelScope.launch { + onNavigateToConversation(modelLoadingService.loadModelForConversation(systemPrompt)) + } companion object { private val TAG = ModelLoadingViewModel::class.java.simpleName