data: allow canceling the ongoing model import
This commit is contained in:
parent
d70b8fe323
commit
9ba74a9d3d
|
|
@ -16,7 +16,10 @@ import com.example.llama.revamp.util.extractQuantizationFromFilename
|
||||||
import com.example.llama.revamp.util.getFileNameFromUri
|
import com.example.llama.revamp.util.getFileNameFromUri
|
||||||
import com.example.llama.revamp.util.getFileSizeFromUri
|
import com.example.llama.revamp.util.getFileSizeFromUri
|
||||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||||
|
import kotlinx.coroutines.CancellationException
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.Job
|
||||||
|
import kotlinx.coroutines.cancel
|
||||||
import kotlinx.coroutines.delay
|
import kotlinx.coroutines.delay
|
||||||
import kotlinx.coroutines.flow.Flow
|
import kotlinx.coroutines.flow.Flow
|
||||||
import kotlinx.coroutines.flow.flow
|
import kotlinx.coroutines.flow.flow
|
||||||
|
|
@ -34,10 +37,16 @@ import javax.inject.Singleton
|
||||||
* Repository for managing available models on local device.
|
* Repository for managing available models on local device.
|
||||||
*/
|
*/
|
||||||
interface ModelRepository {
|
interface ModelRepository {
|
||||||
|
/**
|
||||||
|
* Obtain the current status of local storage and available models.
|
||||||
|
*/
|
||||||
fun getStorageMetrics(): Flow<StorageMetrics>
|
fun getStorageMetrics(): Flow<StorageMetrics>
|
||||||
fun getModels(): Flow<List<ModelInfo>>
|
fun getModels(): Flow<List<ModelInfo>>
|
||||||
suspend fun getModelById(id: String): ModelInfo?
|
suspend fun getModelById(id: String): ModelInfo?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Import a local model file from device storage.
|
||||||
|
*/
|
||||||
suspend fun importModel(
|
suspend fun importModel(
|
||||||
uri: Uri,
|
uri: Uri,
|
||||||
name: String? = null,
|
name: String? = null,
|
||||||
|
|
@ -45,14 +54,29 @@ interface ModelRepository {
|
||||||
progressTracker: ImportProgressTracker? = null
|
progressTracker: ImportProgressTracker? = null
|
||||||
): ModelInfo
|
): ModelInfo
|
||||||
|
|
||||||
suspend fun updateModelLastUsed(modelId: String)
|
|
||||||
|
|
||||||
suspend fun deleteModel(modelId: String)
|
|
||||||
suspend fun deleteModels(modelIds: List<String>)
|
|
||||||
|
|
||||||
fun interface ImportProgressTracker {
|
fun interface ImportProgressTracker {
|
||||||
fun onProgress(progress: Float) // 0.0f to 1.0f
|
fun onProgress(progress: Float) // 0.0f to 1.0f
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cancels any ongoing local model import operation.
|
||||||
|
*
|
||||||
|
* @return null if no import is in progress,
|
||||||
|
* true if successfully canceled,
|
||||||
|
* false if cancellation failed
|
||||||
|
*/
|
||||||
|
suspend fun cancelImport(): Boolean?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update a model's last used timestamp.
|
||||||
|
*/
|
||||||
|
suspend fun updateModelLastUsed(modelId: String)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Delete a model or in batches
|
||||||
|
*/
|
||||||
|
suspend fun deleteModel(modelId: String)
|
||||||
|
suspend fun deleteModels(modelIds: List<String>)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Singleton
|
@Singleton
|
||||||
|
|
@ -93,16 +117,27 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
override suspend fun getModelById(id: String) =
|
override suspend fun getModelById(id: String) =
|
||||||
modelDao.getModelById(id)?.toModelInfo()
|
modelDao.getModelById(id)?.toModelInfo()
|
||||||
|
|
||||||
|
|
||||||
|
private var importJob: Job? = null
|
||||||
|
private var currentModelFile: File? = null
|
||||||
|
|
||||||
override suspend fun importModel(
|
override suspend fun importModel(
|
||||||
uri: Uri,
|
uri: Uri,
|
||||||
name: String?,
|
name: String?,
|
||||||
size: Long?,
|
size: Long?,
|
||||||
progressTracker: ImportProgressTracker?
|
progressTracker: ImportProgressTracker?
|
||||||
): ModelInfo = withContext(Dispatchers.IO) {
|
): ModelInfo = withContext(Dispatchers.IO) {
|
||||||
|
if (importJob != null && importJob?.isActive == true) {
|
||||||
|
throw IllegalStateException("Another import is already in progress!")
|
||||||
|
}
|
||||||
|
|
||||||
val fileName = name ?: getFileNameFromUri(context, uri) ?: throw FileNotFoundException("Filename N/A")
|
val fileName = name ?: getFileNameFromUri(context, uri) ?: throw FileNotFoundException("Filename N/A")
|
||||||
val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
|
val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
|
||||||
val modelFile = File(modelsDir, fileName)
|
val modelFile = File(modelsDir, fileName)
|
||||||
|
|
||||||
|
importJob = coroutineContext[Job]
|
||||||
|
currentModelFile = modelFile
|
||||||
|
|
||||||
try {
|
try {
|
||||||
val inputStream = context.contentResolver.openInputStream(uri)
|
val inputStream = context.contentResolver.openInputStream(uri)
|
||||||
?: throw IOException("Failed to open input stream")
|
?: throw IOException("Failed to open input stream")
|
||||||
|
|
@ -163,12 +198,74 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
dateAdded = System.currentTimeMillis()
|
dateAdded = System.currentTimeMillis()
|
||||||
).let {
|
).let {
|
||||||
modelDao.insertModel(it)
|
modelDao.insertModel(it)
|
||||||
|
|
||||||
|
importJob = null
|
||||||
|
currentModelFile = null
|
||||||
|
|
||||||
it.toModelInfo()
|
it.toModelInfo()
|
||||||
}
|
}
|
||||||
} catch (e: Exception) {
|
|
||||||
// Clean up partially downloaded file if error occurs
|
} catch (e: CancellationException) {
|
||||||
if (modelFile.exists()) { modelFile.delete() }
|
Log.i(TAG, "Import was cancelled for $fileName: ${e.message}")
|
||||||
|
cleanupPartialFile(modelFile)
|
||||||
throw e
|
throw e
|
||||||
|
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Import failed for $fileName: ${e.message}")
|
||||||
|
cleanupPartialFile(modelFile)
|
||||||
|
throw e
|
||||||
|
|
||||||
|
} finally {
|
||||||
|
importJob = null
|
||||||
|
currentModelFile = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override suspend fun cancelImport(): Boolean? = withContext(Dispatchers.IO) {
|
||||||
|
val job = importJob
|
||||||
|
val file = currentModelFile
|
||||||
|
|
||||||
|
return@withContext when {
|
||||||
|
// No import in progress
|
||||||
|
job == null -> null
|
||||||
|
|
||||||
|
// Job already completed or cancelled
|
||||||
|
!job.isActive -> {
|
||||||
|
importJob = null
|
||||||
|
currentModelFile = null
|
||||||
|
null
|
||||||
|
}
|
||||||
|
|
||||||
|
// Job in progress
|
||||||
|
else -> try {
|
||||||
|
// Attempt to cancel the job
|
||||||
|
job.cancel("Import cancelled by user")
|
||||||
|
|
||||||
|
// Give the job a moment to clean up
|
||||||
|
delay(CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT)
|
||||||
|
|
||||||
|
// Clean up the partial file (as a safety measure)
|
||||||
|
cleanupPartialFile(file)
|
||||||
|
|
||||||
|
// Reset state
|
||||||
|
importJob = null
|
||||||
|
currentModelFile = null
|
||||||
|
|
||||||
|
true // Successfully cancelled
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Failed to cancel import: ${e.message}")
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun cleanupPartialFile(file: File?) {
|
||||||
|
try {
|
||||||
|
if (file?.exists() == true && !file.delete()) {
|
||||||
|
Log.e(TAG, "Failed to delete partial file: ${file.absolutePath}")
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Error cleaning up partial file: ${e.message}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -212,6 +309,7 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
|
|
||||||
private const val INTERNAL_STORAGE_PATH = "models"
|
private const val INTERNAL_STORAGE_PATH = "models"
|
||||||
|
|
||||||
|
private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L
|
||||||
private const val BYTES_IN_GB = 1024f * 1024f * 1024f
|
private const val BYTES_IN_GB = 1024f * 1024f * 1024f
|
||||||
|
|
||||||
private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024
|
private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024
|
||||||
|
|
@ -219,8 +317,8 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
private const val NIO_YIELD_SIZE = 128 * 1024 * 1024
|
private const val NIO_YIELD_SIZE = 128 * 1024 * 1024
|
||||||
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
|
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
|
||||||
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
|
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
|
||||||
|
private const val CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT = 500L
|
||||||
|
|
||||||
private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L
|
|
||||||
private const val DEFAULT_CONTEXT_SIZE = 8192
|
private const val DEFAULT_CONTEXT_SIZE = 8192
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue