diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt index beeb47b0b0..0a204950c4 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt @@ -9,17 +9,14 @@ import com.example.llama.data.db.dao.ModelDao import com.example.llama.data.db.entity.ModelEntity import com.example.llama.data.model.GgufMetadata import com.example.llama.data.model.ModelInfo +import com.example.llama.data.repo.ModelRepository.ImportProgressTracker +import com.example.llama.data.source.local.LocalFileDataSource import com.example.llama.data.source.remote.HuggingFaceDownloadInfo import com.example.llama.data.source.remote.HuggingFaceModel import com.example.llama.data.source.remote.HuggingFaceModelDetails import com.example.llama.data.source.remote.HuggingFaceRemoteDataSource -import com.example.llama.data.repo.ModelRepository.ImportProgressTracker import com.example.llama.monitoring.StorageMetrics -import com.example.llama.util.copyWithBuffer -import com.example.llama.util.copyWithChannels import com.example.llama.util.formatFileByteSize -import com.example.llama.util.getFileNameFromUri -import com.example.llama.util.getFileSizeFromUri import dagger.hilt.android.qualifiers.ApplicationContext import kotlinx.coroutines.CancellationException import kotlinx.coroutines.Dispatchers @@ -32,7 +29,6 @@ import kotlinx.coroutines.flow.map import kotlinx.coroutines.withContext import java.io.File import java.io.FileNotFoundException -import java.io.FileOutputStream import java.io.IOException import java.util.UUID import javax.inject.Inject @@ -113,6 +109,7 @@ class InsufficientStorageException(message: String) : IOException(message) class ModelRepositoryImpl @Inject constructor( @ApplicationContext private val context: Context, private val modelDao: ModelDao, + private val localFileDataSource: LocalFileDataSource, private val huggingFaceRemoteDataSource: HuggingFaceRemoteDataSource, private val ggufMetadataReader: GgufMetadataReader, ) : ModelRepository { @@ -174,7 +171,9 @@ class ModelRepositoryImpl @Inject constructor( throw IllegalStateException("Another import is already in progress!") } - val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A") + val fileInfo = localFileDataSource.getFileInfo(uri) + val fileSize = size ?: fileInfo?.size ?: throw FileNotFoundException("File size N/A") + val fileName = name ?: fileInfo?.name ?: throw FileNotFoundException("File name N/A") if (!hasEnoughSpaceForImport(fileSize)) { throw InsufficientStorageException( "Not enough storage space! " + @@ -182,53 +181,24 @@ class ModelRepositoryImpl @Inject constructor( "Available: ${formatFileByteSize(availableSpaceBytes)}" ) } - - val fileName = name ?: getFileNameFromUri(context, uri) ?: throw FileNotFoundException("Filename N/A") val modelFile = File(modelsDir, fileName) importJob = coroutineContext[Job] currentModelFile = modelFile try { - val inputStream = context.contentResolver.openInputStream(uri) - ?: throw IOException("Failed to open input stream") - val outputStream = FileOutputStream(modelFile) - - if (fileSize > LARGE_MODEL_THRESHOLD_SIZE) { - Log.i(TAG, "Copying $fileName (size: $fileSize) via NIO...") - - // Use NIO channels for large models - copyWithChannels( - input = inputStream, - output = outputStream, - totalSize = fileSize, - bufferSize = NIO_BUFFER_SIZE, - yieldSize = NIO_YIELD_SIZE - ) { progress -> + localFileDataSource.copyFile( + sourceUri = uri, + destinationFile = modelFile, + fileSize = fileSize, + onProgress = { progress -> progressTracker?.let { withContext(Dispatchers.Main) { it.onProgress(progress) } } } - } else { - Log.i(TAG, "Copying $fileName (size: $fileSize) via buffer...") - - // Default copy with buffer for small models - copyWithBuffer( - input = inputStream, - output = outputStream, - totalSize = fileSize, - bufferSize = DEFAULT_BUFFER_SIZE, - yieldSize = DEFAULT_YIELD_SIZE - ) { progress -> - progressTracker?.let { - withContext(Dispatchers.Main) { - it.onProgress(progress) - } - } - } - } + ).getOrThrow() // Extract GGUF metadata if possible val metadata = try { @@ -256,12 +226,12 @@ class ModelRepositoryImpl @Inject constructor( } catch (e: CancellationException) { Log.i(TAG, "Import was cancelled for $fileName: ${e.message}") - cleanupPartialFile(modelFile) + localFileDataSource.cleanupPartialFile(modelFile) throw e } catch (e: Exception) { Log.e(TAG, "Import failed for $fileName: ${e.message}") - cleanupPartialFile(modelFile) + localFileDataSource.cleanupPartialFile(modelFile) throw e } finally { @@ -300,8 +270,8 @@ class ModelRepositoryImpl @Inject constructor( // Give the job a moment to clean up delay(CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT) - // Clean up the partial file (as a safety measure) - cleanupPartialFile(file) + // Clean up the partial file + file?.let { localFileDataSource.cleanupPartialFile(it) } // Reset state importJob = null @@ -315,16 +285,6 @@ class ModelRepositoryImpl @Inject constructor( } } - private fun cleanupPartialFile(file: File?) { - try { - if (file?.exists() == true && !file.delete()) { - Log.e(TAG, "Failed to delete partial file: ${file.absolutePath}") - } - } catch (e: Exception) { - Log.e(TAG, "Error cleaning up partial file: ${e.message}") - } - } - override suspend fun updateModelLastUsed(modelId: String) = withContext(Dispatchers.IO) { modelDao.updateLastUsed(modelId, System.currentTimeMillis()) } @@ -399,11 +359,6 @@ class ModelRepositoryImpl @Inject constructor( private const val BYTES_IN_GB = 1024f * 1024f * 1024f private const val MODEL_IMPORT_SPACE_BUFFER_SCALE = 1.2f - private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024 - private const val NIO_BUFFER_SIZE = 32 * 1024 * 1024 - private const val NIO_YIELD_SIZE = 128 * 1024 * 1024 - private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024 - private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024 private const val CANCEL_LOCAL_MODEL_IMPORT_TIMEOUT = 500L } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/source/local/LocalFileDataSource.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/source/local/LocalFileDataSource.kt new file mode 100644 index 0000000000..041f100dc0 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/source/local/LocalFileDataSource.kt @@ -0,0 +1,118 @@ +package com.example.llama.data.source.local + +import android.content.Context +import android.net.Uri +import android.util.Log +import com.example.llama.data.source.local.LocalFileDataSource.FileInfo +import com.example.llama.util.copyWithBuffer +import com.example.llama.util.copyWithChannels +import com.example.llama.util.getFileNameFromUri +import com.example.llama.util.getFileSizeFromUri +import dagger.hilt.android.qualifiers.ApplicationContext +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import java.io.File +import java.io.FileOutputStream +import java.io.IOException +import javax.inject.Inject +import javax.inject.Singleton + +interface LocalFileDataSource { + /** + * Copy local file from [sourceUri] into [destinationFile] + */ + suspend fun copyFile( + sourceUri: Uri, + destinationFile: File, + fileSize: Long, + onProgress: (suspend (Float) -> Unit)? = null + ): Result + + /** + * Obtain the file name and size from given [uri] + */ + suspend fun getFileInfo(uri: Uri): FileInfo? + + /** + * Clean up incomplete file due to unfinished import + */ + suspend fun cleanupPartialFile(file: File): Boolean + + data class FileInfo(val name: String, val size: Long) +} + +@Singleton +class LocalFileDataSourceImpl @Inject constructor( + @ApplicationContext private val context: Context +) : LocalFileDataSource { + + override suspend fun copyFile( + sourceUri: Uri, + destinationFile: File, + fileSize: Long, + onProgress: (suspend (Float) -> Unit)? + ): Result = withContext(Dispatchers.IO) { + try { + val inputStream = context.contentResolver.openInputStream(sourceUri) + ?: return@withContext Result.failure(IOException("Failed to open input stream")) + val outputStream = FileOutputStream(destinationFile) + + if (fileSize > LARGE_MODEL_THRESHOLD_SIZE) { + // Use NIO channels for large models + Log.i(TAG, "Copying ${destinationFile.name} (size: $fileSize) via NIO...") + copyWithChannels( + input = inputStream, + output = outputStream, + totalSize = fileSize, + bufferSize = NIO_BUFFER_SIZE, + yieldSize = NIO_YIELD_SIZE, + onProgress = onProgress + ) + } else { + // Default copy with buffer for small models + Log.i(TAG, "Copying ${destinationFile.name} (size: $fileSize) via buffer...") + copyWithBuffer( + input = inputStream, + output = outputStream, + totalSize = fileSize, + bufferSize = DEFAULT_BUFFER_SIZE, + yieldSize = DEFAULT_YIELD_SIZE, + onProgress = onProgress + ) + } + + Result.success(destinationFile) + } catch (e: Exception) { + if (destinationFile.exists()) { + destinationFile.delete() + } + Result.failure(e) + } + } + + override suspend fun getFileInfo(uri: Uri): FileInfo? { + val name = getFileNameFromUri(context, uri) + val size = getFileSizeFromUri(context, uri) + return if (name != null && size != null) { + FileInfo(name, size) + } else null + } + + override suspend fun cleanupPartialFile(file: File): Boolean = + try { + if (file.exists()) file.delete() else false + } catch (e: Exception) { + Log.e(TAG, "Failed to delete file: ${e.message}") + false + } + + companion object { + private val TAG = LocalFileDataSourceImpl::class.java.simpleName + + private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024 + private const val NIO_BUFFER_SIZE = 32 * 1024 * 1024 + private const val NIO_YIELD_SIZE = 128 * 1024 * 1024 + private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024 + private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024 + } +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt b/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt index 8ca4854bb5..0ed08226a0 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt @@ -6,14 +6,16 @@ import android.llama.cpp.KleidiLlama import android.llama.cpp.TierDetection import android.llama.cpp.gguf.GgufMetadataReader import com.example.llama.data.db.AppDatabase -import com.example.llama.data.source.remote.GatedTypeAdapter -import com.example.llama.data.source.remote.HuggingFaceApiService -import com.example.llama.data.source.remote.HuggingFaceRemoteDataSource -import com.example.llama.data.source.remote.HuggingFaceRemoteDataSourceImpl import com.example.llama.data.repo.ModelRepository import com.example.llama.data.repo.ModelRepositoryImpl import com.example.llama.data.repo.SystemPromptRepository import com.example.llama.data.repo.SystemPromptRepositoryImpl +import com.example.llama.data.source.local.LocalFileDataSource +import com.example.llama.data.source.local.LocalFileDataSourceImpl +import com.example.llama.data.source.remote.GatedTypeAdapter +import com.example.llama.data.source.remote.HuggingFaceApiService +import com.example.llama.data.source.remote.HuggingFaceRemoteDataSource +import com.example.llama.data.source.remote.HuggingFaceRemoteDataSourceImpl import com.example.llama.engine.BenchmarkService import com.example.llama.engine.ConversationService import com.example.llama.engine.InferenceService @@ -56,15 +58,16 @@ internal abstract class AppModule { abstract fun bindConversationService(impl: InferenceServiceImpl) : ConversationService @Binds - abstract fun bindsModelsRepository(impl: ModelRepositoryImpl): ModelRepository + abstract fun bindModelsRepository(impl: ModelRepositoryImpl): ModelRepository @Binds - abstract fun bindsSystemPromptRepository(impl: SystemPromptRepositoryImpl): SystemPromptRepository + abstract fun bindSystemPromptRepository(impl: SystemPromptRepositoryImpl): SystemPromptRepository @Binds - abstract fun bindHuggingFaceRemoteDataSource( - impl: HuggingFaceRemoteDataSourceImpl - ): HuggingFaceRemoteDataSource + abstract fun bindLocalFileDataSource(impl: LocalFileDataSourceImpl) : LocalFileDataSource + + @Binds + abstract fun bindHuggingFaceRemoteDataSource(impl: HuggingFaceRemoteDataSourceImpl): HuggingFaceRemoteDataSource companion object { const val USE_STUB_ENGINE = false diff --git a/examples/llama.android/app/src/main/java/com/example/llama/util/FileUtils.kt b/examples/llama.android/app/src/main/java/com/example/llama/util/FileUtils.kt index ce9c4c3d8e..8befc3f847 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/util/FileUtils.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/util/FileUtils.kt @@ -55,20 +55,25 @@ suspend fun copyWithChannels( val buffer = ByteBuffer.allocateDirect(bufferSize) var totalBytesRead = 0L + var yieldCounter = 0L while (inChannel.read(buffer) != -1) { buffer.flip() while (buffer.hasRemaining()) { outChannel.write(buffer) } - totalBytesRead += buffer.position() + + val bytesRead = buffer.position() + totalBytesRead += bytesRead + yieldCounter += bytesRead buffer.clear() // Report progress onProgress?.invoke(totalBytesRead.toFloat() / totalSize) - if (totalBytesRead % (yieldSize) == 0L) { + if (yieldCounter >= yieldSize) { yield() + yieldCounter = 0L } }