data: add hand crafted rules to filter the models fetched from HuggingFace API

This commit is contained in:
Han Yin 2025-08-31 00:05:42 -07:00
parent f1269f4d39
commit b1bcb8126c
5 changed files with 45 additions and 17 deletions

View File

@ -83,7 +83,7 @@ interface ModelRepository {
/** /**
* Search models on HuggingFace * Search models on HuggingFace
*/ */
suspend fun searchHuggingFaceModels(limit: Int = 20): Result<List<HuggingFaceModel>> suspend fun searchHuggingFaceModels(limit: Int): Result<List<HuggingFaceModel>>
/** /**
* Obtain the model details from HuggingFace * Obtain the model details from HuggingFace

View File

@ -5,7 +5,8 @@ import androidx.core.net.toUri
import com.example.llama.di.HUGGINGFACE_HOST import com.example.llama.di.HUGGINGFACE_HOST
import java.util.Date import java.util.Date
private const val FILE_EXTENSION_GGUF = ".gguf" private const val FILE_EXTENSION_GGUF = ".GGUF"
private val QUANTIZATION_Q4_0 = arrayOf("Q4_0", "Q4-0")
data class HuggingFaceModel( data class HuggingFaceModel(
val _id: String, val _id: String,
@ -34,10 +35,23 @@ data class HuggingFaceModel(
val rfilename: String, val rfilename: String,
) )
fun getGgufFilename(): String? = fun getGgufFilename(keywords: Array<String> = QUANTIZATION_Q4_0): String? =
siblings.map { it.rfilename }.firstOrNull { it.endsWith(FILE_EXTENSION_GGUF) } siblings.map { it.rfilename }
.filter {
it.endsWith(FILE_EXTENSION_GGUF, ignoreCase = true) }
.firstOrNull { filename ->
keywords.any { filename.contains(it, ignoreCase = true) }
}
fun toDownloadInfo() = getGgufFilename()?.let { HuggingFaceDownloadInfo(_id, modelId, it) } fun anyFilenameContains(keywords: Array<String>): Boolean =
siblings.map { it.rfilename }
.any { filename ->
keywords.any { filename.contains(it, ignoreCase = true) }
}
fun toDownloadInfo() = getGgufFilename()?.let {
HuggingFaceDownloadInfo(_id, modelId, it)
}
} }
data class HuggingFaceDownloadInfo( data class HuggingFaceDownloadInfo(

View File

@ -15,7 +15,9 @@ import javax.inject.Singleton
private const val QUERY_Q4_0_GGUF = "gguf q4_0" private const val QUERY_Q4_0_GGUF = "gguf q4_0"
private const val FILTER_TEXT_GENERATION = "text-generation" private const val FILTER_TEXT_GENERATION = "text-generation"
private const val SORT_BY_DOWNLOADS = "downloads" private const val SORT_BY_DOWNLOADS = "downloads"
private const val SEARCH_RESULT_LIMIT = 20 private const val SEARCH_RESULT_LIMIT = 30
private val INVALID_KEYWORDS = arrayOf("-of-", "split", "70B", "30B", "27B", "14B", "13B", "12B")
interface HuggingFaceRemoteDataSource { interface HuggingFaceRemoteDataSource {
/** /**
@ -67,11 +69,18 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
direction = direction, direction = direction,
limit = limit, limit = limit,
full = full, full = full,
).filter { )
it.gated != true && it.private != true && it.getGgufFilename() != null .filterNot { it.gated || it.private }
}.let { .filterNot {
if (it.isEmpty()) Result.failure(FileNotFoundException()) else Result.success(it) it.getGgufFilename().let { filename ->
} filename.isNullOrBlank() || INVALID_KEYWORDS.any {
filename.contains(it, ignoreCase = true)
}
}
}.let {
if (it.isEmpty()) Result.failure(FileNotFoundException())
else Result.success(it)
}
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Error searching for models on HuggingFace: ${e.message}") Log.e(TAG, "Error searching for models on HuggingFace: ${e.message}")
Result.failure(e) Result.failure(e)

View File

@ -442,8 +442,11 @@ private fun ImportFromHuggingFaceDialog(
) { ) {
Text( Text(
modifier = Modifier.fillMaxWidth(), modifier = Modifier.fillMaxWidth(),
text = models?.let { "Select a model to download:" } text = models?.let {
?: "Searching on HuggingFace for models available for direct download...", "These open-source models from Hugging Face are ungated and free for everyone.\n\n" +
"Please use them at your own discretion.\n\n" +
"Select a model and tap download button:"
} ?: "Searching on HuggingFace for ungated models available for free downloading...",
style = MaterialTheme.typography.bodyLarge, style = MaterialTheme.typography.bodyLarge,
textAlign = TextAlign.Start, textAlign = TextAlign.Start,
) )

View File

@ -182,14 +182,14 @@ class ModelsManagementViewModel @Inject constructor(
/** /**
* Query models on HuggingFace available for download even without signing in * Query models on HuggingFace available for download even without signing in
*/ */
fun queryModelsFromHuggingFace() { fun queryModelsFromHuggingFace(cap: Int = FETCH_HUGGINGFACE_MODELS_CAP_SIZE) {
huggingFaceQueryJob = viewModelScope.launch { huggingFaceQueryJob = viewModelScope.launch {
_managementState.emit(Download.Querying) _managementState.emit(Download.Querying)
try { try {
modelRepository.searchHuggingFaceModels().fold( modelRepository.searchHuggingFaceModels(FETCH_HUGGINGFACE_MODELS_LIMIT_SIZE).fold(
onSuccess = { models -> onSuccess = { models ->
Log.d(TAG, "Fetched ${models.size} models from HuggingFace:") Log.d(TAG, "Fetched ${models.size} models from HuggingFace, capped by $cap:\n")
_managementState.emit(Download.Ready(models)) _managementState.emit(Download.Ready(models.take(cap)))
}, },
onFailure = { throw it } onFailure = { throw it }
) )
@ -303,6 +303,8 @@ class ModelsManagementViewModel @Inject constructor(
companion object { companion object {
private val TAG = ModelsManagementViewModel::class.java.simpleName private val TAG = ModelsManagementViewModel::class.java.simpleName
private const val FETCH_HUGGINGFACE_MODELS_LIMIT_SIZE = 50
private const val FETCH_HUGGINGFACE_MODELS_CAP_SIZE = 12
private const val DELETE_SUCCESS_RESET_TIMEOUT_MS = 1000L private const val DELETE_SUCCESS_RESET_TIMEOUT_MS = 1000L
} }
} }