From d70b8fe323b709c3b3b987ba995680ad0bf75b1e Mon Sep 17 00:00:00 2001 From: Han Yin Date: Wed, 16 Apr 2025 11:14:14 -0700 Subject: [PATCH] core: swap in LLamaAndroid and mark stub engine for testing only --- .../com/example/llama/revamp/MainActivity.kt | 12 +-- .../com/example/llama/revamp/di/AppModule.kt | 10 +- .../llama/revamp/engine/InferenceServices.kt | 20 ++-- ...erenceEngine.kt => StubInferenceEngine.kt} | 91 +++++++------------ .../revamp/monitoring/PerformanceMonitor.kt | 5 +- .../revamp/ui/screens/BenchmarkScreen.kt | 5 +- .../revamp/ui/screens/ConversationScreen.kt | 13 +-- .../revamp/ui/screens/ModelLoadingScreen.kt | 16 ++-- .../revamp/viewmodel/BenchmarkViewModel.kt | 6 +- .../llama/revamp/viewmodel/MainViewModel.kt | 1 - .../java/android/llama/cpp/InferenceEngine.kt | 64 +++++++++++++ .../java/android/llama/cpp/LLamaAndroid.kt | 43 +++------ 12 files changed, 157 insertions(+), 129 deletions(-) rename examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/{InferenceEngine.kt => StubInferenceEngine.kt} (67%) create mode 100644 examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt index 03f57348ed..cd77902f37 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt @@ -1,5 +1,6 @@ package com.example.llama.revamp +import android.llama.cpp.InferenceEngine.State import android.os.Bundle import androidx.activity.ComponentActivity import androidx.activity.OnBackPressedCallback @@ -27,7 +28,6 @@ import androidx.lifecycle.compose.LocalLifecycleOwner import androidx.navigation.compose.composable import androidx.navigation.compose.currentBackStackEntryAsState import androidx.navigation.compose.rememberNavController -import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.navigation.AppDestinations import com.example.llama.revamp.navigation.NavigationActions import com.example.llama.revamp.ui.components.AnimatedNavHost @@ -35,9 +35,9 @@ import com.example.llama.revamp.ui.components.AppNavigationDrawer import com.example.llama.revamp.ui.components.UnloadModelConfirmationDialog import com.example.llama.revamp.ui.screens.BenchmarkScreen import com.example.llama.revamp.ui.screens.ConversationScreen +import com.example.llama.revamp.ui.screens.ModelLoadingScreen import com.example.llama.revamp.ui.screens.ModelSelectionScreen import com.example.llama.revamp.ui.screens.ModelsManagementScreen -import com.example.llama.revamp.ui.screens.ModelLoadingScreen import com.example.llama.revamp.ui.screens.SettingsGeneralScreen import com.example.llama.revamp.ui.theme.LlamaTheme import com.example.llama.revamp.viewmodel.MainViewModel @@ -72,10 +72,10 @@ fun AppContent( // LLM Inference engine status val engineState by mainVewModel.engineState.collectAsState() - val isModelLoading = engineState is InferenceEngine.State.LoadingModel - || engineState is InferenceEngine.State.ProcessingSystemPrompt - val isModelLoaded = engineState !is InferenceEngine.State.Uninitialized - && engineState !is InferenceEngine.State.LibraryLoaded + val isModelLoading = engineState is State.LoadingModel + || engineState is State.ProcessingSystemPrompt + val isModelLoaded = engineState !is State.Uninitialized + && engineState !is State.LibraryLoaded // Navigation val navController = rememberNavController() diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt index fef6c26ea2..9f63a42a76 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt @@ -1,6 +1,8 @@ package com.example.llama.revamp.di import android.content.Context +import android.llama.cpp.InferenceEngine +import android.llama.cpp.LLamaAndroid import com.example.llama.revamp.data.local.AppDatabase import com.example.llama.revamp.data.repository.ModelRepository import com.example.llama.revamp.data.repository.ModelRepositoryImpl @@ -8,7 +10,7 @@ import com.example.llama.revamp.data.repository.SystemPromptRepository import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl import com.example.llama.revamp.engine.BenchmarkService import com.example.llama.revamp.engine.ConversationService -import com.example.llama.revamp.engine.InferenceEngine +import com.example.llama.revamp.engine.StubInferenceEngine import com.example.llama.revamp.engine.InferenceService import com.example.llama.revamp.engine.InferenceServiceImpl import com.example.llama.revamp.engine.ModelLoadingService @@ -46,10 +48,12 @@ internal abstract class AppModule { companion object { @Provides @Singleton - fun provideInferenceEngine() = InferenceEngine() + fun provideInferenceEngine(): InferenceEngine { + val useRealEngine = true + return if (useRealEngine) LLamaAndroid.instance() else StubInferenceEngine() + } @Provides - @Singleton fun providePerformanceMonitor(@ApplicationContext context: Context) = PerformanceMonitor(context) @Provides diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt index 3da3065224..1d9d356e33 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt @@ -1,5 +1,7 @@ package com.example.llama.revamp.engine +import android.llama.cpp.InferenceEngine +import android.llama.cpp.InferenceEngine.State import android.util.Log import com.example.llama.revamp.data.model.ModelInfo import kotlinx.coroutines.flow.Flow @@ -14,7 +16,7 @@ interface InferenceService { /** * Expose engine state */ - val engineState: StateFlow + val engineState: StateFlow /** * Currently selected model @@ -51,14 +53,14 @@ interface BenchmarkService : InferenceService { * @param pp: Prompt Processing size * @param tg: Token Generation size * @param pl: Parallel sequences - * @param nr: repetitions (Number of Runs) + * @param nr: Number of Runs, i.e. repetitions */ suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String /** * Benchmark results */ - val results: StateFlow + val benchmarkResults: StateFlow } interface ConversationService : InferenceService { @@ -108,7 +110,7 @@ internal class InferenceServiceImpl @Inject internal constructor( /* InferenceService implementation */ - override val engineState: StateFlow = inferenceEngine.state + override val engineState: StateFlow = inferenceEngine.state private val _currentModel = MutableStateFlow(null) override val currentSelectedModel: StateFlow = _currentModel.asStateFlow() @@ -156,9 +158,15 @@ internal class InferenceServiceImpl @Inject internal constructor( /* BenchmarkService implementation */ override suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String = - inferenceEngine.bench(pp, tg, pl, nr) + inferenceEngine.bench(pp, tg, pl, nr).also { + _benchmarkResults.value = it + } - override val results: StateFlow = inferenceEngine.benchmarkResults + /** + * Benchmark results if available + */ + private val _benchmarkResults = MutableStateFlow(null) + override val benchmarkResults: StateFlow = _benchmarkResults /* ConversationService implementation */ diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt similarity index 67% rename from examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt rename to examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt index 61f089dc66..97943f386d 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt @@ -1,5 +1,7 @@ package com.example.llama.revamp.engine +import android.llama.cpp.InferenceEngine +import android.llama.cpp.InferenceEngine.State import android.util.Log import kotlinx.coroutines.CancellationException import kotlinx.coroutines.delay @@ -8,44 +10,29 @@ import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.flow +import org.jetbrains.annotations.TestOnly +import org.jetbrains.annotations.VisibleForTesting +import javax.inject.Singleton /** - * LLM inference engine that handles model loading and text generation. + * A stub [InferenceEngine] for agile development & testing */ -class InferenceEngine { +@VisibleForTesting +@TestOnly +@Singleton +class StubInferenceEngine : InferenceEngine { companion object { - private val TAG = InferenceEngine::class.java.simpleName + private val TAG = StubInferenceEngine::class.java.simpleName - private const val DEFAULT_PREDICT_LENGTH = 1024 - } - - sealed class State { - object Uninitialized : State() - object LibraryLoaded : State() - - object LoadingModel : State() - object ModelLoaded : State() - - object ProcessingSystemPrompt : State() - object AwaitingUserPrompt : State() - - object ProcessingUserPrompt : State() - object Generating : State() - - object Benchmarking : State() - - data class Error( - val errorMessage: String = "" - ) : State() + private const val STUB_MODEL_LOADING_TIME = 2000L + private const val STUB_BENCHMARKING_TIME = 4000L + private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 3000L + private const val STUB_USER_PROMPT_PROCESSING_TIME = 1500L + private const val STUB_TOKEN_GENERATION_TIME = 200L } private val _state = MutableStateFlow(State.Uninitialized) - val state: StateFlow = _state - - // Keep track of current benchmark results - private var _benchmarkResults: String? = null - private val _benchmarkResultsFlow = MutableStateFlow(null) - val benchmarkResults: StateFlow = _benchmarkResultsFlow + override val state: StateFlow = _state init { Log.i(TAG, "Initiated!") @@ -57,14 +44,14 @@ class InferenceEngine { /** * Loads a model from the given path with an optional system prompt. */ - suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) { + override suspend fun loadModel(pathToModel: String, systemPrompt: String?) { Log.i(TAG, "loadModel! state: ${_state.value}") try { _state.value = State.LoadingModel // Simulate model loading - delay(2000) + delay(STUB_MODEL_LOADING_TIME) _state.value = State.ModelLoaded @@ -72,7 +59,7 @@ class InferenceEngine { _state.value = State.ProcessingSystemPrompt // Simulate processing system prompt - delay(3000) + delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME) } _state.value = State.AwaitingUserPrompt @@ -87,7 +74,7 @@ class InferenceEngine { /** * Sends a user prompt to the loaded model and returns a Flow of generated tokens. */ - fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow { + override fun sendUserPrompt(message: String, predictLength: Int): Flow { Log.i(TAG, "sendUserPrompt! state: ${_state.value}") _state.value = State.ProcessingUserPrompt @@ -96,18 +83,15 @@ class InferenceEngine { return flow { try { // Simulate longer processing time (1.5 seconds) - delay(1500) + delay(STUB_USER_PROMPT_PROCESSING_TIME) _state.value = State.Generating // Simulate token generation val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message" - val words = response.split(" ") - - for (word in words) { - emit(word + " ") - // Slower token generation (200ms per token instead of 50ms) - delay(200) + response.split(" ").forEach { + emit("$it ") + delay(STUB_TOKEN_GENERATION_TIME) } _state.value = State.AwaitingUserPrompt @@ -131,26 +115,26 @@ class InferenceEngine { /** * Runs a benchmark with the specified parameters. */ - suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { + override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String { Log.i(TAG, "bench! state: ${_state.value}") _state.value = State.Benchmarking try { // Simulate benchmark running - delay(4000) + delay(STUB_BENCHMARKING_TIME) // Generate fake benchmark results - val modelDesc = "LlamaModel" + val modelDesc = "Kleidi Llama" val model_size = "7" val model_n_params = "7" val backend = "CPU" // Random values for benchmarks - val pp_avg = (15.0 + Math.random() * 10.0).toFloat() - val pp_std = (0.5 + Math.random() * 2.0).toFloat() - val tg_avg = (20.0 + Math.random() * 15.0).toFloat() - val tg_std = (0.7 + Math.random() * 3.0).toFloat() + val pp_avg = (51.4 + Math.random() * 5.14).toFloat() + val pp_std = (5.14 + Math.random() * 0.514).toFloat() + val tg_avg = (11.4 + Math.random() * 1.14).toFloat() + val tg_std = (1.14 + Math.random() * 0.114).toFloat() val result = StringBuilder() result.append("| model | size | params | backend | test | t/s |\n") @@ -160,12 +144,9 @@ class InferenceEngine { result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ") result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n") - _benchmarkResults = result.toString() - _benchmarkResultsFlow.value = _benchmarkResults - _state.value = State.AwaitingUserPrompt - return _benchmarkResults ?: "" + return result.toString() } catch (e: CancellationException) { // If coroutine is cancelled, propagate cancellation _state.value = State.AwaitingUserPrompt @@ -179,20 +160,18 @@ class InferenceEngine { /** * Unloads the currently loaded model. */ - suspend fun unloadModel() { + override suspend fun unloadModel() { Log.i(TAG, "unloadModel! state: ${_state.value}") // Simulate model unloading time delay(2000) _state.value = State.LibraryLoaded - _benchmarkResults = null - _benchmarkResultsFlow.value = null } /** * Cleans up resources when the engine is no longer needed. */ - fun destroy() { + override fun destroy() { Log.i(TAG, "destroy! state: ${_state.value}") _state.value = State.Uninitialized diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/monitoring/PerformanceMonitor.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/monitoring/PerformanceMonitor.kt index a6a68af126..d3f55b300f 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/monitoring/PerformanceMonitor.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/monitoring/PerformanceMonitor.kt @@ -5,18 +5,21 @@ import android.content.Context import android.content.Intent import android.content.IntentFilter import android.os.BatteryManager +import dagger.hilt.android.qualifiers.ApplicationContext import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.withContext +import javax.inject.Singleton import kotlin.math.roundToInt /** * Service that monitors device performance metrics such as memory usage, * battery level, and temperature. */ -class PerformanceMonitor(private val context: Context) { +@Singleton +class PerformanceMonitor(@ApplicationContext private val context: Context) { /** * Provides a flow of memory usage information that updates at the specified interval. diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt index 28e2661d69..d132b7e4fb 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt @@ -1,5 +1,6 @@ package com.example.llama.revamp.ui.screens +import android.llama.cpp.InferenceEngine.State import androidx.compose.foundation.background import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column @@ -23,12 +24,10 @@ import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.unit.dp import androidx.hilt.navigation.compose.hiltViewModel -import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.ui.components.ModelCard import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.theme.MonospacedTextStyle import com.example.llama.revamp.viewmodel.BenchmarkViewModel -import com.example.llama.revamp.viewmodel.MainViewModel @Composable fun BenchmarkScreen( @@ -67,7 +66,7 @@ fun BenchmarkScreen( // Benchmark results or loading indicator when { - engineState is InferenceEngine.State.Benchmarking -> { + engineState is State.Benchmarking -> { Box( modifier = Modifier .fillMaxWidth() diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt index 54e8f22b93..86ce7f9f5b 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt @@ -1,16 +1,12 @@ package com.example.llama.revamp.ui.screens -import androidx.compose.animation.AnimatedVisibility +import android.llama.cpp.InferenceEngine.State import androidx.compose.animation.core.LinearEasing import androidx.compose.animation.core.RepeatMode import androidx.compose.animation.core.animateFloat import androidx.compose.animation.core.infiniteRepeatable import androidx.compose.animation.core.rememberInfiniteTransition import androidx.compose.animation.core.tween -import androidx.compose.animation.expandVertically -import androidx.compose.animation.fadeIn -import androidx.compose.animation.fadeOut -import androidx.compose.animation.shrinkVertically import androidx.compose.foundation.background import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column @@ -33,7 +29,6 @@ import androidx.compose.material.icons.filled.Send import androidx.compose.material3.Card import androidx.compose.material3.CardDefaults import androidx.compose.material3.CircularProgressIndicator -import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.Icon import androidx.compose.material3.IconButton import androidx.compose.material3.MaterialTheme @@ -61,8 +56,6 @@ import androidx.hilt.navigation.compose.hiltViewModel import androidx.lifecycle.Lifecycle import androidx.lifecycle.LifecycleEventObserver import androidx.lifecycle.compose.LocalLifecycleOwner -import com.example.llama.revamp.data.model.ModelInfo -import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.ui.components.ModelCardWithSystemPrompt import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.viewmodel.ConversationViewModel @@ -82,8 +75,8 @@ fun ConversationScreen( val systemPrompt by viewModel.systemPrompt.collectAsState() val selectedModel by viewModel.selectedModel.collectAsState() - val isProcessing = engineState is InferenceEngine.State.ProcessingUserPrompt - val isGenerating = engineState is InferenceEngine.State.Generating + val isProcessing = engineState is State.ProcessingUserPrompt + val isGenerating = engineState is State.Generating val listState = rememberLazyListState() var inputText by remember { mutableStateOf("") } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt index 412726a8cc..ecaca9703d 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt @@ -1,5 +1,6 @@ package com.example.llama.revamp.ui.screens +import android.llama.cpp.InferenceEngine.State import androidx.compose.animation.AnimatedVisibility import androidx.compose.animation.expandVertically import androidx.compose.animation.fadeIn @@ -49,7 +50,6 @@ import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp import androidx.hilt.navigation.compose.hiltViewModel import com.example.llama.revamp.data.model.SystemPrompt -import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.ui.components.ModelCard import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.viewmodel.ModelLoadingViewModel @@ -63,7 +63,7 @@ enum class SystemPromptTab { @OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class) @Composable fun ModelLoadingScreen( - engineState: InferenceEngine.State, + engineState: State, onBenchmarkSelected: (prepareJob: Job) -> Unit, onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit, onBackPressed: () -> Unit, @@ -99,9 +99,9 @@ fun ModelLoadingScreen( } // Check if we're in a loading state - val isLoading = engineState !is InferenceEngine.State.Uninitialized && - engineState !is InferenceEngine.State.LibraryLoaded && - engineState !is InferenceEngine.State.AwaitingUserPrompt + val isLoading = engineState !is State.Uninitialized && + engineState !is State.LibraryLoaded && + engineState !is State.AwaitingUserPrompt // Mode selection callbacks val handleBenchmarkSelected = { @@ -429,9 +429,9 @@ fun ModelLoadingScreen( Spacer(modifier = Modifier.width(8.dp)) Text( text = when (engineState) { - is InferenceEngine.State.LoadingModel -> "Loading model..." - is InferenceEngine.State.ProcessingSystemPrompt -> "Processing system prompt..." - is InferenceEngine.State.ModelLoaded -> "Preparing conversation..." + is State.LoadingModel -> "Loading model..." + is State.ProcessingSystemPrompt -> "Processing system prompt..." + is State.ModelLoaded -> "Preparing conversation..." else -> "Processing..." }, style = MaterialTheme.typography.titleMedium diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt index d4427ad3e8..078992705c 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt @@ -1,10 +1,10 @@ package com.example.llama.revamp.viewmodel +import android.llama.cpp.InferenceEngine.State import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.engine.BenchmarkService -import com.example.llama.revamp.engine.InferenceEngine import dagger.hilt.android.lifecycle.HiltViewModel import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.launch @@ -15,8 +15,8 @@ class BenchmarkViewModel @Inject constructor( private val benchmarkService: BenchmarkService ) : ViewModel() { - val engineState: StateFlow = benchmarkService.engineState - val benchmarkResults: StateFlow = benchmarkService.results + val engineState: StateFlow = benchmarkService.engineState + val benchmarkResults: StateFlow = benchmarkService.benchmarkResults val selectedModel: StateFlow = benchmarkService.currentSelectedModel /** diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt index 5ef1608a5a..6a70e1f9ec 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt @@ -1,7 +1,6 @@ package com.example.llama.revamp.viewmodel import androidx.lifecycle.ViewModel -import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceService import dagger.hilt.android.lifecycle.HiltViewModel import javax.inject.Inject diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt new file mode 100644 index 0000000000..043f0a0037 --- /dev/null +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt @@ -0,0 +1,64 @@ +package android.llama.cpp + +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.StateFlow + +/** + * Interface defining the core LLM inference operations. + */ +interface InferenceEngine { + /** + * Current state of the inference engine + */ + val state: StateFlow + + /** + * Load a model from the given path with an optional system prompt. + */ + suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) + + /** + * Sends a user prompt to the loaded model and returns a Flow of generated tokens. + */ + fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow + + /** + * Runs a benchmark with the specified parameters. + */ + suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String + + /** + * Unloads the currently loaded model. + */ + suspend fun unloadModel() + + /** + * Cleans up resources when the engine is no longer needed. + */ + fun destroy() + + /** + * States of the inference engine + */ + sealed class State { + object Uninitialized : State() + object LibraryLoaded : State() + + object LoadingModel : State() + object ModelLoaded : State() + + object ProcessingSystemPrompt : State() + object AwaitingUserPrompt : State() + + object ProcessingUserPrompt : State() + object Generating : State() + + object Benchmarking : State() + + data class Error(val errorMessage: String = "") : State() + } + + companion object { + const val DEFAULT_PREDICT_LENGTH = 1024 + } +} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index 7f52ccf68a..9110e264e9 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -1,5 +1,7 @@ package android.llama.cpp +import android.llama.cpp.InferenceEngine.State +import android.llama.cpp.LLamaAndroid.Companion.instance import android.util.Log import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers @@ -38,7 +40,7 @@ annotation class RequiresCleanup(val message: String = "Remember to call this me * * @see llama-android.cpp for the native implementation details */ -class LLamaAndroid private constructor() { +class LLamaAndroid private constructor() : InferenceEngine { /** * JNI methods * @see llama-android.cpp @@ -57,32 +59,8 @@ class LLamaAndroid private constructor() { private external fun unload() private external fun shutdown() - - /** - * Fine-grained state management - */ - sealed class State { - object Uninitialized : State() - object LibraryLoaded : State() - - object LoadingModel : State() - object ModelLoaded : State() - - object ProcessingSystemPrompt : State() - object AwaitingUserPrompt : State() - - object ProcessingUserPrompt : State() - object Generating : State() - - object Benchmarking : State() - - data class Error( - val errorMessage: String = "" - ) : State() - } - private val _state = MutableStateFlow(State.Uninitialized) - val state: StateFlow = _state + override val state: StateFlow = _state /** * Single-threaded coroutine dispatcher & scope for LLama asynchronous operations @@ -108,7 +86,7 @@ class LLamaAndroid private constructor() { /** * Load the LLM, then process the plain text system prompt if provided */ - suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) = + override suspend fun loadModel(pathToModel: String, systemPrompt: String?) = withContext(llamaDispatcher) { check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" } File(pathToModel).let { @@ -147,9 +125,9 @@ class LLamaAndroid private constructor() { /** * Send plain text user prompt to LLM, which starts generating tokens in a [Flow] */ - fun sendUserPrompt( + override fun sendUserPrompt( message: String, - predictLength: Int = DEFAULT_PREDICT_LENGTH, + predictLength: Int, ): Flow = flow { require(message.isNotEmpty()) { "User prompt discarded due to being empty!" } check(_state.value is State.AwaitingUserPrompt) { @@ -179,7 +157,7 @@ class LLamaAndroid private constructor() { /** * Benchmark the model */ - suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String = + override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String = withContext(llamaDispatcher) { check(_state.value is State.AwaitingUserPrompt) { "Benchmark request discarded due to: $state" @@ -194,7 +172,7 @@ class LLamaAndroid private constructor() { /** * Unloads the model and frees resources */ - suspend fun unloadModel() = + override suspend fun unloadModel() = withContext(llamaDispatcher) { when(_state.value) { is State.AwaitingUserPrompt, is State.Error -> { @@ -202,6 +180,7 @@ class LLamaAndroid private constructor() { unload() _state.value = State.LibraryLoaded Log.i(TAG, "Model unloaded!") + Unit } else -> throw IllegalStateException("Cannot unload model in ${_state.value}") } @@ -211,7 +190,7 @@ class LLamaAndroid private constructor() { * Cancel all ongoing coroutines and free GGML backends */ @RequiresCleanup("Call from `ViewModel.onCleared()` to prevent resource leaks!") - fun destroy() { + override fun destroy() { llamaScope.cancel() when(_state.value) { is State.Uninitialized -> {}