data: import local model with file picker

This commit is contained in:
Han Yin 2025-04-14 15:44:42 -07:00
parent a3ebdac58f
commit 5de0b5d6d0
6 changed files with 396 additions and 86 deletions

View File

@ -11,10 +11,10 @@ data class ModelEntity(
val name: String, val name: String,
val path: String, val path: String,
val sizeInBytes: Long, val sizeInBytes: Long,
val parameters: String, val parameters: String?,
val quantization: String, val quantization: String?,
val type: String, val type: String?,
val contextLength: Int, val contextLength: Int?,
val lastUsed: Long?, val lastUsed: Long?,
val dateAdded: Long val dateAdded: Long
) { ) {

View File

@ -1,5 +1,7 @@
package com.example.llama.revamp.data.model package com.example.llama.revamp.data.model
import java.util.Locale
/** /**
* Data class containing information about an LLM model. * Data class containing information about an LLM model.
*/ */
@ -8,10 +10,10 @@ data class ModelInfo(
val name: String, val name: String,
val path: String, val path: String,
val sizeInBytes: Long, val sizeInBytes: Long,
val parameters: String, val parameters: String?,
val quantization: String, val quantization: String?,
val type: String, val type: String?,
val contextLength: Int, val contextLength: Int?,
val lastUsed: Long? = null val lastUsed: Long? = null
) { ) {
val formattedSize: String val formattedSize: String
@ -19,15 +21,15 @@ data class ModelInfo(
return when { return when {
sizeInBytes >= 1_000_000_000 -> { sizeInBytes >= 1_000_000_000 -> {
val sizeInGb = sizeInBytes / 1_000_000_000.0 val sizeInGb = sizeInBytes / 1_000_000_000.0
String.format("%.2f GB", sizeInGb) String.format(Locale.getDefault(), "%.2f GB", sizeInGb)
} }
sizeInBytes >= 1_000_000 -> { sizeInBytes >= 1_000_000 -> {
val sizeInMb = sizeInBytes / 1_000_000.0 val sizeInMb = sizeInBytes / 1_000_000.0
String.format("%.2f MB", sizeInMb) String.format(Locale.getDefault(), "%.2f MB", sizeInMb)
} }
else -> { else -> {
val sizeInKb = sizeInBytes / 1_000.0 val sizeInKb = sizeInBytes / 1_000.0
String.format("%.2f KB", sizeInKb) String.format(Locale.getDefault(), "%.2f KB", sizeInKb)
} }
} }
} }

View File

@ -4,17 +4,31 @@ import android.content.Context
import android.net.Uri import android.net.Uri
import android.os.StatFs import android.os.StatFs
import android.provider.OpenableColumns import android.provider.OpenableColumns
import android.util.Log
import com.example.llama.revamp.data.local.ModelDao import com.example.llama.revamp.data.local.ModelDao
import com.example.llama.revamp.data.local.ModelEntity import com.example.llama.revamp.data.local.ModelEntity
import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker
import dagger.hilt.android.qualifiers.ApplicationContext import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.map
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
import java.io.BufferedInputStream
import java.io.BufferedOutputStream
import java.io.File import java.io.File
import java.io.FileNotFoundException
import java.io.FileOutputStream import java.io.FileOutputStream
import java.io.IOException import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.nio.channels.ReadableByteChannel
import java.nio.channels.WritableByteChannel
import java.util.Locale import java.util.Locale
import java.util.UUID import java.util.UUID
import javax.inject.Inject import javax.inject.Inject
@ -27,9 +41,14 @@ interface ModelRepository {
fun getStorageMetrics(): Flow<StorageMetrics> fun getStorageMetrics(): Flow<StorageMetrics>
fun getModels(): Flow<List<ModelInfo>> fun getModels(): Flow<List<ModelInfo>>
suspend fun importModel(uri: Uri): ModelInfo suspend fun importModel(uri: Uri, progressTracker: ImportProgressTracker? = null): ModelInfo
suspend fun deleteModel(modelId: String) suspend fun deleteModel(modelId: String)
suspend fun deleteModels(modelIds: Collection<String>) suspend fun deleteModels(modelIds: Collection<String>)
fun interface ImportProgressTracker {
fun onProgress(progress: Float) // 0.0f to 1.0f
}
} }
@Singleton @Singleton
@ -67,38 +86,138 @@ class ModelRepositoryImpl @Inject constructor(
} }
} }
override suspend fun importModel(uri: Uri): ModelInfo { override suspend fun importModel(
// Obtain the local model's file via provided URI uri: Uri,
val fileName = getFileNameFromUri(uri) progressTracker: ImportProgressTracker?
): ModelInfo = withContext(Dispatchers.IO) {
val fileName = getFileNameFromUri(uri) ?: throw FileNotFoundException("Filename N/A")
val fileSize = getFileSizeFromUri(uri) ?: throw FileNotFoundException("File size N/A")
val modelFile = File(modelsDir, fileName) val modelFile = File(modelsDir, fileName)
// Copy file to app's internal storage try {
context.contentResolver.openInputStream(uri)?.use { inputStream -> val inputStream = context.contentResolver.openInputStream(uri)
FileOutputStream(modelFile).use { outputStream -> ?: throw IOException("Failed to open input stream")
inputStream.copyTo(outputStream) val outputStream = FileOutputStream(modelFile)
if (fileSize > LARGE_MODEL_THRESHOLD_SIZE) {
Log.i(TAG, "Copying $fileName (size: $fileSize) via NIO...")
// Use NIO channels for large models
copyWithChannels(inputStream, outputStream, fileSize, progressTracker)
} else {
Log.i(TAG, "Copying $fileName (size: $fileSize) via buffer...")
// Default copy with buffer for small models
val bufferedInput = BufferedInputStream(inputStream, DEFAULT_BUFFER_SIZE)
val bufferedOutput = BufferedOutputStream(outputStream, DEFAULT_BUFFER_SIZE)
copyWithBuffer(bufferedInput, bufferedOutput, fileSize, progressTracker)
// Close streams
bufferedOutput.flush()
bufferedOutput.close()
bufferedInput.close()
} }
} ?: throw IOException("Failed to open input stream")
// Extract model parameters from filename // Extract model parameters from filename
val modelType = extractModelTypeFromFilename(fileName) ?: "unknown" val modelType = extractModelTypeFromFilename(fileName)
val parameters = extractParametersFromFilename(fileName) ?: "unknown" val parameters = extractParametersFromFilename(fileName)
val quantization = extractQuantizationFromFilename(fileName) ?: "unknown" val quantization = extractQuantizationFromFilename(fileName)
// Create model entity and save via DAO // Create model entity and save via DAO
val modelEntity = ModelEntity( ModelEntity(
id = UUID.randomUUID().toString(), id = UUID.randomUUID().toString(),
name = fileName.substringBeforeLast('.'), name = fileName.substringBeforeLast('.'),
path = modelFile.absolutePath, path = modelFile.absolutePath,
sizeInBytes = modelFile.length(), sizeInBytes = modelFile.length(),
parameters = parameters, parameters = parameters,
quantization = quantization, quantization = quantization,
type = modelType, type = modelType,
contextLength = DEFAULT_CONTEXT_SIZE, contextLength = DEFAULT_CONTEXT_SIZE,
lastUsed = null, lastUsed = null,
dateAdded = System.currentTimeMillis() dateAdded = System.currentTimeMillis()
) ).let {
modelDao.insertModel(modelEntity) modelDao.insertModel(it)
return modelEntity.toModelInfo() it.toModelInfo()
}
} catch (e: Exception) {
// Clean up partially downloaded file if error occurs
if (modelFile.exists()) {
modelFile.delete()
}
throw e
}
}
private suspend fun copyWithChannels(
input: InputStream,
output: OutputStream,
totalSize: Long,
progressTracker: ImportProgressTracker?
) {
val inChannel: ReadableByteChannel = Channels.newChannel(input)
val outChannel: WritableByteChannel = Channels.newChannel(output)
val buffer = ByteBuffer.allocateDirect(NIO_BUFFER_SIZE)
var totalBytesRead = 0L
while (inChannel.read(buffer) != -1) {
buffer.flip()
while (buffer.hasRemaining()) {
outChannel.write(buffer)
}
totalBytesRead += buffer.position()
buffer.clear()
// Report progress
progressTracker?.let {
val progress = totalBytesRead.toFloat() / totalSize
withContext(Dispatchers.Main) {
it.onProgress(progress)
}
}
if (totalBytesRead % (NIO_YIELD_SIZE) == 0L) {
yield()
}
}
inChannel.close()
outChannel.close()
output.close()
input.close()
}
private suspend fun copyWithBuffer(
input: BufferedInputStream,
output: BufferedOutputStream,
totalSize: Long,
progressTracker: ImportProgressTracker?
) {
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
var bytesRead: Int
var totalBytesRead = 0L
while (input.read(buffer).also { bytesRead = it } != -1) {
output.write(buffer, 0, bytesRead)
totalBytesRead += bytesRead
// Report progress
if (progressTracker != null) {
val progress = totalBytesRead.toFloat() / totalSize
withContext(Dispatchers.Main) {
progressTracker.onProgress(progress)
}
}
// Yield less frequently with larger buffers
if (totalBytesRead % (DEFAULT_YIELD_SIZE) == 0L) { // Every 64MB
yield()
}
}
output.close()
input.close()
} }
override suspend fun deleteModel(modelId: String) { override suspend fun deleteModel(modelId: String) {
@ -132,7 +251,7 @@ class ModelRepositoryImpl @Inject constructor(
val totalSpaceBytes: Long val totalSpaceBytes: Long
get() = StatFs(context.filesDir.path).totalBytes get() = StatFs(context.filesDir.path).totalBytes
private fun getFileNameFromUri(uri: Uri): String = private fun getFileNameFromUri(uri: Uri): String? =
context.contentResolver.query(uri, null, null, null, null)?.use { cursor -> context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
if (cursor.moveToFirst()) { if (cursor.moveToFirst()) {
cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex -> cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex ->
@ -141,7 +260,21 @@ class ModelRepositoryImpl @Inject constructor(
} else { } else {
null null
} }
} ?: uri.lastPathSegment ?: "unknown_model.gguf" } ?: uri.lastPathSegment
/**
* Gets the file size from a content URI, or returns 0 if size is unknown.
*/
private fun getFileSizeFromUri(uri: Uri): Long? =
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
if (cursor.moveToFirst()) {
cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex ->
if (sizeIndex != -1) cursor.getLong(sizeIndex) else null
}
} else {
null
}
}
/** /**
* Try to extract parameters by looking for patterns like 7B, 13B, etc. * Try to extract parameters by looking for patterns like 7B, 13B, etc.
@ -182,9 +315,18 @@ class ModelRepositoryImpl @Inject constructor(
} }
companion object { companion object {
private val TAG = ModelRepository::class.java.simpleName
private const val INTERNAL_STORAGE_PATH = "models" private const val INTERNAL_STORAGE_PATH = "models"
private const val BYTES_IN_GB = 1024f * 1024f * 1024f private const val BYTES_IN_GB = 1024f * 1024f * 1024f
private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024
private const val NIO_BUFFER_SIZE = 32 * 1024 * 1024
private const val NIO_YIELD_SIZE = 128 * 1024 * 1024
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L
private const val DEFAULT_CONTEXT_SIZE = 8192 private const val DEFAULT_CONTEXT_SIZE = 8192

View File

@ -99,13 +99,13 @@ fun ModelCard(
Row { Row {
Text( Text(
text = model.parameters, text = model.parameters ?: " - ",
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurfaceVariant
) )
Text( Text(
text = "${model.quantization}", text = "${model.quantization ?: " - "}",
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurfaceVariant
) )
@ -120,7 +120,7 @@ fun ModelCard(
Spacer(modifier = Modifier.height(4.dp)) Spacer(modifier = Modifier.height(4.dp))
Text( Text(
text = "Context Length: ${model.contextLength}", text = "Context Length: ${model.contextLength ?: " - "}",
style = MaterialTheme.typography.bodySmall, style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurfaceVariant
) )

View File

@ -1,6 +1,11 @@
package com.example.llama.revamp.ui.screens package com.example.llama.revamp.ui.screens
import androidx.activity.compose.BackHandler
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.Spacer
@ -23,7 +28,10 @@ import androidx.compose.material.icons.filled.FilterAlt
import androidx.compose.material.icons.filled.FolderOpen import androidx.compose.material.icons.filled.FolderOpen
import androidx.compose.material.icons.filled.Info import androidx.compose.material.icons.filled.Info
import androidx.compose.material.icons.filled.SelectAll import androidx.compose.material.icons.filled.SelectAll
import androidx.compose.material3.AlertDialog
import androidx.compose.material3.BottomAppBar import androidx.compose.material3.BottomAppBar
import androidx.compose.material3.Button
import androidx.compose.material3.ButtonDefaults
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults import androidx.compose.material3.CardDefaults
import androidx.compose.material3.Checkbox import androidx.compose.material3.Checkbox
@ -32,9 +40,11 @@ import androidx.compose.material3.DropdownMenuItem
import androidx.compose.material3.FloatingActionButton import androidx.compose.material3.FloatingActionButton
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton import androidx.compose.material3.IconButton
import androidx.compose.material3.LinearProgressIndicator
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateMapOf import androidx.compose.runtime.mutableStateMapOf
@ -45,6 +55,7 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.res.painterResource import androidx.compose.ui.res.painterResource
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.model.ModelInfo
@ -55,6 +66,7 @@ import java.text.SimpleDateFormat
import java.util.Date import java.util.Date
import java.util.Locale import java.util.Locale
import com.example.llama.R import com.example.llama.R
import com.example.llama.revamp.viewmodel.ModelImportState
/** /**
* Screen for managing LLM models (view, download, delete) * Screen for managing LLM models (view, download, delete)
@ -66,13 +78,24 @@ fun ModelsManagementScreen(
) { ) {
val storageMetrics by viewModel.storageMetrics.collectAsState() val storageMetrics by viewModel.storageMetrics.collectAsState()
val sortedModels by viewModel.sortedModels.collectAsState() val sortedModels by viewModel.sortedModels.collectAsState()
// Model sorting
val sortOrder by viewModel.sortOrder.collectAsState() val sortOrder by viewModel.sortOrder.collectAsState()
// UI: menu states
var showSortMenu by remember { mutableStateOf(false) } var showSortMenu by remember { mutableStateOf(false) }
var showAddModelMenu by remember { mutableStateOf(false) }
// UI: multi-selection states // Model importing
val importState by viewModel.importState.collectAsState()
var showImportModelMenu by remember { mutableStateOf(false) }
val fileLauncher = rememberLauncherForActivityResult(
contract = ActivityResultContracts.OpenDocument()
) { uri -> uri?.let { viewModel.importLocalModel(it) } }
BackHandler(enabled = importState is ModelImportState.Importing) {
/* Ignore back press while importing model */
}
// Multi-selection
var isMultiSelectionMode by remember { mutableStateOf(false) } var isMultiSelectionMode by remember { mutableStateOf(false) }
val selectedModels = remember { mutableStateMapOf<String, ModelInfo>() } val selectedModels = remember { mutableStateMapOf<String, ModelInfo>() }
val exitSelectionMode = { val exitSelectionMode = {
@ -238,7 +261,7 @@ fun ModelsManagementScreen(
if (isMultiSelectionMode) { if (isMultiSelectionMode) {
exitSelectionMode() exitSelectionMode()
} else { } else {
showAddModelMenu = true showImportModelMenu = true
} }
}, },
containerColor = MaterialTheme.colorScheme.primaryContainer containerColor = MaterialTheme.colorScheme.primaryContainer
@ -251,8 +274,8 @@ fun ModelsManagementScreen(
// Add model dropdown menu // Add model dropdown menu
DropdownMenu( DropdownMenu(
expanded = showAddModelMenu, expanded = showImportModelMenu,
onDismissRequest = { showAddModelMenu = false } onDismissRequest = { showImportModelMenu = false }
) { ) {
DropdownMenuItem( DropdownMenuItem(
text = { Text("Import local model") }, text = { Text("Import local model") },
@ -263,9 +286,8 @@ fun ModelsManagementScreen(
) )
}, },
onClick = { onClick = {
// TODO-han.yin: uncomment once file picker done fileLauncher.launch(arrayOf("application/octet-stream", "*/*"))
// viewModel.importLocalModel() showImportModelMenu = false
showAddModelMenu = false
} }
) )
DropdownMenuItem( DropdownMenuItem(
@ -280,7 +302,7 @@ fun ModelsManagementScreen(
}, },
onClick = { onClick = {
viewModel.importFromHuggingFace() viewModel.importFromHuggingFace()
showAddModelMenu = false showImportModelMenu = false
} }
) )
} }
@ -288,38 +310,64 @@ fun ModelsManagementScreen(
) )
}, },
) { paddingValues -> ) { paddingValues ->
// Main content Box(modifier = Modifier.fillMaxSize()) {
ModelList( // Model cards
models = sortedModels, ModelCardList(
isMultiSelectionMode = isMultiSelectionMode, models = sortedModels,
selectedModels = selectedModels, isMultiSelectionMode = isMultiSelectionMode,
onModelClick = { modelId -> selectedModels = selectedModels,
if (isMultiSelectionMode) { onModelClick = { modelId ->
// Toggle selection if (isMultiSelectionMode) {
if (selectedModels.contains(modelId)) { // Toggle selection
selectedModels.remove(modelId) if (selectedModels.contains(modelId)) {
selectedModels.remove(modelId)
} else {
selectedModels.put(modelId, sortedModels.first { it.id == modelId } )
}
} else { } else {
selectedModels.put(modelId, sortedModels.first { it.id == modelId } ) // View model details
viewModel.viewModelDetails(modelId)
} }
} else { },
// View model details onModelInfoClick = { modelId ->
viewModel.viewModelDetails(modelId) viewModel.viewModelDetails(modelId)
},
onModelDeleteClick = { modelId ->
viewModel.deleteModel(modelId)
},
modifier = Modifier.padding(paddingValues)
)
// Model import progress overlay
when (val state = importState) {
is ModelImportState.Importing -> {
ImportProgressOverlay(
progress = state.progress,
filename = state.filename,
onCancel = { /* Implement cancellation if needed */ }
)
} }
}, is ModelImportState.Error -> {
onModelInfoClick = { modelId -> ErrorDialog(
viewModel.viewModelDetails(modelId) message = state.message,
}, onDismiss = { viewModel.resetImportState() }
onModelDeleteClick = { modelId -> )
viewModel.deleteModel(modelId) }
}, is ModelImportState.Success -> {
modifier = Modifier.padding(paddingValues) LaunchedEffect(state) {
) // Show success snackbar or message
// This will auto-dismiss after the delay in viewModel
}
}
else -> { /* Idle state, nothing to show */ }
}
}
} }
} }
@Composable @Composable
private fun ModelList( private fun ModelCardList(
models: List<ModelInfo>, models: List<ModelInfo>,
isMultiSelectionMode: Boolean, isMultiSelectionMode: Boolean,
selectedModels: Map<String, ModelInfo>, selectedModels: Map<String, ModelInfo>,
@ -337,13 +385,16 @@ private fun ModelList(
items = models, items = models,
key = { it.id } key = { it.id }
) { model -> ) { model ->
ModelItem( ModelCard(
model = model, model = model,
isMultiSelectionMode = isMultiSelectionMode, isMultiSelectionMode = isMultiSelectionMode,
isSelected = selectedModels.contains(model.id), isSelected = selectedModels.contains(model.id),
onClick = { onModelClick(model.id) }, onClick = { onModelClick(model.id) },
onInfoClick = { onModelInfoClick(model.id) }, onInfoClick = { onModelInfoClick(model.id) },
onDeleteClick = { onModelDeleteClick(model.id) } onDeleteClick = {
// TODO-han.yin: pop up an AlertDialog asking user for confirmation
onModelDeleteClick(model.id)
}
) )
Spacer(modifier = Modifier.height(8.dp)) Spacer(modifier = Modifier.height(8.dp))
} }
@ -351,7 +402,7 @@ private fun ModelList(
} }
@Composable @Composable
private fun ModelItem( private fun ModelCard(
model: ModelInfo, model: ModelInfo,
isMultiSelectionMode: Boolean, isMultiSelectionMode: Boolean,
isSelected: Boolean, isSelected: Boolean,
@ -424,3 +475,95 @@ private fun ModelItem(
} }
} }
} }
@Composable
fun ImportProgressOverlay(
progress: Float,
filename: String,
onCancel: () -> Unit
) {
Box(
modifier = Modifier
.fillMaxSize()
.background(Color.Black.copy(alpha = 0.7f))
.padding(32.dp),
contentAlignment = Alignment.Center
) {
Card(
modifier = Modifier
.fillMaxWidth()
.padding(16.dp),
elevation = CardDefaults.cardElevation(defaultElevation = 8.dp)
) {
Column(
modifier = Modifier.padding(24.dp),
horizontalAlignment = Alignment.CenterHorizontally
) {
Text(
text = "Importing Model",
style = MaterialTheme.typography.headlineSmall
)
Spacer(modifier = Modifier.height(8.dp))
Text(
text = filename,
style = MaterialTheme.typography.bodyMedium,
maxLines = 1,
overflow = TextOverflow.Ellipsis
)
Spacer(modifier = Modifier.height(24.dp))
LinearProgressIndicator(
progress = { progress },
modifier = Modifier.fillMaxWidth()
)
Spacer(modifier = Modifier.height(8.dp))
Text(
text = "${(progress * 100).toInt()}%",
style = MaterialTheme.typography.bodyLarge
)
Spacer(modifier = Modifier.height(16.dp))
Text(
text = "This may take several minutes for large models",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.height(24.dp))
Button(
onClick = onCancel,
colors = ButtonDefaults.buttonColors(
containerColor = MaterialTheme.colorScheme.errorContainer,
contentColor = MaterialTheme.colorScheme.onErrorContainer
)
) {
Text("Cancel")
}
}
}
}
}
@Composable
fun ErrorDialog(
message: String,
onDismiss: () -> Unit
) {
AlertDialog(
onDismissRequest = onDismiss,
title = { Text("Import Failed") },
text = { Text(message) },
confirmButton = {
Button(onClick = onDismiss) {
Text("OK")
}
}
)
}

View File

@ -1,13 +1,13 @@
package com.example.llama.revamp.viewmodel package com.example.llama.revamp.viewmodel
import android.net.Uri import android.net.Uri
import android.util.Log
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.data.repository.ModelRepository import com.example.llama.revamp.data.repository.ModelRepository
import com.example.llama.revamp.data.repository.StorageMetrics import com.example.llama.revamp.data.repository.StorageMetrics
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
@ -76,16 +76,34 @@ class ModelsManagementViewModel @Inject constructor(
modelRepository.deleteModels(models.keys) modelRepository.deleteModels(models.keys)
} }
private val _importState = MutableStateFlow<ModelImportState>(ModelImportState.Idle)
val importState: StateFlow<ModelImportState> = _importState.asStateFlow()
fun importLocalModel(uri: Uri) = fun importLocalModel(uri: Uri) =
viewModelScope.launch { viewModelScope.launch {
try { try {
modelRepository.importModel(uri) // Get filename for progress updates
val filename = uri.lastPathSegment ?: throw Exception("Model name unknown")
_importState.value = ModelImportState.Importing(0f, filename)
// Import with progress reporting
val model = modelRepository.importModel(uri) { progress ->
_importState.value = ModelImportState.Importing(progress, filename)
}
_importState.value = ModelImportState.Success(model)
// Reset state after a delay
delay(1000)
_importState.value = ModelImportState.Idle
} catch (e: Exception) { } catch (e: Exception) {
// TODO-han.yin: add UI to prompt user about import failure! _importState.value = ModelImportState.Error(e.message ?: "Unknown error")
Log.e(TAG, "Failed to import model from: $uri", e)
} }
} }
fun resetImportState() {
_importState.value = ModelImportState.Idle
}
fun importFromHuggingFace() { fun importFromHuggingFace() {
// TODO-han.yin: Stub for now. Would need to investigate HuggingFace APIs // TODO-han.yin: Stub for now. Would need to investigate HuggingFace APIs
} }
@ -105,4 +123,9 @@ enum class ModelSortOrder {
LAST_USED LAST_USED
} }
sealed class ModelImportState {
object Idle : ModelImportState()
data class Importing(val progress: Float = 0f, val filename: String = "") : ModelImportState()
data class Success(val model: ModelInfo) : ModelImportState()
data class Error(val message: String) : ModelImportState()
}