Util: split FileUtils from ModelUtils; extract copy methods into FileUtils
This commit is contained in:
parent
4913ad0dae
commit
59f5caa699
|
|
@ -8,6 +8,8 @@ import com.example.llama.revamp.data.local.ModelDao
|
||||||
import com.example.llama.revamp.data.local.ModelEntity
|
import com.example.llama.revamp.data.local.ModelEntity
|
||||||
import com.example.llama.revamp.data.model.ModelInfo
|
import com.example.llama.revamp.data.model.ModelInfo
|
||||||
import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker
|
import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker
|
||||||
|
import com.example.llama.revamp.util.copyWithBuffer
|
||||||
|
import com.example.llama.revamp.util.copyWithChannels
|
||||||
import com.example.llama.revamp.util.extractModelTypeFromFilename
|
import com.example.llama.revamp.util.extractModelTypeFromFilename
|
||||||
import com.example.llama.revamp.util.extractParametersFromFilename
|
import com.example.llama.revamp.util.extractParametersFromFilename
|
||||||
import com.example.llama.revamp.util.extractQuantizationFromFilename
|
import com.example.llama.revamp.util.extractQuantizationFromFilename
|
||||||
|
|
@ -20,19 +22,10 @@ import kotlinx.coroutines.flow.Flow
|
||||||
import kotlinx.coroutines.flow.flow
|
import kotlinx.coroutines.flow.flow
|
||||||
import kotlinx.coroutines.flow.map
|
import kotlinx.coroutines.flow.map
|
||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
import kotlinx.coroutines.yield
|
|
||||||
import java.io.BufferedInputStream
|
|
||||||
import java.io.BufferedOutputStream
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
import java.io.FileNotFoundException
|
import java.io.FileNotFoundException
|
||||||
import java.io.FileOutputStream
|
import java.io.FileOutputStream
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
import java.io.InputStream
|
|
||||||
import java.io.OutputStream
|
|
||||||
import java.nio.ByteBuffer
|
|
||||||
import java.nio.channels.Channels
|
|
||||||
import java.nio.channels.ReadableByteChannel
|
|
||||||
import java.nio.channels.WritableByteChannel
|
|
||||||
import java.util.UUID
|
import java.util.UUID
|
||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
import javax.inject.Singleton
|
import javax.inject.Singleton
|
||||||
|
|
@ -115,19 +108,36 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
Log.i(TAG, "Copying $fileName (size: $fileSize) via NIO...")
|
Log.i(TAG, "Copying $fileName (size: $fileSize) via NIO...")
|
||||||
|
|
||||||
// Use NIO channels for large models
|
// Use NIO channels for large models
|
||||||
copyWithChannels(inputStream, outputStream, fileSize, progressTracker)
|
copyWithChannels(
|
||||||
|
input = inputStream,
|
||||||
|
output = outputStream,
|
||||||
|
totalSize = fileSize,
|
||||||
|
bufferSize = NIO_BUFFER_SIZE,
|
||||||
|
yieldSize = NIO_YIELD_SIZE
|
||||||
|
) { progress ->
|
||||||
|
progressTracker?.let {
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
it.onProgress(progress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
Log.i(TAG, "Copying $fileName (size: $fileSize) via buffer...")
|
Log.i(TAG, "Copying $fileName (size: $fileSize) via buffer...")
|
||||||
|
|
||||||
// Default copy with buffer for small models
|
// Default copy with buffer for small models
|
||||||
val bufferedInput = BufferedInputStream(inputStream, DEFAULT_BUFFER_SIZE)
|
copyWithBuffer(
|
||||||
val bufferedOutput = BufferedOutputStream(outputStream, DEFAULT_BUFFER_SIZE)
|
input = inputStream,
|
||||||
copyWithBuffer(bufferedInput, bufferedOutput, fileSize, progressTracker)
|
output = outputStream,
|
||||||
|
totalSize = fileSize,
|
||||||
// Close streams
|
bufferSize = DEFAULT_BUFFER_SIZE,
|
||||||
bufferedOutput.flush()
|
yieldSize = DEFAULT_YIELD_SIZE
|
||||||
bufferedOutput.close()
|
) { progress ->
|
||||||
bufferedInput.close()
|
progressTracker?.let {
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
it.onProgress(progress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract model parameters from filename
|
// Extract model parameters from filename
|
||||||
|
|
@ -153,85 +163,11 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
}
|
}
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
// Clean up partially downloaded file if error occurs
|
// Clean up partially downloaded file if error occurs
|
||||||
if (modelFile.exists()) {
|
if (modelFile.exists()) { modelFile.delete() }
|
||||||
modelFile.delete()
|
|
||||||
}
|
|
||||||
throw e
|
throw e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private suspend fun copyWithChannels(
|
|
||||||
input: InputStream,
|
|
||||||
output: OutputStream,
|
|
||||||
totalSize: Long,
|
|
||||||
progressTracker: ImportProgressTracker?
|
|
||||||
) {
|
|
||||||
val inChannel: ReadableByteChannel = Channels.newChannel(input)
|
|
||||||
val outChannel: WritableByteChannel = Channels.newChannel(output)
|
|
||||||
|
|
||||||
val buffer = ByteBuffer.allocateDirect(NIO_BUFFER_SIZE)
|
|
||||||
var totalBytesRead = 0L
|
|
||||||
|
|
||||||
while (inChannel.read(buffer) != -1) {
|
|
||||||
buffer.flip()
|
|
||||||
while (buffer.hasRemaining()) {
|
|
||||||
outChannel.write(buffer)
|
|
||||||
}
|
|
||||||
totalBytesRead += buffer.position()
|
|
||||||
buffer.clear()
|
|
||||||
|
|
||||||
// Report progress
|
|
||||||
progressTracker?.let {
|
|
||||||
val progress = totalBytesRead.toFloat() / totalSize
|
|
||||||
withContext(Dispatchers.Main) {
|
|
||||||
it.onProgress(progress)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (totalBytesRead % (NIO_YIELD_SIZE) == 0L) {
|
|
||||||
yield()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inChannel.close()
|
|
||||||
outChannel.close()
|
|
||||||
output.close()
|
|
||||||
input.close()
|
|
||||||
}
|
|
||||||
|
|
||||||
private suspend fun copyWithBuffer(
|
|
||||||
input: BufferedInputStream,
|
|
||||||
output: BufferedOutputStream,
|
|
||||||
totalSize: Long,
|
|
||||||
progressTracker: ImportProgressTracker?
|
|
||||||
) {
|
|
||||||
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
|
|
||||||
|
|
||||||
var bytesRead: Int
|
|
||||||
var totalBytesRead = 0L
|
|
||||||
|
|
||||||
while (input.read(buffer).also { bytesRead = it } != -1) {
|
|
||||||
output.write(buffer, 0, bytesRead)
|
|
||||||
totalBytesRead += bytesRead
|
|
||||||
|
|
||||||
// Report progress
|
|
||||||
if (progressTracker != null) {
|
|
||||||
val progress = totalBytesRead.toFloat() / totalSize
|
|
||||||
withContext(Dispatchers.Main) {
|
|
||||||
progressTracker.onProgress(progress)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Yield less frequently with larger buffers
|
|
||||||
if (totalBytesRead % (DEFAULT_YIELD_SIZE) == 0L) { // Every 64MB
|
|
||||||
yield()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
output.close()
|
|
||||||
input.close()
|
|
||||||
}
|
|
||||||
|
|
||||||
override suspend fun updateModelLastUsed(modelId: String) = withContext(Dispatchers.IO) {
|
override suspend fun updateModelLastUsed(modelId: String) = withContext(Dispatchers.IO) {
|
||||||
modelDao.updateLastUsed(modelId, System.currentTimeMillis())
|
modelDao.updateLastUsed(modelId, System.currentTimeMillis())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,110 @@
|
||||||
|
package com.example.llama.revamp.util
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
|
import android.net.Uri
|
||||||
|
import android.provider.OpenableColumns
|
||||||
|
import kotlinx.coroutines.yield
|
||||||
|
import java.io.BufferedInputStream
|
||||||
|
import java.io.BufferedOutputStream
|
||||||
|
import java.io.InputStream
|
||||||
|
import java.io.OutputStream
|
||||||
|
import java.nio.ByteBuffer
|
||||||
|
import java.nio.channels.Channels
|
||||||
|
import java.nio.channels.ReadableByteChannel
|
||||||
|
import java.nio.channels.WritableByteChannel
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the file name from a content URI
|
||||||
|
*/
|
||||||
|
fun getFileNameFromUri(context: Context, uri: Uri): String? =
|
||||||
|
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
||||||
|
if (cursor.moveToFirst()) {
|
||||||
|
cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex ->
|
||||||
|
if (nameIndex != -1) cursor.getString(nameIndex) else null
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
} ?: uri.lastPathSegment
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the file size from a content URI
|
||||||
|
*/
|
||||||
|
fun getFileSizeFromUri(context: Context, uri: Uri): Long? =
|
||||||
|
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
||||||
|
if (cursor.moveToFirst()) {
|
||||||
|
cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex ->
|
||||||
|
if (sizeIndex != -1) cursor.getLong(sizeIndex) else null
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
suspend fun copyWithChannels(
|
||||||
|
input: InputStream,
|
||||||
|
output: OutputStream,
|
||||||
|
totalSize: Long,
|
||||||
|
bufferSize: Int,
|
||||||
|
yieldSize: Int,
|
||||||
|
onProgress: (suspend (Float) -> Unit)?
|
||||||
|
) {
|
||||||
|
val inChannel: ReadableByteChannel = Channels.newChannel(input)
|
||||||
|
val outChannel: WritableByteChannel = Channels.newChannel(output)
|
||||||
|
|
||||||
|
val buffer = ByteBuffer.allocateDirect(bufferSize)
|
||||||
|
var totalBytesRead = 0L
|
||||||
|
|
||||||
|
while (inChannel.read(buffer) != -1) {
|
||||||
|
buffer.flip()
|
||||||
|
while (buffer.hasRemaining()) {
|
||||||
|
outChannel.write(buffer)
|
||||||
|
}
|
||||||
|
totalBytesRead += buffer.position()
|
||||||
|
buffer.clear()
|
||||||
|
|
||||||
|
// Report progress
|
||||||
|
onProgress?.invoke(totalBytesRead.toFloat() / totalSize)
|
||||||
|
|
||||||
|
if (totalBytesRead % (yieldSize) == 0L) {
|
||||||
|
yield()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
outChannel.close()
|
||||||
|
inChannel.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
suspend fun copyWithBuffer(
|
||||||
|
input: InputStream,
|
||||||
|
output: OutputStream,
|
||||||
|
totalSize: Long,
|
||||||
|
bufferSize: Int,
|
||||||
|
yieldSize: Int,
|
||||||
|
onProgress: (suspend (Float) -> Unit)?
|
||||||
|
) {
|
||||||
|
val bufferedInput = BufferedInputStream(input, bufferSize)
|
||||||
|
val bufferedOutput = BufferedOutputStream(output, bufferSize)
|
||||||
|
val buffer = ByteArray(bufferSize)
|
||||||
|
|
||||||
|
var bytesRead: Int
|
||||||
|
var totalBytesRead = 0L
|
||||||
|
|
||||||
|
while (input.read(buffer).also { bytesRead = it } != -1) {
|
||||||
|
output.write(buffer, 0, bytesRead)
|
||||||
|
totalBytesRead += bytesRead
|
||||||
|
|
||||||
|
// Report progress
|
||||||
|
onProgress?.invoke(totalBytesRead.toFloat() / totalSize)
|
||||||
|
|
||||||
|
// Yield less frequently with larger buffers
|
||||||
|
if (totalBytesRead % (yieldSize) == 0L) { // Every 64MB
|
||||||
|
yield()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bufferedOutput.close()
|
||||||
|
bufferedInput.close()
|
||||||
|
}
|
||||||
|
|
@ -1,37 +1,7 @@
|
||||||
package com.example.llama.revamp.util
|
package com.example.llama.revamp.util
|
||||||
|
|
||||||
import android.content.Context
|
|
||||||
import android.net.Uri
|
|
||||||
import android.provider.OpenableColumns
|
|
||||||
import java.util.Locale
|
import java.util.Locale
|
||||||
|
|
||||||
/**
|
|
||||||
* Gets the file name from a content URI
|
|
||||||
*/
|
|
||||||
fun getFileNameFromUri(context: Context, uri: Uri): String? =
|
|
||||||
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
|
||||||
if (cursor.moveToFirst()) {
|
|
||||||
cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex ->
|
|
||||||
if (nameIndex != -1) cursor.getString(nameIndex) else null
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
null
|
|
||||||
}
|
|
||||||
} ?: uri.lastPathSegment
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Gets the file size from a content URI
|
|
||||||
*/
|
|
||||||
fun getFileSizeFromUri(context: Context, uri: Uri): Long? =
|
|
||||||
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
|
||||||
if (cursor.moveToFirst()) {
|
|
||||||
cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex ->
|
|
||||||
if (sizeIndex != -1) cursor.getLong(sizeIndex) else null
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert bytes into human readable sizes
|
* Convert bytes into human readable sizes
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue