UI: implement a dialog UI to show fetched HuggingFace models.
This commit is contained in:
parent
310771f6aa
commit
f085d39c05
|
|
@ -351,7 +351,7 @@ fun AppContent(
|
||||||
// Create file launcher for importing local models
|
// Create file launcher for importing local models
|
||||||
val fileLauncher = rememberLauncherForActivityResult(
|
val fileLauncher = rememberLauncherForActivityResult(
|
||||||
contract = ActivityResultContracts.OpenDocument()
|
contract = ActivityResultContracts.OpenDocument()
|
||||||
) { uri -> uri?.let { modelsManagementViewModel.localModelFileSelected(it) } }
|
) { uri -> uri?.let { modelsManagementViewModel.importLocalModelFileSelected(it) } }
|
||||||
|
|
||||||
val bottomBarConfig = BottomBarConfig.ModelsManagement(
|
val bottomBarConfig = BottomBarConfig.ModelsManagement(
|
||||||
sorting = BottomBarConfig.ModelsManagement.SortingConfig(
|
sorting = BottomBarConfig.ModelsManagement.SortingConfig(
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ data class HuggingFaceModel(
|
||||||
)
|
)
|
||||||
|
|
||||||
fun getGgufFilename(): String? =
|
fun getGgufFilename(): String? =
|
||||||
siblings.map { it.rfilename }.first { it.endsWith(FILE_EXTENSION_GGUF) }
|
siblings.map { it.rfilename }.firstOrNull { it.endsWith(FILE_EXTENSION_GGUF) }
|
||||||
|
|
||||||
fun toDownloadInfo() = getGgufFilename()?.let { HuggingFaceDownloadInfo(_id, modelId, it) }
|
fun toDownloadInfo() = getGgufFilename()?.let { HuggingFaceDownloadInfo(_id, modelId, it) }
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
|
||||||
direction = direction,
|
direction = direction,
|
||||||
limit = limit,
|
limit = limit,
|
||||||
full = full,
|
full = full,
|
||||||
).filter { it.gated != true && it.private != true }
|
).filter { it.gated != true && it.private != true && it.getGgufFilename() != null }
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun getModelDetails(
|
override suspend fun getModelDetails(
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
package com.example.llama.ui.screens
|
package com.example.llama.ui.screens
|
||||||
|
|
||||||
import androidx.activity.compose.BackHandler
|
import androidx.activity.compose.BackHandler
|
||||||
|
import androidx.compose.foundation.basicMarquee
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
|
import androidx.compose.foundation.layout.PaddingValues
|
||||||
import androidx.compose.foundation.layout.Row
|
import androidx.compose.foundation.layout.Row
|
||||||
import androidx.compose.foundation.layout.Spacer
|
import androidx.compose.foundation.layout.Spacer
|
||||||
import androidx.compose.foundation.layout.fillMaxSize
|
import androidx.compose.foundation.layout.fillMaxSize
|
||||||
|
|
@ -15,10 +17,18 @@ import androidx.compose.foundation.layout.width
|
||||||
import androidx.compose.foundation.lazy.LazyColumn
|
import androidx.compose.foundation.lazy.LazyColumn
|
||||||
import androidx.compose.foundation.lazy.items
|
import androidx.compose.foundation.lazy.items
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.filled.Attribution
|
||||||
|
import androidx.compose.material.icons.filled.Download
|
||||||
|
import androidx.compose.material.icons.filled.Favorite
|
||||||
import androidx.compose.material.icons.filled.FolderOpen
|
import androidx.compose.material.icons.filled.FolderOpen
|
||||||
|
import androidx.compose.material.icons.filled.Today
|
||||||
import androidx.compose.material3.AlertDialog
|
import androidx.compose.material3.AlertDialog
|
||||||
import androidx.compose.material3.Button
|
import androidx.compose.material3.Button
|
||||||
|
import androidx.compose.material3.Card
|
||||||
|
import androidx.compose.material3.CardDefaults
|
||||||
|
import androidx.compose.material3.Checkbox
|
||||||
import androidx.compose.material3.CircularProgressIndicator
|
import androidx.compose.material3.CircularProgressIndicator
|
||||||
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.LinearProgressIndicator
|
import androidx.compose.material3.LinearProgressIndicator
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.material3.SnackbarDuration
|
import androidx.compose.material3.SnackbarDuration
|
||||||
|
|
@ -30,23 +40,31 @@ import androidx.compose.runtime.collectAsState
|
||||||
import androidx.compose.runtime.derivedStateOf
|
import androidx.compose.runtime.derivedStateOf
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
import androidx.compose.runtime.mutableStateMapOf
|
import androidx.compose.runtime.mutableStateMapOf
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
import androidx.compose.runtime.remember
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.text.font.FontStyle
|
import androidx.compose.ui.text.font.FontStyle
|
||||||
|
import androidx.compose.ui.text.font.FontWeight
|
||||||
import androidx.compose.ui.text.style.TextAlign
|
import androidx.compose.ui.text.style.TextAlign
|
||||||
import androidx.compose.ui.text.style.TextOverflow
|
import androidx.compose.ui.text.style.TextOverflow
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import androidx.compose.ui.window.DialogProperties
|
import androidx.compose.ui.window.DialogProperties
|
||||||
import com.example.llama.data.model.ModelInfo
|
import com.example.llama.data.model.ModelInfo
|
||||||
|
import com.example.llama.data.remote.HuggingFaceModel
|
||||||
import com.example.llama.ui.components.InfoView
|
import com.example.llama.ui.components.InfoView
|
||||||
import com.example.llama.ui.components.ModelCardFullExpandable
|
import com.example.llama.ui.components.ModelCardFullExpandable
|
||||||
import com.example.llama.ui.scaffold.ScaffoldEvent
|
import com.example.llama.ui.scaffold.ScaffoldEvent
|
||||||
|
import com.example.llama.util.formatContextLength
|
||||||
import com.example.llama.util.formatFileByteSize
|
import com.example.llama.util.formatFileByteSize
|
||||||
import com.example.llama.viewmodel.ModelManagementState
|
import com.example.llama.viewmodel.ModelManagementState
|
||||||
import com.example.llama.viewmodel.ModelManagementState.Deletion
|
import com.example.llama.viewmodel.ModelManagementState.Deletion
|
||||||
|
import com.example.llama.viewmodel.ModelManagementState.Download
|
||||||
import com.example.llama.viewmodel.ModelManagementState.Importation
|
import com.example.llama.viewmodel.ModelManagementState.Importation
|
||||||
import com.example.llama.viewmodel.ModelsManagementViewModel
|
import com.example.llama.viewmodel.ModelsManagementViewModel
|
||||||
|
import java.text.SimpleDateFormat
|
||||||
|
import java.util.Locale
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Screen for managing LLM models (view, download, delete)
|
* Screen for managing LLM models (view, download, delete)
|
||||||
|
|
@ -103,7 +121,9 @@ fun ModelsManagementScreen(
|
||||||
} else {
|
} else {
|
||||||
// Model cards
|
// Model cards
|
||||||
LazyColumn(
|
LazyColumn(
|
||||||
modifier = Modifier.fillMaxSize().padding(horizontal = 16.dp),
|
modifier = Modifier
|
||||||
|
.fillMaxSize()
|
||||||
|
.padding(horizontal = 16.dp),
|
||||||
verticalArrangement = Arrangement.spacedBy(12.dp),
|
verticalArrangement = Arrangement.spacedBy(12.dp),
|
||||||
) {
|
) {
|
||||||
items(items = filteredModels, key = { it.id }) { model ->
|
items(items = filteredModels, key = { it.id }) { model ->
|
||||||
|
|
@ -134,20 +154,20 @@ fun ModelsManagementScreen(
|
||||||
// Model import progress overlay
|
// Model import progress overlay
|
||||||
when (val state = managementState) {
|
when (val state = managementState) {
|
||||||
is Importation.Confirming -> {
|
is Importation.Confirming -> {
|
||||||
ImportProgressDialog(
|
ImportFromLocalFileDialog(
|
||||||
fileName = state.fileName,
|
fileName = state.fileName,
|
||||||
fileSize = state.fileSize,
|
fileSize = state.fileSize,
|
||||||
isImporting = false,
|
isImporting = false,
|
||||||
progress = 0.0f,
|
progress = 0.0f,
|
||||||
onConfirm = {
|
onConfirm = {
|
||||||
viewModel.importLocalModelFile(state.uri, state.fileName, state.fileSize)
|
viewModel.importLocalModelFileConfirmed(state.uri, state.fileName, state.fileSize)
|
||||||
},
|
},
|
||||||
onCancel = { viewModel.resetManagementState() }
|
onCancel = { viewModel.resetManagementState() }
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
is Importation.Importing -> {
|
is Importation.Importing -> {
|
||||||
ImportProgressDialog(
|
ImportFromLocalFileDialog(
|
||||||
fileName = state.fileName,
|
fileName = state.fileName,
|
||||||
fileSize = state.fileSize,
|
fileSize = state.fileSize,
|
||||||
isImporting = true,
|
isImporting = true,
|
||||||
|
|
@ -178,6 +198,43 @@ fun ModelsManagementScreen(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
is Download.Querying -> {
|
||||||
|
ImportFromHuggingFaceDialog(
|
||||||
|
onCancel = { viewModel.resetManagementState() }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
is Download.Ready -> {
|
||||||
|
ImportFromHuggingFaceDialog(
|
||||||
|
models = state.models,
|
||||||
|
onConfirm = { viewModel.downloadHuggingFaceModelConfirmed(it) },
|
||||||
|
onCancel = { viewModel.resetManagementState() }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
is Download.Dispatched -> {
|
||||||
|
LaunchedEffect(state) {
|
||||||
|
onScaffoldEvent(
|
||||||
|
ScaffoldEvent.ShowSnackbar(
|
||||||
|
message = "Started downloading: ${state.downloadInfo.modelId}.\n" +
|
||||||
|
// TODO-han.yin: replace this with a broadcast receiver!
|
||||||
|
"Please come back to import it once completed.",
|
||||||
|
duration = SnackbarDuration.Long,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
viewModel.resetManagementState()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
is Download.Error -> {
|
||||||
|
ErrorDialog(
|
||||||
|
title = "Download Failed",
|
||||||
|
message = state.message,
|
||||||
|
onDismiss = { viewModel.resetManagementState() }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
is Deletion.Confirming -> {
|
is Deletion.Confirming -> {
|
||||||
BatchDeleteConfirmationDialog(
|
BatchDeleteConfirmationDialog(
|
||||||
count = state.models.size,
|
count = state.models.size,
|
||||||
|
|
@ -219,33 +276,12 @@ fun ModelsManagementScreen(
|
||||||
}
|
}
|
||||||
|
|
||||||
is ModelManagementState.Idle -> { /* Idle state, nothing to show */ }
|
is ModelManagementState.Idle -> { /* Idle state, nothing to show */ }
|
||||||
|
|
||||||
else -> TODO()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO-han.yin: UI TO BE IMPLEMENTED
|
|
||||||
val huggingFaceModelsFlow = viewModel.huggingFaceModels
|
|
||||||
LaunchedEffect(Unit) {
|
|
||||||
huggingFaceModelsFlow.collect { models ->
|
|
||||||
val message = models.fold(
|
|
||||||
StringBuilder("Fetched ${models.size} models from HuggingFace")
|
|
||||||
) { builder, model ->
|
|
||||||
builder.append(model.id).append("\n")
|
|
||||||
}.toString()
|
|
||||||
|
|
||||||
onScaffoldEvent(
|
|
||||||
ScaffoldEvent.ShowSnackbar(
|
|
||||||
message = message,
|
|
||||||
duration = SnackbarDuration.Short
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
private fun ImportProgressDialog(
|
private fun ImportFromLocalFileDialog(
|
||||||
fileName: String,
|
fileName: String,
|
||||||
fileSize: Long,
|
fileSize: Long,
|
||||||
isImporting: Boolean,
|
isImporting: Boolean,
|
||||||
|
|
@ -344,6 +380,200 @@ private fun ImportProgressDialog(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun ImportFromHuggingFaceDialog(
|
||||||
|
models: List<HuggingFaceModel>? = null,
|
||||||
|
onConfirm: ((HuggingFaceModel) -> Unit)? = null,
|
||||||
|
onCancel: () -> Unit,
|
||||||
|
) {
|
||||||
|
val dateFormatter = remember { SimpleDateFormat("MMM, yyyy", Locale.getDefault()) }
|
||||||
|
|
||||||
|
var selectedModel by remember { mutableStateOf<HuggingFaceModel?>(null) }
|
||||||
|
|
||||||
|
AlertDialog(
|
||||||
|
onDismissRequest = {},
|
||||||
|
properties = DialogProperties(
|
||||||
|
dismissOnBackPress = true,
|
||||||
|
dismissOnClickOutside = true
|
||||||
|
),
|
||||||
|
title = {
|
||||||
|
Text(models?.let { "Fetched ${it.size} models" } ?: "Fetching models")
|
||||||
|
},
|
||||||
|
text = {
|
||||||
|
Column(
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
|
verticalArrangement = Arrangement.Center
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
text = models?.let { "Select a model to download:" }
|
||||||
|
?: "Searching on HuggingFace for models available for direct download...",
|
||||||
|
style = MaterialTheme.typography.bodyLarge,
|
||||||
|
textAlign = TextAlign.Start,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (models == null) {
|
||||||
|
Spacer(modifier = Modifier.height(24.dp))
|
||||||
|
|
||||||
|
CircularProgressIndicator(
|
||||||
|
modifier = Modifier.size(64.dp),
|
||||||
|
strokeWidth = 6.dp
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
Spacer(modifier = Modifier.height(16.dp))
|
||||||
|
|
||||||
|
LazyColumn(
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
contentPadding = PaddingValues(vertical = 8.dp),
|
||||||
|
verticalArrangement = Arrangement.spacedBy(8.dp)
|
||||||
|
) {
|
||||||
|
items(models) { model ->
|
||||||
|
HuggingFaceModelListItem(
|
||||||
|
model = model,
|
||||||
|
isSelected = model._id == selectedModel?._id,
|
||||||
|
dateFormatter = dateFormatter,
|
||||||
|
onToggleSelect = { selected ->
|
||||||
|
selectedModel = if (selected) model else null
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
confirmButton = {
|
||||||
|
onConfirm?.let { onSelect ->
|
||||||
|
TextButton(
|
||||||
|
onClick = { selectedModel?.let { onSelect.invoke(it) } },
|
||||||
|
enabled = selectedModel != null
|
||||||
|
) {
|
||||||
|
Text("Download")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
dismissButton = {
|
||||||
|
TextButton(
|
||||||
|
onClick = onCancel
|
||||||
|
) {
|
||||||
|
Text("Cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun HuggingFaceModelListItem(
|
||||||
|
model: HuggingFaceModel,
|
||||||
|
isSelected: Boolean,
|
||||||
|
dateFormatter: SimpleDateFormat,
|
||||||
|
onToggleSelect: (Boolean) -> Unit,
|
||||||
|
) {
|
||||||
|
Card(
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
colors = when (isSelected) {
|
||||||
|
true -> CardDefaults.cardColors(
|
||||||
|
containerColor = MaterialTheme.colorScheme.primaryContainer
|
||||||
|
)
|
||||||
|
false -> CardDefaults.cardColors()
|
||||||
|
},
|
||||||
|
onClick = { onToggleSelect(!isSelected) }
|
||||||
|
) {
|
||||||
|
Column(modifier = Modifier.fillMaxWidth().padding(8.dp)) {
|
||||||
|
Text(
|
||||||
|
modifier = Modifier.basicMarquee(),
|
||||||
|
text = model.modelId.substringAfterLast("/"),
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
fontWeight = if (isSelected) FontWeight.Bold else FontWeight.Medium,
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.size(8.dp))
|
||||||
|
|
||||||
|
Row(verticalAlignment = Alignment.Bottom) {
|
||||||
|
Column {
|
||||||
|
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Default.Attribution,
|
||||||
|
contentDescription = "Author",
|
||||||
|
modifier = Modifier.size(16.dp),
|
||||||
|
tint = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.size(4.dp))
|
||||||
|
|
||||||
|
Text(
|
||||||
|
text = model.author,
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.size(8.dp))
|
||||||
|
|
||||||
|
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Default.Today,
|
||||||
|
contentDescription = "Author",
|
||||||
|
modifier = Modifier.size(16.dp),
|
||||||
|
tint = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.size(4.dp))
|
||||||
|
|
||||||
|
Text(
|
||||||
|
text = dateFormatter.format(model.lastModified),
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.size(8.dp))
|
||||||
|
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Default.Favorite,
|
||||||
|
contentDescription = "Favorite count",
|
||||||
|
modifier = Modifier.size(16.dp),
|
||||||
|
tint = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.size(4.dp))
|
||||||
|
|
||||||
|
Text(
|
||||||
|
text = formatContextLength(model.likes),
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.size(8.dp))
|
||||||
|
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Default.Download,
|
||||||
|
contentDescription = "Download count",
|
||||||
|
modifier = Modifier.size(16.dp),
|
||||||
|
tint = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.size(4.dp))
|
||||||
|
|
||||||
|
Text(
|
||||||
|
text = formatContextLength(model.downloads),
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Spacer(Modifier.weight(1f))
|
||||||
|
|
||||||
|
if (isSelected) {
|
||||||
|
Checkbox(
|
||||||
|
checked = isSelected,
|
||||||
|
onCheckedChange = null, // handled by parent selectable
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
private fun BatchDeleteConfirmationDialog(
|
private fun BatchDeleteConfirmationDialog(
|
||||||
count: Int,
|
count: Int,
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,10 @@ import com.example.llama.viewmodel.ModelManagementState.Download
|
||||||
import com.example.llama.viewmodel.ModelManagementState.Importation
|
import com.example.llama.viewmodel.ModelManagementState.Importation
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||||
|
import kotlinx.coroutines.CancellationException
|
||||||
|
import kotlinx.coroutines.Job
|
||||||
import kotlinx.coroutines.delay
|
import kotlinx.coroutines.delay
|
||||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
import kotlinx.coroutines.flow.SharedFlow
|
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
import kotlinx.coroutines.flow.asStateFlow
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
import kotlinx.coroutines.flow.collectLatest
|
import kotlinx.coroutines.flow.collectLatest
|
||||||
|
|
@ -34,7 +34,6 @@ import kotlinx.coroutines.flow.update
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
import java.io.FileNotFoundException
|
import java.io.FileNotFoundException
|
||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
import kotlin.collections.set
|
|
||||||
|
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class ModelsManagementViewModel @Inject constructor(
|
class ModelsManagementViewModel @Inject constructor(
|
||||||
|
|
@ -133,9 +132,8 @@ class ModelsManagementViewModel @Inject constructor(
|
||||||
_showImportModelMenu.value = show
|
_showImportModelMenu.value = show
|
||||||
}
|
}
|
||||||
|
|
||||||
// UI state: HuggingFace models query result
|
// Ongoing coroutine jobs
|
||||||
private val _huggingFaceModels = MutableSharedFlow<List<HuggingFaceModel>>()
|
private var huggingFaceQueryJob: Job? = null
|
||||||
val huggingFaceModels: SharedFlow<List<HuggingFaceModel>> = _huggingFaceModels
|
|
||||||
|
|
||||||
init {
|
init {
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
|
|
@ -156,13 +154,16 @@ class ModelsManagementViewModel @Inject constructor(
|
||||||
val managementState: StateFlow<ModelManagementState> = _managementState.asStateFlow()
|
val managementState: StateFlow<ModelManagementState> = _managementState.asStateFlow()
|
||||||
|
|
||||||
fun resetManagementState() {
|
fun resetManagementState() {
|
||||||
|
huggingFaceQueryJob?.let {
|
||||||
|
if (it.isActive) { it.cancel() }
|
||||||
|
}
|
||||||
_managementState.value = ModelManagementState.Idle
|
_managementState.value = ModelManagementState.Idle
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* First show confirmation instead of starting import local file immediately
|
* First show confirmation instead of starting import local file immediately
|
||||||
*/
|
*/
|
||||||
fun localModelFileSelected(uri: Uri) = viewModelScope.launch {
|
fun importLocalModelFileSelected(uri: Uri) = viewModelScope.launch {
|
||||||
try {
|
try {
|
||||||
val fileName = getFileNameFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
|
val fileName = getFileNameFromUri(context, uri) ?: throw FileNotFoundException("File size N/A")
|
||||||
val fileSize = getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File name N/A")
|
val fileSize = getFileSizeFromUri(context, uri) ?: throw FileNotFoundException("File name N/A")
|
||||||
|
|
@ -177,7 +178,7 @@ class ModelsManagementViewModel @Inject constructor(
|
||||||
/**
|
/**
|
||||||
* Import a local model file from device storage while updating UI states with realtime progress
|
* Import a local model file from device storage while updating UI states with realtime progress
|
||||||
*/
|
*/
|
||||||
fun importLocalModelFile(uri: Uri, fileName: String, fileSize: Long) = viewModelScope.launch {
|
fun importLocalModelFileConfirmed(uri: Uri, fileName: String, fileSize: Long) = viewModelScope.launch {
|
||||||
try {
|
try {
|
||||||
_managementState.value = Importation.Importing(0f, fileName, fileSize)
|
_managementState.value = Importation.Importing(0f, fileName, fileSize)
|
||||||
val model = modelRepository.importModel(uri, fileName, fileSize) { progress ->
|
val model = modelRepository.importModel(uri, fileName, fileSize) { progress ->
|
||||||
|
|
@ -221,29 +222,25 @@ class ModelsManagementViewModel @Inject constructor(
|
||||||
/**
|
/**
|
||||||
* Query models on HuggingFace available for download even without signing in
|
* Query models on HuggingFace available for download even without signing in
|
||||||
*/
|
*/
|
||||||
fun queryModelsFromHuggingFace() = viewModelScope.launch {
|
fun queryModelsFromHuggingFace() {
|
||||||
modelRepository.searchHuggingFaceModels().let { models ->
|
huggingFaceQueryJob = viewModelScope.launch {
|
||||||
_huggingFaceModels.emit(models)
|
_managementState.emit(Download.Querying)
|
||||||
|
try {
|
||||||
|
val models = modelRepository.searchHuggingFaceModels()
|
||||||
Log.d(TAG, "Fetched ${models.size} models from HuggingFace:")
|
Log.d(TAG, "Fetched ${models.size} models from HuggingFace:")
|
||||||
|
_managementState.emit(Download.Ready(models))
|
||||||
// TODO-han.yin: remove these logs
|
} catch (_: CancellationException) {
|
||||||
// models.forEachIndexed { index, model ->
|
// no-op
|
||||||
// Log.d(TAG, "#$index: $model")
|
} catch (e: Exception) {
|
||||||
// }
|
_managementState.emit(Download.Error(message = e.message ?: "Unknown error"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* First show confirmation instead of dispatch download immediately
|
|
||||||
*/
|
|
||||||
fun downloadHuggingFaceModelSelected(model: HuggingFaceModel) {
|
|
||||||
_managementState.value = Download.Confirming(model)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dispatch download request to [DownloadManager] and update UI
|
* Dispatch download request to [DownloadManager] and update UI
|
||||||
*/
|
*/
|
||||||
fun downloadHuggingFaceModel(model: HuggingFaceModel) = viewModelScope.launch {
|
fun downloadHuggingFaceModelConfirmed(model: HuggingFaceModel) = viewModelScope.launch {
|
||||||
try {
|
try {
|
||||||
require(!model.gated) { "Model is gated!" }
|
require(!model.gated) { "Model is gated!" }
|
||||||
require(!model.private) { "Model is private!" }
|
require(!model.private) { "Model is private!" }
|
||||||
|
|
@ -324,7 +321,8 @@ sealed class ModelManagementState {
|
||||||
}
|
}
|
||||||
|
|
||||||
sealed class Download: ModelManagementState() {
|
sealed class Download: ModelManagementState() {
|
||||||
data class Confirming(val model: HuggingFaceModel) : Download()
|
object Querying : Download()
|
||||||
|
data class Ready(val models: List<HuggingFaceModel>) : Download()
|
||||||
data class Dispatched(val downloadInfo: HuggingFaceDownloadInfo) : Download()
|
data class Dispatched(val downloadInfo: HuggingFaceDownloadInfo) : Download()
|
||||||
data class Error(val message: String) : Download()
|
data class Error(val message: String) : Download()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue