diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt index 01d9167a59..ad50828cf8 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt @@ -1,5 +1,6 @@ package com.example.llama.revamp +import android.llama.cpp.isUninterruptible import android.os.Bundle import androidx.activity.ComponentActivity import androidx.activity.compose.rememberLauncherForActivityResult @@ -115,6 +116,36 @@ fun AppContent( } val openDrawer: () -> Unit = { coroutineScope.launch { drawerState.open() } } + // Handle child screens' scaffold events + val handleScaffoldEvent: (ScaffoldEvent) -> Unit = { event -> + when (event) { + is ScaffoldEvent.ShowSnackbar -> { + coroutineScope.launch { + if (event.actionLabel != null && event.onAction != null) { + val result = snackbarHostState.showSnackbar( + message = event.message, + actionLabel = event.actionLabel, + withDismissAction = event.withDismissAction, + duration = event.duration + ) + if (result == SnackbarResult.ActionPerformed) { + event.onAction() + } + } else { + snackbarHostState.showSnackbar( + message = event.message, + withDismissAction = event.withDismissAction, + duration = event.duration + ) + } + } + } + is ScaffoldEvent.ChangeTitle -> { + // TODO-han.yin: TBD + } + } + } + // Create scaffold's top & bottom bar configs based on current route val scaffoldConfig = when { // Model selection screen @@ -131,7 +162,7 @@ fun AppContent( topBarConfig = if (isSearchActive) TopBarConfig.None() else TopBarConfig.Default( - title = "Select a Model", + title = "Pick your model", navigationIcon = NavigationIcon.Menu { modelSelectionViewModel.resetSelection() openDrawer() @@ -177,9 +208,9 @@ fun AppContent( currentRoute == AppDestinations.MODEL_LOADING_ROUTE -> ScaffoldConfig( topBarConfig = TopBarConfig.Performance( - title = "Load Model", + title = "Select a mode", navigationIcon = NavigationIcon.Back { - benchmarkViewModel.onBackPressed { navigationActions.navigateUp() } + modelLoadingViewModel.onBackPressed { navigationActions.navigateUp() } }, memoryMetrics = memoryUsage, temperatureInfo = null @@ -187,7 +218,9 @@ fun AppContent( ) // Benchmark screen - currentRoute.startsWith(AppDestinations.BENCHMARK_ROUTE) -> + currentRoute.startsWith(AppDestinations.BENCHMARK_ROUTE) -> { + val engineState by benchmarkViewModel.engineState.collectAsState() + ScaffoldConfig( topBarConfig = TopBarConfig.Performance( title = "Benchmark", @@ -196,8 +229,23 @@ fun AppContent( }, memoryMetrics = memoryUsage, temperatureInfo = Pair(temperatureInfo, useFahrenheit) + ), + bottomBarConfig = BottomBarConfig.Benchmark( + engineIdle = !engineState.isUninterruptible, + onRerun = { + if (engineState.isUninterruptible) { + handleScaffoldEvent(ScaffoldEvent.ShowSnackbar( + message = "Benchmark already in progress!\n" + + "Please wait for the current run to complete." + )) + } else { + benchmarkViewModel.runBenchmark() + } + }, + onShare = benchmarkViewModel::shareResults, ) ) + } // Conversation screen currentRoute.startsWith(AppDestinations.CONVERSATION_ROUTE) -> @@ -296,36 +344,6 @@ fun AppContent( ) } - // Handle child screens' scaffold events - val handleScaffoldEvent: (ScaffoldEvent) -> Unit = { event -> - when (event) { - is ScaffoldEvent.ShowSnackbar -> { - coroutineScope.launch { - if (event.actionLabel != null && event.onAction != null) { - val result = snackbarHostState.showSnackbar( - message = event.message, - actionLabel = event.actionLabel, - withDismissAction = event.withDismissAction, - duration = event.duration - ) - if (result == SnackbarResult.ActionPerformed) { - event.onAction() - } - } else { - snackbarHostState.showSnackbar( - message = event.message, - withDismissAction = event.withDismissAction, - duration = event.duration - ) - } - } - } - is ScaffoldEvent.ChangeTitle -> { - // TODO-han.yin: TBD - } - } - } - // Main UI hierarchy AppNavigationDrawer( drawerState = drawerState, diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/scaffold/AppScaffold.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/scaffold/AppScaffold.kt index 1c034b2a6a..56ee068897 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/scaffold/AppScaffold.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/scaffold/AppScaffold.kt @@ -85,6 +85,14 @@ fun AppScaffold( importing = config.importing, ) } + + is BottomBarConfig.Benchmark -> { + BenchmarkBottomBar( + engineIdle = config.engineIdle, + onRerun = config.onRerun, + onShare = config.onShare + ) + } } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/scaffold/BottomAppBars.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/scaffold/BottomAppBars.kt index 061a2946e2..fdf80a9616 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/scaffold/BottomAppBars.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/scaffold/BottomAppBars.kt @@ -20,9 +20,11 @@ import androidx.compose.material.icons.filled.Delete import androidx.compose.material.icons.filled.FilterAlt import androidx.compose.material.icons.filled.FolderOpen import androidx.compose.material.icons.filled.PlayArrow +import androidx.compose.material.icons.filled.Replay import androidx.compose.material.icons.filled.Search import androidx.compose.material.icons.filled.SearchOff import androidx.compose.material.icons.filled.SelectAll +import androidx.compose.material.icons.filled.Share import androidx.compose.material.icons.outlined.DeleteSweep import androidx.compose.material.icons.outlined.FilterAlt import androidx.compose.material.icons.outlined.FilterAltOff @@ -126,6 +128,12 @@ sealed class BottomBarConfig { ) } + data class Benchmark( + val engineIdle: Boolean, + val onRerun: () -> Unit, + val onShare: () -> Unit, + ) : BottomBarConfig() + // TODO-han.yin: add bottom bar config for Conversation Screen! } @@ -451,3 +459,42 @@ fun ModelsManagementBottomBar( } ) } + +@Composable +fun BenchmarkBottomBar( + engineIdle: Boolean, + onRerun: () -> Unit, + onShare: () -> Unit, +) { + BottomAppBar( + actions = { + IconButton(onClick = onRerun) { + Icon( + imageVector = Icons.Default.Replay, + contentDescription = "Run the benchmark again", + tint = + if (engineIdle) MaterialTheme.colorScheme.onSurface + else MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.3f) + ) + } + }, + floatingActionButton = { + // Only show FAB if the benchmark result is ready + AnimatedVisibility( + visible = engineIdle, + enter = scaleIn() + fadeIn(), + exit = scaleOut() + fadeOut() + ) { + FloatingActionButton( + onClick = onShare, + containerColor = MaterialTheme.colorScheme.primary + ) { + Icon( + imageVector = Icons.Default.Share, + contentDescription = "Share the benchmark results" + ) + } + } + } + ) +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt index fce93d89e0..106581cc4b 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt @@ -3,19 +3,26 @@ package com.example.llama.revamp.ui.screens import android.llama.cpp.InferenceEngine.State import androidx.activity.compose.BackHandler import androidx.compose.foundation.background +import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.PaddingValues import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.items import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.verticalScroll import androidx.compose.material3.Card +import androidx.compose.material3.CardDefaults import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.ProgressIndicatorDefaults import androidx.compose.material3.Text import androidx.compose.runtime.Composable import androidx.compose.runtime.LaunchedEffect @@ -26,6 +33,7 @@ import androidx.compose.runtime.remember import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.unit.dp import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.engine.ModelLoadingMetrics @@ -64,77 +72,95 @@ fun BenchmarkScreen( } Column( - modifier = Modifier - .fillMaxSize() - .padding(16.dp) - .verticalScroll(rememberScrollState()) + modifier = Modifier.fillMaxSize().verticalScroll(rememberScrollState()) ) { // Selected model card selectedModel?.let { model -> - ModelCardWithLoadingMetrics( - model = model, - loadingMetrics = loadingMetrics, - isExpanded = isModelCardExpanded, - onExpanded = { isModelCardExpanded = !isModelCardExpanded }, - ) + Box( + modifier = Modifier.padding(start = 16.dp, top = 16.dp, end = 16.dp) + ) { + ModelCardWithLoadingMetrics( + model = model, + loadingMetrics = loadingMetrics, + isExpanded = isModelCardExpanded, + onExpanded = { isModelCardExpanded = !isModelCardExpanded }, + ) + } } - // Benchmark results or loading indicator - when { - engineState is State.Benchmarking -> { - Box( - modifier = Modifier - .fillMaxWidth() - .height(200.dp), - contentAlignment = Alignment.Center - ) { - Column(horizontalAlignment = Alignment.CenterHorizontally) { - CircularProgressIndicator() - Spacer(modifier = Modifier.height(16.dp)) - Text( - text = "Running benchmark...", - style = MaterialTheme.typography.bodyMedium - ) + Box( + modifier = Modifier.fillMaxWidth().weight(1f), + contentAlignment = Alignment.Center + ) { + // Benchmark results + LazyColumn( + modifier = Modifier.fillMaxSize(), + verticalArrangement = Arrangement.spacedBy(16.dp), + contentPadding = PaddingValues(horizontal = 16.dp, vertical = 16.dp), + ) { + items(items = benchmarkResults) { result -> + Card( + modifier = Modifier.fillMaxWidth() + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .background( + color = MaterialTheme.colorScheme.surfaceVariant, + shape = RoundedCornerShape(8.dp) + ) + .padding(16.dp) + ) { + Text( + text = result.text, + style = MonospacedTextStyle, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(4.dp)) + + ModelCardContentField("Time spent: ", formatMilliSeconds(result.duration)) + } } } } - benchmarkResults != null -> { + // Loading indicator + if (engineState is State.Benchmarking) { Card( - modifier = Modifier.fillMaxWidth() + modifier = Modifier.align(Alignment.Center), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.primaryContainer + ), + shape = MaterialTheme.shapes.extraLarge ) { - Box( - modifier = Modifier - .fillMaxWidth() - .background( - color = MaterialTheme.colorScheme.surfaceVariant, - shape = RoundedCornerShape(8.dp) - ) - .padding(16.dp) + Column( + modifier = Modifier.padding(horizontal = 32.dp, vertical = 48.dp), + horizontalAlignment = Alignment.CenterHorizontally ) { + CircularProgressIndicator( + modifier = Modifier.size(64.dp), + strokeWidth = ProgressIndicatorDefaults.CircularStrokeWidth * 1.5f + ) + + Spacer(modifier = Modifier.height(16.dp)) + Text( - text = benchmarkResults ?: "", - style = MonospacedTextStyle, + text = "Running benchmark...", + style = MaterialTheme.typography.headlineSmall + ) + + Spacer(modifier = Modifier.height(8.dp)) + + Text( + text = "This usually takes a few minutes", + style = MaterialTheme.typography.bodyLarge, + textAlign = TextAlign.Center, color = MaterialTheme.colorScheme.onSurfaceVariant ) } } } - - else -> { - Box( - modifier = Modifier - .fillMaxWidth() - .height(200.dp), - contentAlignment = Alignment.Center - ) { - Text( - text = "Benchmark results will appear here", - style = MaterialTheme.typography.bodyMedium, - color = MaterialTheme.colorScheme.onSurfaceVariant - ) - } - } } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt index b28981fc59..cb336e45a1 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt @@ -1,10 +1,18 @@ package com.example.llama.revamp.viewmodel +import android.llama.cpp.isUninterruptible import androidx.lifecycle.viewModelScope import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.engine.BenchmarkService import dagger.hilt.android.lifecycle.HiltViewModel +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.filterNotNull +import kotlinx.coroutines.flow.update +import kotlinx.coroutines.flow.zip import kotlinx.coroutines.launch import javax.inject.Inject @@ -16,13 +24,49 @@ class BenchmarkViewModel @Inject constructor( * UI states */ val selectedModel: StateFlow = benchmarkService.currentSelectedModel - val benchmarkResults: StateFlow = benchmarkService.benchmarkResults + + private val _benchmarkDuration = MutableSharedFlow() + + private val _benchmarkResults = MutableStateFlow>(emptyList()) + val benchmarkResults: StateFlow> = _benchmarkResults.asStateFlow() + + init { + viewModelScope.launch { + benchmarkService.benchmarkResults + .filterNotNull() + .zip(_benchmarkDuration) { result, duration -> + _benchmarkResults.update { oldResults -> + oldResults.toMutableList().apply { + add(BenchmarkResult(result, duration)) + } + } + }.collect() + } + } /** * Run benchmark with specified parameters */ - fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3) = - viewModelScope.launch { - benchmarkService.benchmark(pp, tg, pl, nr) + fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3): Boolean { + if (engineState.value.isUninterruptible) { + return false } + + viewModelScope.launch { + val benchmarkStartTs = System.currentTimeMillis() + benchmarkService.benchmark(pp, tg, pl, nr) + val benchmarkEndTs = System.currentTimeMillis() + _benchmarkDuration.emit(benchmarkEndTs - benchmarkStartTs) + } + return true + } + + fun shareResults() { + // TODO-han.yin: TO BE IMPLEMENTED + } } + +data class BenchmarkResult( + val text: String, + val duration: Long +)