DI: replace manual DI with Hilt DI

This commit is contained in:
Han Yin 2025-04-13 18:08:38 -07:00
parent a1f6e7e476
commit 0afd087f35
8 changed files with 31 additions and 147 deletions

View File

@ -22,13 +22,12 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.compose.LocalLifecycleOwner import androidx.lifecycle.compose.LocalLifecycleOwner
import androidx.lifecycle.viewmodel.compose.viewModel
import androidx.navigation.compose.NavHost import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable import androidx.navigation.compose.composable
import androidx.navigation.compose.currentBackStackEntryAsState import androidx.navigation.compose.currentBackStackEntryAsState
import androidx.navigation.compose.rememberNavController import androidx.navigation.compose.rememberNavController
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.navigation.AppDestinations import com.example.llama.revamp.navigation.AppDestinations
import com.example.llama.revamp.navigation.NavigationActions import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.AppNavigationDrawer import com.example.llama.revamp.ui.components.AppNavigationDrawer
@ -40,11 +39,12 @@ import com.example.llama.revamp.ui.screens.ModelsManagementScreen
import com.example.llama.revamp.ui.screens.ModelLoadingScreen import com.example.llama.revamp.ui.screens.ModelLoadingScreen
import com.example.llama.revamp.ui.screens.SettingsGeneralScreen import com.example.llama.revamp.ui.screens.SettingsGeneralScreen
import com.example.llama.revamp.ui.theme.LlamaTheme import com.example.llama.revamp.ui.theme.LlamaTheme
import com.example.llama.revamp.util.ViewModelFactoryProvider
import com.example.llama.revamp.viewmodel.MainViewModel import com.example.llama.revamp.viewmodel.MainViewModel
import dagger.hilt.android.AndroidEntryPoint
import kotlinx.coroutines.isActive import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@AndroidEntryPoint
class MainActivity : ComponentActivity() { class MainActivity : ComponentActivity() {
override fun onCreate(savedInstanceState: Bundle?) { override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState) super.onCreate(savedInstanceState)
@ -62,26 +62,19 @@ class MainActivity : ComponentActivity() {
} }
@Composable @Composable
fun AppContent() { fun AppContent(
val navController = rememberNavController() mainVewModel: MainViewModel = hiltViewModel()
val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed) ) {
val coroutineScope = rememberCoroutineScope() val coroutineScope = rememberCoroutineScope()
// Create inference engine val navController = rememberNavController()
val inferenceEngine = remember { InferenceEngine() } val navigationActions = remember(navController) { NavigationActions(navController) }
val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed)
// Create factory for MainViewModel val engineState by mainVewModel.engineState.collectAsState()
val factory = remember { ViewModelFactoryProvider.getMainViewModelFactory(inferenceEngine) } // TODO-han.yin: Also use delegate for `isModelLoaded`:
val isModelLoaded = remember(engineState) { mainVewModel.isModelLoaded() }
// Get ViewModel instance with factory
val viewModel: MainViewModel = viewModel(factory = factory)
val engineState by viewModel.engineState.collectAsState()
val isModelLoaded = remember(engineState) { viewModel.isModelLoaded() }
val navigationActions = remember(navController) {
NavigationActions(navController)
}
// Model unloading confirmation // Model unloading confirmation
var showUnloadDialog by remember { mutableStateOf(false) } var showUnloadDialog by remember { mutableStateOf(false) }
@ -123,10 +116,10 @@ fun AppContent() {
// Helper function to handle back press with model unloading check // Helper function to handle back press with model unloading check
val handleBackWithModelCheck = { val handleBackWithModelCheck = {
if (viewModel.isModelLoading()) { if (mainVewModel.isModelLoading()) {
// If model is still loading, ignore the request // If model is still loading, ignore the request
true // Mark as handled true // Mark as handled
} else if (viewModel.isModelLoaded()) { } else if (mainVewModel.isModelLoaded()) {
showUnloadDialog = true showUnloadDialog = true
pendingNavigation = { navController.popBackStack() } pendingNavigation = { navController.popBackStack() }
true // Mark as handled true // Mark as handled
@ -194,7 +187,7 @@ fun AppContent() {
composable(AppDestinations.MODEL_SELECTION_ROUTE) { composable(AppDestinations.MODEL_SELECTION_ROUTE) {
ModelSelectionScreen( ModelSelectionScreen(
onModelSelected = { modelInfo -> onModelSelected = { modelInfo ->
viewModel.selectModel(modelInfo) mainVewModel.selectModel(modelInfo)
navigationActions.navigateToModelLoading() navigationActions.navigateToModelLoading()
}, },
onManageModelsClicked = { onManageModelsClicked = {
@ -211,13 +204,13 @@ fun AppContent() {
ModelLoadingScreen( ModelLoadingScreen(
engineState = engineState, engineState = engineState,
onBenchmarkSelected = { onBenchmarkSelected = {
viewModel.prepareForBenchmark() mainVewModel.prepareForBenchmark()
navigationActions.navigateToBenchmark() navigationActions.navigateToBenchmark()
}, },
onConversationSelected = { systemPrompt -> onConversationSelected = { systemPrompt ->
// Store a reference to the loading job // Store a reference to the loading job
val loadingJob = coroutineScope.launch { val loadingJob = coroutineScope.launch {
viewModel.prepareForConversation(systemPrompt) mainVewModel.prepareForConversation(systemPrompt)
// Check if the job wasn't cancelled before navigating // Check if the job wasn't cancelled before navigating
if (isActive) { if (isActive) {
navigationActions.navigateToConversation() navigationActions.navigateToConversation()
@ -245,7 +238,7 @@ fun AppContent() {
// Need to unload model before going back // Need to unload model before going back
handleBackWithModelCheck() handleBackWithModelCheck()
}, },
viewModel = viewModel viewModel = mainVewModel
) )
} }
@ -257,14 +250,14 @@ fun AppContent() {
handleBackWithModelCheck() handleBackWithModelCheck()
}, },
onRerunPressed = { onRerunPressed = {
viewModel.rerunBenchmark() mainVewModel.rerunBenchmark()
}, },
onSharePressed = { onSharePressed = {
// Stub for sharing functionality // Stub for sharing functionality
}, },
drawerState = drawerState, drawerState = drawerState,
navigationActions = navigationActions, navigationActions = navigationActions,
viewModel = viewModel viewModel = mainVewModel
) )
} }
@ -295,7 +288,7 @@ fun AppContent() {
onConfirm = { onConfirm = {
isUnloading = true isUnloading = true
coroutineScope.launch { coroutineScope.launch {
viewModel.unloadModel() mainVewModel.unloadModel()
isUnloading = false isUnloading = false
showUnloadDialog = false showUnloadDialog = false
pendingNavigation?.invoke() pendingNavigation?.invoke()

View File

@ -8,11 +8,7 @@ import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.ui.platform.LocalContext import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.viewmodel.compose.viewModel
import com.example.llama.revamp.data.preferences.UserPreferences
import com.example.llama.revamp.monitoring.PerformanceMonitor
import com.example.llama.revamp.util.ViewModelFactoryProvider
import com.example.llama.revamp.viewmodel.PerformanceViewModel import com.example.llama.revamp.viewmodel.PerformanceViewModel
// DefaultAppScaffold.kt // DefaultAppScaffold.kt
@ -42,6 +38,7 @@ fun DefaultAppScaffold(
// PerformanceAppScaffold.kt // PerformanceAppScaffold.kt
@Composable @Composable
fun PerformanceAppScaffold( fun PerformanceAppScaffold(
performanceViewModel: PerformanceViewModel = hiltViewModel(),
title: String, title: String,
onNavigateBack: (() -> Unit)? = null, onNavigateBack: (() -> Unit)? = null,
onMenuOpen: (() -> Unit)? = null, onMenuOpen: (() -> Unit)? = null,
@ -49,17 +46,6 @@ fun PerformanceAppScaffold(
snackbarHostState: SnackbarHostState = remember { SnackbarHostState() }, snackbarHostState: SnackbarHostState = remember { SnackbarHostState() },
content: @Composable (PaddingValues) -> Unit content: @Composable (PaddingValues) -> Unit
) { ) {
// Create dependencies for PerformanceViewModel
val context = LocalContext.current
val performanceMonitor = remember { PerformanceMonitor(context) }
val userPreferences = remember { UserPreferences(context) }
// Create factory for PerformanceViewModel
val factory = remember { ViewModelFactoryProvider.getPerformanceViewModelFactory(performanceMonitor, userPreferences) }
// Get ViewModel instance with factory
val performanceViewModel: PerformanceViewModel = viewModel(factory = factory)
// Collect performance metrics // Collect performance metrics
val memoryUsage by performanceViewModel.memoryUsage.collectAsState() val memoryUsage by performanceViewModel.memoryUsage.collectAsState()
val temperatureInfo by performanceViewModel.temperatureMetrics.collectAsState() val temperatureInfo by performanceViewModel.temperatureMetrics.collectAsState()

View File

@ -22,7 +22,7 @@ import androidx.compose.runtime.getValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.lifecycle.viewmodel.compose.viewModel import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.navigation.NavigationActions import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.components.PerformanceAppScaffold
@ -36,7 +36,7 @@ fun BenchmarkScreen(
onSharePressed: () -> Unit, onSharePressed: () -> Unit,
drawerState: DrawerState, drawerState: DrawerState,
navigationActions: NavigationActions, navigationActions: NavigationActions,
viewModel: MainViewModel = viewModel() viewModel: MainViewModel = hiltViewModel()
) { ) {
val engineState by viewModel.engineState.collectAsState() val engineState by viewModel.engineState.collectAsState()
val benchmarkResults by viewModel.benchmarkResults.collectAsState() val benchmarkResults by viewModel.benchmarkResults.collectAsState()

View File

@ -57,6 +57,7 @@ import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.StrokeCap import androidx.compose.ui.graphics.StrokeCap
import androidx.compose.ui.platform.LocalLifecycleOwner import androidx.compose.ui.platform.LocalLifecycleOwner
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.Lifecycle import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleEventObserver import androidx.lifecycle.LifecycleEventObserver
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceEngine
@ -71,7 +72,7 @@ import kotlinx.coroutines.launch
@Composable @Composable
fun ConversationScreen( fun ConversationScreen(
onBackPressed: () -> Unit, onBackPressed: () -> Unit,
viewModel: MainViewModel viewModel: MainViewModel = hiltViewModel()
) { ) {
val engineState by viewModel.engineState.collectAsState() val engineState by viewModel.engineState.collectAsState()
val messages by viewModel.messages.collectAsState() val messages by viewModel.messages.collectAsState()

View File

@ -44,17 +44,14 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.semantics.Role import androidx.compose.ui.semantics.Role
import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.lifecycle.viewmodel.compose.viewModel import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.data.model.SystemPrompt
import com.example.llama.revamp.data.repository.SystemPromptRepository
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.navigation.NavigationActions import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.components.PerformanceAppScaffold
import com.example.llama.revamp.util.ViewModelFactoryProvider
import com.example.llama.revamp.viewmodel.SystemPromptViewModel import com.example.llama.revamp.viewmodel.SystemPromptViewModel
enum class SystemPromptTab { enum class SystemPromptTab {
@ -64,6 +61,7 @@ enum class SystemPromptTab {
@OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class) @OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
@Composable @Composable
fun ModelLoadingScreen( fun ModelLoadingScreen(
viewModel: SystemPromptViewModel = hiltViewModel(),
engineState: InferenceEngine.State, engineState: InferenceEngine.State,
onBenchmarkSelected: () -> Unit, onBenchmarkSelected: () -> Unit,
onConversationSelected: (String?) -> Unit, onConversationSelected: (String?) -> Unit,
@ -71,12 +69,6 @@ fun ModelLoadingScreen(
drawerState: DrawerState, drawerState: DrawerState,
navigationActions: NavigationActions navigationActions: NavigationActions
) { ) {
// Set up SystemPromptViewModel
val context = LocalContext.current
val repository = remember { SystemPromptRepository(context) }
val factory = remember { ViewModelFactoryProvider.getSystemPromptViewModelFactory(repository) }
val viewModel: SystemPromptViewModel = viewModel(factory = factory)
val presetPrompts by viewModel.presetPrompts.collectAsState() val presetPrompts by viewModel.presetPrompts.collectAsState()
val recentPrompts by viewModel.recentPrompts.collectAsState() val recentPrompts by viewModel.recentPrompts.collectAsState()

View File

@ -10,11 +10,7 @@ import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items import androidx.compose.foundation.lazy.items
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.CloudDownload
import androidx.compose.material.icons.filled.Delete import androidx.compose.material.icons.filled.Delete
import androidx.compose.material.icons.filled.Done
import androidx.compose.material.icons.filled.Edit
import androidx.compose.material.icons.filled.FileOpen
import androidx.compose.material.icons.filled.Info import androidx.compose.material.icons.filled.Info
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults import androidx.compose.material3.CardDefaults

View File

@ -18,17 +18,12 @@ import androidx.compose.material3.Text
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.remember
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.lifecycle.viewmodel.compose.viewModel import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.data.preferences.UserPreferences
import com.example.llama.revamp.monitoring.PerformanceMonitor
import com.example.llama.revamp.navigation.NavigationActions import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.DefaultAppScaffold import com.example.llama.revamp.ui.components.DefaultAppScaffold
import com.example.llama.revamp.util.ViewModelFactoryProvider
import com.example.llama.revamp.viewmodel.PerformanceViewModel import com.example.llama.revamp.viewmodel.PerformanceViewModel
/** /**
@ -36,22 +31,12 @@ import com.example.llama.revamp.viewmodel.PerformanceViewModel
*/ */
@Composable @Composable
fun SettingsGeneralScreen( fun SettingsGeneralScreen(
performanceViewModel: PerformanceViewModel = hiltViewModel(),
onBackPressed: () -> Unit, onBackPressed: () -> Unit,
drawerState: DrawerState, drawerState: DrawerState,
navigationActions: NavigationActions, navigationActions: NavigationActions,
onMenuClicked: () -> Unit onMenuClicked: () -> Unit
) { ) {
// Create dependencies for PerformanceViewModel
val context = LocalContext.current
val performanceMonitor = remember { PerformanceMonitor(context) }
val userPreferences = remember { UserPreferences(context) }
// Create factory for PerformanceViewModel
val factory = remember { ViewModelFactoryProvider.getPerformanceViewModelFactory(performanceMonitor, userPreferences) }
// Get ViewModel instance with factory
val performanceViewModel: PerformanceViewModel = viewModel(factory = factory)
// Collect state from ViewModel // Collect state from ViewModel
val isMonitoringEnabled by performanceViewModel.isMonitoringEnabled.collectAsState() val isMonitoringEnabled by performanceViewModel.isMonitoringEnabled.collectAsState()
val useFahrenheit by performanceViewModel.useFahrenheitUnit.collectAsState() val useFahrenheit by performanceViewModel.useFahrenheitUnit.collectAsState()

View File

@ -1,69 +0,0 @@
package com.example.llama.revamp.util
import androidx.lifecycle.ViewModel
import androidx.lifecycle.ViewModelProvider
import com.example.llama.revamp.data.preferences.UserPreferences
import com.example.llama.revamp.data.repository.SystemPromptRepository
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.monitoring.PerformanceMonitor
import com.example.llama.revamp.viewmodel.MainViewModel
import com.example.llama.revamp.viewmodel.PerformanceViewModel
import com.example.llama.revamp.viewmodel.SystemPromptViewModel
/**
* Utility class to provide ViewModel factories.
*/
object ViewModelFactoryProvider {
/**
* Creates a factory for PerformanceViewModel.
*/
fun getPerformanceViewModelFactory(
performanceMonitor: PerformanceMonitor,
userPreferences: UserPreferences
): ViewModelProvider.Factory {
return object : ViewModelProvider.Factory {
@Suppress("UNCHECKED_CAST")
override fun <T : ViewModel> create(modelClass: Class<T>): T {
if (modelClass.isAssignableFrom(PerformanceViewModel::class.java)) {
return PerformanceViewModel(performanceMonitor, userPreferences) as T
}
throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}")
}
}
}
/**
* Creates a factory for MainViewModel.
*/
fun getMainViewModelFactory(
inferenceEngine: InferenceEngine
): ViewModelProvider.Factory {
return object : ViewModelProvider.Factory {
@Suppress("UNCHECKED_CAST")
override fun <T : ViewModel> create(modelClass: Class<T>): T {
if (modelClass.isAssignableFrom(MainViewModel::class.java)) {
return MainViewModel(inferenceEngine) as T
}
throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}")
}
}
}
/**
* Creates a factory for SystemPromptViewModel.
*/
fun getSystemPromptViewModelFactory(
repository: SystemPromptRepository
): ViewModelProvider.Factory {
return object : ViewModelProvider.Factory {
@Suppress("UNCHECKED_CAST")
override fun <T : ViewModel> create(modelClass: Class<T>): T {
if (modelClass.isAssignableFrom(SystemPromptViewModel::class.java)) {
return SystemPromptViewModel(repository) as T
}
throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}")
}
}
}
}