data: sort preselected models according to device's available RAM

This commit is contained in:
Han Yin 2025-09-04 14:14:44 -07:00
parent 687b86e924
commit e0ddc37e2e
6 changed files with 311 additions and 267 deletions

View File

@ -344,7 +344,7 @@ fun AppContent(
modelsManagementViewModel.toggleImportMenu(false)
},
importFromHuggingFace = {
modelsManagementViewModel.queryModelsFromHuggingFace()
modelsManagementViewModel.queryModelsFromHuggingFace(memoryUsage)
modelsManagementViewModel.toggleImportMenu(false)
}
),

View File

@ -141,56 +141,56 @@ enum class ModelFilter(val displayName: String, val predicate: (ModelInfo) -> Bo
}),
SMALL_PARAMS("Small (1-3B parameters)", {
it.metadata.basic.sizeLabel?.let { size ->
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 1f && n <= 3f } == true
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 1f && n < 4f } == true
} == true
}),
MEDIUM_PARAMS("Medium (4-7B parameters)", {
it.metadata.basic.sizeLabel?.let { size ->
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 4f && n <= 7f } == true
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 4f && n < 8f } == true
} == true
}),
LARGE_PARAMS("Large (8-13B parameters)", {
LARGE_PARAMS("Large (8-12B parameters)", {
it.metadata.basic.sizeLabel?.let { size ->
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 8f && n <= 13f } == true
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 8f && n < 13f } == true
} == true
}),
XLARGE_PARAMS("X-Large (>13B parameters)", {
it.metadata.basic.sizeLabel?.let { size ->
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n > 13f } == true
size.contains("B") && size.replace("B", "").toFloatOrNull()?.let { n -> n >= 13f } == true
} == true
}),
// Context length filters
TINY_CONTEXT("Tiny context (<4K)", {
it.metadata.dimensions?.contextLength?.let { it < 4096 } == true
TINY_CONTEXT("Tiny context (<4K)", { model ->
model.metadata.dimensions?.contextLength?.let { it < 4096 } == true
}),
SHORT_CONTEXT("Short context (4-8K)", {
it.metadata.dimensions?.contextLength?.let { it >= 4096 && it <= 8192 } == true
SHORT_CONTEXT("Short context (4-8K)", { model ->
model.metadata.dimensions?.contextLength?.let { it >= 4096 && it <= 8192 } == true
}),
MEDIUM_CONTEXT("Medium context (8-32K)", {
it.metadata.dimensions?.contextLength?.let { it > 8192 && it <= 32768 } == true
MEDIUM_CONTEXT("Medium context (8-32K)", { model ->
model.metadata.dimensions?.contextLength?.let { it > 8192 && it <= 32768 } == true
}),
LONG_CONTEXT("Long context (32-128K)", {
it.metadata.dimensions?.contextLength?.let { it > 32768 && it <= 131072 } == true
LONG_CONTEXT("Long context (32-128K)", { model ->
model.metadata.dimensions?.contextLength?.let { it > 32768 && it <= 131072 } == true
}),
XLARGE_CONTEXT("Extended context (>128K)", {
it.metadata.dimensions?.contextLength?.let { it > 131072 } == true
XLARGE_CONTEXT("Extended context (>128K)", { model ->
model.metadata.dimensions?.contextLength?.let { it > 131072 } == true
}),
// Quantization filters
INT2_QUANT("2-bit quantization", {
it.formattedQuantization.let { it.contains("Q2") || it.contains("IQ2") }
INT2_QUANT("2-bit quantization", { model ->
model.formattedQuantization.let { it.contains("Q2") || it.contains("IQ2") }
}),
INT3_QUANT("3-bit quantization", {
it.formattedQuantization.let { it.contains("Q3") || it.contains("IQ3") }
INT3_QUANT("3-bit quantization", { model ->
model.formattedQuantization.let { it.contains("Q3") || it.contains("IQ3") }
}),
INT4_QUANT("4-bit quantization", {
it.formattedQuantization.let { it.contains("Q4") || it.contains("IQ4") }
INT4_QUANT("4-bit quantization", { model ->
model.formattedQuantization.let { it.contains("Q4") || it.contains("IQ4") }
}),
// Special features
MULTILINGUAL("Multilingual", {
it.languages?.let { languages ->
MULTILINGUAL("Multilingual", { model ->
model.languages?.let { languages ->
languages.size > 1 || languages.any { it.contains("multi", ignoreCase = true) }
} == true
}),

View File

@ -16,6 +16,7 @@ import com.example.llama.data.source.remote.HuggingFaceDownloadInfo
import com.example.llama.data.source.remote.HuggingFaceModel
import com.example.llama.data.source.remote.HuggingFaceModelDetails
import com.example.llama.data.source.remote.HuggingFaceRemoteDataSource
import com.example.llama.monitoring.MemoryMetrics
import com.example.llama.monitoring.StorageMetrics
import com.example.llama.util.formatFileByteSize
import dagger.hilt.android.qualifiers.ApplicationContext
@ -83,7 +84,7 @@ interface ModelRepository {
/**
* Fetch details of preselected models
*/
suspend fun fetchPreselectedHuggingFaceModels(): List<HuggingFaceModelDetails>
suspend fun fetchPreselectedHuggingFaceModels(memoryUsage: MemoryMetrics): List<HuggingFaceModelDetails>
/**
* Search models on HuggingFace
@ -321,8 +322,8 @@ class ModelRepositoryImpl @Inject constructor(
}
}
override suspend fun fetchPreselectedHuggingFaceModels() = withContext(Dispatchers.IO) {
huggingFaceRemoteDataSource.fetchPreselectedModels()
override suspend fun fetchPreselectedHuggingFaceModels(memoryUsage: MemoryMetrics) = withContext(Dispatchers.IO) {
huggingFaceRemoteDataSource.fetchPreselectedModels(memoryUsage)
}
override suspend fun searchHuggingFaceModels(

View File

@ -1,27 +1,12 @@
package com.example.llama.data.source.remote
import android.app.DownloadManager
import android.content.Context
import android.os.Environment
import android.util.Log
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.delay
import kotlinx.coroutines.supervisorScope
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.sync.withPermit
import kotlinx.coroutines.withContext
import retrofit2.HttpException
import java.io.FileNotFoundException
import java.io.IOException
import java.net.SocketTimeoutException
import java.net.UnknownHostException
import javax.inject.Inject
import javax.inject.Singleton
import kotlin.coroutines.cancellation.CancellationException
import kotlin.math.ceil
import com.example.llama.monitoring.MemoryMetrics
/*
* HuggingFace Search API
*/
private const val QUERY_Q4_0_GGUF = "gguf q4_0"
private const val FILTER_TEXT_GENERATION = "text-generation"
private const val SORT_BY_DOWNLOADS = "downloads"
@ -29,21 +14,10 @@ private const val SEARCH_RESULT_LIMIT = 30
private val INVALID_KEYWORDS = arrayOf("-of-", "split", "70B", "30B", "27B", "14B", "13B", "12B")
private val PRESELECTED_MODEL_IDS = listOf(
"unsloth/gemma-3-1b-it-GGUF",
"unsloth/gemma-3-4b-it-GGUF",
"bartowski/Llama-3.2-1B-Instruct-GGUF",
"bartowski/Llama-3.2-3B-Instruct-GGUF",
"Qwen/Qwen2.5-3B-Instruct-GGUF",
"gaianet/Phi-4-mini-instruct-GGUF",
"bartowski/granite-3.0-2b-instruct-GGUF",
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
)
interface HuggingFaceRemoteDataSource {
suspend fun fetchPreselectedModels(
ids: List<String> = PRESELECTED_MODEL_IDS,
memoryUsage: MemoryMetrics,
parallelCount: Int = 3,
quorum: Float = 0.5f,
): List<HuggingFaceModelDetails>
@ -58,6 +32,7 @@ interface HuggingFaceRemoteDataSource {
direction: String? = "-1",
limit: Int? = SEARCH_RESULT_LIMIT,
full: Boolean = true,
invalidKeywords: Array<String> = INVALID_KEYWORDS
): Result<List<HuggingFaceModel>>
suspend fun getModelDetails(modelId: String): HuggingFaceModelDetails
@ -76,208 +51,3 @@ interface HuggingFaceRemoteDataSource {
): Result<Long>
}
@Singleton
class HuggingFaceRemoteDataSourceImpl @Inject constructor(
private val apiService: HuggingFaceApiService
) : HuggingFaceRemoteDataSource {
override suspend fun fetchPreselectedModels(
ids: List<String>,
parallelCount: Int,
quorum: Float,
): List<HuggingFaceModelDetails> = withContext(Dispatchers.IO) {
val sem = Semaphore(parallelCount)
val results = supervisorScope {
ids.map { id ->
async {
sem.withPermit {
try {
Result.success(getModelDetails(id))
} catch (t: CancellationException) {
Result.failure(t)
}
}
}
}.awaitAll()
}
val successes = results.mapNotNull { it.getOrNull() }
val failures = results.mapNotNull { it.exceptionOrNull() }
val total = ids.size
val failed = failures.size
val ok = successes.size
val shouldThrow = failed >= ceil(total * quorum).toInt()
if (!shouldThrow) return@withContext successes.toList()
// 1. No Network
if (failures.count { it is UnknownHostException } >= ceil(failed * 0.5).toInt()) {
throw UnknownHostException()
}
// 2. Time out
if (failures.count { it is SocketTimeoutException } >= ceil(failed * 0.5).toInt()) {
throw SocketTimeoutException()
}
// 3. known error codes: 404/410/204
val http404ish = failures.count { (it as? HttpException)?.code() in listOf(404, 410, 204) }
if (ok == 0 && (failed > 0) && (http404ish >= ceil(failed * 0.5).toInt() || failed == total)) {
throw FileNotFoundException()
}
// 4. Unknown issues
val ioMajority = failures.count {
it is IOException && it !is UnknownHostException && it !is SocketTimeoutException
} >= ceil(failed * 0.5).toInt()
if (ioMajority) {
throw IOException(failures.first { it is IOException }.message)
}
successes
}
override suspend fun searchModels(
query: String?,
filter: String?,
sort: String?,
direction: String?,
limit: Int?,
full: Boolean,
) = withContext(Dispatchers.IO) {
try {
apiService.getModels(
search = query,
filter = filter,
sort = sort,
direction = direction,
limit = limit,
full = full,
)
.filterNot { it.gated || it.private }
.filterNot {
it.getGgufFilename().let { filename ->
filename.isNullOrBlank() || INVALID_KEYWORDS.any {
filename.contains(it, ignoreCase = true)
}
}
}.let {
if (it.isEmpty()) Result.failure(FileNotFoundException())
else Result.success(it)
}
} catch (e: Exception) {
Log.e(TAG, "Error searching for models on HuggingFace: ${e.message}")
Result.failure(e)
}
}
override suspend fun getModelDetails(
modelId: String
) = withContext(Dispatchers.IO) {
apiService.getModelDetails(modelId)
}
override suspend fun getFileSize(
modelId: String,
filePath: String
): Result<Long> = withContext(Dispatchers.IO) {
try {
apiService.getModelFileHeader(modelId, filePath).let { resp ->
if (resp.isSuccessful) {
resp.headers()[HTTP_HEADER_CONTENT_LENGTH]?.toLongOrNull()?.let {
Result.success(it)
} ?: Result.failure(IOException("Content-Length header missing"))
} else {
Result.failure(
when (resp.code()) {
401 -> SecurityException("Model requires authentication")
404 -> FileNotFoundException("Model file not found")
else -> IOException("Failed to get file info: HTTP ${resp.code()}")
}
)
}
}
} catch (e: Exception) {
Log.e(TAG, "Error getting file size for $modelId: ${e.message}")
Result.failure(e)
}
}
override suspend fun downloadModelFile(
context: Context,
downloadInfo: HuggingFaceDownloadInfo,
): Result<Long> = withContext(Dispatchers.IO) {
try {
val downloadManager =
context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager
val request = DownloadManager.Request(downloadInfo.uri).apply {
setTitle(downloadInfo.filename)
setDescription("Downloading directly from HuggingFace")
setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED)
setDestinationInExternalPublicDir(
Environment.DIRECTORY_DOWNLOADS,
downloadInfo.filename
)
setAllowedNetworkTypes(
DownloadManager.Request.NETWORK_WIFI or DownloadManager.Request.NETWORK_MOBILE
)
setAllowedOverMetered(true)
setAllowedOverRoaming(false)
}
Log.d(TAG, "Enqueuing download request for: ${downloadInfo.modelId}")
val downloadId = downloadManager.enqueue(request)
delay(DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY)
val cursor = downloadManager.query(DownloadManager.Query().setFilterById(downloadId))
if (cursor != null && cursor.moveToFirst()) {
val statusIndex = cursor.getColumnIndex(DownloadManager.COLUMN_STATUS)
if (statusIndex >= 0) {
val status = cursor.getInt(statusIndex)
cursor.close()
when (status) {
DownloadManager.STATUS_FAILED -> {
// Get failure reason if available
val reasonIndex = cursor.getColumnIndex(DownloadManager.COLUMN_REASON)
val reason = if (reasonIndex >= 0) cursor.getInt(reasonIndex) else -1
val errorMessage = when (reason) {
DownloadManager.ERROR_HTTP_DATA_ERROR -> "HTTP error"
DownloadManager.ERROR_INSUFFICIENT_SPACE -> "Insufficient storage"
DownloadManager.ERROR_TOO_MANY_REDIRECTS -> "Too many redirects"
DownloadManager.ERROR_UNHANDLED_HTTP_CODE -> "Unhandled HTTP code"
DownloadManager.ERROR_CANNOT_RESUME -> "Cannot resume download"
DownloadManager.ERROR_FILE_ERROR -> "File error"
else -> "Unknown error"
}
Result.failure(Exception(errorMessage))
}
else -> {
// Download is pending, paused, or running
Result.success(downloadId)
}
}
} else {
// Assume success if we can't check status
cursor.close()
Result.success(downloadId)
}
} else {
// Assume success if cursor is empty
cursor?.close()
Result.success(downloadId)
}
} catch (e: Exception) {
Log.e(TAG, "Failed to enqueue download: ${e.message}")
Result.failure(e)
}
}
companion object {
private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName
private const val HTTP_HEADER_CONTENT_LENGTH = "content-length"
private const val DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY = 500L
}
}

View File

@ -0,0 +1,272 @@
package com.example.llama.data.source.remote
import android.app.DownloadManager
import android.content.Context
import android.os.Environment
import android.util.Log
import com.example.llama.monitoring.MemoryMetrics
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.delay
import kotlinx.coroutines.supervisorScope
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.sync.withPermit
import kotlinx.coroutines.withContext
import retrofit2.HttpException
import java.io.FileNotFoundException
import java.io.IOException
import java.net.SocketTimeoutException
import java.net.UnknownHostException
import javax.inject.Inject
import javax.inject.Singleton
import kotlin.collections.contains
import kotlin.coroutines.cancellation.CancellationException
import kotlin.math.ceil
/*
* Preselected models: sized <2GB
*/
private val PRESELECTED_MODEL_IDS_SMALL = listOf(
"bartowski/Llama-3.2-1B-Instruct-GGUF",
"unsloth/gemma-3-1b-it-GGUF",
"bartowski/granite-3.0-2b-instruct-GGUF",
)
/*
* Preselected models: sized 2~3GB
*/
private val PRESELECTED_MODEL_IDS_MEDIUM = listOf(
"bartowski/Llama-3.2-3B-Instruct-GGUF",
"unsloth/gemma-3n-E2B-it-GGUF",
"Qwen/Qwen2.5-3B-Instruct-GGUF",
"gaianet/Phi-4-mini-instruct-GGUF",
"unsloth/gemma-3-4b-it-GGUF",
)
/*
* Preselected models: sized 4~6B
*/
private val PRESELECTED_MODEL_IDS_LARGE = listOf(
"unsloth/gemma-3n-E4B-it-GGUF",
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
)
@Singleton
class HuggingFaceRemoteDataSourceImpl @Inject constructor(
private val apiService: HuggingFaceApiService
) : HuggingFaceRemoteDataSource {
override suspend fun fetchPreselectedModels(
memoryUsage: MemoryMetrics,
parallelCount: Int,
quorum: Float,
): List<HuggingFaceModelDetails> = withContext(Dispatchers.IO) {
val ids: List<String> = when {
memoryUsage.availableGB >= 7f ->
PRESELECTED_MODEL_IDS_MEDIUM + PRESELECTED_MODEL_IDS_LARGE + PRESELECTED_MODEL_IDS_SMALL
memoryUsage.availableGB >= 5f ->
PRESELECTED_MODEL_IDS_SMALL + PRESELECTED_MODEL_IDS_MEDIUM + PRESELECTED_MODEL_IDS_LARGE
memoryUsage.availableGB >= 3f ->
PRESELECTED_MODEL_IDS_SMALL + PRESELECTED_MODEL_IDS_MEDIUM
else ->
PRESELECTED_MODEL_IDS_SMALL
}
val sem = Semaphore(parallelCount)
val results = supervisorScope {
ids.map { id ->
async {
sem.withPermit {
try {
Result.success(getModelDetails(id))
} catch (t: CancellationException) {
Result.failure(t)
}
}
}
}.awaitAll()
}
val successes = results.mapNotNull { it.getOrNull() }
val failures = results.mapNotNull { it.exceptionOrNull() }
val total = ids.size
val failed = failures.size
val ok = successes.size
val shouldThrow = failed >= ceil(total * quorum).toInt()
if (!shouldThrow) return@withContext successes.toList()
// 1. No Network
if (failures.count { it is UnknownHostException } >= ceil(failed * 0.5).toInt()) {
throw UnknownHostException()
}
// 2. Time out
if (failures.count { it is SocketTimeoutException } >= ceil(failed * 0.5).toInt()) {
throw SocketTimeoutException()
}
// 3. known error codes: 404/410/204
val http404ish = failures.count { (it as? HttpException)?.code() in listOf(404, 410, 204) }
if (ok == 0 && (failed > 0) && (http404ish >= ceil(failed * 0.5).toInt() || failed == total)) {
throw FileNotFoundException()
}
// 4. Unknown issues
val ioMajority = failures.count {
it is IOException && it !is UnknownHostException && it !is SocketTimeoutException
} >= ceil(failed * 0.5).toInt()
if (ioMajority) {
throw IOException(failures.first { it is IOException }.message)
}
successes
}
override suspend fun searchModels(
query: String?,
filter: String?,
sort: String?,
direction: String?,
limit: Int?,
full: Boolean,
invalidKeywords: Array<String>,
) = withContext(Dispatchers.IO) {
try {
apiService.getModels(
search = query,
filter = filter,
sort = sort,
direction = direction,
limit = limit,
full = full,
)
.filterNot { it.gated || it.private }
.filterNot {
it.getGgufFilename().let { filename ->
filename.isNullOrBlank() || invalidKeywords.any {
filename.contains(it, ignoreCase = true)
}
}
}.let {
if (it.isEmpty()) Result.failure(FileNotFoundException())
else Result.success(it)
}
} catch (e: Exception) {
Log.e(TAG, "Error searching for models on HuggingFace: ${e.message}")
Result.failure(e)
}
}
override suspend fun getModelDetails(
modelId: String
) = withContext(Dispatchers.IO) {
apiService.getModelDetails(modelId)
}
override suspend fun getFileSize(
modelId: String,
filePath: String
): Result<Long> = withContext(Dispatchers.IO) {
try {
apiService.getModelFileHeader(modelId, filePath).let { resp ->
if (resp.isSuccessful) {
resp.headers()[HTTP_HEADER_CONTENT_LENGTH]?.toLongOrNull()?.let {
Result.success(it)
} ?: Result.failure(IOException("Content-Length header missing"))
} else {
Result.failure(
when (resp.code()) {
401 -> SecurityException("Model requires authentication")
404 -> FileNotFoundException("Model file not found")
else -> IOException("Failed to get file info: HTTP ${resp.code()}")
}
)
}
}
} catch (e: Exception) {
Log.e(TAG, "Error getting file size for $modelId: ${e.message}")
Result.failure(e)
}
}
override suspend fun downloadModelFile(
context: Context,
downloadInfo: HuggingFaceDownloadInfo,
): Result<Long> = withContext(Dispatchers.IO) {
try {
val downloadManager =
context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager
val request = DownloadManager.Request(downloadInfo.uri).apply {
setTitle(downloadInfo.filename)
setDescription("Downloading directly from HuggingFace")
setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED)
setDestinationInExternalPublicDir(
Environment.DIRECTORY_DOWNLOADS,
downloadInfo.filename
)
setAllowedNetworkTypes(
DownloadManager.Request.NETWORK_WIFI or DownloadManager.Request.NETWORK_MOBILE
)
setAllowedOverMetered(true)
setAllowedOverRoaming(false)
}
Log.d(TAG, "Enqueuing download request for: ${downloadInfo.modelId}")
val downloadId = downloadManager.enqueue(request)
delay(DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY)
val cursor = downloadManager.query(DownloadManager.Query().setFilterById(downloadId))
if (cursor != null && cursor.moveToFirst()) {
val statusIndex = cursor.getColumnIndex(DownloadManager.COLUMN_STATUS)
if (statusIndex >= 0) {
val status = cursor.getInt(statusIndex)
cursor.close()
when (status) {
DownloadManager.STATUS_FAILED -> {
// Get failure reason if available
val reasonIndex = cursor.getColumnIndex(DownloadManager.COLUMN_REASON)
val reason = if (reasonIndex >= 0) cursor.getInt(reasonIndex) else -1
val errorMessage = when (reason) {
DownloadManager.ERROR_HTTP_DATA_ERROR -> "HTTP error"
DownloadManager.ERROR_INSUFFICIENT_SPACE -> "Insufficient storage"
DownloadManager.ERROR_TOO_MANY_REDIRECTS -> "Too many redirects"
DownloadManager.ERROR_UNHANDLED_HTTP_CODE -> "Unhandled HTTP code"
DownloadManager.ERROR_CANNOT_RESUME -> "Cannot resume download"
DownloadManager.ERROR_FILE_ERROR -> "File error"
else -> "Unknown error"
}
Result.failure(Exception(errorMessage))
}
else -> {
// Download is pending, paused, or running
Result.success(downloadId)
}
}
} else {
// Assume success if we can't check status
cursor.close()
Result.success(downloadId)
}
} else {
// Assume success if cursor is empty
cursor?.close()
Result.success(downloadId)
}
} catch (e: Exception) {
Log.e(TAG, "Failed to enqueue download: ${e.message}")
Result.failure(e)
}
}
companion object {
private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName
private const val HTTP_HEADER_CONTENT_LENGTH = "content-length"
private const val DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY = 500L
}
}

View File

@ -17,6 +17,7 @@ import com.example.llama.data.repo.ModelRepository
import com.example.llama.data.source.remote.HuggingFaceDownloadInfo
import com.example.llama.data.source.remote.HuggingFaceModel
import com.example.llama.data.source.remote.HuggingFaceModelDetails
import com.example.llama.monitoring.MemoryMetrics
import com.example.llama.util.formatFileByteSize
import com.example.llama.util.getFileNameFromUri
import com.example.llama.util.getFileSizeFromUri
@ -183,11 +184,12 @@ class ModelsManagementViewModel @Inject constructor(
/**
* Query models on HuggingFace available for download even without signing in
*/
fun queryModelsFromHuggingFace(cap: Int = FETCH_HUGGINGFACE_MODELS_CAP_SIZE) {
fun queryModelsFromHuggingFace(memoryUsage: MemoryMetrics) {
huggingFaceQueryJob = viewModelScope.launch {
_managementState.emit(Download.Querying)
try {
val models = modelRepository.fetchPreselectedHuggingFaceModels().map(HuggingFaceModelDetails::toModel)
val models = modelRepository.fetchPreselectedHuggingFaceModels(memoryUsage)
.map(HuggingFaceModelDetails::toModel)
_managementState.emit(Download.Ready(models))
} catch (_: CancellationException) {
resetManagementState()
@ -300,7 +302,6 @@ class ModelsManagementViewModel @Inject constructor(
private val TAG = ModelsManagementViewModel::class.java.simpleName
private const val FETCH_HUGGINGFACE_MODELS_LIMIT_SIZE = 50
private const val FETCH_HUGGINGFACE_MODELS_CAP_SIZE = 12
private const val DELETE_SUCCESS_RESET_TIMEOUT_MS = 1000L
}
}