diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt index 5ef00a399c..8c487ada09 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt @@ -22,13 +22,12 @@ import androidx.compose.runtime.remember import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.setValue import androidx.compose.ui.Modifier +import androidx.hilt.navigation.compose.hiltViewModel import androidx.lifecycle.compose.LocalLifecycleOwner -import androidx.lifecycle.viewmodel.compose.viewModel import androidx.navigation.compose.NavHost import androidx.navigation.compose.composable import androidx.navigation.compose.currentBackStackEntryAsState import androidx.navigation.compose.rememberNavController -import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.navigation.AppDestinations import com.example.llama.revamp.navigation.NavigationActions import com.example.llama.revamp.ui.components.AppNavigationDrawer @@ -40,11 +39,12 @@ import com.example.llama.revamp.ui.screens.ModelsManagementScreen import com.example.llama.revamp.ui.screens.ModelLoadingScreen import com.example.llama.revamp.ui.screens.SettingsGeneralScreen import com.example.llama.revamp.ui.theme.LlamaTheme -import com.example.llama.revamp.util.ViewModelFactoryProvider import com.example.llama.revamp.viewmodel.MainViewModel +import dagger.hilt.android.AndroidEntryPoint import kotlinx.coroutines.isActive import kotlinx.coroutines.launch +@AndroidEntryPoint class MainActivity : ComponentActivity() { override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) @@ -62,26 +62,19 @@ class MainActivity : ComponentActivity() { } @Composable -fun AppContent() { - val navController = rememberNavController() - val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed) +fun AppContent( + mainVewModel: MainViewModel = hiltViewModel() +) { val coroutineScope = rememberCoroutineScope() - // Create inference engine - val inferenceEngine = remember { InferenceEngine() } + val navController = rememberNavController() + val navigationActions = remember(navController) { NavigationActions(navController) } + val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed) - // Create factory for MainViewModel - val factory = remember { ViewModelFactoryProvider.getMainViewModelFactory(inferenceEngine) } + val engineState by mainVewModel.engineState.collectAsState() + // TODO-han.yin: Also use delegate for `isModelLoaded`: + val isModelLoaded = remember(engineState) { mainVewModel.isModelLoaded() } - // Get ViewModel instance with factory - val viewModel: MainViewModel = viewModel(factory = factory) - - val engineState by viewModel.engineState.collectAsState() - val isModelLoaded = remember(engineState) { viewModel.isModelLoaded() } - - val navigationActions = remember(navController) { - NavigationActions(navController) - } // Model unloading confirmation var showUnloadDialog by remember { mutableStateOf(false) } @@ -123,10 +116,10 @@ fun AppContent() { // Helper function to handle back press with model unloading check val handleBackWithModelCheck = { - if (viewModel.isModelLoading()) { + if (mainVewModel.isModelLoading()) { // If model is still loading, ignore the request true // Mark as handled - } else if (viewModel.isModelLoaded()) { + } else if (mainVewModel.isModelLoaded()) { showUnloadDialog = true pendingNavigation = { navController.popBackStack() } true // Mark as handled @@ -194,7 +187,7 @@ fun AppContent() { composable(AppDestinations.MODEL_SELECTION_ROUTE) { ModelSelectionScreen( onModelSelected = { modelInfo -> - viewModel.selectModel(modelInfo) + mainVewModel.selectModel(modelInfo) navigationActions.navigateToModelLoading() }, onManageModelsClicked = { @@ -211,13 +204,13 @@ fun AppContent() { ModelLoadingScreen( engineState = engineState, onBenchmarkSelected = { - viewModel.prepareForBenchmark() + mainVewModel.prepareForBenchmark() navigationActions.navigateToBenchmark() }, onConversationSelected = { systemPrompt -> // Store a reference to the loading job val loadingJob = coroutineScope.launch { - viewModel.prepareForConversation(systemPrompt) + mainVewModel.prepareForConversation(systemPrompt) // Check if the job wasn't cancelled before navigating if (isActive) { navigationActions.navigateToConversation() @@ -245,7 +238,7 @@ fun AppContent() { // Need to unload model before going back handleBackWithModelCheck() }, - viewModel = viewModel + viewModel = mainVewModel ) } @@ -257,14 +250,14 @@ fun AppContent() { handleBackWithModelCheck() }, onRerunPressed = { - viewModel.rerunBenchmark() + mainVewModel.rerunBenchmark() }, onSharePressed = { // Stub for sharing functionality }, drawerState = drawerState, navigationActions = navigationActions, - viewModel = viewModel + viewModel = mainVewModel ) } @@ -295,7 +288,7 @@ fun AppContent() { onConfirm = { isUnloading = true coroutineScope.launch { - viewModel.unloadModel() + mainVewModel.unloadModel() isUnloading = false showUnloadDialog = false pendingNavigation?.invoke() diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/AppScaffolds.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/AppScaffolds.kt index 7955b18b7c..b564b5e108 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/AppScaffolds.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/AppScaffolds.kt @@ -8,11 +8,7 @@ import androidx.compose.runtime.Composable import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue import androidx.compose.runtime.remember -import androidx.compose.ui.platform.LocalContext -import androidx.lifecycle.viewmodel.compose.viewModel -import com.example.llama.revamp.data.preferences.UserPreferences -import com.example.llama.revamp.monitoring.PerformanceMonitor -import com.example.llama.revamp.util.ViewModelFactoryProvider +import androidx.hilt.navigation.compose.hiltViewModel import com.example.llama.revamp.viewmodel.PerformanceViewModel // DefaultAppScaffold.kt @@ -42,6 +38,7 @@ fun DefaultAppScaffold( // PerformanceAppScaffold.kt @Composable fun PerformanceAppScaffold( + performanceViewModel: PerformanceViewModel = hiltViewModel(), title: String, onNavigateBack: (() -> Unit)? = null, onMenuOpen: (() -> Unit)? = null, @@ -49,17 +46,6 @@ fun PerformanceAppScaffold( snackbarHostState: SnackbarHostState = remember { SnackbarHostState() }, content: @Composable (PaddingValues) -> Unit ) { - // Create dependencies for PerformanceViewModel - val context = LocalContext.current - val performanceMonitor = remember { PerformanceMonitor(context) } - val userPreferences = remember { UserPreferences(context) } - - // Create factory for PerformanceViewModel - val factory = remember { ViewModelFactoryProvider.getPerformanceViewModelFactory(performanceMonitor, userPreferences) } - - // Get ViewModel instance with factory - val performanceViewModel: PerformanceViewModel = viewModel(factory = factory) - // Collect performance metrics val memoryUsage by performanceViewModel.memoryUsage.collectAsState() val temperatureInfo by performanceViewModel.temperatureMetrics.collectAsState() diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt index 6b5daca836..ca16ce5504 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt @@ -22,7 +22,7 @@ import androidx.compose.runtime.getValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.unit.dp -import androidx.lifecycle.viewmodel.compose.viewModel +import androidx.hilt.navigation.compose.hiltViewModel import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.navigation.NavigationActions import com.example.llama.revamp.ui.components.PerformanceAppScaffold @@ -36,7 +36,7 @@ fun BenchmarkScreen( onSharePressed: () -> Unit, drawerState: DrawerState, navigationActions: NavigationActions, - viewModel: MainViewModel = viewModel() + viewModel: MainViewModel = hiltViewModel() ) { val engineState by viewModel.engineState.collectAsState() val benchmarkResults by viewModel.benchmarkResults.collectAsState() diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt index 57de254d5b..7fad8b8773 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt @@ -57,6 +57,7 @@ import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.StrokeCap import androidx.compose.ui.platform.LocalLifecycleOwner import androidx.compose.ui.unit.dp +import androidx.hilt.navigation.compose.hiltViewModel import androidx.lifecycle.Lifecycle import androidx.lifecycle.LifecycleEventObserver import com.example.llama.revamp.engine.InferenceEngine @@ -71,7 +72,7 @@ import kotlinx.coroutines.launch @Composable fun ConversationScreen( onBackPressed: () -> Unit, - viewModel: MainViewModel + viewModel: MainViewModel = hiltViewModel() ) { val engineState by viewModel.engineState.collectAsState() val messages by viewModel.messages.collectAsState() diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt index 7cc252ba48..80ecd745ad 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt @@ -44,17 +44,14 @@ import androidx.compose.runtime.remember import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier -import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.semantics.Role import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp -import androidx.lifecycle.viewmodel.compose.viewModel +import androidx.hilt.navigation.compose.hiltViewModel import com.example.llama.revamp.data.model.SystemPrompt -import com.example.llama.revamp.data.repository.SystemPromptRepository import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.navigation.NavigationActions import com.example.llama.revamp.ui.components.PerformanceAppScaffold -import com.example.llama.revamp.util.ViewModelFactoryProvider import com.example.llama.revamp.viewmodel.SystemPromptViewModel enum class SystemPromptTab { @@ -64,6 +61,7 @@ enum class SystemPromptTab { @OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class) @Composable fun ModelLoadingScreen( + viewModel: SystemPromptViewModel = hiltViewModel(), engineState: InferenceEngine.State, onBenchmarkSelected: () -> Unit, onConversationSelected: (String?) -> Unit, @@ -71,12 +69,6 @@ fun ModelLoadingScreen( drawerState: DrawerState, navigationActions: NavigationActions ) { - // Set up SystemPromptViewModel - val context = LocalContext.current - val repository = remember { SystemPromptRepository(context) } - val factory = remember { ViewModelFactoryProvider.getSystemPromptViewModelFactory(repository) } - val viewModel: SystemPromptViewModel = viewModel(factory = factory) - val presetPrompts by viewModel.presetPrompts.collectAsState() val recentPrompts by viewModel.recentPrompts.collectAsState() diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelsManagementScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelsManagementScreen.kt index ae9d0161dd..74f912c0d7 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelsManagementScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelsManagementScreen.kt @@ -10,11 +10,7 @@ import androidx.compose.foundation.layout.padding import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.items import androidx.compose.material.icons.Icons -import androidx.compose.material.icons.filled.CloudDownload import androidx.compose.material.icons.filled.Delete -import androidx.compose.material.icons.filled.Done -import androidx.compose.material.icons.filled.Edit -import androidx.compose.material.icons.filled.FileOpen import androidx.compose.material.icons.filled.Info import androidx.compose.material3.Card import androidx.compose.material3.CardDefaults diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/SettingsGeneralScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/SettingsGeneralScreen.kt index 99a6a2f733..b8649455ab 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/SettingsGeneralScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/SettingsGeneralScreen.kt @@ -18,17 +18,12 @@ import androidx.compose.material3.Text import androidx.compose.runtime.Composable import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue -import androidx.compose.runtime.remember import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier -import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.unit.dp -import androidx.lifecycle.viewmodel.compose.viewModel -import com.example.llama.revamp.data.preferences.UserPreferences -import com.example.llama.revamp.monitoring.PerformanceMonitor +import androidx.hilt.navigation.compose.hiltViewModel import com.example.llama.revamp.navigation.NavigationActions import com.example.llama.revamp.ui.components.DefaultAppScaffold -import com.example.llama.revamp.util.ViewModelFactoryProvider import com.example.llama.revamp.viewmodel.PerformanceViewModel /** @@ -36,22 +31,12 @@ import com.example.llama.revamp.viewmodel.PerformanceViewModel */ @Composable fun SettingsGeneralScreen( + performanceViewModel: PerformanceViewModel = hiltViewModel(), onBackPressed: () -> Unit, drawerState: DrawerState, navigationActions: NavigationActions, onMenuClicked: () -> Unit ) { - // Create dependencies for PerformanceViewModel - val context = LocalContext.current - val performanceMonitor = remember { PerformanceMonitor(context) } - val userPreferences = remember { UserPreferences(context) } - - // Create factory for PerformanceViewModel - val factory = remember { ViewModelFactoryProvider.getPerformanceViewModelFactory(performanceMonitor, userPreferences) } - - // Get ViewModel instance with factory - val performanceViewModel: PerformanceViewModel = viewModel(factory = factory) - // Collect state from ViewModel val isMonitoringEnabled by performanceViewModel.isMonitoringEnabled.collectAsState() val useFahrenheit by performanceViewModel.useFahrenheitUnit.collectAsState() diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/ViewModelFactoryProvider.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/ViewModelFactoryProvider.kt deleted file mode 100644 index 2b7e72fe14..0000000000 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/ViewModelFactoryProvider.kt +++ /dev/null @@ -1,69 +0,0 @@ -package com.example.llama.revamp.util - -import androidx.lifecycle.ViewModel -import androidx.lifecycle.ViewModelProvider -import com.example.llama.revamp.data.preferences.UserPreferences -import com.example.llama.revamp.data.repository.SystemPromptRepository -import com.example.llama.revamp.engine.InferenceEngine -import com.example.llama.revamp.monitoring.PerformanceMonitor -import com.example.llama.revamp.viewmodel.MainViewModel -import com.example.llama.revamp.viewmodel.PerformanceViewModel -import com.example.llama.revamp.viewmodel.SystemPromptViewModel - -/** - * Utility class to provide ViewModel factories. - */ -object ViewModelFactoryProvider { - - /** - * Creates a factory for PerformanceViewModel. - */ - fun getPerformanceViewModelFactory( - performanceMonitor: PerformanceMonitor, - userPreferences: UserPreferences - ): ViewModelProvider.Factory { - return object : ViewModelProvider.Factory { - @Suppress("UNCHECKED_CAST") - override fun create(modelClass: Class): T { - if (modelClass.isAssignableFrom(PerformanceViewModel::class.java)) { - return PerformanceViewModel(performanceMonitor, userPreferences) as T - } - throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}") - } - } - } - - /** - * Creates a factory for MainViewModel. - */ - fun getMainViewModelFactory( - inferenceEngine: InferenceEngine - ): ViewModelProvider.Factory { - return object : ViewModelProvider.Factory { - @Suppress("UNCHECKED_CAST") - override fun create(modelClass: Class): T { - if (modelClass.isAssignableFrom(MainViewModel::class.java)) { - return MainViewModel(inferenceEngine) as T - } - throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}") - } - } - } - - /** - * Creates a factory for SystemPromptViewModel. - */ - fun getSystemPromptViewModelFactory( - repository: SystemPromptRepository - ): ViewModelProvider.Factory { - return object : ViewModelProvider.Factory { - @Suppress("UNCHECKED_CAST") - override fun create(modelClass: Class): T { - if (modelClass.isAssignableFrom(SystemPromptViewModel::class.java)) { - return SystemPromptViewModel(repository) as T - } - throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}") - } - } - } -}