data: import local model with file picker
This commit is contained in:
parent
a3ebdac58f
commit
5de0b5d6d0
|
|
@ -11,10 +11,10 @@ data class ModelEntity(
|
|||
val name: String,
|
||||
val path: String,
|
||||
val sizeInBytes: Long,
|
||||
val parameters: String,
|
||||
val quantization: String,
|
||||
val type: String,
|
||||
val contextLength: Int,
|
||||
val parameters: String?,
|
||||
val quantization: String?,
|
||||
val type: String?,
|
||||
val contextLength: Int?,
|
||||
val lastUsed: Long?,
|
||||
val dateAdded: Long
|
||||
) {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
package com.example.llama.revamp.data.model
|
||||
|
||||
import java.util.Locale
|
||||
|
||||
/**
|
||||
* Data class containing information about an LLM model.
|
||||
*/
|
||||
|
|
@ -8,10 +10,10 @@ data class ModelInfo(
|
|||
val name: String,
|
||||
val path: String,
|
||||
val sizeInBytes: Long,
|
||||
val parameters: String,
|
||||
val quantization: String,
|
||||
val type: String,
|
||||
val contextLength: Int,
|
||||
val parameters: String?,
|
||||
val quantization: String?,
|
||||
val type: String?,
|
||||
val contextLength: Int?,
|
||||
val lastUsed: Long? = null
|
||||
) {
|
||||
val formattedSize: String
|
||||
|
|
@ -19,15 +21,15 @@ data class ModelInfo(
|
|||
return when {
|
||||
sizeInBytes >= 1_000_000_000 -> {
|
||||
val sizeInGb = sizeInBytes / 1_000_000_000.0
|
||||
String.format("%.2f GB", sizeInGb)
|
||||
String.format(Locale.getDefault(), "%.2f GB", sizeInGb)
|
||||
}
|
||||
sizeInBytes >= 1_000_000 -> {
|
||||
val sizeInMb = sizeInBytes / 1_000_000.0
|
||||
String.format("%.2f MB", sizeInMb)
|
||||
String.format(Locale.getDefault(), "%.2f MB", sizeInMb)
|
||||
}
|
||||
else -> {
|
||||
val sizeInKb = sizeInBytes / 1_000.0
|
||||
String.format("%.2f KB", sizeInKb)
|
||||
String.format(Locale.getDefault(), "%.2f KB", sizeInKb)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,17 +4,31 @@ import android.content.Context
|
|||
import android.net.Uri
|
||||
import android.os.StatFs
|
||||
import android.provider.OpenableColumns
|
||||
import android.util.Log
|
||||
import com.example.llama.revamp.data.local.ModelDao
|
||||
import com.example.llama.revamp.data.local.ModelEntity
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.flow
|
||||
import kotlinx.coroutines.flow.map
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.coroutines.yield
|
||||
import java.io.BufferedInputStream
|
||||
import java.io.BufferedOutputStream
|
||||
import java.io.File
|
||||
import java.io.FileNotFoundException
|
||||
import java.io.FileOutputStream
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
import java.io.OutputStream
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.channels.Channels
|
||||
import java.nio.channels.ReadableByteChannel
|
||||
import java.nio.channels.WritableByteChannel
|
||||
import java.util.Locale
|
||||
import java.util.UUID
|
||||
import javax.inject.Inject
|
||||
|
|
@ -27,9 +41,14 @@ interface ModelRepository {
|
|||
fun getStorageMetrics(): Flow<StorageMetrics>
|
||||
fun getModels(): Flow<List<ModelInfo>>
|
||||
|
||||
suspend fun importModel(uri: Uri): ModelInfo
|
||||
suspend fun importModel(uri: Uri, progressTracker: ImportProgressTracker? = null): ModelInfo
|
||||
|
||||
suspend fun deleteModel(modelId: String)
|
||||
suspend fun deleteModels(modelIds: Collection<String>)
|
||||
|
||||
fun interface ImportProgressTracker {
|
||||
fun onProgress(progress: Float) // 0.0f to 1.0f
|
||||
}
|
||||
}
|
||||
|
||||
@Singleton
|
||||
|
|
@ -67,38 +86,138 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
}
|
||||
}
|
||||
|
||||
override suspend fun importModel(uri: Uri): ModelInfo {
|
||||
// Obtain the local model's file via provided URI
|
||||
val fileName = getFileNameFromUri(uri)
|
||||
override suspend fun importModel(
|
||||
uri: Uri,
|
||||
progressTracker: ImportProgressTracker?
|
||||
): ModelInfo = withContext(Dispatchers.IO) {
|
||||
val fileName = getFileNameFromUri(uri) ?: throw FileNotFoundException("Filename N/A")
|
||||
val fileSize = getFileSizeFromUri(uri) ?: throw FileNotFoundException("File size N/A")
|
||||
val modelFile = File(modelsDir, fileName)
|
||||
|
||||
// Copy file to app's internal storage
|
||||
context.contentResolver.openInputStream(uri)?.use { inputStream ->
|
||||
FileOutputStream(modelFile).use { outputStream ->
|
||||
inputStream.copyTo(outputStream)
|
||||
try {
|
||||
val inputStream = context.contentResolver.openInputStream(uri)
|
||||
?: throw IOException("Failed to open input stream")
|
||||
val outputStream = FileOutputStream(modelFile)
|
||||
|
||||
if (fileSize > LARGE_MODEL_THRESHOLD_SIZE) {
|
||||
Log.i(TAG, "Copying $fileName (size: $fileSize) via NIO...")
|
||||
|
||||
// Use NIO channels for large models
|
||||
copyWithChannels(inputStream, outputStream, fileSize, progressTracker)
|
||||
} else {
|
||||
Log.i(TAG, "Copying $fileName (size: $fileSize) via buffer...")
|
||||
|
||||
// Default copy with buffer for small models
|
||||
val bufferedInput = BufferedInputStream(inputStream, DEFAULT_BUFFER_SIZE)
|
||||
val bufferedOutput = BufferedOutputStream(outputStream, DEFAULT_BUFFER_SIZE)
|
||||
copyWithBuffer(bufferedInput, bufferedOutput, fileSize, progressTracker)
|
||||
|
||||
// Close streams
|
||||
bufferedOutput.flush()
|
||||
bufferedOutput.close()
|
||||
bufferedInput.close()
|
||||
}
|
||||
} ?: throw IOException("Failed to open input stream")
|
||||
|
||||
// Extract model parameters from filename
|
||||
val modelType = extractModelTypeFromFilename(fileName) ?: "unknown"
|
||||
val parameters = extractParametersFromFilename(fileName) ?: "unknown"
|
||||
val quantization = extractQuantizationFromFilename(fileName) ?: "unknown"
|
||||
// Extract model parameters from filename
|
||||
val modelType = extractModelTypeFromFilename(fileName)
|
||||
val parameters = extractParametersFromFilename(fileName)
|
||||
val quantization = extractQuantizationFromFilename(fileName)
|
||||
|
||||
// Create model entity and save via DAO
|
||||
val modelEntity = ModelEntity(
|
||||
id = UUID.randomUUID().toString(),
|
||||
name = fileName.substringBeforeLast('.'),
|
||||
path = modelFile.absolutePath,
|
||||
sizeInBytes = modelFile.length(),
|
||||
parameters = parameters,
|
||||
quantization = quantization,
|
||||
type = modelType,
|
||||
contextLength = DEFAULT_CONTEXT_SIZE,
|
||||
lastUsed = null,
|
||||
dateAdded = System.currentTimeMillis()
|
||||
)
|
||||
modelDao.insertModel(modelEntity)
|
||||
return modelEntity.toModelInfo()
|
||||
// Create model entity and save via DAO
|
||||
ModelEntity(
|
||||
id = UUID.randomUUID().toString(),
|
||||
name = fileName.substringBeforeLast('.'),
|
||||
path = modelFile.absolutePath,
|
||||
sizeInBytes = modelFile.length(),
|
||||
parameters = parameters,
|
||||
quantization = quantization,
|
||||
type = modelType,
|
||||
contextLength = DEFAULT_CONTEXT_SIZE,
|
||||
lastUsed = null,
|
||||
dateAdded = System.currentTimeMillis()
|
||||
).let {
|
||||
modelDao.insertModel(it)
|
||||
it.toModelInfo()
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
// Clean up partially downloaded file if error occurs
|
||||
if (modelFile.exists()) {
|
||||
modelFile.delete()
|
||||
}
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun copyWithChannels(
|
||||
input: InputStream,
|
||||
output: OutputStream,
|
||||
totalSize: Long,
|
||||
progressTracker: ImportProgressTracker?
|
||||
) {
|
||||
val inChannel: ReadableByteChannel = Channels.newChannel(input)
|
||||
val outChannel: WritableByteChannel = Channels.newChannel(output)
|
||||
|
||||
val buffer = ByteBuffer.allocateDirect(NIO_BUFFER_SIZE)
|
||||
var totalBytesRead = 0L
|
||||
|
||||
while (inChannel.read(buffer) != -1) {
|
||||
buffer.flip()
|
||||
while (buffer.hasRemaining()) {
|
||||
outChannel.write(buffer)
|
||||
}
|
||||
totalBytesRead += buffer.position()
|
||||
buffer.clear()
|
||||
|
||||
// Report progress
|
||||
progressTracker?.let {
|
||||
val progress = totalBytesRead.toFloat() / totalSize
|
||||
withContext(Dispatchers.Main) {
|
||||
it.onProgress(progress)
|
||||
}
|
||||
}
|
||||
|
||||
if (totalBytesRead % (NIO_YIELD_SIZE) == 0L) {
|
||||
yield()
|
||||
}
|
||||
}
|
||||
|
||||
inChannel.close()
|
||||
outChannel.close()
|
||||
output.close()
|
||||
input.close()
|
||||
}
|
||||
|
||||
private suspend fun copyWithBuffer(
|
||||
input: BufferedInputStream,
|
||||
output: BufferedOutputStream,
|
||||
totalSize: Long,
|
||||
progressTracker: ImportProgressTracker?
|
||||
) {
|
||||
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
|
||||
|
||||
var bytesRead: Int
|
||||
var totalBytesRead = 0L
|
||||
|
||||
while (input.read(buffer).also { bytesRead = it } != -1) {
|
||||
output.write(buffer, 0, bytesRead)
|
||||
totalBytesRead += bytesRead
|
||||
|
||||
// Report progress
|
||||
if (progressTracker != null) {
|
||||
val progress = totalBytesRead.toFloat() / totalSize
|
||||
withContext(Dispatchers.Main) {
|
||||
progressTracker.onProgress(progress)
|
||||
}
|
||||
}
|
||||
|
||||
// Yield less frequently with larger buffers
|
||||
if (totalBytesRead % (DEFAULT_YIELD_SIZE) == 0L) { // Every 64MB
|
||||
yield()
|
||||
}
|
||||
}
|
||||
|
||||
output.close()
|
||||
input.close()
|
||||
}
|
||||
|
||||
override suspend fun deleteModel(modelId: String) {
|
||||
|
|
@ -132,7 +251,7 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
val totalSpaceBytes: Long
|
||||
get() = StatFs(context.filesDir.path).totalBytes
|
||||
|
||||
private fun getFileNameFromUri(uri: Uri): String =
|
||||
private fun getFileNameFromUri(uri: Uri): String? =
|
||||
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
||||
if (cursor.moveToFirst()) {
|
||||
cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex ->
|
||||
|
|
@ -141,7 +260,21 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
} else {
|
||||
null
|
||||
}
|
||||
} ?: uri.lastPathSegment ?: "unknown_model.gguf"
|
||||
} ?: uri.lastPathSegment
|
||||
|
||||
/**
|
||||
* Gets the file size from a content URI, or returns 0 if size is unknown.
|
||||
*/
|
||||
private fun getFileSizeFromUri(uri: Uri): Long? =
|
||||
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
||||
if (cursor.moveToFirst()) {
|
||||
cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex ->
|
||||
if (sizeIndex != -1) cursor.getLong(sizeIndex) else null
|
||||
}
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to extract parameters by looking for patterns like 7B, 13B, etc.
|
||||
|
|
@ -182,9 +315,18 @@ class ModelRepositoryImpl @Inject constructor(
|
|||
}
|
||||
|
||||
companion object {
|
||||
private val TAG = ModelRepository::class.java.simpleName
|
||||
|
||||
private const val INTERNAL_STORAGE_PATH = "models"
|
||||
|
||||
private const val BYTES_IN_GB = 1024f * 1024f * 1024f
|
||||
|
||||
private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024
|
||||
private const val NIO_BUFFER_SIZE = 32 * 1024 * 1024
|
||||
private const val NIO_YIELD_SIZE = 128 * 1024 * 1024
|
||||
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
|
||||
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
|
||||
|
||||
private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L
|
||||
private const val DEFAULT_CONTEXT_SIZE = 8192
|
||||
|
||||
|
|
|
|||
|
|
@ -99,13 +99,13 @@ fun ModelCard(
|
|||
|
||||
Row {
|
||||
Text(
|
||||
text = model.parameters,
|
||||
text = model.parameters ?: " - ",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Text(
|
||||
text = " • ${model.quantization}",
|
||||
text = " • ${model.quantization ?: " - "}",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
|
@ -120,7 +120,7 @@ fun ModelCard(
|
|||
Spacer(modifier = Modifier.height(4.dp))
|
||||
|
||||
Text(
|
||||
text = "Context Length: ${model.contextLength}",
|
||||
text = "Context Length: ${model.contextLength ?: " - "}",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
package com.example.llama.revamp.ui.screens
|
||||
|
||||
import androidx.activity.compose.BackHandler
|
||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||
import androidx.activity.result.contract.ActivityResultContracts
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
|
|
@ -23,7 +28,10 @@ import androidx.compose.material.icons.filled.FilterAlt
|
|||
import androidx.compose.material.icons.filled.FolderOpen
|
||||
import androidx.compose.material.icons.filled.Info
|
||||
import androidx.compose.material.icons.filled.SelectAll
|
||||
import androidx.compose.material3.AlertDialog
|
||||
import androidx.compose.material3.BottomAppBar
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.ButtonDefaults
|
||||
import androidx.compose.material3.Card
|
||||
import androidx.compose.material3.CardDefaults
|
||||
import androidx.compose.material3.Checkbox
|
||||
|
|
@ -32,9 +40,11 @@ import androidx.compose.material3.DropdownMenuItem
|
|||
import androidx.compose.material3.FloatingActionButton
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.LinearProgressIndicator
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateMapOf
|
||||
|
|
@ -45,6 +55,7 @@ import androidx.compose.ui.Alignment
|
|||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.res.painterResource
|
||||
import androidx.compose.ui.text.style.TextOverflow
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
|
|
@ -55,6 +66,7 @@ import java.text.SimpleDateFormat
|
|||
import java.util.Date
|
||||
import java.util.Locale
|
||||
import com.example.llama.R
|
||||
import com.example.llama.revamp.viewmodel.ModelImportState
|
||||
|
||||
/**
|
||||
* Screen for managing LLM models (view, download, delete)
|
||||
|
|
@ -66,13 +78,24 @@ fun ModelsManagementScreen(
|
|||
) {
|
||||
val storageMetrics by viewModel.storageMetrics.collectAsState()
|
||||
val sortedModels by viewModel.sortedModels.collectAsState()
|
||||
|
||||
// Model sorting
|
||||
val sortOrder by viewModel.sortOrder.collectAsState()
|
||||
|
||||
// UI: menu states
|
||||
var showSortMenu by remember { mutableStateOf(false) }
|
||||
var showAddModelMenu by remember { mutableStateOf(false) }
|
||||
|
||||
// UI: multi-selection states
|
||||
// Model importing
|
||||
val importState by viewModel.importState.collectAsState()
|
||||
|
||||
var showImportModelMenu by remember { mutableStateOf(false) }
|
||||
val fileLauncher = rememberLauncherForActivityResult(
|
||||
contract = ActivityResultContracts.OpenDocument()
|
||||
) { uri -> uri?.let { viewModel.importLocalModel(it) } }
|
||||
|
||||
BackHandler(enabled = importState is ModelImportState.Importing) {
|
||||
/* Ignore back press while importing model */
|
||||
}
|
||||
|
||||
// Multi-selection
|
||||
var isMultiSelectionMode by remember { mutableStateOf(false) }
|
||||
val selectedModels = remember { mutableStateMapOf<String, ModelInfo>() }
|
||||
val exitSelectionMode = {
|
||||
|
|
@ -238,7 +261,7 @@ fun ModelsManagementScreen(
|
|||
if (isMultiSelectionMode) {
|
||||
exitSelectionMode()
|
||||
} else {
|
||||
showAddModelMenu = true
|
||||
showImportModelMenu = true
|
||||
}
|
||||
},
|
||||
containerColor = MaterialTheme.colorScheme.primaryContainer
|
||||
|
|
@ -251,8 +274,8 @@ fun ModelsManagementScreen(
|
|||
|
||||
// Add model dropdown menu
|
||||
DropdownMenu(
|
||||
expanded = showAddModelMenu,
|
||||
onDismissRequest = { showAddModelMenu = false }
|
||||
expanded = showImportModelMenu,
|
||||
onDismissRequest = { showImportModelMenu = false }
|
||||
) {
|
||||
DropdownMenuItem(
|
||||
text = { Text("Import local model") },
|
||||
|
|
@ -263,9 +286,8 @@ fun ModelsManagementScreen(
|
|||
)
|
||||
},
|
||||
onClick = {
|
||||
// TODO-han.yin: uncomment once file picker done
|
||||
// viewModel.importLocalModel()
|
||||
showAddModelMenu = false
|
||||
fileLauncher.launch(arrayOf("application/octet-stream", "*/*"))
|
||||
showImportModelMenu = false
|
||||
}
|
||||
)
|
||||
DropdownMenuItem(
|
||||
|
|
@ -280,7 +302,7 @@ fun ModelsManagementScreen(
|
|||
},
|
||||
onClick = {
|
||||
viewModel.importFromHuggingFace()
|
||||
showAddModelMenu = false
|
||||
showImportModelMenu = false
|
||||
}
|
||||
)
|
||||
}
|
||||
|
|
@ -288,38 +310,64 @@ fun ModelsManagementScreen(
|
|||
)
|
||||
},
|
||||
) { paddingValues ->
|
||||
// Main content
|
||||
ModelList(
|
||||
models = sortedModels,
|
||||
isMultiSelectionMode = isMultiSelectionMode,
|
||||
selectedModels = selectedModels,
|
||||
onModelClick = { modelId ->
|
||||
if (isMultiSelectionMode) {
|
||||
// Toggle selection
|
||||
if (selectedModels.contains(modelId)) {
|
||||
selectedModels.remove(modelId)
|
||||
Box(modifier = Modifier.fillMaxSize()) {
|
||||
// Model cards
|
||||
ModelCardList(
|
||||
models = sortedModels,
|
||||
isMultiSelectionMode = isMultiSelectionMode,
|
||||
selectedModels = selectedModels,
|
||||
onModelClick = { modelId ->
|
||||
if (isMultiSelectionMode) {
|
||||
// Toggle selection
|
||||
if (selectedModels.contains(modelId)) {
|
||||
selectedModels.remove(modelId)
|
||||
} else {
|
||||
selectedModels.put(modelId, sortedModels.first { it.id == modelId } )
|
||||
}
|
||||
} else {
|
||||
selectedModels.put(modelId, sortedModels.first { it.id == modelId } )
|
||||
// View model details
|
||||
viewModel.viewModelDetails(modelId)
|
||||
}
|
||||
} else {
|
||||
// View model details
|
||||
},
|
||||
onModelInfoClick = { modelId ->
|
||||
viewModel.viewModelDetails(modelId)
|
||||
},
|
||||
onModelDeleteClick = { modelId ->
|
||||
viewModel.deleteModel(modelId)
|
||||
},
|
||||
modifier = Modifier.padding(paddingValues)
|
||||
)
|
||||
|
||||
// Model import progress overlay
|
||||
when (val state = importState) {
|
||||
is ModelImportState.Importing -> {
|
||||
ImportProgressOverlay(
|
||||
progress = state.progress,
|
||||
filename = state.filename,
|
||||
onCancel = { /* Implement cancellation if needed */ }
|
||||
)
|
||||
}
|
||||
},
|
||||
onModelInfoClick = { modelId ->
|
||||
viewModel.viewModelDetails(modelId)
|
||||
},
|
||||
onModelDeleteClick = { modelId ->
|
||||
viewModel.deleteModel(modelId)
|
||||
},
|
||||
modifier = Modifier.padding(paddingValues)
|
||||
)
|
||||
is ModelImportState.Error -> {
|
||||
ErrorDialog(
|
||||
message = state.message,
|
||||
onDismiss = { viewModel.resetImportState() }
|
||||
)
|
||||
}
|
||||
is ModelImportState.Success -> {
|
||||
LaunchedEffect(state) {
|
||||
// Show success snackbar or message
|
||||
// This will auto-dismiss after the delay in viewModel
|
||||
}
|
||||
}
|
||||
else -> { /* Idle state, nothing to show */ }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Composable
|
||||
private fun ModelList(
|
||||
private fun ModelCardList(
|
||||
models: List<ModelInfo>,
|
||||
isMultiSelectionMode: Boolean,
|
||||
selectedModels: Map<String, ModelInfo>,
|
||||
|
|
@ -337,13 +385,16 @@ private fun ModelList(
|
|||
items = models,
|
||||
key = { it.id }
|
||||
) { model ->
|
||||
ModelItem(
|
||||
ModelCard(
|
||||
model = model,
|
||||
isMultiSelectionMode = isMultiSelectionMode,
|
||||
isSelected = selectedModels.contains(model.id),
|
||||
onClick = { onModelClick(model.id) },
|
||||
onInfoClick = { onModelInfoClick(model.id) },
|
||||
onDeleteClick = { onModelDeleteClick(model.id) }
|
||||
onDeleteClick = {
|
||||
// TODO-han.yin: pop up an AlertDialog asking user for confirmation
|
||||
onModelDeleteClick(model.id)
|
||||
}
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
}
|
||||
|
|
@ -351,7 +402,7 @@ private fun ModelList(
|
|||
}
|
||||
|
||||
@Composable
|
||||
private fun ModelItem(
|
||||
private fun ModelCard(
|
||||
model: ModelInfo,
|
||||
isMultiSelectionMode: Boolean,
|
||||
isSelected: Boolean,
|
||||
|
|
@ -424,3 +475,95 @@ private fun ModelItem(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ImportProgressOverlay(
|
||||
progress: Float,
|
||||
filename: String,
|
||||
onCancel: () -> Unit
|
||||
) {
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.background(Color.Black.copy(alpha = 0.7f))
|
||||
.padding(32.dp),
|
||||
contentAlignment = Alignment.Center
|
||||
) {
|
||||
Card(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(16.dp),
|
||||
elevation = CardDefaults.cardElevation(defaultElevation = 8.dp)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(24.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
Text(
|
||||
text = "Importing Model",
|
||||
style = MaterialTheme.typography.headlineSmall
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
||||
Text(
|
||||
text = filename,
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
maxLines = 1,
|
||||
overflow = TextOverflow.Ellipsis
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
|
||||
LinearProgressIndicator(
|
||||
progress = { progress },
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
||||
Text(
|
||||
text = "${(progress * 100).toInt()}%",
|
||||
style = MaterialTheme.typography.bodyLarge
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
Text(
|
||||
text = "This may take several minutes for large models",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
|
||||
Button(
|
||||
onClick = onCancel,
|
||||
colors = ButtonDefaults.buttonColors(
|
||||
containerColor = MaterialTheme.colorScheme.errorContainer,
|
||||
contentColor = MaterialTheme.colorScheme.onErrorContainer
|
||||
)
|
||||
) {
|
||||
Text("Cancel")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ErrorDialog(
|
||||
message: String,
|
||||
onDismiss: () -> Unit
|
||||
) {
|
||||
AlertDialog(
|
||||
onDismissRequest = onDismiss,
|
||||
title = { Text("Import Failed") },
|
||||
text = { Text(message) },
|
||||
confirmButton = {
|
||||
Button(onClick = onDismiss) {
|
||||
Text("OK")
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
package com.example.llama.revamp.viewmodel
|
||||
|
||||
import android.net.Uri
|
||||
import android.util.Log
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
import com.example.llama.revamp.data.repository.ModelRepository
|
||||
import com.example.llama.revamp.data.repository.StorageMetrics
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.SharingStarted
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
|
|
@ -76,16 +76,34 @@ class ModelsManagementViewModel @Inject constructor(
|
|||
modelRepository.deleteModels(models.keys)
|
||||
}
|
||||
|
||||
private val _importState = MutableStateFlow<ModelImportState>(ModelImportState.Idle)
|
||||
val importState: StateFlow<ModelImportState> = _importState.asStateFlow()
|
||||
|
||||
fun importLocalModel(uri: Uri) =
|
||||
viewModelScope.launch {
|
||||
try {
|
||||
modelRepository.importModel(uri)
|
||||
// Get filename for progress updates
|
||||
val filename = uri.lastPathSegment ?: throw Exception("Model name unknown")
|
||||
_importState.value = ModelImportState.Importing(0f, filename)
|
||||
|
||||
// Import with progress reporting
|
||||
val model = modelRepository.importModel(uri) { progress ->
|
||||
_importState.value = ModelImportState.Importing(progress, filename)
|
||||
}
|
||||
_importState.value = ModelImportState.Success(model)
|
||||
|
||||
// Reset state after a delay
|
||||
delay(1000)
|
||||
_importState.value = ModelImportState.Idle
|
||||
} catch (e: Exception) {
|
||||
// TODO-han.yin: add UI to prompt user about import failure!
|
||||
Log.e(TAG, "Failed to import model from: $uri", e)
|
||||
_importState.value = ModelImportState.Error(e.message ?: "Unknown error")
|
||||
}
|
||||
}
|
||||
|
||||
fun resetImportState() {
|
||||
_importState.value = ModelImportState.Idle
|
||||
}
|
||||
|
||||
fun importFromHuggingFace() {
|
||||
// TODO-han.yin: Stub for now. Would need to investigate HuggingFace APIs
|
||||
}
|
||||
|
|
@ -105,4 +123,9 @@ enum class ModelSortOrder {
|
|||
LAST_USED
|
||||
}
|
||||
|
||||
|
||||
sealed class ModelImportState {
|
||||
object Idle : ModelImportState()
|
||||
data class Importing(val progress: Float = 0f, val filename: String = "") : ModelImportState()
|
||||
data class Success(val model: ModelInfo) : ModelImportState()
|
||||
data class Error(val message: String) : ModelImportState()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue