From 05c620cc522696888b9a86c938a0bb8e66fff567 Mon Sep 17 00:00:00 2001 From: Han Yin Date: Mon, 21 Apr 2025 14:08:26 -0700 Subject: [PATCH] data: move Model related actions (query, filter, sort) into ModelInfo file --- .../llama/revamp/data/model/ModelInfo.kt | 189 ++++++++++++++++++ .../llama/revamp/ui/scaffold/BottomAppBars.kt | 9 +- .../viewmodel/ModelSelectionViewModel.kt | 66 +----- .../viewmodel/ModelsManagementViewModel.kt | 9 +- 4 files changed, 206 insertions(+), 67 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/ModelInfo.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/ModelInfo.kt index b4c58ba2a2..123bc06272 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/ModelInfo.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/ModelInfo.kt @@ -8,6 +8,17 @@ import com.example.llama.revamp.util.formatFileByteSize /** * Data class containing information about an LLM model. + * + * This class represents a language model with its associated metadata, including + * file information, architecture details, and usage statistics. + * + * @property id Unique identifier for the model + * @property name Display name of the model + * @property path File path to the model on device storage + * @property sizeInBytes Size of the model file in bytes + * @property metadata Structured metadata extracted from the GGUF file + * @property dateAdded Timestamp when the model was added to the app + * @property dateLastUsed Timestamp when the model was last used, or null if never used */ data class ModelInfo( val id: String, @@ -18,25 +29,203 @@ data class ModelInfo( val dateAdded: Long, val dateLastUsed: Long? = null, ) { + /** + * Full model name including version and parameter size if available, otherwise fallback to file name. + */ val formattedFullName: String get() = metadata.fullModelName ?: name + /** + * Human-readable file size with appropriate unit (KB, MB, GB). + */ val formattedFileSize: String get() = formatFileByteSize(sizeInBytes) + /** + * Architecture name of the model (e.g., "llama", "mistral"), or "-" if unavailable. + */ val formattedArchitecture: String get() = metadata.architecture?.architecture ?: "-" + /** + * Model parameter size with suffix (e.g., "7B", "13B"), or "-" if unavailable. + */ val formattedParamSize: String get() = metadata.basic.sizeLabel ?: "-" + /** + * Human-readable context length (e.g., "4K", "8K tokens"), or "-" if unavailable. + */ val formattedContextLength: String get() = metadata.dimensions?.contextLength?.let { formatContextLength(it) } ?: "-" + /** + * Quantization format of the model (e.g., "Q4_0", "Q5_K_M"), or "-" if unavailable. + */ val formattedQuantization: String get() = metadata.architecture?.fileType?.let { FileType.fromCode(it).label } ?: "-" + /** + * Tags associated with the model, or null if none are defined. + */ val tags: List? = metadata.additional?.tags?.takeIf { it.isNotEmpty() } + /** + * Languages supported by the model, or null if none are defined. + */ val languages: List? = metadata.additional?.languages?.takeIf { it.isNotEmpty() } } + +/** + * Filters models by search query. + * + * Searches through model names, tags, languages, and architecture. + * Returns the original list if the query is blank. + * + * @param query The search term to filter by + * @return List of models matching the search criteria + */ +fun List.queryBy(query: String): List { + if (query.isBlank()) return this + + return filter { model -> + model.name.contains(query, ignoreCase = true) || + model.metadata.fullModelName?.contains(query, ignoreCase = true) == true || + model.metadata.additional?.tags?.any { it.contains(query, ignoreCase = true) } == true || + model.metadata.additional?.languages?.any { it.contains(query, ignoreCase = true) } == true || + model.metadata.architecture?.architecture?.contains(query, ignoreCase = true) == true + } +} + +/** + * Sorting options for model lists. + */ +enum class ModelSortOrder { + NAME_ASC, + NAME_DESC, + SIZE_ASC, + SIZE_DESC, + LAST_USED +} + +/** + * Sorts models according to the specified order. + * + * @param order The sort order to apply + * @return Sorted list of models + */ +fun List.sortByOrder(order: ModelSortOrder): List { + return when (order) { + ModelSortOrder.NAME_ASC -> sortedBy { it.name } + ModelSortOrder.NAME_DESC -> sortedByDescending { it.name } + ModelSortOrder.SIZE_ASC -> sortedBy { it.sizeInBytes } + ModelSortOrder.SIZE_DESC -> sortedByDescending { it.sizeInBytes } + ModelSortOrder.LAST_USED -> sortedWith( + compareByDescending { it.dateLastUsed } + .thenBy { it.name } + ) + } +} + +/** + * Filters for categorizing and filtering models. + * + * @property displayName Human-readable name shown in the UI + * @property predicate Function that determines if a model matches this filter + */ +enum class ModelFilter(val displayName: String, val predicate: (ModelInfo) -> Boolean) { + // Parameter size filters + TINY_PARAMS("Tiny (<1B parameters)", { + it.metadata.basic.sizeLabel?.let { size -> + size.contains("M") || (size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n < 1f } == true) + } == true + }), + SMALL_PARAMS("Small (1-3B parameters)", { + it.metadata.basic.sizeLabel?.let { size -> + size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 1f && n <= 3f } == true + } == true + }), + MEDIUM_PARAMS("Medium (4-7B parameters)", { + it.metadata.basic.sizeLabel?.let { size -> + size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 4f && n <= 7f } == true + } == true + }), + LARGE_PARAMS("Large (8-13B parameters)", { + it.metadata.basic.sizeLabel?.let { size -> + size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 8f && n <= 13f } == true + } == true + }), + XLARGE_PARAMS("X-Large (>13B parameters)", { + it.metadata.basic.sizeLabel?.let { size -> + size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n > 13f } == true + } == true + }), + + // Context length filters + TINY_CONTEXT("Tiny context (<4K)", { + it.metadata.dimensions?.contextLength?.let { it < 4096 } == true + }), + SHORT_CONTEXT("Short context (4-8K)", { + it.metadata.dimensions?.contextLength?.let { it >= 4096 && it <= 8192 } == true + }), + MEDIUM_CONTEXT("Medium context (8-32K)", { + it.metadata.dimensions?.contextLength?.let { it > 8192 && it <= 32768 } == true + }), + LONG_CONTEXT("Long context (32-128K)", { + it.metadata.dimensions?.contextLength?.let { it > 32768 && it <= 131072 } == true + }), + XLARGE_CONTEXT("Extended context (>128K)", { + it.metadata.dimensions?.contextLength?.let { it > 131072 } == true + }), + + // Quantization filters + INT2_QUANT("2-bit quantization", { + it.formattedQuantization.let { it.contains("Q2") || it.contains("IQ2") } + }), + INT3_QUANT("3-bit quantization", { + it.formattedQuantization.let { it.contains("Q3") || it.contains("IQ3") } + }), + INT4_QUANT("4-bit quantization", { + it.formattedQuantization.let { it.contains("Q4") || it.contains("IQ4") } + }), + + // Special features + MULTILINGUAL("Multilingual", { + it.languages?.let { languages -> + languages.size > 1 || languages.any { it.contains("multi", ignoreCase = true) } + } == true + }), + HAS_TAGS("Has tags", { + !it.tags.isNullOrEmpty() + }); + + companion object { + // Group filters by category for UI + private val PARAMETER_FILTERS = listOf(TINY_PARAMS, SMALL_PARAMS, MEDIUM_PARAMS, LARGE_PARAMS) + private val CONTEXT_FILTERS = listOf(SHORT_CONTEXT, MEDIUM_CONTEXT, LONG_CONTEXT) + private val QUANTIZATION_FILTERS = listOf(INT2_QUANT, INT3_QUANT, INT4_QUANT) + private val FEATURE_FILTERS = listOf(MULTILINGUAL, HAS_TAGS) + + // All filters flattened + val ALL_FILTERS = PARAMETER_FILTERS + CONTEXT_FILTERS + QUANTIZATION_FILTERS + FEATURE_FILTERS + } +} + +/** + * Filters models based on a set of active filters. + * + * @param filters Map of filters to their enabled state + * @return List of models that match all active filters + */ +fun List.filterBy(filters: Map): List { + val activeFilters = filters.filterValues { it } + return if (activeFilters.isEmpty()) { + this + } else { + filter { model -> + activeFilters.keys.all { filter -> + filter.predicate(model) + } + } + } +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/scaffold/BottomAppBars.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/scaffold/BottomAppBars.kt index b611277a12..e6e639c06c 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/scaffold/BottomAppBars.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/scaffold/BottomAppBars.kt @@ -36,8 +36,9 @@ import androidx.compose.ui.graphics.Color import androidx.compose.ui.res.painterResource import androidx.compose.ui.unit.dp import com.example.llama.R +import com.example.llama.revamp.data.model.ModelFilter import com.example.llama.revamp.data.model.ModelInfo -import com.example.llama.revamp.viewmodel.ModelSortOrder +import com.example.llama.revamp.data.model.ModelSortOrder /** * [BottomAppBar] configurations @@ -68,8 +69,8 @@ sealed class BottomBarConfig { data class FilteringConfig( val isActive: Boolean, - val filters: Map, // Filter name -> enabled - val onToggleFilter: (String, Boolean) -> Unit, + val filters: Map, + val onToggleFilter: (ModelFilter, Boolean) -> Unit, val onClearFilters: () -> Unit, val isMenuVisible: Boolean, val toggleMenu: (Boolean) -> Unit @@ -210,7 +211,7 @@ fun ModelSelectionBottomBar( filtering.filters.forEach { (filter, isEnabled) -> DropdownMenuItem( - text = { Text(filter) }, + text = { Text(filter.displayName) }, leadingIcon = { Checkbox( checked = isEnabled, diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt index 62d769fc21..28438d1d76 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt @@ -5,7 +5,12 @@ import androidx.compose.foundation.text.input.clearText import androidx.compose.runtime.snapshotFlow import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope +import com.example.llama.revamp.data.model.ModelFilter import com.example.llama.revamp.data.model.ModelInfo +import com.example.llama.revamp.data.model.ModelSortOrder +import com.example.llama.revamp.data.model.filterBy +import com.example.llama.revamp.data.model.queryBy +import com.example.llama.revamp.data.model.sortByOrder import com.example.llama.revamp.data.repository.ModelRepository import com.example.llama.revamp.engine.InferenceService import dagger.hilt.android.lifecycle.HiltViewModel @@ -59,17 +64,12 @@ class ModelSelectionViewModel @Inject constructor( } // UI state: filters - // TODO-han.yin: Refactor this into Enums! - private val _activeFilters = MutableStateFlow>(mapOf( - "Has context length" to false, - "Support system prompt" to false, - "7B models" to false, - "13B models" to false, - "70B models" to false - )) - val activeFilters: StateFlow> = _activeFilters.asStateFlow() + private val _activeFilters = MutableStateFlow>( + ModelFilter.ALL_FILTERS.associateWith { false } + ) + val activeFilters: StateFlow> = _activeFilters.asStateFlow() - fun toggleFilter(filter: String, enabled: Boolean) { + fun toggleFilter(filter: ModelFilter, enabled: Boolean) { _activeFilters.update { current -> current.toMutableMap().apply { this[filter] = enabled @@ -110,7 +110,7 @@ class ModelSelectionViewModel @Inject constructor( _sortOrder, ) { models, filters, sortOrder -> models.filterBy(filters).sortByOrder(sortOrder) - }.collect { + }.collectLatest { _filteredModels.value = it } } @@ -131,50 +131,6 @@ class ModelSelectionViewModel @Inject constructor( } } - private fun List.queryBy(query: String): List { - if (query.isBlank()) return this - - return filter { model -> - model.name.contains(query, ignoreCase = true) || - model.metadata.fullModelName?.contains(query, ignoreCase = true) == true || - model.metadata.additional?.tags?.any { it.contains(query, ignoreCase = true) } == true || - model.metadata.additional?.languages?.any { it.contains(query, ignoreCase = true) } == true || - model.metadata.architecture?.architecture?.contains(query, ignoreCase = true) == true - } - } - - // TODO-han.yin: Refactor this into Enums! - private fun List.filterBy(filters: Map): List { - val activeFilters = filters.filterValues { it } - if (activeFilters.isEmpty()) return this - - return filter { model -> - activeFilters.all { (filter, _) -> - when (filter) { - "Has context length" -> model.metadata.dimensions?.contextLength != null - "Support system prompt" -> true - "7B models" -> model.metadata.basic.sizeLabel?.contains("7B") == true - "13B models" -> model.metadata.basic.sizeLabel?.contains("13B") == true - "70B models" -> model.metadata.basic.sizeLabel?.contains("70B") == true - else -> true - } - } - } - } - - private fun List.sortByOrder(order: ModelSortOrder): List { - return when (order) { - ModelSortOrder.NAME_ASC -> sortedBy { it.name } - ModelSortOrder.NAME_DESC -> sortedByDescending { it.name } - ModelSortOrder.SIZE_ASC -> sortedBy { it.sizeInBytes } - ModelSortOrder.SIZE_DESC -> sortedByDescending { it.sizeInBytes } - ModelSortOrder.LAST_USED -> sortedWith( - compareByDescending { it.dateLastUsed } - .thenBy { it.name } - ) - } - } - /** * Pre-select a model */ diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelsManagementViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelsManagementViewModel.kt index e8dcc46527..5d3da860d0 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelsManagementViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelsManagementViewModel.kt @@ -6,6 +6,7 @@ import android.util.Log import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.example.llama.revamp.data.model.ModelInfo +import com.example.llama.revamp.data.model.ModelSortOrder import com.example.llama.revamp.data.repository.InsufficientStorageException import com.example.llama.revamp.data.repository.ModelRepository import com.example.llama.revamp.util.getFileNameFromUri @@ -236,14 +237,6 @@ class ModelsManagementViewModel @Inject constructor( } } -enum class ModelSortOrder { - NAME_ASC, - NAME_DESC, - SIZE_ASC, - SIZE_DESC, - LAST_USED -} - sealed class ModelManagementState { object Idle : ModelManagementState()