core: swap in LLamaAndroid and mark stub engine for testing only
This commit is contained in:
parent
c2426a42e5
commit
d70b8fe323
|
|
@ -1,5 +1,6 @@
|
|||
package com.example.llama.revamp
|
||||
|
||||
import android.llama.cpp.InferenceEngine.State
|
||||
import android.os.Bundle
|
||||
import androidx.activity.ComponentActivity
|
||||
import androidx.activity.OnBackPressedCallback
|
||||
|
|
@ -27,7 +28,6 @@ import androidx.lifecycle.compose.LocalLifecycleOwner
|
|||
import androidx.navigation.compose.composable
|
||||
import androidx.navigation.compose.currentBackStackEntryAsState
|
||||
import androidx.navigation.compose.rememberNavController
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.navigation.AppDestinations
|
||||
import com.example.llama.revamp.navigation.NavigationActions
|
||||
import com.example.llama.revamp.ui.components.AnimatedNavHost
|
||||
|
|
@ -35,9 +35,9 @@ import com.example.llama.revamp.ui.components.AppNavigationDrawer
|
|||
import com.example.llama.revamp.ui.components.UnloadModelConfirmationDialog
|
||||
import com.example.llama.revamp.ui.screens.BenchmarkScreen
|
||||
import com.example.llama.revamp.ui.screens.ConversationScreen
|
||||
import com.example.llama.revamp.ui.screens.ModelLoadingScreen
|
||||
import com.example.llama.revamp.ui.screens.ModelSelectionScreen
|
||||
import com.example.llama.revamp.ui.screens.ModelsManagementScreen
|
||||
import com.example.llama.revamp.ui.screens.ModelLoadingScreen
|
||||
import com.example.llama.revamp.ui.screens.SettingsGeneralScreen
|
||||
import com.example.llama.revamp.ui.theme.LlamaTheme
|
||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
||||
|
|
@ -72,10 +72,10 @@ fun AppContent(
|
|||
|
||||
// LLM Inference engine status
|
||||
val engineState by mainVewModel.engineState.collectAsState()
|
||||
val isModelLoading = engineState is InferenceEngine.State.LoadingModel
|
||||
|| engineState is InferenceEngine.State.ProcessingSystemPrompt
|
||||
val isModelLoaded = engineState !is InferenceEngine.State.Uninitialized
|
||||
&& engineState !is InferenceEngine.State.LibraryLoaded
|
||||
val isModelLoading = engineState is State.LoadingModel
|
||||
|| engineState is State.ProcessingSystemPrompt
|
||||
val isModelLoaded = engineState !is State.Uninitialized
|
||||
&& engineState !is State.LibraryLoaded
|
||||
|
||||
// Navigation
|
||||
val navController = rememberNavController()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package com.example.llama.revamp.di
|
||||
|
||||
import android.content.Context
|
||||
import android.llama.cpp.InferenceEngine
|
||||
import android.llama.cpp.LLamaAndroid
|
||||
import com.example.llama.revamp.data.local.AppDatabase
|
||||
import com.example.llama.revamp.data.repository.ModelRepository
|
||||
import com.example.llama.revamp.data.repository.ModelRepositoryImpl
|
||||
|
|
@ -8,7 +10,7 @@ import com.example.llama.revamp.data.repository.SystemPromptRepository
|
|||
import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl
|
||||
import com.example.llama.revamp.engine.BenchmarkService
|
||||
import com.example.llama.revamp.engine.ConversationService
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.engine.StubInferenceEngine
|
||||
import com.example.llama.revamp.engine.InferenceService
|
||||
import com.example.llama.revamp.engine.InferenceServiceImpl
|
||||
import com.example.llama.revamp.engine.ModelLoadingService
|
||||
|
|
@ -46,10 +48,12 @@ internal abstract class AppModule {
|
|||
companion object {
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideInferenceEngine() = InferenceEngine()
|
||||
fun provideInferenceEngine(): InferenceEngine {
|
||||
val useRealEngine = true
|
||||
return if (useRealEngine) LLamaAndroid.instance() else StubInferenceEngine()
|
||||
}
|
||||
|
||||
@Provides
|
||||
@Singleton
|
||||
fun providePerformanceMonitor(@ApplicationContext context: Context) = PerformanceMonitor(context)
|
||||
|
||||
@Provides
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
package com.example.llama.revamp.engine
|
||||
|
||||
import android.llama.cpp.InferenceEngine
|
||||
import android.llama.cpp.InferenceEngine.State
|
||||
import android.util.Log
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
|
|
@ -14,7 +16,7 @@ interface InferenceService {
|
|||
/**
|
||||
* Expose engine state
|
||||
*/
|
||||
val engineState: StateFlow<InferenceEngine.State>
|
||||
val engineState: StateFlow<State>
|
||||
|
||||
/**
|
||||
* Currently selected model
|
||||
|
|
@ -51,14 +53,14 @@ interface BenchmarkService : InferenceService {
|
|||
* @param pp: Prompt Processing size
|
||||
* @param tg: Token Generation size
|
||||
* @param pl: Parallel sequences
|
||||
* @param nr: repetitions (Number of Runs)
|
||||
* @param nr: Number of Runs, i.e. repetitions
|
||||
*/
|
||||
suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String
|
||||
|
||||
/**
|
||||
* Benchmark results
|
||||
*/
|
||||
val results: StateFlow<String?>
|
||||
val benchmarkResults: StateFlow<String?>
|
||||
}
|
||||
|
||||
interface ConversationService : InferenceService {
|
||||
|
|
@ -108,7 +110,7 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
|||
|
||||
/* InferenceService implementation */
|
||||
|
||||
override val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state
|
||||
override val engineState: StateFlow<State> = inferenceEngine.state
|
||||
|
||||
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
|
||||
override val currentSelectedModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
|
||||
|
|
@ -156,9 +158,15 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
|||
/* BenchmarkService implementation */
|
||||
|
||||
override suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String =
|
||||
inferenceEngine.bench(pp, tg, pl, nr)
|
||||
inferenceEngine.bench(pp, tg, pl, nr).also {
|
||||
_benchmarkResults.value = it
|
||||
}
|
||||
|
||||
override val results: StateFlow<String?> = inferenceEngine.benchmarkResults
|
||||
/**
|
||||
* Benchmark results if available
|
||||
*/
|
||||
private val _benchmarkResults = MutableStateFlow<String?>(null)
|
||||
override val benchmarkResults: StateFlow<String?> = _benchmarkResults
|
||||
|
||||
|
||||
/* ConversationService implementation */
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
package com.example.llama.revamp.engine
|
||||
|
||||
import android.llama.cpp.InferenceEngine
|
||||
import android.llama.cpp.InferenceEngine.State
|
||||
import android.util.Log
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.delay
|
||||
|
|
@ -8,44 +10,29 @@ import kotlinx.coroutines.flow.MutableStateFlow
|
|||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.catch
|
||||
import kotlinx.coroutines.flow.flow
|
||||
import org.jetbrains.annotations.TestOnly
|
||||
import org.jetbrains.annotations.VisibleForTesting
|
||||
import javax.inject.Singleton
|
||||
|
||||
/**
|
||||
* LLM inference engine that handles model loading and text generation.
|
||||
* A stub [InferenceEngine] for agile development & testing
|
||||
*/
|
||||
class InferenceEngine {
|
||||
@VisibleForTesting
|
||||
@TestOnly
|
||||
@Singleton
|
||||
class StubInferenceEngine : InferenceEngine {
|
||||
companion object {
|
||||
private val TAG = InferenceEngine::class.java.simpleName
|
||||
private val TAG = StubInferenceEngine::class.java.simpleName
|
||||
|
||||
private const val DEFAULT_PREDICT_LENGTH = 1024
|
||||
}
|
||||
|
||||
sealed class State {
|
||||
object Uninitialized : State()
|
||||
object LibraryLoaded : State()
|
||||
|
||||
object LoadingModel : State()
|
||||
object ModelLoaded : State()
|
||||
|
||||
object ProcessingSystemPrompt : State()
|
||||
object AwaitingUserPrompt : State()
|
||||
|
||||
object ProcessingUserPrompt : State()
|
||||
object Generating : State()
|
||||
|
||||
object Benchmarking : State()
|
||||
|
||||
data class Error(
|
||||
val errorMessage: String = ""
|
||||
) : State()
|
||||
private const val STUB_MODEL_LOADING_TIME = 2000L
|
||||
private const val STUB_BENCHMARKING_TIME = 4000L
|
||||
private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 3000L
|
||||
private const val STUB_USER_PROMPT_PROCESSING_TIME = 1500L
|
||||
private const val STUB_TOKEN_GENERATION_TIME = 200L
|
||||
}
|
||||
|
||||
private val _state = MutableStateFlow<State>(State.Uninitialized)
|
||||
val state: StateFlow<State> = _state
|
||||
|
||||
// Keep track of current benchmark results
|
||||
private var _benchmarkResults: String? = null
|
||||
private val _benchmarkResultsFlow = MutableStateFlow<String?>(null)
|
||||
val benchmarkResults: StateFlow<String?> = _benchmarkResultsFlow
|
||||
override val state: StateFlow<State> = _state
|
||||
|
||||
init {
|
||||
Log.i(TAG, "Initiated!")
|
||||
|
|
@ -57,14 +44,14 @@ class InferenceEngine {
|
|||
/**
|
||||
* Loads a model from the given path with an optional system prompt.
|
||||
*/
|
||||
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) {
|
||||
override suspend fun loadModel(pathToModel: String, systemPrompt: String?) {
|
||||
Log.i(TAG, "loadModel! state: ${_state.value}")
|
||||
|
||||
try {
|
||||
_state.value = State.LoadingModel
|
||||
|
||||
// Simulate model loading
|
||||
delay(2000)
|
||||
delay(STUB_MODEL_LOADING_TIME)
|
||||
|
||||
_state.value = State.ModelLoaded
|
||||
|
||||
|
|
@ -72,7 +59,7 @@ class InferenceEngine {
|
|||
_state.value = State.ProcessingSystemPrompt
|
||||
|
||||
// Simulate processing system prompt
|
||||
delay(3000)
|
||||
delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME)
|
||||
}
|
||||
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
|
|
@ -87,7 +74,7 @@ class InferenceEngine {
|
|||
/**
|
||||
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
|
||||
*/
|
||||
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String> {
|
||||
override fun sendUserPrompt(message: String, predictLength: Int): Flow<String> {
|
||||
Log.i(TAG, "sendUserPrompt! state: ${_state.value}")
|
||||
|
||||
_state.value = State.ProcessingUserPrompt
|
||||
|
|
@ -96,18 +83,15 @@ class InferenceEngine {
|
|||
return flow {
|
||||
try {
|
||||
// Simulate longer processing time (1.5 seconds)
|
||||
delay(1500)
|
||||
delay(STUB_USER_PROMPT_PROCESSING_TIME)
|
||||
|
||||
_state.value = State.Generating
|
||||
|
||||
// Simulate token generation
|
||||
val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
|
||||
val words = response.split(" ")
|
||||
|
||||
for (word in words) {
|
||||
emit(word + " ")
|
||||
// Slower token generation (200ms per token instead of 50ms)
|
||||
delay(200)
|
||||
response.split(" ").forEach {
|
||||
emit("$it ")
|
||||
delay(STUB_TOKEN_GENERATION_TIME)
|
||||
}
|
||||
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
|
|
@ -131,26 +115,26 @@ class InferenceEngine {
|
|||
/**
|
||||
* Runs a benchmark with the specified parameters.
|
||||
*/
|
||||
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
|
||||
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String {
|
||||
Log.i(TAG, "bench! state: ${_state.value}")
|
||||
|
||||
_state.value = State.Benchmarking
|
||||
|
||||
try {
|
||||
// Simulate benchmark running
|
||||
delay(4000)
|
||||
delay(STUB_BENCHMARKING_TIME)
|
||||
|
||||
// Generate fake benchmark results
|
||||
val modelDesc = "LlamaModel"
|
||||
val modelDesc = "Kleidi Llama"
|
||||
val model_size = "7"
|
||||
val model_n_params = "7"
|
||||
val backend = "CPU"
|
||||
|
||||
// Random values for benchmarks
|
||||
val pp_avg = (15.0 + Math.random() * 10.0).toFloat()
|
||||
val pp_std = (0.5 + Math.random() * 2.0).toFloat()
|
||||
val tg_avg = (20.0 + Math.random() * 15.0).toFloat()
|
||||
val tg_std = (0.7 + Math.random() * 3.0).toFloat()
|
||||
val pp_avg = (51.4 + Math.random() * 5.14).toFloat()
|
||||
val pp_std = (5.14 + Math.random() * 0.514).toFloat()
|
||||
val tg_avg = (11.4 + Math.random() * 1.14).toFloat()
|
||||
val tg_std = (1.14 + Math.random() * 0.114).toFloat()
|
||||
|
||||
val result = StringBuilder()
|
||||
result.append("| model | size | params | backend | test | t/s |\n")
|
||||
|
|
@ -160,12 +144,9 @@ class InferenceEngine {
|
|||
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
|
||||
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
|
||||
|
||||
_benchmarkResults = result.toString()
|
||||
_benchmarkResultsFlow.value = _benchmarkResults
|
||||
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
|
||||
return _benchmarkResults ?: ""
|
||||
return result.toString()
|
||||
} catch (e: CancellationException) {
|
||||
// If coroutine is cancelled, propagate cancellation
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
|
|
@ -179,20 +160,18 @@ class InferenceEngine {
|
|||
/**
|
||||
* Unloads the currently loaded model.
|
||||
*/
|
||||
suspend fun unloadModel() {
|
||||
override suspend fun unloadModel() {
|
||||
Log.i(TAG, "unloadModel! state: ${_state.value}")
|
||||
|
||||
// Simulate model unloading time
|
||||
delay(2000)
|
||||
_state.value = State.LibraryLoaded
|
||||
_benchmarkResults = null
|
||||
_benchmarkResultsFlow.value = null
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleans up resources when the engine is no longer needed.
|
||||
*/
|
||||
fun destroy() {
|
||||
override fun destroy() {
|
||||
Log.i(TAG, "destroy! state: ${_state.value}")
|
||||
|
||||
_state.value = State.Uninitialized
|
||||
|
|
@ -5,18 +5,21 @@ import android.content.Context
|
|||
import android.content.Intent
|
||||
import android.content.IntentFilter
|
||||
import android.os.BatteryManager
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.flow
|
||||
import kotlinx.coroutines.withContext
|
||||
import javax.inject.Singleton
|
||||
import kotlin.math.roundToInt
|
||||
|
||||
/**
|
||||
* Service that monitors device performance metrics such as memory usage,
|
||||
* battery level, and temperature.
|
||||
*/
|
||||
class PerformanceMonitor(private val context: Context) {
|
||||
@Singleton
|
||||
class PerformanceMonitor(@ApplicationContext private val context: Context) {
|
||||
|
||||
/**
|
||||
* Provides a flow of memory usage information that updates at the specified interval.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
package com.example.llama.revamp.ui.screens
|
||||
|
||||
import android.llama.cpp.InferenceEngine.State
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
|
|
@ -23,12 +24,10 @@ import androidx.compose.ui.Alignment
|
|||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.ui.components.ModelCard
|
||||
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
||||
import com.example.llama.revamp.ui.theme.MonospacedTextStyle
|
||||
import com.example.llama.revamp.viewmodel.BenchmarkViewModel
|
||||
import com.example.llama.revamp.viewmodel.MainViewModel
|
||||
|
||||
@Composable
|
||||
fun BenchmarkScreen(
|
||||
|
|
@ -67,7 +66,7 @@ fun BenchmarkScreen(
|
|||
|
||||
// Benchmark results or loading indicator
|
||||
when {
|
||||
engineState is InferenceEngine.State.Benchmarking -> {
|
||||
engineState is State.Benchmarking -> {
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
|
|
|
|||
|
|
@ -1,16 +1,12 @@
|
|||
package com.example.llama.revamp.ui.screens
|
||||
|
||||
import androidx.compose.animation.AnimatedVisibility
|
||||
import android.llama.cpp.InferenceEngine.State
|
||||
import androidx.compose.animation.core.LinearEasing
|
||||
import androidx.compose.animation.core.RepeatMode
|
||||
import androidx.compose.animation.core.animateFloat
|
||||
import androidx.compose.animation.core.infiniteRepeatable
|
||||
import androidx.compose.animation.core.rememberInfiniteTransition
|
||||
import androidx.compose.animation.core.tween
|
||||
import androidx.compose.animation.expandVertically
|
||||
import androidx.compose.animation.fadeIn
|
||||
import androidx.compose.animation.fadeOut
|
||||
import androidx.compose.animation.shrinkVertically
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
|
|
@ -33,7 +29,6 @@ import androidx.compose.material.icons.filled.Send
|
|||
import androidx.compose.material3.Card
|
||||
import androidx.compose.material3.CardDefaults
|
||||
import androidx.compose.material3.CircularProgressIndicator
|
||||
import androidx.compose.material3.HorizontalDivider
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
|
|
@ -61,8 +56,6 @@ import androidx.hilt.navigation.compose.hiltViewModel
|
|||
import androidx.lifecycle.Lifecycle
|
||||
import androidx.lifecycle.LifecycleEventObserver
|
||||
import androidx.lifecycle.compose.LocalLifecycleOwner
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.ui.components.ModelCardWithSystemPrompt
|
||||
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
||||
import com.example.llama.revamp.viewmodel.ConversationViewModel
|
||||
|
|
@ -82,8 +75,8 @@ fun ConversationScreen(
|
|||
val systemPrompt by viewModel.systemPrompt.collectAsState()
|
||||
val selectedModel by viewModel.selectedModel.collectAsState()
|
||||
|
||||
val isProcessing = engineState is InferenceEngine.State.ProcessingUserPrompt
|
||||
val isGenerating = engineState is InferenceEngine.State.Generating
|
||||
val isProcessing = engineState is State.ProcessingUserPrompt
|
||||
val isGenerating = engineState is State.Generating
|
||||
|
||||
val listState = rememberLazyListState()
|
||||
var inputText by remember { mutableStateOf("") }
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
package com.example.llama.revamp.ui.screens
|
||||
|
||||
import android.llama.cpp.InferenceEngine.State
|
||||
import androidx.compose.animation.AnimatedVisibility
|
||||
import androidx.compose.animation.expandVertically
|
||||
import androidx.compose.animation.fadeIn
|
||||
|
|
@ -49,7 +50,6 @@ import androidx.compose.ui.text.style.TextOverflow
|
|||
import androidx.compose.ui.unit.dp
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import com.example.llama.revamp.data.model.SystemPrompt
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.ui.components.ModelCard
|
||||
import com.example.llama.revamp.ui.components.PerformanceAppScaffold
|
||||
import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
|
||||
|
|
@ -63,7 +63,7 @@ enum class SystemPromptTab {
|
|||
@OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
|
||||
@Composable
|
||||
fun ModelLoadingScreen(
|
||||
engineState: InferenceEngine.State,
|
||||
engineState: State,
|
||||
onBenchmarkSelected: (prepareJob: Job) -> Unit,
|
||||
onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit,
|
||||
onBackPressed: () -> Unit,
|
||||
|
|
@ -99,9 +99,9 @@ fun ModelLoadingScreen(
|
|||
}
|
||||
|
||||
// Check if we're in a loading state
|
||||
val isLoading = engineState !is InferenceEngine.State.Uninitialized &&
|
||||
engineState !is InferenceEngine.State.LibraryLoaded &&
|
||||
engineState !is InferenceEngine.State.AwaitingUserPrompt
|
||||
val isLoading = engineState !is State.Uninitialized &&
|
||||
engineState !is State.LibraryLoaded &&
|
||||
engineState !is State.AwaitingUserPrompt
|
||||
|
||||
// Mode selection callbacks
|
||||
val handleBenchmarkSelected = {
|
||||
|
|
@ -429,9 +429,9 @@ fun ModelLoadingScreen(
|
|||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text(
|
||||
text = when (engineState) {
|
||||
is InferenceEngine.State.LoadingModel -> "Loading model..."
|
||||
is InferenceEngine.State.ProcessingSystemPrompt -> "Processing system prompt..."
|
||||
is InferenceEngine.State.ModelLoaded -> "Preparing conversation..."
|
||||
is State.LoadingModel -> "Loading model..."
|
||||
is State.ProcessingSystemPrompt -> "Processing system prompt..."
|
||||
is State.ModelLoaded -> "Preparing conversation..."
|
||||
else -> "Processing..."
|
||||
},
|
||||
style = MaterialTheme.typography.titleMedium
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
package com.example.llama.revamp.viewmodel
|
||||
|
||||
import android.llama.cpp.InferenceEngine.State
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
import com.example.llama.revamp.engine.BenchmarkService
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.launch
|
||||
|
|
@ -15,8 +15,8 @@ class BenchmarkViewModel @Inject constructor(
|
|||
private val benchmarkService: BenchmarkService
|
||||
) : ViewModel() {
|
||||
|
||||
val engineState: StateFlow<InferenceEngine.State> = benchmarkService.engineState
|
||||
val benchmarkResults: StateFlow<String?> = benchmarkService.results
|
||||
val engineState: StateFlow<State> = benchmarkService.engineState
|
||||
val benchmarkResults: StateFlow<String?> = benchmarkService.benchmarkResults
|
||||
val selectedModel: StateFlow<ModelInfo?> = benchmarkService.currentSelectedModel
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package com.example.llama.revamp.viewmodel
|
||||
|
||||
import androidx.lifecycle.ViewModel
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.engine.InferenceService
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import javax.inject.Inject
|
||||
|
|
|
|||
|
|
@ -0,0 +1,64 @@
|
|||
package android.llama.cpp
|
||||
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
|
||||
/**
|
||||
* Interface defining the core LLM inference operations.
|
||||
*/
|
||||
interface InferenceEngine {
|
||||
/**
|
||||
* Current state of the inference engine
|
||||
*/
|
||||
val state: StateFlow<State>
|
||||
|
||||
/**
|
||||
* Load a model from the given path with an optional system prompt.
|
||||
*/
|
||||
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null)
|
||||
|
||||
/**
|
||||
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
|
||||
*/
|
||||
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String>
|
||||
|
||||
/**
|
||||
* Runs a benchmark with the specified parameters.
|
||||
*/
|
||||
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
|
||||
|
||||
/**
|
||||
* Unloads the currently loaded model.
|
||||
*/
|
||||
suspend fun unloadModel()
|
||||
|
||||
/**
|
||||
* Cleans up resources when the engine is no longer needed.
|
||||
*/
|
||||
fun destroy()
|
||||
|
||||
/**
|
||||
* States of the inference engine
|
||||
*/
|
||||
sealed class State {
|
||||
object Uninitialized : State()
|
||||
object LibraryLoaded : State()
|
||||
|
||||
object LoadingModel : State()
|
||||
object ModelLoaded : State()
|
||||
|
||||
object ProcessingSystemPrompt : State()
|
||||
object AwaitingUserPrompt : State()
|
||||
|
||||
object ProcessingUserPrompt : State()
|
||||
object Generating : State()
|
||||
|
||||
object Benchmarking : State()
|
||||
|
||||
data class Error(val errorMessage: String = "") : State()
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val DEFAULT_PREDICT_LENGTH = 1024
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
package android.llama.cpp
|
||||
|
||||
import android.llama.cpp.InferenceEngine.State
|
||||
import android.llama.cpp.LLamaAndroid.Companion.instance
|
||||
import android.util.Log
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
|
|
@ -38,7 +40,7 @@ annotation class RequiresCleanup(val message: String = "Remember to call this me
|
|||
*
|
||||
* @see llama-android.cpp for the native implementation details
|
||||
*/
|
||||
class LLamaAndroid private constructor() {
|
||||
class LLamaAndroid private constructor() : InferenceEngine {
|
||||
/**
|
||||
* JNI methods
|
||||
* @see llama-android.cpp
|
||||
|
|
@ -57,32 +59,8 @@ class LLamaAndroid private constructor() {
|
|||
private external fun unload()
|
||||
private external fun shutdown()
|
||||
|
||||
|
||||
/**
|
||||
* Fine-grained state management
|
||||
*/
|
||||
sealed class State {
|
||||
object Uninitialized : State()
|
||||
object LibraryLoaded : State()
|
||||
|
||||
object LoadingModel : State()
|
||||
object ModelLoaded : State()
|
||||
|
||||
object ProcessingSystemPrompt : State()
|
||||
object AwaitingUserPrompt : State()
|
||||
|
||||
object ProcessingUserPrompt : State()
|
||||
object Generating : State()
|
||||
|
||||
object Benchmarking : State()
|
||||
|
||||
data class Error(
|
||||
val errorMessage: String = ""
|
||||
) : State()
|
||||
}
|
||||
|
||||
private val _state = MutableStateFlow<State>(State.Uninitialized)
|
||||
val state: StateFlow<State> = _state
|
||||
override val state: StateFlow<State> = _state
|
||||
|
||||
/**
|
||||
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
|
||||
|
|
@ -108,7 +86,7 @@ class LLamaAndroid private constructor() {
|
|||
/**
|
||||
* Load the LLM, then process the plain text system prompt if provided
|
||||
*/
|
||||
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) =
|
||||
override suspend fun loadModel(pathToModel: String, systemPrompt: String?) =
|
||||
withContext(llamaDispatcher) {
|
||||
check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" }
|
||||
File(pathToModel).let {
|
||||
|
|
@ -147,9 +125,9 @@ class LLamaAndroid private constructor() {
|
|||
/**
|
||||
* Send plain text user prompt to LLM, which starts generating tokens in a [Flow]
|
||||
*/
|
||||
fun sendUserPrompt(
|
||||
override fun sendUserPrompt(
|
||||
message: String,
|
||||
predictLength: Int = DEFAULT_PREDICT_LENGTH,
|
||||
predictLength: Int,
|
||||
): Flow<String> = flow {
|
||||
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
|
||||
check(_state.value is State.AwaitingUserPrompt) {
|
||||
|
|
@ -179,7 +157,7 @@ class LLamaAndroid private constructor() {
|
|||
/**
|
||||
* Benchmark the model
|
||||
*/
|
||||
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String =
|
||||
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
|
||||
withContext(llamaDispatcher) {
|
||||
check(_state.value is State.AwaitingUserPrompt) {
|
||||
"Benchmark request discarded due to: $state"
|
||||
|
|
@ -194,7 +172,7 @@ class LLamaAndroid private constructor() {
|
|||
/**
|
||||
* Unloads the model and frees resources
|
||||
*/
|
||||
suspend fun unloadModel() =
|
||||
override suspend fun unloadModel() =
|
||||
withContext(llamaDispatcher) {
|
||||
when(_state.value) {
|
||||
is State.AwaitingUserPrompt, is State.Error -> {
|
||||
|
|
@ -202,6 +180,7 @@ class LLamaAndroid private constructor() {
|
|||
unload()
|
||||
_state.value = State.LibraryLoaded
|
||||
Log.i(TAG, "Model unloaded!")
|
||||
Unit
|
||||
}
|
||||
else -> throw IllegalStateException("Cannot unload model in ${_state.value}")
|
||||
}
|
||||
|
|
@ -211,7 +190,7 @@ class LLamaAndroid private constructor() {
|
|||
* Cancel all ongoing coroutines and free GGML backends
|
||||
*/
|
||||
@RequiresCleanup("Call from `ViewModel.onCleared()` to prevent resource leaks!")
|
||||
fun destroy() {
|
||||
override fun destroy() {
|
||||
llamaScope.cancel()
|
||||
when(_state.value) {
|
||||
is State.Uninitialized -> {}
|
||||
|
|
|
|||
Loading…
Reference in New Issue