data: sort preselected models according to device's available RAM
This commit is contained in:
parent
687b86e924
commit
e0ddc37e2e
|
|
@ -344,7 +344,7 @@ fun AppContent(
|
|||
modelsManagementViewModel.toggleImportMenu(false)
|
||||
},
|
||||
importFromHuggingFace = {
|
||||
modelsManagementViewModel.queryModelsFromHuggingFace()
|
||||
modelsManagementViewModel.queryModelsFromHuggingFace(memoryUsage)
|
||||
modelsManagementViewModel.toggleImportMenu(false)
|
||||
}
|
||||
),
|
||||
|
|
|
|||
|
|
@ -141,56 +141,56 @@ enum class ModelFilter(val displayName: String, val predicate: (ModelInfo) -> Bo
|
|||
}),
|
||||
SMALL_PARAMS("Small (1-3B parameters)", {
|
||||
it.metadata.basic.sizeLabel?.let { size ->
|
||||
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 1f && n <= 3f } == true
|
||||
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 1f && n < 4f } == true
|
||||
} == true
|
||||
}),
|
||||
MEDIUM_PARAMS("Medium (4-7B parameters)", {
|
||||
it.metadata.basic.sizeLabel?.let { size ->
|
||||
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 4f && n <= 7f } == true
|
||||
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 4f && n < 8f } == true
|
||||
} == true
|
||||
}),
|
||||
LARGE_PARAMS("Large (8-13B parameters)", {
|
||||
LARGE_PARAMS("Large (8-12B parameters)", {
|
||||
it.metadata.basic.sizeLabel?.let { size ->
|
||||
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 8f && n <= 13f } == true
|
||||
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 8f && n < 13f } == true
|
||||
} == true
|
||||
}),
|
||||
XLARGE_PARAMS("X-Large (>13B parameters)", {
|
||||
it.metadata.basic.sizeLabel?.let { size ->
|
||||
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n > 13f } == true
|
||||
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 13f } == true
|
||||
} == true
|
||||
}),
|
||||
|
||||
// Context length filters
|
||||
TINY_CONTEXT("Tiny context (<4K)", {
|
||||
it.metadata.dimensions?.contextLength?.let { it < 4096 } == true
|
||||
TINY_CONTEXT("Tiny context (<4K)", { model ->
|
||||
model.metadata.dimensions?.contextLength?.let { it < 4096 } == true
|
||||
}),
|
||||
SHORT_CONTEXT("Short context (4-8K)", {
|
||||
it.metadata.dimensions?.contextLength?.let { it >= 4096 && it <= 8192 } == true
|
||||
SHORT_CONTEXT("Short context (4-8K)", { model ->
|
||||
model.metadata.dimensions?.contextLength?.let { it >= 4096 && it <= 8192 } == true
|
||||
}),
|
||||
MEDIUM_CONTEXT("Medium context (8-32K)", {
|
||||
it.metadata.dimensions?.contextLength?.let { it > 8192 && it <= 32768 } == true
|
||||
MEDIUM_CONTEXT("Medium context (8-32K)", { model ->
|
||||
model.metadata.dimensions?.contextLength?.let { it > 8192 && it <= 32768 } == true
|
||||
}),
|
||||
LONG_CONTEXT("Long context (32-128K)", {
|
||||
it.metadata.dimensions?.contextLength?.let { it > 32768 && it <= 131072 } == true
|
||||
LONG_CONTEXT("Long context (32-128K)", { model ->
|
||||
model.metadata.dimensions?.contextLength?.let { it > 32768 && it <= 131072 } == true
|
||||
}),
|
||||
XLARGE_CONTEXT("Extended context (>128K)", {
|
||||
it.metadata.dimensions?.contextLength?.let { it > 131072 } == true
|
||||
XLARGE_CONTEXT("Extended context (>128K)", { model ->
|
||||
model.metadata.dimensions?.contextLength?.let { it > 131072 } == true
|
||||
}),
|
||||
|
||||
// Quantization filters
|
||||
INT2_QUANT("2-bit quantization", {
|
||||
it.formattedQuantization.let { it.contains("Q2") || it.contains("IQ2") }
|
||||
INT2_QUANT("2-bit quantization", { model ->
|
||||
model.formattedQuantization.let { it.contains("Q2") || it.contains("IQ2") }
|
||||
}),
|
||||
INT3_QUANT("3-bit quantization", {
|
||||
it.formattedQuantization.let { it.contains("Q3") || it.contains("IQ3") }
|
||||
INT3_QUANT("3-bit quantization", { model ->
|
||||
model.formattedQuantization.let { it.contains("Q3") || it.contains("IQ3") }
|
||||
}),
|
||||
INT4_QUANT("4-bit quantization", {
|
||||
it.formattedQuantization.let { it.contains("Q4") || it.contains("IQ4") }
|
||||
INT4_QUANT("4-bit quantization", { model ->
|
||||
model.formattedQuantization.let { it.contains("Q4") || it.contains("IQ4") }
|
||||
}),
|
||||
|
||||
// Special features
|
||||
MULTILINGUAL("Multilingual", {
|
||||
it.languages?.let { languages ->
|
||||
MULTILINGUAL("Multilingual", { model ->
|
||||
model.languages?.let { languages ->
|
||||
languages.size > 1 || languages.any { it.contains("multi", ignoreCase = true) }
|
||||
} == true
|
||||
}),
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import com.example.llama.data.source.remote.HuggingFaceDownloadInfo
|
|||
import com.example.llama.data.source.remote.HuggingFaceModel
|
||||
import com.example.llama.data.source.remote.HuggingFaceModelDetails
|
||||
import com.example.llama.data.source.remote.HuggingFaceRemoteDataSource
|
||||
import com.example.llama.monitoring.MemoryMetrics
|
||||
import com.example.llama.monitoring.StorageMetrics
|
||||
import com.example.llama.util.formatFileByteSize
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
|
|
@ -83,7 +84,7 @@ interface ModelRepository {
|
|||
/**
|
||||
* Fetch details of preselected models
|
||||
*/
|
||||
suspend fun fetchPreselectedHuggingFaceModels(): List<HuggingFaceModelDetails>
|
||||
suspend fun fetchPreselectedHuggingFaceModels(memoryUsage: MemoryMetrics): List<HuggingFaceModelDetails>
|
||||
|
||||
/**
|
||||
* Search models on HuggingFace
|
||||
|
|
@ -321,8 +322,8 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
}
|
||||
}
|
||||
|
||||
override suspend fun fetchPreselectedHuggingFaceModels() = withContext(Dispatchers.IO) {
|
||||
huggingFaceRemoteDataSource.fetchPreselectedModels()
|
||||
override suspend fun fetchPreselectedHuggingFaceModels(memoryUsage: MemoryMetrics) = withContext(Dispatchers.IO) {
|
||||
huggingFaceRemoteDataSource.fetchPreselectedModels(memoryUsage)
|
||||
}
|
||||
|
||||
override suspend fun searchHuggingFaceModels(
|
||||
|
|
|
|||
|
|
@ -1,27 +1,12 @@
|
|||
package com.example.llama.data.source.remote
|
||||
|
||||
import android.app.DownloadManager
|
||||
import android.content.Context
|
||||
import android.os.Environment
|
||||
import android.util.Log
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.awaitAll
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.supervisorScope
|
||||
import kotlinx.coroutines.sync.Semaphore
|
||||
import kotlinx.coroutines.sync.withPermit
|
||||
import kotlinx.coroutines.withContext
|
||||
import retrofit2.HttpException
|
||||
import java.io.FileNotFoundException
|
||||
import java.io.IOException
|
||||
import java.net.SocketTimeoutException
|
||||
import java.net.UnknownHostException
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
import kotlin.coroutines.cancellation.CancellationException
|
||||
import kotlin.math.ceil
|
||||
import com.example.llama.monitoring.MemoryMetrics
|
||||
|
||||
|
||||
/*
|
||||
* HuggingFace Search API
|
||||
*/
|
||||
private const val QUERY_Q4_0_GGUF = "gguf q4_0"
|
||||
private const val FILTER_TEXT_GENERATION = "text-generation"
|
||||
private const val SORT_BY_DOWNLOADS = "downloads"
|
||||
|
|
@ -29,21 +14,10 @@ private const val SEARCH_RESULT_LIMIT = 30
|
|||
|
||||
private val INVALID_KEYWORDS = arrayOf("-of-", "split", "70B", "30B", "27B", "14B", "13B", "12B")
|
||||
|
||||
private val PRESELECTED_MODEL_IDS = listOf(
|
||||
"unsloth/gemma-3-1b-it-GGUF",
|
||||
"unsloth/gemma-3-4b-it-GGUF",
|
||||
"bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||
"bartowski/Llama-3.2-3B-Instruct-GGUF",
|
||||
"Qwen/Qwen2.5-3B-Instruct-GGUF",
|
||||
"gaianet/Phi-4-mini-instruct-GGUF",
|
||||
"bartowski/granite-3.0-2b-instruct-GGUF",
|
||||
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
)
|
||||
|
||||
interface HuggingFaceRemoteDataSource {
|
||||
|
||||
suspend fun fetchPreselectedModels(
|
||||
ids: List<String> = PRESELECTED_MODEL_IDS,
|
||||
memoryUsage: MemoryMetrics,
|
||||
parallelCount: Int = 3,
|
||||
quorum: Float = 0.5f,
|
||||
): List<HuggingFaceModelDetails>
|
||||
|
|
@ -58,6 +32,7 @@ interface HuggingFaceRemoteDataSource {
|
|||
direction: String? = "-1",
|
||||
limit: Int? = SEARCH_RESULT_LIMIT,
|
||||
full: Boolean = true,
|
||||
invalidKeywords: Array<String> = INVALID_KEYWORDS
|
||||
): Result<List<HuggingFaceModel>>
|
||||
|
||||
suspend fun getModelDetails(modelId: String): HuggingFaceModelDetails
|
||||
|
|
@ -76,208 +51,3 @@ interface HuggingFaceRemoteDataSource {
|
|||
): Result<Long>
|
||||
}
|
||||
|
||||
@Singleton
|
||||
class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
||||
private val apiService: HuggingFaceApiService
|
||||
) : HuggingFaceRemoteDataSource {
|
||||
|
||||
override suspend fun fetchPreselectedModels(
|
||||
ids: List<String>,
|
||||
parallelCount: Int,
|
||||
quorum: Float,
|
||||
): List<HuggingFaceModelDetails> = withContext(Dispatchers.IO) {
|
||||
val sem = Semaphore(parallelCount)
|
||||
val results = supervisorScope {
|
||||
ids.map { id ->
|
||||
async {
|
||||
sem.withPermit {
|
||||
try {
|
||||
Result.success(getModelDetails(id))
|
||||
} catch (t: CancellationException) {
|
||||
Result.failure(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
}.awaitAll()
|
||||
}
|
||||
|
||||
val successes = results.mapNotNull { it.getOrNull() }
|
||||
val failures = results.mapNotNull { it.exceptionOrNull() }
|
||||
|
||||
val total = ids.size
|
||||
val failed = failures.size
|
||||
val ok = successes.size
|
||||
val shouldThrow = failed >= ceil(total * quorum).toInt()
|
||||
|
||||
if (!shouldThrow) return@withContext successes.toList()
|
||||
|
||||
// 1. No Network
|
||||
if (failures.count { it is UnknownHostException } >= ceil(failed * 0.5).toInt()) {
|
||||
throw UnknownHostException()
|
||||
}
|
||||
|
||||
// 2. Time out
|
||||
if (failures.count { it is SocketTimeoutException } >= ceil(failed * 0.5).toInt()) {
|
||||
throw SocketTimeoutException()
|
||||
}
|
||||
|
||||
// 3. known error codes: 404/410/204
|
||||
val http404ish = failures.count { (it as? HttpException)?.code() in listOf(404, 410, 204) }
|
||||
if (ok == 0 && (failed > 0) && (http404ish >= ceil(failed * 0.5).toInt() || failed == total)) {
|
||||
throw FileNotFoundException()
|
||||
}
|
||||
|
||||
// 4. Unknown issues
|
||||
val ioMajority = failures.count {
|
||||
it is IOException && it !is UnknownHostException && it !is SocketTimeoutException
|
||||
} >= ceil(failed * 0.5).toInt()
|
||||
if (ioMajority) {
|
||||
throw IOException(failures.first { it is IOException }.message)
|
||||
}
|
||||
|
||||
successes
|
||||
}
|
||||
|
||||
override suspend fun searchModels(
|
||||
query: String?,
|
||||
filter: String?,
|
||||
sort: String?,
|
||||
direction: String?,
|
||||
limit: Int?,
|
||||
full: Boolean,
|
||||
) = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
apiService.getModels(
|
||||
search = query,
|
||||
filter = filter,
|
||||
sort = sort,
|
||||
direction = direction,
|
||||
limit = limit,
|
||||
full = full,
|
||||
)
|
||||
.filterNot { it.gated || it.private }
|
||||
.filterNot {
|
||||
it.getGgufFilename().let { filename ->
|
||||
filename.isNullOrBlank() || INVALID_KEYWORDS.any {
|
||||
filename.contains(it, ignoreCase = true)
|
||||
}
|
||||
}
|
||||
}.let {
|
||||
if (it.isEmpty()) Result.failure(FileNotFoundException())
|
||||
else Result.success(it)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error searching for models on HuggingFace: ${e.message}")
|
||||
Result.failure(e)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun getModelDetails(
|
||||
modelId: String
|
||||
) = withContext(Dispatchers.IO) {
|
||||
apiService.getModelDetails(modelId)
|
||||
}
|
||||
|
||||
override suspend fun getFileSize(
|
||||
modelId: String,
|
||||
filePath: String
|
||||
): Result<Long> = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
apiService.getModelFileHeader(modelId, filePath).let { resp ->
|
||||
if (resp.isSuccessful) {
|
||||
resp.headers()[HTTP_HEADER_CONTENT_LENGTH]?.toLongOrNull()?.let {
|
||||
Result.success(it)
|
||||
} ?: Result.failure(IOException("Content-Length header missing"))
|
||||
} else {
|
||||
Result.failure(
|
||||
when (resp.code()) {
|
||||
401 -> SecurityException("Model requires authentication")
|
||||
404 -> FileNotFoundException("Model file not found")
|
||||
else -> IOException("Failed to get file info: HTTP ${resp.code()}")
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error getting file size for $modelId: ${e.message}")
|
||||
Result.failure(e)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun downloadModelFile(
|
||||
context: Context,
|
||||
downloadInfo: HuggingFaceDownloadInfo,
|
||||
): Result<Long> = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
val downloadManager =
|
||||
context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager
|
||||
val request = DownloadManager.Request(downloadInfo.uri).apply {
|
||||
setTitle(downloadInfo.filename)
|
||||
setDescription("Downloading directly from HuggingFace")
|
||||
setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED)
|
||||
setDestinationInExternalPublicDir(
|
||||
Environment.DIRECTORY_DOWNLOADS,
|
||||
downloadInfo.filename
|
||||
)
|
||||
setAllowedNetworkTypes(
|
||||
DownloadManager.Request.NETWORK_WIFI or DownloadManager.Request.NETWORK_MOBILE
|
||||
)
|
||||
setAllowedOverMetered(true)
|
||||
setAllowedOverRoaming(false)
|
||||
}
|
||||
Log.d(TAG, "Enqueuing download request for: ${downloadInfo.modelId}")
|
||||
val downloadId = downloadManager.enqueue(request)
|
||||
|
||||
delay(DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY)
|
||||
|
||||
val cursor = downloadManager.query(DownloadManager.Query().setFilterById(downloadId))
|
||||
if (cursor != null && cursor.moveToFirst()) {
|
||||
val statusIndex = cursor.getColumnIndex(DownloadManager.COLUMN_STATUS)
|
||||
if (statusIndex >= 0) {
|
||||
val status = cursor.getInt(statusIndex)
|
||||
cursor.close()
|
||||
|
||||
when (status) {
|
||||
DownloadManager.STATUS_FAILED -> {
|
||||
// Get failure reason if available
|
||||
val reasonIndex = cursor.getColumnIndex(DownloadManager.COLUMN_REASON)
|
||||
val reason = if (reasonIndex >= 0) cursor.getInt(reasonIndex) else -1
|
||||
val errorMessage = when (reason) {
|
||||
DownloadManager.ERROR_HTTP_DATA_ERROR -> "HTTP error"
|
||||
DownloadManager.ERROR_INSUFFICIENT_SPACE -> "Insufficient storage"
|
||||
DownloadManager.ERROR_TOO_MANY_REDIRECTS -> "Too many redirects"
|
||||
DownloadManager.ERROR_UNHANDLED_HTTP_CODE -> "Unhandled HTTP code"
|
||||
DownloadManager.ERROR_CANNOT_RESUME -> "Cannot resume download"
|
||||
DownloadManager.ERROR_FILE_ERROR -> "File error"
|
||||
else -> "Unknown error"
|
||||
}
|
||||
Result.failure(Exception(errorMessage))
|
||||
}
|
||||
else -> {
|
||||
// Download is pending, paused, or running
|
||||
Result.success(downloadId)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Assume success if we can't check status
|
||||
cursor.close()
|
||||
Result.success(downloadId)
|
||||
}
|
||||
} else {
|
||||
// Assume success if cursor is empty
|
||||
cursor?.close()
|
||||
Result.success(downloadId)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Failed to enqueue download: ${e.message}")
|
||||
Result.failure(e)
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName
|
||||
|
||||
private const val HTTP_HEADER_CONTENT_LENGTH = "content-length"
|
||||
private const val DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY = 500L
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,272 @@
|
|||
package com.example.llama.data.source.remote
|
||||
|
||||
import android.app.DownloadManager
|
||||
import android.content.Context
|
||||
import android.os.Environment
|
||||
import android.util.Log
|
||||
import com.example.llama.monitoring.MemoryMetrics
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.awaitAll
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.supervisorScope
|
||||
import kotlinx.coroutines.sync.Semaphore
|
||||
import kotlinx.coroutines.sync.withPermit
|
||||
import kotlinx.coroutines.withContext
|
||||
import retrofit2.HttpException
|
||||
import java.io.FileNotFoundException
|
||||
import java.io.IOException
|
||||
import java.net.SocketTimeoutException
|
||||
import java.net.UnknownHostException
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
import kotlin.collections.contains
|
||||
import kotlin.coroutines.cancellation.CancellationException
|
||||
import kotlin.math.ceil
|
||||
|
||||
|
||||
/*
|
||||
* Preselected models: sized <2GB
|
||||
*/
|
||||
private val PRESELECTED_MODEL_IDS_SMALL = listOf(
|
||||
"bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||
"unsloth/gemma-3-1b-it-GGUF",
|
||||
"bartowski/granite-3.0-2b-instruct-GGUF",
|
||||
)
|
||||
|
||||
/*
|
||||
* Preselected models: sized 2~3GB
|
||||
*/
|
||||
private val PRESELECTED_MODEL_IDS_MEDIUM = listOf(
|
||||
"bartowski/Llama-3.2-3B-Instruct-GGUF",
|
||||
"unsloth/gemma-3n-E2B-it-GGUF",
|
||||
"Qwen/Qwen2.5-3B-Instruct-GGUF",
|
||||
"gaianet/Phi-4-mini-instruct-GGUF",
|
||||
"unsloth/gemma-3-4b-it-GGUF",
|
||||
)
|
||||
|
||||
/*
|
||||
* Preselected models: sized 4~6B
|
||||
*/
|
||||
private val PRESELECTED_MODEL_IDS_LARGE = listOf(
|
||||
"unsloth/gemma-3n-E4B-it-GGUF",
|
||||
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
)
|
||||
|
||||
@Singleton
|
||||
class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
||||
private val apiService: HuggingFaceApiService
|
||||
) : HuggingFaceRemoteDataSource {
|
||||
|
||||
override suspend fun fetchPreselectedModels(
|
||||
memoryUsage: MemoryMetrics,
|
||||
parallelCount: Int,
|
||||
quorum: Float,
|
||||
): List<HuggingFaceModelDetails> = withContext(Dispatchers.IO) {
|
||||
val ids: List<String> = when {
|
||||
memoryUsage.availableGB >= 7f ->
|
||||
PRESELECTED_MODEL_IDS_MEDIUM + PRESELECTED_MODEL_IDS_LARGE + PRESELECTED_MODEL_IDS_SMALL
|
||||
memoryUsage.availableGB >= 5f ->
|
||||
PRESELECTED_MODEL_IDS_SMALL + PRESELECTED_MODEL_IDS_MEDIUM + PRESELECTED_MODEL_IDS_LARGE
|
||||
memoryUsage.availableGB >= 3f ->
|
||||
PRESELECTED_MODEL_IDS_SMALL + PRESELECTED_MODEL_IDS_MEDIUM
|
||||
else ->
|
||||
PRESELECTED_MODEL_IDS_SMALL
|
||||
}
|
||||
|
||||
val sem = Semaphore(parallelCount)
|
||||
val results = supervisorScope {
|
||||
ids.map { id ->
|
||||
async {
|
||||
sem.withPermit {
|
||||
try {
|
||||
Result.success(getModelDetails(id))
|
||||
} catch (t: CancellationException) {
|
||||
Result.failure(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
}.awaitAll()
|
||||
}
|
||||
|
||||
val successes = results.mapNotNull { it.getOrNull() }
|
||||
val failures = results.mapNotNull { it.exceptionOrNull() }
|
||||
|
||||
val total = ids.size
|
||||
val failed = failures.size
|
||||
val ok = successes.size
|
||||
val shouldThrow = failed >= ceil(total * quorum).toInt()
|
||||
|
||||
if (!shouldThrow) return@withContext successes.toList()
|
||||
|
||||
// 1. No Network
|
||||
if (failures.count { it is UnknownHostException } >= ceil(failed * 0.5).toInt()) {
|
||||
throw UnknownHostException()
|
||||
}
|
||||
|
||||
// 2. Time out
|
||||
if (failures.count { it is SocketTimeoutException } >= ceil(failed * 0.5).toInt()) {
|
||||
throw SocketTimeoutException()
|
||||
}
|
||||
|
||||
// 3. known error codes: 404/410/204
|
||||
val http404ish = failures.count { (it as? HttpException)?.code() in listOf(404, 410, 204) }
|
||||
if (ok == 0 && (failed > 0) && (http404ish >= ceil(failed * 0.5).toInt() || failed == total)) {
|
||||
throw FileNotFoundException()
|
||||
}
|
||||
|
||||
// 4. Unknown issues
|
||||
val ioMajority = failures.count {
|
||||
it is IOException && it !is UnknownHostException && it !is SocketTimeoutException
|
||||
} >= ceil(failed * 0.5).toInt()
|
||||
if (ioMajority) {
|
||||
throw IOException(failures.first { it is IOException }.message)
|
||||
}
|
||||
|
||||
successes
|
||||
}
|
||||
|
||||
override suspend fun searchModels(
|
||||
query: String?,
|
||||
filter: String?,
|
||||
sort: String?,
|
||||
direction: String?,
|
||||
limit: Int?,
|
||||
full: Boolean,
|
||||
invalidKeywords: Array<String>,
|
||||
) = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
apiService.getModels(
|
||||
search = query,
|
||||
filter = filter,
|
||||
sort = sort,
|
||||
direction = direction,
|
||||
limit = limit,
|
||||
full = full,
|
||||
)
|
||||
.filterNot { it.gated || it.private }
|
||||
.filterNot {
|
||||
it.getGgufFilename().let { filename ->
|
||||
filename.isNullOrBlank() || invalidKeywords.any {
|
||||
filename.contains(it, ignoreCase = true)
|
||||
}
|
||||
}
|
||||
}.let {
|
||||
if (it.isEmpty()) Result.failure(FileNotFoundException())
|
||||
else Result.success(it)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error searching for models on HuggingFace: ${e.message}")
|
||||
Result.failure(e)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun getModelDetails(
|
||||
modelId: String
|
||||
) = withContext(Dispatchers.IO) {
|
||||
apiService.getModelDetails(modelId)
|
||||
}
|
||||
|
||||
override suspend fun getFileSize(
|
||||
modelId: String,
|
||||
filePath: String
|
||||
): Result<Long> = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
apiService.getModelFileHeader(modelId, filePath).let { resp ->
|
||||
if (resp.isSuccessful) {
|
||||
resp.headers()[HTTP_HEADER_CONTENT_LENGTH]?.toLongOrNull()?.let {
|
||||
Result.success(it)
|
||||
} ?: Result.failure(IOException("Content-Length header missing"))
|
||||
} else {
|
||||
Result.failure(
|
||||
when (resp.code()) {
|
||||
401 -> SecurityException("Model requires authentication")
|
||||
404 -> FileNotFoundException("Model file not found")
|
||||
else -> IOException("Failed to get file info: HTTP ${resp.code()}")
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error getting file size for $modelId: ${e.message}")
|
||||
Result.failure(e)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun downloadModelFile(
|
||||
context: Context,
|
||||
downloadInfo: HuggingFaceDownloadInfo,
|
||||
): Result<Long> = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
val downloadManager =
|
||||
context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager
|
||||
val request = DownloadManager.Request(downloadInfo.uri).apply {
|
||||
setTitle(downloadInfo.filename)
|
||||
setDescription("Downloading directly from HuggingFace")
|
||||
setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED)
|
||||
setDestinationInExternalPublicDir(
|
||||
Environment.DIRECTORY_DOWNLOADS,
|
||||
downloadInfo.filename
|
||||
)
|
||||
setAllowedNetworkTypes(
|
||||
DownloadManager.Request.NETWORK_WIFI or DownloadManager.Request.NETWORK_MOBILE
|
||||
)
|
||||
setAllowedOverMetered(true)
|
||||
setAllowedOverRoaming(false)
|
||||
}
|
||||
Log.d(TAG, "Enqueuing download request for: ${downloadInfo.modelId}")
|
||||
val downloadId = downloadManager.enqueue(request)
|
||||
|
||||
delay(DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY)
|
||||
|
||||
val cursor = downloadManager.query(DownloadManager.Query().setFilterById(downloadId))
|
||||
if (cursor != null && cursor.moveToFirst()) {
|
||||
val statusIndex = cursor.getColumnIndex(DownloadManager.COLUMN_STATUS)
|
||||
if (statusIndex >= 0) {
|
||||
val status = cursor.getInt(statusIndex)
|
||||
cursor.close()
|
||||
|
||||
when (status) {
|
||||
DownloadManager.STATUS_FAILED -> {
|
||||
// Get failure reason if available
|
||||
val reasonIndex = cursor.getColumnIndex(DownloadManager.COLUMN_REASON)
|
||||
val reason = if (reasonIndex >= 0) cursor.getInt(reasonIndex) else -1
|
||||
val errorMessage = when (reason) {
|
||||
DownloadManager.ERROR_HTTP_DATA_ERROR -> "HTTP error"
|
||||
DownloadManager.ERROR_INSUFFICIENT_SPACE -> "Insufficient storage"
|
||||
DownloadManager.ERROR_TOO_MANY_REDIRECTS -> "Too many redirects"
|
||||
DownloadManager.ERROR_UNHANDLED_HTTP_CODE -> "Unhandled HTTP code"
|
||||
DownloadManager.ERROR_CANNOT_RESUME -> "Cannot resume download"
|
||||
DownloadManager.ERROR_FILE_ERROR -> "File error"
|
||||
else -> "Unknown error"
|
||||
}
|
||||
Result.failure(Exception(errorMessage))
|
||||
}
|
||||
else -> {
|
||||
// Download is pending, paused, or running
|
||||
Result.success(downloadId)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Assume success if we can't check status
|
||||
cursor.close()
|
||||
Result.success(downloadId)
|
||||
}
|
||||
} else {
|
||||
// Assume success if cursor is empty
|
||||
cursor?.close()
|
||||
Result.success(downloadId)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Failed to enqueue download: ${e.message}")
|
||||
Result.failure(e)
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName
|
||||
|
||||
private const val HTTP_HEADER_CONTENT_LENGTH = "content-length"
|
||||
private const val DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY = 500L
|
||||
}
|
||||
}
|
||||
|
|
@ -17,6 +17,7 @@ import com.example.llama.data.repo.ModelRepository
|
|||
import com.example.llama.data.source.remote.HuggingFaceDownloadInfo
|
||||
import com.example.llama.data.source.remote.HuggingFaceModel
|
||||
import com.example.llama.data.source.remote.HuggingFaceModelDetails
|
||||
import com.example.llama.monitoring.MemoryMetrics
|
||||
import com.example.llama.util.formatFileByteSize
|
||||
import com.example.llama.util.getFileNameFromUri
|
||||
import com.example.llama.util.getFileSizeFromUri
|
||||
|
|
@ -183,11 +184,12 @@ class ModelsManagementViewModel @Inject constructor(
|
|||
/**
|
||||
* Query models on HuggingFace available for download even without signing in
|
||||
*/
|
||||
fun queryModelsFromHuggingFace(cap: Int = FETCH_HUGGINGFACE_MODELS_CAP_SIZE) {
|
||||
fun queryModelsFromHuggingFace(memoryUsage: MemoryMetrics) {
|
||||
huggingFaceQueryJob = viewModelScope.launch {
|
||||
_managementState.emit(Download.Querying)
|
||||
try {
|
||||
val models = modelRepository.fetchPreselectedHuggingFaceModels().map(HuggingFaceModelDetails::toModel)
|
||||
val models = modelRepository.fetchPreselectedHuggingFaceModels(memoryUsage)
|
||||
.map(HuggingFaceModelDetails::toModel)
|
||||
_managementState.emit(Download.Ready(models))
|
||||
} catch (_: CancellationException) {
|
||||
resetManagementState()
|
||||
|
|
@ -300,7 +302,6 @@ class ModelsManagementViewModel @Inject constructor(
|
|||
private val TAG = ModelsManagementViewModel::class.java.simpleName
|
||||
|
||||
private const val FETCH_HUGGINGFACE_MODELS_LIMIT_SIZE = 50
|
||||
private const val FETCH_HUGGINGFACE_MODELS_CAP_SIZE = 12
|
||||
private const val DELETE_SUCCESS_RESET_TIMEOUT_MS = 1000L
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue