From 5de0b5d6d098b23302c82ccd4113447888007537 Mon Sep 17 00:00:00 2001 From: Han Yin Date: Mon, 14 Apr 2025 15:44:42 -0700 Subject: [PATCH] data: import local model with file picker --- .../llama/revamp/data/local/ModelEntity.kt | 8 +- .../llama/revamp/data/model/ModelInfo.kt | 16 +- .../revamp/data/repository/ModelRepository.kt | 202 +++++++++++++--- .../revamp/ui/screens/ModelSelectionScreen.kt | 6 +- .../ui/screens/ModelsManagementScreen.kt | 217 +++++++++++++++--- .../viewmodel/ModelsManagementViewModel.kt | 33 ++- 6 files changed, 396 insertions(+), 86 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/local/ModelEntity.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/local/ModelEntity.kt index c5a67babf1..b1a7c0a1f9 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/local/ModelEntity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/local/ModelEntity.kt @@ -11,10 +11,10 @@ data class ModelEntity( val name: String, val path: String, val sizeInBytes: Long, - val parameters: String, - val quantization: String, - val type: String, - val contextLength: Int, + val parameters: String?, + val quantization: String?, + val type: String?, + val contextLength: Int?, val lastUsed: Long?, val dateAdded: Long ) { diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/ModelInfo.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/ModelInfo.kt index d7be84c967..68ef17db3e 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/ModelInfo.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/ModelInfo.kt @@ -1,5 +1,7 @@ package com.example.llama.revamp.data.model +import java.util.Locale + /** * Data class containing information about an LLM model. */ @@ -8,10 +10,10 @@ data class ModelInfo( val name: String, val path: String, val sizeInBytes: Long, - val parameters: String, - val quantization: String, - val type: String, - val contextLength: Int, + val parameters: String?, + val quantization: String?, + val type: String?, + val contextLength: Int?, val lastUsed: Long? = null ) { val formattedSize: String @@ -19,15 +21,15 @@ data class ModelInfo( return when { sizeInBytes >= 1_000_000_000 -> { val sizeInGb = sizeInBytes / 1_000_000_000.0 - String.format("%.2f GB", sizeInGb) + String.format(Locale.getDefault(), "%.2f GB", sizeInGb) } sizeInBytes >= 1_000_000 -> { val sizeInMb = sizeInBytes / 1_000_000.0 - String.format("%.2f MB", sizeInMb) + String.format(Locale.getDefault(), "%.2f MB", sizeInMb) } else -> { val sizeInKb = sizeInBytes / 1_000.0 - String.format("%.2f KB", sizeInKb) + String.format(Locale.getDefault(), "%.2f KB", sizeInKb) } } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt index fa262ede9b..1548f98232 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/ModelRepository.kt @@ -4,17 +4,31 @@ import android.content.Context import android.net.Uri import android.os.StatFs import android.provider.OpenableColumns +import android.util.Log import com.example.llama.revamp.data.local.ModelDao import com.example.llama.revamp.data.local.ModelEntity import com.example.llama.revamp.data.model.ModelInfo +import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker import dagger.hilt.android.qualifiers.ApplicationContext +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map +import kotlinx.coroutines.withContext +import kotlinx.coroutines.yield +import java.io.BufferedInputStream +import java.io.BufferedOutputStream import java.io.File +import java.io.FileNotFoundException import java.io.FileOutputStream import java.io.IOException +import java.io.InputStream +import java.io.OutputStream +import java.nio.ByteBuffer +import java.nio.channels.Channels +import java.nio.channels.ReadableByteChannel +import java.nio.channels.WritableByteChannel import java.util.Locale import java.util.UUID import javax.inject.Inject @@ -27,9 +41,14 @@ interface ModelRepository { fun getStorageMetrics(): Flow fun getModels(): Flow> - suspend fun importModel(uri: Uri): ModelInfo + suspend fun importModel(uri: Uri, progressTracker: ImportProgressTracker? = null): ModelInfo + suspend fun deleteModel(modelId: String) suspend fun deleteModels(modelIds: Collection) + + fun interface ImportProgressTracker { + fun onProgress(progress: Float) // 0.0f to 1.0f + } } @Singleton @@ -67,38 +86,138 @@ class ModelRepositoryImpl @Inject constructor( } } - override suspend fun importModel(uri: Uri): ModelInfo { - // Obtain the local model's file via provided URI - val fileName = getFileNameFromUri(uri) + override suspend fun importModel( + uri: Uri, + progressTracker: ImportProgressTracker? + ): ModelInfo = withContext(Dispatchers.IO) { + val fileName = getFileNameFromUri(uri) ?: throw FileNotFoundException("Filename N/A") + val fileSize = getFileSizeFromUri(uri) ?: throw FileNotFoundException("File size N/A") val modelFile = File(modelsDir, fileName) - // Copy file to app's internal storage - context.contentResolver.openInputStream(uri)?.use { inputStream -> - FileOutputStream(modelFile).use { outputStream -> - inputStream.copyTo(outputStream) + try { + val inputStream = context.contentResolver.openInputStream(uri) + ?: throw IOException("Failed to open input stream") + val outputStream = FileOutputStream(modelFile) + + if (fileSize > LARGE_MODEL_THRESHOLD_SIZE) { + Log.i(TAG, "Copying $fileName (size: $fileSize) via NIO...") + + // Use NIO channels for large models + copyWithChannels(inputStream, outputStream, fileSize, progressTracker) + } else { + Log.i(TAG, "Copying $fileName (size: $fileSize) via buffer...") + + // Default copy with buffer for small models + val bufferedInput = BufferedInputStream(inputStream, DEFAULT_BUFFER_SIZE) + val bufferedOutput = BufferedOutputStream(outputStream, DEFAULT_BUFFER_SIZE) + copyWithBuffer(bufferedInput, bufferedOutput, fileSize, progressTracker) + + // Close streams + bufferedOutput.flush() + bufferedOutput.close() + bufferedInput.close() } - } ?: throw IOException("Failed to open input stream") - // Extract model parameters from filename - val modelType = extractModelTypeFromFilename(fileName) ?: "unknown" - val parameters = extractParametersFromFilename(fileName) ?: "unknown" - val quantization = extractQuantizationFromFilename(fileName) ?: "unknown" + // Extract model parameters from filename + val modelType = extractModelTypeFromFilename(fileName) + val parameters = extractParametersFromFilename(fileName) + val quantization = extractQuantizationFromFilename(fileName) - // Create model entity and save via DAO - val modelEntity = ModelEntity( - id = UUID.randomUUID().toString(), - name = fileName.substringBeforeLast('.'), - path = modelFile.absolutePath, - sizeInBytes = modelFile.length(), - parameters = parameters, - quantization = quantization, - type = modelType, - contextLength = DEFAULT_CONTEXT_SIZE, - lastUsed = null, - dateAdded = System.currentTimeMillis() - ) - modelDao.insertModel(modelEntity) - return modelEntity.toModelInfo() + // Create model entity and save via DAO + ModelEntity( + id = UUID.randomUUID().toString(), + name = fileName.substringBeforeLast('.'), + path = modelFile.absolutePath, + sizeInBytes = modelFile.length(), + parameters = parameters, + quantization = quantization, + type = modelType, + contextLength = DEFAULT_CONTEXT_SIZE, + lastUsed = null, + dateAdded = System.currentTimeMillis() + ).let { + modelDao.insertModel(it) + it.toModelInfo() + } + } catch (e: Exception) { + // Clean up partially downloaded file if error occurs + if (modelFile.exists()) { + modelFile.delete() + } + throw e + } + } + + private suspend fun copyWithChannels( + input: InputStream, + output: OutputStream, + totalSize: Long, + progressTracker: ImportProgressTracker? + ) { + val inChannel: ReadableByteChannel = Channels.newChannel(input) + val outChannel: WritableByteChannel = Channels.newChannel(output) + + val buffer = ByteBuffer.allocateDirect(NIO_BUFFER_SIZE) + var totalBytesRead = 0L + + while (inChannel.read(buffer) != -1) { + buffer.flip() + while (buffer.hasRemaining()) { + outChannel.write(buffer) + } + totalBytesRead += buffer.position() + buffer.clear() + + // Report progress + progressTracker?.let { + val progress = totalBytesRead.toFloat() / totalSize + withContext(Dispatchers.Main) { + it.onProgress(progress) + } + } + + if (totalBytesRead % (NIO_YIELD_SIZE) == 0L) { + yield() + } + } + + inChannel.close() + outChannel.close() + output.close() + input.close() + } + + private suspend fun copyWithBuffer( + input: BufferedInputStream, + output: BufferedOutputStream, + totalSize: Long, + progressTracker: ImportProgressTracker? + ) { + val buffer = ByteArray(DEFAULT_BUFFER_SIZE) + + var bytesRead: Int + var totalBytesRead = 0L + + while (input.read(buffer).also { bytesRead = it } != -1) { + output.write(buffer, 0, bytesRead) + totalBytesRead += bytesRead + + // Report progress + if (progressTracker != null) { + val progress = totalBytesRead.toFloat() / totalSize + withContext(Dispatchers.Main) { + progressTracker.onProgress(progress) + } + } + + // Yield less frequently with larger buffers + if (totalBytesRead % (DEFAULT_YIELD_SIZE) == 0L) { // Every 64MB + yield() + } + } + + output.close() + input.close() } override suspend fun deleteModel(modelId: String) { @@ -132,7 +251,7 @@ class ModelRepositoryImpl @Inject constructor( val totalSpaceBytes: Long get() = StatFs(context.filesDir.path).totalBytes - private fun getFileNameFromUri(uri: Uri): String = + private fun getFileNameFromUri(uri: Uri): String? = context.contentResolver.query(uri, null, null, null, null)?.use { cursor -> if (cursor.moveToFirst()) { cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex -> @@ -141,7 +260,21 @@ class ModelRepositoryImpl @Inject constructor( } else { null } - } ?: uri.lastPathSegment ?: "unknown_model.gguf" + } ?: uri.lastPathSegment + + /** + * Gets the file size from a content URI, or returns 0 if size is unknown. + */ + private fun getFileSizeFromUri(uri: Uri): Long? = + context.contentResolver.query(uri, null, null, null, null)?.use { cursor -> + if (cursor.moveToFirst()) { + cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex -> + if (sizeIndex != -1) cursor.getLong(sizeIndex) else null + } + } else { + null + } + } /** * Try to extract parameters by looking for patterns like 7B, 13B, etc. @@ -182,9 +315,18 @@ class ModelRepositoryImpl @Inject constructor( } companion object { + private val TAG = ModelRepository::class.java.simpleName + private const val INTERNAL_STORAGE_PATH = "models" + private const val BYTES_IN_GB = 1024f * 1024f * 1024f + private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024 + private const val NIO_BUFFER_SIZE = 32 * 1024 * 1024 + private const val NIO_YIELD_SIZE = 128 * 1024 * 1024 + private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024 + private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024 + private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L private const val DEFAULT_CONTEXT_SIZE = 8192 diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt index c7ba477191..86107a8d00 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt @@ -99,13 +99,13 @@ fun ModelCard( Row { Text( - text = model.parameters, + text = model.parameters ?: " - ", style = MaterialTheme.typography.bodyMedium, color = MaterialTheme.colorScheme.onSurfaceVariant ) Text( - text = " • ${model.quantization}", + text = " • ${model.quantization ?: " - "}", style = MaterialTheme.typography.bodyMedium, color = MaterialTheme.colorScheme.onSurfaceVariant ) @@ -120,7 +120,7 @@ fun ModelCard( Spacer(modifier = Modifier.height(4.dp)) Text( - text = "Context Length: ${model.contextLength}", + text = "Context Length: ${model.contextLength ?: " - "}", style = MaterialTheme.typography.bodySmall, color = MaterialTheme.colorScheme.onSurfaceVariant ) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelsManagementScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelsManagementScreen.kt index 5587ffe819..0cbfbfd58a 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelsManagementScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelsManagementScreen.kt @@ -1,6 +1,11 @@ package com.example.llama.revamp.ui.screens +import androidx.activity.compose.BackHandler +import androidx.activity.compose.rememberLauncherForActivityResult +import androidx.activity.result.contract.ActivityResultContracts +import androidx.compose.foundation.background import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Spacer @@ -23,7 +28,10 @@ import androidx.compose.material.icons.filled.FilterAlt import androidx.compose.material.icons.filled.FolderOpen import androidx.compose.material.icons.filled.Info import androidx.compose.material.icons.filled.SelectAll +import androidx.compose.material3.AlertDialog import androidx.compose.material3.BottomAppBar +import androidx.compose.material3.Button +import androidx.compose.material3.ButtonDefaults import androidx.compose.material3.Card import androidx.compose.material3.CardDefaults import androidx.compose.material3.Checkbox @@ -32,9 +40,11 @@ import androidx.compose.material3.DropdownMenuItem import androidx.compose.material3.FloatingActionButton import androidx.compose.material3.Icon import androidx.compose.material3.IconButton +import androidx.compose.material3.LinearProgressIndicator import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Text import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateMapOf @@ -45,6 +55,7 @@ import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.graphics.Color import androidx.compose.ui.res.painterResource +import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp import androidx.hilt.navigation.compose.hiltViewModel import com.example.llama.revamp.data.model.ModelInfo @@ -55,6 +66,7 @@ import java.text.SimpleDateFormat import java.util.Date import java.util.Locale import com.example.llama.R +import com.example.llama.revamp.viewmodel.ModelImportState /** * Screen for managing LLM models (view, download, delete) @@ -66,13 +78,24 @@ fun ModelsManagementScreen( ) { val storageMetrics by viewModel.storageMetrics.collectAsState() val sortedModels by viewModel.sortedModels.collectAsState() + + // Model sorting val sortOrder by viewModel.sortOrder.collectAsState() - - // UI: menu states var showSortMenu by remember { mutableStateOf(false) } - var showAddModelMenu by remember { mutableStateOf(false) } - // UI: multi-selection states + // Model importing + val importState by viewModel.importState.collectAsState() + + var showImportModelMenu by remember { mutableStateOf(false) } + val fileLauncher = rememberLauncherForActivityResult( + contract = ActivityResultContracts.OpenDocument() + ) { uri -> uri?.let { viewModel.importLocalModel(it) } } + + BackHandler(enabled = importState is ModelImportState.Importing) { + /* Ignore back press while importing model */ + } + + // Multi-selection var isMultiSelectionMode by remember { mutableStateOf(false) } val selectedModels = remember { mutableStateMapOf() } val exitSelectionMode = { @@ -238,7 +261,7 @@ fun ModelsManagementScreen( if (isMultiSelectionMode) { exitSelectionMode() } else { - showAddModelMenu = true + showImportModelMenu = true } }, containerColor = MaterialTheme.colorScheme.primaryContainer @@ -251,8 +274,8 @@ fun ModelsManagementScreen( // Add model dropdown menu DropdownMenu( - expanded = showAddModelMenu, - onDismissRequest = { showAddModelMenu = false } + expanded = showImportModelMenu, + onDismissRequest = { showImportModelMenu = false } ) { DropdownMenuItem( text = { Text("Import local model") }, @@ -263,9 +286,8 @@ fun ModelsManagementScreen( ) }, onClick = { - // TODO-han.yin: uncomment once file picker done - // viewModel.importLocalModel() - showAddModelMenu = false + fileLauncher.launch(arrayOf("application/octet-stream", "*/*")) + showImportModelMenu = false } ) DropdownMenuItem( @@ -280,7 +302,7 @@ fun ModelsManagementScreen( }, onClick = { viewModel.importFromHuggingFace() - showAddModelMenu = false + showImportModelMenu = false } ) } @@ -288,38 +310,64 @@ fun ModelsManagementScreen( ) }, ) { paddingValues -> - // Main content - ModelList( - models = sortedModels, - isMultiSelectionMode = isMultiSelectionMode, - selectedModels = selectedModels, - onModelClick = { modelId -> - if (isMultiSelectionMode) { - // Toggle selection - if (selectedModels.contains(modelId)) { - selectedModels.remove(modelId) + Box(modifier = Modifier.fillMaxSize()) { + // Model cards + ModelCardList( + models = sortedModels, + isMultiSelectionMode = isMultiSelectionMode, + selectedModels = selectedModels, + onModelClick = { modelId -> + if (isMultiSelectionMode) { + // Toggle selection + if (selectedModels.contains(modelId)) { + selectedModels.remove(modelId) + } else { + selectedModels.put(modelId, sortedModels.first { it.id == modelId } ) + } } else { - selectedModels.put(modelId, sortedModels.first { it.id == modelId } ) + // View model details + viewModel.viewModelDetails(modelId) } - } else { - // View model details + }, + onModelInfoClick = { modelId -> viewModel.viewModelDetails(modelId) + }, + onModelDeleteClick = { modelId -> + viewModel.deleteModel(modelId) + }, + modifier = Modifier.padding(paddingValues) + ) + + // Model import progress overlay + when (val state = importState) { + is ModelImportState.Importing -> { + ImportProgressOverlay( + progress = state.progress, + filename = state.filename, + onCancel = { /* Implement cancellation if needed */ } + ) } - }, - onModelInfoClick = { modelId -> - viewModel.viewModelDetails(modelId) - }, - onModelDeleteClick = { modelId -> - viewModel.deleteModel(modelId) - }, - modifier = Modifier.padding(paddingValues) - ) + is ModelImportState.Error -> { + ErrorDialog( + message = state.message, + onDismiss = { viewModel.resetImportState() } + ) + } + is ModelImportState.Success -> { + LaunchedEffect(state) { + // Show success snackbar or message + // This will auto-dismiss after the delay in viewModel + } + } + else -> { /* Idle state, nothing to show */ } + } + } } } @Composable -private fun ModelList( +private fun ModelCardList( models: List, isMultiSelectionMode: Boolean, selectedModels: Map, @@ -337,13 +385,16 @@ private fun ModelList( items = models, key = { it.id } ) { model -> - ModelItem( + ModelCard( model = model, isMultiSelectionMode = isMultiSelectionMode, isSelected = selectedModels.contains(model.id), onClick = { onModelClick(model.id) }, onInfoClick = { onModelInfoClick(model.id) }, - onDeleteClick = { onModelDeleteClick(model.id) } + onDeleteClick = { + // TODO-han.yin: pop up an AlertDialog asking user for confirmation + onModelDeleteClick(model.id) + } ) Spacer(modifier = Modifier.height(8.dp)) } @@ -351,7 +402,7 @@ private fun ModelList( } @Composable -private fun ModelItem( +private fun ModelCard( model: ModelInfo, isMultiSelectionMode: Boolean, isSelected: Boolean, @@ -424,3 +475,95 @@ private fun ModelItem( } } } + +@Composable +fun ImportProgressOverlay( + progress: Float, + filename: String, + onCancel: () -> Unit +) { + Box( + modifier = Modifier + .fillMaxSize() + .background(Color.Black.copy(alpha = 0.7f)) + .padding(32.dp), + contentAlignment = Alignment.Center + ) { + Card( + modifier = Modifier + .fillMaxWidth() + .padding(16.dp), + elevation = CardDefaults.cardElevation(defaultElevation = 8.dp) + ) { + Column( + modifier = Modifier.padding(24.dp), + horizontalAlignment = Alignment.CenterHorizontally + ) { + Text( + text = "Importing Model", + style = MaterialTheme.typography.headlineSmall + ) + + Spacer(modifier = Modifier.height(8.dp)) + + Text( + text = filename, + style = MaterialTheme.typography.bodyMedium, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + + Spacer(modifier = Modifier.height(24.dp)) + + LinearProgressIndicator( + progress = { progress }, + modifier = Modifier.fillMaxWidth() + ) + + Spacer(modifier = Modifier.height(8.dp)) + + Text( + text = "${(progress * 100).toInt()}%", + style = MaterialTheme.typography.bodyLarge + ) + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + text = "This may take several minutes for large models", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(24.dp)) + + Button( + onClick = onCancel, + colors = ButtonDefaults.buttonColors( + containerColor = MaterialTheme.colorScheme.errorContainer, + contentColor = MaterialTheme.colorScheme.onErrorContainer + ) + ) { + Text("Cancel") + } + } + } + } +} + +@Composable +fun ErrorDialog( + message: String, + onDismiss: () -> Unit +) { + AlertDialog( + onDismissRequest = onDismiss, + title = { Text("Import Failed") }, + text = { Text(message) }, + confirmButton = { + Button(onClick = onDismiss) { + Text("OK") + } + } + ) +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelsManagementViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelsManagementViewModel.kt index af108e5c69..a138c17d3e 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelsManagementViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelsManagementViewModel.kt @@ -1,13 +1,13 @@ package com.example.llama.revamp.viewmodel import android.net.Uri -import android.util.Log import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.repository.ModelRepository import com.example.llama.revamp.data.repository.StorageMetrics import dagger.hilt.android.lifecycle.HiltViewModel +import kotlinx.coroutines.delay import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow @@ -76,16 +76,34 @@ class ModelsManagementViewModel @Inject constructor( modelRepository.deleteModels(models.keys) } + private val _importState = MutableStateFlow(ModelImportState.Idle) + val importState: StateFlow = _importState.asStateFlow() + fun importLocalModel(uri: Uri) = viewModelScope.launch { try { - modelRepository.importModel(uri) + // Get filename for progress updates + val filename = uri.lastPathSegment ?: throw Exception("Model name unknown") + _importState.value = ModelImportState.Importing(0f, filename) + + // Import with progress reporting + val model = modelRepository.importModel(uri) { progress -> + _importState.value = ModelImportState.Importing(progress, filename) + } + _importState.value = ModelImportState.Success(model) + + // Reset state after a delay + delay(1000) + _importState.value = ModelImportState.Idle } catch (e: Exception) { - // TODO-han.yin: add UI to prompt user about import failure! - Log.e(TAG, "Failed to import model from: $uri", e) + _importState.value = ModelImportState.Error(e.message ?: "Unknown error") } } + fun resetImportState() { + _importState.value = ModelImportState.Idle + } + fun importFromHuggingFace() { // TODO-han.yin: Stub for now. Would need to investigate HuggingFace APIs } @@ -105,4 +123,9 @@ enum class ModelSortOrder { LAST_USED } - +sealed class ModelImportState { + object Idle : ModelImportState() + data class Importing(val progress: Float = 0f, val filename: String = "") : ModelImportState() + data class Success(val model: ModelInfo) : ModelImportState() + data class Error(val message: String) : ModelImportState() +}