data: allow canceling the ongoing model import

This commit is contained in:
Han Yin 2025-04-16 14:28:21 -07:00
parent d70b8fe323
commit 9ba74a9d3d
1 changed files with 107 additions and 9 deletions

View File

@ -16,7 +16,10 @@ import com.example.llama.revamp.util.extractQuantizationFromFilename
import com.example.llama.revamp.util.getFileNameFromUri import com.example.llama.revamp.util.getFileNameFromUri
import com.example.llama.revamp.util.getFileSizeFromUri import com.example.llama.revamp.util.getFileSizeFromUri
import dagger.hilt.android.qualifiers.ApplicationContext import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancel
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flow
@ -34,10 +37,16 @@ import javax.inject.Singleton
* Repository for managing available models on local device. * Repository for managing available models on local device.
*/ */
interface ModelRepository { interface ModelRepository {
/**
* Obtain the current status of local storage and available models.
*/
fun getStorageMetrics(): Flow<StorageMetrics> fun getStorageMetrics(): Flow<StorageMetrics>
fun getModels(): Flow<List<ModelInfo>> fun getModels(): Flow<List<ModelInfo>>
suspend fun getModelById(id: String): ModelInfo? suspend fun getModelById(id: String): ModelInfo?
/**
* Import a local model file from device storage.
*/
suspend fun importModel( suspend fun importModel(
uri: Uri, uri: Uri,
name: String? = null, name: String? = null,
@ -45,14 +54,29 @@ interface ModelRepository {
progressTracker: ImportProgressTracker? = null progressTracker: ImportProgressTracker? = null
): ModelInfo ): ModelInfo
suspend fun updateModelLastUsed(modelId: String)
suspend fun deleteModel(modelId: String)
suspend fun deleteModels(modelIds: List<String>)
fun interface ImportProgressTracker { fun interface ImportProgressTracker {
fun onProgress(progress: Float) // 0.0f to 1.0f fun onProgress(progress: Float) // 0.0f to 1.0f
} }
/**
* Cancels any ongoing local model import operation.
*
* @return null if no import is in progress,
* true if successfully canceled,
* false if cancellation failed
*/
suspend fun cancelImport(): Boolean?
/**
* Update a model's last used timestamp.
*/
suspend fun updateModelLastUsed(modelId: String)
/**
* Delete a model or in batches
*/
suspend fun deleteModel(modelId: String)
suspend fun deleteModels(modelIds: List<String>)
} }
@Singleton @Singleton
@ -93,16 +117,27 @@ class ModelRepositoryImpl @Inject constructor(
override suspend fun getModelById(id: String) = override suspend fun getModelById(id: String) =
modelDao.getModelById(id)?.toModelInfo() modelDao.getModelById(id)?.toModelInfo()
private var importJob: Job? = null
private var currentModelFile: File? = null
override suspend fun importModel( override suspend fun importModel(
uri: Uri, uri: Uri,
name: String?, name: String?,
size: Long?, size: Long?,
progressTracker: ImportProgressTracker? progressTracker: ImportProgressTracker?
): ModelInfo = withContext(Dispatchers.IO) { ): ModelInfo = withContext(Dispatchers.IO) {
if (importJob != null && importJob?.isActive == true) {
throw IllegalStateException("Another import is already in progress!")
}
val fileName = name ?: getFileNameFromUri(context, uri) ?: throw FileNotFoundException("Filename N/A") val fileName = name ?: getFileNameFromUri(context, uri) ?: throw FileNotFoundException("Filename N/A")
val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A") val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
val modelFile = File(modelsDir, fileName) val modelFile = File(modelsDir, fileName)
importJob = coroutineContext[Job]
currentModelFile = modelFile
try { try {
val inputStream = context.contentResolver.openInputStream(uri) val inputStream = context.contentResolver.openInputStream(uri)
?: throw IOException("Failed to open input stream") ?: throw IOException("Failed to open input stream")
@ -163,12 +198,74 @@ class ModelRepositoryImpl @Inject constructor(
dateAdded = System.currentTimeMillis() dateAdded = System.currentTimeMillis()
).let { ).let {
modelDao.insertModel(it) modelDao.insertModel(it)
importJob = null
currentModelFile = null
it.toModelInfo() it.toModelInfo()
} }
} catch (e: Exception) {
// Clean up partially downloaded file if error occurs } catch (e: CancellationException) {
if (modelFile.exists()) { modelFile.delete() } Log.i(TAG, "Import was cancelled for $fileName: ${e.message}")
cleanupPartialFile(modelFile)
throw e throw e
} catch (e: Exception) {
Log.e(TAG, "Import failed for $fileName: ${e.message}")
cleanupPartialFile(modelFile)
throw e
} finally {
importJob = null
currentModelFile = null
}
}
override suspend fun cancelImport(): Boolean? = withContext(Dispatchers.IO) {
val job = importJob
val file = currentModelFile
return@withContext when {
// No import in progress
job == null -> null
// Job already completed or cancelled
!job.isActive -> {
importJob = null
currentModelFile = null
null
}
// Job in progress
else -> try {
// Attempt to cancel the job
job.cancel("Import cancelled by user")
// Give the job a moment to clean up
delay(CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT)
// Clean up the partial file (as a safety measure)
cleanupPartialFile(file)
// Reset state
importJob = null
currentModelFile = null
true // Successfully cancelled
} catch (e: Exception) {
Log.e(TAG, "Failed to cancel import: ${e.message}")
false
}
}
}
private fun cleanupPartialFile(file: File?) {
try {
if (file?.exists() == true && !file.delete()) {
Log.e(TAG, "Failed to delete partial file: ${file.absolutePath}")
}
} catch (e: Exception) {
Log.e(TAG, "Error cleaning up partial file: ${e.message}")
} }
} }
@ -212,6 +309,7 @@ class ModelRepositoryImpl @Inject constructor(
private const val INTERNAL_STORAGE_PATH = "models" private const val INTERNAL_STORAGE_PATH = "models"
private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L
private const val BYTES_IN_GB = 1024f * 1024f * 1024f private const val BYTES_IN_GB = 1024f * 1024f * 1024f
private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024 private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024
@ -219,8 +317,8 @@ class ModelRepositoryImpl @Inject constructor(
private const val NIO_YIELD_SIZE = 128 * 1024 * 1024 private const val NIO_YIELD_SIZE = 128 * 1024 * 1024
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024 private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024 private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
private const val CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT = 500L
private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L
private const val DEFAULT_CONTEXT_SIZE = 8192 private const val DEFAULT_CONTEXT_SIZE = 8192
} }