UI: implement a dialog UI to show fetched HuggingFace models.

This commit is contained in:
Han Yin 2025-07-07 21:45:30 -07:00
parent 310771f6aa
commit f085d39c05
5 changed files with 283 additions and 55 deletions

View File

@ -351,7 +351,7 @@ fun AppContent(
// Create file launcher for importing local models // Create file launcher for importing local models
val fileLauncher = rememberLauncherForActivityResult( val fileLauncher = rememberLauncherForActivityResult(
contract = ActivityResultContracts.OpenDocument() contract = ActivityResultContracts.OpenDocument()
) { uri -> uri?.let { modelsManagementViewModel.localModelFileSelected(it) } } ) { uri -> uri?.let { modelsManagementViewModel.importLocalModelFileSelected(it) } }
val bottomBarConfig = BottomBarConfig.ModelsManagement( val bottomBarConfig = BottomBarConfig.ModelsManagement(
sorting = BottomBarConfig.ModelsManagement.SortingConfig( sorting = BottomBarConfig.ModelsManagement.SortingConfig(

View File

@ -35,7 +35,7 @@ data class HuggingFaceModel(
) )
fun getGgufFilename(): String? = fun getGgufFilename(): String? =
siblings.map { it.rfilename }.first { it.endsWith(FILE_EXTENSION_GGUF) } siblings.map { it.rfilename }.firstOrNull { it.endsWith(FILE_EXTENSION_GGUF) }
fun toDownloadInfo() = getGgufFilename()?.let { HuggingFaceDownloadInfo(_id, modelId, it) } fun toDownloadInfo() = getGgufFilename()?.let { HuggingFaceDownloadInfo(_id, modelId, it) }
} }

View File

@ -64,7 +64,7 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
direction = direction, direction = direction,
limit = limit, limit = limit,
full = full, full = full,
).filter { it.gated != true && it.private != true } ).filter { it.gated != true && it.private != true && it.getGgufFilename() != null }
} }
override suspend fun getModelDetails( override suspend fun getModelDetails(

View File

@ -1,9 +1,11 @@
package com.example.llama.ui.screens package com.example.llama.ui.screens
import androidx.activity.compose.BackHandler import androidx.activity.compose.BackHandler
import androidx.compose.foundation.basicMarquee
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
@ -15,10 +17,18 @@ import androidx.compose.foundation.layout.width
import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items import androidx.compose.foundation.lazy.items
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.Attribution
import androidx.compose.material.icons.filled.Download
import androidx.compose.material.icons.filled.Favorite
import androidx.compose.material.icons.filled.FolderOpen import androidx.compose.material.icons.filled.FolderOpen
import androidx.compose.material.icons.filled.Today
import androidx.compose.material3.AlertDialog import androidx.compose.material3.AlertDialog
import androidx.compose.material3.Button import androidx.compose.material3.Button
import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults
import androidx.compose.material3.Checkbox
import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.Icon
import androidx.compose.material3.LinearProgressIndicator import androidx.compose.material3.LinearProgressIndicator
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.SnackbarDuration import androidx.compose.material3.SnackbarDuration
@ -30,23 +40,31 @@ import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.derivedStateOf import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateMapOf import androidx.compose.runtime.mutableStateMapOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.text.font.FontStyle import androidx.compose.ui.text.font.FontStyle
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.DialogProperties import androidx.compose.ui.window.DialogProperties
import com.example.llama.data.model.ModelInfo import com.example.llama.data.model.ModelInfo
import com.example.llama.data.remote.HuggingFaceModel
import com.example.llama.ui.components.InfoView import com.example.llama.ui.components.InfoView
import com.example.llama.ui.components.ModelCardFullExpandable import com.example.llama.ui.components.ModelCardFullExpandable
import com.example.llama.ui.scaffold.ScaffoldEvent import com.example.llama.ui.scaffold.ScaffoldEvent
import com.example.llama.util.formatContextLength
import com.example.llama.util.formatFileByteSize import com.example.llama.util.formatFileByteSize
import com.example.llama.viewmodel.ModelManagementState import com.example.llama.viewmodel.ModelManagementState
import com.example.llama.viewmodel.ModelManagementState.Deletion import com.example.llama.viewmodel.ModelManagementState.Deletion
import com.example.llama.viewmodel.ModelManagementState.Download
import com.example.llama.viewmodel.ModelManagementState.Importation import com.example.llama.viewmodel.ModelManagementState.Importation
import com.example.llama.viewmodel.ModelsManagementViewModel import com.example.llama.viewmodel.ModelsManagementViewModel
import java.text.SimpleDateFormat
import java.util.Locale
/** /**
* Screen for managing LLM models (view, download, delete) * Screen for managing LLM models (view, download, delete)
@ -103,7 +121,9 @@ fun ModelsManagementScreen(
} else { } else {
// Model cards // Model cards
LazyColumn( LazyColumn(
modifier = Modifier.fillMaxSize().padding(horizontal = 16.dp), modifier = Modifier
.fillMaxSize()
.padding(horizontal = 16.dp),
verticalArrangement = Arrangement.spacedBy(12.dp), verticalArrangement = Arrangement.spacedBy(12.dp),
) { ) {
items(items = filteredModels, key = { it.id }) { model -> items(items = filteredModels, key = { it.id }) { model ->
@ -134,20 +154,20 @@ fun ModelsManagementScreen(
// Model import progress overlay // Model import progress overlay
when (val state = managementState) { when (val state = managementState) {
is Importation.Confirming -> { is Importation.Confirming -> {
ImportProgressDialog( ImportFromLocalFileDialog(
fileName = state.fileName, fileName = state.fileName,
fileSize = state.fileSize, fileSize = state.fileSize,
isImporting = false, isImporting = false,
progress = 0.0f, progress = 0.0f,
onConfirm = { onConfirm = {
viewModel.importLocalModelFile(state.uri, state.fileName, state.fileSize) viewModel.importLocalModelFileConfirmed(state.uri, state.fileName, state.fileSize)
}, },
onCancel = { viewModel.resetManagementState() } onCancel = { viewModel.resetManagementState() }
) )
} }
is Importation.Importing -> { is Importation.Importing -> {
ImportProgressDialog( ImportFromLocalFileDialog(
fileName = state.fileName, fileName = state.fileName,
fileSize = state.fileSize, fileSize = state.fileSize,
isImporting = true, isImporting = true,
@ -178,6 +198,43 @@ fun ModelsManagementScreen(
} }
} }
is Download.Querying -> {
ImportFromHuggingFaceDialog(
onCancel = { viewModel.resetManagementState() }
)
}
is Download.Ready -> {
ImportFromHuggingFaceDialog(
models = state.models,
onConfirm = { viewModel.downloadHuggingFaceModelConfirmed(it) },
onCancel = { viewModel.resetManagementState() }
)
}
is Download.Dispatched -> {
LaunchedEffect(state) {
onScaffoldEvent(
ScaffoldEvent.ShowSnackbar(
message = "Started downloading: ${state.downloadInfo.modelId}.\n" +
// TODO-han.yin: replace this with a broadcast receiver!
"Please come back to import it once completed.",
duration = SnackbarDuration.Long,
)
)
viewModel.resetManagementState()
}
}
is Download.Error -> {
ErrorDialog(
title = "Download Failed",
message = state.message,
onDismiss = { viewModel.resetManagementState() }
)
}
is Deletion.Confirming -> { is Deletion.Confirming -> {
BatchDeleteConfirmationDialog( BatchDeleteConfirmationDialog(
count = state.models.size, count = state.models.size,
@ -219,33 +276,12 @@ fun ModelsManagementScreen(
} }
is ModelManagementState.Idle -> { /* Idle state, nothing to show */ } is ModelManagementState.Idle -> { /* Idle state, nothing to show */ }
else -> TODO()
}
}
// TODO-han.yin: UI TO BE IMPLEMENTED
val huggingFaceModelsFlow = viewModel.huggingFaceModels
LaunchedEffect(Unit) {
huggingFaceModelsFlow.collect { models ->
val message = models.fold(
StringBuilder("Fetched ${models.size} models from HuggingFace")
) { builder, model ->
builder.append(model.id).append("\n")
}.toString()
onScaffoldEvent(
ScaffoldEvent.ShowSnackbar(
message = message,
duration = SnackbarDuration.Short
)
)
} }
} }
} }
@Composable @Composable
private fun ImportProgressDialog( private fun ImportFromLocalFileDialog(
fileName: String, fileName: String,
fileSize: Long, fileSize: Long,
isImporting: Boolean, isImporting: Boolean,
@ -344,6 +380,200 @@ private fun ImportProgressDialog(
) )
} }
@Composable
private fun ImportFromHuggingFaceDialog(
models: List<HuggingFaceModel>? = null,
onConfirm: ((HuggingFaceModel) -> Unit)? = null,
onCancel: () -> Unit,
) {
val dateFormatter = remember { SimpleDateFormat("MMM, yyyy", Locale.getDefault()) }
var selectedModel by remember { mutableStateOf<HuggingFaceModel?>(null) }
AlertDialog(
onDismissRequest = {},
properties = DialogProperties(
dismissOnBackPress = true,
dismissOnClickOutside = true
),
title = {
Text(models?.let { "Fetched ${it.size} models" } ?: "Fetching models")
},
text = {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
Text(
modifier = Modifier.fillMaxWidth(),
text = models?.let { "Select a model to download:" }
?: "Searching on HuggingFace for models available for direct download...",
style = MaterialTheme.typography.bodyLarge,
textAlign = TextAlign.Start,
)
if (models == null) {
Spacer(modifier = Modifier.height(24.dp))
CircularProgressIndicator(
modifier = Modifier.size(64.dp),
strokeWidth = 6.dp
)
} else {
Spacer(modifier = Modifier.height(16.dp))
LazyColumn(
modifier = Modifier.fillMaxWidth(),
contentPadding = PaddingValues(vertical = 8.dp),
verticalArrangement = Arrangement.spacedBy(8.dp)
) {
items(models) { model ->
HuggingFaceModelListItem(
model = model,
isSelected = model._id == selectedModel?._id,
dateFormatter = dateFormatter,
onToggleSelect = { selected ->
selectedModel = if (selected) model else null
}
)
}
}
}
}
},
confirmButton = {
onConfirm?.let { onSelect ->
TextButton(
onClick = { selectedModel?.let { onSelect.invoke(it) } },
enabled = selectedModel != null
) {
Text("Download")
}
}
},
dismissButton = {
TextButton(
onClick = onCancel
) {
Text("Cancel")
}
}
)
}
@Composable
fun HuggingFaceModelListItem(
model: HuggingFaceModel,
isSelected: Boolean,
dateFormatter: SimpleDateFormat,
onToggleSelect: (Boolean) -> Unit,
) {
Card(
modifier = Modifier.fillMaxWidth(),
colors = when (isSelected) {
true -> CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.primaryContainer
)
false -> CardDefaults.cardColors()
},
onClick = { onToggleSelect(!isSelected) }
) {
Column(modifier = Modifier.fillMaxWidth().padding(8.dp)) {
Text(
modifier = Modifier.basicMarquee(),
text = model.modelId.substringAfterLast("/"),
style = MaterialTheme.typography.bodyMedium,
fontWeight = if (isSelected) FontWeight.Bold else FontWeight.Medium,
)
Spacer(modifier = Modifier.size(8.dp))
Row(verticalAlignment = Alignment.Bottom) {
Column {
Row(verticalAlignment = Alignment.CenterVertically) {
Icon(
imageVector = Icons.Default.Attribution,
contentDescription = "Author",
modifier = Modifier.size(16.dp),
tint = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.size(4.dp))
Text(
text = model.author,
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
Spacer(modifier = Modifier.size(8.dp))
Row(verticalAlignment = Alignment.CenterVertically) {
Icon(
imageVector = Icons.Default.Today,
contentDescription = "Author",
modifier = Modifier.size(16.dp),
tint = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.size(4.dp))
Text(
text = dateFormatter.format(model.lastModified),
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.size(8.dp))
Icon(
imageVector = Icons.Default.Favorite,
contentDescription = "Favorite count",
modifier = Modifier.size(16.dp),
tint = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.size(4.dp))
Text(
text = formatContextLength(model.likes),
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.size(8.dp))
Icon(
imageVector = Icons.Default.Download,
contentDescription = "Download count",
modifier = Modifier.size(16.dp),
tint = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.size(4.dp))
Text(
text = formatContextLength(model.downloads),
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
Spacer(Modifier.weight(1f))
if (isSelected) {
Checkbox(
checked = isSelected,
onCheckedChange = null, // handled by parent selectable
)
}
}
}
}
}
@Composable @Composable
private fun BatchDeleteConfirmationDialog( private fun BatchDeleteConfirmationDialog(
count: Int, count: Int,

View File

@ -22,10 +22,10 @@ import com.example.llama.viewmodel.ModelManagementState.Download
import com.example.llama.viewmodel.ModelManagementState.Importation import com.example.llama.viewmodel.ModelManagementState.Importation
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import dagger.hilt.android.qualifiers.ApplicationContext import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.collectLatest import kotlinx.coroutines.flow.collectLatest
@ -34,7 +34,6 @@ import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import java.io.FileNotFoundException import java.io.FileNotFoundException
import javax.inject.Inject import javax.inject.Inject
import kotlin.collections.set
@HiltViewModel @HiltViewModel
class ModelsManagementViewModel @Inject constructor( class ModelsManagementViewModel @Inject constructor(
@ -133,9 +132,8 @@ class ModelsManagementViewModel @Inject constructor(
_showImportModelMenu.value = show _showImportModelMenu.value = show
} }
// UI state: HuggingFace models query result // Ongoing coroutine jobs
private val _huggingFaceModels = MutableSharedFlow<List<HuggingFaceModel>>() private var huggingFaceQueryJob: Job? = null
val huggingFaceModels: SharedFlow<List<HuggingFaceModel>> = _huggingFaceModels
init { init {
viewModelScope.launch { viewModelScope.launch {
@ -156,13 +154,16 @@ class ModelsManagementViewModel @Inject constructor(
val managementState: StateFlow<ModelManagementState> = _managementState.asStateFlow() val managementState: StateFlow<ModelManagementState> = _managementState.asStateFlow()
fun resetManagementState() { fun resetManagementState() {
huggingFaceQueryJob?.let {
if (it.isActive) { it.cancel() }
}
_managementState.value = ModelManagementState.Idle _managementState.value = ModelManagementState.Idle
} }
/** /**
* First show confirmation instead of starting import local file immediately * First show confirmation instead of starting import local file immediately
*/ */
fun localModelFileSelected(uri: Uri) = viewModelScope.launch { fun importLocalModelFileSelected(uri: Uri) = viewModelScope.launch {
try { try {
val fileName = getFileNameFromUri(context, uri) ?: throw FileNotFoundException("File size N/A") val fileName = getFileNameFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
val fileSize = getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File name N/A") val fileSize = getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File name N/A")
@ -177,7 +178,7 @@ class ModelsManagementViewModel @Inject constructor(
/** /**
* Import a local model file from device storage while updating UI states with realtime progress * Import a local model file from device storage while updating UI states with realtime progress
*/ */
fun importLocalModelFile(uri: Uri, fileName: String, fileSize: Long) = viewModelScope.launch { fun importLocalModelFileConfirmed(uri: Uri, fileName: String, fileSize: Long) = viewModelScope.launch {
try { try {
_managementState.value = Importation.Importing(0f, fileName, fileSize) _managementState.value = Importation.Importing(0f, fileName, fileSize)
val model = modelRepository.importModel(uri, fileName, fileSize) { progress -> val model = modelRepository.importModel(uri, fileName, fileSize) { progress ->
@ -221,29 +222,25 @@ class ModelsManagementViewModel @Inject constructor(
/** /**
* Query models on HuggingFace available for download even without signing in * Query models on HuggingFace available for download even without signing in
*/ */
fun queryModelsFromHuggingFace() = viewModelScope.launch { fun queryModelsFromHuggingFace() {
modelRepository.searchHuggingFaceModels().let { models -> huggingFaceQueryJob = viewModelScope.launch {
_huggingFaceModels.emit(models) _managementState.emit(Download.Querying)
try {
val models = modelRepository.searchHuggingFaceModels()
Log.d(TAG, "Fetched ${models.size} models from HuggingFace:") Log.d(TAG, "Fetched ${models.size} models from HuggingFace:")
_managementState.emit(Download.Ready(models))
// TODO-han.yin: remove these logs } catch (_: CancellationException) {
// models.forEachIndexed { index, model -> // no-op
// Log.d(TAG, "#$index: $model") } catch (e: Exception) {
// } _managementState.emit(Download.Error(message = e.message ?: "Unknown error"))
} }
} }
/**
* First show confirmation instead of dispatch download immediately
*/
fun downloadHuggingFaceModelSelected(model: HuggingFaceModel) {
_managementState.value = Download.Confirming(model)
} }
/** /**
* Dispatch download request to [DownloadManager] and update UI * Dispatch download request to [DownloadManager] and update UI
*/ */
fun downloadHuggingFaceModel(model: HuggingFaceModel) = viewModelScope.launch { fun downloadHuggingFaceModelConfirmed(model: HuggingFaceModel) = viewModelScope.launch {
try { try {
require(!model.gated) { "Model is gated!" } require(!model.gated) { "Model is gated!" }
require(!model.private) { "Model is private!" } require(!model.private) { "Model is private!" }
@ -324,7 +321,8 @@ sealed class ModelManagementState {
} }
sealed class Download: ModelManagementState() { sealed class Download: ModelManagementState() {
data class Confirming(val model: HuggingFaceModel) : Download() object Querying : Download()
data class Ready(val models: List<HuggingFaceModel>) : Download()
data class Dispatched(val downloadInfo: HuggingFaceDownloadInfo) : Download() data class Dispatched(val downloadInfo: HuggingFaceDownloadInfo) : Download()
data class Error(val message: String) : Download() data class Error(val message: String) : Download()
} }