data: allow canceling the ongoing model import

This commit is contained in:
Han Yin 2025-04-16 14:28:21 -07:00
parent d70b8fe323
commit 9ba74a9d3d
1 changed files with 107 additions and 9 deletions

View File

@ -16,7 +16,10 @@ import com.example.llama.revamp.util.extractQuantizationFromFilename
import com.example.llama.revamp.util.getFileNameFromUri
import com.example.llama.revamp.util.getFileSizeFromUri
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancel
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
@ -34,10 +37,16 @@ import javax.inject.Singleton
* Repository for managing available models on local device.
*/
interface ModelRepository {
/**
* Obtain the current status of local storage and available models.
*/
fun getStorageMetrics(): Flow<StorageMetrics>
fun getModels(): Flow<List<ModelInfo>>
suspend fun getModelById(id: String): ModelInfo?
/**
* Import a local model file from device storage.
*/
suspend fun importModel(
uri: Uri,
name: String? = null,
@ -45,14 +54,29 @@ interface ModelRepository {
progressTracker: ImportProgressTracker? = null
): ModelInfo
suspend fun updateModelLastUsed(modelId: String)
suspend fun deleteModel(modelId: String)
suspend fun deleteModels(modelIds: List<String>)
fun interface ImportProgressTracker {
fun onProgress(progress: Float) // 0.0f to 1.0f
}
/**
* Cancels any ongoing local model import operation.
*
* @return null if no import is in progress,
* true if successfully canceled,
* false if cancellation failed
*/
suspend fun cancelImport(): Boolean?
/**
* Update a model's last used timestamp.
*/
suspend fun updateModelLastUsed(modelId: String)
/**
* Delete a model or in batches
*/
suspend fun deleteModel(modelId: String)
suspend fun deleteModels(modelIds: List<String>)
}
@Singleton
@ -93,16 +117,27 @@ class ModelRepositoryImpl @Inject constructor(
override suspend fun getModelById(id: String) =
modelDao.getModelById(id)?.toModelInfo()
private var importJob: Job? = null
private var currentModelFile: File? = null
override suspend fun importModel(
uri: Uri,
name: String?,
size: Long?,
progressTracker: ImportProgressTracker?
): ModelInfo = withContext(Dispatchers.IO) {
if (importJob != null && importJob?.isActive == true) {
throw IllegalStateException("Another import is already in progress!")
}
val fileName = name ?: getFileNameFromUri(context, uri) ?: throw FileNotFoundException("Filename N/A")
val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
val modelFile = File(modelsDir, fileName)
importJob = coroutineContext[Job]
currentModelFile = modelFile
try {
val inputStream = context.contentResolver.openInputStream(uri)
?: throw IOException("Failed to open input stream")
@ -163,12 +198,74 @@ class ModelRepositoryImpl @Inject constructor(
dateAdded = System.currentTimeMillis()
).let {
modelDao.insertModel(it)
importJob = null
currentModelFile = null
it.toModelInfo()
}
} catch (e: Exception) {
// Clean up partially downloaded file if error occurs
if (modelFile.exists()) { modelFile.delete() }
} catch (e: CancellationException) {
Log.i(TAG, "Import was cancelled for $fileName: ${e.message}")
cleanupPartialFile(modelFile)
throw e
} catch (e: Exception) {
Log.e(TAG, "Import failed for $fileName: ${e.message}")
cleanupPartialFile(modelFile)
throw e
} finally {
importJob = null
currentModelFile = null
}
}
override suspend fun cancelImport(): Boolean? = withContext(Dispatchers.IO) {
val job = importJob
val file = currentModelFile
return@withContext when {
// No import in progress
job == null -> null
// Job already completed or cancelled
!job.isActive -> {
importJob = null
currentModelFile = null
null
}
// Job in progress
else -> try {
// Attempt to cancel the job
job.cancel("Import cancelled by user")
// Give the job a moment to clean up
delay(CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT)
// Clean up the partial file (as a safety measure)
cleanupPartialFile(file)
// Reset state
importJob = null
currentModelFile = null
true // Successfully cancelled
} catch (e: Exception) {
Log.e(TAG, "Failed to cancel import: ${e.message}")
false
}
}
}
private fun cleanupPartialFile(file: File?) {
try {
if (file?.exists() == true && !file.delete()) {
Log.e(TAG, "Failed to delete partial file: ${file.absolutePath}")
}
} catch (e: Exception) {
Log.e(TAG, "Error cleaning up partial file: ${e.message}")
}
}
@ -212,6 +309,7 @@ class ModelRepositoryImpl @Inject constructor(
private const val INTERNAL_STORAGE_PATH = "models"
private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L
private const val BYTES_IN_GB = 1024f * 1024f * 1024f
private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024
@ -219,8 +317,8 @@ class ModelRepositoryImpl @Inject constructor(
private const val NIO_YIELD_SIZE = 128 * 1024 * 1024
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
private const val CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT = 500L
private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L
private const val DEFAULT_CONTEXT_SIZE = 8192
}