From 9ba74a9d3de24bc37b52414207fa35e0184aee65 Mon Sep 17 00:00:00 2001 From: Han Yin Date: Wed, 16 Apr 2025 14:28:21 -0700 Subject: [PATCH] data: allow canceling the ongoing model import --- .../revamp/data/repository/ModelRepository.kt | 116 ++++++++++++++++-- 1 file changed, 107 insertions(+), 9 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt index 40d6d33247..bc76ad8f9a 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt @@ -16,7 +16,10 @@ import com.example.llama.revamp.util.extractQuantizationFromFilename import com.example.llama.revamp.util.getFileNameFromUri import com.example.llama.revamp.util.getFileSizeFromUri import dagger.hilt.android.qualifiers.ApplicationContext +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.cancel import kotlinx.coroutines.delay import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow @@ -34,10 +37,16 @@ import javax.inject.Singleton * Repository for managing available models on local device. */ interface ModelRepository { + /** + * Obtain the current status of local storage and available models. + */ fun getStorageMetrics(): Flow fun getModels(): Flow> suspend fun getModelById(id: String): ModelInfo? + /** + * Import a local model file from device storage. + */ suspend fun importModel( uri: Uri, name: String? = null, @@ -45,14 +54,29 @@ interface ModelRepository { progressTracker: ImportProgressTracker? = null ): ModelInfo - suspend fun updateModelLastUsed(modelId: String) - - suspend fun deleteModel(modelId: String) - suspend fun deleteModels(modelIds: List) - fun interface ImportProgressTracker { fun onProgress(progress: Float) // 0.0f to 1.0f } + + /** + * Cancels any ongoing local model import operation. + * + * @return null if no import is in progress, + * true if successfully canceled, + * false if cancellation failed + */ + suspend fun cancelImport(): Boolean? + + /** + * Update a model's last used timestamp. + */ + suspend fun updateModelLastUsed(modelId: String) + + /** + * Delete a model or in batches + */ + suspend fun deleteModel(modelId: String) + suspend fun deleteModels(modelIds: List) } @Singleton @@ -93,16 +117,27 @@ class ModelRepositoryImpl @Inject constructor( override suspend fun getModelById(id: String) = modelDao.getModelById(id)?.toModelInfo() + + private var importJob: Job? = null + private var currentModelFile: File? = null + override suspend fun importModel( uri: Uri, name: String?, size: Long?, progressTracker: ImportProgressTracker? ): ModelInfo = withContext(Dispatchers.IO) { + if (importJob != null && importJob?.isActive == true) { + throw IllegalStateException("Another import is already in progress!") + } + val fileName = name ?: getFileNameFromUri(context, uri) ?: throw FileNotFoundException("Filename N/A") val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A") val modelFile = File(modelsDir, fileName) + importJob = coroutineContext[Job] + currentModelFile = modelFile + try { val inputStream = context.contentResolver.openInputStream(uri) ?: throw IOException("Failed to open input stream") @@ -163,12 +198,74 @@ class ModelRepositoryImpl @Inject constructor( dateAdded = System.currentTimeMillis() ).let { modelDao.insertModel(it) + + importJob = null + currentModelFile = null + it.toModelInfo() } - } catch (e: Exception) { - // Clean up partially downloaded file if error occurs - if (modelFile.exists()) { modelFile.delete() } + + } catch (e: CancellationException) { + Log.i(TAG, "Import was cancelled for $fileName: ${e.message}") + cleanupPartialFile(modelFile) throw e + + } catch (e: Exception) { + Log.e(TAG, "Import failed for $fileName: ${e.message}") + cleanupPartialFile(modelFile) + throw e + + } finally { + importJob = null + currentModelFile = null + } + } + + override suspend fun cancelImport(): Boolean? = withContext(Dispatchers.IO) { + val job = importJob + val file = currentModelFile + + return@withContext when { + // No import in progress + job == null -> null + + // Job already completed or cancelled + !job.isActive -> { + importJob = null + currentModelFile = null + null + } + + // Job in progress + else -> try { + // Attempt to cancel the job + job.cancel("Import cancelled by user") + + // Give the job a moment to clean up + delay(CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT) + + // Clean up the partial file (as a safety measure) + cleanupPartialFile(file) + + // Reset state + importJob = null + currentModelFile = null + + true // Successfully cancelled + } catch (e: Exception) { + Log.e(TAG, "Failed to cancel import: ${e.message}") + false + } + } + } + + private fun cleanupPartialFile(file: File?) { + try { + if (file?.exists() == true && !file.delete()) { + Log.e(TAG, "Failed to delete partial file: ${file.absolutePath}") + } + } catch (e: Exception) { + Log.e(TAG, "Error cleaning up partial file: ${e.message}") } } @@ -212,6 +309,7 @@ class ModelRepositoryImpl @Inject constructor( private const val INTERNAL_STORAGE_PATH = "models" + private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L private const val BYTES_IN_GB = 1024f * 1024f * 1024f private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024 @@ -219,8 +317,8 @@ class ModelRepositoryImpl @Inject constructor( private const val NIO_YIELD_SIZE = 128 * 1024 * 1024 private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024 private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024 + private const val CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT = 500L - private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L private const val DEFAULT_CONTEXT_SIZE = 8192 }