DI: replace manual DI with Hilt DI

This commit is contained in:
Han Yin 2025-04-13 18:08:38 -07:00
parent a1f6e7e476
commit 0afd087f35
8 changed files with 31 additions and 147 deletions

View File

@ -22,13 +22,12 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier
import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.compose.LocalLifecycleOwner
import androidx.lifecycle.viewmodel.compose.viewModel
import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable
import androidx.navigation.compose.currentBackStackEntryAsState
import androidx.navigation.compose.rememberNavController
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.navigation.AppDestinations
import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.AppNavigationDrawer
@ -40,11 +39,12 @@ import com.example.llama.revamp.ui.screens.ModelsManagementScreen
import com.example.llama.revamp.ui.screens.ModelLoadingScreen
import com.example.llama.revamp.ui.screens.SettingsGeneralScreen
import com.example.llama.revamp.ui.theme.LlamaTheme
import com.example.llama.revamp.util.ViewModelFactoryProvider
import com.example.llama.revamp.viewmodel.MainViewModel
import dagger.hilt.android.AndroidEntryPoint
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
@AndroidEntryPoint
class MainActivity : ComponentActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
@ -62,26 +62,19 @@ class MainActivity : ComponentActivity() {
}
@Composable
fun AppContent() {
val navController = rememberNavController()
val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed)
fun AppContent(
mainVewModel: MainViewModel = hiltViewModel()
) {
val coroutineScope = rememberCoroutineScope()
// Create inference engine
val inferenceEngine = remember { InferenceEngine() }
val navController = rememberNavController()
val navigationActions = remember(navController) { NavigationActions(navController) }
val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed)
// Create factory for MainViewModel
val factory = remember { ViewModelFactoryProvider.getMainViewModelFactory(inferenceEngine) }
val engineState by mainVewModel.engineState.collectAsState()
// TODO-han.yin: Also use delegate for `isModelLoaded`:
val isModelLoaded = remember(engineState) { mainVewModel.isModelLoaded() }
// Get ViewModel instance with factory
val viewModel: MainViewModel = viewModel(factory = factory)
val engineState by viewModel.engineState.collectAsState()
val isModelLoaded = remember(engineState) { viewModel.isModelLoaded() }
val navigationActions = remember(navController) {
NavigationActions(navController)
}
// Model unloading confirmation
var showUnloadDialog by remember { mutableStateOf(false) }
@ -123,10 +116,10 @@ fun AppContent() {
// Helper function to handle back press with model unloading check
val handleBackWithModelCheck = {
if (viewModel.isModelLoading()) {
if (mainVewModel.isModelLoading()) {
// If model is still loading, ignore the request
true // Mark as handled
} else if (viewModel.isModelLoaded()) {
} else if (mainVewModel.isModelLoaded()) {
showUnloadDialog = true
pendingNavigation = { navController.popBackStack() }
true // Mark as handled
@ -194,7 +187,7 @@ fun AppContent() {
composable(AppDestinations.MODEL_SELECTION_ROUTE) {
ModelSelectionScreen(
onModelSelected = { modelInfo ->
viewModel.selectModel(modelInfo)
mainVewModel.selectModel(modelInfo)
navigationActions.navigateToModelLoading()
},
onManageModelsClicked = {
@ -211,13 +204,13 @@ fun AppContent() {
ModelLoadingScreen(
engineState = engineState,
onBenchmarkSelected = {
viewModel.prepareForBenchmark()
mainVewModel.prepareForBenchmark()
navigationActions.navigateToBenchmark()
},
onConversationSelected = { systemPrompt ->
// Store a reference to the loading job
val loadingJob = coroutineScope.launch {
viewModel.prepareForConversation(systemPrompt)
mainVewModel.prepareForConversation(systemPrompt)
// Check if the job wasn't cancelled before navigating
if (isActive) {
navigationActions.navigateToConversation()
@ -245,7 +238,7 @@ fun AppContent() {
// Need to unload model before going back
handleBackWithModelCheck()
},
viewModel = viewModel
viewModel = mainVewModel
)
}
@ -257,14 +250,14 @@ fun AppContent() {
handleBackWithModelCheck()
},
onRerunPressed = {
viewModel.rerunBenchmark()
mainVewModel.rerunBenchmark()
},
onSharePressed = {
// Stub for sharing functionality
},
drawerState = drawerState,
navigationActions = navigationActions,
viewModel = viewModel
viewModel = mainVewModel
)
}
@ -295,7 +288,7 @@ fun AppContent() {
onConfirm = {
isUnloading = true
coroutineScope.launch {
viewModel.unloadModel()
mainVewModel.unloadModel()
isUnloading = false
showUnloadDialog = false
pendingNavigation?.invoke()

View File

@ -8,11 +8,7 @@ import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.remember
import androidx.compose.ui.platform.LocalContext
import androidx.lifecycle.viewmodel.compose.viewModel
import com.example.llama.revamp.data.preferences.UserPreferences
import com.example.llama.revamp.monitoring.PerformanceMonitor
import com.example.llama.revamp.util.ViewModelFactoryProvider
import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.viewmodel.PerformanceViewModel
// DefaultAppScaffold.kt
@ -42,6 +38,7 @@ fun DefaultAppScaffold(
// PerformanceAppScaffold.kt
@Composable
fun PerformanceAppScaffold(
performanceViewModel: PerformanceViewModel = hiltViewModel(),
title: String,
onNavigateBack: (() -> Unit)? = null,
onMenuOpen: (() -> Unit)? = null,
@ -49,17 +46,6 @@ fun PerformanceAppScaffold(
snackbarHostState: SnackbarHostState = remember { SnackbarHostState() },
content: @Composable (PaddingValues) -> Unit
) {
// Create dependencies for PerformanceViewModel
val context = LocalContext.current
val performanceMonitor = remember { PerformanceMonitor(context) }
val userPreferences = remember { UserPreferences(context) }
// Create factory for PerformanceViewModel
val factory = remember { ViewModelFactoryProvider.getPerformanceViewModelFactory(performanceMonitor, userPreferences) }
// Get ViewModel instance with factory
val performanceViewModel: PerformanceViewModel = viewModel(factory = factory)
// Collect performance metrics
val memoryUsage by performanceViewModel.memoryUsage.collectAsState()
val temperatureInfo by performanceViewModel.temperatureMetrics.collectAsState()

View File

@ -22,7 +22,7 @@ import androidx.compose.runtime.getValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp
import androidx.lifecycle.viewmodel.compose.viewModel
import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
@ -36,7 +36,7 @@ fun BenchmarkScreen(
onSharePressed: () -> Unit,
drawerState: DrawerState,
navigationActions: NavigationActions,
viewModel: MainViewModel = viewModel()
viewModel: MainViewModel = hiltViewModel()
) {
val engineState by viewModel.engineState.collectAsState()
val benchmarkResults by viewModel.benchmarkResults.collectAsState()

View File

@ -57,6 +57,7 @@ import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.StrokeCap
import androidx.compose.ui.platform.LocalLifecycleOwner
import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleEventObserver
import com.example.llama.revamp.engine.InferenceEngine
@ -71,7 +72,7 @@ import kotlinx.coroutines.launch
@Composable
fun ConversationScreen(
onBackPressed: () -> Unit,
viewModel: MainViewModel
viewModel: MainViewModel = hiltViewModel()
) {
val engineState by viewModel.engineState.collectAsState()
val messages by viewModel.messages.collectAsState()

View File

@ -44,17 +44,14 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.semantics.Role
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp
import androidx.lifecycle.viewmodel.compose.viewModel
import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.data.model.SystemPrompt
import com.example.llama.revamp.data.repository.SystemPromptRepository
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
import com.example.llama.revamp.util.ViewModelFactoryProvider
import com.example.llama.revamp.viewmodel.SystemPromptViewModel
enum class SystemPromptTab {
@ -64,6 +61,7 @@ enum class SystemPromptTab {
@OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
@Composable
fun ModelLoadingScreen(
viewModel: SystemPromptViewModel = hiltViewModel(),
engineState: InferenceEngine.State,
onBenchmarkSelected: () -> Unit,
onConversationSelected: (String?) -> Unit,
@ -71,12 +69,6 @@ fun ModelLoadingScreen(
drawerState: DrawerState,
navigationActions: NavigationActions
) {
// Set up SystemPromptViewModel
val context = LocalContext.current
val repository = remember { SystemPromptRepository(context) }
val factory = remember { ViewModelFactoryProvider.getSystemPromptViewModelFactory(repository) }
val viewModel: SystemPromptViewModel = viewModel(factory = factory)
val presetPrompts by viewModel.presetPrompts.collectAsState()
val recentPrompts by viewModel.recentPrompts.collectAsState()

View File

@ -10,11 +10,7 @@ import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.CloudDownload
import androidx.compose.material.icons.filled.Delete
import androidx.compose.material.icons.filled.Done
import androidx.compose.material.icons.filled.Edit
import androidx.compose.material.icons.filled.FileOpen
import androidx.compose.material.icons.filled.Info
import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults

View File

@ -18,17 +18,12 @@ import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.remember
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp
import androidx.lifecycle.viewmodel.compose.viewModel
import com.example.llama.revamp.data.preferences.UserPreferences
import com.example.llama.revamp.monitoring.PerformanceMonitor
import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.DefaultAppScaffold
import com.example.llama.revamp.util.ViewModelFactoryProvider
import com.example.llama.revamp.viewmodel.PerformanceViewModel
/**
@ -36,22 +31,12 @@ import com.example.llama.revamp.viewmodel.PerformanceViewModel
*/
@Composable
fun SettingsGeneralScreen(
performanceViewModel: PerformanceViewModel = hiltViewModel(),
onBackPressed: () -> Unit,
drawerState: DrawerState,
navigationActions: NavigationActions,
onMenuClicked: () -> Unit
) {
// Create dependencies for PerformanceViewModel
val context = LocalContext.current
val performanceMonitor = remember { PerformanceMonitor(context) }
val userPreferences = remember { UserPreferences(context) }
// Create factory for PerformanceViewModel
val factory = remember { ViewModelFactoryProvider.getPerformanceViewModelFactory(performanceMonitor, userPreferences) }
// Get ViewModel instance with factory
val performanceViewModel: PerformanceViewModel = viewModel(factory = factory)
// Collect state from ViewModel
val isMonitoringEnabled by performanceViewModel.isMonitoringEnabled.collectAsState()
val useFahrenheit by performanceViewModel.useFahrenheitUnit.collectAsState()

View File

@ -1,69 +0,0 @@
package com.example.llama.revamp.util
import androidx.lifecycle.ViewModel
import androidx.lifecycle.ViewModelProvider
import com.example.llama.revamp.data.preferences.UserPreferences
import com.example.llama.revamp.data.repository.SystemPromptRepository
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.monitoring.PerformanceMonitor
import com.example.llama.revamp.viewmodel.MainViewModel
import com.example.llama.revamp.viewmodel.PerformanceViewModel
import com.example.llama.revamp.viewmodel.SystemPromptViewModel
/**
* Utility class to provide ViewModel factories.
*/
object ViewModelFactoryProvider {
/**
* Creates a factory for PerformanceViewModel.
*/
fun getPerformanceViewModelFactory(
performanceMonitor: PerformanceMonitor,
userPreferences: UserPreferences
): ViewModelProvider.Factory {
return object : ViewModelProvider.Factory {
@Suppress("UNCHECKED_CAST")
override fun <T : ViewModel> create(modelClass: Class<T>): T {
if (modelClass.isAssignableFrom(PerformanceViewModel::class.java)) {
return PerformanceViewModel(performanceMonitor, userPreferences) as T
}
throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}")
}
}
}
/**
* Creates a factory for MainViewModel.
*/
fun getMainViewModelFactory(
inferenceEngine: InferenceEngine
): ViewModelProvider.Factory {
return object : ViewModelProvider.Factory {
@Suppress("UNCHECKED_CAST")
override fun <T : ViewModel> create(modelClass: Class<T>): T {
if (modelClass.isAssignableFrom(MainViewModel::class.java)) {
return MainViewModel(inferenceEngine) as T
}
throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}")
}
}
}
/**
* Creates a factory for SystemPromptViewModel.
*/
fun getSystemPromptViewModelFactory(
repository: SystemPromptRepository
): ViewModelProvider.Factory {
return object : ViewModelProvider.Factory {
@Suppress("UNCHECKED_CAST")
override fun <T : ViewModel> create(modelClass: Class<T>): T {
if (modelClass.isAssignableFrom(SystemPromptViewModel::class.java)) {
return SystemPromptViewModel(repository) as T
}
throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}")
}
}
}
}