From 77edad5a019ee3a7aad988082926f24dc7034bca Mon Sep 17 00:00:00 2001 From: Han Yin Date: Sun, 20 Apr 2025 21:51:23 -0700 Subject: [PATCH] feature: support searching on Model Selection screen --- .../com/example/llama/revamp/MainActivity.kt | 61 ++++- .../llama/revamp/ui/components/AppScaffold.kt | 11 + .../revamp/ui/components/BottomAppBars.kt | 180 +++++++++++++- .../llama/revamp/ui/components/TopAppBars.kt | 7 +- .../revamp/ui/screens/ModelSelectionScreen.kt | 225 ++++++++++++++---- .../viewmodel/ModelSelectionViewModel.kt | 165 ++++++++++++- 6 files changed, 575 insertions(+), 74 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt index 1f17b7b7d2..8e104fbecb 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt @@ -118,16 +118,60 @@ fun AppContent( // Create scaffold's top & bottom bar configs based on current route val scaffoldConfig = when { // Model selection screen - currentRoute == AppDestinations.MODEL_SELECTION_ROUTE -> + currentRoute == AppDestinations.MODEL_SELECTION_ROUTE -> { + // Collect states for bottom bar + val isSearchActive by modelSelectionViewModel.isSearchActive.collectAsState() + val sortOrder by modelSelectionViewModel.sortOrder.collectAsState() + val showSortMenu by modelSelectionViewModel.showSortMenu.collectAsState() + val activeFilters by modelSelectionViewModel.activeFilters.collectAsState() + val showFilterMenu by modelSelectionViewModel.showFilterMenu.collectAsState() + val preselectedModel by modelSelectionViewModel.preselectedModel.collectAsState() + ScaffoldConfig( - topBarConfig = TopBarConfig.Default( - title = "Models", - navigationIcon = NavigationIcon.Menu { - modelSelectionViewModel.resetSelection() - openDrawer() - } + topBarConfig = + if (isSearchActive) TopBarConfig.None() + else TopBarConfig.Default( + title = "Select a Model", + navigationIcon = NavigationIcon.Menu { + modelSelectionViewModel.resetSelection() + openDrawer() + } + ), + bottomBarConfig = BottomBarConfig.ModelSelection( + search = BottomBarConfig.ModelSelection.SearchConfig( + isActive = isSearchActive, + onToggleSearch = modelSelectionViewModel::toggleSearchState, + textFieldState = modelSelectionViewModel.searchFieldState, + onSearch = { /* No-op for now */ } + ), + sorting = BottomBarConfig.ModelSelection.SortingConfig( + currentOrder = sortOrder, + isMenuVisible = showSortMenu, + toggleMenu = modelSelectionViewModel::toggleSortMenu, + selectOrder = { + modelSelectionViewModel.setSortOrder(it) + modelSelectionViewModel.toggleSortMenu(false) + } + ), + filtering = BottomBarConfig.ModelSelection.FilteringConfig( + isActive = activeFilters.any { it.value }, + filters = activeFilters, + onToggleFilter = modelSelectionViewModel::toggleFilter, + onClearFilters = modelSelectionViewModel::clearFilters, + isMenuVisible = showFilterMenu, + toggleMenu = modelSelectionViewModel::toggleFilterMenu + ), + runAction = BottomBarConfig.ModelSelection.RunActionConfig( + selectedModel = preselectedModel, + onRun = { model -> + modelSelectionViewModel.confirmSelectedModel(model) + navigationActions.navigateToModelLoading() + modelSelectionViewModel.toggleSearchState(false) + } + ) ) ) + } // Model loading screen currentRoute == AppDestinations.MODEL_LOADING_ROUTE -> @@ -297,9 +341,6 @@ fun AppContent( // Model Selection Screen composable(AppDestinations.MODEL_SELECTION_ROUTE) { ModelSelectionScreen( - onModelConfirmed = { modelInfo -> - navigationActions.navigateToModelLoading() - }, onManageModelsClicked = { navigationActions.navigateToModelsManagement() }, diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/AppScaffold.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/AppScaffold.kt index 98c5f374ce..5f4b5f3214 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/AppScaffold.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/AppScaffold.kt @@ -40,6 +40,8 @@ fun AppScaffold( ) { val topBar: @Composable () -> Unit = { when (val topConfig = topBarconfig) { + is TopBarConfig.None -> {} + is TopBarConfig.Default -> DefaultTopBar( title = topBarconfig.title, onNavigateBack = topConfig.navigationIcon.backAction, @@ -66,6 +68,15 @@ fun AppScaffold( when (val config = bottomBarConfig) { is BottomBarConfig.None -> { /* No bottom bar */ } + is BottomBarConfig.ModelSelection -> { + ModelSelectionBottomBar( + search = config.search, + sorting = config.sorting, + filtering = config.filtering, + runAction = config.runAction + ) + } + is BottomBarConfig.ModelsManagement -> { ModelsManagementBottomBar( sorting = config.sorting, diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/BottomAppBars.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/BottomAppBars.kt index e74885ebc1..698ab2eac7 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/BottomAppBars.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/BottomAppBars.kt @@ -1,8 +1,12 @@ package com.example.llama.revamp.ui.components +import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.size +import androidx.compose.foundation.text.input.TextFieldState +import androidx.compose.foundation.text.input.clearText import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.filled.Sort +import androidx.compose.material.icons.automirrored.outlined.Backspace import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.filled.Check import androidx.compose.material.icons.filled.ClearAll @@ -11,11 +15,17 @@ import androidx.compose.material.icons.filled.Delete import androidx.compose.material.icons.filled.DeleteSweep import androidx.compose.material.icons.filled.FilterAlt import androidx.compose.material.icons.filled.FolderOpen +import androidx.compose.material.icons.filled.PlayArrow +import androidx.compose.material.icons.filled.Search +import androidx.compose.material.icons.filled.SearchOff import androidx.compose.material.icons.filled.SelectAll +import androidx.compose.material.icons.outlined.FilterAlt import androidx.compose.material3.BottomAppBar +import androidx.compose.material3.Checkbox import androidx.compose.material3.DropdownMenu import androidx.compose.material3.DropdownMenuItem import androidx.compose.material3.FloatingActionButton +import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.Icon import androidx.compose.material3.IconButton import androidx.compose.material3.MaterialTheme @@ -37,6 +47,41 @@ sealed class BottomBarConfig { object None : BottomBarConfig() + data class ModelSelection( + val search: SearchConfig, + val sorting: SortingConfig, + val filtering: FilteringConfig, + val runAction: RunActionConfig + ) : BottomBarConfig() { + data class SearchConfig( + val isActive: Boolean, + val onToggleSearch: (Boolean) -> Unit, + val textFieldState: TextFieldState, + val onSearch: (String) -> Unit, + ) + + data class SortingConfig( + val currentOrder: ModelSortOrder, + val isMenuVisible: Boolean, + val toggleMenu: (Boolean) -> Unit, + val selectOrder: (ModelSortOrder) -> Unit + ) + + data class FilteringConfig( + val isActive: Boolean, + val filters: Map, // Filter name -> enabled + val onToggleFilter: (String, Boolean) -> Unit, + val onClearFilters: () -> Unit, + val isMenuVisible: Boolean, + val toggleMenu: (Boolean) -> Unit + ) + + data class RunActionConfig( + val selectedModel: ModelInfo?, + val onRun: (ModelInfo) -> Unit + ) + } + data class ModelsManagement( val sorting: SortingConfig, val filtering: FilteringConfig, @@ -70,7 +115,140 @@ sealed class BottomBarConfig { ) } - // TODO-han.yin: add more bottom bar types here + // TODO-han.yin: add bottom bar config for Conversation Screen! +} + +@Composable +fun ModelSelectionBottomBar( + search: BottomBarConfig.ModelSelection.SearchConfig, + sorting: BottomBarConfig.ModelSelection.SortingConfig, + filtering: BottomBarConfig.ModelSelection.FilteringConfig, + runAction: BottomBarConfig.ModelSelection.RunActionConfig +) { + BottomAppBar( + actions = { + if (search.isActive) { + // Quit search action + IconButton(onClick = { search.onToggleSearch(false) }) { + Icon( + imageVector = Icons.Default.SearchOff, + contentDescription = "Quit search mode" + ) + } + + // Clear query action + IconButton(onClick = { search.textFieldState.clearText() }) { + Icon( + imageVector = Icons.AutoMirrored.Outlined.Backspace, + contentDescription = "Clear query text" + ) + } + } else { + // Enter search action + IconButton(onClick = { search.onToggleSearch(true) }) { + Icon( + imageVector = Icons.Default.Search, + contentDescription = "Search models" + ) + } + + // Sorting action + IconButton(onClick = { sorting.toggleMenu(true) }) { + Icon( + imageVector = Icons.AutoMirrored.Filled.Sort, + contentDescription = "Sort models" + ) + } + + // Sorting dropdown menu + DropdownMenu( + expanded = sorting.isMenuVisible, + onDismissRequest = { sorting.toggleMenu(false) } + ) { + val sortOptions = listOf( + Triple(ModelSortOrder.NAME_ASC, "Name (A-Z)", "Sort by name in ascending order"), + Triple(ModelSortOrder.NAME_DESC, "Name (Z-A)", "Sort by name in descending order"), + Triple(ModelSortOrder.SIZE_ASC, "Size (Smallest first)", "Sort by size in ascending order"), + Triple(ModelSortOrder.SIZE_DESC, "Size (Largest first)", "Sort by size in descending order"), + Triple(ModelSortOrder.LAST_USED, "Last used", "Sort by last used") + ) + + sortOptions.forEach { (order, label, description) -> + DropdownMenuItem( + text = { Text(label) }, + trailingIcon = { + if (sorting.currentOrder == order) + Icon( + imageVector = Icons.Default.Check, + contentDescription = "$description, selected" + ) + }, + onClick = { sorting.selectOrder(order) } + ) + } + } + + // Filter action + IconButton(onClick = { filtering.toggleMenu(true) }) { + Icon( + imageVector = + if (filtering.isActive) Icons.Default.FilterAlt + else Icons.Outlined.FilterAlt, + contentDescription = "Filter models" + ) + } + + // Filter dropdown menu + DropdownMenu( + expanded = filtering.isMenuVisible, + onDismissRequest = { filtering.toggleMenu(false) } + ) { + Text( + text = "Filter by", + style = MaterialTheme.typography.labelMedium, + modifier = Modifier.padding(horizontal = 16.dp, vertical = 8.dp) + ) + + filtering.filters.forEach { (filter, isEnabled) -> + DropdownMenuItem( + text = { Text(filter) }, + leadingIcon = { + Checkbox( + checked = isEnabled, + onCheckedChange = null + ) + }, + onClick = { filtering.onToggleFilter(filter, !isEnabled) } + ) + } + + HorizontalDivider() + + DropdownMenuItem( + text = { Text("Clear filters") }, + onClick = { + filtering.onClearFilters() + filtering.toggleMenu(false) + } + ) + } + } + }, + floatingActionButton = { + // Only show FAB if a model is selected + runAction.selectedModel?.let { model -> + FloatingActionButton( + onClick = { runAction.onRun(model) }, + containerColor = MaterialTheme.colorScheme.primary + ) { + Icon( + imageVector = Icons.Default.PlayArrow, + contentDescription = "Run with selected model" + ) + } + } + } + ) } @Composable diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/TopAppBars.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/TopAppBars.kt index faac83e94e..5a3683fd92 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/TopAppBars.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/components/TopAppBars.kt @@ -14,7 +14,6 @@ import androidx.compose.material.icons.filled.WarningAmber import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon import androidx.compose.material3.IconButton -import androidx.compose.material3.Label import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Text import androidx.compose.material3.TopAppBar @@ -36,6 +35,12 @@ sealed class TopBarConfig { abstract val title: String abstract val navigationIcon: NavigationIcon + // No top bar at all + data class None( + override val title: String = "", + override val navigationIcon: NavigationIcon = NavigationIcon.None + ) : TopBarConfig() + // Default/simple top bar with only a navigation icon data class Default( override val title: String, diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt index 875e772d50..7dad1e27f3 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt @@ -1,71 +1,174 @@ package com.example.llama.revamp.ui.screens import androidx.activity.compose.BackHandler -import androidx.compose.animation.AnimatedVisibility -import androidx.compose.animation.fadeIn -import androidx.compose.animation.fadeOut -import androidx.compose.animation.scaleIn -import androidx.compose.animation.scaleOut import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.PaddingValues import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.width import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.text.input.clearText import androidx.compose.material.icons.Icons import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.filled.FolderOpen -import androidx.compose.material.icons.filled.PlayArrow +import androidx.compose.material.icons.filled.MoreVert +import androidx.compose.material.icons.filled.Search +import androidx.compose.material.icons.filled.SearchOff import androidx.compose.material3.Button +import androidx.compose.material3.DockedSearchBar import androidx.compose.material3.ExperimentalMaterial3Api -import androidx.compose.material3.FloatingActionButton import androidx.compose.material3.Icon import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.SearchBarDefaults import androidx.compose.material3.Text import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.collectAsState +import androidx.compose.runtime.derivedStateOf import androidx.compose.runtime.getValue +import androidx.compose.runtime.remember import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.focus.FocusRequester +import androidx.compose.ui.focus.focusRequester +import androidx.compose.ui.platform.LocalSoftwareKeyboardController import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.unit.dp -import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.ui.components.ModelCardFullExpandable import com.example.llama.revamp.viewmodel.ModelSelectionViewModel @OptIn(ExperimentalMaterial3Api::class) @Composable fun ModelSelectionScreen( - onModelConfirmed: (ModelInfo) -> Unit, onManageModelsClicked: () -> Unit, viewModel: ModelSelectionViewModel, ) { - val models by viewModel.availableModels.collectAsState() + val filteredModels by viewModel.filteredModels.collectAsState() val preselectedModel by viewModel.preselectedModel.collectAsState() - // Handle back button press - BackHandler(preselectedModel != null) { - viewModel.onBackPressed() + val textFieldState = viewModel.searchFieldState + val isSearchActive by viewModel.isSearchActive.collectAsState() + val searchQuery by remember(textFieldState) { + derivedStateOf { textFieldState.text.toString() } + } + val queryResults by viewModel.queryResults.collectAsState() + + val activeFilters by viewModel.activeFilters.collectAsState() + val activeFiltersCount by remember(activeFilters) { + derivedStateOf { activeFilters.count { it.value } } } - Box(modifier = Modifier.fillMaxSize()) { - Column( - modifier = Modifier - .fillMaxSize() - .padding(horizontal = 16.dp) - ) { - if (models.isEmpty()) { - EmptyModelsView(onManageModelsClicked) + + val focusRequester = remember { FocusRequester() } + val keyboardController = LocalSoftwareKeyboardController.current + + val toggleSearchFocusAndIme: (Boolean) -> Unit = { show -> + if (show) { + focusRequester.requestFocus() + keyboardController?.show() + } else { + focusRequester.freeFocus() + keyboardController?.hide() + } + } + + // Handle back button press + BackHandler(preselectedModel != null || isSearchActive) { + if (isSearchActive) { + viewModel.toggleSearchState(false) + } else { + viewModel.onBackPressed() + } + } + + LaunchedEffect (isSearchActive) { + if (isSearchActive) { + toggleSearchFocusAndIme(true) + } + } + + Box( + modifier = Modifier.fillMaxSize() + ) { + if (isSearchActive) { + DockedSearchBar( + modifier = Modifier.align(Alignment.TopCenter), + inputField = { + SearchBarDefaults.InputField( + modifier = Modifier.focusRequester(focusRequester), + query = textFieldState.text.toString(), + onQueryChange = { textFieldState.edit { replace(0, length, it) } }, + onSearch = {}, + expanded = true, + onExpandedChange = { expanded -> + viewModel.toggleSearchState(expanded) + textFieldState.clearText() + }, + leadingIcon = { Icon(Icons.Default.Search, contentDescription = null) }, + trailingIcon = { Icon(Icons.Default.MoreVert, contentDescription = null) }, + placeholder = { Text("Type to search your models") } + ) + }, + expanded = true, + onExpandedChange = { + viewModel.toggleSearchState(it) + } + ) { + if (queryResults.isEmpty()) { + if (searchQuery.isNotBlank()) { + // Show "no results" message + EmptySearchResultsView( + onClearSearch = { + textFieldState.clearText() + toggleSearchFocusAndIme(true) + } + ) + } + } else { + LazyColumn( + Modifier.fillMaxSize(), + verticalArrangement = Arrangement.spacedBy(12.dp), + contentPadding = PaddingValues(vertical = 12.dp, horizontal = 16.dp), + ) { + items(items = queryResults, key = { it.id }) { model -> + ModelCardFullExpandable( + model = model, + isSelected = if (model == preselectedModel) true else null, + onSelected = { selected -> + if (selected) { + toggleSearchFocusAndIme(false) + } else { + viewModel.resetSelection() + toggleSearchFocusAndIme(true) + } + }, + isExpanded = model == preselectedModel, + onExpanded = { expanded -> + viewModel.preselectModel(model, expanded) + toggleSearchFocusAndIme(!expanded) + } + ) + } + } + } + } + } else { + if (filteredModels.isEmpty()) { + EmptyModelsView(activeFiltersCount, onManageModelsClicked) } else { LazyColumn( - verticalArrangement = Arrangement.spacedBy(12.dp) + Modifier.fillMaxSize(), // .padding(horizontal = 16.dp), + verticalArrangement = Arrangement.spacedBy(12.dp), + contentPadding = PaddingValues(vertical = 12.dp, horizontal = 16.dp), ) { - items(items = models, key = { it.id }) { model -> + items(items = filteredModels, key = { it.id }) { model -> ModelCardFullExpandable( model = model, isSelected = if (model == preselectedModel) true else null, @@ -81,38 +184,56 @@ fun ModelSelectionScreen( } } } + } +} - // Show FAB if a model is selected - AnimatedVisibility( - modifier = Modifier.padding(16.dp).align(Alignment.BottomEnd), - visible = preselectedModel != null, - enter = scaleIn() + fadeIn(), - exit = scaleOut() + fadeOut() - ) { - FloatingActionButton( - onClick = { - preselectedModel?.let { - viewModel.confirmSelectedModel(it) - onModelConfirmed(it) - } - }, - containerColor = MaterialTheme.colorScheme.primary - ) { - Icon( - imageVector = Icons.Default.PlayArrow, - contentDescription = "Start with selected model" - ) - } +@Composable +private fun EmptySearchResultsView( + onClearSearch: () -> Unit +) { + Column( + modifier = Modifier.fillMaxWidth().padding(horizontal = 16.dp, vertical = 32.dp), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center + ) { + Icon( + imageVector = Icons.Default.SearchOff, + contentDescription = null, + modifier = Modifier.size(64.dp), + tint = MaterialTheme.colorScheme.primary.copy(alpha = 0.6f) + ) + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + text = "No matching models found", + style = MaterialTheme.typography.headlineSmall + ) + + Spacer(modifier = Modifier.height(8.dp)) + + Text( + text = "Try a different search term", + style = MaterialTheme.typography.bodyLarge, + textAlign = TextAlign.Center, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(24.dp)) + + Button(onClick = onClearSearch) { + Text("Clear Search") } } } @Composable -private fun EmptyModelsView(onManageModelsClicked: () -> Unit) { +private fun EmptyModelsView( + activeFiltersCount: Int, + onManageModelsClicked: () -> Unit +) { Column( - modifier = Modifier - .fillMaxSize() - .padding(16.dp), + modifier = Modifier.fillMaxSize().padding(16.dp), horizontalAlignment = Alignment.CenterHorizontally, verticalArrangement = Arrangement.Center ) { @@ -133,8 +254,12 @@ private fun EmptyModelsView(onManageModelsClicked: () -> Unit) { Spacer(modifier = Modifier.height(8.dp)) Text( - text = "Add models to get started with local LLM inference", - style = MaterialTheme.typography.bodyMedium, + text = when (activeFiltersCount) { + 0 -> "Import some models to get started!" + 1 -> "No models match the selected filter" + else -> "No models match the selected filters" + }, + style = MaterialTheme.typography.bodyLarge, textAlign = TextAlign.Center, color = MaterialTheme.colorScheme.onSurfaceVariant ) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt index 352db8502e..62d769fc21 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt @@ -1,38 +1,179 @@ package com.example.llama.revamp.viewmodel +import androidx.compose.foundation.text.input.TextFieldState +import androidx.compose.foundation.text.input.clearText +import androidx.compose.runtime.snapshotFlow import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.repository.ModelRepository import com.example.llama.revamp.engine.InferenceService import dagger.hilt.android.lifecycle.HiltViewModel +import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.flow.MutableStateFlow -import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow -import kotlinx.coroutines.flow.stateIn +import kotlinx.coroutines.flow.collectLatest +import kotlinx.coroutines.flow.combine +import kotlinx.coroutines.flow.debounce import kotlinx.coroutines.flow.update +import kotlinx.coroutines.launch import javax.inject.Inject +@OptIn(FlowPreview::class) @HiltViewModel class ModelSelectionViewModel @Inject constructor( private val inferenceService: InferenceService, modelRepository: ModelRepository ) : ViewModel() { + // UI state: search mode + private val _isSearchActive = MutableStateFlow(false) + val isSearchActive: StateFlow = _isSearchActive.asStateFlow() + + fun toggleSearchState(active: Boolean) { + _isSearchActive.value = active + if (active) { + resetSelection() + } else { + searchFieldState.clearText() + } + } + + val searchFieldState = TextFieldState() + + // UI state: sort menu + private val _sortOrder = MutableStateFlow(ModelSortOrder.LAST_USED) + val sortOrder: StateFlow = _sortOrder.asStateFlow() + + fun setSortOrder(order: ModelSortOrder) { + _sortOrder.value = order + } + + private val _showSortMenu = MutableStateFlow(false) + val showSortMenu: StateFlow = _showSortMenu.asStateFlow() + + fun toggleSortMenu(visible: Boolean) { + _showSortMenu.value = visible + } + + // UI state: filters + // TODO-han.yin: Refactor this into Enums! + private val _activeFilters = MutableStateFlow>(mapOf( + "Has context length" to false, + "Support system prompt" to false, + "7B models" to false, + "13B models" to false, + "70B models" to false + )) + val activeFilters: StateFlow> = _activeFilters.asStateFlow() + + fun toggleFilter(filter: String, enabled: Boolean) { + _activeFilters.update { current -> + current.toMutableMap().apply { + this[filter] = enabled + } + } + } + + fun clearFilters() { + _activeFilters.update { current -> + current.mapValues { false } + } + } + + private val _showFilterMenu = MutableStateFlow(false) + val showFilterMenu: StateFlow = _showFilterMenu.asStateFlow() + + fun toggleFilterMenu(visible: Boolean) { + _showFilterMenu.value = visible + } + + // Data: filtered & sorted models + private val _filteredModels = MutableStateFlow>(emptyList()) + val filteredModels: StateFlow> = _filteredModels.asStateFlow() + + // Data: queried models + private val _queryResults = MutableStateFlow>(emptyList()) + val queryResults: StateFlow> = _queryResults.asStateFlow() + + // Data: pre-selected model in expansion mode private val _preselectedModel = MutableStateFlow(null) val preselectedModel: StateFlow = _preselectedModel.asStateFlow() - /** - * Available models for selection - */ - val availableModels: StateFlow> = modelRepository.getModels() - .stateIn( - scope = viewModelScope, - started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS), - initialValue = emptyList() - ) + init { + viewModelScope.launch { + combine( + modelRepository.getModels(), + _activeFilters, + _sortOrder, + ) { models, filters, sortOrder -> + models.filterBy(filters).sortByOrder(sortOrder) + }.collect { + _filteredModels.value = it + } + } + + viewModelScope.launch { + combine( + modelRepository.getModels(), + snapshotFlow { searchFieldState.text }.debounce(QUERY_DEBOUNCE_TIMEOUT_MS) + ) { models, query -> + if (query.isBlank()) { + emptyList() + } else { + models.queryBy(query.toString()).sortedBy { it.dateLastUsed ?: it.dateAdded } + } + }.collectLatest { + _queryResults.value = it + } + } + } + + private fun List.queryBy(query: String): List { + if (query.isBlank()) return this + + return filter { model -> + model.name.contains(query, ignoreCase = true) || + model.metadata.fullModelName?.contains(query, ignoreCase = true) == true || + model.metadata.additional?.tags?.any { it.contains(query, ignoreCase = true) } == true || + model.metadata.additional?.languages?.any { it.contains(query, ignoreCase = true) } == true || + model.metadata.architecture?.architecture?.contains(query, ignoreCase = true) == true + } + } + + // TODO-han.yin: Refactor this into Enums! + private fun List.filterBy(filters: Map): List { + val activeFilters = filters.filterValues { it } + if (activeFilters.isEmpty()) return this + + return filter { model -> + activeFilters.all { (filter, _) -> + when (filter) { + "Has context length" -> model.metadata.dimensions?.contextLength != null + "Support system prompt" -> true + "7B models" -> model.metadata.basic.sizeLabel?.contains("7B") == true + "13B models" -> model.metadata.basic.sizeLabel?.contains("13B") == true + "70B models" -> model.metadata.basic.sizeLabel?.contains("70B") == true + else -> true + } + } + } + } + + private fun List.sortByOrder(order: ModelSortOrder): List { + return when (order) { + ModelSortOrder.NAME_ASC -> sortedBy { it.name } + ModelSortOrder.NAME_DESC -> sortedByDescending { it.name } + ModelSortOrder.SIZE_ASC -> sortedBy { it.sizeInBytes } + ModelSortOrder.SIZE_DESC -> sortedByDescending { it.sizeInBytes } + ModelSortOrder.LAST_USED -> sortedWith( + compareByDescending { it.dateLastUsed } + .thenBy { it.name } + ) + } + } /** * Pre-select a model @@ -67,6 +208,6 @@ class ModelSelectionViewModel @Inject constructor( companion object { private val TAG = ModelSelectionViewModel::class.java.simpleName - private const val SUBSCRIPTION_TIMEOUT_MS = 5000L + private const val QUERY_DEBOUNCE_TIMEOUT_MS = 500L } }