navigation: sink model loading state management from AppContent down into ModelLoadingScreen; pass ModelLoadingMetrics to Benchmark and Conversation screens
This commit is contained in:
parent
8a682ff85d
commit
a9466c0370
|
|
@ -21,9 +21,12 @@ import androidx.compose.runtime.remember
|
||||||
import androidx.compose.runtime.rememberCoroutineScope
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.hilt.navigation.compose.hiltViewModel
|
import androidx.hilt.navigation.compose.hiltViewModel
|
||||||
|
import androidx.navigation.NavType
|
||||||
import androidx.navigation.compose.composable
|
import androidx.navigation.compose.composable
|
||||||
import androidx.navigation.compose.currentBackStackEntryAsState
|
import androidx.navigation.compose.currentBackStackEntryAsState
|
||||||
import androidx.navigation.compose.rememberNavController
|
import androidx.navigation.compose.rememberNavController
|
||||||
|
import androidx.navigation.navArgument
|
||||||
|
import com.example.llama.revamp.engine.ModelLoadingMetrics
|
||||||
import com.example.llama.revamp.navigation.AppDestinations
|
import com.example.llama.revamp.navigation.AppDestinations
|
||||||
import com.example.llama.revamp.navigation.NavigationActions
|
import com.example.llama.revamp.navigation.NavigationActions
|
||||||
import com.example.llama.revamp.ui.components.AnimatedNavHost
|
import com.example.llama.revamp.ui.components.AnimatedNavHost
|
||||||
|
|
@ -48,7 +51,6 @@ import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
|
||||||
import com.example.llama.revamp.viewmodel.ModelsManagementViewModel
|
import com.example.llama.revamp.viewmodel.ModelsManagementViewModel
|
||||||
import com.example.llama.revamp.viewmodel.PerformanceViewModel
|
import com.example.llama.revamp.viewmodel.PerformanceViewModel
|
||||||
import dagger.hilt.android.AndroidEntryPoint
|
import dagger.hilt.android.AndroidEntryPoint
|
||||||
import kotlinx.coroutines.isActive
|
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
@AndroidEntryPoint
|
@AndroidEntryPoint
|
||||||
|
|
@ -110,9 +112,9 @@ fun AppContent(
|
||||||
val openDrawer: () -> Unit = { coroutineScope.launch { drawerState.open() } }
|
val openDrawer: () -> Unit = { coroutineScope.launch { drawerState.open() } }
|
||||||
|
|
||||||
// Create scaffold's top & bottom bar configs based on current route
|
// Create scaffold's top & bottom bar configs based on current route
|
||||||
val scaffoldConfig = when (currentRoute) {
|
val scaffoldConfig = when {
|
||||||
// Model selection screen
|
// Model selection screen
|
||||||
AppDestinations.MODEL_SELECTION_ROUTE ->
|
currentRoute == AppDestinations.MODEL_SELECTION_ROUTE ->
|
||||||
ScaffoldConfig(
|
ScaffoldConfig(
|
||||||
topBarConfig = TopBarConfig.Default(
|
topBarConfig = TopBarConfig.Default(
|
||||||
title = "Models",
|
title = "Models",
|
||||||
|
|
@ -121,7 +123,7 @@ fun AppContent(
|
||||||
)
|
)
|
||||||
|
|
||||||
// Model loading screen
|
// Model loading screen
|
||||||
AppDestinations.MODEL_LOADING_ROUTE ->
|
currentRoute == AppDestinations.MODEL_LOADING_ROUTE ->
|
||||||
ScaffoldConfig(
|
ScaffoldConfig(
|
||||||
topBarConfig = TopBarConfig.Performance(
|
topBarConfig = TopBarConfig.Performance(
|
||||||
title = "Load Model",
|
title = "Load Model",
|
||||||
|
|
@ -134,7 +136,7 @@ fun AppContent(
|
||||||
)
|
)
|
||||||
|
|
||||||
// Benchmark screen
|
// Benchmark screen
|
||||||
AppDestinations.BENCHMARK_ROUTE ->
|
currentRoute.startsWith(AppDestinations.BENCHMARK_ROUTE) ->
|
||||||
ScaffoldConfig(
|
ScaffoldConfig(
|
||||||
topBarConfig = TopBarConfig.Performance(
|
topBarConfig = TopBarConfig.Performance(
|
||||||
title = "Benchmark",
|
title = "Benchmark",
|
||||||
|
|
@ -147,7 +149,7 @@ fun AppContent(
|
||||||
)
|
)
|
||||||
|
|
||||||
// Conversation screen
|
// Conversation screen
|
||||||
AppDestinations.CONVERSATION_ROUTE ->
|
currentRoute.startsWith(AppDestinations.CONVERSATION_ROUTE) ->
|
||||||
ScaffoldConfig(
|
ScaffoldConfig(
|
||||||
topBarConfig = TopBarConfig.Performance(
|
topBarConfig = TopBarConfig.Performance(
|
||||||
title = "Chat",
|
title = "Chat",
|
||||||
|
|
@ -160,7 +162,7 @@ fun AppContent(
|
||||||
)
|
)
|
||||||
|
|
||||||
// Settings screen
|
// Settings screen
|
||||||
AppDestinations.SETTINGS_GENERAL_ROUTE ->
|
currentRoute == AppDestinations.SETTINGS_GENERAL_ROUTE ->
|
||||||
ScaffoldConfig(
|
ScaffoldConfig(
|
||||||
topBarConfig = TopBarConfig.Default(
|
topBarConfig = TopBarConfig.Default(
|
||||||
title = "Settings",
|
title = "Settings",
|
||||||
|
|
@ -169,7 +171,7 @@ fun AppContent(
|
||||||
)
|
)
|
||||||
|
|
||||||
// Storage management screen
|
// Storage management screen
|
||||||
AppDestinations.MODELS_MANAGEMENT_ROUTE -> {
|
currentRoute == AppDestinations.MODELS_MANAGEMENT_ROUTE -> {
|
||||||
// Collect the needed states
|
// Collect the needed states
|
||||||
val sortOrder by modelsManagementViewModel.sortOrder.collectAsState()
|
val sortOrder by modelsManagementViewModel.sortOrder.collectAsState()
|
||||||
val isMultiSelectionMode by modelsManagementViewModel.isMultiSelectionMode.collectAsState()
|
val isMultiSelectionMode by modelsManagementViewModel.isMultiSelectionMode.collectAsState()
|
||||||
|
|
@ -301,35 +303,59 @@ fun AppContent(
|
||||||
composable(AppDestinations.MODEL_LOADING_ROUTE) {
|
composable(AppDestinations.MODEL_LOADING_ROUTE) {
|
||||||
ModelLoadingScreen(
|
ModelLoadingScreen(
|
||||||
onNavigateBack = { navigationActions.navigateUp() },
|
onNavigateBack = { navigationActions.navigateUp() },
|
||||||
onBenchmarkSelected = { prepareJob ->
|
onNavigateToBenchmark = { navigationActions.navigateToBenchmark(it) },
|
||||||
// Wait for preparation to complete, then navigate if still active
|
onNavigateToConversation = { navigationActions.navigateToConversation(it) },
|
||||||
coroutineScope.launch {
|
|
||||||
prepareJob.join()
|
|
||||||
if (isActive) { navigationActions.navigateToBenchmark() }
|
|
||||||
}
|
|
||||||
},
|
|
||||||
onConversationSelected = { systemPrompt, prepareJob ->
|
|
||||||
// Wait for preparation to complete, then navigate if still active
|
|
||||||
coroutineScope.launch {
|
|
||||||
prepareJob.join()
|
|
||||||
if (isActive) { navigationActions.navigateToConversation() }
|
|
||||||
}
|
|
||||||
},
|
|
||||||
viewModel = modelLoadingViewModel
|
viewModel = modelLoadingViewModel
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Benchmark Screen
|
// Benchmark Screen
|
||||||
composable(AppDestinations.BENCHMARK_ROUTE) {
|
composable(
|
||||||
|
route = AppDestinations.BENCHMARK_ROUTE_WITH_PARAMS,
|
||||||
|
arguments = listOf(
|
||||||
|
navArgument("modelLoadTimeMs") {
|
||||||
|
type = NavType.LongType
|
||||||
|
defaultValue = 0L
|
||||||
|
}
|
||||||
|
)
|
||||||
|
) { backStackEntry ->
|
||||||
|
val modelLoadTimeMs = backStackEntry.arguments?.getLong("modelLoadTimeMs") ?: 0L
|
||||||
|
val metrics = if (modelLoadTimeMs > 0) {
|
||||||
|
ModelLoadingMetrics(modelLoadTimeMs)
|
||||||
|
} else throw IllegalArgumentException("Expecting a valid ModelLoadingMetrics!")
|
||||||
|
|
||||||
BenchmarkScreen(
|
BenchmarkScreen(
|
||||||
|
loadingMetrics = metrics,
|
||||||
onNavigateBack = { navigationActions.navigateUp() },
|
onNavigateBack = { navigationActions.navigateUp() },
|
||||||
viewModel = benchmarkViewModel
|
viewModel = benchmarkViewModel
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conversation Screen
|
// Conversation Screen
|
||||||
composable(AppDestinations.CONVERSATION_ROUTE) {
|
composable(
|
||||||
|
route = AppDestinations.CONVERSATION_ROUTE_WITH_PARAMS,
|
||||||
|
arguments = listOf(
|
||||||
|
navArgument("modelLoadTimeMs") {
|
||||||
|
type = NavType.LongType
|
||||||
|
defaultValue = 0L
|
||||||
|
},
|
||||||
|
navArgument("promptTimeMs") {
|
||||||
|
type = NavType.LongType
|
||||||
|
defaultValue = 0L
|
||||||
|
}
|
||||||
|
)
|
||||||
|
) { backStackEntry ->
|
||||||
|
val modelLoadTimeMs = backStackEntry.arguments?.getLong("modelLoadTimeMs") ?: 0L
|
||||||
|
val promptTimeMs = backStackEntry.arguments?.getLong("promptTimeMs") ?: 0L
|
||||||
|
val metrics = if (modelLoadTimeMs > 0) {
|
||||||
|
ModelLoadingMetrics(
|
||||||
|
modelLoadingTimeMs = modelLoadTimeMs,
|
||||||
|
systemPromptProcessingTimeMs = if (promptTimeMs > 0) promptTimeMs else null
|
||||||
|
)
|
||||||
|
} else throw IllegalArgumentException("Expecting a valid ModelLoadingMetrics!")
|
||||||
|
|
||||||
ConversationScreen(
|
ConversationScreen(
|
||||||
|
loadingMetrics = metrics,
|
||||||
onNavigateBack = { navigationActions.navigateUp() },
|
onNavigateBack = { navigationActions.navigateUp() },
|
||||||
viewModel = conversationViewModel
|
viewModel = conversationViewModel
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package com.example.llama.revamp.navigation
|
package com.example.llama.revamp.navigation
|
||||||
|
|
||||||
import androidx.navigation.NavController
|
import androidx.navigation.NavController
|
||||||
|
import com.example.llama.revamp.engine.ModelLoadingMetrics
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Navigation destinations for the app
|
* Navigation destinations for the app
|
||||||
|
|
@ -9,10 +10,14 @@ object AppDestinations {
|
||||||
// Primary navigation destinations
|
// Primary navigation destinations
|
||||||
const val MODEL_SELECTION_ROUTE = "model_selection"
|
const val MODEL_SELECTION_ROUTE = "model_selection"
|
||||||
const val MODEL_LOADING_ROUTE = "model_loading"
|
const val MODEL_LOADING_ROUTE = "model_loading"
|
||||||
const val CONVERSATION_ROUTE = "conversation"
|
|
||||||
const val BENCHMARK_ROUTE = "benchmark"
|
|
||||||
|
|
||||||
// Settings destinations (moved from tabs to separate routes)
|
const val CONVERSATION_ROUTE = "conversation"
|
||||||
|
const val CONVERSATION_ROUTE_WITH_PARAMS = "conversation/{modelLoadTimeMs}/{promptTimeMs}"
|
||||||
|
|
||||||
|
const val BENCHMARK_ROUTE = "benchmark"
|
||||||
|
const val BENCHMARK_ROUTE_WITH_PARAMS = "benchmark/{modelLoadTimeMs}"
|
||||||
|
|
||||||
|
// Settings destinations
|
||||||
const val SETTINGS_GENERAL_ROUTE = "settings_general"
|
const val SETTINGS_GENERAL_ROUTE = "settings_general"
|
||||||
const val MODELS_MANAGEMENT_ROUTE = "models_management"
|
const val MODELS_MANAGEMENT_ROUTE = "models_management"
|
||||||
}
|
}
|
||||||
|
|
@ -33,12 +38,17 @@ class NavigationActions(private val navController: NavController) {
|
||||||
navController.navigate(AppDestinations.MODEL_LOADING_ROUTE)
|
navController.navigate(AppDestinations.MODEL_LOADING_ROUTE)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun navigateToConversation() {
|
fun navigateToConversation(metrics: ModelLoadingMetrics) {
|
||||||
navController.navigate(AppDestinations.CONVERSATION_ROUTE)
|
val route = AppDestinations.CONVERSATION_ROUTE
|
||||||
|
val modelLoadTimeMs = metrics.modelLoadingTimeMs
|
||||||
|
val promptTimeMs = metrics.systemPromptProcessingTimeMs ?: 0
|
||||||
|
navController.navigate("$route/$modelLoadTimeMs/$promptTimeMs")
|
||||||
}
|
}
|
||||||
|
|
||||||
fun navigateToBenchmark() {
|
fun navigateToBenchmark(metrics: ModelLoadingMetrics) {
|
||||||
navController.navigate(AppDestinations.BENCHMARK_ROUTE)
|
val route = AppDestinations.BENCHMARK_ROUTE
|
||||||
|
val modelLoadTimeMs = metrics.modelLoadingTimeMs
|
||||||
|
navController.navigate("$route/$modelLoadTimeMs")
|
||||||
}
|
}
|
||||||
|
|
||||||
fun navigateToSettingsGeneral() {
|
fun navigateToSettingsGeneral() {
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ import androidx.compose.runtime.getValue
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
|
import com.example.llama.revamp.engine.ModelLoadingMetrics
|
||||||
import com.example.llama.revamp.ui.components.ModelCard
|
import com.example.llama.revamp.ui.components.ModelCard
|
||||||
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
|
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
|
||||||
import com.example.llama.revamp.ui.theme.MonospacedTextStyle
|
import com.example.llama.revamp.ui.theme.MonospacedTextStyle
|
||||||
|
|
@ -31,6 +32,7 @@ import com.example.llama.revamp.viewmodel.BenchmarkViewModel
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
fun BenchmarkScreen(
|
fun BenchmarkScreen(
|
||||||
|
loadingMetrics: ModelLoadingMetrics,
|
||||||
onNavigateBack: () -> Unit,
|
onNavigateBack: () -> Unit,
|
||||||
viewModel: BenchmarkViewModel
|
viewModel: BenchmarkViewModel
|
||||||
) {
|
) {
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,7 @@ import androidx.lifecycle.Lifecycle
|
||||||
import androidx.lifecycle.LifecycleEventObserver
|
import androidx.lifecycle.LifecycleEventObserver
|
||||||
import androidx.lifecycle.compose.LocalLifecycleOwner
|
import androidx.lifecycle.compose.LocalLifecycleOwner
|
||||||
import com.example.llama.revamp.APP_NAME
|
import com.example.llama.revamp.APP_NAME
|
||||||
|
import com.example.llama.revamp.engine.ModelLoadingMetrics
|
||||||
import com.example.llama.revamp.ui.components.ModelCardWithSystemPrompt
|
import com.example.llama.revamp.ui.components.ModelCardWithSystemPrompt
|
||||||
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
|
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
|
||||||
import com.example.llama.revamp.viewmodel.ConversationViewModel
|
import com.example.llama.revamp.viewmodel.ConversationViewModel
|
||||||
|
|
@ -68,6 +69,7 @@ import kotlinx.coroutines.launch
|
||||||
*/
|
*/
|
||||||
@Composable
|
@Composable
|
||||||
fun ConversationScreen(
|
fun ConversationScreen(
|
||||||
|
loadingMetrics: ModelLoadingMetrics,
|
||||||
onNavigateBack: () -> Unit,
|
onNavigateBack: () -> Unit,
|
||||||
viewModel: ConversationViewModel
|
viewModel: ConversationViewModel
|
||||||
) {
|
) {
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,6 @@ import androidx.compose.runtime.collectAsState
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
import androidx.compose.runtime.mutableStateOf
|
import androidx.compose.runtime.mutableStateOf
|
||||||
import androidx.compose.runtime.remember
|
import androidx.compose.runtime.remember
|
||||||
import androidx.compose.runtime.rememberCoroutineScope
|
|
||||||
import androidx.compose.runtime.setValue
|
import androidx.compose.runtime.setValue
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
|
|
@ -50,11 +49,10 @@ import androidx.compose.ui.semantics.Role
|
||||||
import androidx.compose.ui.text.style.TextOverflow
|
import androidx.compose.ui.text.style.TextOverflow
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import com.example.llama.revamp.data.model.SystemPrompt
|
import com.example.llama.revamp.data.model.SystemPrompt
|
||||||
|
import com.example.llama.revamp.engine.ModelLoadingMetrics
|
||||||
import com.example.llama.revamp.ui.components.ModelCard
|
import com.example.llama.revamp.ui.components.ModelCard
|
||||||
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
|
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
|
||||||
import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
|
import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
|
||||||
import kotlinx.coroutines.Job
|
|
||||||
import kotlinx.coroutines.launch
|
|
||||||
|
|
||||||
|
|
||||||
enum class Mode {
|
enum class Mode {
|
||||||
|
|
@ -70,12 +68,10 @@ enum class SystemPromptTab(val label: String) {
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelLoadingScreen(
|
fun ModelLoadingScreen(
|
||||||
onNavigateBack: () -> Unit,
|
onNavigateBack: () -> Unit,
|
||||||
onBenchmarkSelected: (prepareJob: Job) -> Unit,
|
onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit,
|
||||||
onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit,
|
onNavigateToConversation: (ModelLoadingMetrics) -> Unit,
|
||||||
viewModel: ModelLoadingViewModel,
|
viewModel: ModelLoadingViewModel,
|
||||||
) {
|
) {
|
||||||
val coroutineScope = rememberCoroutineScope()
|
|
||||||
|
|
||||||
val engineState by viewModel.engineState.collectAsState()
|
val engineState by viewModel.engineState.collectAsState()
|
||||||
val selectedModel by viewModel.selectedModel.collectAsState()
|
val selectedModel by viewModel.selectedModel.collectAsState()
|
||||||
val presetPrompts by viewModel.presetPrompts.collectAsState()
|
val presetPrompts by viewModel.presetPrompts.collectAsState()
|
||||||
|
|
@ -108,22 +104,6 @@ fun ModelLoadingScreen(
|
||||||
// Check if we're in a loading state
|
// Check if we're in a loading state
|
||||||
val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady
|
val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady
|
||||||
|
|
||||||
// Mode selection callbacks
|
|
||||||
val handleBenchmarkSelected = {
|
|
||||||
val prepareJob = coroutineScope.launch {
|
|
||||||
viewModel.prepareForBenchmark()
|
|
||||||
}
|
|
||||||
onBenchmarkSelected(prepareJob)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO-han.yin: refactor this into ViewModel too
|
|
||||||
val handleConversationSelected = { systemPrompt: String? ->
|
|
||||||
val prepareJob = coroutineScope.launch {
|
|
||||||
viewModel.prepareForConversation(systemPrompt)
|
|
||||||
}
|
|
||||||
onConversationSelected(systemPrompt, prepareJob)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle back navigation requests
|
// Handle back navigation requests
|
||||||
BackHandler {
|
BackHandler {
|
||||||
viewModel.onBackPressed(onNavigateBack)
|
viewModel.onBackPressed(onNavigateBack)
|
||||||
|
|
@ -301,7 +281,7 @@ fun ModelLoadingScreen(
|
||||||
Button(
|
Button(
|
||||||
onClick = {
|
onClick = {
|
||||||
when (selectedMode) {
|
when (selectedMode) {
|
||||||
Mode.BENCHMARK -> handleBenchmarkSelected()
|
Mode.BENCHMARK -> viewModel.onBenchmarkSelected(onNavigateToBenchmark)
|
||||||
|
|
||||||
Mode.CONVERSATION -> {
|
Mode.CONVERSATION -> {
|
||||||
val systemPrompt = if (useSystemPrompt) {
|
val systemPrompt = if (useSystemPrompt) {
|
||||||
|
|
@ -324,7 +304,7 @@ fun ModelLoadingScreen(
|
||||||
}
|
}
|
||||||
} else null
|
} else null
|
||||||
|
|
||||||
handleConversationSelected(systemPrompt)
|
viewModel.onConversationSelected(systemPrompt, onNavigateToConversation)
|
||||||
}
|
}
|
||||||
|
|
||||||
null -> { /* No mode selected */
|
null -> { /* No mode selected */
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package com.example.llama.revamp.viewmodel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.example.llama.revamp.data.model.SystemPrompt
|
import com.example.llama.revamp.data.model.SystemPrompt
|
||||||
import com.example.llama.revamp.data.repository.SystemPromptRepository
|
import com.example.llama.revamp.data.repository.SystemPromptRepository
|
||||||
|
import com.example.llama.revamp.engine.ModelLoadingMetrics
|
||||||
import com.example.llama.revamp.engine.ModelLoadingService
|
import com.example.llama.revamp.engine.ModelLoadingService
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
import kotlinx.coroutines.flow.SharingStarted
|
import kotlinx.coroutines.flow.SharingStarted
|
||||||
|
|
@ -79,17 +80,23 @@ class ModelLoadingViewModel @Inject constructor(
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Prepares the engine for benchmark mode.
|
* Loads the model, then navigate to [BenchmarkScreen] with [ModelLoadingMetrics]
|
||||||
*/
|
*/
|
||||||
suspend fun prepareForBenchmark() =
|
fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) =
|
||||||
modelLoadingService.loadModelForBenchmark()
|
viewModelScope.launch {
|
||||||
|
onNavigateToBenchmark(modelLoadingService.loadModelForBenchmark())
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Prepare for conversation
|
* Loads the model, process system prompt if any,
|
||||||
|
* then navigate to [ConversationScreen] with [ModelLoadingMetrics]
|
||||||
*/
|
*/
|
||||||
suspend fun prepareForConversation(systemPrompt: String? = null) =
|
fun onConversationSelected(
|
||||||
modelLoadingService.loadModelForConversation(systemPrompt)
|
systemPrompt: String? = null,
|
||||||
|
onNavigateToConversation: (ModelLoadingMetrics) -> Unit
|
||||||
|
) = viewModelScope.launch {
|
||||||
|
onNavigateToConversation(modelLoadingService.loadModelForConversation(systemPrompt))
|
||||||
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private val TAG = ModelLoadingViewModel::class.java.simpleName
|
private val TAG = ModelLoadingViewModel::class.java.simpleName
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue