diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt index e833439cd0..08a1f231e8 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt @@ -351,7 +351,7 @@ fun AppContent( // Create file launcher for importing local models val fileLauncher = rememberLauncherForActivityResult( contract = ActivityResultContracts.OpenDocument() - ) { uri -> uri?.let { modelsManagementViewModel.localModelFileSelected(it) } } + ) { uri -> uri?.let { modelsManagementViewModel.importLocalModelFileSelected(it) } } val bottomBarConfig = BottomBarConfig.ModelsManagement( sorting = BottomBarConfig.ModelsManagement.SortingConfig( diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceModel.kt index 78ee31bc1d..71f0364b95 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceModel.kt @@ -35,7 +35,7 @@ data class HuggingFaceModel( ) fun getGgufFilename(): String? = - siblings.map { it.rfilename }.first { it.endsWith(FILE_EXTENSION_GGUF) } + siblings.map { it.rfilename }.firstOrNull { it.endsWith(FILE_EXTENSION_GGUF) } fun toDownloadInfo() = getGgufFilename()?.let { HuggingFaceDownloadInfo(_id, modelId, it) } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt index 8f3434b74a..f2d52ce9ef 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt @@ -64,7 +64,7 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor( direction = direction, limit = limit, full = full, - ).filter { it.gated != true && it.private != true } + ).filter { it.gated != true && it.private != true && it.getGgufFilename() != null } } override suspend fun getModelDetails( diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementScreen.kt index 935ff506f8..19a8e34750 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementScreen.kt @@ -1,9 +1,11 @@ package com.example.llama.ui.screens import androidx.activity.compose.BackHandler +import androidx.compose.foundation.basicMarquee import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.PaddingValues import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxSize @@ -15,10 +17,18 @@ import androidx.compose.foundation.layout.width import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.items import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.Attribution +import androidx.compose.material.icons.filled.Download +import androidx.compose.material.icons.filled.Favorite import androidx.compose.material.icons.filled.FolderOpen +import androidx.compose.material.icons.filled.Today import androidx.compose.material3.AlertDialog import androidx.compose.material3.Button +import androidx.compose.material3.Card +import androidx.compose.material3.CardDefaults +import androidx.compose.material3.Checkbox import androidx.compose.material3.CircularProgressIndicator +import androidx.compose.material3.Icon import androidx.compose.material3.LinearProgressIndicator import androidx.compose.material3.MaterialTheme import androidx.compose.material3.SnackbarDuration @@ -30,23 +40,31 @@ import androidx.compose.runtime.collectAsState import androidx.compose.runtime.derivedStateOf import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateMapOf +import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember +import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.text.font.FontStyle +import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp import androidx.compose.ui.window.DialogProperties import com.example.llama.data.model.ModelInfo +import com.example.llama.data.remote.HuggingFaceModel import com.example.llama.ui.components.InfoView import com.example.llama.ui.components.ModelCardFullExpandable import com.example.llama.ui.scaffold.ScaffoldEvent +import com.example.llama.util.formatContextLength import com.example.llama.util.formatFileByteSize import com.example.llama.viewmodel.ModelManagementState import com.example.llama.viewmodel.ModelManagementState.Deletion +import com.example.llama.viewmodel.ModelManagementState.Download import com.example.llama.viewmodel.ModelManagementState.Importation import com.example.llama.viewmodel.ModelsManagementViewModel +import java.text.SimpleDateFormat +import java.util.Locale /** * Screen for managing LLM models (view, download, delete) @@ -103,7 +121,9 @@ fun ModelsManagementScreen( } else { // Model cards LazyColumn( - modifier = Modifier.fillMaxSize().padding(horizontal = 16.dp), + modifier = Modifier + .fillMaxSize() + .padding(horizontal = 16.dp), verticalArrangement = Arrangement.spacedBy(12.dp), ) { items(items = filteredModels, key = { it.id }) { model -> @@ -134,20 +154,20 @@ fun ModelsManagementScreen( // Model import progress overlay when (val state = managementState) { is Importation.Confirming -> { - ImportProgressDialog( + ImportFromLocalFileDialog( fileName = state.fileName, fileSize = state.fileSize, isImporting = false, progress = 0.0f, onConfirm = { - viewModel.importLocalModelFile(state.uri, state.fileName, state.fileSize) + viewModel.importLocalModelFileConfirmed(state.uri, state.fileName, state.fileSize) }, onCancel = { viewModel.resetManagementState() } ) } is Importation.Importing -> { - ImportProgressDialog( + ImportFromLocalFileDialog( fileName = state.fileName, fileSize = state.fileSize, isImporting = true, @@ -178,6 +198,43 @@ fun ModelsManagementScreen( } } + is Download.Querying -> { + ImportFromHuggingFaceDialog( + onCancel = { viewModel.resetManagementState() } + ) + } + + is Download.Ready -> { + ImportFromHuggingFaceDialog( + models = state.models, + onConfirm = { viewModel.downloadHuggingFaceModelConfirmed(it) }, + onCancel = { viewModel.resetManagementState() } + ) + } + + is Download.Dispatched -> { + LaunchedEffect(state) { + onScaffoldEvent( + ScaffoldEvent.ShowSnackbar( + message = "Started downloading: ${state.downloadInfo.modelId}.\n" + + // TODO-han.yin: replace this with a broadcast receiver! + "Please come back to import it once completed.", + duration = SnackbarDuration.Long, + ) + ) + + viewModel.resetManagementState() + } + } + + is Download.Error -> { + ErrorDialog( + title = "Download Failed", + message = state.message, + onDismiss = { viewModel.resetManagementState() } + ) + } + is Deletion.Confirming -> { BatchDeleteConfirmationDialog( count = state.models.size, @@ -219,33 +276,12 @@ fun ModelsManagementScreen( } is ModelManagementState.Idle -> { /* Idle state, nothing to show */ } - - else -> TODO() - } - } - - // TODO-han.yin: UI TO BE IMPLEMENTED - val huggingFaceModelsFlow = viewModel.huggingFaceModels - LaunchedEffect(Unit) { - huggingFaceModelsFlow.collect { models -> - val message = models.fold( - StringBuilder("Fetched ${models.size} models from HuggingFace") - ) { builder, model -> - builder.append(model.id).append("\n") - }.toString() - - onScaffoldEvent( - ScaffoldEvent.ShowSnackbar( - message = message, - duration = SnackbarDuration.Short - ) - ) } } } @Composable -private fun ImportProgressDialog( +private fun ImportFromLocalFileDialog( fileName: String, fileSize: Long, isImporting: Boolean, @@ -344,6 +380,200 @@ private fun ImportProgressDialog( ) } +@Composable +private fun ImportFromHuggingFaceDialog( + models: List? = null, + onConfirm: ((HuggingFaceModel) -> Unit)? = null, + onCancel: () -> Unit, +) { + val dateFormatter = remember { SimpleDateFormat("MMM, yyyy", Locale.getDefault()) } + + var selectedModel by remember { mutableStateOf(null) } + + AlertDialog( + onDismissRequest = {}, + properties = DialogProperties( + dismissOnBackPress = true, + dismissOnClickOutside = true + ), + title = { + Text(models?.let { "Fetched ${it.size} models" } ?: "Fetching models") + }, + text = { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center + ) { + Text( + modifier = Modifier.fillMaxWidth(), + text = models?.let { "Select a model to download:" } + ?: "Searching on HuggingFace for models available for direct download...", + style = MaterialTheme.typography.bodyLarge, + textAlign = TextAlign.Start, + ) + + if (models == null) { + Spacer(modifier = Modifier.height(24.dp)) + + CircularProgressIndicator( + modifier = Modifier.size(64.dp), + strokeWidth = 6.dp + ) + } else { + Spacer(modifier = Modifier.height(16.dp)) + + LazyColumn( + modifier = Modifier.fillMaxWidth(), + contentPadding = PaddingValues(vertical = 8.dp), + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + items(models) { model -> + HuggingFaceModelListItem( + model = model, + isSelected = model._id == selectedModel?._id, + dateFormatter = dateFormatter, + onToggleSelect = { selected -> + selectedModel = if (selected) model else null + } + ) + } + } + } + } + }, + confirmButton = { + onConfirm?.let { onSelect -> + TextButton( + onClick = { selectedModel?.let { onSelect.invoke(it) } }, + enabled = selectedModel != null + ) { + Text("Download") + } + } + }, + dismissButton = { + TextButton( + onClick = onCancel + ) { + Text("Cancel") + } + } + ) +} + +@Composable +fun HuggingFaceModelListItem( + model: HuggingFaceModel, + isSelected: Boolean, + dateFormatter: SimpleDateFormat, + onToggleSelect: (Boolean) -> Unit, +) { + Card( + modifier = Modifier.fillMaxWidth(), + colors = when (isSelected) { + true -> CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.primaryContainer + ) + false -> CardDefaults.cardColors() + }, + onClick = { onToggleSelect(!isSelected) } + ) { + Column(modifier = Modifier.fillMaxWidth().padding(8.dp)) { + Text( + modifier = Modifier.basicMarquee(), + text = model.modelId.substringAfterLast("/"), + style = MaterialTheme.typography.bodyMedium, + fontWeight = if (isSelected) FontWeight.Bold else FontWeight.Medium, + ) + + Spacer(modifier = Modifier.size(8.dp)) + + Row(verticalAlignment = Alignment.Bottom) { + Column { + Row(verticalAlignment = Alignment.CenterVertically) { + Icon( + imageVector = Icons.Default.Attribution, + contentDescription = "Author", + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.size(4.dp)) + + Text( + text = model.author, + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + Spacer(modifier = Modifier.size(8.dp)) + + Row(verticalAlignment = Alignment.CenterVertically) { + Icon( + imageVector = Icons.Default.Today, + contentDescription = "Author", + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.size(4.dp)) + + Text( + text = dateFormatter.format(model.lastModified), + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.size(8.dp)) + + Icon( + imageVector = Icons.Default.Favorite, + contentDescription = "Favorite count", + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.size(4.dp)) + + Text( + text = formatContextLength(model.likes), + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.size(8.dp)) + + Icon( + imageVector = Icons.Default.Download, + contentDescription = "Download count", + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.size(4.dp)) + + Text( + text = formatContextLength(model.downloads), + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + + Spacer(Modifier.weight(1f)) + + if (isSelected) { + Checkbox( + checked = isSelected, + onCheckedChange = null, // handled by parent selectable + ) + } + } + } + } +} + @Composable private fun BatchDeleteConfirmationDialog( count: Int, diff --git a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt index 6f8f7cc562..59f14227ef 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt @@ -22,10 +22,10 @@ import com.example.llama.viewmodel.ModelManagementState.Download import com.example.llama.viewmodel.ModelManagementState.Importation import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.qualifiers.ApplicationContext +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.Job import kotlinx.coroutines.delay -import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.MutableStateFlow -import kotlinx.coroutines.flow.SharedFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.collectLatest @@ -34,7 +34,6 @@ import kotlinx.coroutines.flow.update import kotlinx.coroutines.launch import java.io.FileNotFoundException import javax.inject.Inject -import kotlin.collections.set @HiltViewModel class ModelsManagementViewModel @Inject constructor( @@ -133,9 +132,8 @@ class ModelsManagementViewModel @Inject constructor( _showImportModelMenu.value = show } - // UI state: HuggingFace models query result - private val _huggingFaceModels = MutableSharedFlow>() - val huggingFaceModels: SharedFlow> = _huggingFaceModels + // Ongoing coroutine jobs + private var huggingFaceQueryJob: Job? = null init { viewModelScope.launch { @@ -156,13 +154,16 @@ class ModelsManagementViewModel @Inject constructor( val managementState: StateFlow = _managementState.asStateFlow() fun resetManagementState() { + huggingFaceQueryJob?.let { + if (it.isActive) { it.cancel() } + } _managementState.value = ModelManagementState.Idle } /** * First show confirmation instead of starting import local file immediately */ - fun localModelFileSelected(uri: Uri) = viewModelScope.launch { + fun importLocalModelFileSelected(uri: Uri) = viewModelScope.launch { try { val fileName = getFileNameFromUri(context, uri) ?: throw FileNotFoundException("File size N/A") val fileSize = getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File name N/A") @@ -177,7 +178,7 @@ class ModelsManagementViewModel @Inject constructor( /** * Import a local model file from device storage while updating UI states with realtime progress */ - fun importLocalModelFile(uri: Uri, fileName: String, fileSize: Long) = viewModelScope.launch { + fun importLocalModelFileConfirmed(uri: Uri, fileName: String, fileSize: Long) = viewModelScope.launch { try { _managementState.value = Importation.Importing(0f, fileName, fileSize) val model = modelRepository.importModel(uri, fileName, fileSize) { progress -> @@ -221,29 +222,25 @@ class ModelsManagementViewModel @Inject constructor( /** * Query models on HuggingFace available for download even without signing in */ - fun queryModelsFromHuggingFace() = viewModelScope.launch { - modelRepository.searchHuggingFaceModels().let { models -> - _huggingFaceModels.emit(models) - Log.d(TAG, "Fetched ${models.size} models from HuggingFace:") - - // TODO-han.yin: remove these logs -// models.forEachIndexed { index, model -> -// Log.d(TAG, "#$index: $model") -// } + fun queryModelsFromHuggingFace() { + huggingFaceQueryJob = viewModelScope.launch { + _managementState.emit(Download.Querying) + try { + val models = modelRepository.searchHuggingFaceModels() + Log.d(TAG, "Fetched ${models.size} models from HuggingFace:") + _managementState.emit(Download.Ready(models)) + } catch (_: CancellationException) { + // no-op + } catch (e: Exception) { + _managementState.emit(Download.Error(message = e.message ?: "Unknown error")) + } } } - /** - * First show confirmation instead of dispatch download immediately - */ - fun downloadHuggingFaceModelSelected(model: HuggingFaceModel) { - _managementState.value = Download.Confirming(model) - } - /** * Dispatch download request to [DownloadManager] and update UI */ - fun downloadHuggingFaceModel(model: HuggingFaceModel) = viewModelScope.launch { + fun downloadHuggingFaceModelConfirmed(model: HuggingFaceModel) = viewModelScope.launch { try { require(!model.gated) { "Model is gated!" } require(!model.private) { "Model is private!" } @@ -324,7 +321,8 @@ sealed class ModelManagementState { } sealed class Download: ModelManagementState() { - data class Confirming(val model: HuggingFaceModel) : Download() + object Querying : Download() + data class Ready(val models: List) : Download() data class Dispatched(val downloadInfo: HuggingFaceDownloadInfo) : Download() data class Error(val message: String) : Download() }