diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt index 89aa7d2408..17f2923ba6 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt @@ -48,6 +48,7 @@ import com.example.llama.revamp.viewmodel.BenchmarkViewModel import com.example.llama.revamp.viewmodel.ConversationViewModel import com.example.llama.revamp.viewmodel.MainViewModel import com.example.llama.revamp.viewmodel.ModelLoadingViewModel +import com.example.llama.revamp.viewmodel.ModelSelectionViewModel import com.example.llama.revamp.viewmodel.ModelsManagementViewModel import com.example.llama.revamp.viewmodel.PerformanceViewModel import dagger.hilt.android.AndroidEntryPoint @@ -74,6 +75,7 @@ class MainActivity : ComponentActivity() { fun AppContent( mainViewModel: MainViewModel = hiltViewModel(), performanceViewModel: PerformanceViewModel = hiltViewModel(), + modelSelectionViewModel: ModelSelectionViewModel = hiltViewModel(), modelLoadingViewModel: ModelLoadingViewModel = hiltViewModel(), benchmarkViewModel: BenchmarkViewModel = hiltViewModel(), conversationViewModel: ConversationViewModel = hiltViewModel(), @@ -104,9 +106,7 @@ fun AppContent( val drawerGesturesEnabled by remember(currentRoute, drawerState.currentValue) { derivedStateOf { // Always allow gesture dismissal when drawer is open - if (drawerState.currentValue == DrawerValue.Open) true - // Only enable drawer opening by gesture on these screens - else currentRoute == AppDestinations.MODEL_SELECTION_ROUTE + if (drawerState.currentValue == DrawerValue.Open) true else false } } val openDrawer: () -> Unit = { coroutineScope.launch { drawerState.open() } } @@ -118,7 +118,10 @@ fun AppContent( ScaffoldConfig( topBarConfig = TopBarConfig.Default( title = "Models", - navigationIcon = NavigationIcon.Menu(openDrawer) + navigationIcon = NavigationIcon.Menu { + modelSelectionViewModel.resetSelection() + openDrawer() + } ) ) @@ -290,12 +293,13 @@ fun AppContent( // Model Selection Screen composable(AppDestinations.MODEL_SELECTION_ROUTE) { ModelSelectionScreen( - onModelSelected = { modelInfo -> + onModelConfirmed = { modelInfo -> navigationActions.navigateToModelLoading() }, onManageModelsClicked = { navigationActions.navigateToModelsManagement() }, + viewModel = modelSelectionViewModel ) } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/ModelCards.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/ModelCards.kt index 487c0a14fb..b8a30f9a8a 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/ModelCards.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/ModelCards.kt @@ -140,25 +140,26 @@ fun ModelCardContentCore( * toggled by clicking on the content area of the card. * * @param model The model information to display - * @param onClick Action to perform when the card is clicked (for selection) - * @param expanded Whether additional details are currently shown * @param isSelected Optional selection state (shows checkbox when not null) + * @param onSelected Action to perform when the card is selected (in multi-selection mode) + * @param isExpanded Whether additional details is expanded or shrunk + * @param onExpanded Action to perform when the card is expanded or shrunk */ @Composable fun ModelCardExpandable( model: ModelInfo, - onClick: () -> Unit, - expanded: Boolean, isSelected: Boolean? = null, + onSelected: ((Boolean) -> Unit)? = null, + isExpanded: Boolean = false, + onExpanded: ((Boolean) -> Unit)? = null, ) { - var isExpanded by remember { mutableStateOf(expanded) } - CompositionLocalProvider(LocalMinimumInteractiveComponentSize provides Dp.Unspecified) { Card( modifier = Modifier .fillMaxWidth() - .clickable { isExpanded = !isExpanded } - , + .clickable { + onExpanded?.invoke(!isExpanded) + }, colors = when (isSelected) { true -> CardDefaults.cardColors(containerColor = MaterialTheme.colorScheme.primaryContainer) false -> CardDefaults.cardColors() @@ -173,16 +174,17 @@ fun ModelCardExpandable( verticalAlignment = Alignment.Top ) { // Show checkbox if in selection mode - if (isSelected != null) { + isSelected?.let { selected -> Checkbox( - checked = isSelected, - onCheckedChange = { onClick() }, + checked = selected, + onCheckedChange = { onSelected?.invoke(it) }, modifier = Modifier.padding(top = 16.dp, start = 16.dp) ) } Box( - modifier = Modifier.weight(1f) + modifier = Modifier + .weight(1f) .padding(start = 16.dp, top = 16.dp, end = 16.dp) ) { // Core content always visible @@ -197,7 +199,9 @@ fun ModelCardExpandable( exit = shrinkVertically() + fadeOut() ) { Box( - modifier = Modifier.weight(1f).padding(horizontal = 16.dp) + modifier = Modifier + .weight(1f) + .padding(horizontal = 16.dp) ) { ExpandableModelDetails(model = model) } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt index a7a8fc4cb7..2d87c3be6e 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt @@ -1,6 +1,12 @@ package com.example.llama.revamp.ui.screens +import androidx.compose.animation.AnimatedVisibility +import androidx.compose.animation.fadeIn +import androidx.compose.animation.fadeOut +import androidx.compose.animation.scaleIn +import androidx.compose.animation.scaleOut import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxSize @@ -13,8 +19,10 @@ import androidx.compose.foundation.lazy.items import androidx.compose.material.icons.Icons import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.filled.FolderOpen +import androidx.compose.material.icons.filled.PlayArrow import androidx.compose.material3.Button import androidx.compose.material3.ExperimentalMaterial3Api +import androidx.compose.material3.FloatingActionButton import androidx.compose.material3.Icon import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Text @@ -25,7 +33,6 @@ import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.unit.dp -import androidx.hilt.navigation.compose.hiltViewModel import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.ui.components.ModelCardExpandable import com.example.llama.revamp.viewmodel.ModelSelectionViewModel @@ -33,43 +40,61 @@ import com.example.llama.revamp.viewmodel.ModelSelectionViewModel @OptIn(ExperimentalMaterial3Api::class) @Composable fun ModelSelectionScreen( - onModelSelected: (ModelInfo) -> Unit, + onModelConfirmed: (ModelInfo) -> Unit, onManageModelsClicked: () -> Unit, - viewModel: ModelSelectionViewModel = hiltViewModel(), + viewModel: ModelSelectionViewModel, ) { val models by viewModel.availableModels.collectAsState() + val preselectedModel by viewModel.preselectedModel.collectAsState() val handleModelSelection = { model: ModelInfo -> - viewModel.selectModel(model) - onModelSelected(model) + viewModel.confirmSelectedModel(model) + onModelConfirmed(model) } - Column( - modifier = Modifier - .fillMaxSize() - .padding(horizontal = 16.dp) - ) { - if (models.isEmpty()) { - EmptyModelsView(onManageModelsClicked) - } else { - LazyColumn { - items(models) { model -> - ModelCardExpandable( - model = model, - onClick = { handleModelSelection(model) }, - expanded = false, - isSelected = null, // Not in selection mode - // TODO-han.yin: refactor this -// actionButton = { -// ModelCardActions.PlayButton { -// handleModelSelection(model) -// } -// }, - ) - Spacer(modifier = Modifier.height(8.dp)) + Box(modifier = Modifier.fillMaxSize()) { + Column( + modifier = Modifier + .fillMaxSize() + .padding(horizontal = 16.dp) + ) { + if (models.isEmpty()) { + EmptyModelsView(onManageModelsClicked) + } else { + LazyColumn( + verticalArrangement = Arrangement.spacedBy(12.dp) + ) { + items(items = models, key = { it.id }) { model -> + ModelCardExpandable( + model = model, + isSelected = if (model == preselectedModel) true else null, + isExpanded = model == preselectedModel, + onExpanded = { expanded -> + viewModel.preselectModel(model, expanded) + } + ) + } } } } + + // Show FAB if a model is selected + AnimatedVisibility( + modifier = Modifier.padding(16.dp).align(Alignment.BottomEnd), + visible = preselectedModel != null, + enter = scaleIn() + fadeIn(), + exit = scaleOut() + fadeOut() + ) { + FloatingActionButton( + onClick = { preselectedModel?.let { handleModelSelection(it) } }, + containerColor = MaterialTheme.colorScheme.primary + ) { + Icon( + imageVector = Icons.Default.PlayArrow, + contentDescription = "Start with selected model" + ) + } + } } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelsManagementScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelsManagementScreen.kt index 73265e906f..7af1b1d02c 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelsManagementScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelsManagementScreen.kt @@ -26,6 +26,8 @@ import androidx.compose.runtime.Composable import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateMapOf +import androidx.compose.runtime.remember import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.text.font.FontStyle @@ -33,6 +35,7 @@ import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp import androidx.compose.ui.window.DialogProperties +import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.ui.components.ModelCardExpandable import com.example.llama.revamp.ui.components.ScaffoldEvent import com.example.llama.revamp.util.formatFileByteSize @@ -57,6 +60,8 @@ fun ModelsManagementScreen( val isMultiSelectionMode by viewModel.isMultiSelectionMode.collectAsState() val selectedModels by viewModel.selectedModels.collectAsState() + var expandedModels = remember { mutableStateMapOf() } + BackHandler( enabled = isMultiSelectionMode || managementState is Importation.Importing @@ -73,33 +78,28 @@ fun ModelsManagementScreen( Box(modifier = Modifier.fillMaxSize()) { // Model cards LazyColumn( - modifier = Modifier - .fillMaxSize() - .padding(horizontal = 16.dp) + modifier = Modifier.fillMaxSize().padding(horizontal = 16.dp), + verticalArrangement = Arrangement.spacedBy(12.dp), ) { items(items = sortedModels, key = { it.id }) { model -> val isSelected = if (isMultiSelectionMode) selectedModels.contains(model.id) else null ModelCardExpandable( model = model, - onClick = { + isSelected = isSelected, + onSelected = { if (isMultiSelectionMode) { viewModel.toggleModelSelectionById(model.id) - } else { - viewModel.viewModelDetails(model) } }, - expanded = isSelected == true, - isSelected = isSelected, - // TODO-han.yin: refactor this -// actionButton = -// if (!isMultiSelectionMode) { -// { -// ModelCardActions.InfoButton( -// onClick = { viewModel.viewModelDetails(model) } -// ) -// } -// } else null + isExpanded = expandedModels.contains(model.id), + onExpanded = { expanded -> + if (expanded) { + expandedModels.put(model.id, model) + } else { + expandedModels.remove(model.id) + } + } ) } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt index 6e15a6e359..9383f6ab5d 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt @@ -2,6 +2,7 @@ package com.example.llama.revamp.viewmodel import androidx.lifecycle.viewModelScope import com.example.llama.revamp.data.model.SystemPrompt +import com.example.llama.revamp.data.repository.ModelRepository import com.example.llama.revamp.data.repository.SystemPromptRepository import com.example.llama.revamp.engine.ModelLoadingMetrics import com.example.llama.revamp.engine.ModelLoadingService @@ -15,7 +16,8 @@ import javax.inject.Inject @HiltViewModel class ModelLoadingViewModel @Inject constructor( private val modelLoadingService: ModelLoadingService, - private val repository: SystemPromptRepository + private val systemPromptRepository: SystemPromptRepository, + private val modelRepository: ModelRepository, ) : ModelUnloadingViewModel(modelLoadingService) { /** @@ -26,7 +28,7 @@ class ModelLoadingViewModel @Inject constructor( /** * Preset prompts */ - val presetPrompts: StateFlow> = repository.getPresetPrompts() + val presetPrompts: StateFlow> = systemPromptRepository.getPresetPrompts() .stateIn( scope = viewModelScope, started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS), @@ -36,7 +38,7 @@ class ModelLoadingViewModel @Inject constructor( /** * Recent prompts */ - val recentPrompts: StateFlow> = repository.getRecentPrompts() + val recentPrompts: StateFlow> = systemPromptRepository.getRecentPrompts() .stateIn( scope = viewModelScope, started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS), @@ -48,7 +50,7 @@ class ModelLoadingViewModel @Inject constructor( */ fun savePromptToRecents(prompt: SystemPrompt) { viewModelScope.launch { - repository.savePromptToRecents(prompt) + systemPromptRepository.savePromptToRecents(prompt) } } @@ -57,7 +59,7 @@ class ModelLoadingViewModel @Inject constructor( */ fun saveCustomPromptToRecents(content: String) { viewModelScope.launch { - repository.saveCustomPrompt(content) + systemPromptRepository.saveCustomPrompt(content) } } @@ -66,7 +68,7 @@ class ModelLoadingViewModel @Inject constructor( */ fun deletePrompt(id: String) { viewModelScope.launch { - repository.deletePrompt(id) + systemPromptRepository.deletePrompt(id) } } @@ -75,7 +77,7 @@ class ModelLoadingViewModel @Inject constructor( */ fun clearRecentPrompts() { viewModelScope.launch { - repository.deleteAllPrompts() + systemPromptRepository.deleteAllPrompts() } } @@ -84,6 +86,9 @@ class ModelLoadingViewModel @Inject constructor( */ fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) = viewModelScope.launch { + selectedModel.value?.let { + modelRepository.updateModelLastUsed(it.id) + } onNavigateToBenchmark(modelLoadingService.loadModelForBenchmark()) } @@ -95,6 +100,9 @@ class ModelLoadingViewModel @Inject constructor( systemPrompt: String? = null, onNavigateToConversation: (ModelLoadingMetrics) -> Unit ) = viewModelScope.launch { + selectedModel.value?.let { + modelRepository.updateModelLastUsed(it.id) + } onNavigateToConversation(modelLoadingService.loadModelForConversation(systemPrompt)) } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt index dd68bfc63e..304422f110 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt @@ -6,19 +6,24 @@ import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.repository.ModelRepository import com.example.llama.revamp.engine.InferenceService import dagger.hilt.android.lifecycle.HiltViewModel +import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.stateIn -import kotlinx.coroutines.launch +import kotlinx.coroutines.flow.update import javax.inject.Inject @HiltViewModel class ModelSelectionViewModel @Inject constructor( private val inferenceService: InferenceService, - private val modelRepository: ModelRepository + modelRepository: ModelRepository ) : ViewModel() { + private val _preselectedModel = MutableStateFlow(null) + val preselectedModel: StateFlow = _preselectedModel.asStateFlow() + /** * Available models for selection */ @@ -30,14 +35,24 @@ class ModelSelectionViewModel @Inject constructor( ) /** - * Select a model and update its last used timestamp + * Pre-select a model */ - fun selectModel(modelInfo: ModelInfo) { + fun preselectModel(modelInfo: ModelInfo, preselected: Boolean) = + _preselectedModel.update { current -> + if (preselected) modelInfo else null + } + + /** + * Confirm currently selected model + */ + fun confirmSelectedModel(modelInfo: ModelInfo) = inferenceService.setCurrentModel(modelInfo) - viewModelScope.launch { - modelRepository.updateModelLastUsed(modelInfo.id) - } + /** + * Reset selected model to none (before navigating away) + */ + fun resetSelection() { + _preselectedModel.value = null } companion object {