data: extract local file info, copy and cleanup logics into LocalFileDataSource
This commit is contained in:
parent
33d1e24ac4
commit
6f901e5203
|
|
@ -9,17 +9,14 @@ import com.example.llama.data.db.dao.ModelDao
|
|||
import com.example.llama.data.db.entity.ModelEntity
|
||||
import com.example.llama.data.model.GgufMetadata
|
||||
import com.example.llama.data.model.ModelInfo
|
||||
import com.example.llama.data.repo.ModelRepository.ImportProgressTracker
|
||||
import com.example.llama.data.source.local.LocalFileDataSource
|
||||
import com.example.llama.data.source.remote.HuggingFaceDownloadInfo
|
||||
import com.example.llama.data.source.remote.HuggingFaceModel
|
||||
import com.example.llama.data.source.remote.HuggingFaceModelDetails
|
||||
import com.example.llama.data.source.remote.HuggingFaceRemoteDataSource
|
||||
import com.example.llama.data.repo.ModelRepository.ImportProgressTracker
|
||||
import com.example.llama.monitoring.StorageMetrics
|
||||
import com.example.llama.util.copyWithBuffer
|
||||
import com.example.llama.util.copyWithChannels
|
||||
import com.example.llama.util.formatFileByteSize
|
||||
import com.example.llama.util.getFileNameFromUri
|
||||
import com.example.llama.util.getFileSizeFromUri
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
|
|
@ -32,7 +29,6 @@ import kotlinx.coroutines.flow.map
|
|||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
import java.io.FileNotFoundException
|
||||
import java.io.FileOutputStream
|
||||
import java.io.IOException
|
||||
import java.util.UUID
|
||||
import javax.inject.Inject
|
||||
|
|
@ -113,6 +109,7 @@ class InsufficientStorageException(message: String) : IOException(message)
|
|||
class ModelRepositoryImpl @Inject constructor(
|
||||
@ApplicationContext private val context: Context,
|
||||
private val modelDao: ModelDao,
|
||||
private val localFileDataSource: LocalFileDataSource,
|
||||
private val huggingFaceRemoteDataSource: HuggingFaceRemoteDataSource,
|
||||
private val ggufMetadataReader: GgufMetadataReader,
|
||||
) : ModelRepository {
|
||||
|
|
@ -174,7 +171,9 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
throw IllegalStateException("Another import is already in progress!")
|
||||
}
|
||||
|
||||
val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
|
||||
val fileInfo = localFileDataSource.getFileInfo(uri)
|
||||
val fileSize = size ?: fileInfo?.size ?: throw FileNotFoundException("File size N/A")
|
||||
val fileName = name ?: fileInfo?.name ?: throw FileNotFoundException("File name N/A")
|
||||
if (!hasEnoughSpaceForImport(fileSize)) {
|
||||
throw InsufficientStorageException(
|
||||
"Not enough storage space! " +
|
||||
|
|
@ -182,53 +181,24 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
"Available: ${formatFileByteSize(availableSpaceBytes)}"
|
||||
)
|
||||
}
|
||||
|
||||
val fileName = name ?: getFileNameFromUri(context, uri) ?: throw FileNotFoundException("Filename N/A")
|
||||
val modelFile = File(modelsDir, fileName)
|
||||
|
||||
importJob = coroutineContext[Job]
|
||||
currentModelFile = modelFile
|
||||
|
||||
try {
|
||||
val inputStream = context.contentResolver.openInputStream(uri)
|
||||
?: throw IOException("Failed to open input stream")
|
||||
val outputStream = FileOutputStream(modelFile)
|
||||
|
||||
if (fileSize > LARGE_MODEL_THRESHOLD_SIZE) {
|
||||
Log.i(TAG, "Copying $fileName (size: $fileSize) via NIO...")
|
||||
|
||||
// Use NIO channels for large models
|
||||
copyWithChannels(
|
||||
input = inputStream,
|
||||
output = outputStream,
|
||||
totalSize = fileSize,
|
||||
bufferSize = NIO_BUFFER_SIZE,
|
||||
yieldSize = NIO_YIELD_SIZE
|
||||
) { progress ->
|
||||
localFileDataSource.copyFile(
|
||||
sourceUri = uri,
|
||||
destinationFile = modelFile,
|
||||
fileSize = fileSize,
|
||||
onProgress = { progress ->
|
||||
progressTracker?.let {
|
||||
withContext(Dispatchers.Main) {
|
||||
it.onProgress(progress)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Log.i(TAG, "Copying $fileName (size: $fileSize) via buffer...")
|
||||
|
||||
// Default copy with buffer for small models
|
||||
copyWithBuffer(
|
||||
input = inputStream,
|
||||
output = outputStream,
|
||||
totalSize = fileSize,
|
||||
bufferSize = DEFAULT_BUFFER_SIZE,
|
||||
yieldSize = DEFAULT_YIELD_SIZE
|
||||
) { progress ->
|
||||
progressTracker?.let {
|
||||
withContext(Dispatchers.Main) {
|
||||
it.onProgress(progress)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
).getOrThrow()
|
||||
|
||||
// Extract GGUF metadata if possible
|
||||
val metadata = try {
|
||||
|
|
@ -256,12 +226,12 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
|
||||
} catch (e: CancellationException) {
|
||||
Log.i(TAG, "Import was cancelled for $fileName: ${e.message}")
|
||||
cleanupPartialFile(modelFile)
|
||||
localFileDataSource.cleanupPartialFile(modelFile)
|
||||
throw e
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Import failed for $fileName: ${e.message}")
|
||||
cleanupPartialFile(modelFile)
|
||||
localFileDataSource.cleanupPartialFile(modelFile)
|
||||
throw e
|
||||
|
||||
} finally {
|
||||
|
|
@ -300,8 +270,8 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
// Give the job a moment to clean up
|
||||
delay(CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT)
|
||||
|
||||
// Clean up the partial file (as a safety measure)
|
||||
cleanupPartialFile(file)
|
||||
// Clean up the partial file
|
||||
file?.let { localFileDataSource.cleanupPartialFile(it) }
|
||||
|
||||
// Reset state
|
||||
importJob = null
|
||||
|
|
@ -315,16 +285,6 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
}
|
||||
}
|
||||
|
||||
private fun cleanupPartialFile(file: File?) {
|
||||
try {
|
||||
if (file?.exists() == true && !file.delete()) {
|
||||
Log.e(TAG, "Failed to delete partial file: ${file.absolutePath}")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error cleaning up partial file: ${e.message}")
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun updateModelLastUsed(modelId: String) = withContext(Dispatchers.IO) {
|
||||
modelDao.updateLastUsed(modelId, System.currentTimeMillis())
|
||||
}
|
||||
|
|
@ -399,11 +359,6 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
private const val BYTES_IN_GB = 1024f * 1024f * 1024f
|
||||
|
||||
private const val MODEL_IMPORT_SPACE_BUFFER_SCALE = 1.2f
|
||||
private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024
|
||||
private const val NIO_BUFFER_SIZE = 32 * 1024 * 1024
|
||||
private const val NIO_YIELD_SIZE = 128 * 1024 * 1024
|
||||
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
|
||||
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
|
||||
private const val CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT = 500L
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,118 @@
|
|||
package com.example.llama.data.source.local
|
||||
|
||||
import android.content.Context
|
||||
import android.net.Uri
|
||||
import android.util.Log
|
||||
import com.example.llama.data.source.local.LocalFileDataSource.FileInfo
|
||||
import com.example.llama.util.copyWithBuffer
|
||||
import com.example.llama.util.copyWithChannels
|
||||
import com.example.llama.util.getFileNameFromUri
|
||||
import com.example.llama.util.getFileSizeFromUri
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.io.IOException
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
interface LocalFileDataSource {
|
||||
/**
|
||||
* Copy local file from [sourceUri] into [destinationFile]
|
||||
*/
|
||||
suspend fun copyFile(
|
||||
sourceUri: Uri,
|
||||
destinationFile: File,
|
||||
fileSize: Long,
|
||||
onProgress: (suspend (Float) -> Unit)? = null
|
||||
): Result<File>
|
||||
|
||||
/**
|
||||
* Obtain the file name and size from given [uri]
|
||||
*/
|
||||
suspend fun getFileInfo(uri: Uri): FileInfo?
|
||||
|
||||
/**
|
||||
* Clean up incomplete file due to unfinished import
|
||||
*/
|
||||
suspend fun cleanupPartialFile(file: File): Boolean
|
||||
|
||||
data class FileInfo(val name: String, val size: Long)
|
||||
}
|
||||
|
||||
@Singleton
|
||||
class LocalFileDataSourceImpl @Inject constructor(
|
||||
@ApplicationContext private val context: Context
|
||||
) : LocalFileDataSource {
|
||||
|
||||
override suspend fun copyFile(
|
||||
sourceUri: Uri,
|
||||
destinationFile: File,
|
||||
fileSize: Long,
|
||||
onProgress: (suspend (Float) -> Unit)?
|
||||
): Result<File> = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
val inputStream = context.contentResolver.openInputStream(sourceUri)
|
||||
?: return@withContext Result.failure(IOException("Failed to open input stream"))
|
||||
val outputStream = FileOutputStream(destinationFile)
|
||||
|
||||
if (fileSize > LARGE_MODEL_THRESHOLD_SIZE) {
|
||||
// Use NIO channels for large models
|
||||
Log.i(TAG, "Copying ${destinationFile.name} (size: $fileSize) via NIO...")
|
||||
copyWithChannels(
|
||||
input = inputStream,
|
||||
output = outputStream,
|
||||
totalSize = fileSize,
|
||||
bufferSize = NIO_BUFFER_SIZE,
|
||||
yieldSize = NIO_YIELD_SIZE,
|
||||
onProgress = onProgress
|
||||
)
|
||||
} else {
|
||||
// Default copy with buffer for small models
|
||||
Log.i(TAG, "Copying ${destinationFile.name} (size: $fileSize) via buffer...")
|
||||
copyWithBuffer(
|
||||
input = inputStream,
|
||||
output = outputStream,
|
||||
totalSize = fileSize,
|
||||
bufferSize = DEFAULT_BUFFER_SIZE,
|
||||
yieldSize = DEFAULT_YIELD_SIZE,
|
||||
onProgress = onProgress
|
||||
)
|
||||
}
|
||||
|
||||
Result.success(destinationFile)
|
||||
} catch (e: Exception) {
|
||||
if (destinationFile.exists()) {
|
||||
destinationFile.delete()
|
||||
}
|
||||
Result.failure(e)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun getFileInfo(uri: Uri): FileInfo? {
|
||||
val name = getFileNameFromUri(context, uri)
|
||||
val size = getFileSizeFromUri(context, uri)
|
||||
return if (name != null && size != null) {
|
||||
FileInfo(name, size)
|
||||
} else null
|
||||
}
|
||||
|
||||
override suspend fun cleanupPartialFile(file: File): Boolean =
|
||||
try {
|
||||
if (file.exists()) file.delete() else false
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Failed to delete file: ${e.message}")
|
||||
false
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val TAG = LocalFileDataSourceImpl::class.java.simpleName
|
||||
|
||||
private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024
|
||||
private const val NIO_BUFFER_SIZE = 32 * 1024 * 1024
|
||||
private const val NIO_YIELD_SIZE = 128 * 1024 * 1024
|
||||
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
|
||||
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
|
||||
}
|
||||
}
|
||||
|
|
@ -6,14 +6,16 @@ import android.llama.cpp.KleidiLlama
|
|||
import android.llama.cpp.TierDetection
|
||||
import android.llama.cpp.gguf.GgufMetadataReader
|
||||
import com.example.llama.data.db.AppDatabase
|
||||
import com.example.llama.data.source.remote.GatedTypeAdapter
|
||||
import com.example.llama.data.source.remote.HuggingFaceApiService
|
||||
import com.example.llama.data.source.remote.HuggingFaceRemoteDataSource
|
||||
import com.example.llama.data.source.remote.HuggingFaceRemoteDataSourceImpl
|
||||
import com.example.llama.data.repo.ModelRepository
|
||||
import com.example.llama.data.repo.ModelRepositoryImpl
|
||||
import com.example.llama.data.repo.SystemPromptRepository
|
||||
import com.example.llama.data.repo.SystemPromptRepositoryImpl
|
||||
import com.example.llama.data.source.local.LocalFileDataSource
|
||||
import com.example.llama.data.source.local.LocalFileDataSourceImpl
|
||||
import com.example.llama.data.source.remote.GatedTypeAdapter
|
||||
import com.example.llama.data.source.remote.HuggingFaceApiService
|
||||
import com.example.llama.data.source.remote.HuggingFaceRemoteDataSource
|
||||
import com.example.llama.data.source.remote.HuggingFaceRemoteDataSourceImpl
|
||||
import com.example.llama.engine.BenchmarkService
|
||||
import com.example.llama.engine.ConversationService
|
||||
import com.example.llama.engine.InferenceService
|
||||
|
|
@ -56,15 +58,16 @@ internal abstract class AppModule {
|
|||
abstract fun bindConversationService(impl: InferenceServiceImpl) : ConversationService
|
||||
|
||||
@Binds
|
||||
abstract fun bindsModelsRepository(impl: ModelRepositoryImpl): ModelRepository
|
||||
abstract fun bindModelsRepository(impl: ModelRepositoryImpl): ModelRepository
|
||||
|
||||
@Binds
|
||||
abstract fun bindsSystemPromptRepository(impl: SystemPromptRepositoryImpl): SystemPromptRepository
|
||||
abstract fun bindSystemPromptRepository(impl: SystemPromptRepositoryImpl): SystemPromptRepository
|
||||
|
||||
@Binds
|
||||
abstract fun bindHuggingFaceRemoteDataSource(
|
||||
impl: HuggingFaceRemoteDataSourceImpl
|
||||
): HuggingFaceRemoteDataSource
|
||||
abstract fun bindLocalFileDataSource(impl: LocalFileDataSourceImpl) : LocalFileDataSource
|
||||
|
||||
@Binds
|
||||
abstract fun bindHuggingFaceRemoteDataSource(impl: HuggingFaceRemoteDataSourceImpl): HuggingFaceRemoteDataSource
|
||||
|
||||
companion object {
|
||||
const val USE_STUB_ENGINE = false
|
||||
|
|
|
|||
|
|
@ -55,20 +55,25 @@ suspend fun copyWithChannels(
|
|||
|
||||
val buffer = ByteBuffer.allocateDirect(bufferSize)
|
||||
var totalBytesRead = 0L
|
||||
var yieldCounter = 0L
|
||||
|
||||
while (inChannel.read(buffer) != -1) {
|
||||
buffer.flip()
|
||||
while (buffer.hasRemaining()) {
|
||||
outChannel.write(buffer)
|
||||
}
|
||||
totalBytesRead += buffer.position()
|
||||
|
||||
val bytesRead = buffer.position()
|
||||
totalBytesRead += bytesRead
|
||||
yieldCounter += bytesRead
|
||||
buffer.clear()
|
||||
|
||||
// Report progress
|
||||
onProgress?.invoke(totalBytesRead.toFloat() / totalSize)
|
||||
|
||||
if (totalBytesRead % (yieldSize) == 0L) {
|
||||
if (yieldCounter >= yieldSize) {
|
||||
yield()
|
||||
yieldCounter = 0L
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue