diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt index f2d52ce9ef..6e6c1a6c3c 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceRemoteDataSource.kt @@ -41,7 +41,7 @@ interface HuggingFaceRemoteDataSource { suspend fun downloadModelFile( context: Context, downloadInfo: HuggingFaceDownloadInfo, - ): Result + ): Result } @Singleton @@ -94,13 +94,13 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor( override suspend fun downloadModelFile( context: Context, downloadInfo: HuggingFaceDownloadInfo, - ): Result = withContext(Dispatchers.IO) { + ): Result = withContext(Dispatchers.IO) { try { val downloadManager = context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager val request = DownloadManager.Request(downloadInfo.uri).apply { - setTitle("HuggingFace model download") - setDescription("Downloading ${downloadInfo.filename}") + setTitle(downloadInfo.filename) + setDescription("Downloading directly from HuggingFace") setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED) setDestinationInExternalPublicDir( Environment.DIRECTORY_DOWNLOADS, @@ -142,18 +142,18 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor( } else -> { // Download is pending, paused, or running - Result.success(Unit) + Result.success(downloadId) } } } else { // Assume success if we can't check status cursor.close() - Result.success(Unit) + Result.success(downloadId) } } else { // Assume success if cursor is empty cursor?.close() - Result.success(Unit) + Result.success(downloadId) } } catch (e: Exception) { Log.e(TAG, "Failed to enqueue download: ${e.message}") diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/repository/ModelRepository.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/repository/ModelRepository.kt index a0efa04aea..4a398b332f 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/repository/ModelRepository.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/repository/ModelRepository.kt @@ -99,12 +99,12 @@ interface ModelRepository { suspend fun getHuggingFaceModelFileSize(downloadInfo: HuggingFaceDownloadInfo): Long? /** - * Download and import a HuggingFace model + * Download a HuggingFace model via system download manager */ - suspend fun importHuggingFaceModel( + suspend fun downloadHuggingFaceModel( downloadInfo: HuggingFaceDownloadInfo, actualSize: Long, - ): Result + ): Result } class InsufficientStorageException(message: String) : IOException(message) @@ -367,10 +367,10 @@ class ModelRepositoryImpl @Inject constructor( huggingFaceRemoteDataSource.getFileSize(downloadInfo.modelId, downloadInfo.filename) } - override suspend fun importHuggingFaceModel( + override suspend fun downloadHuggingFaceModel( downloadInfo: HuggingFaceDownloadInfo, actualSize: Long, - ): Result = withContext(Dispatchers.IO) { + ): Result = withContext(Dispatchers.IO) { if (!hasEnoughSpaceForImport(actualSize)) { throw InsufficientStorageException( "Not enough storage space! " + diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementScreen.kt index 19a8e34750..caf6d45798 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementScreen.kt @@ -160,7 +160,9 @@ fun ModelsManagementScreen( isImporting = false, progress = 0.0f, onConfirm = { - viewModel.importLocalModelFileConfirmed(state.uri, state.fileName, state.fileSize) + viewModel.importLocalModelFileConfirmed( + state.uri, state.fileName, state.fileSize + ) }, onCancel = { viewModel.resetManagementState() } ) @@ -216,9 +218,7 @@ fun ModelsManagementScreen( LaunchedEffect(state) { onScaffoldEvent( ScaffoldEvent.ShowSnackbar( - message = "Started downloading: ${state.downloadInfo.modelId}.\n" + - // TODO-han.yin: replace this with a broadcast receiver! - "Please come back to import it once completed.", + message = "Started downloading:\n${state.downloadInfo.modelId}", duration = SnackbarDuration.Long, ) ) @@ -227,6 +227,21 @@ fun ModelsManagementScreen( } } + is Download.Completed -> { + ImportFromLocalFileDialog( + fileName = state.fileName, + fileSize = state.fileSize, + isImporting = false, + progress = 0.0f, + onConfirm = { + viewModel.importLocalModelFileConfirmed( + state.uri, state.fileName, state.fileSize + ) + }, + onCancel = { viewModel.resetManagementState() } + ) + } + is Download.Error -> { ErrorDialog( title = "Download Failed", diff --git a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt index 59f14227ef..6223c0a601 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt @@ -1,6 +1,11 @@ package com.example.llama.viewmodel +import android.app.DownloadManager +import android.content.BroadcastReceiver import android.content.Context +import android.content.Context.RECEIVER_EXPORTED +import android.content.Intent +import android.content.IntentFilter import android.net.Uri import android.util.Log import androidx.lifecycle.ViewModel @@ -132,9 +137,21 @@ class ModelsManagementViewModel @Inject constructor( _showImportModelMenu.value = show } - // Ongoing coroutine jobs + // HuggingFace: ongoing query jobs private var huggingFaceQueryJob: Job? = null + // HuggingFace: Ongoing download jobs + private val activeDownloads = mutableMapOf() + private val downloadReceiver = object : BroadcastReceiver() { + override fun onReceive(context: Context, intent: Intent) { + intent.getLongExtra(DownloadManager.EXTRA_DOWNLOAD_ID, -1).let { id -> + if (id in activeDownloads) { + handleDownloadComplete(id) + } + } + } + } + init { viewModelScope.launch { combine( @@ -147,6 +164,9 @@ class ModelsManagementViewModel @Inject constructor( _filteredModels.value = it } } + + val filter = IntentFilter(DownloadManager.ACTION_DOWNLOAD_COMPLETE) + context.registerReceiver(downloadReceiver, filter, RECEIVER_EXPORTED) } // Internal state @@ -251,8 +271,9 @@ class ModelsManagementViewModel @Inject constructor( requireNotNull(actualSize) { "Unknown model file size!" } Log.d(TAG, "Model file size: ${formatFileByteSize(actualSize)}") - modelRepository.importHuggingFaceModel(downloadInfo, actualSize) - .onSuccess { + modelRepository.downloadHuggingFaceModel(downloadInfo, actualSize) + .onSuccess { downloadId -> + activeDownloads[downloadId] = model _managementState.value = Download.Dispatched(downloadInfo) } .onFailure { throw it } @@ -268,6 +289,23 @@ class ModelsManagementViewModel @Inject constructor( } } + private fun handleDownloadComplete(downloadId: Long) = viewModelScope.launch { + val model = activeDownloads.remove(downloadId) ?: return@launch + + (context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager) + .getUriForDownloadedFile(downloadId)?.let { uri -> + try { + val fileName = getFileNameFromUri(context, uri) ?: throw FileNotFoundException("File size N/A") + val fileSize = getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File name N/A") + _managementState.emit(Download.Completed(model, uri, fileName, fileSize)) + } catch (e: Exception) { + _managementState.value = Download.Error( + message = e.message ?: "Unknown error downloading ${model.modelId}" + ) + } + } + } + /** * First show confirmation instead of starting deletion immediately */ @@ -324,6 +362,7 @@ sealed class ModelManagementState { object Querying : Download() data class Ready(val models: List) : Download() data class Dispatched(val downloadInfo: HuggingFaceDownloadInfo) : Download() + data class Completed(val model: HuggingFaceModel, val uri: Uri, val fileName: String, val fileSize: Long) : Download() data class Error(val message: String) : Download() }