UI: implement a dialog UI to show fetched HuggingFace models.
This commit is contained in:
parent
310771f6aa
commit
f085d39c05
|
|
@ -351,7 +351,7 @@ fun AppContent(
|
|||
// Create file launcher for importing local models
|
||||
val fileLauncher = rememberLauncherForActivityResult(
|
||||
contract = ActivityResultContracts.OpenDocument()
|
||||
) { uri -> uri?.let { modelsManagementViewModel.localModelFileSelected(it) } }
|
||||
) { uri -> uri?.let { modelsManagementViewModel.importLocalModelFileSelected(it) } }
|
||||
|
||||
val bottomBarConfig = BottomBarConfig.ModelsManagement(
|
||||
sorting = BottomBarConfig.ModelsManagement.SortingConfig(
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ data class HuggingFaceModel(
|
|||
)
|
||||
|
||||
fun getGgufFilename(): String? =
|
||||
siblings.map { it.rfilename }.first { it.endsWith(FILE_EXTENSION_GGUF) }
|
||||
siblings.map { it.rfilename }.firstOrNull { it.endsWith(FILE_EXTENSION_GGUF) }
|
||||
|
||||
fun toDownloadInfo() = getGgufFilename()?.let { HuggingFaceDownloadInfo(_id, modelId, it) }
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
|||
direction = direction,
|
||||
limit = limit,
|
||||
full = full,
|
||||
).filter { it.gated != true && it.private != true }
|
||||
).filter { it.gated != true && it.private != true && it.getGgufFilename() != null }
|
||||
}
|
||||
|
||||
override suspend fun getModelDetails(
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
package com.example.llama.ui.screens
|
||||
|
||||
import androidx.activity.compose.BackHandler
|
||||
import androidx.compose.foundation.basicMarquee
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.PaddingValues
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
|
|
@ -15,10 +17,18 @@ import androidx.compose.foundation.layout.width
|
|||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.Attribution
|
||||
import androidx.compose.material.icons.filled.Download
|
||||
import androidx.compose.material.icons.filled.Favorite
|
||||
import androidx.compose.material.icons.filled.FolderOpen
|
||||
import androidx.compose.material.icons.filled.Today
|
||||
import androidx.compose.material3.AlertDialog
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.Card
|
||||
import androidx.compose.material3.CardDefaults
|
||||
import androidx.compose.material3.Checkbox
|
||||
import androidx.compose.material3.CircularProgressIndicator
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.LinearProgressIndicator
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.SnackbarDuration
|
||||
|
|
@ -30,23 +40,31 @@ import androidx.compose.runtime.collectAsState
|
|||
import androidx.compose.runtime.derivedStateOf
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateMapOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.text.font.FontStyle
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.text.style.TextOverflow
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.window.DialogProperties
|
||||
import com.example.llama.data.model.ModelInfo
|
||||
import com.example.llama.data.remote.HuggingFaceModel
|
||||
import com.example.llama.ui.components.InfoView
|
||||
import com.example.llama.ui.components.ModelCardFullExpandable
|
||||
import com.example.llama.ui.scaffold.ScaffoldEvent
|
||||
import com.example.llama.util.formatContextLength
|
||||
import com.example.llama.util.formatFileByteSize
|
||||
import com.example.llama.viewmodel.ModelManagementState
|
||||
import com.example.llama.viewmodel.ModelManagementState.Deletion
|
||||
import com.example.llama.viewmodel.ModelManagementState.Download
|
||||
import com.example.llama.viewmodel.ModelManagementState.Importation
|
||||
import com.example.llama.viewmodel.ModelsManagementViewModel
|
||||
import java.text.SimpleDateFormat
|
||||
import java.util.Locale
|
||||
|
||||
/**
|
||||
* Screen for managing LLM models (view, download, delete)
|
||||
|
|
@ -103,7 +121,9 @@ fun ModelsManagementScreen(
|
|||
} else {
|
||||
// Model cards
|
||||
LazyColumn(
|
||||
modifier = Modifier.fillMaxSize().padding(horizontal = 16.dp),
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(horizontal = 16.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(12.dp),
|
||||
) {
|
||||
items(items = filteredModels, key = { it.id }) { model ->
|
||||
|
|
@ -134,20 +154,20 @@ fun ModelsManagementScreen(
|
|||
// Model import progress overlay
|
||||
when (val state = managementState) {
|
||||
is Importation.Confirming -> {
|
||||
ImportProgressDialog(
|
||||
ImportFromLocalFileDialog(
|
||||
fileName = state.fileName,
|
||||
fileSize = state.fileSize,
|
||||
isImporting = false,
|
||||
progress = 0.0f,
|
||||
onConfirm = {
|
||||
viewModel.importLocalModelFile(state.uri, state.fileName, state.fileSize)
|
||||
viewModel.importLocalModelFileConfirmed(state.uri, state.fileName, state.fileSize)
|
||||
},
|
||||
onCancel = { viewModel.resetManagementState() }
|
||||
)
|
||||
}
|
||||
|
||||
is Importation.Importing -> {
|
||||
ImportProgressDialog(
|
||||
ImportFromLocalFileDialog(
|
||||
fileName = state.fileName,
|
||||
fileSize = state.fileSize,
|
||||
isImporting = true,
|
||||
|
|
@ -178,6 +198,43 @@ fun ModelsManagementScreen(
|
|||
}
|
||||
}
|
||||
|
||||
is Download.Querying -> {
|
||||
ImportFromHuggingFaceDialog(
|
||||
onCancel = { viewModel.resetManagementState() }
|
||||
)
|
||||
}
|
||||
|
||||
is Download.Ready -> {
|
||||
ImportFromHuggingFaceDialog(
|
||||
models = state.models,
|
||||
onConfirm = { viewModel.downloadHuggingFaceModelConfirmed(it) },
|
||||
onCancel = { viewModel.resetManagementState() }
|
||||
)
|
||||
}
|
||||
|
||||
is Download.Dispatched -> {
|
||||
LaunchedEffect(state) {
|
||||
onScaffoldEvent(
|
||||
ScaffoldEvent.ShowSnackbar(
|
||||
message = "Started downloading: ${state.downloadInfo.modelId}.\n" +
|
||||
// TODO-han.yin: replace this with a broadcast receiver!
|
||||
"Please come back to import it once completed.",
|
||||
duration = SnackbarDuration.Long,
|
||||
)
|
||||
)
|
||||
|
||||
viewModel.resetManagementState()
|
||||
}
|
||||
}
|
||||
|
||||
is Download.Error -> {
|
||||
ErrorDialog(
|
||||
title = "Download Failed",
|
||||
message = state.message,
|
||||
onDismiss = { viewModel.resetManagementState() }
|
||||
)
|
||||
}
|
||||
|
||||
is Deletion.Confirming -> {
|
||||
BatchDeleteConfirmationDialog(
|
||||
count = state.models.size,
|
||||
|
|
@ -219,33 +276,12 @@ fun ModelsManagementScreen(
|
|||
}
|
||||
|
||||
is ModelManagementState.Idle -> { /* Idle state, nothing to show */ }
|
||||
|
||||
else -> TODO()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO-han.yin: UI TO BE IMPLEMENTED
|
||||
val huggingFaceModelsFlow = viewModel.huggingFaceModels
|
||||
LaunchedEffect(Unit) {
|
||||
huggingFaceModelsFlow.collect { models ->
|
||||
val message = models.fold(
|
||||
StringBuilder("Fetched ${models.size} models from HuggingFace")
|
||||
) { builder, model ->
|
||||
builder.append(model.id).append("\n")
|
||||
}.toString()
|
||||
|
||||
onScaffoldEvent(
|
||||
ScaffoldEvent.ShowSnackbar(
|
||||
message = message,
|
||||
duration = SnackbarDuration.Short
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun ImportProgressDialog(
|
||||
private fun ImportFromLocalFileDialog(
|
||||
fileName: String,
|
||||
fileSize: Long,
|
||||
isImporting: Boolean,
|
||||
|
|
@ -344,6 +380,200 @@ private fun ImportProgressDialog(
|
|||
)
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun ImportFromHuggingFaceDialog(
|
||||
models: List<HuggingFaceModel>? = null,
|
||||
onConfirm: ((HuggingFaceModel) -> Unit)? = null,
|
||||
onCancel: () -> Unit,
|
||||
) {
|
||||
val dateFormatter = remember { SimpleDateFormat("MMM, yyyy", Locale.getDefault()) }
|
||||
|
||||
var selectedModel by remember { mutableStateOf<HuggingFaceModel?>(null) }
|
||||
|
||||
AlertDialog(
|
||||
onDismissRequest = {},
|
||||
properties = DialogProperties(
|
||||
dismissOnBackPress = true,
|
||||
dismissOnClickOutside = true
|
||||
),
|
||||
title = {
|
||||
Text(models?.let { "Fetched ${it.size} models" } ?: "Fetching models")
|
||||
},
|
||||
text = {
|
||||
Column(
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.Center
|
||||
) {
|
||||
Text(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
text = models?.let { "Select a model to download:" }
|
||||
?: "Searching on HuggingFace for models available for direct download...",
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
textAlign = TextAlign.Start,
|
||||
)
|
||||
|
||||
if (models == null) {
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier.size(64.dp),
|
||||
strokeWidth = 6.dp
|
||||
)
|
||||
} else {
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
LazyColumn(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
contentPadding = PaddingValues(vertical = 8.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(8.dp)
|
||||
) {
|
||||
items(models) { model ->
|
||||
HuggingFaceModelListItem(
|
||||
model = model,
|
||||
isSelected = model._id == selectedModel?._id,
|
||||
dateFormatter = dateFormatter,
|
||||
onToggleSelect = { selected ->
|
||||
selectedModel = if (selected) model else null
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
confirmButton = {
|
||||
onConfirm?.let { onSelect ->
|
||||
TextButton(
|
||||
onClick = { selectedModel?.let { onSelect.invoke(it) } },
|
||||
enabled = selectedModel != null
|
||||
) {
|
||||
Text("Download")
|
||||
}
|
||||
}
|
||||
},
|
||||
dismissButton = {
|
||||
TextButton(
|
||||
onClick = onCancel
|
||||
) {
|
||||
Text("Cancel")
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun HuggingFaceModelListItem(
|
||||
model: HuggingFaceModel,
|
||||
isSelected: Boolean,
|
||||
dateFormatter: SimpleDateFormat,
|
||||
onToggleSelect: (Boolean) -> Unit,
|
||||
) {
|
||||
Card(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
colors = when (isSelected) {
|
||||
true -> CardDefaults.cardColors(
|
||||
containerColor = MaterialTheme.colorScheme.primaryContainer
|
||||
)
|
||||
false -> CardDefaults.cardColors()
|
||||
},
|
||||
onClick = { onToggleSelect(!isSelected) }
|
||||
) {
|
||||
Column(modifier = Modifier.fillMaxWidth().padding(8.dp)) {
|
||||
Text(
|
||||
modifier = Modifier.basicMarquee(),
|
||||
text = model.modelId.substringAfterLast("/"),
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
fontWeight = if (isSelected) FontWeight.Bold else FontWeight.Medium,
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.size(8.dp))
|
||||
|
||||
Row(verticalAlignment = Alignment.Bottom) {
|
||||
Column {
|
||||
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Attribution,
|
||||
contentDescription = "Author",
|
||||
modifier = Modifier.size(16.dp),
|
||||
tint = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.size(4.dp))
|
||||
|
||||
Text(
|
||||
text = model.author,
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.size(8.dp))
|
||||
|
||||
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Today,
|
||||
contentDescription = "Author",
|
||||
modifier = Modifier.size(16.dp),
|
||||
tint = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.size(4.dp))
|
||||
|
||||
Text(
|
||||
text = dateFormatter.format(model.lastModified),
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.size(8.dp))
|
||||
|
||||
Icon(
|
||||
imageVector = Icons.Default.Favorite,
|
||||
contentDescription = "Favorite count",
|
||||
modifier = Modifier.size(16.dp),
|
||||
tint = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.size(4.dp))
|
||||
|
||||
Text(
|
||||
text = formatContextLength(model.likes),
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.size(8.dp))
|
||||
|
||||
Icon(
|
||||
imageVector = Icons.Default.Download,
|
||||
contentDescription = "Download count",
|
||||
modifier = Modifier.size(16.dp),
|
||||
tint = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.size(4.dp))
|
||||
|
||||
Text(
|
||||
text = formatContextLength(model.downloads),
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(Modifier.weight(1f))
|
||||
|
||||
if (isSelected) {
|
||||
Checkbox(
|
||||
checked = isSelected,
|
||||
onCheckedChange = null, // handled by parent selectable
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun BatchDeleteConfirmationDialog(
|
||||
count: Int,
|
||||
|
|
|
|||
|
|
@ -22,10 +22,10 @@ import com.example.llama.viewmodel.ModelManagementState.Download
|
|||
import com.example.llama.viewmodel.ModelManagementState.Importation
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.SharedFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.flow.collectLatest
|
||||
|
|
@ -34,7 +34,6 @@ import kotlinx.coroutines.flow.update
|
|||
import kotlinx.coroutines.launch
|
||||
import java.io.FileNotFoundException
|
||||
import javax.inject.Inject
|
||||
import kotlin.collections.set
|
||||
|
||||
@HiltViewModel
|
||||
class ModelsManagementViewModel @Inject constructor(
|
||||
|
|
@ -133,9 +132,8 @@ class ModelsManagementViewModel @Inject constructor(
|
|||
_showImportModelMenu.value = show
|
||||
}
|
||||
|
||||
// UI state: HuggingFace models query result
|
||||
private val _huggingFaceModels = MutableSharedFlow<List<HuggingFaceModel>>()
|
||||
val huggingFaceModels: SharedFlow<List<HuggingFaceModel>> = _huggingFaceModels
|
||||
// Ongoing coroutine jobs
|
||||
private var huggingFaceQueryJob: Job? = null
|
||||
|
||||
init {
|
||||
viewModelScope.launch {
|
||||
|
|
@ -156,13 +154,16 @@ class ModelsManagementViewModel @Inject constructor(
|
|||
val managementState: StateFlow<ModelManagementState> = _managementState.asStateFlow()
|
||||
|
||||
fun resetManagementState() {
|
||||
huggingFaceQueryJob?.let {
|
||||
if (it.isActive) { it.cancel() }
|
||||
}
|
||||
_managementState.value = ModelManagementState.Idle
|
||||
}
|
||||
|
||||
/**
|
||||
* First show confirmation instead of starting import local file immediately
|
||||
*/
|
||||
fun localModelFileSelected(uri: Uri) = viewModelScope.launch {
|
||||
fun importLocalModelFileSelected(uri: Uri) = viewModelScope.launch {
|
||||
try {
|
||||
val fileName = getFileNameFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
|
||||
val fileSize = getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File name N/A")
|
||||
|
|
@ -177,7 +178,7 @@ class ModelsManagementViewModel @Inject constructor(
|
|||
/**
|
||||
* Import a local model file from device storage while updating UI states with realtime progress
|
||||
*/
|
||||
fun importLocalModelFile(uri: Uri, fileName: String, fileSize: Long) = viewModelScope.launch {
|
||||
fun importLocalModelFileConfirmed(uri: Uri, fileName: String, fileSize: Long) = viewModelScope.launch {
|
||||
try {
|
||||
_managementState.value = Importation.Importing(0f, fileName, fileSize)
|
||||
val model = modelRepository.importModel(uri, fileName, fileSize) { progress ->
|
||||
|
|
@ -221,29 +222,25 @@ class ModelsManagementViewModel @Inject constructor(
|
|||
/**
|
||||
* Query models on HuggingFace available for download even without signing in
|
||||
*/
|
||||
fun queryModelsFromHuggingFace() = viewModelScope.launch {
|
||||
modelRepository.searchHuggingFaceModels().let { models ->
|
||||
_huggingFaceModels.emit(models)
|
||||
Log.d(TAG, "Fetched ${models.size} models from HuggingFace:")
|
||||
|
||||
// TODO-han.yin: remove these logs
|
||||
// models.forEachIndexed { index, model ->
|
||||
// Log.d(TAG, "#$index: $model")
|
||||
// }
|
||||
fun queryModelsFromHuggingFace() {
|
||||
huggingFaceQueryJob = viewModelScope.launch {
|
||||
_managementState.emit(Download.Querying)
|
||||
try {
|
||||
val models = modelRepository.searchHuggingFaceModels()
|
||||
Log.d(TAG, "Fetched ${models.size} models from HuggingFace:")
|
||||
_managementState.emit(Download.Ready(models))
|
||||
} catch (_: CancellationException) {
|
||||
// no-op
|
||||
} catch (e: Exception) {
|
||||
_managementState.emit(Download.Error(message = e.message ?: "Unknown error"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* First show confirmation instead of dispatch download immediately
|
||||
*/
|
||||
fun downloadHuggingFaceModelSelected(model: HuggingFaceModel) {
|
||||
_managementState.value = Download.Confirming(model)
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispatch download request to [DownloadManager] and update UI
|
||||
*/
|
||||
fun downloadHuggingFaceModel(model: HuggingFaceModel) = viewModelScope.launch {
|
||||
fun downloadHuggingFaceModelConfirmed(model: HuggingFaceModel) = viewModelScope.launch {
|
||||
try {
|
||||
require(!model.gated) { "Model is gated!" }
|
||||
require(!model.private) { "Model is private!" }
|
||||
|
|
@ -324,7 +321,8 @@ sealed class ModelManagementState {
|
|||
}
|
||||
|
||||
sealed class Download: ModelManagementState() {
|
||||
data class Confirming(val model: HuggingFaceModel) : Download()
|
||||
object Querying : Download()
|
||||
data class Ready(val models: List<HuggingFaceModel>) : Download()
|
||||
data class Dispatched(val downloadInfo: HuggingFaceDownloadInfo) : Download()
|
||||
data class Error(val message: String) : Download()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue