LLM: stub a local inference engine for faster iteration

This commit is contained in:
Han Yin 2025-04-11 14:36:26 -07:00
parent 3787fbddb0
commit 3f913ce440
1 changed files with 159 additions and 0 deletions

View File

@ -0,0 +1,159 @@
package com.example.llama.revamp.engine
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.flow
/**
* LLM inference engine that handles model loading and text generation.
*/
class InferenceEngine {
companion object {
const val DEFAULT_PREDICT_LENGTH = 1024
}
sealed class State {
object Uninitialized : State()
object LibraryLoaded : State()
object LoadingModel : State()
object ModelLoaded : State()
object ProcessingSystemPrompt : State()
object AwaitingUserPrompt : State()
object ProcessingUserPrompt : State()
object Generating : State()
object Benchmarking : State()
data class Error(
val errorMessage: String = ""
) : State()
}
private val _state = MutableStateFlow<State>(State.Uninitialized)
val state: StateFlow<State> = _state
// Keep track of current benchmark results
private var _benchmarkResults: String? = null
val benchmarkResults: StateFlow<String?> = MutableStateFlow(_benchmarkResults)
init {
// Simulate library loading
_state.value = State.LibraryLoaded
}
/**
* Loads a model from the given path with an optional system prompt.
*/
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) {
try {
_state.value = State.LoadingModel
// Simulate model loading
delay(1000)
_state.value = State.ModelLoaded
if (systemPrompt != null) {
_state.value = State.ProcessingSystemPrompt
// Simulate processing system prompt
delay(500)
}
_state.value = State.AwaitingUserPrompt
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during model loading")
}
}
/**
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
*/
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String> {
_state.value = State.ProcessingUserPrompt
// This would be replaced with actual token generation logic
return flow {
delay(500) // Simulate processing time
_state.value = State.Generating
// Simulate token generation
val response =
"This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
val words = response.split(" ")
for (word in words) {
emit(word + " ")
delay(50) // Simulate token generation delay
}
_state.value = State.AwaitingUserPrompt
}
}
/**
* Runs a benchmark with the specified parameters.
*/
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
_state.value = State.Benchmarking
try {
// Simulate benchmark running
delay(2000)
// Generate fake benchmark results
val modelDesc = "LlamaModel"
val model_size = "7"
val model_n_params = "7"
val backend = "CPU"
// Random values for benchmarks
val pp_avg = (15.0 + Math.random() * 10.0).toFloat()
val pp_std = (0.5 + Math.random() * 2.0).toFloat()
val tg_avg = (20.0 + Math.random() * 15.0).toFloat()
val tg_std = (0.7 + Math.random() * 3.0).toFloat()
val result = StringBuilder()
result.append("| model | size | params | backend | test | t/s |\n")
result.append("| --- | --- | --- | --- | --- | --- |\n")
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
result.append("$backend | pp $pp | $pp_avg ± $pp_std |\n")
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
_benchmarkResults = result.toString()
(benchmarkResults as MutableStateFlow).value = _benchmarkResults
_state.value = State.AwaitingUserPrompt
return _benchmarkResults ?: ""
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
return "Error: ${e.message}"
}
}
/**
* Unloads the currently loaded model.
*/
suspend fun unloadModel() {
// Simulate model unloading time
delay(300)
_state.value = State.LibraryLoaded
_benchmarkResults = null
(benchmarkResults as MutableStateFlow).value = null
}
/**
* Cleans up resources when the engine is no longer needed.
*/
fun destroy() {
// In a real implementation, this would release native resources
}
}