diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt index 4e61bb852a..1bed5b8ae4 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt @@ -80,6 +80,11 @@ interface ModelRepository { suspend fun deleteModel(modelId: String) suspend fun deleteModels(modelIds: List) + /** + * Fetch details of preselected models + */ + suspend fun fetchPreselectedHuggingFaceModels(): List + /** * Search models on HuggingFace */ @@ -316,6 +321,10 @@ class ModelRepositoryImpl @Inject constructor( } } + override suspend fun fetchPreselectedHuggingFaceModels() = withContext(Dispatchers.IO) { + huggingFaceRemoteDataSource.fetchPreselectedModels() + } + override suspend fun searchHuggingFaceModels( limit: Int ) = withContext(Dispatchers.IO) { diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceModel.kt index ff7619fc3d..9dffb7da13 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceModel.kt @@ -5,8 +5,8 @@ import androidx.core.net.toUri import com.example.llama.di.HUGGINGFACE_HOST import java.util.Date -private const val FILE_EXTENSION_GGUF = ".GGUF" -private val QUANTIZATION_Q4_0 = arrayOf("Q4_0", "Q4-0") +internal const val FILE_EXTENSION_GGUF = ".GGUF" +internal val QUANTIZATION_Q4_0 = arrayOf("Q4_0", "Q4-0") data class HuggingFaceModel( val _id: String, @@ -31,10 +31,6 @@ data class HuggingFaceModel( val library_name: String?, ) { - data class Sibling( - val rfilename: String, - ) - fun getGgufFilename(keywords: Array = QUANTIZATION_Q4_0): String? = siblings.map { it.rfilename } .filter { diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceModelDetails.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceModelDetails.kt index ab7dcc38eb..c30780cf89 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceModelDetails.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceModelDetails.kt @@ -1,6 +1,7 @@ package com.example.llama.data.source.remote import java.util.Date +import kotlin.String data class HuggingFaceModelDetails( val _id: String, @@ -8,51 +9,82 @@ data class HuggingFaceModelDetails( val modelId: String, val author: String, - val createdAt: Date?, - val lastModified: Date?, + val createdAt: Date, + val lastModified: Date, - val library_name: String?, - val pipeline_tag: String?, - val tags: List?, + val pipeline_tag: String, + val tags: List, - val private: Boolean?, + val private: Boolean, + val gated: Boolean, val disabled: Boolean?, - val gated: Boolean?, - val likes: Int?, - val downloads: Int?, + val likes: Int, + val downloads: Int, val usedStorage: Long?, - val sha: String?, + val sha: String, + val siblings: List, val cardData: CardData?, - val siblings: List?, val widgetData: List?, val gguf: Gguf?, + + val library_name: String?, ) { - data class Sibling( - val rfilename: String, + fun getGgufFilename(keywords: Array = QUANTIZATION_Q4_0): String? = + siblings.map { it.rfilename } + .filter { + it.endsWith(FILE_EXTENSION_GGUF, ignoreCase = true) } + .firstOrNull { filename -> + keywords.any { filename.contains(it, ignoreCase = true) } + } + + fun toModel() = HuggingFaceModel( + _id = this._id, + id = this.id, + modelId = this.modelId, + author = this.author, + createdAt = this.createdAt, + lastModified = this.lastModified, + pipeline_tag = this.pipeline_tag, + tags = this.tags, + private = this.private, + gated = this.gated, + likes = this.likes, + downloads = this.downloads, + sha = this.sha, + siblings = this.siblings.map { Sibling(it.rfilename) }, + library_name = this.library_name, ) - data class Gguf( - val total: Long?, - val architecture: String?, - val context_length: Int?, - val chat_template: String?, - val bos_token: String?, - val eos_token: String?, - ) - - data class CardData( - val base_model: String?, - val language: List?, - val license: String?, - val pipeline_tag: String?, - val tags: List?, - ) - - data class WidgetData( - val text: String - ) + fun toDownloadInfo() = getGgufFilename()?.let { + HuggingFaceDownloadInfo(_id, modelId, it) + } } + +data class Sibling( + val rfilename: String, +) + +data class Gguf( + val total: Long?, + val architecture: String?, + val context_length: Int?, + val chat_template: String?, + val bos_token: String?, + val eos_token: String?, +) + +data class CardData( + val base_model: String?, + val language: List?, + val license: String?, + val pipeline_tag: String?, + val tags: List?, +) + +data class WidgetData( + val text: String +) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSource.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSource.kt index 44e1723726..33953f830f 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSource.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSource.kt @@ -5,12 +5,22 @@ import android.content.Context import android.os.Environment import android.util.Log import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll import kotlinx.coroutines.delay +import kotlinx.coroutines.supervisorScope +import kotlinx.coroutines.sync.Semaphore +import kotlinx.coroutines.sync.withPermit import kotlinx.coroutines.withContext +import retrofit2.HttpException import java.io.FileNotFoundException import java.io.IOException +import java.net.SocketTimeoutException +import java.net.UnknownHostException import javax.inject.Inject import javax.inject.Singleton +import kotlin.coroutines.cancellation.CancellationException +import kotlin.math.ceil private const val QUERY_Q4_0_GGUF = "gguf q4_0" private const val FILTER_TEXT_GENERATION = "text-generation" @@ -19,7 +29,25 @@ private const val SEARCH_RESULT_LIMIT = 30 private val INVALID_KEYWORDS = arrayOf("-of-", "split", "70B", "30B", "27B", "14B", "13B", "12B") +private val PRESELECTED_MODEL_IDS = listOf( + "unsloth/gemma-3-1b-it-GGUF", + "unsloth/gemma-3-4b-it-GGUF", + "bartowski/Llama-3.2-1B-Instruct-GGUF", + "bartowski/Llama-3.2-3B-Instruct-GGUF", + "Qwen/Qwen2.5-3B-Instruct-GGUF", + "gaianet/Phi-4-mini-instruct-GGUF", + "bartowski/granite-3.0-2b-instruct-GGUF", + "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", +) + interface HuggingFaceRemoteDataSource { + + suspend fun fetchPreselectedModels( + ids: List = PRESELECTED_MODEL_IDS, + parallelCount: Int = 3, + quorum: Float = 0.5f, + ): List + /** * Query openly available Q4_0 GGUF models on HuggingFace */ @@ -53,6 +81,64 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor( private val apiService: HuggingFaceApiService ) : HuggingFaceRemoteDataSource { + override suspend fun fetchPreselectedModels( + ids: List, + parallelCount: Int, + quorum: Float, + ): List = withContext(Dispatchers.IO) { + val successes = mutableListOf() + val failures = mutableListOf() + + val sem = Semaphore(parallelCount) + supervisorScope { + ids.map { id -> + async { + sem.withPermit { + runCatching { getModelDetails(id) } + .onSuccess { synchronized(successes) { successes += it } } + .onFailure { t -> + if (t is CancellationException) throw t + synchronized(failures) { failures += t } + } + } + } + }.awaitAll() + } + + val total = ids.size + val failed = failures.size + val ok = successes.size + val shouldThrow = failed >= ceil(total * quorum).toInt() + + if (!shouldThrow) return@withContext successes.toList() + + // 1. No Network + if (failures.count { it is UnknownHostException } >= ceil(failed * 0.5).toInt()) { + throw UnknownHostException() + } + + // 2. Time out + if (failures.count { it is SocketTimeoutException } >= ceil(failed * 0.5).toInt()) { + throw SocketTimeoutException() + } + + // 3. known error codes: 404/410/204 + val http404ish = failures.count { (it as? HttpException)?.code() in listOf(404, 410, 204) } + if (ok == 0 && (failed > 0) && (http404ish >= ceil(failed * 0.5).toInt() || failed == total)) { + throw FileNotFoundException() + } + + // 4. Unknown issues + val ioMajority = failures.count { + it is IOException && it !is UnknownHostException && it !is SocketTimeoutException + } >= ceil(failed * 0.5).toInt() + if (ioMajority) { + throw IOException(failures.first { it is IOException }.message) + } + + successes + } + override suspend fun searchModels( query: String?, filter: String?, diff --git a/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt b/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt index 0ed08226a0..272874ea3d 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt @@ -33,7 +33,6 @@ import dagger.hilt.InstallIn import dagger.hilt.android.qualifiers.ApplicationContext import dagger.hilt.components.SingletonComponent import okhttp3.OkHttpClient -import okhttp3.logging.HttpLoggingInterceptor import retrofit2.Retrofit import retrofit2.converter.gson.GsonConverterFactory import javax.inject.Singleton @@ -109,10 +108,7 @@ internal abstract class AppModule { @Provides @Singleton - fun provideOkhttpClient() = OkHttpClient.Builder() - .addInterceptor(HttpLoggingInterceptor().apply { - level = HttpLoggingInterceptor.Level.BODY - }).build() + fun provideOkhttpClient() = OkHttpClient.Builder().build() @Provides @Singleton diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementAndDeletingScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementAndDeletingScreen.kt index 39fc63496b..a1bf43368a 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementAndDeletingScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelsManagementAndDeletingScreen.kt @@ -245,16 +245,10 @@ fun ModelsManagementAndDeletingScreen( } is Download.Dispatched -> { - LaunchedEffect(state) { - onScaffoldEvent( - ScaffoldEvent.ShowSnackbar( - message = "Started downloading:\n${state.downloadInfo.modelId}", - duration = SnackbarDuration.Long, - ) - ) - - managementViewModel.resetManagementState() - } + DownloadHuggingFaceDispatchedDialog( + state.downloadInfo.modelId, + onConfirm = { managementViewModel.resetManagementState() } + ) } is Download.Completed -> { @@ -627,6 +621,33 @@ fun HuggingFaceModelListItem( } } +@Composable +private fun DownloadHuggingFaceDispatchedDialog( + modelId: String, + onConfirm: () -> Unit, +) { + AlertDialog( + onDismissRequest = {}, + properties = DialogProperties( + dismissOnBackPress = false, + dismissOnClickOutside = false + ), + title = {}, + text = { + InfoView( + title = "Download has started", + icon = Icons.Default.Download, + message = "Your Android system download manager has started downloading the model: $modelId.\n\n" + + "You can track its progress in your notification drawer.\n" + + "Feel free to stay on this screen, or come back to import it after complete.", + ) + }, + confirmButton = { + Button(onClick = onConfirm) { Text("Okay") } + }, + ) +} + @Composable private fun FirstModelImportSuccessDialog( onConfirm: () -> Unit, diff --git a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt index 7283bd20d9..212d0f6cb8 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt @@ -16,6 +16,7 @@ import com.example.llama.data.repo.InsufficientStorageException import com.example.llama.data.repo.ModelRepository import com.example.llama.data.source.remote.HuggingFaceDownloadInfo import com.example.llama.data.source.remote.HuggingFaceModel +import com.example.llama.data.source.remote.HuggingFaceModelDetails import com.example.llama.util.formatFileByteSize import com.example.llama.util.getFileNameFromUri import com.example.llama.util.getFileSizeFromUri @@ -186,13 +187,8 @@ class ModelsManagementViewModel @Inject constructor( huggingFaceQueryJob = viewModelScope.launch { _managementState.emit(Download.Querying) try { - modelRepository.searchHuggingFaceModels(FETCH_HUGGINGFACE_MODELS_LIMIT_SIZE).fold( - onSuccess = { models -> - Log.d(TAG, "Fetched ${models.size} models from HuggingFace, capped by $cap:\n") - _managementState.emit(Download.Ready(models.take(cap))) - }, - onFailure = { throw it } - ) + val models = modelRepository.fetchPreselectedHuggingFaceModels().map(HuggingFaceModelDetails::toModel) + _managementState.emit(Download.Ready(models)) } catch (_: CancellationException) { // no-op } catch (_: UnknownHostException) {