DI: replace manual DI with Hilt DI
This commit is contained in:
parent
a1f6e7e476
commit
0afd087f35
|
|
@ -22,13 +22,12 @@ import androidx.compose.runtime.remember
|
|||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import androidx.lifecycle.compose.LocalLifecycleOwner
|
||||
import androidx.lifecycle.viewmodel.compose.viewModel
|
||||
import androidx.navigation.compose.NavHost
|
||||
import androidx.navigation.compose.composable
|
||||
import androidx.navigation.compose.currentBackStackEntryAsState
|
||||
import androidx.navigation.compose.rememberNavController
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.navigation.AppDestinations
|
||||
import com.example.llama.revamp.navigation.NavigationActions
|
||||
import com.example.llama.revamp.ui.components.AppNavigationDrawer
|
||||
|
|
@ -40,11 +39,12 @@ import com.example.llama.revamp.ui.screens.ModelsManagementScreen
|
|||
import com.example.llama.revamp.ui.screens.ModelLoadingScreen
|
||||
import com.example.llama.revamp.ui.screens.SettingsGeneralScreen
|
||||
import com.example.llama.revamp.ui.theme.LlamaTheme
|
||||
import com.example.llama.revamp.util.ViewModelFactoryProvider
|
||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
||||
import dagger.hilt.android.AndroidEntryPoint
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
@AndroidEntryPoint
|
||||
class MainActivity : ComponentActivity() {
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
|
|
@ -62,26 +62,19 @@ class MainActivity : ComponentActivity() {
|
|||
}
|
||||
|
||||
@Composable
|
||||
fun AppContent() {
|
||||
val navController = rememberNavController()
|
||||
val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed)
|
||||
fun AppContent(
|
||||
mainVewModel: MainViewModel = hiltViewModel()
|
||||
) {
|
||||
val coroutineScope = rememberCoroutineScope()
|
||||
|
||||
// Create inference engine
|
||||
val inferenceEngine = remember { InferenceEngine() }
|
||||
val navController = rememberNavController()
|
||||
val navigationActions = remember(navController) { NavigationActions(navController) }
|
||||
val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed)
|
||||
|
||||
// Create factory for MainViewModel
|
||||
val factory = remember { ViewModelFactoryProvider.getMainViewModelFactory(inferenceEngine) }
|
||||
val engineState by mainVewModel.engineState.collectAsState()
|
||||
// TODO-han.yin: Also use delegate for `isModelLoaded`:
|
||||
val isModelLoaded = remember(engineState) { mainVewModel.isModelLoaded() }
|
||||
|
||||
// Get ViewModel instance with factory
|
||||
val viewModel: MainViewModel = viewModel(factory = factory)
|
||||
|
||||
val engineState by viewModel.engineState.collectAsState()
|
||||
val isModelLoaded = remember(engineState) { viewModel.isModelLoaded() }
|
||||
|
||||
val navigationActions = remember(navController) {
|
||||
NavigationActions(navController)
|
||||
}
|
||||
|
||||
// Model unloading confirmation
|
||||
var showUnloadDialog by remember { mutableStateOf(false) }
|
||||
|
|
@ -123,10 +116,10 @@ fun AppContent() {
|
|||
|
||||
// Helper function to handle back press with model unloading check
|
||||
val handleBackWithModelCheck = {
|
||||
if (viewModel.isModelLoading()) {
|
||||
if (mainVewModel.isModelLoading()) {
|
||||
// If model is still loading, ignore the request
|
||||
true // Mark as handled
|
||||
} else if (viewModel.isModelLoaded()) {
|
||||
} else if (mainVewModel.isModelLoaded()) {
|
||||
showUnloadDialog = true
|
||||
pendingNavigation = { navController.popBackStack() }
|
||||
true // Mark as handled
|
||||
|
|
@ -194,7 +187,7 @@ fun AppContent() {
|
|||
composable(AppDestinations.MODEL_SELECTION_ROUTE) {
|
||||
ModelSelectionScreen(
|
||||
onModelSelected = { modelInfo ->
|
||||
viewModel.selectModel(modelInfo)
|
||||
mainVewModel.selectModel(modelInfo)
|
||||
navigationActions.navigateToModelLoading()
|
||||
},
|
||||
onManageModelsClicked = {
|
||||
|
|
@ -211,13 +204,13 @@ fun AppContent() {
|
|||
ModelLoadingScreen(
|
||||
engineState = engineState,
|
||||
onBenchmarkSelected = {
|
||||
viewModel.prepareForBenchmark()
|
||||
mainVewModel.prepareForBenchmark()
|
||||
navigationActions.navigateToBenchmark()
|
||||
},
|
||||
onConversationSelected = { systemPrompt ->
|
||||
// Store a reference to the loading job
|
||||
val loadingJob = coroutineScope.launch {
|
||||
viewModel.prepareForConversation(systemPrompt)
|
||||
mainVewModel.prepareForConversation(systemPrompt)
|
||||
// Check if the job wasn't cancelled before navigating
|
||||
if (isActive) {
|
||||
navigationActions.navigateToConversation()
|
||||
|
|
@ -245,7 +238,7 @@ fun AppContent() {
|
|||
// Need to unload model before going back
|
||||
handleBackWithModelCheck()
|
||||
},
|
||||
viewModel = viewModel
|
||||
viewModel = mainVewModel
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -257,14 +250,14 @@ fun AppContent() {
|
|||
handleBackWithModelCheck()
|
||||
},
|
||||
onRerunPressed = {
|
||||
viewModel.rerunBenchmark()
|
||||
mainVewModel.rerunBenchmark()
|
||||
},
|
||||
onSharePressed = {
|
||||
// Stub for sharing functionality
|
||||
},
|
||||
drawerState = drawerState,
|
||||
navigationActions = navigationActions,
|
||||
viewModel = viewModel
|
||||
viewModel = mainVewModel
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -295,7 +288,7 @@ fun AppContent() {
|
|||
onConfirm = {
|
||||
isUnloading = true
|
||||
coroutineScope.launch {
|
||||
viewModel.unloadModel()
|
||||
mainVewModel.unloadModel()
|
||||
isUnloading = false
|
||||
showUnloadDialog = false
|
||||
pendingNavigation?.invoke()
|
||||
|
|
|
|||
|
|
@ -8,11 +8,7 @@ import androidx.compose.runtime.Composable
|
|||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.lifecycle.viewmodel.compose.viewModel
|
||||
import com.example.llama.revamp.data.preferences.UserPreferences
|
||||
import com.example.llama.revamp.monitoring.PerformanceMonitor
|
||||
import com.example.llama.revamp.util.ViewModelFactoryProvider
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import com.example.llama.revamp.viewmodel.PerformanceViewModel
|
||||
|
||||
// DefaultAppScaffold.kt
|
||||
|
|
@ -42,6 +38,7 @@ fun DefaultAppScaffold(
|
|||
// PerformanceAppScaffold.kt
|
||||
@Composable
|
||||
fun PerformanceAppScaffold(
|
||||
performanceViewModel: PerformanceViewModel = hiltViewModel(),
|
||||
title: String,
|
||||
onNavigateBack: (() -> Unit)? = null,
|
||||
onMenuOpen: (() -> Unit)? = null,
|
||||
|
|
@ -49,17 +46,6 @@ fun PerformanceAppScaffold(
|
|||
snackbarHostState: SnackbarHostState = remember { SnackbarHostState() },
|
||||
content: @Composable (PaddingValues) -> Unit
|
||||
) {
|
||||
// Create dependencies for PerformanceViewModel
|
||||
val context = LocalContext.current
|
||||
val performanceMonitor = remember { PerformanceMonitor(context) }
|
||||
val userPreferences = remember { UserPreferences(context) }
|
||||
|
||||
// Create factory for PerformanceViewModel
|
||||
val factory = remember { ViewModelFactoryProvider.getPerformanceViewModelFactory(performanceMonitor, userPreferences) }
|
||||
|
||||
// Get ViewModel instance with factory
|
||||
val performanceViewModel: PerformanceViewModel = viewModel(factory = factory)
|
||||
|
||||
// Collect performance metrics
|
||||
val memoryUsage by performanceViewModel.memoryUsage.collectAsState()
|
||||
val temperatureInfo by performanceViewModel.temperatureMetrics.collectAsState()
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ import androidx.compose.runtime.getValue
|
|||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.lifecycle.viewmodel.compose.viewModel
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.navigation.NavigationActions
|
||||
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
||||
|
|
@ -36,7 +36,7 @@ fun BenchmarkScreen(
|
|||
onSharePressed: () -> Unit,
|
||||
drawerState: DrawerState,
|
||||
navigationActions: NavigationActions,
|
||||
viewModel: MainViewModel = viewModel()
|
||||
viewModel: MainViewModel = hiltViewModel()
|
||||
) {
|
||||
val engineState by viewModel.engineState.collectAsState()
|
||||
val benchmarkResults by viewModel.benchmarkResults.collectAsState()
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ import androidx.compose.ui.graphics.Color
|
|||
import androidx.compose.ui.graphics.StrokeCap
|
||||
import androidx.compose.ui.platform.LocalLifecycleOwner
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import androidx.lifecycle.Lifecycle
|
||||
import androidx.lifecycle.LifecycleEventObserver
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
|
|
@ -71,7 +72,7 @@ import kotlinx.coroutines.launch
|
|||
@Composable
|
||||
fun ConversationScreen(
|
||||
onBackPressed: () -> Unit,
|
||||
viewModel: MainViewModel
|
||||
viewModel: MainViewModel = hiltViewModel()
|
||||
) {
|
||||
val engineState by viewModel.engineState.collectAsState()
|
||||
val messages by viewModel.messages.collectAsState()
|
||||
|
|
|
|||
|
|
@ -44,17 +44,14 @@ import androidx.compose.runtime.remember
|
|||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.semantics.Role
|
||||
import androidx.compose.ui.text.style.TextOverflow
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.lifecycle.viewmodel.compose.viewModel
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import com.example.llama.revamp.data.model.SystemPrompt
|
||||
import com.example.llama.revamp.data.repository.SystemPromptRepository
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.navigation.NavigationActions
|
||||
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
||||
import com.example.llama.revamp.util.ViewModelFactoryProvider
|
||||
import com.example.llama.revamp.viewmodel.SystemPromptViewModel
|
||||
|
||||
enum class SystemPromptTab {
|
||||
|
|
@ -64,6 +61,7 @@ enum class SystemPromptTab {
|
|||
@OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
|
||||
@Composable
|
||||
fun ModelLoadingScreen(
|
||||
viewModel: SystemPromptViewModel = hiltViewModel(),
|
||||
engineState: InferenceEngine.State,
|
||||
onBenchmarkSelected: () -> Unit,
|
||||
onConversationSelected: (String?) -> Unit,
|
||||
|
|
@ -71,12 +69,6 @@ fun ModelLoadingScreen(
|
|||
drawerState: DrawerState,
|
||||
navigationActions: NavigationActions
|
||||
) {
|
||||
// Set up SystemPromptViewModel
|
||||
val context = LocalContext.current
|
||||
val repository = remember { SystemPromptRepository(context) }
|
||||
val factory = remember { ViewModelFactoryProvider.getSystemPromptViewModelFactory(repository) }
|
||||
val viewModel: SystemPromptViewModel = viewModel(factory = factory)
|
||||
|
||||
val presetPrompts by viewModel.presetPrompts.collectAsState()
|
||||
val recentPrompts by viewModel.recentPrompts.collectAsState()
|
||||
|
||||
|
|
|
|||
|
|
@ -10,11 +10,7 @@ import androidx.compose.foundation.layout.padding
|
|||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.CloudDownload
|
||||
import androidx.compose.material.icons.filled.Delete
|
||||
import androidx.compose.material.icons.filled.Done
|
||||
import androidx.compose.material.icons.filled.Edit
|
||||
import androidx.compose.material.icons.filled.FileOpen
|
||||
import androidx.compose.material.icons.filled.Info
|
||||
import androidx.compose.material3.Card
|
||||
import androidx.compose.material3.CardDefaults
|
||||
|
|
|
|||
|
|
@ -18,17 +18,12 @@ import androidx.compose.material3.Text
|
|||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.lifecycle.viewmodel.compose.viewModel
|
||||
import com.example.llama.revamp.data.preferences.UserPreferences
|
||||
import com.example.llama.revamp.monitoring.PerformanceMonitor
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import com.example.llama.revamp.navigation.NavigationActions
|
||||
import com.example.llama.revamp.ui.components.DefaultAppScaffold
|
||||
import com.example.llama.revamp.util.ViewModelFactoryProvider
|
||||
import com.example.llama.revamp.viewmodel.PerformanceViewModel
|
||||
|
||||
/**
|
||||
|
|
@ -36,22 +31,12 @@ import com.example.llama.revamp.viewmodel.PerformanceViewModel
|
|||
*/
|
||||
@Composable
|
||||
fun SettingsGeneralScreen(
|
||||
performanceViewModel: PerformanceViewModel = hiltViewModel(),
|
||||
onBackPressed: () -> Unit,
|
||||
drawerState: DrawerState,
|
||||
navigationActions: NavigationActions,
|
||||
onMenuClicked: () -> Unit
|
||||
) {
|
||||
// Create dependencies for PerformanceViewModel
|
||||
val context = LocalContext.current
|
||||
val performanceMonitor = remember { PerformanceMonitor(context) }
|
||||
val userPreferences = remember { UserPreferences(context) }
|
||||
|
||||
// Create factory for PerformanceViewModel
|
||||
val factory = remember { ViewModelFactoryProvider.getPerformanceViewModelFactory(performanceMonitor, userPreferences) }
|
||||
|
||||
// Get ViewModel instance with factory
|
||||
val performanceViewModel: PerformanceViewModel = viewModel(factory = factory)
|
||||
|
||||
// Collect state from ViewModel
|
||||
val isMonitoringEnabled by performanceViewModel.isMonitoringEnabled.collectAsState()
|
||||
val useFahrenheit by performanceViewModel.useFahrenheitUnit.collectAsState()
|
||||
|
|
|
|||
|
|
@ -1,69 +0,0 @@
|
|||
package com.example.llama.revamp.util
|
||||
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.ViewModelProvider
|
||||
import com.example.llama.revamp.data.preferences.UserPreferences
|
||||
import com.example.llama.revamp.data.repository.SystemPromptRepository
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.monitoring.PerformanceMonitor
|
||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
||||
import com.example.llama.revamp.viewmodel.PerformanceViewModel
|
||||
import com.example.llama.revamp.viewmodel.SystemPromptViewModel
|
||||
|
||||
/**
|
||||
* Utility class to provide ViewModel factories.
|
||||
*/
|
||||
object ViewModelFactoryProvider {
|
||||
|
||||
/**
|
||||
* Creates a factory for PerformanceViewModel.
|
||||
*/
|
||||
fun getPerformanceViewModelFactory(
|
||||
performanceMonitor: PerformanceMonitor,
|
||||
userPreferences: UserPreferences
|
||||
): ViewModelProvider.Factory {
|
||||
return object : ViewModelProvider.Factory {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun <T : ViewModel> create(modelClass: Class<T>): T {
|
||||
if (modelClass.isAssignableFrom(PerformanceViewModel::class.java)) {
|
||||
return PerformanceViewModel(performanceMonitor, userPreferences) as T
|
||||
}
|
||||
throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a factory for MainViewModel.
|
||||
*/
|
||||
fun getMainViewModelFactory(
|
||||
inferenceEngine: InferenceEngine
|
||||
): ViewModelProvider.Factory {
|
||||
return object : ViewModelProvider.Factory {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun <T : ViewModel> create(modelClass: Class<T>): T {
|
||||
if (modelClass.isAssignableFrom(MainViewModel::class.java)) {
|
||||
return MainViewModel(inferenceEngine) as T
|
||||
}
|
||||
throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a factory for SystemPromptViewModel.
|
||||
*/
|
||||
fun getSystemPromptViewModelFactory(
|
||||
repository: SystemPromptRepository
|
||||
): ViewModelProvider.Factory {
|
||||
return object : ViewModelProvider.Factory {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun <T : ViewModel> create(modelClass: Class<T>): T {
|
||||
if (modelClass.isAssignableFrom(SystemPromptViewModel::class.java)) {
|
||||
return SystemPromptViewModel(repository) as T
|
||||
}
|
||||
throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue