[WIP] ui: polish Benchmark screen; implement its bottom app bar

This commit is contained in:
Han Yin 2025-04-21 15:47:44 -07:00
parent ec907d01ba
commit 32f37a4316
5 changed files with 234 additions and 91 deletions

View File

@ -1,5 +1,6 @@
package com.example.llama.revamp package com.example.llama.revamp
import android.llama.cpp.isUninterruptible
import android.os.Bundle import android.os.Bundle
import androidx.activity.ComponentActivity import androidx.activity.ComponentActivity
import androidx.activity.compose.rememberLauncherForActivityResult import androidx.activity.compose.rememberLauncherForActivityResult
@ -115,6 +116,36 @@ fun AppContent(
} }
val openDrawer: () -> Unit = { coroutineScope.launch { drawerState.open() } } val openDrawer: () -> Unit = { coroutineScope.launch { drawerState.open() } }
// Handle child screens' scaffold events
val handleScaffoldEvent: (ScaffoldEvent) -> Unit = { event ->
when (event) {
is ScaffoldEvent.ShowSnackbar -> {
coroutineScope.launch {
if (event.actionLabel != null && event.onAction != null) {
val result = snackbarHostState.showSnackbar(
message = event.message,
actionLabel = event.actionLabel,
withDismissAction = event.withDismissAction,
duration = event.duration
)
if (result == SnackbarResult.ActionPerformed) {
event.onAction()
}
} else {
snackbarHostState.showSnackbar(
message = event.message,
withDismissAction = event.withDismissAction,
duration = event.duration
)
}
}
}
is ScaffoldEvent.ChangeTitle -> {
// TODO-han.yin: TBD
}
}
}
// Create scaffold's top & bottom bar configs based on current route // Create scaffold's top & bottom bar configs based on current route
val scaffoldConfig = when { val scaffoldConfig = when {
// Model selection screen // Model selection screen
@ -131,7 +162,7 @@ fun AppContent(
topBarConfig = topBarConfig =
if (isSearchActive) TopBarConfig.None() if (isSearchActive) TopBarConfig.None()
else TopBarConfig.Default( else TopBarConfig.Default(
title = "Select a Model", title = "Pick your model",
navigationIcon = NavigationIcon.Menu { navigationIcon = NavigationIcon.Menu {
modelSelectionViewModel.resetSelection() modelSelectionViewModel.resetSelection()
openDrawer() openDrawer()
@ -177,9 +208,9 @@ fun AppContent(
currentRoute == AppDestinations.MODEL_LOADING_ROUTE -> currentRoute == AppDestinations.MODEL_LOADING_ROUTE ->
ScaffoldConfig( ScaffoldConfig(
topBarConfig = TopBarConfig.Performance( topBarConfig = TopBarConfig.Performance(
title = "Load Model", title = "Select a mode",
navigationIcon = NavigationIcon.Back { navigationIcon = NavigationIcon.Back {
benchmarkViewModel.onBackPressed { navigationActions.navigateUp() } modelLoadingViewModel.onBackPressed { navigationActions.navigateUp() }
}, },
memoryMetrics = memoryUsage, memoryMetrics = memoryUsage,
temperatureInfo = null temperatureInfo = null
@ -187,7 +218,9 @@ fun AppContent(
) )
// Benchmark screen // Benchmark screen
currentRoute.startsWith(AppDestinations.BENCHMARK_ROUTE) -> currentRoute.startsWith(AppDestinations.BENCHMARK_ROUTE) -> {
val engineState by benchmarkViewModel.engineState.collectAsState()
ScaffoldConfig( ScaffoldConfig(
topBarConfig = TopBarConfig.Performance( topBarConfig = TopBarConfig.Performance(
title = "Benchmark", title = "Benchmark",
@ -196,8 +229,23 @@ fun AppContent(
}, },
memoryMetrics = memoryUsage, memoryMetrics = memoryUsage,
temperatureInfo = Pair(temperatureInfo, useFahrenheit) temperatureInfo = Pair(temperatureInfo, useFahrenheit)
),
bottomBarConfig = BottomBarConfig.Benchmark(
engineIdle = !engineState.isUninterruptible,
onRerun = {
if (engineState.isUninterruptible) {
handleScaffoldEvent(ScaffoldEvent.ShowSnackbar(
message = "Benchmark already in progress!\n" +
"Please wait for the current run to complete."
))
} else {
benchmarkViewModel.runBenchmark()
}
},
onShare = benchmarkViewModel::shareResults,
) )
) )
}
// Conversation screen // Conversation screen
currentRoute.startsWith(AppDestinations.CONVERSATION_ROUTE) -> currentRoute.startsWith(AppDestinations.CONVERSATION_ROUTE) ->
@ -296,36 +344,6 @@ fun AppContent(
) )
} }
// Handle child screens' scaffold events
val handleScaffoldEvent: (ScaffoldEvent) -> Unit = { event ->
when (event) {
is ScaffoldEvent.ShowSnackbar -> {
coroutineScope.launch {
if (event.actionLabel != null && event.onAction != null) {
val result = snackbarHostState.showSnackbar(
message = event.message,
actionLabel = event.actionLabel,
withDismissAction = event.withDismissAction,
duration = event.duration
)
if (result == SnackbarResult.ActionPerformed) {
event.onAction()
}
} else {
snackbarHostState.showSnackbar(
message = event.message,
withDismissAction = event.withDismissAction,
duration = event.duration
)
}
}
}
is ScaffoldEvent.ChangeTitle -> {
// TODO-han.yin: TBD
}
}
}
// Main UI hierarchy // Main UI hierarchy
AppNavigationDrawer( AppNavigationDrawer(
drawerState = drawerState, drawerState = drawerState,

View File

@ -85,6 +85,14 @@ fun AppScaffold(
importing = config.importing, importing = config.importing,
) )
} }
is BottomBarConfig.Benchmark -> {
BenchmarkBottomBar(
engineIdle = config.engineIdle,
onRerun = config.onRerun,
onShare = config.onShare
)
}
} }
} }

View File

@ -20,9 +20,11 @@ import androidx.compose.material.icons.filled.Delete
import androidx.compose.material.icons.filled.FilterAlt import androidx.compose.material.icons.filled.FilterAlt
import androidx.compose.material.icons.filled.FolderOpen import androidx.compose.material.icons.filled.FolderOpen
import androidx.compose.material.icons.filled.PlayArrow import androidx.compose.material.icons.filled.PlayArrow
import androidx.compose.material.icons.filled.Replay
import androidx.compose.material.icons.filled.Search import androidx.compose.material.icons.filled.Search
import androidx.compose.material.icons.filled.SearchOff import androidx.compose.material.icons.filled.SearchOff
import androidx.compose.material.icons.filled.SelectAll import androidx.compose.material.icons.filled.SelectAll
import androidx.compose.material.icons.filled.Share
import androidx.compose.material.icons.outlined.DeleteSweep import androidx.compose.material.icons.outlined.DeleteSweep
import androidx.compose.material.icons.outlined.FilterAlt import androidx.compose.material.icons.outlined.FilterAlt
import androidx.compose.material.icons.outlined.FilterAltOff import androidx.compose.material.icons.outlined.FilterAltOff
@ -126,6 +128,12 @@ sealed class BottomBarConfig {
) )
} }
data class Benchmark(
val engineIdle: Boolean,
val onRerun: () -> Unit,
val onShare: () -> Unit,
) : BottomBarConfig()
// TODO-han.yin: add bottom bar config for Conversation Screen! // TODO-han.yin: add bottom bar config for Conversation Screen!
} }
@ -451,3 +459,42 @@ fun ModelsManagementBottomBar(
} }
) )
} }
@Composable
fun BenchmarkBottomBar(
engineIdle: Boolean,
onRerun: () -> Unit,
onShare: () -> Unit,
) {
BottomAppBar(
actions = {
IconButton(onClick = onRerun) {
Icon(
imageVector = Icons.Default.Replay,
contentDescription = "Run the benchmark again",
tint =
if (engineIdle) MaterialTheme.colorScheme.onSurface
else MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.3f)
)
}
},
floatingActionButton = {
// Only show FAB if the benchmark result is ready
AnimatedVisibility(
visible = engineIdle,
enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut()
) {
FloatingActionButton(
onClick = onShare,
containerColor = MaterialTheme.colorScheme.primary
) {
Icon(
imageVector = Icons.Default.Share,
contentDescription = "Share the benchmark results"
)
}
}
}
)
}

View File

@ -3,19 +3,26 @@ package com.example.llama.revamp.ui.screens
import android.llama.cpp.InferenceEngine.State import android.llama.cpp.InferenceEngine.State
import androidx.activity.compose.BackHandler import androidx.activity.compose.BackHandler
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.verticalScroll import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults
import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ProgressIndicatorDefaults
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
@ -26,6 +33,7 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.engine.ModelLoadingMetrics import com.example.llama.revamp.engine.ModelLoadingMetrics
@ -64,77 +72,95 @@ fun BenchmarkScreen(
} }
Column( Column(
modifier = Modifier modifier = Modifier.fillMaxSize().verticalScroll(rememberScrollState())
.fillMaxSize()
.padding(16.dp)
.verticalScroll(rememberScrollState())
) { ) {
// Selected model card // Selected model card
selectedModel?.let { model -> selectedModel?.let { model ->
ModelCardWithLoadingMetrics( Box(
model = model, modifier = Modifier.padding(start = 16.dp, top = 16.dp, end = 16.dp)
loadingMetrics = loadingMetrics, ) {
isExpanded = isModelCardExpanded, ModelCardWithLoadingMetrics(
onExpanded = { isModelCardExpanded = !isModelCardExpanded }, model = model,
) loadingMetrics = loadingMetrics,
isExpanded = isModelCardExpanded,
onExpanded = { isModelCardExpanded = !isModelCardExpanded },
)
}
} }
// Benchmark results or loading indicator Box(
when { modifier = Modifier.fillMaxWidth().weight(1f),
engineState is State.Benchmarking -> { contentAlignment = Alignment.Center
Box( ) {
modifier = Modifier // Benchmark results
.fillMaxWidth() LazyColumn(
.height(200.dp), modifier = Modifier.fillMaxSize(),
contentAlignment = Alignment.Center verticalArrangement = Arrangement.spacedBy(16.dp),
) { contentPadding = PaddingValues(horizontal = 16.dp, vertical = 16.dp),
Column(horizontalAlignment = Alignment.CenterHorizontally) { ) {
CircularProgressIndicator() items(items = benchmarkResults) { result ->
Spacer(modifier = Modifier.height(16.dp)) Card(
Text( modifier = Modifier.fillMaxWidth()
text = "Running benchmark...", ) {
style = MaterialTheme.typography.bodyMedium Column(
) modifier = Modifier
.fillMaxWidth()
.background(
color = MaterialTheme.colorScheme.surfaceVariant,
shape = RoundedCornerShape(8.dp)
)
.padding(16.dp)
) {
Text(
text = result.text,
style = MonospacedTextStyle,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.height(4.dp))
ModelCardContentField("Time spent: ", formatMilliSeconds(result.duration))
}
} }
} }
} }
benchmarkResults != null -> { // Loading indicator
if (engineState is State.Benchmarking) {
Card( Card(
modifier = Modifier.fillMaxWidth() modifier = Modifier.align(Alignment.Center),
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.primaryContainer
),
shape = MaterialTheme.shapes.extraLarge
) { ) {
Box( Column(
modifier = Modifier modifier = Modifier.padding(horizontal = 32.dp, vertical = 48.dp),
.fillMaxWidth() horizontalAlignment = Alignment.CenterHorizontally
.background(
color = MaterialTheme.colorScheme.surfaceVariant,
shape = RoundedCornerShape(8.dp)
)
.padding(16.dp)
) { ) {
CircularProgressIndicator(
modifier = Modifier.size(64.dp),
strokeWidth = ProgressIndicatorDefaults.CircularStrokeWidth * 1.5f
)
Spacer(modifier = Modifier.height(16.dp))
Text( Text(
text = benchmarkResults ?: "", text = "Running benchmark...",
style = MonospacedTextStyle, style = MaterialTheme.typography.headlineSmall
)
Spacer(modifier = Modifier.height(8.dp))
Text(
text = "This usually takes a few minutes",
style = MaterialTheme.typography.bodyLarge,
textAlign = TextAlign.Center,
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurfaceVariant
) )
} }
} }
} }
else -> {
Box(
modifier = Modifier
.fillMaxWidth()
.height(200.dp),
contentAlignment = Alignment.Center
) {
Text(
text = "Benchmark results will appear here",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
} }
} }

View File

@ -1,10 +1,18 @@
package com.example.llama.revamp.viewmodel package com.example.llama.revamp.viewmodel
import android.llama.cpp.isUninterruptible
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.engine.BenchmarkService import com.example.llama.revamp.engine.BenchmarkService
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.flow.zip
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import javax.inject.Inject import javax.inject.Inject
@ -16,13 +24,49 @@ class BenchmarkViewModel @Inject constructor(
* UI states * UI states
*/ */
val selectedModel: StateFlow<ModelInfo?> = benchmarkService.currentSelectedModel val selectedModel: StateFlow<ModelInfo?> = benchmarkService.currentSelectedModel
val benchmarkResults: StateFlow<String?> = benchmarkService.benchmarkResults
private val _benchmarkDuration = MutableSharedFlow<Long>()
private val _benchmarkResults = MutableStateFlow<List<BenchmarkResult>>(emptyList())
val benchmarkResults: StateFlow<List<BenchmarkResult>> = _benchmarkResults.asStateFlow()
init {
viewModelScope.launch {
benchmarkService.benchmarkResults
.filterNotNull()
.zip(_benchmarkDuration) { result, duration ->
_benchmarkResults.update { oldResults ->
oldResults.toMutableList().apply {
add(BenchmarkResult(result, duration))
}
}
}.collect()
}
}
/** /**
* Run benchmark with specified parameters * Run benchmark with specified parameters
*/ */
fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3) = fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3): Boolean {
viewModelScope.launch { if (engineState.value.isUninterruptible) {
benchmarkService.benchmark(pp, tg, pl, nr) return false
} }
viewModelScope.launch {
val benchmarkStartTs = System.currentTimeMillis()
benchmarkService.benchmark(pp, tg, pl, nr)
val benchmarkEndTs = System.currentTimeMillis()
_benchmarkDuration.emit(benchmarkEndTs - benchmarkStartTs)
}
return true
}
fun shareResults() {
// TODO-han.yin: TO BE IMPLEMENTED
}
} }
data class BenchmarkResult(
val text: String,
val duration: Long
)