data: add hand crafted rules to filter the models fetched from HuggingFace API
This commit is contained in:
parent
f1269f4d39
commit
b1bcb8126c
|
|
@ -83,7 +83,7 @@ interface ModelRepository {
|
|||
/**
|
||||
* Search models on HuggingFace
|
||||
*/
|
||||
suspend fun searchHuggingFaceModels(limit: Int = 20): Result<List<HuggingFaceModel>>
|
||||
suspend fun searchHuggingFaceModels(limit: Int): Result<List<HuggingFaceModel>>
|
||||
|
||||
/**
|
||||
* Obtain the model details from HuggingFace
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ import androidx.core.net.toUri
|
|||
import com.example.llama.di.HUGGINGFACE_HOST
|
||||
import java.util.Date
|
||||
|
||||
private const val FILE_EXTENSION_GGUF = ".gguf"
|
||||
private const val FILE_EXTENSION_GGUF = ".GGUF"
|
||||
private val QUANTIZATION_Q4_0 = arrayOf("Q4_0", "Q4-0")
|
||||
|
||||
data class HuggingFaceModel(
|
||||
val _id: String,
|
||||
|
|
@ -34,10 +35,23 @@ data class HuggingFaceModel(
|
|||
val rfilename: String,
|
||||
)
|
||||
|
||||
fun getGgufFilename(): String? =
|
||||
siblings.map { it.rfilename }.firstOrNull { it.endsWith(FILE_EXTENSION_GGUF) }
|
||||
fun getGgufFilename(keywords: Array<String> = QUANTIZATION_Q4_0): String? =
|
||||
siblings.map { it.rfilename }
|
||||
.filter {
|
||||
it.endsWith(FILE_EXTENSION_GGUF, ignoreCase = true) }
|
||||
.firstOrNull { filename ->
|
||||
keywords.any { filename.contains(it, ignoreCase = true) }
|
||||
}
|
||||
|
||||
fun toDownloadInfo() = getGgufFilename()?.let { HuggingFaceDownloadInfo(_id, modelId, it) }
|
||||
fun anyFilenameContains(keywords: Array<String>): Boolean =
|
||||
siblings.map { it.rfilename }
|
||||
.any { filename ->
|
||||
keywords.any { filename.contains(it, ignoreCase = true) }
|
||||
}
|
||||
|
||||
fun toDownloadInfo() = getGgufFilename()?.let {
|
||||
HuggingFaceDownloadInfo(_id, modelId, it)
|
||||
}
|
||||
}
|
||||
|
||||
data class HuggingFaceDownloadInfo(
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ import javax.inject.Singleton
|
|||
private const val QUERY_Q4_0_GGUF = "gguf q4_0"
|
||||
private const val FILTER_TEXT_GENERATION = "text-generation"
|
||||
private const val SORT_BY_DOWNLOADS = "downloads"
|
||||
private const val SEARCH_RESULT_LIMIT = 20
|
||||
private const val SEARCH_RESULT_LIMIT = 30
|
||||
|
||||
private val INVALID_KEYWORDS = arrayOf("-of-", "split", "70B", "30B", "27B", "14B", "13B", "12B")
|
||||
|
||||
interface HuggingFaceRemoteDataSource {
|
||||
/**
|
||||
|
|
@ -67,11 +69,18 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
|||
direction = direction,
|
||||
limit = limit,
|
||||
full = full,
|
||||
).filter {
|
||||
it.gated != true && it.private != true && it.getGgufFilename() != null
|
||||
}.let {
|
||||
if (it.isEmpty()) Result.failure(FileNotFoundException()) else Result.success(it)
|
||||
}
|
||||
)
|
||||
.filterNot { it.gated || it.private }
|
||||
.filterNot {
|
||||
it.getGgufFilename().let { filename ->
|
||||
filename.isNullOrBlank() || INVALID_KEYWORDS.any {
|
||||
filename.contains(it, ignoreCase = true)
|
||||
}
|
||||
}
|
||||
}.let {
|
||||
if (it.isEmpty()) Result.failure(FileNotFoundException())
|
||||
else Result.success(it)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error searching for models on HuggingFace: ${e.message}")
|
||||
Result.failure(e)
|
||||
|
|
|
|||
|
|
@ -442,8 +442,11 @@ private fun ImportFromHuggingFaceDialog(
|
|||
) {
|
||||
Text(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
text = models?.let { "Select a model to download:" }
|
||||
?: "Searching on HuggingFace for models available for direct download...",
|
||||
text = models?.let {
|
||||
"These open-source models from Hugging Face are ungated and free for everyone.\n\n" +
|
||||
"Please use them at your own discretion.\n\n" +
|
||||
"Select a model and tap download button:"
|
||||
} ?: "Searching on HuggingFace for ungated models available for free downloading...",
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
textAlign = TextAlign.Start,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -182,14 +182,14 @@ class ModelsManagementViewModel @Inject constructor(
|
|||
/**
|
||||
* Query models on HuggingFace available for download even without signing in
|
||||
*/
|
||||
fun queryModelsFromHuggingFace() {
|
||||
fun queryModelsFromHuggingFace(cap: Int = FETCH_HUGGINGFACE_MODELS_CAP_SIZE) {
|
||||
huggingFaceQueryJob = viewModelScope.launch {
|
||||
_managementState.emit(Download.Querying)
|
||||
try {
|
||||
modelRepository.searchHuggingFaceModels().fold(
|
||||
modelRepository.searchHuggingFaceModels(FETCH_HUGGINGFACE_MODELS_LIMIT_SIZE).fold(
|
||||
onSuccess = { models ->
|
||||
Log.d(TAG, "Fetched ${models.size} models from HuggingFace:")
|
||||
_managementState.emit(Download.Ready(models))
|
||||
Log.d(TAG, "Fetched ${models.size} models from HuggingFace, capped by $cap:\n")
|
||||
_managementState.emit(Download.Ready(models.take(cap)))
|
||||
},
|
||||
onFailure = { throw it }
|
||||
)
|
||||
|
|
@ -303,6 +303,8 @@ class ModelsManagementViewModel @Inject constructor(
|
|||
companion object {
|
||||
private val TAG = ModelsManagementViewModel::class.java.simpleName
|
||||
|
||||
private const val FETCH_HUGGINGFACE_MODELS_LIMIT_SIZE = 50
|
||||
private const val FETCH_HUGGINGFACE_MODELS_CAP_SIZE = 12
|
||||
private const val DELETE_SUCCESS_RESET_TIMEOUT_MS = 1000L
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue