data: import local model with file picker

This commit is contained in:
Han Yin 2025-04-14 15:44:42 -07:00
parent a3ebdac58f
commit 5de0b5d6d0
6 changed files with 396 additions and 86 deletions

View File

@ -11,10 +11,10 @@ data class ModelEntity(
val name: String,
val path: String,
val sizeInBytes: Long,
val parameters: String,
val quantization: String,
val type: String,
val contextLength: Int,
val parameters: String?,
val quantization: String?,
val type: String?,
val contextLength: Int?,
val lastUsed: Long?,
val dateAdded: Long
) {

View File

@ -1,5 +1,7 @@
package com.example.llama.revamp.data.model
import java.util.Locale
/**
* Data class containing information about an LLM model.
*/
@ -8,10 +10,10 @@ data class ModelInfo(
val name: String,
val path: String,
val sizeInBytes: Long,
val parameters: String,
val quantization: String,
val type: String,
val contextLength: Int,
val parameters: String?,
val quantization: String?,
val type: String?,
val contextLength: Int?,
val lastUsed: Long? = null
) {
val formattedSize: String
@ -19,15 +21,15 @@ data class ModelInfo(
return when {
sizeInBytes >= 1_000_000_000 -> {
val sizeInGb = sizeInBytes / 1_000_000_000.0
String.format("%.2f GB", sizeInGb)
String.format(Locale.getDefault(), "%.2f GB", sizeInGb)
}
sizeInBytes >= 1_000_000 -> {
val sizeInMb = sizeInBytes / 1_000_000.0
String.format("%.2f MB", sizeInMb)
String.format(Locale.getDefault(), "%.2f MB", sizeInMb)
}
else -> {
val sizeInKb = sizeInBytes / 1_000.0
String.format("%.2f KB", sizeInKb)
String.format(Locale.getDefault(), "%.2f KB", sizeInKb)
}
}
}

View File

@ -4,17 +4,31 @@ import android.content.Context
import android.net.Uri
import android.os.StatFs
import android.provider.OpenableColumns
import android.util.Log
import com.example.llama.revamp.data.local.ModelDao
import com.example.llama.revamp.data.local.ModelEntity
import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
import java.io.BufferedInputStream
import java.io.BufferedOutputStream
import java.io.File
import java.io.FileNotFoundException
import java.io.FileOutputStream
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.nio.channels.ReadableByteChannel
import java.nio.channels.WritableByteChannel
import java.util.Locale
import java.util.UUID
import javax.inject.Inject
@ -27,9 +41,14 @@ interface ModelRepository {
fun getStorageMetrics(): Flow<StorageMetrics>
fun getModels(): Flow<List<ModelInfo>>
suspend fun importModel(uri: Uri): ModelInfo
suspend fun importModel(uri: Uri, progressTracker: ImportProgressTracker? = null): ModelInfo
suspend fun deleteModel(modelId: String)
suspend fun deleteModels(modelIds: Collection<String>)
fun interface ImportProgressTracker {
fun onProgress(progress: Float) // 0.0f to 1.0f
}
}
@Singleton
@ -67,38 +86,138 @@ class ModelRepositoryImpl @Inject constructor(
}
}
override suspend fun importModel(uri: Uri): ModelInfo {
// Obtain the local model's file via provided URI
val fileName = getFileNameFromUri(uri)
override suspend fun importModel(
uri: Uri,
progressTracker: ImportProgressTracker?
): ModelInfo = withContext(Dispatchers.IO) {
val fileName = getFileNameFromUri(uri) ?: throw FileNotFoundException("Filename N/A")
val fileSize = getFileSizeFromUri(uri) ?: throw FileNotFoundException("File size N/A")
val modelFile = File(modelsDir, fileName)
// Copy file to app's internal storage
context.contentResolver.openInputStream(uri)?.use { inputStream ->
FileOutputStream(modelFile).use { outputStream ->
inputStream.copyTo(outputStream)
try {
val inputStream = context.contentResolver.openInputStream(uri)
?: throw IOException("Failed to open input stream")
val outputStream = FileOutputStream(modelFile)
if (fileSize > LARGE_MODEL_THRESHOLD_SIZE) {
Log.i(TAG, "Copying $fileName (size: $fileSize) via NIO...")
// Use NIO channels for large models
copyWithChannels(inputStream, outputStream, fileSize, progressTracker)
} else {
Log.i(TAG, "Copying $fileName (size: $fileSize) via buffer...")
// Default copy with buffer for small models
val bufferedInput = BufferedInputStream(inputStream, DEFAULT_BUFFER_SIZE)
val bufferedOutput = BufferedOutputStream(outputStream, DEFAULT_BUFFER_SIZE)
copyWithBuffer(bufferedInput, bufferedOutput, fileSize, progressTracker)
// Close streams
bufferedOutput.flush()
bufferedOutput.close()
bufferedInput.close()
}
} ?: throw IOException("Failed to open input stream")
// Extract model parameters from filename
val modelType = extractModelTypeFromFilename(fileName) ?: "unknown"
val parameters = extractParametersFromFilename(fileName) ?: "unknown"
val quantization = extractQuantizationFromFilename(fileName) ?: "unknown"
// Extract model parameters from filename
val modelType = extractModelTypeFromFilename(fileName)
val parameters = extractParametersFromFilename(fileName)
val quantization = extractQuantizationFromFilename(fileName)
// Create model entity and save via DAO
val modelEntity = ModelEntity(
id = UUID.randomUUID().toString(),
name = fileName.substringBeforeLast('.'),
path = modelFile.absolutePath,
sizeInBytes = modelFile.length(),
parameters = parameters,
quantization = quantization,
type = modelType,
contextLength = DEFAULT_CONTEXT_SIZE,
lastUsed = null,
dateAdded = System.currentTimeMillis()
)
modelDao.insertModel(modelEntity)
return modelEntity.toModelInfo()
// Create model entity and save via DAO
ModelEntity(
id = UUID.randomUUID().toString(),
name = fileName.substringBeforeLast('.'),
path = modelFile.absolutePath,
sizeInBytes = modelFile.length(),
parameters = parameters,
quantization = quantization,
type = modelType,
contextLength = DEFAULT_CONTEXT_SIZE,
lastUsed = null,
dateAdded = System.currentTimeMillis()
).let {
modelDao.insertModel(it)
it.toModelInfo()
}
} catch (e: Exception) {
// Clean up partially downloaded file if error occurs
if (modelFile.exists()) {
modelFile.delete()
}
throw e
}
}
private suspend fun copyWithChannels(
input: InputStream,
output: OutputStream,
totalSize: Long,
progressTracker: ImportProgressTracker?
) {
val inChannel: ReadableByteChannel = Channels.newChannel(input)
val outChannel: WritableByteChannel = Channels.newChannel(output)
val buffer = ByteBuffer.allocateDirect(NIO_BUFFER_SIZE)
var totalBytesRead = 0L
while (inChannel.read(buffer) != -1) {
buffer.flip()
while (buffer.hasRemaining()) {
outChannel.write(buffer)
}
totalBytesRead += buffer.position()
buffer.clear()
// Report progress
progressTracker?.let {
val progress = totalBytesRead.toFloat() / totalSize
withContext(Dispatchers.Main) {
it.onProgress(progress)
}
}
if (totalBytesRead % (NIO_YIELD_SIZE) == 0L) {
yield()
}
}
inChannel.close()
outChannel.close()
output.close()
input.close()
}
private suspend fun copyWithBuffer(
input: BufferedInputStream,
output: BufferedOutputStream,
totalSize: Long,
progressTracker: ImportProgressTracker?
) {
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
var bytesRead: Int
var totalBytesRead = 0L
while (input.read(buffer).also { bytesRead = it } != -1) {
output.write(buffer, 0, bytesRead)
totalBytesRead += bytesRead
// Report progress
if (progressTracker != null) {
val progress = totalBytesRead.toFloat() / totalSize
withContext(Dispatchers.Main) {
progressTracker.onProgress(progress)
}
}
// Yield less frequently with larger buffers
if (totalBytesRead % (DEFAULT_YIELD_SIZE) == 0L) { // Every 64MB
yield()
}
}
output.close()
input.close()
}
override suspend fun deleteModel(modelId: String) {
@ -132,7 +251,7 @@ class ModelRepositoryImpl @Inject constructor(
val totalSpaceBytes: Long
get() = StatFs(context.filesDir.path).totalBytes
private fun getFileNameFromUri(uri: Uri): String =
private fun getFileNameFromUri(uri: Uri): String? =
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
if (cursor.moveToFirst()) {
cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex ->
@ -141,7 +260,21 @@ class ModelRepositoryImpl @Inject constructor(
} else {
null
}
} ?: uri.lastPathSegment ?: "unknown_model.gguf"
} ?: uri.lastPathSegment
/**
* Gets the file size from a content URI, or returns 0 if size is unknown.
*/
private fun getFileSizeFromUri(uri: Uri): Long? =
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
if (cursor.moveToFirst()) {
cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex ->
if (sizeIndex != -1) cursor.getLong(sizeIndex) else null
}
} else {
null
}
}
/**
* Try to extract parameters by looking for patterns like 7B, 13B, etc.
@ -182,9 +315,18 @@ class ModelRepositoryImpl @Inject constructor(
}
companion object {
private val TAG = ModelRepository::class.java.simpleName
private const val INTERNAL_STORAGE_PATH = "models"
private const val BYTES_IN_GB = 1024f * 1024f * 1024f
private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024
private const val NIO_BUFFER_SIZE = 32 * 1024 * 1024
private const val NIO_YIELD_SIZE = 128 * 1024 * 1024
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L
private const val DEFAULT_CONTEXT_SIZE = 8192

View File

@ -99,13 +99,13 @@ fun ModelCard(
Row {
Text(
text = model.parameters,
text = model.parameters ?: " - ",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
text = "${model.quantization}",
text = "${model.quantization ?: " - "}",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
@ -120,7 +120,7 @@ fun ModelCard(
Spacer(modifier = Modifier.height(4.dp))
Text(
text = "Context Length: ${model.contextLength}",
text = "Context Length: ${model.contextLength ?: " - "}",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)

View File

@ -1,6 +1,11 @@
package com.example.llama.revamp.ui.screens
import androidx.activity.compose.BackHandler
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
@ -23,7 +28,10 @@ import androidx.compose.material.icons.filled.FilterAlt
import androidx.compose.material.icons.filled.FolderOpen
import androidx.compose.material.icons.filled.Info
import androidx.compose.material.icons.filled.SelectAll
import androidx.compose.material3.AlertDialog
import androidx.compose.material3.BottomAppBar
import androidx.compose.material3.Button
import androidx.compose.material3.ButtonDefaults
import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults
import androidx.compose.material3.Checkbox
@ -32,9 +40,11 @@ import androidx.compose.material3.DropdownMenuItem
import androidx.compose.material3.FloatingActionButton
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.LinearProgressIndicator
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateMapOf
@ -45,6 +55,7 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.data.model.ModelInfo
@ -55,6 +66,7 @@ import java.text.SimpleDateFormat
import java.util.Date
import java.util.Locale
import com.example.llama.R
import com.example.llama.revamp.viewmodel.ModelImportState
/**
* Screen for managing LLM models (view, download, delete)
@ -66,13 +78,24 @@ fun ModelsManagementScreen(
) {
val storageMetrics by viewModel.storageMetrics.collectAsState()
val sortedModels by viewModel.sortedModels.collectAsState()
// Model sorting
val sortOrder by viewModel.sortOrder.collectAsState()
// UI: menu states
var showSortMenu by remember { mutableStateOf(false) }
var showAddModelMenu by remember { mutableStateOf(false) }
// UI: multi-selection states
// Model importing
val importState by viewModel.importState.collectAsState()
var showImportModelMenu by remember { mutableStateOf(false) }
val fileLauncher = rememberLauncherForActivityResult(
contract = ActivityResultContracts.OpenDocument()
) { uri -> uri?.let { viewModel.importLocalModel(it) } }
BackHandler(enabled = importState is ModelImportState.Importing) {
/* Ignore back press while importing model */
}
// Multi-selection
var isMultiSelectionMode by remember { mutableStateOf(false) }
val selectedModels = remember { mutableStateMapOf<String, ModelInfo>() }
val exitSelectionMode = {
@ -238,7 +261,7 @@ fun ModelsManagementScreen(
if (isMultiSelectionMode) {
exitSelectionMode()
} else {
showAddModelMenu = true
showImportModelMenu = true
}
},
containerColor = MaterialTheme.colorScheme.primaryContainer
@ -251,8 +274,8 @@ fun ModelsManagementScreen(
// Add model dropdown menu
DropdownMenu(
expanded = showAddModelMenu,
onDismissRequest = { showAddModelMenu = false }
expanded = showImportModelMenu,
onDismissRequest = { showImportModelMenu = false }
) {
DropdownMenuItem(
text = { Text("Import local model") },
@ -263,9 +286,8 @@ fun ModelsManagementScreen(
)
},
onClick = {
// TODO-han.yin: uncomment once file picker done
// viewModel.importLocalModel()
showAddModelMenu = false
fileLauncher.launch(arrayOf("application/octet-stream", "*/*"))
showImportModelMenu = false
}
)
DropdownMenuItem(
@ -280,7 +302,7 @@ fun ModelsManagementScreen(
},
onClick = {
viewModel.importFromHuggingFace()
showAddModelMenu = false
showImportModelMenu = false
}
)
}
@ -288,38 +310,64 @@ fun ModelsManagementScreen(
)
},
) { paddingValues ->
// Main content
ModelList(
models = sortedModels,
isMultiSelectionMode = isMultiSelectionMode,
selectedModels = selectedModels,
onModelClick = { modelId ->
if (isMultiSelectionMode) {
// Toggle selection
if (selectedModels.contains(modelId)) {
selectedModels.remove(modelId)
Box(modifier = Modifier.fillMaxSize()) {
// Model cards
ModelCardList(
models = sortedModels,
isMultiSelectionMode = isMultiSelectionMode,
selectedModels = selectedModels,
onModelClick = { modelId ->
if (isMultiSelectionMode) {
// Toggle selection
if (selectedModels.contains(modelId)) {
selectedModels.remove(modelId)
} else {
selectedModels.put(modelId, sortedModels.first { it.id == modelId } )
}
} else {
selectedModels.put(modelId, sortedModels.first { it.id == modelId } )
// View model details
viewModel.viewModelDetails(modelId)
}
} else {
// View model details
},
onModelInfoClick = { modelId ->
viewModel.viewModelDetails(modelId)
},
onModelDeleteClick = { modelId ->
viewModel.deleteModel(modelId)
},
modifier = Modifier.padding(paddingValues)
)
// Model import progress overlay
when (val state = importState) {
is ModelImportState.Importing -> {
ImportProgressOverlay(
progress = state.progress,
filename = state.filename,
onCancel = { /* Implement cancellation if needed */ }
)
}
},
onModelInfoClick = { modelId ->
viewModel.viewModelDetails(modelId)
},
onModelDeleteClick = { modelId ->
viewModel.deleteModel(modelId)
},
modifier = Modifier.padding(paddingValues)
)
is ModelImportState.Error -> {
ErrorDialog(
message = state.message,
onDismiss = { viewModel.resetImportState() }
)
}
is ModelImportState.Success -> {
LaunchedEffect(state) {
// Show success snackbar or message
// This will auto-dismiss after the delay in viewModel
}
}
else -> { /* Idle state, nothing to show */ }
}
}
}
}
@Composable
private fun ModelList(
private fun ModelCardList(
models: List<ModelInfo>,
isMultiSelectionMode: Boolean,
selectedModels: Map<String, ModelInfo>,
@ -337,13 +385,16 @@ private fun ModelList(
items = models,
key = { it.id }
) { model ->
ModelItem(
ModelCard(
model = model,
isMultiSelectionMode = isMultiSelectionMode,
isSelected = selectedModels.contains(model.id),
onClick = { onModelClick(model.id) },
onInfoClick = { onModelInfoClick(model.id) },
onDeleteClick = { onModelDeleteClick(model.id) }
onDeleteClick = {
// TODO-han.yin: pop up an AlertDialog asking user for confirmation
onModelDeleteClick(model.id)
}
)
Spacer(modifier = Modifier.height(8.dp))
}
@ -351,7 +402,7 @@ private fun ModelList(
}
@Composable
private fun ModelItem(
private fun ModelCard(
model: ModelInfo,
isMultiSelectionMode: Boolean,
isSelected: Boolean,
@ -424,3 +475,95 @@ private fun ModelItem(
}
}
}
@Composable
fun ImportProgressOverlay(
progress: Float,
filename: String,
onCancel: () -> Unit
) {
Box(
modifier = Modifier
.fillMaxSize()
.background(Color.Black.copy(alpha = 0.7f))
.padding(32.dp),
contentAlignment = Alignment.Center
) {
Card(
modifier = Modifier
.fillMaxWidth()
.padding(16.dp),
elevation = CardDefaults.cardElevation(defaultElevation = 8.dp)
) {
Column(
modifier = Modifier.padding(24.dp),
horizontalAlignment = Alignment.CenterHorizontally
) {
Text(
text = "Importing Model",
style = MaterialTheme.typography.headlineSmall
)
Spacer(modifier = Modifier.height(8.dp))
Text(
text = filename,
style = MaterialTheme.typography.bodyMedium,
maxLines = 1,
overflow = TextOverflow.Ellipsis
)
Spacer(modifier = Modifier.height(24.dp))
LinearProgressIndicator(
progress = { progress },
modifier = Modifier.fillMaxWidth()
)
Spacer(modifier = Modifier.height(8.dp))
Text(
text = "${(progress * 100).toInt()}%",
style = MaterialTheme.typography.bodyLarge
)
Spacer(modifier = Modifier.height(16.dp))
Text(
text = "This may take several minutes for large models",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.height(24.dp))
Button(
onClick = onCancel,
colors = ButtonDefaults.buttonColors(
containerColor = MaterialTheme.colorScheme.errorContainer,
contentColor = MaterialTheme.colorScheme.onErrorContainer
)
) {
Text("Cancel")
}
}
}
}
}
@Composable
fun ErrorDialog(
message: String,
onDismiss: () -> Unit
) {
AlertDialog(
onDismissRequest = onDismiss,
title = { Text("Import Failed") },
text = { Text(message) },
confirmButton = {
Button(onClick = onDismiss) {
Text("OK")
}
}
)
}

View File

@ -1,13 +1,13 @@
package com.example.llama.revamp.viewmodel
import android.net.Uri
import android.util.Log
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.data.repository.ModelRepository
import com.example.llama.revamp.data.repository.StorageMetrics
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
@ -76,16 +76,34 @@ class ModelsManagementViewModel @Inject constructor(
modelRepository.deleteModels(models.keys)
}
private val _importState = MutableStateFlow<ModelImportState>(ModelImportState.Idle)
val importState: StateFlow<ModelImportState> = _importState.asStateFlow()
fun importLocalModel(uri: Uri) =
viewModelScope.launch {
try {
modelRepository.importModel(uri)
// Get filename for progress updates
val filename = uri.lastPathSegment ?: throw Exception("Model name unknown")
_importState.value = ModelImportState.Importing(0f, filename)
// Import with progress reporting
val model = modelRepository.importModel(uri) { progress ->
_importState.value = ModelImportState.Importing(progress, filename)
}
_importState.value = ModelImportState.Success(model)
// Reset state after a delay
delay(1000)
_importState.value = ModelImportState.Idle
} catch (e: Exception) {
// TODO-han.yin: add UI to prompt user about import failure!
Log.e(TAG, "Failed to import model from: $uri", e)
_importState.value = ModelImportState.Error(e.message ?: "Unknown error")
}
}
fun resetImportState() {
_importState.value = ModelImportState.Idle
}
fun importFromHuggingFace() {
// TODO-han.yin: Stub for now. Would need to investigate HuggingFace APIs
}
@ -105,4 +123,9 @@ enum class ModelSortOrder {
LAST_USED
}
sealed class ModelImportState {
object Idle : ModelImportState()
data class Importing(val progress: Float = 0f, val filename: String = "") : ModelImportState()
data class Success(val model: ModelInfo) : ModelImportState()
data class Error(val message: String) : ModelImportState()
}