data: allow canceling the ongoing model import
This commit is contained in:
parent
d70b8fe323
commit
9ba74a9d3d
|
|
@ -16,7 +16,10 @@ import com.example.llama.revamp.util.extractQuantizationFromFilename
|
|||
import com.example.llama.revamp.util.getFileNameFromUri
|
||||
import com.example.llama.revamp.util.getFileSizeFromUri
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.cancel
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.flow
|
||||
|
|
@ -34,10 +37,16 @@ import javax.inject.Singleton
|
|||
* Repository for managing available models on local device.
|
||||
*/
|
||||
interface ModelRepository {
|
||||
/**
|
||||
* Obtain the current status of local storage and available models.
|
||||
*/
|
||||
fun getStorageMetrics(): Flow<StorageMetrics>
|
||||
fun getModels(): Flow<List<ModelInfo>>
|
||||
suspend fun getModelById(id: String): ModelInfo?
|
||||
|
||||
/**
|
||||
* Import a local model file from device storage.
|
||||
*/
|
||||
suspend fun importModel(
|
||||
uri: Uri,
|
||||
name: String? = null,
|
||||
|
|
@ -45,14 +54,29 @@ interface ModelRepository {
|
|||
progressTracker: ImportProgressTracker? = null
|
||||
): ModelInfo
|
||||
|
||||
suspend fun updateModelLastUsed(modelId: String)
|
||||
|
||||
suspend fun deleteModel(modelId: String)
|
||||
suspend fun deleteModels(modelIds: List<String>)
|
||||
|
||||
fun interface ImportProgressTracker {
|
||||
fun onProgress(progress: Float) // 0.0f to 1.0f
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancels any ongoing local model import operation.
|
||||
*
|
||||
* @return null if no import is in progress,
|
||||
* true if successfully canceled,
|
||||
* false if cancellation failed
|
||||
*/
|
||||
suspend fun cancelImport(): Boolean?
|
||||
|
||||
/**
|
||||
* Update a model's last used timestamp.
|
||||
*/
|
||||
suspend fun updateModelLastUsed(modelId: String)
|
||||
|
||||
/**
|
||||
* Delete a model or in batches
|
||||
*/
|
||||
suspend fun deleteModel(modelId: String)
|
||||
suspend fun deleteModels(modelIds: List<String>)
|
||||
}
|
||||
|
||||
@Singleton
|
||||
|
|
@ -93,16 +117,27 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
override suspend fun getModelById(id: String) =
|
||||
modelDao.getModelById(id)?.toModelInfo()
|
||||
|
||||
|
||||
private var importJob: Job? = null
|
||||
private var currentModelFile: File? = null
|
||||
|
||||
override suspend fun importModel(
|
||||
uri: Uri,
|
||||
name: String?,
|
||||
size: Long?,
|
||||
progressTracker: ImportProgressTracker?
|
||||
): ModelInfo = withContext(Dispatchers.IO) {
|
||||
if (importJob != null && importJob?.isActive == true) {
|
||||
throw IllegalStateException("Another import is already in progress!")
|
||||
}
|
||||
|
||||
val fileName = name ?: getFileNameFromUri(context, uri) ?: throw FileNotFoundException("Filename N/A")
|
||||
val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
|
||||
val modelFile = File(modelsDir, fileName)
|
||||
|
||||
importJob = coroutineContext[Job]
|
||||
currentModelFile = modelFile
|
||||
|
||||
try {
|
||||
val inputStream = context.contentResolver.openInputStream(uri)
|
||||
?: throw IOException("Failed to open input stream")
|
||||
|
|
@ -163,12 +198,74 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
dateAdded = System.currentTimeMillis()
|
||||
).let {
|
||||
modelDao.insertModel(it)
|
||||
|
||||
importJob = null
|
||||
currentModelFile = null
|
||||
|
||||
it.toModelInfo()
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
// Clean up partially downloaded file if error occurs
|
||||
if (modelFile.exists()) { modelFile.delete() }
|
||||
|
||||
} catch (e: CancellationException) {
|
||||
Log.i(TAG, "Import was cancelled for $fileName: ${e.message}")
|
||||
cleanupPartialFile(modelFile)
|
||||
throw e
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Import failed for $fileName: ${e.message}")
|
||||
cleanupPartialFile(modelFile)
|
||||
throw e
|
||||
|
||||
} finally {
|
||||
importJob = null
|
||||
currentModelFile = null
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun cancelImport(): Boolean? = withContext(Dispatchers.IO) {
|
||||
val job = importJob
|
||||
val file = currentModelFile
|
||||
|
||||
return@withContext when {
|
||||
// No import in progress
|
||||
job == null -> null
|
||||
|
||||
// Job already completed or cancelled
|
||||
!job.isActive -> {
|
||||
importJob = null
|
||||
currentModelFile = null
|
||||
null
|
||||
}
|
||||
|
||||
// Job in progress
|
||||
else -> try {
|
||||
// Attempt to cancel the job
|
||||
job.cancel("Import cancelled by user")
|
||||
|
||||
// Give the job a moment to clean up
|
||||
delay(CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT)
|
||||
|
||||
// Clean up the partial file (as a safety measure)
|
||||
cleanupPartialFile(file)
|
||||
|
||||
// Reset state
|
||||
importJob = null
|
||||
currentModelFile = null
|
||||
|
||||
true // Successfully cancelled
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Failed to cancel import: ${e.message}")
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun cleanupPartialFile(file: File?) {
|
||||
try {
|
||||
if (file?.exists() == true && !file.delete()) {
|
||||
Log.e(TAG, "Failed to delete partial file: ${file.absolutePath}")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error cleaning up partial file: ${e.message}")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -212,6 +309,7 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
|
||||
private const val INTERNAL_STORAGE_PATH = "models"
|
||||
|
||||
private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L
|
||||
private const val BYTES_IN_GB = 1024f * 1024f * 1024f
|
||||
|
||||
private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024
|
||||
|
|
@ -219,8 +317,8 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
private const val NIO_YIELD_SIZE = 128 * 1024 * 1024
|
||||
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
|
||||
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
|
||||
private const val CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT = 500L
|
||||
|
||||
private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L
|
||||
private const val DEFAULT_CONTEXT_SIZE = 8192
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue