diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt index 3c657d9d37..e28660495d 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt @@ -3,12 +3,16 @@ package com.example.llama.revamp.data.repository import android.content.Context import android.net.Uri import android.os.StatFs -import android.provider.OpenableColumns import android.util.Log import com.example.llama.revamp.data.local.ModelDao import com.example.llama.revamp.data.local.ModelEntity import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker +import com.example.llama.revamp.util.extractModelTypeFromFilename +import com.example.llama.revamp.util.extractParametersFromFilename +import com.example.llama.revamp.util.extractQuantizationFromFilename +import com.example.llama.revamp.util.getFileNameFromUri +import com.example.llama.revamp.util.getFileSizeFromUri import dagger.hilt.android.qualifiers.ApplicationContext import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay @@ -29,7 +33,6 @@ import java.nio.ByteBuffer import java.nio.channels.Channels import java.nio.channels.ReadableByteChannel import java.nio.channels.WritableByteChannel -import java.util.Locale import java.util.UUID import javax.inject.Inject import javax.inject.Singleton @@ -41,7 +44,12 @@ interface ModelRepository { fun getStorageMetrics(): Flow fun getModels(): Flow> - suspend fun importModel(uri: Uri, progressTracker: ImportProgressTracker? = null): ModelInfo + suspend fun importModel( + uri: Uri, + name: String? = null, + size: Long? = null, + progressTracker: ImportProgressTracker? = null + ): ModelInfo suspend fun deleteModel(modelId: String) suspend fun deleteModels(modelIds: List) @@ -88,10 +96,12 @@ class ModelRepositoryImpl @Inject constructor( override suspend fun importModel( uri: Uri, + name: String?, + size: Long?, progressTracker: ImportProgressTracker? ): ModelInfo = withContext(Dispatchers.IO) { - val fileName = getFileNameFromUri(uri) ?: throw FileNotFoundException("Filename N/A") - val fileSize = getFileSizeFromUri(uri) ?: throw FileNotFoundException("File size N/A") + val fileName = name ?: getFileNameFromUri(context, uri) ?: throw FileNotFoundException("Filename N/A") + val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A") val modelFile = File(modelsDir, fileName) try { @@ -251,69 +261,6 @@ class ModelRepositoryImpl @Inject constructor( val totalSpaceBytes: Long get() = StatFs(context.filesDir.path).totalBytes - private fun getFileNameFromUri(uri: Uri): String? = - context.contentResolver.query(uri, null, null, null, null)?.use { cursor -> - if (cursor.moveToFirst()) { - cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex -> - if (nameIndex != -1) cursor.getString(nameIndex) else null - } - } else { - null - } - } ?: uri.lastPathSegment - - /** - * Gets the file size from a content URI, or returns 0 if size is unknown. - */ - private fun getFileSizeFromUri(uri: Uri): Long? = - context.contentResolver.query(uri, null, null, null, null)?.use { cursor -> - if (cursor.moveToFirst()) { - cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex -> - if (sizeIndex != -1) cursor.getLong(sizeIndex) else null - } - } else { - null - } - } - - /** - * Try to extract parameters by looking for patterns like 7B, 13B, etc. - * - * TODO-han.yin: Enhance and move into a utility object for unit testing - */ - private fun extractParametersFromFilename(filename: String): String? = - Regex("([0-9]+(\\.[0-9]+)?)[bB]").find(filename)?.value?.uppercase() - - /** - * Try to extract quantization by looking for patterns like Q4_0, Q5_K_M, etc. - */ - private fun extractQuantizationFromFilename(filename: String) = - listOf( - Regex("[qQ][0-9]_[0-9]"), - Regex("[qQ][0-9]_[kK]_[mM]"), - Regex("[qQ][0-9]_[kK]"), - Regex("[qQ][0-9][fF](16|32)") - ).firstNotNullOfOrNull { - it.find(filename)?.value?.uppercase() - } - - /** - * Try to extract model type (Llama, Mistral, etc.) - * - * TODO-han.yin: Replace with GGUF parsing, also to be moved into the util object - */ - private fun extractModelTypeFromFilename(filename: String): String? { - val lowerFilename = filename.lowercase() - return listOf("llama", "mistral", "phi", "qwen", "falcon", "mpt") - .firstNotNullOfOrNull { type -> - if (lowerFilename.contains(type)) { - type.replaceFirstChar { - if (it.isLowerCase()) it.titlecase(Locale.ROOT) else it.toString() - } - } else { null } - } - } - companion object { private val TAG = ModelRepository::class.java.simpleName diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/ModelUtils.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/ModelUtils.kt new file mode 100644 index 0000000000..48ac961d87 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/ModelUtils.kt @@ -0,0 +1,70 @@ +package com.example.llama.revamp.util + +import android.content.Context +import android.net.Uri +import android.provider.OpenableColumns +import java.util.Locale + +/** + * Gets the file name from a content URI + */ +fun getFileNameFromUri(context: Context, uri: Uri): String? = + context.contentResolver.query(uri, null, null, null, null)?.use { cursor -> + if (cursor.moveToFirst()) { + cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex -> + if (nameIndex != -1) cursor.getString(nameIndex) else null + } + } else { + null + } + } ?: uri.lastPathSegment + +/** + * Gets the file size from a content URI + */ +fun getFileSizeFromUri(context: Context, uri: Uri): Long? = + context.contentResolver.query(uri, null, null, null, null)?.use { cursor -> + if (cursor.moveToFirst()) { + cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex -> + if (sizeIndex != -1) cursor.getLong(sizeIndex) else null + } + } else { + null + } + } + +/** + * Try to extract parameters by looking for patterns like 7B, 13B, etc. + */ +fun extractParametersFromFilename(filename: String): String? = + Regex("([0-9]+(\\.[0-9]+)?)[bB]").find(filename)?.value?.uppercase() + +/** + * Try to extract quantization by looking for patterns like Q4_0, Q5_K_M, etc. + */ +fun extractQuantizationFromFilename(filename: String) = + listOf( + Regex("[qQ][0-9]_[0-9]"), + Regex("[qQ][0-9]_[kK]_[mM]"), + Regex("[qQ][0-9]_[kK]"), + Regex("[qQ][0-9][fF](16|32)") + ).firstNotNullOfOrNull { + it.find(filename)?.value?.uppercase() + } + +/** + * Try to extract model type (Llama, Mistral, etc.) + * + * TODO-han.yin: Replace with GGUF parsing, also to be moved into the util object + */ +fun extractModelTypeFromFilename(filename: String): String? { + val lowerFilename = filename.lowercase() + return listOf("llama", "mistral", "phi", "qwen", "falcon", "mpt") + .firstNotNullOfOrNull { type -> + if (lowerFilename.contains(type)) { + type.replaceFirstChar { + if (it.isLowerCase()) it.titlecase(Locale.ROOT) else it.toString() + } + } else { null } + } +}