UI: address Rojin's UX feedbacks - part 2
This commit is contained in:
parent
e067f7051b
commit
8268d70518
|
|
@ -80,6 +80,11 @@ interface ModelRepository {
|
|||
suspend fun deleteModel(modelId: String)
|
||||
suspend fun deleteModels(modelIds: List<String>)
|
||||
|
||||
/**
|
||||
* Fetch details of preselected models
|
||||
*/
|
||||
suspend fun fetchPreselectedHuggingFaceModels(): List<HuggingFaceModelDetails>
|
||||
|
||||
/**
|
||||
* Search models on HuggingFace
|
||||
*/
|
||||
|
|
@ -316,6 +321,10 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
}
|
||||
}
|
||||
|
||||
override suspend fun fetchPreselectedHuggingFaceModels() = withContext(Dispatchers.IO) {
|
||||
huggingFaceRemoteDataSource.fetchPreselectedModels()
|
||||
}
|
||||
|
||||
override suspend fun searchHuggingFaceModels(
|
||||
limit: Int
|
||||
) = withContext(Dispatchers.IO) {
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import androidx.core.net.toUri
|
|||
import com.example.llama.di.HUGGINGFACE_HOST
|
||||
import java.util.Date
|
||||
|
||||
private const val FILE_EXTENSION_GGUF = ".GGUF"
|
||||
private val QUANTIZATION_Q4_0 = arrayOf("Q4_0", "Q4-0")
|
||||
internal const val FILE_EXTENSION_GGUF = ".GGUF"
|
||||
internal val QUANTIZATION_Q4_0 = arrayOf("Q4_0", "Q4-0")
|
||||
|
||||
data class HuggingFaceModel(
|
||||
val _id: String,
|
||||
|
|
@ -31,10 +31,6 @@ data class HuggingFaceModel(
|
|||
|
||||
val library_name: String?,
|
||||
) {
|
||||
data class Sibling(
|
||||
val rfilename: String,
|
||||
)
|
||||
|
||||
fun getGgufFilename(keywords: Array<String> = QUANTIZATION_Q4_0): String? =
|
||||
siblings.map { it.rfilename }
|
||||
.filter {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package com.example.llama.data.source.remote
|
||||
|
||||
import java.util.Date
|
||||
import kotlin.String
|
||||
|
||||
data class HuggingFaceModelDetails(
|
||||
val _id: String,
|
||||
|
|
@ -8,51 +9,82 @@ data class HuggingFaceModelDetails(
|
|||
val modelId: String,
|
||||
|
||||
val author: String,
|
||||
val createdAt: Date?,
|
||||
val lastModified: Date?,
|
||||
val createdAt: Date,
|
||||
val lastModified: Date,
|
||||
|
||||
val library_name: String?,
|
||||
val pipeline_tag: String?,
|
||||
val tags: List<String>?,
|
||||
val pipeline_tag: String,
|
||||
val tags: List<String>,
|
||||
|
||||
val private: Boolean?,
|
||||
val private: Boolean,
|
||||
val gated: Boolean,
|
||||
val disabled: Boolean?,
|
||||
val gated: Boolean?,
|
||||
|
||||
val likes: Int?,
|
||||
val downloads: Int?,
|
||||
val likes: Int,
|
||||
val downloads: Int,
|
||||
|
||||
val usedStorage: Long?,
|
||||
val sha: String?,
|
||||
|
||||
val sha: String,
|
||||
val siblings: List<Sibling>,
|
||||
val cardData: CardData?,
|
||||
val siblings: List<Sibling>?,
|
||||
val widgetData: List<WidgetData>?,
|
||||
|
||||
val gguf: Gguf?,
|
||||
|
||||
val library_name: String?,
|
||||
) {
|
||||
data class Sibling(
|
||||
val rfilename: String,
|
||||
fun getGgufFilename(keywords: Array<String> = QUANTIZATION_Q4_0): String? =
|
||||
siblings.map { it.rfilename }
|
||||
.filter {
|
||||
it.endsWith(FILE_EXTENSION_GGUF, ignoreCase = true) }
|
||||
.firstOrNull { filename ->
|
||||
keywords.any { filename.contains(it, ignoreCase = true) }
|
||||
}
|
||||
|
||||
fun toModel() = HuggingFaceModel(
|
||||
_id = this._id,
|
||||
id = this.id,
|
||||
modelId = this.modelId,
|
||||
author = this.author,
|
||||
createdAt = this.createdAt,
|
||||
lastModified = this.lastModified,
|
||||
pipeline_tag = this.pipeline_tag,
|
||||
tags = this.tags,
|
||||
private = this.private,
|
||||
gated = this.gated,
|
||||
likes = this.likes,
|
||||
downloads = this.downloads,
|
||||
sha = this.sha,
|
||||
siblings = this.siblings.map { Sibling(it.rfilename) },
|
||||
library_name = this.library_name,
|
||||
)
|
||||
|
||||
data class Gguf(
|
||||
val total: Long?,
|
||||
val architecture: String?,
|
||||
val context_length: Int?,
|
||||
val chat_template: String?,
|
||||
val bos_token: String?,
|
||||
val eos_token: String?,
|
||||
)
|
||||
|
||||
data class CardData(
|
||||
val base_model: String?,
|
||||
val language: List<String>?,
|
||||
val license: String?,
|
||||
val pipeline_tag: String?,
|
||||
val tags: List<String>?,
|
||||
)
|
||||
|
||||
data class WidgetData(
|
||||
val text: String
|
||||
)
|
||||
fun toDownloadInfo() = getGgufFilename()?.let {
|
||||
HuggingFaceDownloadInfo(_id, modelId, it)
|
||||
}
|
||||
}
|
||||
|
||||
data class Sibling(
|
||||
val rfilename: String,
|
||||
)
|
||||
|
||||
data class Gguf(
|
||||
val total: Long?,
|
||||
val architecture: String?,
|
||||
val context_length: Int?,
|
||||
val chat_template: String?,
|
||||
val bos_token: String?,
|
||||
val eos_token: String?,
|
||||
)
|
||||
|
||||
data class CardData(
|
||||
val base_model: String?,
|
||||
val language: List<String>?,
|
||||
val license: String?,
|
||||
val pipeline_tag: String?,
|
||||
val tags: List<String>?,
|
||||
)
|
||||
|
||||
data class WidgetData(
|
||||
val text: String
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,12 +5,22 @@ import android.content.Context
|
|||
import android.os.Environment
|
||||
import android.util.Log
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.awaitAll
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.supervisorScope
|
||||
import kotlinx.coroutines.sync.Semaphore
|
||||
import kotlinx.coroutines.sync.withPermit
|
||||
import kotlinx.coroutines.withContext
|
||||
import retrofit2.HttpException
|
||||
import java.io.FileNotFoundException
|
||||
import java.io.IOException
|
||||
import java.net.SocketTimeoutException
|
||||
import java.net.UnknownHostException
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
import kotlin.coroutines.cancellation.CancellationException
|
||||
import kotlin.math.ceil
|
||||
|
||||
private const val QUERY_Q4_0_GGUF = "gguf q4_0"
|
||||
private const val FILTER_TEXT_GENERATION = "text-generation"
|
||||
|
|
@ -19,7 +29,25 @@ private const val SEARCH_RESULT_LIMIT = 30
|
|||
|
||||
private val INVALID_KEYWORDS = arrayOf("-of-", "split", "70B", "30B", "27B", "14B", "13B", "12B")
|
||||
|
||||
private val PRESELECTED_MODEL_IDS = listOf(
|
||||
"unsloth/gemma-3-1b-it-GGUF",
|
||||
"unsloth/gemma-3-4b-it-GGUF",
|
||||
"bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||
"bartowski/Llama-3.2-3B-Instruct-GGUF",
|
||||
"Qwen/Qwen2.5-3B-Instruct-GGUF",
|
||||
"gaianet/Phi-4-mini-instruct-GGUF",
|
||||
"bartowski/granite-3.0-2b-instruct-GGUF",
|
||||
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
)
|
||||
|
||||
interface HuggingFaceRemoteDataSource {
|
||||
|
||||
suspend fun fetchPreselectedModels(
|
||||
ids: List<String> = PRESELECTED_MODEL_IDS,
|
||||
parallelCount: Int = 3,
|
||||
quorum: Float = 0.5f,
|
||||
): List<HuggingFaceModelDetails>
|
||||
|
||||
/**
|
||||
* Query openly available Q4_0 GGUF models on HuggingFace
|
||||
*/
|
||||
|
|
@ -53,6 +81,64 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
|||
private val apiService: HuggingFaceApiService
|
||||
) : HuggingFaceRemoteDataSource {
|
||||
|
||||
override suspend fun fetchPreselectedModels(
|
||||
ids: List<String>,
|
||||
parallelCount: Int,
|
||||
quorum: Float,
|
||||
): List<HuggingFaceModelDetails> = withContext(Dispatchers.IO) {
|
||||
val successes = mutableListOf<HuggingFaceModelDetails>()
|
||||
val failures = mutableListOf<Throwable>()
|
||||
|
||||
val sem = Semaphore(parallelCount)
|
||||
supervisorScope {
|
||||
ids.map { id ->
|
||||
async {
|
||||
sem.withPermit {
|
||||
runCatching { getModelDetails(id) }
|
||||
.onSuccess { synchronized(successes) { successes += it } }
|
||||
.onFailure { t ->
|
||||
if (t is CancellationException) throw t
|
||||
synchronized(failures) { failures += t }
|
||||
}
|
||||
}
|
||||
}
|
||||
}.awaitAll()
|
||||
}
|
||||
|
||||
val total = ids.size
|
||||
val failed = failures.size
|
||||
val ok = successes.size
|
||||
val shouldThrow = failed >= ceil(total * quorum).toInt()
|
||||
|
||||
if (!shouldThrow) return@withContext successes.toList()
|
||||
|
||||
// 1. No Network
|
||||
if (failures.count { it is UnknownHostException } >= ceil(failed * 0.5).toInt()) {
|
||||
throw UnknownHostException()
|
||||
}
|
||||
|
||||
// 2. Time out
|
||||
if (failures.count { it is SocketTimeoutException } >= ceil(failed * 0.5).toInt()) {
|
||||
throw SocketTimeoutException()
|
||||
}
|
||||
|
||||
// 3. known error codes: 404/410/204
|
||||
val http404ish = failures.count { (it as? HttpException)?.code() in listOf(404, 410, 204) }
|
||||
if (ok == 0 && (failed > 0) && (http404ish >= ceil(failed * 0.5).toInt() || failed == total)) {
|
||||
throw FileNotFoundException()
|
||||
}
|
||||
|
||||
// 4. Unknown issues
|
||||
val ioMajority = failures.count {
|
||||
it is IOException && it !is UnknownHostException && it !is SocketTimeoutException
|
||||
} >= ceil(failed * 0.5).toInt()
|
||||
if (ioMajority) {
|
||||
throw IOException(failures.first { it is IOException }.message)
|
||||
}
|
||||
|
||||
successes
|
||||
}
|
||||
|
||||
override suspend fun searchModels(
|
||||
query: String?,
|
||||
filter: String?,
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ import dagger.hilt.InstallIn
|
|||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
import dagger.hilt.components.SingletonComponent
|
||||
import okhttp3.OkHttpClient
|
||||
import okhttp3.logging.HttpLoggingInterceptor
|
||||
import retrofit2.Retrofit
|
||||
import retrofit2.converter.gson.GsonConverterFactory
|
||||
import javax.inject.Singleton
|
||||
|
|
@ -109,10 +108,7 @@ internal abstract class AppModule {
|
|||
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideOkhttpClient() = OkHttpClient.Builder()
|
||||
.addInterceptor(HttpLoggingInterceptor().apply {
|
||||
level = HttpLoggingInterceptor.Level.BODY
|
||||
}).build()
|
||||
fun provideOkhttpClient() = OkHttpClient.Builder().build()
|
||||
|
||||
@Provides
|
||||
@Singleton
|
||||
|
|
|
|||
|
|
@ -245,16 +245,10 @@ fun ModelsManagementAndDeletingScreen(
|
|||
}
|
||||
|
||||
is Download.Dispatched -> {
|
||||
LaunchedEffect(state) {
|
||||
onScaffoldEvent(
|
||||
ScaffoldEvent.ShowSnackbar(
|
||||
message = "Started downloading:\n${state.downloadInfo.modelId}",
|
||||
duration = SnackbarDuration.Long,
|
||||
)
|
||||
)
|
||||
|
||||
managementViewModel.resetManagementState()
|
||||
}
|
||||
DownloadHuggingFaceDispatchedDialog(
|
||||
state.downloadInfo.modelId,
|
||||
onConfirm = { managementViewModel.resetManagementState() }
|
||||
)
|
||||
}
|
||||
|
||||
is Download.Completed -> {
|
||||
|
|
@ -627,6 +621,33 @@ fun HuggingFaceModelListItem(
|
|||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun DownloadHuggingFaceDispatchedDialog(
|
||||
modelId: String,
|
||||
onConfirm: () -> Unit,
|
||||
) {
|
||||
AlertDialog(
|
||||
onDismissRequest = {},
|
||||
properties = DialogProperties(
|
||||
dismissOnBackPress = false,
|
||||
dismissOnClickOutside = false
|
||||
),
|
||||
title = {},
|
||||
text = {
|
||||
InfoView(
|
||||
title = "Download has started",
|
||||
icon = Icons.Default.Download,
|
||||
message = "Your Android system download manager has started downloading the model: $modelId.\n\n"
|
||||
+ "You can track its progress in your notification drawer.\n"
|
||||
+ "Feel free to stay on this screen, or come back to import it after complete.",
|
||||
)
|
||||
},
|
||||
confirmButton = {
|
||||
Button(onClick = onConfirm) { Text("Okay") }
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun FirstModelImportSuccessDialog(
|
||||
onConfirm: () -> Unit,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import com.example.llama.data.repo.InsufficientStorageException
|
|||
import com.example.llama.data.repo.ModelRepository
|
||||
import com.example.llama.data.source.remote.HuggingFaceDownloadInfo
|
||||
import com.example.llama.data.source.remote.HuggingFaceModel
|
||||
import com.example.llama.data.source.remote.HuggingFaceModelDetails
|
||||
import com.example.llama.util.formatFileByteSize
|
||||
import com.example.llama.util.getFileNameFromUri
|
||||
import com.example.llama.util.getFileSizeFromUri
|
||||
|
|
@ -186,13 +187,8 @@ class ModelsManagementViewModel @Inject constructor(
|
|||
huggingFaceQueryJob = viewModelScope.launch {
|
||||
_managementState.emit(Download.Querying)
|
||||
try {
|
||||
modelRepository.searchHuggingFaceModels(FETCH_HUGGINGFACE_MODELS_LIMIT_SIZE).fold(
|
||||
onSuccess = { models ->
|
||||
Log.d(TAG, "Fetched ${models.size} models from HuggingFace, capped by $cap:\n")
|
||||
_managementState.emit(Download.Ready(models.take(cap)))
|
||||
},
|
||||
onFailure = { throw it }
|
||||
)
|
||||
val models = modelRepository.fetchPreselectedHuggingFaceModels().map(HuggingFaceModelDetails::toModel)
|
||||
_managementState.emit(Download.Ready(models))
|
||||
} catch (_: CancellationException) {
|
||||
// no-op
|
||||
} catch (_: UnknownHostException) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue