data: import local model with file picker
This commit is contained in:
parent
a3ebdac58f
commit
5de0b5d6d0
|
|
@ -11,10 +11,10 @@ data class ModelEntity(
|
||||||
val name: String,
|
val name: String,
|
||||||
val path: String,
|
val path: String,
|
||||||
val sizeInBytes: Long,
|
val sizeInBytes: Long,
|
||||||
val parameters: String,
|
val parameters: String?,
|
||||||
val quantization: String,
|
val quantization: String?,
|
||||||
val type: String,
|
val type: String?,
|
||||||
val contextLength: Int,
|
val contextLength: Int?,
|
||||||
val lastUsed: Long?,
|
val lastUsed: Long?,
|
||||||
val dateAdded: Long
|
val dateAdded: Long
|
||||||
) {
|
) {
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
package com.example.llama.revamp.data.model
|
package com.example.llama.revamp.data.model
|
||||||
|
|
||||||
|
import java.util.Locale
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Data class containing information about an LLM model.
|
* Data class containing information about an LLM model.
|
||||||
*/
|
*/
|
||||||
|
|
@ -8,10 +10,10 @@ data class ModelInfo(
|
||||||
val name: String,
|
val name: String,
|
||||||
val path: String,
|
val path: String,
|
||||||
val sizeInBytes: Long,
|
val sizeInBytes: Long,
|
||||||
val parameters: String,
|
val parameters: String?,
|
||||||
val quantization: String,
|
val quantization: String?,
|
||||||
val type: String,
|
val type: String?,
|
||||||
val contextLength: Int,
|
val contextLength: Int?,
|
||||||
val lastUsed: Long? = null
|
val lastUsed: Long? = null
|
||||||
) {
|
) {
|
||||||
val formattedSize: String
|
val formattedSize: String
|
||||||
|
|
@ -19,15 +21,15 @@ data class ModelInfo(
|
||||||
return when {
|
return when {
|
||||||
sizeInBytes >= 1_000_000_000 -> {
|
sizeInBytes >= 1_000_000_000 -> {
|
||||||
val sizeInGb = sizeInBytes / 1_000_000_000.0
|
val sizeInGb = sizeInBytes / 1_000_000_000.0
|
||||||
String.format("%.2f GB", sizeInGb)
|
String.format(Locale.getDefault(), "%.2f GB", sizeInGb)
|
||||||
}
|
}
|
||||||
sizeInBytes >= 1_000_000 -> {
|
sizeInBytes >= 1_000_000 -> {
|
||||||
val sizeInMb = sizeInBytes / 1_000_000.0
|
val sizeInMb = sizeInBytes / 1_000_000.0
|
||||||
String.format("%.2f MB", sizeInMb)
|
String.format(Locale.getDefault(), "%.2f MB", sizeInMb)
|
||||||
}
|
}
|
||||||
else -> {
|
else -> {
|
||||||
val sizeInKb = sizeInBytes / 1_000.0
|
val sizeInKb = sizeInBytes / 1_000.0
|
||||||
String.format("%.2f KB", sizeInKb)
|
String.format(Locale.getDefault(), "%.2f KB", sizeInKb)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,31 @@ import android.content.Context
|
||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
import android.os.StatFs
|
import android.os.StatFs
|
||||||
import android.provider.OpenableColumns
|
import android.provider.OpenableColumns
|
||||||
|
import android.util.Log
|
||||||
import com.example.llama.revamp.data.local.ModelDao
|
import com.example.llama.revamp.data.local.ModelDao
|
||||||
import com.example.llama.revamp.data.local.ModelEntity
|
import com.example.llama.revamp.data.local.ModelEntity
|
||||||
import com.example.llama.revamp.data.model.ModelInfo
|
import com.example.llama.revamp.data.model.ModelInfo
|
||||||
|
import com.example.llama.revamp.data.repository.ModelRepository.ImportProgressTracker
|
||||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.delay
|
import kotlinx.coroutines.delay
|
||||||
import kotlinx.coroutines.flow.Flow
|
import kotlinx.coroutines.flow.Flow
|
||||||
import kotlinx.coroutines.flow.flow
|
import kotlinx.coroutines.flow.flow
|
||||||
import kotlinx.coroutines.flow.map
|
import kotlinx.coroutines.flow.map
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
|
import kotlinx.coroutines.yield
|
||||||
|
import java.io.BufferedInputStream
|
||||||
|
import java.io.BufferedOutputStream
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
import java.io.FileNotFoundException
|
||||||
import java.io.FileOutputStream
|
import java.io.FileOutputStream
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
|
import java.io.InputStream
|
||||||
|
import java.io.OutputStream
|
||||||
|
import java.nio.ByteBuffer
|
||||||
|
import java.nio.channels.Channels
|
||||||
|
import java.nio.channels.ReadableByteChannel
|
||||||
|
import java.nio.channels.WritableByteChannel
|
||||||
import java.util.Locale
|
import java.util.Locale
|
||||||
import java.util.UUID
|
import java.util.UUID
|
||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
|
|
@ -27,9 +41,14 @@ interface ModelRepository {
|
||||||
fun getStorageMetrics(): Flow<StorageMetrics>
|
fun getStorageMetrics(): Flow<StorageMetrics>
|
||||||
fun getModels(): Flow<List<ModelInfo>>
|
fun getModels(): Flow<List<ModelInfo>>
|
||||||
|
|
||||||
suspend fun importModel(uri: Uri): ModelInfo
|
suspend fun importModel(uri: Uri, progressTracker: ImportProgressTracker? = null): ModelInfo
|
||||||
|
|
||||||
suspend fun deleteModel(modelId: String)
|
suspend fun deleteModel(modelId: String)
|
||||||
suspend fun deleteModels(modelIds: Collection<String>)
|
suspend fun deleteModels(modelIds: Collection<String>)
|
||||||
|
|
||||||
|
fun interface ImportProgressTracker {
|
||||||
|
fun onProgress(progress: Float) // 0.0f to 1.0f
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Singleton
|
@Singleton
|
||||||
|
|
@ -67,38 +86,138 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun importModel(uri: Uri): ModelInfo {
|
override suspend fun importModel(
|
||||||
// Obtain the local model's file via provided URI
|
uri: Uri,
|
||||||
val fileName = getFileNameFromUri(uri)
|
progressTracker: ImportProgressTracker?
|
||||||
|
): ModelInfo = withContext(Dispatchers.IO) {
|
||||||
|
val fileName = getFileNameFromUri(uri) ?: throw FileNotFoundException("Filename N/A")
|
||||||
|
val fileSize = getFileSizeFromUri(uri) ?: throw FileNotFoundException("File size N/A")
|
||||||
val modelFile = File(modelsDir, fileName)
|
val modelFile = File(modelsDir, fileName)
|
||||||
|
|
||||||
// Copy file to app's internal storage
|
try {
|
||||||
context.contentResolver.openInputStream(uri)?.use { inputStream ->
|
val inputStream = context.contentResolver.openInputStream(uri)
|
||||||
FileOutputStream(modelFile).use { outputStream ->
|
?: throw IOException("Failed to open input stream")
|
||||||
inputStream.copyTo(outputStream)
|
val outputStream = FileOutputStream(modelFile)
|
||||||
|
|
||||||
|
if (fileSize > LARGE_MODEL_THRESHOLD_SIZE) {
|
||||||
|
Log.i(TAG, "Copying $fileName (size: $fileSize) via NIO...")
|
||||||
|
|
||||||
|
// Use NIO channels for large models
|
||||||
|
copyWithChannels(inputStream, outputStream, fileSize, progressTracker)
|
||||||
|
} else {
|
||||||
|
Log.i(TAG, "Copying $fileName (size: $fileSize) via buffer...")
|
||||||
|
|
||||||
|
// Default copy with buffer for small models
|
||||||
|
val bufferedInput = BufferedInputStream(inputStream, DEFAULT_BUFFER_SIZE)
|
||||||
|
val bufferedOutput = BufferedOutputStream(outputStream, DEFAULT_BUFFER_SIZE)
|
||||||
|
copyWithBuffer(bufferedInput, bufferedOutput, fileSize, progressTracker)
|
||||||
|
|
||||||
|
// Close streams
|
||||||
|
bufferedOutput.flush()
|
||||||
|
bufferedOutput.close()
|
||||||
|
bufferedInput.close()
|
||||||
}
|
}
|
||||||
} ?: throw IOException("Failed to open input stream")
|
|
||||||
|
|
||||||
// Extract model parameters from filename
|
// Extract model parameters from filename
|
||||||
val modelType = extractModelTypeFromFilename(fileName) ?: "unknown"
|
val modelType = extractModelTypeFromFilename(fileName)
|
||||||
val parameters = extractParametersFromFilename(fileName) ?: "unknown"
|
val parameters = extractParametersFromFilename(fileName)
|
||||||
val quantization = extractQuantizationFromFilename(fileName) ?: "unknown"
|
val quantization = extractQuantizationFromFilename(fileName)
|
||||||
|
|
||||||
// Create model entity and save via DAO
|
// Create model entity and save via DAO
|
||||||
val modelEntity = ModelEntity(
|
ModelEntity(
|
||||||
id = UUID.randomUUID().toString(),
|
id = UUID.randomUUID().toString(),
|
||||||
name = fileName.substringBeforeLast('.'),
|
name = fileName.substringBeforeLast('.'),
|
||||||
path = modelFile.absolutePath,
|
path = modelFile.absolutePath,
|
||||||
sizeInBytes = modelFile.length(),
|
sizeInBytes = modelFile.length(),
|
||||||
parameters = parameters,
|
parameters = parameters,
|
||||||
quantization = quantization,
|
quantization = quantization,
|
||||||
type = modelType,
|
type = modelType,
|
||||||
contextLength = DEFAULT_CONTEXT_SIZE,
|
contextLength = DEFAULT_CONTEXT_SIZE,
|
||||||
lastUsed = null,
|
lastUsed = null,
|
||||||
dateAdded = System.currentTimeMillis()
|
dateAdded = System.currentTimeMillis()
|
||||||
)
|
).let {
|
||||||
modelDao.insertModel(modelEntity)
|
modelDao.insertModel(it)
|
||||||
return modelEntity.toModelInfo()
|
it.toModelInfo()
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
// Clean up partially downloaded file if error occurs
|
||||||
|
if (modelFile.exists()) {
|
||||||
|
modelFile.delete()
|
||||||
|
}
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private suspend fun copyWithChannels(
|
||||||
|
input: InputStream,
|
||||||
|
output: OutputStream,
|
||||||
|
totalSize: Long,
|
||||||
|
progressTracker: ImportProgressTracker?
|
||||||
|
) {
|
||||||
|
val inChannel: ReadableByteChannel = Channels.newChannel(input)
|
||||||
|
val outChannel: WritableByteChannel = Channels.newChannel(output)
|
||||||
|
|
||||||
|
val buffer = ByteBuffer.allocateDirect(NIO_BUFFER_SIZE)
|
||||||
|
var totalBytesRead = 0L
|
||||||
|
|
||||||
|
while (inChannel.read(buffer) != -1) {
|
||||||
|
buffer.flip()
|
||||||
|
while (buffer.hasRemaining()) {
|
||||||
|
outChannel.write(buffer)
|
||||||
|
}
|
||||||
|
totalBytesRead += buffer.position()
|
||||||
|
buffer.clear()
|
||||||
|
|
||||||
|
// Report progress
|
||||||
|
progressTracker?.let {
|
||||||
|
val progress = totalBytesRead.toFloat() / totalSize
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
it.onProgress(progress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (totalBytesRead % (NIO_YIELD_SIZE) == 0L) {
|
||||||
|
yield()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inChannel.close()
|
||||||
|
outChannel.close()
|
||||||
|
output.close()
|
||||||
|
input.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
private suspend fun copyWithBuffer(
|
||||||
|
input: BufferedInputStream,
|
||||||
|
output: BufferedOutputStream,
|
||||||
|
totalSize: Long,
|
||||||
|
progressTracker: ImportProgressTracker?
|
||||||
|
) {
|
||||||
|
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
|
||||||
|
|
||||||
|
var bytesRead: Int
|
||||||
|
var totalBytesRead = 0L
|
||||||
|
|
||||||
|
while (input.read(buffer).also { bytesRead = it } != -1) {
|
||||||
|
output.write(buffer, 0, bytesRead)
|
||||||
|
totalBytesRead += bytesRead
|
||||||
|
|
||||||
|
// Report progress
|
||||||
|
if (progressTracker != null) {
|
||||||
|
val progress = totalBytesRead.toFloat() / totalSize
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
progressTracker.onProgress(progress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Yield less frequently with larger buffers
|
||||||
|
if (totalBytesRead % (DEFAULT_YIELD_SIZE) == 0L) { // Every 64MB
|
||||||
|
yield()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output.close()
|
||||||
|
input.close()
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun deleteModel(modelId: String) {
|
override suspend fun deleteModel(modelId: String) {
|
||||||
|
|
@ -132,7 +251,7 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
val totalSpaceBytes: Long
|
val totalSpaceBytes: Long
|
||||||
get() = StatFs(context.filesDir.path).totalBytes
|
get() = StatFs(context.filesDir.path).totalBytes
|
||||||
|
|
||||||
private fun getFileNameFromUri(uri: Uri): String =
|
private fun getFileNameFromUri(uri: Uri): String? =
|
||||||
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
||||||
if (cursor.moveToFirst()) {
|
if (cursor.moveToFirst()) {
|
||||||
cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex ->
|
cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME).let { nameIndex ->
|
||||||
|
|
@ -141,7 +260,21 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
} else {
|
} else {
|
||||||
null
|
null
|
||||||
}
|
}
|
||||||
} ?: uri.lastPathSegment ?: "unknown_model.gguf"
|
} ?: uri.lastPathSegment
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the file size from a content URI, or returns 0 if size is unknown.
|
||||||
|
*/
|
||||||
|
private fun getFileSizeFromUri(uri: Uri): Long? =
|
||||||
|
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
||||||
|
if (cursor.moveToFirst()) {
|
||||||
|
cursor.getColumnIndex(OpenableColumns.SIZE).let { sizeIndex ->
|
||||||
|
if (sizeIndex != -1) cursor.getLong(sizeIndex) else null
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Try to extract parameters by looking for patterns like 7B, 13B, etc.
|
* Try to extract parameters by looking for patterns like 7B, 13B, etc.
|
||||||
|
|
@ -182,9 +315,18 @@ class ModelRepositoryImpl @Inject constructor(
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
private val TAG = ModelRepository::class.java.simpleName
|
||||||
|
|
||||||
private const val INTERNAL_STORAGE_PATH = "models"
|
private const val INTERNAL_STORAGE_PATH = "models"
|
||||||
|
|
||||||
private const val BYTES_IN_GB = 1024f * 1024f * 1024f
|
private const val BYTES_IN_GB = 1024f * 1024f * 1024f
|
||||||
|
|
||||||
|
private const val LARGE_MODEL_THRESHOLD_SIZE = 1024 * 1024 * 1024
|
||||||
|
private const val NIO_BUFFER_SIZE = 32 * 1024 * 1024
|
||||||
|
private const val NIO_YIELD_SIZE = 128 * 1024 * 1024
|
||||||
|
private const val DEFAULT_BUFFER_SIZE = 4 * 1024 * 1024
|
||||||
|
private const val DEFAULT_YIELD_SIZE = 16 * 1024 * 1024
|
||||||
|
|
||||||
private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L
|
private const val STORAGE_METRICS_UPDATE_INTERVAL = 5_000L
|
||||||
private const val DEFAULT_CONTEXT_SIZE = 8192
|
private const val DEFAULT_CONTEXT_SIZE = 8192
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -99,13 +99,13 @@ fun ModelCard(
|
||||||
|
|
||||||
Row {
|
Row {
|
||||||
Text(
|
Text(
|
||||||
text = model.parameters,
|
text = model.parameters ?: " - ",
|
||||||
style = MaterialTheme.typography.bodyMedium,
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
)
|
)
|
||||||
|
|
||||||
Text(
|
Text(
|
||||||
text = " • ${model.quantization}",
|
text = " • ${model.quantization ?: " - "}",
|
||||||
style = MaterialTheme.typography.bodyMedium,
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
)
|
)
|
||||||
|
|
@ -120,7 +120,7 @@ fun ModelCard(
|
||||||
Spacer(modifier = Modifier.height(4.dp))
|
Spacer(modifier = Modifier.height(4.dp))
|
||||||
|
|
||||||
Text(
|
Text(
|
||||||
text = "Context Length: ${model.contextLength}",
|
text = "Context Length: ${model.contextLength ?: " - "}",
|
||||||
style = MaterialTheme.typography.bodySmall,
|
style = MaterialTheme.typography.bodySmall,
|
||||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,11 @@
|
||||||
package com.example.llama.revamp.ui.screens
|
package com.example.llama.revamp.ui.screens
|
||||||
|
|
||||||
|
import androidx.activity.compose.BackHandler
|
||||||
|
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||||
|
import androidx.activity.result.contract.ActivityResultContracts
|
||||||
|
import androidx.compose.foundation.background
|
||||||
import androidx.compose.foundation.clickable
|
import androidx.compose.foundation.clickable
|
||||||
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.Row
|
import androidx.compose.foundation.layout.Row
|
||||||
import androidx.compose.foundation.layout.Spacer
|
import androidx.compose.foundation.layout.Spacer
|
||||||
|
|
@ -23,7 +28,10 @@ import androidx.compose.material.icons.filled.FilterAlt
|
||||||
import androidx.compose.material.icons.filled.FolderOpen
|
import androidx.compose.material.icons.filled.FolderOpen
|
||||||
import androidx.compose.material.icons.filled.Info
|
import androidx.compose.material.icons.filled.Info
|
||||||
import androidx.compose.material.icons.filled.SelectAll
|
import androidx.compose.material.icons.filled.SelectAll
|
||||||
|
import androidx.compose.material3.AlertDialog
|
||||||
import androidx.compose.material3.BottomAppBar
|
import androidx.compose.material3.BottomAppBar
|
||||||
|
import androidx.compose.material3.Button
|
||||||
|
import androidx.compose.material3.ButtonDefaults
|
||||||
import androidx.compose.material3.Card
|
import androidx.compose.material3.Card
|
||||||
import androidx.compose.material3.CardDefaults
|
import androidx.compose.material3.CardDefaults
|
||||||
import androidx.compose.material3.Checkbox
|
import androidx.compose.material3.Checkbox
|
||||||
|
|
@ -32,9 +40,11 @@ import androidx.compose.material3.DropdownMenuItem
|
||||||
import androidx.compose.material3.FloatingActionButton
|
import androidx.compose.material3.FloatingActionButton
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.IconButton
|
import androidx.compose.material3.IconButton
|
||||||
|
import androidx.compose.material3.LinearProgressIndicator
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.material3.Text
|
import androidx.compose.material3.Text
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
import androidx.compose.runtime.collectAsState
|
import androidx.compose.runtime.collectAsState
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
import androidx.compose.runtime.mutableStateMapOf
|
import androidx.compose.runtime.mutableStateMapOf
|
||||||
|
|
@ -45,6 +55,7 @@ import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.graphics.Color
|
import androidx.compose.ui.graphics.Color
|
||||||
import androidx.compose.ui.res.painterResource
|
import androidx.compose.ui.res.painterResource
|
||||||
|
import androidx.compose.ui.text.style.TextOverflow
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import androidx.hilt.navigation.compose.hiltViewModel
|
import androidx.hilt.navigation.compose.hiltViewModel
|
||||||
import com.example.llama.revamp.data.model.ModelInfo
|
import com.example.llama.revamp.data.model.ModelInfo
|
||||||
|
|
@ -55,6 +66,7 @@ import java.text.SimpleDateFormat
|
||||||
import java.util.Date
|
import java.util.Date
|
||||||
import java.util.Locale
|
import java.util.Locale
|
||||||
import com.example.llama.R
|
import com.example.llama.R
|
||||||
|
import com.example.llama.revamp.viewmodel.ModelImportState
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Screen for managing LLM models (view, download, delete)
|
* Screen for managing LLM models (view, download, delete)
|
||||||
|
|
@ -66,13 +78,24 @@ fun ModelsManagementScreen(
|
||||||
) {
|
) {
|
||||||
val storageMetrics by viewModel.storageMetrics.collectAsState()
|
val storageMetrics by viewModel.storageMetrics.collectAsState()
|
||||||
val sortedModels by viewModel.sortedModels.collectAsState()
|
val sortedModels by viewModel.sortedModels.collectAsState()
|
||||||
|
|
||||||
|
// Model sorting
|
||||||
val sortOrder by viewModel.sortOrder.collectAsState()
|
val sortOrder by viewModel.sortOrder.collectAsState()
|
||||||
|
|
||||||
// UI: menu states
|
|
||||||
var showSortMenu by remember { mutableStateOf(false) }
|
var showSortMenu by remember { mutableStateOf(false) }
|
||||||
var showAddModelMenu by remember { mutableStateOf(false) }
|
|
||||||
|
|
||||||
// UI: multi-selection states
|
// Model importing
|
||||||
|
val importState by viewModel.importState.collectAsState()
|
||||||
|
|
||||||
|
var showImportModelMenu by remember { mutableStateOf(false) }
|
||||||
|
val fileLauncher = rememberLauncherForActivityResult(
|
||||||
|
contract = ActivityResultContracts.OpenDocument()
|
||||||
|
) { uri -> uri?.let { viewModel.importLocalModel(it) } }
|
||||||
|
|
||||||
|
BackHandler(enabled = importState is ModelImportState.Importing) {
|
||||||
|
/* Ignore back press while importing model */
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multi-selection
|
||||||
var isMultiSelectionMode by remember { mutableStateOf(false) }
|
var isMultiSelectionMode by remember { mutableStateOf(false) }
|
||||||
val selectedModels = remember { mutableStateMapOf<String, ModelInfo>() }
|
val selectedModels = remember { mutableStateMapOf<String, ModelInfo>() }
|
||||||
val exitSelectionMode = {
|
val exitSelectionMode = {
|
||||||
|
|
@ -238,7 +261,7 @@ fun ModelsManagementScreen(
|
||||||
if (isMultiSelectionMode) {
|
if (isMultiSelectionMode) {
|
||||||
exitSelectionMode()
|
exitSelectionMode()
|
||||||
} else {
|
} else {
|
||||||
showAddModelMenu = true
|
showImportModelMenu = true
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
containerColor = MaterialTheme.colorScheme.primaryContainer
|
containerColor = MaterialTheme.colorScheme.primaryContainer
|
||||||
|
|
@ -251,8 +274,8 @@ fun ModelsManagementScreen(
|
||||||
|
|
||||||
// Add model dropdown menu
|
// Add model dropdown menu
|
||||||
DropdownMenu(
|
DropdownMenu(
|
||||||
expanded = showAddModelMenu,
|
expanded = showImportModelMenu,
|
||||||
onDismissRequest = { showAddModelMenu = false }
|
onDismissRequest = { showImportModelMenu = false }
|
||||||
) {
|
) {
|
||||||
DropdownMenuItem(
|
DropdownMenuItem(
|
||||||
text = { Text("Import local model") },
|
text = { Text("Import local model") },
|
||||||
|
|
@ -263,9 +286,8 @@ fun ModelsManagementScreen(
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
onClick = {
|
onClick = {
|
||||||
// TODO-han.yin: uncomment once file picker done
|
fileLauncher.launch(arrayOf("application/octet-stream", "*/*"))
|
||||||
// viewModel.importLocalModel()
|
showImportModelMenu = false
|
||||||
showAddModelMenu = false
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
DropdownMenuItem(
|
DropdownMenuItem(
|
||||||
|
|
@ -280,7 +302,7 @@ fun ModelsManagementScreen(
|
||||||
},
|
},
|
||||||
onClick = {
|
onClick = {
|
||||||
viewModel.importFromHuggingFace()
|
viewModel.importFromHuggingFace()
|
||||||
showAddModelMenu = false
|
showImportModelMenu = false
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
@ -288,38 +310,64 @@ fun ModelsManagementScreen(
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
) { paddingValues ->
|
) { paddingValues ->
|
||||||
// Main content
|
Box(modifier = Modifier.fillMaxSize()) {
|
||||||
ModelList(
|
// Model cards
|
||||||
models = sortedModels,
|
ModelCardList(
|
||||||
isMultiSelectionMode = isMultiSelectionMode,
|
models = sortedModels,
|
||||||
selectedModels = selectedModels,
|
isMultiSelectionMode = isMultiSelectionMode,
|
||||||
onModelClick = { modelId ->
|
selectedModels = selectedModels,
|
||||||
if (isMultiSelectionMode) {
|
onModelClick = { modelId ->
|
||||||
// Toggle selection
|
if (isMultiSelectionMode) {
|
||||||
if (selectedModels.contains(modelId)) {
|
// Toggle selection
|
||||||
selectedModels.remove(modelId)
|
if (selectedModels.contains(modelId)) {
|
||||||
|
selectedModels.remove(modelId)
|
||||||
|
} else {
|
||||||
|
selectedModels.put(modelId, sortedModels.first { it.id == modelId } )
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
selectedModels.put(modelId, sortedModels.first { it.id == modelId } )
|
// View model details
|
||||||
|
viewModel.viewModelDetails(modelId)
|
||||||
}
|
}
|
||||||
} else {
|
},
|
||||||
// View model details
|
onModelInfoClick = { modelId ->
|
||||||
viewModel.viewModelDetails(modelId)
|
viewModel.viewModelDetails(modelId)
|
||||||
|
},
|
||||||
|
onModelDeleteClick = { modelId ->
|
||||||
|
viewModel.deleteModel(modelId)
|
||||||
|
},
|
||||||
|
modifier = Modifier.padding(paddingValues)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Model import progress overlay
|
||||||
|
when (val state = importState) {
|
||||||
|
is ModelImportState.Importing -> {
|
||||||
|
ImportProgressOverlay(
|
||||||
|
progress = state.progress,
|
||||||
|
filename = state.filename,
|
||||||
|
onCancel = { /* Implement cancellation if needed */ }
|
||||||
|
)
|
||||||
}
|
}
|
||||||
},
|
is ModelImportState.Error -> {
|
||||||
onModelInfoClick = { modelId ->
|
ErrorDialog(
|
||||||
viewModel.viewModelDetails(modelId)
|
message = state.message,
|
||||||
},
|
onDismiss = { viewModel.resetImportState() }
|
||||||
onModelDeleteClick = { modelId ->
|
)
|
||||||
viewModel.deleteModel(modelId)
|
}
|
||||||
},
|
is ModelImportState.Success -> {
|
||||||
modifier = Modifier.padding(paddingValues)
|
LaunchedEffect(state) {
|
||||||
)
|
// Show success snackbar or message
|
||||||
|
// This will auto-dismiss after the delay in viewModel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else -> { /* Idle state, nothing to show */ }
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
private fun ModelList(
|
private fun ModelCardList(
|
||||||
models: List<ModelInfo>,
|
models: List<ModelInfo>,
|
||||||
isMultiSelectionMode: Boolean,
|
isMultiSelectionMode: Boolean,
|
||||||
selectedModels: Map<String, ModelInfo>,
|
selectedModels: Map<String, ModelInfo>,
|
||||||
|
|
@ -337,13 +385,16 @@ private fun ModelList(
|
||||||
items = models,
|
items = models,
|
||||||
key = { it.id }
|
key = { it.id }
|
||||||
) { model ->
|
) { model ->
|
||||||
ModelItem(
|
ModelCard(
|
||||||
model = model,
|
model = model,
|
||||||
isMultiSelectionMode = isMultiSelectionMode,
|
isMultiSelectionMode = isMultiSelectionMode,
|
||||||
isSelected = selectedModels.contains(model.id),
|
isSelected = selectedModels.contains(model.id),
|
||||||
onClick = { onModelClick(model.id) },
|
onClick = { onModelClick(model.id) },
|
||||||
onInfoClick = { onModelInfoClick(model.id) },
|
onInfoClick = { onModelInfoClick(model.id) },
|
||||||
onDeleteClick = { onModelDeleteClick(model.id) }
|
onDeleteClick = {
|
||||||
|
// TODO-han.yin: pop up an AlertDialog asking user for confirmation
|
||||||
|
onModelDeleteClick(model.id)
|
||||||
|
}
|
||||||
)
|
)
|
||||||
Spacer(modifier = Modifier.height(8.dp))
|
Spacer(modifier = Modifier.height(8.dp))
|
||||||
}
|
}
|
||||||
|
|
@ -351,7 +402,7 @@ private fun ModelList(
|
||||||
}
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
private fun ModelItem(
|
private fun ModelCard(
|
||||||
model: ModelInfo,
|
model: ModelInfo,
|
||||||
isMultiSelectionMode: Boolean,
|
isMultiSelectionMode: Boolean,
|
||||||
isSelected: Boolean,
|
isSelected: Boolean,
|
||||||
|
|
@ -424,3 +475,95 @@ private fun ModelItem(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun ImportProgressOverlay(
|
||||||
|
progress: Float,
|
||||||
|
filename: String,
|
||||||
|
onCancel: () -> Unit
|
||||||
|
) {
|
||||||
|
Box(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxSize()
|
||||||
|
.background(Color.Black.copy(alpha = 0.7f))
|
||||||
|
.padding(32.dp),
|
||||||
|
contentAlignment = Alignment.Center
|
||||||
|
) {
|
||||||
|
Card(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(16.dp),
|
||||||
|
elevation = CardDefaults.cardElevation(defaultElevation = 8.dp)
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier.padding(24.dp),
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
text = "Importing Model",
|
||||||
|
style = MaterialTheme.typography.headlineSmall
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(8.dp))
|
||||||
|
|
||||||
|
Text(
|
||||||
|
text = filename,
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
maxLines = 1,
|
||||||
|
overflow = TextOverflow.Ellipsis
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(24.dp))
|
||||||
|
|
||||||
|
LinearProgressIndicator(
|
||||||
|
progress = { progress },
|
||||||
|
modifier = Modifier.fillMaxWidth()
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(8.dp))
|
||||||
|
|
||||||
|
Text(
|
||||||
|
text = "${(progress * 100).toInt()}%",
|
||||||
|
style = MaterialTheme.typography.bodyLarge
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(16.dp))
|
||||||
|
|
||||||
|
Text(
|
||||||
|
text = "This may take several minutes for large models",
|
||||||
|
style = MaterialTheme.typography.bodySmall,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(24.dp))
|
||||||
|
|
||||||
|
Button(
|
||||||
|
onClick = onCancel,
|
||||||
|
colors = ButtonDefaults.buttonColors(
|
||||||
|
containerColor = MaterialTheme.colorScheme.errorContainer,
|
||||||
|
contentColor = MaterialTheme.colorScheme.onErrorContainer
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
Text("Cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun ErrorDialog(
|
||||||
|
message: String,
|
||||||
|
onDismiss: () -> Unit
|
||||||
|
) {
|
||||||
|
AlertDialog(
|
||||||
|
onDismissRequest = onDismiss,
|
||||||
|
title = { Text("Import Failed") },
|
||||||
|
text = { Text(message) },
|
||||||
|
confirmButton = {
|
||||||
|
Button(onClick = onDismiss) {
|
||||||
|
Text("OK")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
package com.example.llama.revamp.viewmodel
|
package com.example.llama.revamp.viewmodel
|
||||||
|
|
||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
import android.util.Log
|
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.example.llama.revamp.data.model.ModelInfo
|
import com.example.llama.revamp.data.model.ModelInfo
|
||||||
import com.example.llama.revamp.data.repository.ModelRepository
|
import com.example.llama.revamp.data.repository.ModelRepository
|
||||||
import com.example.llama.revamp.data.repository.StorageMetrics
|
import com.example.llama.revamp.data.repository.StorageMetrics
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
|
import kotlinx.coroutines.delay
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
import kotlinx.coroutines.flow.SharingStarted
|
import kotlinx.coroutines.flow.SharingStarted
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
|
|
@ -76,16 +76,34 @@ class ModelsManagementViewModel @Inject constructor(
|
||||||
modelRepository.deleteModels(models.keys)
|
modelRepository.deleteModels(models.keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private val _importState = MutableStateFlow<ModelImportState>(ModelImportState.Idle)
|
||||||
|
val importState: StateFlow<ModelImportState> = _importState.asStateFlow()
|
||||||
|
|
||||||
fun importLocalModel(uri: Uri) =
|
fun importLocalModel(uri: Uri) =
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
try {
|
try {
|
||||||
modelRepository.importModel(uri)
|
// Get filename for progress updates
|
||||||
|
val filename = uri.lastPathSegment ?: throw Exception("Model name unknown")
|
||||||
|
_importState.value = ModelImportState.Importing(0f, filename)
|
||||||
|
|
||||||
|
// Import with progress reporting
|
||||||
|
val model = modelRepository.importModel(uri) { progress ->
|
||||||
|
_importState.value = ModelImportState.Importing(progress, filename)
|
||||||
|
}
|
||||||
|
_importState.value = ModelImportState.Success(model)
|
||||||
|
|
||||||
|
// Reset state after a delay
|
||||||
|
delay(1000)
|
||||||
|
_importState.value = ModelImportState.Idle
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
// TODO-han.yin: add UI to prompt user about import failure!
|
_importState.value = ModelImportState.Error(e.message ?: "Unknown error")
|
||||||
Log.e(TAG, "Failed to import model from: $uri", e)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun resetImportState() {
|
||||||
|
_importState.value = ModelImportState.Idle
|
||||||
|
}
|
||||||
|
|
||||||
fun importFromHuggingFace() {
|
fun importFromHuggingFace() {
|
||||||
// TODO-han.yin: Stub for now. Would need to investigate HuggingFace APIs
|
// TODO-han.yin: Stub for now. Would need to investigate HuggingFace APIs
|
||||||
}
|
}
|
||||||
|
|
@ -105,4 +123,9 @@ enum class ModelSortOrder {
|
||||||
LAST_USED
|
LAST_USED
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sealed class ModelImportState {
|
||||||
|
object Idle : ModelImportState()
|
||||||
|
data class Importing(val progress: Float = 0f, val filename: String = "") : ModelImportState()
|
||||||
|
data class Success(val model: ModelInfo) : ModelImportState()
|
||||||
|
data class Error(val message: String) : ModelImportState()
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue