data: sort preselected models according to device's available RAM
This commit is contained in:
parent
687b86e924
commit
e0ddc37e2e
|
|
@ -344,7 +344,7 @@ fun AppContent(
|
||||||
modelsManagementViewModel.toggleImportMenu(false)
|
modelsManagementViewModel.toggleImportMenu(false)
|
||||||
},
|
},
|
||||||
importFromHuggingFace = {
|
importFromHuggingFace = {
|
||||||
modelsManagementViewModel.queryModelsFromHuggingFace()
|
modelsManagementViewModel.queryModelsFromHuggingFace(memoryUsage)
|
||||||
modelsManagementViewModel.toggleImportMenu(false)
|
modelsManagementViewModel.toggleImportMenu(false)
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -141,56 +141,56 @@ enum class ModelFilter(val displayName: String, val predicate: (ModelInfo) -> Bo
|
||||||
}),
|
}),
|
||||||
SMALL_PARAMS("Small (1-3B parameters)", {
|
SMALL_PARAMS("Small (1-3B parameters)", {
|
||||||
it.metadata.basic.sizeLabel?.let { size ->
|
it.metadata.basic.sizeLabel?.let { size ->
|
||||||
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 1f && n <= 3f } == true
|
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 1f && n < 4f } == true
|
||||||
} == true
|
} == true
|
||||||
}),
|
}),
|
||||||
MEDIUM_PARAMS("Medium (4-7B parameters)", {
|
MEDIUM_PARAMS("Medium (4-7B parameters)", {
|
||||||
it.metadata.basic.sizeLabel?.let { size ->
|
it.metadata.basic.sizeLabel?.let { size ->
|
||||||
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 4f && n <= 7f } == true
|
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 4f && n < 8f } == true
|
||||||
} == true
|
} == true
|
||||||
}),
|
}),
|
||||||
LARGE_PARAMS("Large (8-13B parameters)", {
|
LARGE_PARAMS("Large (8-12B parameters)", {
|
||||||
it.metadata.basic.sizeLabel?.let { size ->
|
it.metadata.basic.sizeLabel?.let { size ->
|
||||||
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 8f && n <= 13f } == true
|
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 8f && n < 13f } == true
|
||||||
} == true
|
} == true
|
||||||
}),
|
}),
|
||||||
XLARGE_PARAMS("X-Large (>13B parameters)", {
|
XLARGE_PARAMS("X-Large (>13B parameters)", {
|
||||||
it.metadata.basic.sizeLabel?.let { size ->
|
it.metadata.basic.sizeLabel?.let { size ->
|
||||||
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n > 13f } == true
|
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 13f } == true
|
||||||
} == true
|
} == true
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// Context length filters
|
// Context length filters
|
||||||
TINY_CONTEXT("Tiny context (<4K)", {
|
TINY_CONTEXT("Tiny context (<4K)", { model ->
|
||||||
it.metadata.dimensions?.contextLength?.let { it < 4096 } == true
|
model.metadata.dimensions?.contextLength?.let { it < 4096 } == true
|
||||||
}),
|
}),
|
||||||
SHORT_CONTEXT("Short context (4-8K)", {
|
SHORT_CONTEXT("Short context (4-8K)", { model ->
|
||||||
it.metadata.dimensions?.contextLength?.let { it >= 4096 && it <= 8192 } == true
|
model.metadata.dimensions?.contextLength?.let { it >= 4096 && it <= 8192 } == true
|
||||||
}),
|
}),
|
||||||
MEDIUM_CONTEXT("Medium context (8-32K)", {
|
MEDIUM_CONTEXT("Medium context (8-32K)", { model ->
|
||||||
it.metadata.dimensions?.contextLength?.let { it > 8192 && it <= 32768 } == true
|
model.metadata.dimensions?.contextLength?.let { it > 8192 && it <= 32768 } == true
|
||||||
}),
|
}),
|
||||||
LONG_CONTEXT("Long context (32-128K)", {
|
LONG_CONTEXT("Long context (32-128K)", { model ->
|
||||||
it.metadata.dimensions?.contextLength?.let { it > 32768 && it <= 131072 } == true
|
model.metadata.dimensions?.contextLength?.let { it > 32768 && it <= 131072 } == true
|
||||||
}),
|
}),
|
||||||
XLARGE_CONTEXT("Extended context (>128K)", {
|
XLARGE_CONTEXT("Extended context (>128K)", { model ->
|
||||||
it.metadata.dimensions?.contextLength?.let { it > 131072 } == true
|
model.metadata.dimensions?.contextLength?.let { it > 131072 } == true
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// Quantization filters
|
// Quantization filters
|
||||||
INT2_QUANT("2-bit quantization", {
|
INT2_QUANT("2-bit quantization", { model ->
|
||||||
it.formattedQuantization.let { it.contains("Q2") || it.contains("IQ2") }
|
model.formattedQuantization.let { it.contains("Q2") || it.contains("IQ2") }
|
||||||
}),
|
}),
|
||||||
INT3_QUANT("3-bit quantization", {
|
INT3_QUANT("3-bit quantization", { model ->
|
||||||
it.formattedQuantization.let { it.contains("Q3") || it.contains("IQ3") }
|
model.formattedQuantization.let { it.contains("Q3") || it.contains("IQ3") }
|
||||||
}),
|
}),
|
||||||
INT4_QUANT("4-bit quantization", {
|
INT4_QUANT("4-bit quantization", { model ->
|
||||||
it.formattedQuantization.let { it.contains("Q4") || it.contains("IQ4") }
|
model.formattedQuantization.let { it.contains("Q4") || it.contains("IQ4") }
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// Special features
|
// Special features
|
||||||
MULTILINGUAL("Multilingual", {
|
MULTILINGUAL("Multilingual", { model ->
|
||||||
it.languages?.let { languages ->
|
model.languages?.let { languages ->
|
||||||
languages.size > 1 || languages.any { it.contains("multi", ignoreCase = true) }
|
languages.size > 1 || languages.any { it.contains("multi", ignoreCase = true) }
|
||||||
} == true
|
} == true
|
||||||
}),
|
}),
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ import com.example.llama.data.source.remote.HuggingFaceDownloadInfo
|
||||||
import com.example.llama.data.source.remote.HuggingFaceModel
|
import com.example.llama.data.source.remote.HuggingFaceModel
|
||||||
import com.example.llama.data.source.remote.HuggingFaceModelDetails
|
import com.example.llama.data.source.remote.HuggingFaceModelDetails
|
||||||
import com.example.llama.data.source.remote.HuggingFaceRemoteDataSource
|
import com.example.llama.data.source.remote.HuggingFaceRemoteDataSource
|
||||||
|
import com.example.llama.monitoring.MemoryMetrics
|
||||||
import com.example.llama.monitoring.StorageMetrics
|
import com.example.llama.monitoring.StorageMetrics
|
||||||
import com.example.llama.util.formatFileByteSize
|
import com.example.llama.util.formatFileByteSize
|
||||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||||
|
|
@ -83,7 +84,7 @@ interface ModelRepository {
|
||||||
/**
|
/**
|
||||||
* Fetch details of preselected models
|
* Fetch details of preselected models
|
||||||
*/
|
*/
|
||||||
suspend fun fetchPreselectedHuggingFaceModels(): List<HuggingFaceModelDetails>
|
suspend fun fetchPreselectedHuggingFaceModels(memoryUsage: MemoryMetrics): List<HuggingFaceModelDetails>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Search models on HuggingFace
|
* Search models on HuggingFace
|
||||||
|
|
@ -321,8 +322,8 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun fetchPreselectedHuggingFaceModels() = withContext(Dispatchers.IO) {
|
override suspend fun fetchPreselectedHuggingFaceModels(memoryUsage: MemoryMetrics) = withContext(Dispatchers.IO) {
|
||||||
huggingFaceRemoteDataSource.fetchPreselectedModels()
|
huggingFaceRemoteDataSource.fetchPreselectedModels(memoryUsage)
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun searchHuggingFaceModels(
|
override suspend fun searchHuggingFaceModels(
|
||||||
|
|
|
||||||
|
|
@ -1,27 +1,12 @@
|
||||||
package com.example.llama.data.source.remote
|
package com.example.llama.data.source.remote
|
||||||
|
|
||||||
import android.app.DownloadManager
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
import android.os.Environment
|
import com.example.llama.monitoring.MemoryMetrics
|
||||||
import android.util.Log
|
|
||||||
import kotlinx.coroutines.Dispatchers
|
|
||||||
import kotlinx.coroutines.async
|
|
||||||
import kotlinx.coroutines.awaitAll
|
|
||||||
import kotlinx.coroutines.delay
|
|
||||||
import kotlinx.coroutines.supervisorScope
|
|
||||||
import kotlinx.coroutines.sync.Semaphore
|
|
||||||
import kotlinx.coroutines.sync.withPermit
|
|
||||||
import kotlinx.coroutines.withContext
|
|
||||||
import retrofit2.HttpException
|
|
||||||
import java.io.FileNotFoundException
|
|
||||||
import java.io.IOException
|
|
||||||
import java.net.SocketTimeoutException
|
|
||||||
import java.net.UnknownHostException
|
|
||||||
import javax.inject.Inject
|
|
||||||
import javax.inject.Singleton
|
|
||||||
import kotlin.coroutines.cancellation.CancellationException
|
|
||||||
import kotlin.math.ceil
|
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
* HuggingFace Search API
|
||||||
|
*/
|
||||||
private const val QUERY_Q4_0_GGUF = "gguf q4_0"
|
private const val QUERY_Q4_0_GGUF = "gguf q4_0"
|
||||||
private const val FILTER_TEXT_GENERATION = "text-generation"
|
private const val FILTER_TEXT_GENERATION = "text-generation"
|
||||||
private const val SORT_BY_DOWNLOADS = "downloads"
|
private const val SORT_BY_DOWNLOADS = "downloads"
|
||||||
|
|
@ -29,21 +14,10 @@ private const val SEARCH_RESULT_LIMIT = 30
|
||||||
|
|
||||||
private val INVALID_KEYWORDS = arrayOf("-of-", "split", "70B", "30B", "27B", "14B", "13B", "12B")
|
private val INVALID_KEYWORDS = arrayOf("-of-", "split", "70B", "30B", "27B", "14B", "13B", "12B")
|
||||||
|
|
||||||
private val PRESELECTED_MODEL_IDS = listOf(
|
|
||||||
"unsloth/gemma-3-1b-it-GGUF",
|
|
||||||
"unsloth/gemma-3-4b-it-GGUF",
|
|
||||||
"bartowski/Llama-3.2-1B-Instruct-GGUF",
|
|
||||||
"bartowski/Llama-3.2-3B-Instruct-GGUF",
|
|
||||||
"Qwen/Qwen2.5-3B-Instruct-GGUF",
|
|
||||||
"gaianet/Phi-4-mini-instruct-GGUF",
|
|
||||||
"bartowski/granite-3.0-2b-instruct-GGUF",
|
|
||||||
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
|
||||||
)
|
|
||||||
|
|
||||||
interface HuggingFaceRemoteDataSource {
|
interface HuggingFaceRemoteDataSource {
|
||||||
|
|
||||||
suspend fun fetchPreselectedModels(
|
suspend fun fetchPreselectedModels(
|
||||||
ids: List<String> = PRESELECTED_MODEL_IDS,
|
memoryUsage: MemoryMetrics,
|
||||||
parallelCount: Int = 3,
|
parallelCount: Int = 3,
|
||||||
quorum: Float = 0.5f,
|
quorum: Float = 0.5f,
|
||||||
): List<HuggingFaceModelDetails>
|
): List<HuggingFaceModelDetails>
|
||||||
|
|
@ -58,6 +32,7 @@ interface HuggingFaceRemoteDataSource {
|
||||||
direction: String? = "-1",
|
direction: String? = "-1",
|
||||||
limit: Int? = SEARCH_RESULT_LIMIT,
|
limit: Int? = SEARCH_RESULT_LIMIT,
|
||||||
full: Boolean = true,
|
full: Boolean = true,
|
||||||
|
invalidKeywords: Array<String> = INVALID_KEYWORDS
|
||||||
): Result<List<HuggingFaceModel>>
|
): Result<List<HuggingFaceModel>>
|
||||||
|
|
||||||
suspend fun getModelDetails(modelId: String): HuggingFaceModelDetails
|
suspend fun getModelDetails(modelId: String): HuggingFaceModelDetails
|
||||||
|
|
@ -76,208 +51,3 @@ interface HuggingFaceRemoteDataSource {
|
||||||
): Result<Long>
|
): Result<Long>
|
||||||
}
|
}
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
|
||||||
private val apiService: HuggingFaceApiService
|
|
||||||
) : HuggingFaceRemoteDataSource {
|
|
||||||
|
|
||||||
override suspend fun fetchPreselectedModels(
|
|
||||||
ids: List<String>,
|
|
||||||
parallelCount: Int,
|
|
||||||
quorum: Float,
|
|
||||||
): List<HuggingFaceModelDetails> = withContext(Dispatchers.IO) {
|
|
||||||
val sem = Semaphore(parallelCount)
|
|
||||||
val results = supervisorScope {
|
|
||||||
ids.map { id ->
|
|
||||||
async {
|
|
||||||
sem.withPermit {
|
|
||||||
try {
|
|
||||||
Result.success(getModelDetails(id))
|
|
||||||
} catch (t: CancellationException) {
|
|
||||||
Result.failure(t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}.awaitAll()
|
|
||||||
}
|
|
||||||
|
|
||||||
val successes = results.mapNotNull { it.getOrNull() }
|
|
||||||
val failures = results.mapNotNull { it.exceptionOrNull() }
|
|
||||||
|
|
||||||
val total = ids.size
|
|
||||||
val failed = failures.size
|
|
||||||
val ok = successes.size
|
|
||||||
val shouldThrow = failed >= ceil(total * quorum).toInt()
|
|
||||||
|
|
||||||
if (!shouldThrow) return@withContext successes.toList()
|
|
||||||
|
|
||||||
// 1. No Network
|
|
||||||
if (failures.count { it is UnknownHostException } >= ceil(failed * 0.5).toInt()) {
|
|
||||||
throw UnknownHostException()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Time out
|
|
||||||
if (failures.count { it is SocketTimeoutException } >= ceil(failed * 0.5).toInt()) {
|
|
||||||
throw SocketTimeoutException()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. known error codes: 404/410/204
|
|
||||||
val http404ish = failures.count { (it as? HttpException)?.code() in listOf(404, 410, 204) }
|
|
||||||
if (ok == 0 && (failed > 0) && (http404ish >= ceil(failed * 0.5).toInt() || failed == total)) {
|
|
||||||
throw FileNotFoundException()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. Unknown issues
|
|
||||||
val ioMajority = failures.count {
|
|
||||||
it is IOException && it !is UnknownHostException && it !is SocketTimeoutException
|
|
||||||
} >= ceil(failed * 0.5).toInt()
|
|
||||||
if (ioMajority) {
|
|
||||||
throw IOException(failures.first { it is IOException }.message)
|
|
||||||
}
|
|
||||||
|
|
||||||
successes
|
|
||||||
}
|
|
||||||
|
|
||||||
override suspend fun searchModels(
|
|
||||||
query: String?,
|
|
||||||
filter: String?,
|
|
||||||
sort: String?,
|
|
||||||
direction: String?,
|
|
||||||
limit: Int?,
|
|
||||||
full: Boolean,
|
|
||||||
) = withContext(Dispatchers.IO) {
|
|
||||||
try {
|
|
||||||
apiService.getModels(
|
|
||||||
search = query,
|
|
||||||
filter = filter,
|
|
||||||
sort = sort,
|
|
||||||
direction = direction,
|
|
||||||
limit = limit,
|
|
||||||
full = full,
|
|
||||||
)
|
|
||||||
.filterNot { it.gated || it.private }
|
|
||||||
.filterNot {
|
|
||||||
it.getGgufFilename().let { filename ->
|
|
||||||
filename.isNullOrBlank() || INVALID_KEYWORDS.any {
|
|
||||||
filename.contains(it, ignoreCase = true)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}.let {
|
|
||||||
if (it.isEmpty()) Result.failure(FileNotFoundException())
|
|
||||||
else Result.success(it)
|
|
||||||
}
|
|
||||||
} catch (e: Exception) {
|
|
||||||
Log.e(TAG, "Error searching for models on HuggingFace: ${e.message}")
|
|
||||||
Result.failure(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override suspend fun getModelDetails(
|
|
||||||
modelId: String
|
|
||||||
) = withContext(Dispatchers.IO) {
|
|
||||||
apiService.getModelDetails(modelId)
|
|
||||||
}
|
|
||||||
|
|
||||||
override suspend fun getFileSize(
|
|
||||||
modelId: String,
|
|
||||||
filePath: String
|
|
||||||
): Result<Long> = withContext(Dispatchers.IO) {
|
|
||||||
try {
|
|
||||||
apiService.getModelFileHeader(modelId, filePath).let { resp ->
|
|
||||||
if (resp.isSuccessful) {
|
|
||||||
resp.headers()[HTTP_HEADER_CONTENT_LENGTH]?.toLongOrNull()?.let {
|
|
||||||
Result.success(it)
|
|
||||||
} ?: Result.failure(IOException("Content-Length header missing"))
|
|
||||||
} else {
|
|
||||||
Result.failure(
|
|
||||||
when (resp.code()) {
|
|
||||||
401 -> SecurityException("Model requires authentication")
|
|
||||||
404 -> FileNotFoundException("Model file not found")
|
|
||||||
else -> IOException("Failed to get file info: HTTP ${resp.code()}")
|
|
||||||
}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (e: Exception) {
|
|
||||||
Log.e(TAG, "Error getting file size for $modelId: ${e.message}")
|
|
||||||
Result.failure(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override suspend fun downloadModelFile(
|
|
||||||
context: Context,
|
|
||||||
downloadInfo: HuggingFaceDownloadInfo,
|
|
||||||
): Result<Long> = withContext(Dispatchers.IO) {
|
|
||||||
try {
|
|
||||||
val downloadManager =
|
|
||||||
context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager
|
|
||||||
val request = DownloadManager.Request(downloadInfo.uri).apply {
|
|
||||||
setTitle(downloadInfo.filename)
|
|
||||||
setDescription("Downloading directly from HuggingFace")
|
|
||||||
setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED)
|
|
||||||
setDestinationInExternalPublicDir(
|
|
||||||
Environment.DIRECTORY_DOWNLOADS,
|
|
||||||
downloadInfo.filename
|
|
||||||
)
|
|
||||||
setAllowedNetworkTypes(
|
|
||||||
DownloadManager.Request.NETWORK_WIFI or DownloadManager.Request.NETWORK_MOBILE
|
|
||||||
)
|
|
||||||
setAllowedOverMetered(true)
|
|
||||||
setAllowedOverRoaming(false)
|
|
||||||
}
|
|
||||||
Log.d(TAG, "Enqueuing download request for: ${downloadInfo.modelId}")
|
|
||||||
val downloadId = downloadManager.enqueue(request)
|
|
||||||
|
|
||||||
delay(DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY)
|
|
||||||
|
|
||||||
val cursor = downloadManager.query(DownloadManager.Query().setFilterById(downloadId))
|
|
||||||
if (cursor != null && cursor.moveToFirst()) {
|
|
||||||
val statusIndex = cursor.getColumnIndex(DownloadManager.COLUMN_STATUS)
|
|
||||||
if (statusIndex >= 0) {
|
|
||||||
val status = cursor.getInt(statusIndex)
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
when (status) {
|
|
||||||
DownloadManager.STATUS_FAILED -> {
|
|
||||||
// Get failure reason if available
|
|
||||||
val reasonIndex = cursor.getColumnIndex(DownloadManager.COLUMN_REASON)
|
|
||||||
val reason = if (reasonIndex >= 0) cursor.getInt(reasonIndex) else -1
|
|
||||||
val errorMessage = when (reason) {
|
|
||||||
DownloadManager.ERROR_HTTP_DATA_ERROR -> "HTTP error"
|
|
||||||
DownloadManager.ERROR_INSUFFICIENT_SPACE -> "Insufficient storage"
|
|
||||||
DownloadManager.ERROR_TOO_MANY_REDIRECTS -> "Too many redirects"
|
|
||||||
DownloadManager.ERROR_UNHANDLED_HTTP_CODE -> "Unhandled HTTP code"
|
|
||||||
DownloadManager.ERROR_CANNOT_RESUME -> "Cannot resume download"
|
|
||||||
DownloadManager.ERROR_FILE_ERROR -> "File error"
|
|
||||||
else -> "Unknown error"
|
|
||||||
}
|
|
||||||
Result.failure(Exception(errorMessage))
|
|
||||||
}
|
|
||||||
else -> {
|
|
||||||
// Download is pending, paused, or running
|
|
||||||
Result.success(downloadId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Assume success if we can't check status
|
|
||||||
cursor.close()
|
|
||||||
Result.success(downloadId)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Assume success if cursor is empty
|
|
||||||
cursor?.close()
|
|
||||||
Result.success(downloadId)
|
|
||||||
}
|
|
||||||
} catch (e: Exception) {
|
|
||||||
Log.e(TAG, "Failed to enqueue download: ${e.message}")
|
|
||||||
Result.failure(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName
|
|
||||||
|
|
||||||
private const val HTTP_HEADER_CONTENT_LENGTH = "content-length"
|
|
||||||
private const val DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY = 500L
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,272 @@
|
||||||
|
package com.example.llama.data.source.remote
|
||||||
|
|
||||||
|
import android.app.DownloadManager
|
||||||
|
import android.content.Context
|
||||||
|
import android.os.Environment
|
||||||
|
import android.util.Log
|
||||||
|
import com.example.llama.monitoring.MemoryMetrics
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.async
|
||||||
|
import kotlinx.coroutines.awaitAll
|
||||||
|
import kotlinx.coroutines.delay
|
||||||
|
import kotlinx.coroutines.supervisorScope
|
||||||
|
import kotlinx.coroutines.sync.Semaphore
|
||||||
|
import kotlinx.coroutines.sync.withPermit
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
|
import retrofit2.HttpException
|
||||||
|
import java.io.FileNotFoundException
|
||||||
|
import java.io.IOException
|
||||||
|
import java.net.SocketTimeoutException
|
||||||
|
import java.net.UnknownHostException
|
||||||
|
import javax.inject.Inject
|
||||||
|
import javax.inject.Singleton
|
||||||
|
import kotlin.collections.contains
|
||||||
|
import kotlin.coroutines.cancellation.CancellationException
|
||||||
|
import kotlin.math.ceil
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Preselected models: sized <2GB
|
||||||
|
*/
|
||||||
|
private val PRESELECTED_MODEL_IDS_SMALL = listOf(
|
||||||
|
"bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||||
|
"unsloth/gemma-3-1b-it-GGUF",
|
||||||
|
"bartowski/granite-3.0-2b-instruct-GGUF",
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Preselected models: sized 2~3GB
|
||||||
|
*/
|
||||||
|
private val PRESELECTED_MODEL_IDS_MEDIUM = listOf(
|
||||||
|
"bartowski/Llama-3.2-3B-Instruct-GGUF",
|
||||||
|
"unsloth/gemma-3n-E2B-it-GGUF",
|
||||||
|
"Qwen/Qwen2.5-3B-Instruct-GGUF",
|
||||||
|
"gaianet/Phi-4-mini-instruct-GGUF",
|
||||||
|
"unsloth/gemma-3-4b-it-GGUF",
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Preselected models: sized 4~6B
|
||||||
|
*/
|
||||||
|
private val PRESELECTED_MODEL_IDS_LARGE = listOf(
|
||||||
|
"unsloth/gemma-3n-E4B-it-GGUF",
|
||||||
|
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||||
|
)
|
||||||
|
|
||||||
|
@Singleton
|
||||||
|
class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
||||||
|
private val apiService: HuggingFaceApiService
|
||||||
|
) : HuggingFaceRemoteDataSource {
|
||||||
|
|
||||||
|
override suspend fun fetchPreselectedModels(
|
||||||
|
memoryUsage: MemoryMetrics,
|
||||||
|
parallelCount: Int,
|
||||||
|
quorum: Float,
|
||||||
|
): List<HuggingFaceModelDetails> = withContext(Dispatchers.IO) {
|
||||||
|
val ids: List<String> = when {
|
||||||
|
memoryUsage.availableGB >= 7f ->
|
||||||
|
PRESELECTED_MODEL_IDS_MEDIUM + PRESELECTED_MODEL_IDS_LARGE + PRESELECTED_MODEL_IDS_SMALL
|
||||||
|
memoryUsage.availableGB >= 5f ->
|
||||||
|
PRESELECTED_MODEL_IDS_SMALL + PRESELECTED_MODEL_IDS_MEDIUM + PRESELECTED_MODEL_IDS_LARGE
|
||||||
|
memoryUsage.availableGB >= 3f ->
|
||||||
|
PRESELECTED_MODEL_IDS_SMALL + PRESELECTED_MODEL_IDS_MEDIUM
|
||||||
|
else ->
|
||||||
|
PRESELECTED_MODEL_IDS_SMALL
|
||||||
|
}
|
||||||
|
|
||||||
|
val sem = Semaphore(parallelCount)
|
||||||
|
val results = supervisorScope {
|
||||||
|
ids.map { id ->
|
||||||
|
async {
|
||||||
|
sem.withPermit {
|
||||||
|
try {
|
||||||
|
Result.success(getModelDetails(id))
|
||||||
|
} catch (t: CancellationException) {
|
||||||
|
Result.failure(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}.awaitAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
val successes = results.mapNotNull { it.getOrNull() }
|
||||||
|
val failures = results.mapNotNull { it.exceptionOrNull() }
|
||||||
|
|
||||||
|
val total = ids.size
|
||||||
|
val failed = failures.size
|
||||||
|
val ok = successes.size
|
||||||
|
val shouldThrow = failed >= ceil(total * quorum).toInt()
|
||||||
|
|
||||||
|
if (!shouldThrow) return@withContext successes.toList()
|
||||||
|
|
||||||
|
// 1. No Network
|
||||||
|
if (failures.count { it is UnknownHostException } >= ceil(failed * 0.5).toInt()) {
|
||||||
|
throw UnknownHostException()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Time out
|
||||||
|
if (failures.count { it is SocketTimeoutException } >= ceil(failed * 0.5).toInt()) {
|
||||||
|
throw SocketTimeoutException()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. known error codes: 404/410/204
|
||||||
|
val http404ish = failures.count { (it as? HttpException)?.code() in listOf(404, 410, 204) }
|
||||||
|
if (ok == 0 && (failed > 0) && (http404ish >= ceil(failed * 0.5).toInt() || failed == total)) {
|
||||||
|
throw FileNotFoundException()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Unknown issues
|
||||||
|
val ioMajority = failures.count {
|
||||||
|
it is IOException && it !is UnknownHostException && it !is SocketTimeoutException
|
||||||
|
} >= ceil(failed * 0.5).toInt()
|
||||||
|
if (ioMajority) {
|
||||||
|
throw IOException(failures.first { it is IOException }.message)
|
||||||
|
}
|
||||||
|
|
||||||
|
successes
|
||||||
|
}
|
||||||
|
|
||||||
|
override suspend fun searchModels(
|
||||||
|
query: String?,
|
||||||
|
filter: String?,
|
||||||
|
sort: String?,
|
||||||
|
direction: String?,
|
||||||
|
limit: Int?,
|
||||||
|
full: Boolean,
|
||||||
|
invalidKeywords: Array<String>,
|
||||||
|
) = withContext(Dispatchers.IO) {
|
||||||
|
try {
|
||||||
|
apiService.getModels(
|
||||||
|
search = query,
|
||||||
|
filter = filter,
|
||||||
|
sort = sort,
|
||||||
|
direction = direction,
|
||||||
|
limit = limit,
|
||||||
|
full = full,
|
||||||
|
)
|
||||||
|
.filterNot { it.gated || it.private }
|
||||||
|
.filterNot {
|
||||||
|
it.getGgufFilename().let { filename ->
|
||||||
|
filename.isNullOrBlank() || invalidKeywords.any {
|
||||||
|
filename.contains(it, ignoreCase = true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}.let {
|
||||||
|
if (it.isEmpty()) Result.failure(FileNotFoundException())
|
||||||
|
else Result.success(it)
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Error searching for models on HuggingFace: ${e.message}")
|
||||||
|
Result.failure(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override suspend fun getModelDetails(
|
||||||
|
modelId: String
|
||||||
|
) = withContext(Dispatchers.IO) {
|
||||||
|
apiService.getModelDetails(modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
override suspend fun getFileSize(
|
||||||
|
modelId: String,
|
||||||
|
filePath: String
|
||||||
|
): Result<Long> = withContext(Dispatchers.IO) {
|
||||||
|
try {
|
||||||
|
apiService.getModelFileHeader(modelId, filePath).let { resp ->
|
||||||
|
if (resp.isSuccessful) {
|
||||||
|
resp.headers()[HTTP_HEADER_CONTENT_LENGTH]?.toLongOrNull()?.let {
|
||||||
|
Result.success(it)
|
||||||
|
} ?: Result.failure(IOException("Content-Length header missing"))
|
||||||
|
} else {
|
||||||
|
Result.failure(
|
||||||
|
when (resp.code()) {
|
||||||
|
401 -> SecurityException("Model requires authentication")
|
||||||
|
404 -> FileNotFoundException("Model file not found")
|
||||||
|
else -> IOException("Failed to get file info: HTTP ${resp.code()}")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Error getting file size for $modelId: ${e.message}")
|
||||||
|
Result.failure(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override suspend fun downloadModelFile(
|
||||||
|
context: Context,
|
||||||
|
downloadInfo: HuggingFaceDownloadInfo,
|
||||||
|
): Result<Long> = withContext(Dispatchers.IO) {
|
||||||
|
try {
|
||||||
|
val downloadManager =
|
||||||
|
context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager
|
||||||
|
val request = DownloadManager.Request(downloadInfo.uri).apply {
|
||||||
|
setTitle(downloadInfo.filename)
|
||||||
|
setDescription("Downloading directly from HuggingFace")
|
||||||
|
setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED)
|
||||||
|
setDestinationInExternalPublicDir(
|
||||||
|
Environment.DIRECTORY_DOWNLOADS,
|
||||||
|
downloadInfo.filename
|
||||||
|
)
|
||||||
|
setAllowedNetworkTypes(
|
||||||
|
DownloadManager.Request.NETWORK_WIFI or DownloadManager.Request.NETWORK_MOBILE
|
||||||
|
)
|
||||||
|
setAllowedOverMetered(true)
|
||||||
|
setAllowedOverRoaming(false)
|
||||||
|
}
|
||||||
|
Log.d(TAG, "Enqueuing download request for: ${downloadInfo.modelId}")
|
||||||
|
val downloadId = downloadManager.enqueue(request)
|
||||||
|
|
||||||
|
delay(DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY)
|
||||||
|
|
||||||
|
val cursor = downloadManager.query(DownloadManager.Query().setFilterById(downloadId))
|
||||||
|
if (cursor != null && cursor.moveToFirst()) {
|
||||||
|
val statusIndex = cursor.getColumnIndex(DownloadManager.COLUMN_STATUS)
|
||||||
|
if (statusIndex >= 0) {
|
||||||
|
val status = cursor.getInt(statusIndex)
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
when (status) {
|
||||||
|
DownloadManager.STATUS_FAILED -> {
|
||||||
|
// Get failure reason if available
|
||||||
|
val reasonIndex = cursor.getColumnIndex(DownloadManager.COLUMN_REASON)
|
||||||
|
val reason = if (reasonIndex >= 0) cursor.getInt(reasonIndex) else -1
|
||||||
|
val errorMessage = when (reason) {
|
||||||
|
DownloadManager.ERROR_HTTP_DATA_ERROR -> "HTTP error"
|
||||||
|
DownloadManager.ERROR_INSUFFICIENT_SPACE -> "Insufficient storage"
|
||||||
|
DownloadManager.ERROR_TOO_MANY_REDIRECTS -> "Too many redirects"
|
||||||
|
DownloadManager.ERROR_UNHANDLED_HTTP_CODE -> "Unhandled HTTP code"
|
||||||
|
DownloadManager.ERROR_CANNOT_RESUME -> "Cannot resume download"
|
||||||
|
DownloadManager.ERROR_FILE_ERROR -> "File error"
|
||||||
|
else -> "Unknown error"
|
||||||
|
}
|
||||||
|
Result.failure(Exception(errorMessage))
|
||||||
|
}
|
||||||
|
else -> {
|
||||||
|
// Download is pending, paused, or running
|
||||||
|
Result.success(downloadId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Assume success if we can't check status
|
||||||
|
cursor.close()
|
||||||
|
Result.success(downloadId)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Assume success if cursor is empty
|
||||||
|
cursor?.close()
|
||||||
|
Result.success(downloadId)
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Failed to enqueue download: ${e.message}")
|
||||||
|
Result.failure(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName
|
||||||
|
|
||||||
|
private const val HTTP_HEADER_CONTENT_LENGTH = "content-length"
|
||||||
|
private const val DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY = 500L
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -17,6 +17,7 @@ import com.example.llama.data.repo.ModelRepository
|
||||||
import com.example.llama.data.source.remote.HuggingFaceDownloadInfo
|
import com.example.llama.data.source.remote.HuggingFaceDownloadInfo
|
||||||
import com.example.llama.data.source.remote.HuggingFaceModel
|
import com.example.llama.data.source.remote.HuggingFaceModel
|
||||||
import com.example.llama.data.source.remote.HuggingFaceModelDetails
|
import com.example.llama.data.source.remote.HuggingFaceModelDetails
|
||||||
|
import com.example.llama.monitoring.MemoryMetrics
|
||||||
import com.example.llama.util.formatFileByteSize
|
import com.example.llama.util.formatFileByteSize
|
||||||
import com.example.llama.util.getFileNameFromUri
|
import com.example.llama.util.getFileNameFromUri
|
||||||
import com.example.llama.util.getFileSizeFromUri
|
import com.example.llama.util.getFileSizeFromUri
|
||||||
|
|
@ -183,11 +184,12 @@ class ModelsManagementViewModel @Inject constructor(
|
||||||
/**
|
/**
|
||||||
* Query models on HuggingFace available for download even without signing in
|
* Query models on HuggingFace available for download even without signing in
|
||||||
*/
|
*/
|
||||||
fun queryModelsFromHuggingFace(cap: Int = FETCH_HUGGINGFACE_MODELS_CAP_SIZE) {
|
fun queryModelsFromHuggingFace(memoryUsage: MemoryMetrics) {
|
||||||
huggingFaceQueryJob = viewModelScope.launch {
|
huggingFaceQueryJob = viewModelScope.launch {
|
||||||
_managementState.emit(Download.Querying)
|
_managementState.emit(Download.Querying)
|
||||||
try {
|
try {
|
||||||
val models = modelRepository.fetchPreselectedHuggingFaceModels().map(HuggingFaceModelDetails::toModel)
|
val models = modelRepository.fetchPreselectedHuggingFaceModels(memoryUsage)
|
||||||
|
.map(HuggingFaceModelDetails::toModel)
|
||||||
_managementState.emit(Download.Ready(models))
|
_managementState.emit(Download.Ready(models))
|
||||||
} catch (_: CancellationException) {
|
} catch (_: CancellationException) {
|
||||||
resetManagementState()
|
resetManagementState()
|
||||||
|
|
@ -300,7 +302,6 @@ class ModelsManagementViewModel @Inject constructor(
|
||||||
private val TAG = ModelsManagementViewModel::class.java.simpleName
|
private val TAG = ModelsManagementViewModel::class.java.simpleName
|
||||||
|
|
||||||
private const val FETCH_HUGGINGFACE_MODELS_LIMIT_SIZE = 50
|
private const val FETCH_HUGGINGFACE_MODELS_LIMIT_SIZE = 50
|
||||||
private const val FETCH_HUGGINGFACE_MODELS_CAP_SIZE = 12
|
|
||||||
private const val DELETE_SUCCESS_RESET_TIMEOUT_MS = 1000L
|
private const val DELETE_SUCCESS_RESET_TIMEOUT_MS = 1000L
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue