From e0ddc37e2e7cd185feba7e03b4561009ceb79cbe Mon Sep 17 00:00:00 2001 From: Han Yin Date: Thu, 4 Sep 2025 14:14:44 -0700 Subject: [PATCH] data: sort preselected models according to device's available RAM --- .../java/com/example/llama/MainActivity.kt | 2 +- .../com/example/llama/data/model/ModelInfo.kt | 46 +-- .../llama/data/repo/ModelRepository.kt | 7 +- .../remote/HuggingFaceRemoteDataSource.kt | 244 +--------------- .../remote/HuggingFaceRemoteDataSourceImpl.kt | 272 ++++++++++++++++++ .../viewmodel/ModelsManagementViewModel.kt | 7 +- 6 files changed, 311 insertions(+), 267 deletions(-) create mode 100644 examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSourceImpl.kt diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt index 43044ed68b..19e21768da 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt @@ -344,7 +344,7 @@ fun AppContent( modelsManagementViewModel.toggleImportMenu(false) }, importFromHuggingFace = { - modelsManagementViewModel.queryModelsFromHuggingFace() + modelsManagementViewModel.queryModelsFromHuggingFace(memoryUsage) modelsManagementViewModel.toggleImportMenu(false) } ), diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/model/ModelInfo.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/model/ModelInfo.kt index 757f6568b8..919816a68a 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/model/ModelInfo.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/model/ModelInfo.kt @@ -141,56 +141,56 @@ enum class ModelFilter(val displayName: String, val predicate: (ModelInfo) -> Bo }), SMALL_PARAMS("Small (1-3B parameters)", { it.metadata.basic.sizeLabel?.let { size -> - size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 1f && n <= 3f } == true + size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 1f && n < 4f } == true } == true }), MEDIUM_PARAMS("Medium (4-7B parameters)", { it.metadata.basic.sizeLabel?.let { size -> - size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 4f && n <= 7f } == true + size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 4f && n < 8f } == true } == true }), - LARGE_PARAMS("Large (8-13B parameters)", { + LARGE_PARAMS("Large (8-12B parameters)", { it.metadata.basic.sizeLabel?.let { size -> - size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 8f && n <= 13f } == true + size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 8f && n < 13f } == true } == true }), XLARGE_PARAMS("X-Large (>13B parameters)", { it.metadata.basic.sizeLabel?.let { size -> - size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n > 13f } == true + size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 13f } == true } == true }), // Context length filters - TINY_CONTEXT("Tiny context (<4K)", { - it.metadata.dimensions?.contextLength?.let { it < 4096 } == true + TINY_CONTEXT("Tiny context (<4K)", { model -> + model.metadata.dimensions?.contextLength?.let { it < 4096 } == true }), - SHORT_CONTEXT("Short context (4-8K)", { - it.metadata.dimensions?.contextLength?.let { it >= 4096 && it <= 8192 } == true + SHORT_CONTEXT("Short context (4-8K)", { model -> + model.metadata.dimensions?.contextLength?.let { it >= 4096 && it <= 8192 } == true }), - MEDIUM_CONTEXT("Medium context (8-32K)", { - it.metadata.dimensions?.contextLength?.let { it > 8192 && it <= 32768 } == true + MEDIUM_CONTEXT("Medium context (8-32K)", { model -> + model.metadata.dimensions?.contextLength?.let { it > 8192 && it <= 32768 } == true }), - LONG_CONTEXT("Long context (32-128K)", { - it.metadata.dimensions?.contextLength?.let { it > 32768 && it <= 131072 } == true + LONG_CONTEXT("Long context (32-128K)", { model -> + model.metadata.dimensions?.contextLength?.let { it > 32768 && it <= 131072 } == true }), - XLARGE_CONTEXT("Extended context (>128K)", { - it.metadata.dimensions?.contextLength?.let { it > 131072 } == true + XLARGE_CONTEXT("Extended context (>128K)", { model -> + model.metadata.dimensions?.contextLength?.let { it > 131072 } == true }), // Quantization filters - INT2_QUANT("2-bit quantization", { - it.formattedQuantization.let { it.contains("Q2") || it.contains("IQ2") } + INT2_QUANT("2-bit quantization", { model -> + model.formattedQuantization.let { it.contains("Q2") || it.contains("IQ2") } }), - INT3_QUANT("3-bit quantization", { - it.formattedQuantization.let { it.contains("Q3") || it.contains("IQ3") } + INT3_QUANT("3-bit quantization", { model -> + model.formattedQuantization.let { it.contains("Q3") || it.contains("IQ3") } }), - INT4_QUANT("4-bit quantization", { - it.formattedQuantization.let { it.contains("Q4") || it.contains("IQ4") } + INT4_QUANT("4-bit quantization", { model -> + model.formattedQuantization.let { it.contains("Q4") || it.contains("IQ4") } }), // Special features - MULTILINGUAL("Multilingual", { - it.languages?.let { languages -> + MULTILINGUAL("Multilingual", { model -> + model.languages?.let { languages -> languages.size > 1 || languages.any { it.contains("multi", ignoreCase = true) } } == true }), diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt index 1bed5b8ae4..6008d5cbec 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/repo/ModelRepository.kt @@ -16,6 +16,7 @@ import com.example.llama.data.source.remote.HuggingFaceDownloadInfo import com.example.llama.data.source.remote.HuggingFaceModel import com.example.llama.data.source.remote.HuggingFaceModelDetails import com.example.llama.data.source.remote.HuggingFaceRemoteDataSource +import com.example.llama.monitoring.MemoryMetrics import com.example.llama.monitoring.StorageMetrics import com.example.llama.util.formatFileByteSize import dagger.hilt.android.qualifiers.ApplicationContext @@ -83,7 +84,7 @@ interface ModelRepository { /** * Fetch details of preselected models */ - suspend fun fetchPreselectedHuggingFaceModels(): List + suspend fun fetchPreselectedHuggingFaceModels(memoryUsage: MemoryMetrics): List /** * Search models on HuggingFace @@ -321,8 +322,8 @@ class ModelRepositoryImpl @Inject constructor( } } - override suspend fun fetchPreselectedHuggingFaceModels() = withContext(Dispatchers.IO) { - huggingFaceRemoteDataSource.fetchPreselectedModels() + override suspend fun fetchPreselectedHuggingFaceModels(memoryUsage: MemoryMetrics) = withContext(Dispatchers.IO) { + huggingFaceRemoteDataSource.fetchPreselectedModels(memoryUsage) } override suspend fun searchHuggingFaceModels( diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSource.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSource.kt index 0352651ff6..647a788682 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSource.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSource.kt @@ -1,27 +1,12 @@ package com.example.llama.data.source.remote -import android.app.DownloadManager import android.content.Context -import android.os.Environment -import android.util.Log -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.async -import kotlinx.coroutines.awaitAll -import kotlinx.coroutines.delay -import kotlinx.coroutines.supervisorScope -import kotlinx.coroutines.sync.Semaphore -import kotlinx.coroutines.sync.withPermit -import kotlinx.coroutines.withContext -import retrofit2.HttpException -import java.io.FileNotFoundException -import java.io.IOException -import java.net.SocketTimeoutException -import java.net.UnknownHostException -import javax.inject.Inject -import javax.inject.Singleton -import kotlin.coroutines.cancellation.CancellationException -import kotlin.math.ceil +import com.example.llama.monitoring.MemoryMetrics + +/* + * HuggingFace Search API + */ private const val QUERY_Q4_0_GGUF = "gguf q4_0" private const val FILTER_TEXT_GENERATION = "text-generation" private const val SORT_BY_DOWNLOADS = "downloads" @@ -29,21 +14,10 @@ private const val SEARCH_RESULT_LIMIT = 30 private val INVALID_KEYWORDS = arrayOf("-of-", "split", "70B", "30B", "27B", "14B", "13B", "12B") -private val PRESELECTED_MODEL_IDS = listOf( - "unsloth/gemma-3-1b-it-GGUF", - "unsloth/gemma-3-4b-it-GGUF", - "bartowski/Llama-3.2-1B-Instruct-GGUF", - "bartowski/Llama-3.2-3B-Instruct-GGUF", - "Qwen/Qwen2.5-3B-Instruct-GGUF", - "gaianet/Phi-4-mini-instruct-GGUF", - "bartowski/granite-3.0-2b-instruct-GGUF", - "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", -) - interface HuggingFaceRemoteDataSource { suspend fun fetchPreselectedModels( - ids: List = PRESELECTED_MODEL_IDS, + memoryUsage: MemoryMetrics, parallelCount: Int = 3, quorum: Float = 0.5f, ): List @@ -58,6 +32,7 @@ interface HuggingFaceRemoteDataSource { direction: String? = "-1", limit: Int? = SEARCH_RESULT_LIMIT, full: Boolean = true, + invalidKeywords: Array = INVALID_KEYWORDS ): Result> suspend fun getModelDetails(modelId: String): HuggingFaceModelDetails @@ -76,208 +51,3 @@ interface HuggingFaceRemoteDataSource { ): Result } -@Singleton -class HuggingFaceRemoteDataSourceImpl @Inject constructor( - private val apiService: HuggingFaceApiService -) : HuggingFaceRemoteDataSource { - - override suspend fun fetchPreselectedModels( - ids: List, - parallelCount: Int, - quorum: Float, - ): List = withContext(Dispatchers.IO) { - val sem = Semaphore(parallelCount) - val results = supervisorScope { - ids.map { id -> - async { - sem.withPermit { - try { - Result.success(getModelDetails(id)) - } catch (t: CancellationException) { - Result.failure(t) - } - } - } - }.awaitAll() - } - - val successes = results.mapNotNull { it.getOrNull() } - val failures = results.mapNotNull { it.exceptionOrNull() } - - val total = ids.size - val failed = failures.size - val ok = successes.size - val shouldThrow = failed >= ceil(total * quorum).toInt() - - if (!shouldThrow) return@withContext successes.toList() - - // 1. No Network - if (failures.count { it is UnknownHostException } >= ceil(failed * 0.5).toInt()) { - throw UnknownHostException() - } - - // 2. Time out - if (failures.count { it is SocketTimeoutException } >= ceil(failed * 0.5).toInt()) { - throw SocketTimeoutException() - } - - // 3. known error codes: 404/410/204 - val http404ish = failures.count { (it as? HttpException)?.code() in listOf(404, 410, 204) } - if (ok == 0 && (failed > 0) && (http404ish >= ceil(failed * 0.5).toInt() || failed == total)) { - throw FileNotFoundException() - } - - // 4. Unknown issues - val ioMajority = failures.count { - it is IOException && it !is UnknownHostException && it !is SocketTimeoutException - } >= ceil(failed * 0.5).toInt() - if (ioMajority) { - throw IOException(failures.first { it is IOException }.message) - } - - successes - } - - override suspend fun searchModels( - query: String?, - filter: String?, - sort: String?, - direction: String?, - limit: Int?, - full: Boolean, - ) = withContext(Dispatchers.IO) { - try { - apiService.getModels( - search = query, - filter = filter, - sort = sort, - direction = direction, - limit = limit, - full = full, - ) - .filterNot { it.gated || it.private } - .filterNot { - it.getGgufFilename().let { filename -> - filename.isNullOrBlank() || INVALID_KEYWORDS.any { - filename.contains(it, ignoreCase = true) - } - } - }.let { - if (it.isEmpty()) Result.failure(FileNotFoundException()) - else Result.success(it) - } - } catch (e: Exception) { - Log.e(TAG, "Error searching for models on HuggingFace: ${e.message}") - Result.failure(e) - } - } - - override suspend fun getModelDetails( - modelId: String - ) = withContext(Dispatchers.IO) { - apiService.getModelDetails(modelId) - } - - override suspend fun getFileSize( - modelId: String, - filePath: String - ): Result = withContext(Dispatchers.IO) { - try { - apiService.getModelFileHeader(modelId, filePath).let { resp -> - if (resp.isSuccessful) { - resp.headers()[HTTP_HEADER_CONTENT_LENGTH]?.toLongOrNull()?.let { - Result.success(it) - } ?: Result.failure(IOException("Content-Length header missing")) - } else { - Result.failure( - when (resp.code()) { - 401 -> SecurityException("Model requires authentication") - 404 -> FileNotFoundException("Model file not found") - else -> IOException("Failed to get file info: HTTP ${resp.code()}") - } - ) - } - } - } catch (e: Exception) { - Log.e(TAG, "Error getting file size for $modelId: ${e.message}") - Result.failure(e) - } - } - - override suspend fun downloadModelFile( - context: Context, - downloadInfo: HuggingFaceDownloadInfo, - ): Result = withContext(Dispatchers.IO) { - try { - val downloadManager = - context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager - val request = DownloadManager.Request(downloadInfo.uri).apply { - setTitle(downloadInfo.filename) - setDescription("Downloading directly from HuggingFace") - setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED) - setDestinationInExternalPublicDir( - Environment.DIRECTORY_DOWNLOADS, - downloadInfo.filename - ) - setAllowedNetworkTypes( - DownloadManager.Request.NETWORK_WIFI or DownloadManager.Request.NETWORK_MOBILE - ) - setAllowedOverMetered(true) - setAllowedOverRoaming(false) - } - Log.d(TAG, "Enqueuing download request for: ${downloadInfo.modelId}") - val downloadId = downloadManager.enqueue(request) - - delay(DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY) - - val cursor = downloadManager.query(DownloadManager.Query().setFilterById(downloadId)) - if (cursor != null && cursor.moveToFirst()) { - val statusIndex = cursor.getColumnIndex(DownloadManager.COLUMN_STATUS) - if (statusIndex >= 0) { - val status = cursor.getInt(statusIndex) - cursor.close() - - when (status) { - DownloadManager.STATUS_FAILED -> { - // Get failure reason if available - val reasonIndex = cursor.getColumnIndex(DownloadManager.COLUMN_REASON) - val reason = if (reasonIndex >= 0) cursor.getInt(reasonIndex) else -1 - val errorMessage = when (reason) { - DownloadManager.ERROR_HTTP_DATA_ERROR -> "HTTP error" - DownloadManager.ERROR_INSUFFICIENT_SPACE -> "Insufficient storage" - DownloadManager.ERROR_TOO_MANY_REDIRECTS -> "Too many redirects" - DownloadManager.ERROR_UNHANDLED_HTTP_CODE -> "Unhandled HTTP code" - DownloadManager.ERROR_CANNOT_RESUME -> "Cannot resume download" - DownloadManager.ERROR_FILE_ERROR -> "File error" - else -> "Unknown error" - } - Result.failure(Exception(errorMessage)) - } - else -> { - // Download is pending, paused, or running - Result.success(downloadId) - } - } - } else { - // Assume success if we can't check status - cursor.close() - Result.success(downloadId) - } - } else { - // Assume success if cursor is empty - cursor?.close() - Result.success(downloadId) - } - } catch (e: Exception) { - Log.e(TAG, "Failed to enqueue download: ${e.message}") - Result.failure(e) - } - } - - companion object { - private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName - - private const val HTTP_HEADER_CONTENT_LENGTH = "content-length" - private const val DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY = 500L - } -} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSourceImpl.kt b/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSourceImpl.kt new file mode 100644 index 0000000000..0b0bfe7e50 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/data/source/remote/HuggingFaceRemoteDataSourceImpl.kt @@ -0,0 +1,272 @@ +package com.example.llama.data.source.remote + +import android.app.DownloadManager +import android.content.Context +import android.os.Environment +import android.util.Log +import com.example.llama.monitoring.MemoryMetrics +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.delay +import kotlinx.coroutines.supervisorScope +import kotlinx.coroutines.sync.Semaphore +import kotlinx.coroutines.sync.withPermit +import kotlinx.coroutines.withContext +import retrofit2.HttpException +import java.io.FileNotFoundException +import java.io.IOException +import java.net.SocketTimeoutException +import java.net.UnknownHostException +import javax.inject.Inject +import javax.inject.Singleton +import kotlin.collections.contains +import kotlin.coroutines.cancellation.CancellationException +import kotlin.math.ceil + + +/* + * Preselected models: sized <2GB + */ +private val PRESELECTED_MODEL_IDS_SMALL = listOf( + "bartowski/Llama-3.2-1B-Instruct-GGUF", + "unsloth/gemma-3-1b-it-GGUF", + "bartowski/granite-3.0-2b-instruct-GGUF", +) + +/* + * Preselected models: sized 2~3GB + */ +private val PRESELECTED_MODEL_IDS_MEDIUM = listOf( + "bartowski/Llama-3.2-3B-Instruct-GGUF", + "unsloth/gemma-3n-E2B-it-GGUF", + "Qwen/Qwen2.5-3B-Instruct-GGUF", + "gaianet/Phi-4-mini-instruct-GGUF", + "unsloth/gemma-3-4b-it-GGUF", +) + +/* + * Preselected models: sized 4~6B + */ +private val PRESELECTED_MODEL_IDS_LARGE = listOf( + "unsloth/gemma-3n-E4B-it-GGUF", + "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", +) + +@Singleton +class HuggingFaceRemoteDataSourceImpl @Inject constructor( + private val apiService: HuggingFaceApiService +) : HuggingFaceRemoteDataSource { + + override suspend fun fetchPreselectedModels( + memoryUsage: MemoryMetrics, + parallelCount: Int, + quorum: Float, + ): List = withContext(Dispatchers.IO) { + val ids: List = when { + memoryUsage.availableGB >= 7f -> + PRESELECTED_MODEL_IDS_MEDIUM + PRESELECTED_MODEL_IDS_LARGE + PRESELECTED_MODEL_IDS_SMALL + memoryUsage.availableGB >= 5f -> + PRESELECTED_MODEL_IDS_SMALL + PRESELECTED_MODEL_IDS_MEDIUM + PRESELECTED_MODEL_IDS_LARGE + memoryUsage.availableGB >= 3f -> + PRESELECTED_MODEL_IDS_SMALL + PRESELECTED_MODEL_IDS_MEDIUM + else -> + PRESELECTED_MODEL_IDS_SMALL + } + + val sem = Semaphore(parallelCount) + val results = supervisorScope { + ids.map { id -> + async { + sem.withPermit { + try { + Result.success(getModelDetails(id)) + } catch (t: CancellationException) { + Result.failure(t) + } + } + } + }.awaitAll() + } + + val successes = results.mapNotNull { it.getOrNull() } + val failures = results.mapNotNull { it.exceptionOrNull() } + + val total = ids.size + val failed = failures.size + val ok = successes.size + val shouldThrow = failed >= ceil(total * quorum).toInt() + + if (!shouldThrow) return@withContext successes.toList() + + // 1. No Network + if (failures.count { it is UnknownHostException } >= ceil(failed * 0.5).toInt()) { + throw UnknownHostException() + } + + // 2. Time out + if (failures.count { it is SocketTimeoutException } >= ceil(failed * 0.5).toInt()) { + throw SocketTimeoutException() + } + + // 3. known error codes: 404/410/204 + val http404ish = failures.count { (it as? HttpException)?.code() in listOf(404, 410, 204) } + if (ok == 0 && (failed > 0) && (http404ish >= ceil(failed * 0.5).toInt() || failed == total)) { + throw FileNotFoundException() + } + + // 4. Unknown issues + val ioMajority = failures.count { + it is IOException && it !is UnknownHostException && it !is SocketTimeoutException + } >= ceil(failed * 0.5).toInt() + if (ioMajority) { + throw IOException(failures.first { it is IOException }.message) + } + + successes + } + + override suspend fun searchModels( + query: String?, + filter: String?, + sort: String?, + direction: String?, + limit: Int?, + full: Boolean, + invalidKeywords: Array, + ) = withContext(Dispatchers.IO) { + try { + apiService.getModels( + search = query, + filter = filter, + sort = sort, + direction = direction, + limit = limit, + full = full, + ) + .filterNot { it.gated || it.private } + .filterNot { + it.getGgufFilename().let { filename -> + filename.isNullOrBlank() || invalidKeywords.any { + filename.contains(it, ignoreCase = true) + } + } + }.let { + if (it.isEmpty()) Result.failure(FileNotFoundException()) + else Result.success(it) + } + } catch (e: Exception) { + Log.e(TAG, "Error searching for models on HuggingFace: ${e.message}") + Result.failure(e) + } + } + + override suspend fun getModelDetails( + modelId: String + ) = withContext(Dispatchers.IO) { + apiService.getModelDetails(modelId) + } + + override suspend fun getFileSize( + modelId: String, + filePath: String + ): Result = withContext(Dispatchers.IO) { + try { + apiService.getModelFileHeader(modelId, filePath).let { resp -> + if (resp.isSuccessful) { + resp.headers()[HTTP_HEADER_CONTENT_LENGTH]?.toLongOrNull()?.let { + Result.success(it) + } ?: Result.failure(IOException("Content-Length header missing")) + } else { + Result.failure( + when (resp.code()) { + 401 -> SecurityException("Model requires authentication") + 404 -> FileNotFoundException("Model file not found") + else -> IOException("Failed to get file info: HTTP ${resp.code()}") + } + ) + } + } + } catch (e: Exception) { + Log.e(TAG, "Error getting file size for $modelId: ${e.message}") + Result.failure(e) + } + } + + override suspend fun downloadModelFile( + context: Context, + downloadInfo: HuggingFaceDownloadInfo, + ): Result = withContext(Dispatchers.IO) { + try { + val downloadManager = + context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager + val request = DownloadManager.Request(downloadInfo.uri).apply { + setTitle(downloadInfo.filename) + setDescription("Downloading directly from HuggingFace") + setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED) + setDestinationInExternalPublicDir( + Environment.DIRECTORY_DOWNLOADS, + downloadInfo.filename + ) + setAllowedNetworkTypes( + DownloadManager.Request.NETWORK_WIFI or DownloadManager.Request.NETWORK_MOBILE + ) + setAllowedOverMetered(true) + setAllowedOverRoaming(false) + } + Log.d(TAG, "Enqueuing download request for: ${downloadInfo.modelId}") + val downloadId = downloadManager.enqueue(request) + + delay(DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY) + + val cursor = downloadManager.query(DownloadManager.Query().setFilterById(downloadId)) + if (cursor != null && cursor.moveToFirst()) { + val statusIndex = cursor.getColumnIndex(DownloadManager.COLUMN_STATUS) + if (statusIndex >= 0) { + val status = cursor.getInt(statusIndex) + cursor.close() + + when (status) { + DownloadManager.STATUS_FAILED -> { + // Get failure reason if available + val reasonIndex = cursor.getColumnIndex(DownloadManager.COLUMN_REASON) + val reason = if (reasonIndex >= 0) cursor.getInt(reasonIndex) else -1 + val errorMessage = when (reason) { + DownloadManager.ERROR_HTTP_DATA_ERROR -> "HTTP error" + DownloadManager.ERROR_INSUFFICIENT_SPACE -> "Insufficient storage" + DownloadManager.ERROR_TOO_MANY_REDIRECTS -> "Too many redirects" + DownloadManager.ERROR_UNHANDLED_HTTP_CODE -> "Unhandled HTTP code" + DownloadManager.ERROR_CANNOT_RESUME -> "Cannot resume download" + DownloadManager.ERROR_FILE_ERROR -> "File error" + else -> "Unknown error" + } + Result.failure(Exception(errorMessage)) + } + else -> { + // Download is pending, paused, or running + Result.success(downloadId) + } + } + } else { + // Assume success if we can't check status + cursor.close() + Result.success(downloadId) + } + } else { + // Assume success if cursor is empty + cursor?.close() + Result.success(downloadId) + } + } catch (e: Exception) { + Log.e(TAG, "Failed to enqueue download: ${e.message}") + Result.failure(e) + } + } + + companion object { + private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName + + private const val HTTP_HEADER_CONTENT_LENGTH = "content-length" + private const val DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY = 500L + } +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt index ab1addbf4f..fe372eda7f 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelsManagementViewModel.kt @@ -17,6 +17,7 @@ import com.example.llama.data.repo.ModelRepository import com.example.llama.data.source.remote.HuggingFaceDownloadInfo import com.example.llama.data.source.remote.HuggingFaceModel import com.example.llama.data.source.remote.HuggingFaceModelDetails +import com.example.llama.monitoring.MemoryMetrics import com.example.llama.util.formatFileByteSize import com.example.llama.util.getFileNameFromUri import com.example.llama.util.getFileSizeFromUri @@ -183,11 +184,12 @@ class ModelsManagementViewModel @Inject constructor( /** * Query models on HuggingFace available for download even without signing in */ - fun queryModelsFromHuggingFace(cap: Int = FETCH_HUGGINGFACE_MODELS_CAP_SIZE) { + fun queryModelsFromHuggingFace(memoryUsage: MemoryMetrics) { huggingFaceQueryJob = viewModelScope.launch { _managementState.emit(Download.Querying) try { - val models = modelRepository.fetchPreselectedHuggingFaceModels().map(HuggingFaceModelDetails::toModel) + val models = modelRepository.fetchPreselectedHuggingFaceModels(memoryUsage) + .map(HuggingFaceModelDetails::toModel) _managementState.emit(Download.Ready(models)) } catch (_: CancellationException) { resetManagementState() @@ -300,7 +302,6 @@ class ModelsManagementViewModel @Inject constructor( private val TAG = ModelsManagementViewModel::class.java.simpleName private const val FETCH_HUGGINGFACE_MODELS_LIMIT_SIZE = 50 - private const val FETCH_HUGGINGFACE_MODELS_CAP_SIZE = 12 private const val DELETE_SUCCESS_RESET_TIMEOUT_MS = 1000L } }