diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt index d915672baf..4e61bb852a 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt @@ -83,7 +83,7 @@ interface ModelRepository { /** * Search models on HuggingFace */ - suspend fun searchHuggingFaceModels(limit: Int = 20): Result> + suspend fun searchHuggingFaceModels(limit: Int): Result> /** * Obtain the model details from HuggingFace diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceModel.kt index 350a1c6505..ff7619fc3d 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceModel.kt @@ -5,7 +5,8 @@ import androidx.core.net.toUri import com.example.llama.di.HUGGINGFACE_HOST import java.util.Date -private const val FILE_EXTENSION_GGUF = ".gguf" +private const val FILE_EXTENSION_GGUF = ".GGUF" +private val QUANTIZATION_Q4_0 = arrayOf("Q4_0", "Q4-0") data class HuggingFaceModel( val _id: String, @@ -34,10 +35,23 @@ data class HuggingFaceModel( val rfilename: String, ) - fun getGgufFilename(): String? = - siblings.map { it.rfilename }.firstOrNull { it.endsWith(FILE_EXTENSION_GGUF) } + fun getGgufFilename(keywords: Array = QUANTIZATION_Q4_0): String? = + siblings.map { it.rfilename } + .filter { + it.endsWith(FILE_EXTENSION_GGUF, ignoreCase = true) } + .firstOrNull { filename -> + keywords.any { filename.contains(it, ignoreCase = true) } + } - fun toDownloadInfo() = getGgufFilename()?.let { HuggingFaceDownloadInfo(_id, modelId, it) } + fun anyFilenameContains(keywords: Array): Boolean = + siblings.map { it.rfilename } + .any { filename -> + keywords.any { filename.contains(it, ignoreCase = true) } + } + + fun toDownloadInfo() = getGgufFilename()?.let { + HuggingFaceDownloadInfo(_id, modelId, it) + } } data class HuggingFaceDownloadInfo( diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSource.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSource.kt index 9d5fcffb36..44e1723726 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSource.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSource.kt @@ -15,7 +15,9 @@ import javax.inject.Singleton private const val QUERY_Q4_0_GGUF = "gguf q4_0" private const val FILTER_TEXT_GENERATION = "text-generation" private const val SORT_BY_DOWNLOADS = "downloads" -private const val SEARCH_RESULT_LIMIT = 20 +private const val SEARCH_RESULT_LIMIT = 30 + +private val INVALID_KEYWORDS = arrayOf("-of-", "split", "70B", "30B", "27B", "14B", "13B", "12B") interface HuggingFaceRemoteDataSource { /** @@ -67,11 +69,18 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor( direction = direction, limit = limit, full = full, - ).filter { - it.gated != true && it.private != true && it.getGgufFilename() != null - }.let { - if (it.isEmpty()) Result.failure(FileNotFoundException()) else Result.success(it) - } + ) + .filterNot { it.gated || it.private } + .filterNot { + it.getGgufFilename().let { filename -> + filename.isNullOrBlank() || INVALID_KEYWORDS.any { + filename.contains(it, ignoreCase = true) + } + } + }.let { + if (it.isEmpty()) Result.failure(FileNotFoundException()) + else Result.success(it) + } } catch (e: Exception) { Log.e(TAG, "Error searching for models on HuggingFace: ${e.message}") Result.failure(e) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementAndDeletingScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementAndDeletingScreen.kt index 714a19c814..2a0cf840b1 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementAndDeletingScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementAndDeletingScreen.kt @@ -442,8 +442,11 @@ private fun ImportFromHuggingFaceDialog( ) { Text( modifier = Modifier.fillMaxWidth(), - text = models?.let { "Select a model to download:" } - ?: "Searching on HuggingFace for models available for direct download...", + text = models?.let { + "These open-source models from Hugging Face are ungated and free for everyone.\n\n" + + "Please use them at your own discretion.\n\n" + + "Select a model and tap download button:" + } ?: "Searching on HuggingFace for ungated models available for free downloading...", style = MaterialTheme.typography.bodyLarge, textAlign = TextAlign.Start, ) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt index ac6435a8fd..7283bd20d9 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt @@ -182,14 +182,14 @@ class ModelsManagementViewModel @Inject constructor( /** * Query models on HuggingFace available for download even without signing in */ - fun queryModelsFromHuggingFace() { + fun queryModelsFromHuggingFace(cap: Int = FETCH_HUGGINGFACE_MODELS_CAP_SIZE) { huggingFaceQueryJob = viewModelScope.launch { _managementState.emit(Download.Querying) try { - modelRepository.searchHuggingFaceModels().fold( + modelRepository.searchHuggingFaceModels(FETCH_HUGGINGFACE_MODELS_LIMIT_SIZE).fold( onSuccess = { models -> - Log.d(TAG, "Fetched ${models.size} models from HuggingFace:") - _managementState.emit(Download.Ready(models)) + Log.d(TAG, "Fetched ${models.size} models from HuggingFace, capped by $cap:\n") + _managementState.emit(Download.Ready(models.take(cap))) }, onFailure = { throw it } ) @@ -303,6 +303,8 @@ class ModelsManagementViewModel @Inject constructor( companion object { private val TAG = ModelsManagementViewModel::class.java.simpleName + private const val FETCH_HUGGINGFACE_MODELS_LIMIT_SIZE = 50 + private const val FETCH_HUGGINGFACE_MODELS_CAP_SIZE = 12 private const val DELETE_SUCCESS_RESET_TIMEOUT_MS = 1000L } }