data: add a util file for extracting file name & size and model metadata
This commit is contained in:
parent
290a6bfebe
commit
adfbfe3ffb
|
|
@ -3,12 +3,16 @@ package com.example.llama.revamp.data.repository
|
|||
import android.content.Context
|
||||
import android.net.Uri
|
||||
import android.os.StatFs
|
||||
import android.provider.OpenableColumns
|
||||
import android.util.Log
|
||||
import com.example.llama.revamp.data.local.ModelDao
|
||||
import com.example.llama.revamp.data.local.ModelEntity
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker
|
||||
import com.example.llama.revamp.util.extractModelTypeFromFilename
|
||||
import com.example.llama.revamp.util.extractParametersFromFilename
|
||||
import com.example.llama.revamp.util.extractQuantizationFromFilename
|
||||
import com.example.llama.revamp.util.getFileNameFromUri
|
||||
import com.example.llama.revamp.util.getFileSizeFromUri
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.delay
|
||||
|
|
@ -29,7 +33,6 @@ import java.nio.ByteBuffer
|
|||
import java.nio.channels.Channels
|
||||
import java.nio.channels.ReadableByteChannel
|
||||
import java.nio.channels.WritableByteChannel
|
||||
import java.util.Locale
|
||||
import java.util.UUID
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
|
@ -41,7 +44,12 @@ interface ModelRepository {
|
|||
fun getStorageMetrics(): Flow<StorageMetrics>
|
||||
fun getModels(): Flow<List<ModelInfo>>
|
||||
|
||||
suspend fun importModel(uri: Uri, progressTracker: ImportProgressTracker? = null): ModelInfo
|
||||
suspend fun importModel(
|
||||
uri: Uri,
|
||||
name: String? = null,
|
||||
size: Long? = null,
|
||||
progressTracker: ImportProgressTracker? = null
|
||||
): ModelInfo
|
||||
|
||||
suspend fun deleteModel(modelId: String)
|
||||
suspend fun deleteModels(modelIds: List<String>)
|
||||
|
|
@ -88,10 +96,12 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
|
||||
override suspend fun importModel(
|
||||
uri: Uri,
|
||||
name: String?,
|
||||
size: Long?,
|
||||
progressTracker: ImportProgressTracker?
|
||||
): ModelInfo = withContext(Dispatchers.IO) {
|
||||
val fileName = getFileNameFromUri(uri) ?: throw FileNotFoundException("Filename N/A")
|
||||
val fileSize = getFileSizeFromUri(uri) ?: throw FileNotFoundException("File size N/A")
|
||||
val fileName = name ?: getFileNameFromUri(context, uri) ?: throw FileNotFoundException("Filename N/A")
|
||||
val fileSize = size ?: getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
|
||||
val modelFile = File(modelsDir, fileName)
|
||||
|
||||
try {
|
||||
|
|
@ -251,69 +261,6 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
val totalSpaceBytes: Long
|
||||
get() = StatFs(context.filesDir.path).totalBytes
|
||||
|
||||
private fun getFileNameFromUri(uri: Uri): String? =
|
||||
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
||||
if (cursor.moveToFirst()) {
|
||||
cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex ->
|
||||
if (nameIndex != -1) cursor.getString(nameIndex) else null
|
||||
}
|
||||
} else {
|
||||
null
|
||||
}
|
||||
} ?: uri.lastPathSegment
|
||||
|
||||
/**
|
||||
* Gets the file size from a content URI, or returns 0 if size is unknown.
|
||||
*/
|
||||
private fun getFileSizeFromUri(uri: Uri): Long? =
|
||||
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
||||
if (cursor.moveToFirst()) {
|
||||
cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex ->
|
||||
if (sizeIndex != -1) cursor.getLong(sizeIndex) else null
|
||||
}
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to extract parameters by looking for patterns like 7B, 13B, etc.
|
||||
*
|
||||
* TODO-han.yin: Enhance and move into a utility object for unit testing
|
||||
*/
|
||||
private fun extractParametersFromFilename(filename: String): String? =
|
||||
Regex("([0-9]+(\\.[0-9]+)?)[bB]").find(filename)?.value?.uppercase()
|
||||
|
||||
/**
|
||||
* Try to extract quantization by looking for patterns like Q4_0, Q5_K_M, etc.
|
||||
*/
|
||||
private fun extractQuantizationFromFilename(filename: String) =
|
||||
listOf(
|
||||
Regex("[qQ][0-9]_[0-9]"),
|
||||
Regex("[qQ][0-9]_[kK]_[mM]"),
|
||||
Regex("[qQ][0-9]_[kK]"),
|
||||
Regex("[qQ][0-9][fF](16|32)")
|
||||
).firstNotNullOfOrNull {
|
||||
it.find(filename)?.value?.uppercase()
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to extract model type (Llama, Mistral, etc.)
|
||||
*
|
||||
* TODO-han.yin: Replace with GGUF parsing, also to be moved into the util object
|
||||
*/
|
||||
private fun extractModelTypeFromFilename(filename: String): String? {
|
||||
val lowerFilename = filename.lowercase()
|
||||
return listOf("llama", "mistral", "phi", "qwen", "falcon", "mpt")
|
||||
.firstNotNullOfOrNull { type ->
|
||||
if (lowerFilename.contains(type)) {
|
||||
type.replaceFirstChar {
|
||||
if (it.isLowerCase()) it.titlecase(Locale.ROOT) else it.toString()
|
||||
}
|
||||
} else { null }
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val TAG = ModelRepository::class.java.simpleName
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,70 @@
|
|||
package com.example.llama.revamp.util
|
||||
|
||||
import android.content.Context
|
||||
import android.net.Uri
|
||||
import android.provider.OpenableColumns
|
||||
import java.util.Locale
|
||||
|
||||
/**
|
||||
* Gets the file name from a content URI
|
||||
*/
|
||||
fun getFileNameFromUri(context: Context, uri: Uri): String? =
|
||||
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
||||
if (cursor.moveToFirst()) {
|
||||
cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex ->
|
||||
if (nameIndex != -1) cursor.getString(nameIndex) else null
|
||||
}
|
||||
} else {
|
||||
null
|
||||
}
|
||||
} ?: uri.lastPathSegment
|
||||
|
||||
/**
|
||||
* Gets the file size from a content URI
|
||||
*/
|
||||
fun getFileSizeFromUri(context: Context, uri: Uri): Long? =
|
||||
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
||||
if (cursor.moveToFirst()) {
|
||||
cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex ->
|
||||
if (sizeIndex != -1) cursor.getLong(sizeIndex) else null
|
||||
}
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to extract parameters by looking for patterns like 7B, 13B, etc.
|
||||
*/
|
||||
fun extractParametersFromFilename(filename: String): String? =
|
||||
Regex("([0-9]+(\\.[0-9]+)?)[bB]").find(filename)?.value?.uppercase()
|
||||
|
||||
/**
|
||||
* Try to extract quantization by looking for patterns like Q4_0, Q5_K_M, etc.
|
||||
*/
|
||||
fun extractQuantizationFromFilename(filename: String) =
|
||||
listOf(
|
||||
Regex("[qQ][0-9]_[0-9]"),
|
||||
Regex("[qQ][0-9]_[kK]_[mM]"),
|
||||
Regex("[qQ][0-9]_[kK]"),
|
||||
Regex("[qQ][0-9][fF](16|32)")
|
||||
).firstNotNullOfOrNull {
|
||||
it.find(filename)?.value?.uppercase()
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to extract model type (Llama, Mistral, etc.)
|
||||
*
|
||||
* TODO-han.yin: Replace with GGUF parsing, also to be moved into the util object
|
||||
*/
|
||||
fun extractModelTypeFromFilename(filename: String): String? {
|
||||
val lowerFilename = filename.lowercase()
|
||||
return listOf("llama", "mistral", "phi", "qwen", "falcon", "mpt")
|
||||
.firstNotNullOfOrNull { type ->
|
||||
if (lowerFilename.contains(type)) {
|
||||
type.replaceFirstChar {
|
||||
if (it.isLowerCase()) it.titlecase(Locale.ROOT) else it.toString()
|
||||
}
|
||||
} else { null }
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue