UI: use a broadcast receiver to listen for download complete events and show local import dialog.
This commit is contained in:
parent
f085d39c05
commit
85434e6580
|
|
@ -41,7 +41,7 @@ interface HuggingFaceRemoteDataSource {
|
|||
suspend fun downloadModelFile(
|
||||
context: Context,
|
||||
downloadInfo: HuggingFaceDownloadInfo,
|
||||
): Result<Unit>
|
||||
): Result<Long>
|
||||
}
|
||||
|
||||
@Singleton
|
||||
|
|
@ -94,13 +94,13 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
|||
override suspend fun downloadModelFile(
|
||||
context: Context,
|
||||
downloadInfo: HuggingFaceDownloadInfo,
|
||||
): Result<Unit> = withContext(Dispatchers.IO) {
|
||||
): Result<Long> = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
val downloadManager =
|
||||
context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager
|
||||
val request = DownloadManager.Request(downloadInfo.uri).apply {
|
||||
setTitle("HuggingFace model download")
|
||||
setDescription("Downloading ${downloadInfo.filename}")
|
||||
setTitle(downloadInfo.filename)
|
||||
setDescription("Downloading directly from HuggingFace")
|
||||
setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED)
|
||||
setDestinationInExternalPublicDir(
|
||||
Environment.DIRECTORY_DOWNLOADS,
|
||||
|
|
@ -142,18 +142,18 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
|||
}
|
||||
else -> {
|
||||
// Download is pending, paused, or running
|
||||
Result.success(Unit)
|
||||
Result.success(downloadId)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Assume success if we can't check status
|
||||
cursor.close()
|
||||
Result.success(Unit)
|
||||
Result.success(downloadId)
|
||||
}
|
||||
} else {
|
||||
// Assume success if cursor is empty
|
||||
cursor?.close()
|
||||
Result.success(Unit)
|
||||
Result.success(downloadId)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Failed to enqueue download: ${e.message}")
|
||||
|
|
|
|||
|
|
@ -99,12 +99,12 @@ interface ModelRepository {
|
|||
suspend fun getHuggingFaceModelFileSize(downloadInfo: HuggingFaceDownloadInfo): Long?
|
||||
|
||||
/**
|
||||
* Download and import a HuggingFace model
|
||||
* Download a HuggingFace model via system download manager
|
||||
*/
|
||||
suspend fun importHuggingFaceModel(
|
||||
suspend fun downloadHuggingFaceModel(
|
||||
downloadInfo: HuggingFaceDownloadInfo,
|
||||
actualSize: Long,
|
||||
): Result<Unit>
|
||||
): Result<Long>
|
||||
}
|
||||
|
||||
class InsufficientStorageException(message: String) : IOException(message)
|
||||
|
|
@ -367,10 +367,10 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
huggingFaceRemoteDataSource.getFileSize(downloadInfo.modelId, downloadInfo.filename)
|
||||
}
|
||||
|
||||
override suspend fun importHuggingFaceModel(
|
||||
override suspend fun downloadHuggingFaceModel(
|
||||
downloadInfo: HuggingFaceDownloadInfo,
|
||||
actualSize: Long,
|
||||
): Result<Unit> = withContext(Dispatchers.IO) {
|
||||
): Result<Long> = withContext(Dispatchers.IO) {
|
||||
if (!hasEnoughSpaceForImport(actualSize)) {
|
||||
throw InsufficientStorageException(
|
||||
"Not enough storage space! " +
|
||||
|
|
|
|||
|
|
@ -160,7 +160,9 @@ fun ModelsManagementScreen(
|
|||
isImporting = false,
|
||||
progress = 0.0f,
|
||||
onConfirm = {
|
||||
viewModel.importLocalModelFileConfirmed(state.uri, state.fileName, state.fileSize)
|
||||
viewModel.importLocalModelFileConfirmed(
|
||||
state.uri, state.fileName, state.fileSize
|
||||
)
|
||||
},
|
||||
onCancel = { viewModel.resetManagementState() }
|
||||
)
|
||||
|
|
@ -216,9 +218,7 @@ fun ModelsManagementScreen(
|
|||
LaunchedEffect(state) {
|
||||
onScaffoldEvent(
|
||||
ScaffoldEvent.ShowSnackbar(
|
||||
message = "Started downloading: ${state.downloadInfo.modelId}.\n" +
|
||||
// TODO-han.yin: replace this with a broadcast receiver!
|
||||
"Please come back to import it once completed.",
|
||||
message = "Started downloading:\n${state.downloadInfo.modelId}",
|
||||
duration = SnackbarDuration.Long,
|
||||
)
|
||||
)
|
||||
|
|
@ -227,6 +227,21 @@ fun ModelsManagementScreen(
|
|||
}
|
||||
}
|
||||
|
||||
is Download.Completed -> {
|
||||
ImportFromLocalFileDialog(
|
||||
fileName = state.fileName,
|
||||
fileSize = state.fileSize,
|
||||
isImporting = false,
|
||||
progress = 0.0f,
|
||||
onConfirm = {
|
||||
viewModel.importLocalModelFileConfirmed(
|
||||
state.uri, state.fileName, state.fileSize
|
||||
)
|
||||
},
|
||||
onCancel = { viewModel.resetManagementState() }
|
||||
)
|
||||
}
|
||||
|
||||
is Download.Error -> {
|
||||
ErrorDialog(
|
||||
title = "Download Failed",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
package com.example.llama.viewmodel
|
||||
|
||||
import android.app.DownloadManager
|
||||
import android.content.BroadcastReceiver
|
||||
import android.content.Context
|
||||
import android.content.Context.RECEIVER_EXPORTED
|
||||
import android.content.Intent
|
||||
import android.content.IntentFilter
|
||||
import android.net.Uri
|
||||
import android.util.Log
|
||||
import androidx.lifecycle.ViewModel
|
||||
|
|
@ -132,9 +137,21 @@ class ModelsManagementViewModel @Inject constructor(
|
|||
_showImportModelMenu.value = show
|
||||
}
|
||||
|
||||
// Ongoing coroutine jobs
|
||||
// HuggingFace: ongoing query jobs
|
||||
private var huggingFaceQueryJob: Job? = null
|
||||
|
||||
// HuggingFace: Ongoing download jobs
|
||||
private val activeDownloads = mutableMapOf<Long, HuggingFaceModel>()
|
||||
private val downloadReceiver = object : BroadcastReceiver() {
|
||||
override fun onReceive(context: Context, intent: Intent) {
|
||||
intent.getLongExtra(DownloadManager.EXTRA_DOWNLOAD_ID, -1).let { id ->
|
||||
if (id in activeDownloads) {
|
||||
handleDownloadComplete(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
init {
|
||||
viewModelScope.launch {
|
||||
combine(
|
||||
|
|
@ -147,6 +164,9 @@ class ModelsManagementViewModel @Inject constructor(
|
|||
_filteredModels.value = it
|
||||
}
|
||||
}
|
||||
|
||||
val filter = IntentFilter(DownloadManager.ACTION_DOWNLOAD_COMPLETE)
|
||||
context.registerReceiver(downloadReceiver, filter, RECEIVER_EXPORTED)
|
||||
}
|
||||
|
||||
// Internal state
|
||||
|
|
@ -251,8 +271,9 @@ class ModelsManagementViewModel @Inject constructor(
|
|||
requireNotNull(actualSize) { "Unknown model file size!" }
|
||||
Log.d(TAG, "Model file size: ${formatFileByteSize(actualSize)}")
|
||||
|
||||
modelRepository.importHuggingFaceModel(downloadInfo, actualSize)
|
||||
.onSuccess {
|
||||
modelRepository.downloadHuggingFaceModel(downloadInfo, actualSize)
|
||||
.onSuccess { downloadId ->
|
||||
activeDownloads[downloadId] = model
|
||||
_managementState.value = Download.Dispatched(downloadInfo)
|
||||
}
|
||||
.onFailure { throw it }
|
||||
|
|
@ -268,6 +289,23 @@ class ModelsManagementViewModel @Inject constructor(
|
|||
}
|
||||
}
|
||||
|
||||
private fun handleDownloadComplete(downloadId: Long) = viewModelScope.launch {
|
||||
val model = activeDownloads.remove(downloadId) ?: return@launch
|
||||
|
||||
(context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager)
|
||||
.getUriForDownloadedFile(downloadId)?.let { uri ->
|
||||
try {
|
||||
val fileName = getFileNameFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
|
||||
val fileSize = getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File name N/A")
|
||||
_managementState.emit(Download.Completed(model, uri, fileName, fileSize))
|
||||
} catch (e: Exception) {
|
||||
_managementState.value = Download.Error(
|
||||
message = e.message ?: "Unknown error downloading ${model.modelId}"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* First show confirmation instead of starting deletion immediately
|
||||
*/
|
||||
|
|
@ -324,6 +362,7 @@ sealed class ModelManagementState {
|
|||
object Querying : Download()
|
||||
data class Ready(val models: List<HuggingFaceModel>) : Download()
|
||||
data class Dispatched(val downloadInfo: HuggingFaceDownloadInfo) : Download()
|
||||
data class Completed(val model: HuggingFaceModel, val uri: Uri, val fileName: String, val fileSize: Long) : Download()
|
||||
data class Error(val message: String) : Download()
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue