Util: split FileUtils from ModelUtils; extract copy methods into FileUtils

This commit is contained in:
Han Yin 2025-04-14 23:58:36 -07:00
parent 4913ad0dae
commit 59f5caa699
3 changed files with 139 additions and 123 deletions

View File

@ -8,6 +8,8 @@ import com.example.llama.revamp.data.local.ModelDao
import com.example.llama.revamp.data.local.ModelEntity
import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker
import com.example.llama.revamp.util.copyWithBuffer
import com.example.llama.revamp.util.copyWithChannels
import com.example.llama.revamp.util.extractModelTypeFromFilename
import com.example.llama.revamp.util.extractParametersFromFilename
import com.example.llama.revamp.util.extractQuantizationFromFilename
@ -20,19 +22,10 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
import java.io.BufferedInputStream
import java.io.BufferedOutputStream
import java.io.File
import java.io.FileNotFoundException
import java.io.FileOutputStream
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.nio.channels.ReadableByteChannel
import java.nio.channels.WritableByteChannel
import java.util.UUID
import javax.inject.Inject
import javax.inject.Singleton
@ -115,19 +108,36 @@ class ModelRepositoryImpl @Inject constructor(
Log.i(TAG, "Copying $fileName (size: $fileSize) via NIO...")
// Use NIO channels for large models
copyWithChannels(inputStream, outputStream, fileSize, progressTracker)
copyWithChannels(
input = inputStream,
output = outputStream,
totalSize = fileSize,
bufferSize = NIO_BUFFER_SIZE,
yieldSize = NIO_YIELD_SIZE
) { progress ->
progressTracker?.let {
withContext(Dispatchers.Main) {
it.onProgress(progress)
}
}
}
} else {
Log.i(TAG, "Copying $fileName (size: $fileSize) via buffer...")
// Default copy with buffer for small models
val bufferedInput = BufferedInputStream(inputStream, DEFAULT_BUFFER_SIZE)
val bufferedOutput = BufferedOutputStream(outputStream, DEFAULT_BUFFER_SIZE)
copyWithBuffer(bufferedInput, bufferedOutput, fileSize, progressTracker)
// Close streams
bufferedOutput.flush()
bufferedOutput.close()
bufferedInput.close()
copyWithBuffer(
input = inputStream,
output = outputStream,
totalSize = fileSize,
bufferSize = DEFAULT_BUFFER_SIZE,
yieldSize = DEFAULT_YIELD_SIZE
) { progress ->
progressTracker?.let {
withContext(Dispatchers.Main) {
it.onProgress(progress)
}
}
}
}
// Extract model parameters from filename
@ -153,85 +163,11 @@ class ModelRepositoryImpl @Inject constructor(
}
} catch (e: Exception) {
// Clean up partially downloaded file if error occurs
if (modelFile.exists()) {
modelFile.delete()
}
if (modelFile.exists()) { modelFile.delete() }
throw e
}
}
private suspend fun copyWithChannels(
input: InputStream,
output: OutputStream,
totalSize: Long,
progressTracker: ImportProgressTracker?
) {
val inChannel: ReadableByteChannel = Channels.newChannel(input)
val outChannel: WritableByteChannel = Channels.newChannel(output)
val buffer = ByteBuffer.allocateDirect(NIO_BUFFER_SIZE)
var totalBytesRead = 0L
while (inChannel.read(buffer) != -1) {
buffer.flip()
while (buffer.hasRemaining()) {
outChannel.write(buffer)
}
totalBytesRead += buffer.position()
buffer.clear()
// Report progress
progressTracker?.let {
val progress = totalBytesRead.toFloat() / totalSize
withContext(Dispatchers.Main) {
it.onProgress(progress)
}
}
if (totalBytesRead % (NIO_YIELD_SIZE) == 0L) {
yield()
}
}
inChannel.close()
outChannel.close()
output.close()
input.close()
}
private suspend fun copyWithBuffer(
input: BufferedInputStream,
output: BufferedOutputStream,
totalSize: Long,
progressTracker: ImportProgressTracker?
) {
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
var bytesRead: Int
var totalBytesRead = 0L
while (input.read(buffer).also { bytesRead = it } != -1) {
output.write(buffer, 0, bytesRead)
totalBytesRead += bytesRead
// Report progress
if (progressTracker != null) {
val progress = totalBytesRead.toFloat() / totalSize
withContext(Dispatchers.Main) {
progressTracker.onProgress(progress)
}
}
// Yield less frequently with larger buffers
if (totalBytesRead % (DEFAULT_YIELD_SIZE) == 0L) { // Every 64MB
yield()
}
}
output.close()
input.close()
}
override suspend fun updateModelLastUsed(modelId: String) = withContext(Dispatchers.IO) {
modelDao.updateLastUsed(modelId, System.currentTimeMillis())
}

View File

@ -0,0 +1,110 @@
package com.example.llama.revamp.util
import android.content.Context
import android.net.Uri
import android.provider.OpenableColumns
import kotlinx.coroutines.yield
import java.io.BufferedInputStream
import java.io.BufferedOutputStream
import java.io.InputStream
import java.io.OutputStream
import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.nio.channels.ReadableByteChannel
import java.nio.channels.WritableByteChannel
/**
* Gets the file name from a content URI
*/
fun getFileNameFromUri(context: Context, uri: Uri): String? =
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
if (cursor.moveToFirst()) {
cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex ->
if (nameIndex != -1) cursor.getString(nameIndex) else null
}
} else {
null
}
} ?: uri.lastPathSegment
/**
* Gets the file size from a content URI
*/
fun getFileSizeFromUri(context: Context, uri: Uri): Long? =
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
if (cursor.moveToFirst()) {
cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex ->
if (sizeIndex != -1) cursor.getLong(sizeIndex) else null
}
} else {
null
}
}
suspend fun copyWithChannels(
input: InputStream,
output: OutputStream,
totalSize: Long,
bufferSize: Int,
yieldSize: Int,
onProgress: (suspend (Float) -> Unit)?
) {
val inChannel: ReadableByteChannel = Channels.newChannel(input)
val outChannel: WritableByteChannel = Channels.newChannel(output)
val buffer = ByteBuffer.allocateDirect(bufferSize)
var totalBytesRead = 0L
while (inChannel.read(buffer) != -1) {
buffer.flip()
while (buffer.hasRemaining()) {
outChannel.write(buffer)
}
totalBytesRead += buffer.position()
buffer.clear()
// Report progress
onProgress?.invoke(totalBytesRead.toFloat() / totalSize)
if (totalBytesRead % (yieldSize) == 0L) {
yield()
}
}
outChannel.close()
inChannel.close()
}
suspend fun copyWithBuffer(
input: InputStream,
output: OutputStream,
totalSize: Long,
bufferSize: Int,
yieldSize: Int,
onProgress: (suspend (Float) -> Unit)?
) {
val bufferedInput = BufferedInputStream(input, bufferSize)
val bufferedOutput = BufferedOutputStream(output, bufferSize)
val buffer = ByteArray(bufferSize)
var bytesRead: Int
var totalBytesRead = 0L
while (input.read(buffer).also { bytesRead = it } != -1) {
output.write(buffer, 0, bytesRead)
totalBytesRead += bytesRead
// Report progress
onProgress?.invoke(totalBytesRead.toFloat() / totalSize)
// Yield less frequently with larger buffers
if (totalBytesRead % (yieldSize) == 0L) { // Every 64MB
yield()
}
}
bufferedOutput.close()
bufferedInput.close()
}

View File

@ -1,37 +1,7 @@
package com.example.llama.revamp.util
import android.content.Context
import android.net.Uri
import android.provider.OpenableColumns
import java.util.Locale
/**
* Gets the file name from a content URI
*/
fun getFileNameFromUri(context: Context, uri: Uri): String? =
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
if (cursor.moveToFirst()) {
cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex ->
if (nameIndex != -1) cursor.getString(nameIndex) else null
}
} else {
null
}
} ?: uri.lastPathSegment
/**
* Gets the file size from a content URI
*/
fun getFileSizeFromUri(context: Context, uri: Uri): Long? =
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
if (cursor.moveToFirst()) {
cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex ->
if (sizeIndex != -1) cursor.getLong(sizeIndex) else null
}
} else {
null
}
}
/**
* Convert bytes into human readable sizes