core: swap in LLamaAndroid and mark stub engine for testing only

This commit is contained in:
Han Yin 2025-04-16 11:14:14 -07:00
parent c2426a42e5
commit d70b8fe323
12 changed files with 157 additions and 129 deletions

View File

@ -1,5 +1,6 @@
package com.example.llama.revamp package com.example.llama.revamp
import android.llama.cpp.InferenceEngine.State
import android.os.Bundle import android.os.Bundle
import androidx.activity.ComponentActivity import androidx.activity.ComponentActivity
import androidx.activity.OnBackPressedCallback import androidx.activity.OnBackPressedCallback
@ -27,7 +28,6 @@ import androidx.lifecycle.compose.LocalLifecycleOwner
import androidx.navigation.compose.composable import androidx.navigation.compose.composable
import androidx.navigation.compose.currentBackStackEntryAsState import androidx.navigation.compose.currentBackStackEntryAsState
import androidx.navigation.compose.rememberNavController import androidx.navigation.compose.rememberNavController
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.navigation.AppDestinations import com.example.llama.revamp.navigation.AppDestinations
import com.example.llama.revamp.navigation.NavigationActions import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.AnimatedNavHost import com.example.llama.revamp.ui.components.AnimatedNavHost
@ -35,9 +35,9 @@ import com.example.llama.revamp.ui.components.AppNavigationDrawer
import com.example.llama.revamp.ui.components.UnloadModelConfirmationDialog import com.example.llama.revamp.ui.components.UnloadModelConfirmationDialog
import com.example.llama.revamp.ui.screens.BenchmarkScreen import com.example.llama.revamp.ui.screens.BenchmarkScreen
import com.example.llama.revamp.ui.screens.ConversationScreen import com.example.llama.revamp.ui.screens.ConversationScreen
import com.example.llama.revamp.ui.screens.ModelLoadingScreen
import com.example.llama.revamp.ui.screens.ModelSelectionScreen import com.example.llama.revamp.ui.screens.ModelSelectionScreen
import com.example.llama.revamp.ui.screens.ModelsManagementScreen import com.example.llama.revamp.ui.screens.ModelsManagementScreen
import com.example.llama.revamp.ui.screens.ModelLoadingScreen
import com.example.llama.revamp.ui.screens.SettingsGeneralScreen import com.example.llama.revamp.ui.screens.SettingsGeneralScreen
import com.example.llama.revamp.ui.theme.LlamaTheme import com.example.llama.revamp.ui.theme.LlamaTheme
import com.example.llama.revamp.viewmodel.MainViewModel import com.example.llama.revamp.viewmodel.MainViewModel
@ -72,10 +72,10 @@ fun AppContent(
// LLM Inference engine status // LLM Inference engine status
val engineState by mainVewModel.engineState.collectAsState() val engineState by mainVewModel.engineState.collectAsState()
val isModelLoading = engineState is InferenceEngine.State.LoadingModel val isModelLoading = engineState is State.LoadingModel
|| engineState is InferenceEngine.State.ProcessingSystemPrompt || engineState is State.ProcessingSystemPrompt
val isModelLoaded = engineState !is InferenceEngine.State.Uninitialized val isModelLoaded = engineState !is State.Uninitialized
&& engineState !is InferenceEngine.State.LibraryLoaded && engineState !is State.LibraryLoaded
// Navigation // Navigation
val navController = rememberNavController() val navController = rememberNavController()

View File

@ -1,6 +1,8 @@
package com.example.llama.revamp.di package com.example.llama.revamp.di
import android.content.Context import android.content.Context
import android.llama.cpp.InferenceEngine
import android.llama.cpp.LLamaAndroid
import com.example.llama.revamp.data.local.AppDatabase import com.example.llama.revamp.data.local.AppDatabase
import com.example.llama.revamp.data.repository.ModelRepository import com.example.llama.revamp.data.repository.ModelRepository
import com.example.llama.revamp.data.repository.ModelRepositoryImpl import com.example.llama.revamp.data.repository.ModelRepositoryImpl
@ -8,7 +10,7 @@ import com.example.llama.revamp.data.repository.SystemPromptRepository
import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl
import com.example.llama.revamp.engine.BenchmarkService import com.example.llama.revamp.engine.BenchmarkService
import com.example.llama.revamp.engine.ConversationService import com.example.llama.revamp.engine.ConversationService
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.StubInferenceEngine
import com.example.llama.revamp.engine.InferenceService import com.example.llama.revamp.engine.InferenceService
import com.example.llama.revamp.engine.InferenceServiceImpl import com.example.llama.revamp.engine.InferenceServiceImpl
import com.example.llama.revamp.engine.ModelLoadingService import com.example.llama.revamp.engine.ModelLoadingService
@ -46,10 +48,12 @@ internal abstract class AppModule {
companion object { companion object {
@Provides @Provides
@Singleton @Singleton
fun provideInferenceEngine() = InferenceEngine() fun provideInferenceEngine(): InferenceEngine {
val useRealEngine = true
return if (useRealEngine) LLamaAndroid.instance() else StubInferenceEngine()
}
@Provides @Provides
@Singleton
fun providePerformanceMonitor(@ApplicationContext context: Context) = PerformanceMonitor(context) fun providePerformanceMonitor(@ApplicationContext context: Context) = PerformanceMonitor(context)
@Provides @Provides

View File

@ -1,5 +1,7 @@
package com.example.llama.revamp.engine package com.example.llama.revamp.engine
import android.llama.cpp.InferenceEngine
import android.llama.cpp.InferenceEngine.State
import android.util.Log import android.util.Log
import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.model.ModelInfo
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
@ -14,7 +16,7 @@ interface InferenceService {
/** /**
* Expose engine state * Expose engine state
*/ */
val engineState: StateFlow<InferenceEngine.State> val engineState: StateFlow<State>
/** /**
* Currently selected model * Currently selected model
@ -51,14 +53,14 @@ interface BenchmarkService : InferenceService {
* @param pp: Prompt Processing size * @param pp: Prompt Processing size
* @param tg: Token Generation size * @param tg: Token Generation size
* @param pl: Parallel sequences * @param pl: Parallel sequences
* @param nr: repetitions (Number of Runs) * @param nr: Number of Runs, i.e. repetitions
*/ */
suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String
/** /**
* Benchmark results * Benchmark results
*/ */
val results: StateFlow<String?> val benchmarkResults: StateFlow<String?>
} }
interface ConversationService : InferenceService { interface ConversationService : InferenceService {
@ -108,7 +110,7 @@ internal class InferenceServiceImpl @Inject internal constructor(
/* InferenceService implementation */ /* InferenceService implementation */
override val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state override val engineState: StateFlow<State> = inferenceEngine.state
private val _currentModel = MutableStateFlow<ModelInfo?>(null) private val _currentModel = MutableStateFlow<ModelInfo?>(null)
override val currentSelectedModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow() override val currentSelectedModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
@ -156,9 +158,15 @@ internal class InferenceServiceImpl @Inject internal constructor(
/* BenchmarkService implementation */ /* BenchmarkService implementation */
override suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String = override suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String =
inferenceEngine.bench(pp, tg, pl, nr) inferenceEngine.bench(pp, tg, pl, nr).also {
_benchmarkResults.value = it
}
override val results: StateFlow<String?> = inferenceEngine.benchmarkResults /**
* Benchmark results if available
*/
private val _benchmarkResults = MutableStateFlow<String?>(null)
override val benchmarkResults: StateFlow<String?> = _benchmarkResults
/* ConversationService implementation */ /* ConversationService implementation */

View File

@ -1,5 +1,7 @@
package com.example.llama.revamp.engine package com.example.llama.revamp.engine
import android.llama.cpp.InferenceEngine
import android.llama.cpp.InferenceEngine.State
import android.util.Log import android.util.Log
import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
@ -8,44 +10,29 @@ import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flow
import org.jetbrains.annotations.TestOnly
import org.jetbrains.annotations.VisibleForTesting
import javax.inject.Singleton
/** /**
* LLM inference engine that handles model loading and text generation. * A stub [InferenceEngine] for agile development & testing
*/ */
class InferenceEngine { @VisibleForTesting
@TestOnly
@Singleton
class StubInferenceEngine : InferenceEngine {
companion object { companion object {
private val TAG = InferenceEngine::class.java.simpleName private val TAG = StubInferenceEngine::class.java.simpleName
private const val DEFAULT_PREDICT_LENGTH = 1024 private const val STUB_MODEL_LOADING_TIME = 2000L
} private const val STUB_BENCHMARKING_TIME = 4000L
private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 3000L
sealed class State { private const val STUB_USER_PROMPT_PROCESSING_TIME = 1500L
object Uninitialized : State() private const val STUB_TOKEN_GENERATION_TIME = 200L
object LibraryLoaded : State()
object LoadingModel : State()
object ModelLoaded : State()
object ProcessingSystemPrompt : State()
object AwaitingUserPrompt : State()
object ProcessingUserPrompt : State()
object Generating : State()
object Benchmarking : State()
data class Error(
val errorMessage: String = ""
) : State()
} }
private val _state = MutableStateFlow<State>(State.Uninitialized) private val _state = MutableStateFlow<State>(State.Uninitialized)
val state: StateFlow<State> = _state override val state: StateFlow<State> = _state
// Keep track of current benchmark results
private var _benchmarkResults: String? = null
private val _benchmarkResultsFlow = MutableStateFlow<String?>(null)
val benchmarkResults: StateFlow<String?> = _benchmarkResultsFlow
init { init {
Log.i(TAG, "Initiated!") Log.i(TAG, "Initiated!")
@ -57,14 +44,14 @@ class InferenceEngine {
/** /**
* Loads a model from the given path with an optional system prompt. * Loads a model from the given path with an optional system prompt.
*/ */
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) { override suspend fun loadModel(pathToModel: String, systemPrompt: String?) {
Log.i(TAG, "loadModel! state: ${_state.value}") Log.i(TAG, "loadModel! state: ${_state.value}")
try { try {
_state.value = State.LoadingModel _state.value = State.LoadingModel
// Simulate model loading // Simulate model loading
delay(2000) delay(STUB_MODEL_LOADING_TIME)
_state.value = State.ModelLoaded _state.value = State.ModelLoaded
@ -72,7 +59,7 @@ class InferenceEngine {
_state.value = State.ProcessingSystemPrompt _state.value = State.ProcessingSystemPrompt
// Simulate processing system prompt // Simulate processing system prompt
delay(3000) delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME)
} }
_state.value = State.AwaitingUserPrompt _state.value = State.AwaitingUserPrompt
@ -87,7 +74,7 @@ class InferenceEngine {
/** /**
* Sends a user prompt to the loaded model and returns a Flow of generated tokens. * Sends a user prompt to the loaded model and returns a Flow of generated tokens.
*/ */
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String> { override fun sendUserPrompt(message: String, predictLength: Int): Flow<String> {
Log.i(TAG, "sendUserPrompt! state: ${_state.value}") Log.i(TAG, "sendUserPrompt! state: ${_state.value}")
_state.value = State.ProcessingUserPrompt _state.value = State.ProcessingUserPrompt
@ -96,18 +83,15 @@ class InferenceEngine {
return flow { return flow {
try { try {
// Simulate longer processing time (1.5 seconds) // Simulate longer processing time (1.5 seconds)
delay(1500) delay(STUB_USER_PROMPT_PROCESSING_TIME)
_state.value = State.Generating _state.value = State.Generating
// Simulate token generation // Simulate token generation
val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message" val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
val words = response.split(" ") response.split(" ").forEach {
emit("$it ")
for (word in words) { delay(STUB_TOKEN_GENERATION_TIME)
emit(word + " ")
// Slower token generation (200ms per token instead of 50ms)
delay(200)
} }
_state.value = State.AwaitingUserPrompt _state.value = State.AwaitingUserPrompt
@ -131,26 +115,26 @@ class InferenceEngine {
/** /**
* Runs a benchmark with the specified parameters. * Runs a benchmark with the specified parameters.
*/ */
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String {
Log.i(TAG, "bench! state: ${_state.value}") Log.i(TAG, "bench! state: ${_state.value}")
_state.value = State.Benchmarking _state.value = State.Benchmarking
try { try {
// Simulate benchmark running // Simulate benchmark running
delay(4000) delay(STUB_BENCHMARKING_TIME)
// Generate fake benchmark results // Generate fake benchmark results
val modelDesc = "LlamaModel" val modelDesc = "Kleidi Llama"
val model_size = "7" val model_size = "7"
val model_n_params = "7" val model_n_params = "7"
val backend = "CPU" val backend = "CPU"
// Random values for benchmarks // Random values for benchmarks
val pp_avg = (15.0 + Math.random() * 10.0).toFloat() val pp_avg = (51.4 + Math.random() * 5.14).toFloat()
val pp_std = (0.5 + Math.random() * 2.0).toFloat() val pp_std = (5.14 + Math.random() * 0.514).toFloat()
val tg_avg = (20.0 + Math.random() * 15.0).toFloat() val tg_avg = (11.4 + Math.random() * 1.14).toFloat()
val tg_std = (0.7 + Math.random() * 3.0).toFloat() val tg_std = (1.14 + Math.random() * 0.114).toFloat()
val result = StringBuilder() val result = StringBuilder()
result.append("| model | size | params | backend | test | t/s |\n") result.append("| model | size | params | backend | test | t/s |\n")
@ -160,12 +144,9 @@ class InferenceEngine {
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ") result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n") result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
_benchmarkResults = result.toString()
_benchmarkResultsFlow.value = _benchmarkResults
_state.value = State.AwaitingUserPrompt _state.value = State.AwaitingUserPrompt
return _benchmarkResults ?: "" return result.toString()
} catch (e: CancellationException) { } catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation // If coroutine is cancelled, propagate cancellation
_state.value = State.AwaitingUserPrompt _state.value = State.AwaitingUserPrompt
@ -179,20 +160,18 @@ class InferenceEngine {
/** /**
* Unloads the currently loaded model. * Unloads the currently loaded model.
*/ */
suspend fun unloadModel() { override suspend fun unloadModel() {
Log.i(TAG, "unloadModel! state: ${_state.value}") Log.i(TAG, "unloadModel! state: ${_state.value}")
// Simulate model unloading time // Simulate model unloading time
delay(2000) delay(2000)
_state.value = State.LibraryLoaded _state.value = State.LibraryLoaded
_benchmarkResults = null
_benchmarkResultsFlow.value = null
} }
/** /**
* Cleans up resources when the engine is no longer needed. * Cleans up resources when the engine is no longer needed.
*/ */
fun destroy() { override fun destroy() {
Log.i(TAG, "destroy! state: ${_state.value}") Log.i(TAG, "destroy! state: ${_state.value}")
_state.value = State.Uninitialized _state.value = State.Uninitialized

View File

@ -5,18 +5,21 @@ import android.content.Context
import android.content.Intent import android.content.Intent
import android.content.IntentFilter import android.content.IntentFilter
import android.os.BatteryManager import android.os.BatteryManager
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import javax.inject.Singleton
import kotlin.math.roundToInt import kotlin.math.roundToInt
/** /**
* Service that monitors device performance metrics such as memory usage, * Service that monitors device performance metrics such as memory usage,
* battery level, and temperature. * battery level, and temperature.
*/ */
class PerformanceMonitor(private val context: Context) { @Singleton
class PerformanceMonitor(@ApplicationContext private val context: Context) {
/** /**
* Provides a flow of memory usage information that updates at the specified interval. * Provides a flow of memory usage information that updates at the specified interval.

View File

@ -1,5 +1,6 @@
package com.example.llama.revamp.ui.screens package com.example.llama.revamp.ui.screens
import android.llama.cpp.InferenceEngine.State
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
@ -23,12 +24,10 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.ui.components.ModelCard import com.example.llama.revamp.ui.components.ModelCard
import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.components.PerformanceAppScaffold
import com.example.llama.revamp.ui.theme.MonospacedTextStyle import com.example.llama.revamp.ui.theme.MonospacedTextStyle
import com.example.llama.revamp.viewmodel.BenchmarkViewModel import com.example.llama.revamp.viewmodel.BenchmarkViewModel
import com.example.llama.revamp.viewmodel.MainViewModel
@Composable @Composable
fun BenchmarkScreen( fun BenchmarkScreen(
@ -67,7 +66,7 @@ fun BenchmarkScreen(
// Benchmark results or loading indicator // Benchmark results or loading indicator
when { when {
engineState is InferenceEngine.State.Benchmarking -> { engineState is State.Benchmarking -> {
Box( Box(
modifier = Modifier modifier = Modifier
.fillMaxWidth() .fillMaxWidth()

View File

@ -1,16 +1,12 @@
package com.example.llama.revamp.ui.screens package com.example.llama.revamp.ui.screens
import androidx.compose.animation.AnimatedVisibility import android.llama.cpp.InferenceEngine.State
import androidx.compose.animation.core.LinearEasing import androidx.compose.animation.core.LinearEasing
import androidx.compose.animation.core.RepeatMode import androidx.compose.animation.core.RepeatMode
import androidx.compose.animation.core.animateFloat import androidx.compose.animation.core.animateFloat
import androidx.compose.animation.core.infiniteRepeatable import androidx.compose.animation.core.infiniteRepeatable
import androidx.compose.animation.core.rememberInfiniteTransition import androidx.compose.animation.core.rememberInfiniteTransition
import androidx.compose.animation.core.tween import androidx.compose.animation.core.tween
import androidx.compose.animation.expandVertically
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.shrinkVertically
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
@ -33,7 +29,6 @@ import androidx.compose.material.icons.filled.Send
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults import androidx.compose.material3.CardDefaults
import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.HorizontalDivider
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
@ -61,8 +56,6 @@ import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.Lifecycle import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleEventObserver import androidx.lifecycle.LifecycleEventObserver
import androidx.lifecycle.compose.LocalLifecycleOwner import androidx.lifecycle.compose.LocalLifecycleOwner
import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.ui.components.ModelCardWithSystemPrompt import com.example.llama.revamp.ui.components.ModelCardWithSystemPrompt
import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.components.PerformanceAppScaffold
import com.example.llama.revamp.viewmodel.ConversationViewModel import com.example.llama.revamp.viewmodel.ConversationViewModel
@ -82,8 +75,8 @@ fun ConversationScreen(
val systemPrompt by viewModel.systemPrompt.collectAsState() val systemPrompt by viewModel.systemPrompt.collectAsState()
val selectedModel by viewModel.selectedModel.collectAsState() val selectedModel by viewModel.selectedModel.collectAsState()
val isProcessing = engineState is InferenceEngine.State.ProcessingUserPrompt val isProcessing = engineState is State.ProcessingUserPrompt
val isGenerating = engineState is InferenceEngine.State.Generating val isGenerating = engineState is State.Generating
val listState = rememberLazyListState() val listState = rememberLazyListState()
var inputText by remember { mutableStateOf("") } var inputText by remember { mutableStateOf("") }

View File

@ -1,5 +1,6 @@
package com.example.llama.revamp.ui.screens package com.example.llama.revamp.ui.screens
import android.llama.cpp.InferenceEngine.State
import androidx.compose.animation.AnimatedVisibility import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.expandVertically import androidx.compose.animation.expandVertically
import androidx.compose.animation.fadeIn import androidx.compose.animation.fadeIn
@ -49,7 +50,6 @@ import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.data.model.SystemPrompt
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.ui.components.ModelCard import com.example.llama.revamp.ui.components.ModelCard
import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.components.PerformanceAppScaffold
import com.example.llama.revamp.viewmodel.ModelLoadingViewModel import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
@ -63,7 +63,7 @@ enum class SystemPromptTab {
@OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class) @OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
@Composable @Composable
fun ModelLoadingScreen( fun ModelLoadingScreen(
engineState: InferenceEngine.State, engineState: State,
onBenchmarkSelected: (prepareJob: Job) -> Unit, onBenchmarkSelected: (prepareJob: Job) -> Unit,
onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit, onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit,
onBackPressed: () -> Unit, onBackPressed: () -> Unit,
@ -99,9 +99,9 @@ fun ModelLoadingScreen(
} }
// Check if we're in a loading state // Check if we're in a loading state
val isLoading = engineState !is InferenceEngine.State.Uninitialized && val isLoading = engineState !is State.Uninitialized &&
engineState !is InferenceEngine.State.LibraryLoaded && engineState !is State.LibraryLoaded &&
engineState !is InferenceEngine.State.AwaitingUserPrompt engineState !is State.AwaitingUserPrompt
// Mode selection callbacks // Mode selection callbacks
val handleBenchmarkSelected = { val handleBenchmarkSelected = {
@ -429,9 +429,9 @@ fun ModelLoadingScreen(
Spacer(modifier = Modifier.width(8.dp)) Spacer(modifier = Modifier.width(8.dp))
Text( Text(
text = when (engineState) { text = when (engineState) {
is InferenceEngine.State.LoadingModel -> "Loading model..." is State.LoadingModel -> "Loading model..."
is InferenceEngine.State.ProcessingSystemPrompt -> "Processing system prompt..." is State.ProcessingSystemPrompt -> "Processing system prompt..."
is InferenceEngine.State.ModelLoaded -> "Preparing conversation..." is State.ModelLoaded -> "Preparing conversation..."
else -> "Processing..." else -> "Processing..."
}, },
style = MaterialTheme.typography.titleMedium style = MaterialTheme.typography.titleMedium

View File

@ -1,10 +1,10 @@
package com.example.llama.revamp.viewmodel package com.example.llama.revamp.viewmodel
import android.llama.cpp.InferenceEngine.State
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.engine.BenchmarkService import com.example.llama.revamp.engine.BenchmarkService
import com.example.llama.revamp.engine.InferenceEngine
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@ -15,8 +15,8 @@ class BenchmarkViewModel @Inject constructor(
private val benchmarkService: BenchmarkService private val benchmarkService: BenchmarkService
) : ViewModel() { ) : ViewModel() {
val engineState: StateFlow<InferenceEngine.State> = benchmarkService.engineState val engineState: StateFlow<State> = benchmarkService.engineState
val benchmarkResults: StateFlow<String?> = benchmarkService.results val benchmarkResults: StateFlow<String?> = benchmarkService.benchmarkResults
val selectedModel: StateFlow<ModelInfo?> = benchmarkService.currentSelectedModel val selectedModel: StateFlow<ModelInfo?> = benchmarkService.currentSelectedModel
/** /**

View File

@ -1,7 +1,6 @@
package com.example.llama.revamp.viewmodel package com.example.llama.revamp.viewmodel
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.engine.InferenceService import com.example.llama.revamp.engine.InferenceService
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import javax.inject.Inject import javax.inject.Inject

View File

@ -0,0 +1,64 @@
package android.llama.cpp
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.StateFlow
/**
* Interface defining the core LLM inference operations.
*/
interface InferenceEngine {
/**
* Current state of the inference engine
*/
val state: StateFlow<State>
/**
* Load a model from the given path with an optional system prompt.
*/
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null)
/**
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
*/
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String>
/**
* Runs a benchmark with the specified parameters.
*/
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
/**
* Unloads the currently loaded model.
*/
suspend fun unloadModel()
/**
* Cleans up resources when the engine is no longer needed.
*/
fun destroy()
/**
* States of the inference engine
*/
sealed class State {
object Uninitialized : State()
object LibraryLoaded : State()
object LoadingModel : State()
object ModelLoaded : State()
object ProcessingSystemPrompt : State()
object AwaitingUserPrompt : State()
object ProcessingUserPrompt : State()
object Generating : State()
object Benchmarking : State()
data class Error(val errorMessage: String = "") : State()
}
companion object {
const val DEFAULT_PREDICT_LENGTH = 1024
}
}

View File

@ -1,5 +1,7 @@
package android.llama.cpp package android.llama.cpp
import android.llama.cpp.InferenceEngine.State
import android.llama.cpp.LLamaAndroid.Companion.instance
import android.util.Log import android.util.Log
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
@ -38,7 +40,7 @@ annotation class RequiresCleanup(val message: String = "Remember to call this me
* *
* @see llama-android.cpp for the native implementation details * @see llama-android.cpp for the native implementation details
*/ */
class LLamaAndroid private constructor() { class LLamaAndroid private constructor() : InferenceEngine {
/** /**
* JNI methods * JNI methods
* @see llama-android.cpp * @see llama-android.cpp
@ -57,32 +59,8 @@ class LLamaAndroid private constructor() {
private external fun unload() private external fun unload()
private external fun shutdown() private external fun shutdown()
/**
* Fine-grained state management
*/
sealed class State {
object Uninitialized : State()
object LibraryLoaded : State()
object LoadingModel : State()
object ModelLoaded : State()
object ProcessingSystemPrompt : State()
object AwaitingUserPrompt : State()
object ProcessingUserPrompt : State()
object Generating : State()
object Benchmarking : State()
data class Error(
val errorMessage: String = ""
) : State()
}
private val _state = MutableStateFlow<State>(State.Uninitialized) private val _state = MutableStateFlow<State>(State.Uninitialized)
val state: StateFlow<State> = _state override val state: StateFlow<State> = _state
/** /**
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations * Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
@ -108,7 +86,7 @@ class LLamaAndroid private constructor() {
/** /**
* Load the LLM, then process the plain text system prompt if provided * Load the LLM, then process the plain text system prompt if provided
*/ */
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) = override suspend fun loadModel(pathToModel: String, systemPrompt: String?) =
withContext(llamaDispatcher) { withContext(llamaDispatcher) {
check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" } check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" }
File(pathToModel).let { File(pathToModel).let {
@ -147,9 +125,9 @@ class LLamaAndroid private constructor() {
/** /**
* Send plain text user prompt to LLM, which starts generating tokens in a [Flow] * Send plain text user prompt to LLM, which starts generating tokens in a [Flow]
*/ */
fun sendUserPrompt( override fun sendUserPrompt(
message: String, message: String,
predictLength: Int = DEFAULT_PREDICT_LENGTH, predictLength: Int,
): Flow<String> = flow { ): Flow<String> = flow {
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" } require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is State.AwaitingUserPrompt) { check(_state.value is State.AwaitingUserPrompt) {
@ -179,7 +157,7 @@ class LLamaAndroid private constructor() {
/** /**
* Benchmark the model * Benchmark the model
*/ */
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String = override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
withContext(llamaDispatcher) { withContext(llamaDispatcher) {
check(_state.value is State.AwaitingUserPrompt) { check(_state.value is State.AwaitingUserPrompt) {
"Benchmark request discarded due to: $state" "Benchmark request discarded due to: $state"
@ -194,7 +172,7 @@ class LLamaAndroid private constructor() {
/** /**
* Unloads the model and frees resources * Unloads the model and frees resources
*/ */
suspend fun unloadModel() = override suspend fun unloadModel() =
withContext(llamaDispatcher) { withContext(llamaDispatcher) {
when(_state.value) { when(_state.value) {
is State.AwaitingUserPrompt, is State.Error -> { is State.AwaitingUserPrompt, is State.Error -> {
@ -202,6 +180,7 @@ class LLamaAndroid private constructor() {
unload() unload()
_state.value = State.LibraryLoaded _state.value = State.LibraryLoaded
Log.i(TAG, "Model unloaded!") Log.i(TAG, "Model unloaded!")
Unit
} }
else -> throw IllegalStateException("Cannot unload model in ${_state.value}") else -> throw IllegalStateException("Cannot unload model in ${_state.value}")
} }
@ -211,7 +190,7 @@ class LLamaAndroid private constructor() {
* Cancel all ongoing coroutines and free GGML backends * Cancel all ongoing coroutines and free GGML backends
*/ */
@RequiresCleanup("Call from `ViewModel.onCleared()` to prevent resource leaks!") @RequiresCleanup("Call from `ViewModel.onCleared()` to prevent resource leaks!")
fun destroy() { override fun destroy() {
llamaScope.cancel() llamaScope.cancel()
when(_state.value) { when(_state.value) {
is State.Uninitialized -> {} is State.Uninitialized -> {}