UI: update ModelSelectionScreen with a preselect mechanism

This commit is contained in:
Han Yin 2025-04-20 16:24:31 -07:00
parent b81a0c6e6d
commit c12ef7a779
6 changed files with 133 additions and 77 deletions

View File

@ -48,6 +48,7 @@ import com.example.llama.revamp.viewmodel.BenchmarkViewModel
import com.example.llama.revamp.viewmodel.ConversationViewModel import com.example.llama.revamp.viewmodel.ConversationViewModel
import com.example.llama.revamp.viewmodel.MainViewModel import com.example.llama.revamp.viewmodel.MainViewModel
import com.example.llama.revamp.viewmodel.ModelLoadingViewModel import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
import com.example.llama.revamp.viewmodel.ModelSelectionViewModel
import com.example.llama.revamp.viewmodel.ModelsManagementViewModel import com.example.llama.revamp.viewmodel.ModelsManagementViewModel
import com.example.llama.revamp.viewmodel.PerformanceViewModel import com.example.llama.revamp.viewmodel.PerformanceViewModel
import dagger.hilt.android.AndroidEntryPoint import dagger.hilt.android.AndroidEntryPoint
@ -74,6 +75,7 @@ class MainActivity : ComponentActivity() {
fun AppContent( fun AppContent(
mainViewModel: MainViewModel = hiltViewModel(), mainViewModel: MainViewModel = hiltViewModel(),
performanceViewModel: PerformanceViewModel = hiltViewModel(), performanceViewModel: PerformanceViewModel = hiltViewModel(),
modelSelectionViewModel: ModelSelectionViewModel = hiltViewModel(),
modelLoadingViewModel: ModelLoadingViewModel = hiltViewModel(), modelLoadingViewModel: ModelLoadingViewModel = hiltViewModel(),
benchmarkViewModel: BenchmarkViewModel = hiltViewModel(), benchmarkViewModel: BenchmarkViewModel = hiltViewModel(),
conversationViewModel: ConversationViewModel = hiltViewModel(), conversationViewModel: ConversationViewModel = hiltViewModel(),
@ -104,9 +106,7 @@ fun AppContent(
val drawerGesturesEnabled by remember(currentRoute, drawerState.currentValue) { val drawerGesturesEnabled by remember(currentRoute, drawerState.currentValue) {
derivedStateOf { derivedStateOf {
// Always allow gesture dismissal when drawer is open // Always allow gesture dismissal when drawer is open
if (drawerState.currentValue == DrawerValue.Open) true if (drawerState.currentValue == DrawerValue.Open) true else false
// Only enable drawer opening by gesture on these screens
else currentRoute == AppDestinations.MODEL_SELECTION_ROUTE
} }
} }
val openDrawer: () -> Unit = { coroutineScope.launch { drawerState.open() } } val openDrawer: () -> Unit = { coroutineScope.launch { drawerState.open() } }
@ -118,7 +118,10 @@ fun AppContent(
ScaffoldConfig( ScaffoldConfig(
topBarConfig = TopBarConfig.Default( topBarConfig = TopBarConfig.Default(
title = "Models", title = "Models",
navigationIcon = NavigationIcon.Menu(openDrawer) navigationIcon = NavigationIcon.Menu {
modelSelectionViewModel.resetSelection()
openDrawer()
}
) )
) )
@ -290,12 +293,13 @@ fun AppContent(
// Model Selection Screen // Model Selection Screen
composable(AppDestinations.MODEL_SELECTION_ROUTE) { composable(AppDestinations.MODEL_SELECTION_ROUTE) {
ModelSelectionScreen( ModelSelectionScreen(
onModelSelected = { modelInfo -> onModelConfirmed = { modelInfo ->
navigationActions.navigateToModelLoading() navigationActions.navigateToModelLoading()
}, },
onManageModelsClicked = { onManageModelsClicked = {
navigationActions.navigateToModelsManagement() navigationActions.navigateToModelsManagement()
}, },
viewModel = modelSelectionViewModel
) )
} }

View File

@ -140,25 +140,26 @@ fun ModelCardContentCore(
* toggled by clicking on the content area of the card. * toggled by clicking on the content area of the card.
* *
* @param model The model information to display * @param model The model information to display
* @param onClick Action to perform when the card is clicked (for selection)
* @param expanded Whether additional details are currently shown
* @param isSelected Optional selection state (shows checkbox when not null) * @param isSelected Optional selection state (shows checkbox when not null)
* @param onSelected Action to perform when the card is selected (in multi-selection mode)
* @param isExpanded Whether additional details is expanded or shrunk
* @param onExpanded Action to perform when the card is expanded or shrunk
*/ */
@Composable @Composable
fun ModelCardExpandable( fun ModelCardExpandable(
model: ModelInfo, model: ModelInfo,
onClick: () -> Unit,
expanded: Boolean,
isSelected: Boolean? = null, isSelected: Boolean? = null,
onSelected: ((Boolean) -> Unit)? = null,
isExpanded: Boolean = false,
onExpanded: ((Boolean) -> Unit)? = null,
) { ) {
var isExpanded by remember { mutableStateOf(expanded) }
CompositionLocalProvider(LocalMinimumInteractiveComponentSize provides Dp.Unspecified) { CompositionLocalProvider(LocalMinimumInteractiveComponentSize provides Dp.Unspecified) {
Card( Card(
modifier = Modifier modifier = Modifier
.fillMaxWidth() .fillMaxWidth()
.clickable { isExpanded = !isExpanded } .clickable {
, onExpanded?.invoke(!isExpanded)
},
colors = when (isSelected) { colors = when (isSelected) {
true -> CardDefaults.cardColors(containerColor = MaterialTheme.colorScheme.primaryContainer) true -> CardDefaults.cardColors(containerColor = MaterialTheme.colorScheme.primaryContainer)
false -> CardDefaults.cardColors() false -> CardDefaults.cardColors()
@ -173,16 +174,17 @@ fun ModelCardExpandable(
verticalAlignment = Alignment.Top verticalAlignment = Alignment.Top
) { ) {
// Show checkbox if in selection mode // Show checkbox if in selection mode
if (isSelected != null) { isSelected?.let { selected ->
Checkbox( Checkbox(
checked = isSelected, checked = selected,
onCheckedChange = { onClick() }, onCheckedChange = { onSelected?.invoke(it) },
modifier = Modifier.padding(top = 16.dp, start = 16.dp) modifier = Modifier.padding(top = 16.dp, start = 16.dp)
) )
} }
Box( Box(
modifier = Modifier.weight(1f) modifier = Modifier
.weight(1f)
.padding(start = 16.dp, top = 16.dp, end = 16.dp) .padding(start = 16.dp, top = 16.dp, end = 16.dp)
) { ) {
// Core content always visible // Core content always visible
@ -197,7 +199,9 @@ fun ModelCardExpandable(
exit = shrinkVertically() + fadeOut() exit = shrinkVertically() + fadeOut()
) { ) {
Box( Box(
modifier = Modifier.weight(1f).padding(horizontal = 16.dp) modifier = Modifier
.weight(1f)
.padding(horizontal = 16.dp)
) { ) {
ExpandableModelDetails(model = model) ExpandableModelDetails(model = model)
} }

View File

@ -1,6 +1,12 @@
package com.example.llama.revamp.ui.screens package com.example.llama.revamp.ui.screens
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
@ -13,8 +19,10 @@ import androidx.compose.foundation.lazy.items
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.filled.FolderOpen import androidx.compose.material.icons.filled.FolderOpen
import androidx.compose.material.icons.filled.PlayArrow
import androidx.compose.material3.Button import androidx.compose.material3.Button
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.FloatingActionButton
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text import androidx.compose.material3.Text
@ -25,7 +33,6 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.ui.components.ModelCardExpandable import com.example.llama.revamp.ui.components.ModelCardExpandable
import com.example.llama.revamp.viewmodel.ModelSelectionViewModel import com.example.llama.revamp.viewmodel.ModelSelectionViewModel
@ -33,43 +40,61 @@ import com.example.llama.revamp.viewmodel.ModelSelectionViewModel
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun ModelSelectionScreen( fun ModelSelectionScreen(
onModelSelected: (ModelInfo) -> Unit, onModelConfirmed: (ModelInfo) -> Unit,
onManageModelsClicked: () -> Unit, onManageModelsClicked: () -> Unit,
viewModel: ModelSelectionViewModel = hiltViewModel(), viewModel: ModelSelectionViewModel,
) { ) {
val models by viewModel.availableModels.collectAsState() val models by viewModel.availableModels.collectAsState()
val preselectedModel by viewModel.preselectedModel.collectAsState()
val handleModelSelection = { model: ModelInfo -> val handleModelSelection = { model: ModelInfo ->
viewModel.selectModel(model) viewModel.confirmSelectedModel(model)
onModelSelected(model) onModelConfirmed(model)
} }
Column( Box(modifier = Modifier.fillMaxSize()) {
modifier = Modifier Column(
.fillMaxSize() modifier = Modifier
.padding(horizontal = 16.dp) .fillMaxSize()
) { .padding(horizontal = 16.dp)
if (models.isEmpty()) { ) {
EmptyModelsView(onManageModelsClicked) if (models.isEmpty()) {
} else { EmptyModelsView(onManageModelsClicked)
LazyColumn { } else {
items(models) { model -> LazyColumn(
ModelCardExpandable( verticalArrangement = Arrangement.spacedBy(12.dp)
model = model, ) {
onClick = { handleModelSelection(model) }, items(items = models, key = { it.id }) { model ->
expanded = false, ModelCardExpandable(
isSelected = null, // Not in selection mode model = model,
// TODO-han.yin: refactor this isSelected = if (model == preselectedModel) true else null,
// actionButton = { isExpanded = model == preselectedModel,
// ModelCardActions.PlayButton { onExpanded = { expanded ->
// handleModelSelection(model) viewModel.preselectModel(model, expanded)
// } }
// }, )
) }
Spacer(modifier = Modifier.height(8.dp))
} }
} }
} }
// Show FAB if a model is selected
AnimatedVisibility(
modifier = Modifier.padding(16.dp).align(Alignment.BottomEnd),
visible = preselectedModel != null,
enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut()
) {
FloatingActionButton(
onClick = { preselectedModel?.let { handleModelSelection(it) } },
containerColor = MaterialTheme.colorScheme.primary
) {
Icon(
imageVector = Icons.Default.PlayArrow,
contentDescription = "Start with selected model"
)
}
}
} }
} }

View File

@ -26,6 +26,8 @@ import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateMapOf
import androidx.compose.runtime.remember
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.text.font.FontStyle import androidx.compose.ui.text.font.FontStyle
@ -33,6 +35,7 @@ import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.DialogProperties import androidx.compose.ui.window.DialogProperties
import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.ui.components.ModelCardExpandable import com.example.llama.revamp.ui.components.ModelCardExpandable
import com.example.llama.revamp.ui.components.ScaffoldEvent import com.example.llama.revamp.ui.components.ScaffoldEvent
import com.example.llama.revamp.util.formatFileByteSize import com.example.llama.revamp.util.formatFileByteSize
@ -57,6 +60,8 @@ fun ModelsManagementScreen(
val isMultiSelectionMode by viewModel.isMultiSelectionMode.collectAsState() val isMultiSelectionMode by viewModel.isMultiSelectionMode.collectAsState()
val selectedModels by viewModel.selectedModels.collectAsState() val selectedModels by viewModel.selectedModels.collectAsState()
var expandedModels = remember { mutableStateMapOf<String, ModelInfo>() }
BackHandler( BackHandler(
enabled = isMultiSelectionMode enabled = isMultiSelectionMode
|| managementState is Importation.Importing || managementState is Importation.Importing
@ -73,33 +78,28 @@ fun ModelsManagementScreen(
Box(modifier = Modifier.fillMaxSize()) { Box(modifier = Modifier.fillMaxSize()) {
// Model cards // Model cards
LazyColumn( LazyColumn(
modifier = Modifier modifier = Modifier.fillMaxSize().padding(horizontal = 16.dp),
.fillMaxSize() verticalArrangement = Arrangement.spacedBy(12.dp),
.padding(horizontal = 16.dp)
) { ) {
items(items = sortedModels, key = { it.id }) { model -> items(items = sortedModels, key = { it.id }) { model ->
val isSelected = if (isMultiSelectionMode) selectedModels.contains(model.id) else null val isSelected = if (isMultiSelectionMode) selectedModels.contains(model.id) else null
ModelCardExpandable( ModelCardExpandable(
model = model, model = model,
onClick = { isSelected = isSelected,
onSelected = {
if (isMultiSelectionMode) { if (isMultiSelectionMode) {
viewModel.toggleModelSelectionById(model.id) viewModel.toggleModelSelectionById(model.id)
} else {
viewModel.viewModelDetails(model)
} }
}, },
expanded = isSelected == true, isExpanded = expandedModels.contains(model.id),
isSelected = isSelected, onExpanded = { expanded ->
// TODO-han.yin: refactor this if (expanded) {
// actionButton = expandedModels.put(model.id, model)
// if (!isMultiSelectionMode) { } else {
// { expandedModels.remove(model.id)
// ModelCardActions.InfoButton( }
// onClick = { viewModel.viewModelDetails(model) } }
// )
// }
// } else null
) )
} }
} }

View File

@ -2,6 +2,7 @@ package com.example.llama.revamp.viewmodel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.data.model.SystemPrompt
import com.example.llama.revamp.data.repository.ModelRepository
import com.example.llama.revamp.data.repository.SystemPromptRepository import com.example.llama.revamp.data.repository.SystemPromptRepository
import com.example.llama.revamp.engine.ModelLoadingMetrics import com.example.llama.revamp.engine.ModelLoadingMetrics
import com.example.llama.revamp.engine.ModelLoadingService import com.example.llama.revamp.engine.ModelLoadingService
@ -15,7 +16,8 @@ import javax.inject.Inject
@HiltViewModel @HiltViewModel
class ModelLoadingViewModel @Inject constructor( class ModelLoadingViewModel @Inject constructor(
private val modelLoadingService: ModelLoadingService, private val modelLoadingService: ModelLoadingService,
private val repository: SystemPromptRepository private val systemPromptRepository: SystemPromptRepository,
private val modelRepository: ModelRepository,
) : ModelUnloadingViewModel(modelLoadingService) { ) : ModelUnloadingViewModel(modelLoadingService) {
/** /**
@ -26,7 +28,7 @@ class ModelLoadingViewModel @Inject constructor(
/** /**
* Preset prompts * Preset prompts
*/ */
val presetPrompts: StateFlow<List<SystemPrompt>> = repository.getPresetPrompts() val presetPrompts: StateFlow<List<SystemPrompt>> = systemPromptRepository.getPresetPrompts()
.stateIn( .stateIn(
scope = viewModelScope, scope = viewModelScope,
started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS), started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS),
@ -36,7 +38,7 @@ class ModelLoadingViewModel @Inject constructor(
/** /**
* Recent prompts * Recent prompts
*/ */
val recentPrompts: StateFlow<List<SystemPrompt>> = repository.getRecentPrompts() val recentPrompts: StateFlow<List<SystemPrompt>> = systemPromptRepository.getRecentPrompts()
.stateIn( .stateIn(
scope = viewModelScope, scope = viewModelScope,
started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS), started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS),
@ -48,7 +50,7 @@ class ModelLoadingViewModel @Inject constructor(
*/ */
fun savePromptToRecents(prompt: SystemPrompt) { fun savePromptToRecents(prompt: SystemPrompt) {
viewModelScope.launch { viewModelScope.launch {
repository.savePromptToRecents(prompt) systemPromptRepository.savePromptToRecents(prompt)
} }
} }
@ -57,7 +59,7 @@ class ModelLoadingViewModel @Inject constructor(
*/ */
fun saveCustomPromptToRecents(content: String) { fun saveCustomPromptToRecents(content: String) {
viewModelScope.launch { viewModelScope.launch {
repository.saveCustomPrompt(content) systemPromptRepository.saveCustomPrompt(content)
} }
} }
@ -66,7 +68,7 @@ class ModelLoadingViewModel @Inject constructor(
*/ */
fun deletePrompt(id: String) { fun deletePrompt(id: String) {
viewModelScope.launch { viewModelScope.launch {
repository.deletePrompt(id) systemPromptRepository.deletePrompt(id)
} }
} }
@ -75,7 +77,7 @@ class ModelLoadingViewModel @Inject constructor(
*/ */
fun clearRecentPrompts() { fun clearRecentPrompts() {
viewModelScope.launch { viewModelScope.launch {
repository.deleteAllPrompts() systemPromptRepository.deleteAllPrompts()
} }
} }
@ -84,6 +86,9 @@ class ModelLoadingViewModel @Inject constructor(
*/ */
fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) = fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) =
viewModelScope.launch { viewModelScope.launch {
selectedModel.value?.let {
modelRepository.updateModelLastUsed(it.id)
}
onNavigateToBenchmark(modelLoadingService.loadModelForBenchmark()) onNavigateToBenchmark(modelLoadingService.loadModelForBenchmark())
} }
@ -95,6 +100,9 @@ class ModelLoadingViewModel @Inject constructor(
systemPrompt: String? = null, systemPrompt: String? = null,
onNavigateToConversation: (ModelLoadingMetrics) -> Unit onNavigateToConversation: (ModelLoadingMetrics) -> Unit
) = viewModelScope.launch { ) = viewModelScope.launch {
selectedModel.value?.let {
modelRepository.updateModelLastUsed(it.id)
}
onNavigateToConversation(modelLoadingService.loadModelForConversation(systemPrompt)) onNavigateToConversation(modelLoadingService.loadModelForConversation(systemPrompt))
} }

View File

@ -6,19 +6,24 @@ import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.data.repository.ModelRepository import com.example.llama.revamp.data.repository.ModelRepository
import com.example.llama.revamp.engine.InferenceService import com.example.llama.revamp.engine.InferenceService
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.stateIn import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.launch import kotlinx.coroutines.flow.update
import javax.inject.Inject import javax.inject.Inject
@HiltViewModel @HiltViewModel
class ModelSelectionViewModel @Inject constructor( class ModelSelectionViewModel @Inject constructor(
private val inferenceService: InferenceService, private val inferenceService: InferenceService,
private val modelRepository: ModelRepository modelRepository: ModelRepository
) : ViewModel() { ) : ViewModel() {
private val _preselectedModel = MutableStateFlow<ModelInfo?>(null)
val preselectedModel: StateFlow<ModelInfo?> = _preselectedModel.asStateFlow()
/** /**
* Available models for selection * Available models for selection
*/ */
@ -30,14 +35,24 @@ class ModelSelectionViewModel @Inject constructor(
) )
/** /**
* Select a model and update its last used timestamp * Pre-select a model
*/ */
fun selectModel(modelInfo: ModelInfo) { fun preselectModel(modelInfo: ModelInfo, preselected: Boolean) =
_preselectedModel.update { current ->
if (preselected) modelInfo else null
}
/**
* Confirm currently selected model
*/
fun confirmSelectedModel(modelInfo: ModelInfo) =
inferenceService.setCurrentModel(modelInfo) inferenceService.setCurrentModel(modelInfo)
viewModelScope.launch { /**
modelRepository.updateModelLastUsed(modelInfo.id) * Reset selected model to none (before navigating away)
} */
fun resetSelection() {
_preselectedModel.value = null
} }
companion object { companion object {