diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/BenchmarkScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/BenchmarkScreen.kt index 1769c8584d..475dc9a726 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/BenchmarkScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/BenchmarkScreen.kt @@ -3,22 +3,26 @@ package com.example.llama.ui.screens import android.llama.cpp.InferenceEngine.State import androidx.activity.compose.BackHandler import androidx.compose.foundation.background +import androidx.compose.foundation.border import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.PaddingValues +import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.size +import androidx.compose.foundation.layout.width import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.items import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material3.Card import androidx.compose.material3.CardDefaults import androidx.compose.material3.CircularProgressIndicator +import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.MaterialTheme import androidx.compose.material3.ProgressIndicatorDefaults import androidx.compose.material3.Text @@ -32,6 +36,8 @@ import androidx.compose.runtime.saveable.rememberSaveable import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.text.font.FontStyle +import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.unit.dp import com.example.llama.data.model.ModelInfo @@ -41,10 +47,13 @@ import com.example.llama.ui.components.ModelCardContentContextRow import com.example.llama.ui.components.ModelCardContentField import com.example.llama.ui.components.ModelCardCoreExpandable import com.example.llama.ui.components.ModelUnloadDialogHandler -import com.example.llama.ui.theme.MonospacedTextStyle +import com.example.llama.util.TableData import com.example.llama.util.formatMilliSeconds +import com.example.llama.util.parseMarkdownTable +import com.example.llama.viewmodel.BenchmarkResult import com.example.llama.viewmodel.BenchmarkViewModel + @Composable fun BenchmarkScreen( loadingMetrics: ModelLoadingMetrics, @@ -86,30 +95,8 @@ fun BenchmarkScreen( contentPadding = PaddingValues(8.dp), verticalArrangement = Arrangement.Bottom, ) { - items(items = benchmarkResults) { result -> - Card( - modifier = Modifier.fillMaxWidth().padding(8.dp) - ) { - Column( - modifier = Modifier - .fillMaxWidth() - .background( - color = MaterialTheme.colorScheme.surfaceVariant, - shape = RoundedCornerShape(8.dp) - ) - .padding(16.dp) - ) { - Text( - text = result.text, - style = MonospacedTextStyle, - color = MaterialTheme.colorScheme.onSurfaceVariant - ) - - Spacer(modifier = Modifier.height(4.dp)) - - ModelCardContentField("Time spent: ", formatMilliSeconds(result.duration)) - } - } + items(items = benchmarkResults) { + BenchmarkResultCard(it) } } @@ -153,9 +140,7 @@ fun BenchmarkScreen( // Selected model card and loading metrics if (showModelCard) { selectedModel?.let { model -> - Box( - modifier = Modifier.padding(start = 16.dp, top = 16.dp, end = 16.dp) - ) { + Box(modifier = Modifier.padding(start = 16.dp, top = 16.dp, end = 16.dp)) { ModelCardWithLoadingMetrics( model = model, loadingMetrics = loadingMetrics, @@ -198,3 +183,106 @@ private fun ModelCardWithLoadingMetrics( // Row 4: Model loading time ModelCardContentField("Loading time", formatMilliSeconds(loadingMetrics.modelLoadingTimeMs)) } + + +@Composable +fun BenchmarkResultCard(result: BenchmarkResult) { + val rawTable = parseMarkdownTable(result.text.trimIndent()) + val model = rawTable.getColumn("model").firstOrNull() ?: "Unknown" + val parameters = rawTable.getColumn("params").firstOrNull() ?: "-" + val size = rawTable.getColumn("size").firstOrNull() ?: "-" + + Card( + modifier = Modifier.fillMaxWidth().padding(8.dp) + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .background( + color = MaterialTheme.colorScheme.surfaceVariant, + shape = RoundedCornerShape(8.dp) + ) + .padding(16.dp) + ) { + Row { + Text( + text = "Model", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Normal, + ) + + Spacer(modifier = Modifier.width(16.dp)) + + Text( + modifier = Modifier.weight(1f), + text = model, + textAlign = TextAlign.Start, + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Light, + fontStyle = FontStyle.Italic, + ) + } + + Spacer(modifier = Modifier.height(8.dp)) + + Row { + ModelCardContentField("Parameters", parameters) + + Spacer(modifier = Modifier.weight(1f)) + + ModelCardContentField("Size", size) + } + + BenchmarkResultTable(rawTable) + + ModelCardContentField("Time spent: ", formatMilliSeconds(result.duration)) + } + } +} + +// Needs to be aligned with `bench` implementation +private val COLUMNS_TO_KEEP = setOf("backend", "test", "t/s") +private val WEIGHTS_EACH_COLUMN = listOf(1f, 1f, 2f) + +@Composable +fun BenchmarkResultTable( + rawTable: TableData, + columnsToKeep: Set = COLUMNS_TO_KEEP, + columnWeights: List = WEIGHTS_EACH_COLUMN +) { + val (headers, rows) = rawTable.filterColumns(columnsToKeep) + + Column( + modifier = Modifier + .padding(horizontal = 12.dp, vertical = 16.dp) + .border(1.dp, MaterialTheme.colorScheme.outline, shape = RoundedCornerShape(4.dp)) + .padding(12.dp), + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + BenchmarkResultTableRow(headers, columnWeights, isHeader = true) + HorizontalDivider(thickness = 1.dp) + rows.forEach { BenchmarkResultTableRow(it, columnWeights) } + } +} + +@Composable +fun BenchmarkResultTableRow( + cells: List, + weights: List? = null, + isHeader: Boolean = false, +) { + val effectiveWeights = weights ?: List(cells.size) { 1f } + + Row(modifier = Modifier.fillMaxWidth()) { + cells.forEachIndexed { index, cell -> + Text( + modifier = Modifier.weight(effectiveWeights.getOrElse(index) { 1f }), + text = cell, + textAlign = TextAlign.Center, + style = MaterialTheme.typography.bodyLarge, + fontWeight = if (isHeader) FontWeight.Normal else FontWeight.Light, + fontStyle = if (isHeader) FontStyle.Normal else FontStyle.Italic + ) + } + } +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/util/TableUtils.kt b/examples/llama.android/app/src/main/java/com/example/llama/util/TableUtils.kt index 69aa77008f..34b4baa1b0 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/util/TableUtils.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/util/TableUtils.kt @@ -2,37 +2,48 @@ package com.example.llama.util /** - * A basic + * A basic table data holder separating rows and columns */ -data class MarkdownTableData( +data class TableData( val headers: List, val rows: List> ) { val columnCount: Int get() = headers.size val rowCount: Int get() = rows.size + + /** + * Generate a copy of the original table with only the [keep] columns + */ + fun filterColumns(keep: Set): TableData = + headers.mapIndexedNotNull { index, name -> + if (name in keep) index else null + }.let { keepIndices -> + val newHeaders = keepIndices.map { headers[it] } + val newRows = rows.map { row -> keepIndices.map { row.getOrElse(it) { "" } } } + TableData(newHeaders, newRows) + } + + /** + * Obtain the data in the specified column + */ + fun getColumn(name: String): List { + val index = headers.indexOf(name) + if (index == -1) return emptyList() + return rows.mapNotNull { it.getOrNull(index) } + } } /** * Formats llama-bench's markdown output into structured [MarkdownTableData] */ -fun parseMarkdownTableFiltered( - markdown: String, - keepColumns: Set -): MarkdownTableData { +fun parseMarkdownTable(markdown: String): TableData { val lines = markdown.trim().lines().filter { it.startsWith("|") } - if (lines.size < 2) return MarkdownTableData(emptyList(), emptyList()) - - val rawHeaders = lines[0].split("|").map { it.trim() }.filter { it.isNotEmpty() } - val keepIndices = rawHeaders.mapIndexedNotNull { index, name -> - if (name in keepColumns) index else null - } - - val headers = keepIndices.map { rawHeaders[it] } + if (lines.size < 2) return TableData(emptyList(), emptyList()) + val headers = lines[0].split("|").map { it.trim() }.filter { it.isNotEmpty() } val rows = lines.drop(2).map { line -> - val cells = line.split("|").map { it.trim() }.filter { it.isNotEmpty() } - keepIndices.map { cells.getOrElse(it) { "" } } + line.split("|").map { it.trim() }.filter { it.isNotEmpty() } } - return MarkdownTableData(headers, rows) + return TableData(headers, rows) }