data: extract local file info, copy and cleanup logics into LocalFileDataSource

This commit is contained in:
Han Yin 2025-07-08 12:43:02 -07:00
parent 33d1e24ac4
commit 6f901e5203
4 changed files with 153 additions and 72 deletions

View File

@ -9,17 +9,14 @@ import com.example.llama.data.db.dao.ModelDao
import com.example.llama.data.db.entity.ModelEntity
import com.example.llama.data.model.GgufMetadata
import com.example.llama.data.model.ModelInfo
import com.example.llama.data.repo.ModelRepository.ImportProgressTracker
import com.example.llama.data.source.local.LocalFileDataSource
import com.example.llama.data.source.remote.HuggingFaceDownloadInfo
import com.example.llama.data.source.remote.HuggingFaceModel
import com.example.llama.data.source.remote.HuggingFaceModelDetails
import com.example.llama.data.source.remote.HuggingFaceRemoteDataSource
import com.example.llama.data.repo.ModelRepository.ImportProgressTracker
import com.example.llama.monitoring.StorageMetrics
import com.example.llama.util.copyWithBuffer
import com.example.llama.util.copyWithChannels
import com.example.llama.util.formatFileByteSize
import com.example.llama.util.getFileNameFromUri
import com.example.llama.util.getFileSizeFromUri
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.Dispatchers
@ -32,7 +29,6 @@ import kotlinx.coroutines.flow.map
import kotlinx.coroutines.withContext
import java.io.File
import java.io.FileNotFoundException
import java.io.FileOutputStream
import java.io.IOException
import java.util.UUID
import javax.inject.Inject
@ -113,6 +109,7 @@ class InsufficientStorageException(message: String) : IOException(message)
class ModelRepositoryImpl @Inject constructor(
@ApplicationContext private val context: Context,
private val modelDao: ModelDao,
private val localFileDataSource: LocalFileDataSource,
private val huggingFaceRemoteDataSource: HuggingFaceRemoteDataSource,
private val ggufMetadataReader: GgufMetadataReader,
) : ModelRepository {
@ -174,7 +171,9 @@ class ModelRepositoryImpl @Inject constructor(
throw IllegalStateException("Another import is already in progress!")
}
val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
val fileInfo = localFileDataSource.getFileInfo(uri)
val fileSize = size ?: fileInfo?.size ?: throw FileNotFoundException("File size N/A")
val fileName = name ?: fileInfo?.name ?: throw FileNotFoundException("File name N/A")
if (!hasEnoughSpaceForImport(fileSize)) {
throw InsufficientStorageException(
"Not enough storage space! " +
@ -182,53 +181,24 @@ class ModelRepositoryImpl @Inject constructor(
"Available: ${formatFileByteSize(availableSpaceBytes)}"
)
}
val fileName = name ?: getFileNameFromUri(context, uri) ?: throw FileNotFoundException("Filename N/A")
val modelFile = File(modelsDir, fileName)
importJob = coroutineContext[Job]
currentModelFile = modelFile
try {
val inputStream = context.contentResolver.openInputStream(uri)
?: throw IOException("Failed to open input stream")
val outputStream = FileOutputStream(modelFile)
if (fileSize > LARGE_MODEL_THRESHOLD_SIZE) {
Log.i(TAG, "Copying $fileName (size: $fileSize) via NIO...")
// Use NIO channels for large models
copyWithChannels(
input = inputStream,
output = outputStream,
totalSize = fileSize,
bufferSize = NIO_BUFFER_SIZE,
yieldSize = NIO_YIELD_SIZE
) { progress ->
localFileDataSource.copyFile(
sourceUri = uri,
destinationFile = modelFile,
fileSize = fileSize,
onProgress = { progress ->
progressTracker?.let {
withContext(Dispatchers.Main) {
it.onProgress(progress)
}
}
}
} else {
Log.i(TAG, "Copying $fileName (size: $fileSize) via buffer...")
// Default copy with buffer for small models
copyWithBuffer(
input = inputStream,
output = outputStream,
totalSize = fileSize,
bufferSize = DEFAULT_BUFFER_SIZE,
yieldSize = DEFAULT_YIELD_SIZE
) { progress ->
progressTracker?.let {
withContext(Dispatchers.Main) {
it.onProgress(progress)
}
}
}
}
).getOrThrow()
// Extract GGUF metadata if possible
val metadata = try {
@ -256,12 +226,12 @@ class ModelRepositoryImpl @Inject constructor(
} catch (e: CancellationException) {
Log.i(TAG, "Import was cancelled for $fileName: ${e.message}")
cleanupPartialFile(modelFile)
localFileDataSource.cleanupPartialFile(modelFile)
throw e
} catch (e: Exception) {
Log.e(TAG, "Import failed for $fileName: ${e.message}")
cleanupPartialFile(modelFile)
localFileDataSource.cleanupPartialFile(modelFile)
throw e
} finally {
@ -300,8 +270,8 @@ class ModelRepositoryImpl @Inject constructor(
// Give the job a moment to clean up
delay(CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT)
// Clean up the partial file (as a safety measure)
cleanupPartialFile(file)
// Clean up the partial file
file?.let { localFileDataSource.cleanupPartialFile(it) }
// Reset state
importJob = null
@ -315,16 +285,6 @@ class ModelRepositoryImpl @Inject constructor(
}
}
private fun cleanupPartialFile(file: File?) {
try {
if (file?.exists() == true && !file.delete()) {
Log.e(TAG, "Failed to delete partial file: ${file.absolutePath}")
}
} catch (e: Exception) {
Log.e(TAG, "Error cleaning up partial file: ${e.message}")
}
}
override suspend fun updateModelLastUsed(modelId: String) = withContext(Dispatchers.IO) {
modelDao.updateLastUsed(modelId, System.currentTimeMillis())
}
@ -399,11 +359,6 @@ class ModelRepositoryImpl @Inject constructor(
private const val BYTES_IN_GB = 1024f * 1024f * 1024f
private const val MODEL_IMPORT_SPACE_BUFFER_SCALE = 1.2f
private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024
private const val NIO_BUFFER_SIZE = 32 * 1024 * 1024
private const val NIO_YIELD_SIZE = 128 * 1024 * 1024
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
private const val CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT = 500L
}
}

View File

@ -0,0 +1,118 @@
package com.example.llama.data.source.local
import android.content.Context
import android.net.Uri
import android.util.Log
import com.example.llama.data.source.local.LocalFileDataSource.FileInfo
import com.example.llama.util.copyWithBuffer
import com.example.llama.util.copyWithChannels
import com.example.llama.util.getFileNameFromUri
import com.example.llama.util.getFileSizeFromUri
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
import javax.inject.Inject
import javax.inject.Singleton
interface LocalFileDataSource {
/**
* Copy local file from [sourceUri] into [destinationFile]
*/
suspend fun copyFile(
sourceUri: Uri,
destinationFile: File,
fileSize: Long,
onProgress: (suspend (Float) -> Unit)? = null
): Result<File>
/**
* Obtain the file name and size from given [uri]
*/
suspend fun getFileInfo(uri: Uri): FileInfo?
/**
* Clean up incomplete file due to unfinished import
*/
suspend fun cleanupPartialFile(file: File): Boolean
data class FileInfo(val name: String, val size: Long)
}
@Singleton
class LocalFileDataSourceImpl @Inject constructor(
@ApplicationContext private val context: Context
) : LocalFileDataSource {
override suspend fun copyFile(
sourceUri: Uri,
destinationFile: File,
fileSize: Long,
onProgress: (suspend (Float) -> Unit)?
): Result<File> = withContext(Dispatchers.IO) {
try {
val inputStream = context.contentResolver.openInputStream(sourceUri)
?: return@withContext Result.failure(IOException("Failed to open input stream"))
val outputStream = FileOutputStream(destinationFile)
if (fileSize > LARGE_MODEL_THRESHOLD_SIZE) {
// Use NIO channels for large models
Log.i(TAG, "Copying ${destinationFile.name} (size: $fileSize) via NIO...")
copyWithChannels(
input = inputStream,
output = outputStream,
totalSize = fileSize,
bufferSize = NIO_BUFFER_SIZE,
yieldSize = NIO_YIELD_SIZE,
onProgress = onProgress
)
} else {
// Default copy with buffer for small models
Log.i(TAG, "Copying ${destinationFile.name} (size: $fileSize) via buffer...")
copyWithBuffer(
input = inputStream,
output = outputStream,
totalSize = fileSize,
bufferSize = DEFAULT_BUFFER_SIZE,
yieldSize = DEFAULT_YIELD_SIZE,
onProgress = onProgress
)
}
Result.success(destinationFile)
} catch (e: Exception) {
if (destinationFile.exists()) {
destinationFile.delete()
}
Result.failure(e)
}
}
override suspend fun getFileInfo(uri: Uri): FileInfo? {
val name = getFileNameFromUri(context, uri)
val size = getFileSizeFromUri(context, uri)
return if (name != null && size != null) {
FileInfo(name, size)
} else null
}
override suspend fun cleanupPartialFile(file: File): Boolean =
try {
if (file.exists()) file.delete() else false
} catch (e: Exception) {
Log.e(TAG, "Failed to delete file: ${e.message}")
false
}
companion object {
private val TAG = LocalFileDataSourceImpl::class.java.simpleName
private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024
private const val NIO_BUFFER_SIZE = 32 * 1024 * 1024
private const val NIO_YIELD_SIZE = 128 * 1024 * 1024
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
}
}

View File

@ -6,14 +6,16 @@ import android.llama.cpp.KleidiLlama
import android.llama.cpp.TierDetection
import android.llama.cpp.gguf.GgufMetadataReader
import com.example.llama.data.db.AppDatabase
import com.example.llama.data.source.remote.GatedTypeAdapter
import com.example.llama.data.source.remote.HuggingFaceApiService
import com.example.llama.data.source.remote.HuggingFaceRemoteDataSource
import com.example.llama.data.source.remote.HuggingFaceRemoteDataSourceImpl
import com.example.llama.data.repo.ModelRepository
import com.example.llama.data.repo.ModelRepositoryImpl
import com.example.llama.data.repo.SystemPromptRepository
import com.example.llama.data.repo.SystemPromptRepositoryImpl
import com.example.llama.data.source.local.LocalFileDataSource
import com.example.llama.data.source.local.LocalFileDataSourceImpl
import com.example.llama.data.source.remote.GatedTypeAdapter
import com.example.llama.data.source.remote.HuggingFaceApiService
import com.example.llama.data.source.remote.HuggingFaceRemoteDataSource
import com.example.llama.data.source.remote.HuggingFaceRemoteDataSourceImpl
import com.example.llama.engine.BenchmarkService
import com.example.llama.engine.ConversationService
import com.example.llama.engine.InferenceService
@ -56,15 +58,16 @@ internal abstract class AppModule {
abstract fun bindConversationService(impl: InferenceServiceImpl) : ConversationService
@Binds
abstract fun bindsModelsRepository(impl: ModelRepositoryImpl): ModelRepository
abstract fun bindModelsRepository(impl: ModelRepositoryImpl): ModelRepository
@Binds
abstract fun bindsSystemPromptRepository(impl: SystemPromptRepositoryImpl): SystemPromptRepository
abstract fun bindSystemPromptRepository(impl: SystemPromptRepositoryImpl): SystemPromptRepository
@Binds
abstract fun bindHuggingFaceRemoteDataSource(
impl: HuggingFaceRemoteDataSourceImpl
): HuggingFaceRemoteDataSource
abstract fun bindLocalFileDataSource(impl: LocalFileDataSourceImpl) : LocalFileDataSource
@Binds
abstract fun bindHuggingFaceRemoteDataSource(impl: HuggingFaceRemoteDataSourceImpl): HuggingFaceRemoteDataSource
companion object {
const val USE_STUB_ENGINE = false

View File

@ -55,20 +55,25 @@ suspend fun copyWithChannels(
val buffer = ByteBuffer.allocateDirect(bufferSize)
var totalBytesRead = 0L
var yieldCounter = 0L
while (inChannel.read(buffer) != -1) {
buffer.flip()
while (buffer.hasRemaining()) {
outChannel.write(buffer)
}
totalBytesRead += buffer.position()
val bytesRead = buffer.position()
totalBytesRead += bytesRead
yieldCounter += bytesRead
buffer.clear()
// Report progress
onProgress?.invoke(totalBytesRead.toFloat() / totalSize)
if (totalBytesRead % (yieldSize) == 0L) {
if (yieldCounter >= yieldSize) {
yield()
yieldCounter = 0L
}
}