diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt index 0767368bab..8690486c62 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt @@ -8,6 +8,8 @@ import com.example.llama.revamp.data.local.ModelDao import com.example.llama.revamp.data.local.ModelEntity import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker +import com.example.llama.revamp.util.copyWithBuffer +import com.example.llama.revamp.util.copyWithChannels import com.example.llama.revamp.util.extractModelTypeFromFilename import com.example.llama.revamp.util.extractParametersFromFilename import com.example.llama.revamp.util.extractQuantizationFromFilename @@ -20,19 +22,10 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.withContext -import kotlinx.coroutines.yield -import java.io.BufferedInputStream -import java.io.BufferedOutputStream import java.io.File import java.io.FileNotFoundException import java.io.FileOutputStream import java.io.IOException -import java.io.InputStream -import java.io.OutputStream -import java.nio.ByteBuffer -import java.nio.channels.Channels -import java.nio.channels.ReadableByteChannel -import java.nio.channels.WritableByteChannel import java.util.UUID import javax.inject.Inject import javax.inject.Singleton @@ -115,19 +108,36 @@ class ModelRepositoryImpl @Inject constructor( Log.i(TAG, "Copying $fileName (size: $fileSize) via NIO...") // Use NIO channels for large models - copyWithChannels(inputStream, outputStream, fileSize, progressTracker) + copyWithChannels( + input = inputStream, + output = outputStream, + totalSize = fileSize, + bufferSize = NIO_BUFFER_SIZE, + yieldSize = NIO_YIELD_SIZE + ) { progress -> + progressTracker?.let { + withContext(Dispatchers.Main) { + it.onProgress(progress) + } + } + } } else { Log.i(TAG, "Copying $fileName (size: $fileSize) via buffer...") // Default copy with buffer for small models - val bufferedInput = BufferedInputStream(inputStream, DEFAULT_BUFFER_SIZE) - val bufferedOutput = BufferedOutputStream(outputStream, DEFAULT_BUFFER_SIZE) - copyWithBuffer(bufferedInput, bufferedOutput, fileSize, progressTracker) - - // Close streams - bufferedOutput.flush() - bufferedOutput.close() - bufferedInput.close() + copyWithBuffer( + input = inputStream, + output = outputStream, + totalSize = fileSize, + bufferSize = DEFAULT_BUFFER_SIZE, + yieldSize = DEFAULT_YIELD_SIZE + ) { progress -> + progressTracker?.let { + withContext(Dispatchers.Main) { + it.onProgress(progress) + } + } + } } // Extract model parameters from filename @@ -153,85 +163,11 @@ class ModelRepositoryImpl @Inject constructor( } } catch (e: Exception) { // Clean up partially downloaded file if error occurs - if (modelFile.exists()) { - modelFile.delete() - } + if (modelFile.exists()) { modelFile.delete() } throw e } } - private suspend fun copyWithChannels( - input: InputStream, - output: OutputStream, - totalSize: Long, - progressTracker: ImportProgressTracker? - ) { - val inChannel: ReadableByteChannel = Channels.newChannel(input) - val outChannel: WritableByteChannel = Channels.newChannel(output) - - val buffer = ByteBuffer.allocateDirect(NIO_BUFFER_SIZE) - var totalBytesRead = 0L - - while (inChannel.read(buffer) != -1) { - buffer.flip() - while (buffer.hasRemaining()) { - outChannel.write(buffer) - } - totalBytesRead += buffer.position() - buffer.clear() - - // Report progress - progressTracker?.let { - val progress = totalBytesRead.toFloat() / totalSize - withContext(Dispatchers.Main) { - it.onProgress(progress) - } - } - - if (totalBytesRead % (NIO_YIELD_SIZE) == 0L) { - yield() - } - } - - inChannel.close() - outChannel.close() - output.close() - input.close() - } - - private suspend fun copyWithBuffer( - input: BufferedInputStream, - output: BufferedOutputStream, - totalSize: Long, - progressTracker: ImportProgressTracker? - ) { - val buffer = ByteArray(DEFAULT_BUFFER_SIZE) - - var bytesRead: Int - var totalBytesRead = 0L - - while (input.read(buffer).also { bytesRead = it } != -1) { - output.write(buffer, 0, bytesRead) - totalBytesRead += bytesRead - - // Report progress - if (progressTracker != null) { - val progress = totalBytesRead.toFloat() / totalSize - withContext(Dispatchers.Main) { - progressTracker.onProgress(progress) - } - } - - // Yield less frequently with larger buffers - if (totalBytesRead % (DEFAULT_YIELD_SIZE) == 0L) { // Every 64MB - yield() - } - } - - output.close() - input.close() - } - override suspend fun updateModelLastUsed(modelId: String) = withContext(Dispatchers.IO) { modelDao.updateLastUsed(modelId, System.currentTimeMillis()) } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/FileUtils.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/FileUtils.kt new file mode 100644 index 0000000000..6f25785a4b --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/FileUtils.kt @@ -0,0 +1,110 @@ +package com.example.llama.revamp.util + +import android.content.Context +import android.net.Uri +import android.provider.OpenableColumns +import kotlinx.coroutines.yield +import java.io.BufferedInputStream +import java.io.BufferedOutputStream +import java.io.InputStream +import java.io.OutputStream +import java.nio.ByteBuffer +import java.nio.channels.Channels +import java.nio.channels.ReadableByteChannel +import java.nio.channels.WritableByteChannel + + +/** + * Gets the file name from a content URI + */ +fun getFileNameFromUri(context: Context, uri: Uri): String? = + context.contentResolver.query(uri, null, null, null, null)?.use { cursor -> + if (cursor.moveToFirst()) { + cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex -> + if (nameIndex != -1) cursor.getString(nameIndex) else null + } + } else { + null + } + } ?: uri.lastPathSegment + +/** + * Gets the file size from a content URI + */ +fun getFileSizeFromUri(context: Context, uri: Uri): Long? = + context.contentResolver.query(uri, null, null, null, null)?.use { cursor -> + if (cursor.moveToFirst()) { + cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex -> + if (sizeIndex != -1) cursor.getLong(sizeIndex) else null + } + } else { + null + } + } + + +suspend fun copyWithChannels( + input: InputStream, + output: OutputStream, + totalSize: Long, + bufferSize: Int, + yieldSize: Int, + onProgress: (suspend (Float) -> Unit)? +) { + val inChannel: ReadableByteChannel = Channels.newChannel(input) + val outChannel: WritableByteChannel = Channels.newChannel(output) + + val buffer = ByteBuffer.allocateDirect(bufferSize) + var totalBytesRead = 0L + + while (inChannel.read(buffer) != -1) { + buffer.flip() + while (buffer.hasRemaining()) { + outChannel.write(buffer) + } + totalBytesRead += buffer.position() + buffer.clear() + + // Report progress + onProgress?.invoke(totalBytesRead.toFloat() / totalSize) + + if (totalBytesRead % (yieldSize) == 0L) { + yield() + } + } + + outChannel.close() + inChannel.close() +} + +suspend fun copyWithBuffer( + input: InputStream, + output: OutputStream, + totalSize: Long, + bufferSize: Int, + yieldSize: Int, + onProgress: (suspend (Float) -> Unit)? +) { + val bufferedInput = BufferedInputStream(input, bufferSize) + val bufferedOutput = BufferedOutputStream(output, bufferSize) + val buffer = ByteArray(bufferSize) + + var bytesRead: Int + var totalBytesRead = 0L + + while (input.read(buffer).also { bytesRead = it } != -1) { + output.write(buffer, 0, bytesRead) + totalBytesRead += bytesRead + + // Report progress + onProgress?.invoke(totalBytesRead.toFloat() / totalSize) + + // Yield less frequently with larger buffers + if (totalBytesRead % (yieldSize) == 0L) { // Every 64MB + yield() + } + } + + bufferedOutput.close() + bufferedInput.close() +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/ModelUtils.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/ModelUtils.kt index f4ea74edd1..5861ddb260 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/ModelUtils.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/ModelUtils.kt @@ -1,37 +1,7 @@ package com.example.llama.revamp.util -import android.content.Context -import android.net.Uri -import android.provider.OpenableColumns import java.util.Locale -/** - * Gets the file name from a content URI - */ -fun getFileNameFromUri(context: Context, uri: Uri): String? = - context.contentResolver.query(uri, null, null, null, null)?.use { cursor -> - if (cursor.moveToFirst()) { - cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex -> - if (nameIndex != -1) cursor.getString(nameIndex) else null - } - } else { - null - } - } ?: uri.lastPathSegment - -/** - * Gets the file size from a content URI - */ -fun getFileSizeFromUri(context: Context, uri: Uri): Long? = - context.contentResolver.query(uri, null, null, null, null)?.use { cursor -> - if (cursor.moveToFirst()) { - cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex -> - if (sizeIndex != -1) cursor.getLong(sizeIndex) else null - } - } else { - null - } - } /** * Convert bytes into human readable sizes