LLM: stub a local inference engine for faster iteration
This commit is contained in:
parent
3787fbddb0
commit
3f913ce440
|
|
@ -0,0 +1,159 @@
|
|||
package com.example.llama.revamp.engine
|
||||
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.flow
|
||||
|
||||
/**
|
||||
* LLM inference engine that handles model loading and text generation.
|
||||
*/
|
||||
class InferenceEngine {
|
||||
companion object {
|
||||
const val DEFAULT_PREDICT_LENGTH = 1024
|
||||
}
|
||||
|
||||
sealed class State {
|
||||
object Uninitialized : State()
|
||||
object LibraryLoaded : State()
|
||||
|
||||
object LoadingModel : State()
|
||||
object ModelLoaded : State()
|
||||
|
||||
object ProcessingSystemPrompt : State()
|
||||
object AwaitingUserPrompt : State()
|
||||
|
||||
object ProcessingUserPrompt : State()
|
||||
object Generating : State()
|
||||
|
||||
object Benchmarking : State()
|
||||
|
||||
data class Error(
|
||||
val errorMessage: String = ""
|
||||
) : State()
|
||||
}
|
||||
|
||||
private val _state = MutableStateFlow<State>(State.Uninitialized)
|
||||
val state: StateFlow<State> = _state
|
||||
|
||||
// Keep track of current benchmark results
|
||||
private var _benchmarkResults: String? = null
|
||||
val benchmarkResults: StateFlow<String?> = MutableStateFlow(_benchmarkResults)
|
||||
|
||||
init {
|
||||
// Simulate library loading
|
||||
_state.value = State.LibraryLoaded
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a model from the given path with an optional system prompt.
|
||||
*/
|
||||
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) {
|
||||
try {
|
||||
_state.value = State.LoadingModel
|
||||
|
||||
// Simulate model loading
|
||||
delay(1000)
|
||||
|
||||
_state.value = State.ModelLoaded
|
||||
|
||||
if (systemPrompt != null) {
|
||||
_state.value = State.ProcessingSystemPrompt
|
||||
|
||||
// Simulate processing system prompt
|
||||
delay(500)
|
||||
}
|
||||
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during model loading")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
|
||||
*/
|
||||
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String> {
|
||||
_state.value = State.ProcessingUserPrompt
|
||||
|
||||
// This would be replaced with actual token generation logic
|
||||
return flow {
|
||||
delay(500) // Simulate processing time
|
||||
|
||||
_state.value = State.Generating
|
||||
|
||||
// Simulate token generation
|
||||
val response =
|
||||
"This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
|
||||
val words = response.split(" ")
|
||||
|
||||
for (word in words) {
|
||||
emit(word + " ")
|
||||
delay(50) // Simulate token generation delay
|
||||
}
|
||||
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs a benchmark with the specified parameters.
|
||||
*/
|
||||
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
|
||||
_state.value = State.Benchmarking
|
||||
|
||||
try {
|
||||
// Simulate benchmark running
|
||||
delay(2000)
|
||||
|
||||
// Generate fake benchmark results
|
||||
val modelDesc = "LlamaModel"
|
||||
val model_size = "7"
|
||||
val model_n_params = "7"
|
||||
val backend = "CPU"
|
||||
|
||||
// Random values for benchmarks
|
||||
val pp_avg = (15.0 + Math.random() * 10.0).toFloat()
|
||||
val pp_std = (0.5 + Math.random() * 2.0).toFloat()
|
||||
val tg_avg = (20.0 + Math.random() * 15.0).toFloat()
|
||||
val tg_std = (0.7 + Math.random() * 3.0).toFloat()
|
||||
|
||||
val result = StringBuilder()
|
||||
result.append("| model | size | params | backend | test | t/s |\n")
|
||||
result.append("| --- | --- | --- | --- | --- | --- |\n")
|
||||
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
|
||||
result.append("$backend | pp $pp | $pp_avg ± $pp_std |\n")
|
||||
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
|
||||
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
|
||||
|
||||
_benchmarkResults = result.toString()
|
||||
(benchmarkResults as MutableStateFlow).value = _benchmarkResults
|
||||
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
|
||||
return _benchmarkResults ?: ""
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
|
||||
return "Error: ${e.message}"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Unloads the currently loaded model.
|
||||
*/
|
||||
suspend fun unloadModel() {
|
||||
// Simulate model unloading time
|
||||
delay(300)
|
||||
_state.value = State.LibraryLoaded
|
||||
_benchmarkResults = null
|
||||
(benchmarkResults as MutableStateFlow).value = null
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleans up resources when the engine is no longer needed.
|
||||
*/
|
||||
fun destroy() {
|
||||
// In a real implementation, this would release native resources
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue