Restructure `LLamaAndroid.kt`

This commit is contained in:
Han Yin 2025-03-28 12:21:14 -07:00
parent 44720859d6
commit 7dc9968f82
2 changed files with 100 additions and 79 deletions

View File

@ -45,7 +45,8 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
messages += "" messages += ""
viewModelScope.launch { viewModelScope.launch {
llamaAndroid.sendMessage(text) // TODO-hyin: implement format message
llamaAndroid.sendUserPrompt(formattedMessage = text)
.catch { .catch {
Log.e(tag, "send() failed", it) Log.e(tag, "send() failed", it)
messages += it.message!! messages += it.message!!

View File

@ -11,31 +11,10 @@ import java.util.concurrent.Executors
import kotlin.concurrent.thread import kotlin.concurrent.thread
class LLamaAndroid { class LLamaAndroid {
private val tag: String? = this::class.simpleName /**
* JNI methods
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle } * @see llama-android.cpp
*/
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
thread(start = false, name = "Llm-RunLoop") {
Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
// No-op if called more than once.
System.loadLibrary("llama-android")
// Set llama log handler to Android
log_to_android()
backend_init()
Log.d(tag, system_info())
it.run()
}.apply {
uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
Log.e(tag, "Unhandled exception", exception)
}
}
}.asCoroutineDispatcher()
private external fun log_to_android() private external fun log_to_android()
private external fun system_info(): String private external fun system_info(): String
private external fun backend_init() private external fun backend_init()
@ -50,6 +29,40 @@ class LLamaAndroid {
private external fun process_user_prompt(user_prompt: String, nLen: Int): Int private external fun process_user_prompt(user_prompt: String, nLen: Int): Int
private external fun predict_loop(): String? private external fun predict_loop(): String?
/**
* Thread local state
*/
private sealed interface State {
data object Idle: State
data object ModelLoaded: State
data object ReadyForUserPrompt: State
}
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
thread(start = false, name = "Llm-RunLoop") {
Log.d(TAG, "Dedicated thread for native code: ${Thread.currentThread().name}")
// No-op if called more than once.
System.loadLibrary("llama-android")
// Set llama log handler to Android
log_to_android()
backend_init()
Log.d(TAG, system_info())
it.run()
}.apply {
uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
Log.e(TAG, "Unhandled exception", exception)
}
}
}.asCoroutineDispatcher()
/**
* Load the LLM, then process the system prompt if provided
*/
suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) { suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) {
withContext(runLoop) { withContext(runLoop) {
when (threadLocalState.get()) { when (threadLocalState.get()) {
@ -60,12 +73,13 @@ class LLamaAndroid {
val result = ctx_init() val result = ctx_init()
if (result != 0) throw IllegalStateException("Initialization failed with error code: $result") if (result != 0) throw IllegalStateException("Initialization failed with error code: $result")
Log.i(tag, "Loaded model $pathToModel") Log.i(TAG, "Loaded model $pathToModel")
threadLocalState.set(State.ModelLoaded) threadLocalState.set(State.ModelLoaded)
formattedSystemPrompt?.let { formattedSystemPrompt?.let {
initWithSystemPrompt(formattedSystemPrompt) initWithSystemPrompt(formattedSystemPrompt)
} ?: { } ?: {
Log.w(TAG, "No system prompt to process.")
threadLocalState.set(State.ReadyForUserPrompt) threadLocalState.set(State.ReadyForUserPrompt)
} }
} }
@ -74,11 +88,66 @@ class LLamaAndroid {
} }
} }
/**
* Helper method to process system prompt and update [State]
*/
private suspend fun initWithSystemPrompt(formattedMessage: String) {
withContext(runLoop) {
when (threadLocalState.get()) {
is State.ModelLoaded -> {
Log.i(TAG, "Process system prompt...")
process_system_prompt(formattedMessage).let {
if (it != 0)
throw IllegalStateException("Failed to process system prompt: $it")
}
Log.i(TAG, "System prompt processed!")
threadLocalState.set(State.ReadyForUserPrompt)
}
else -> throw IllegalStateException(
"Failed to process system prompt: Model not loaded!"
)
}
}
}
/**
* Send plain text user prompt to LLM
*/
fun sendUserPrompt(
formattedMessage: String,
nPredict: Int = DEFAULT_PREDICT_LENGTH,
): Flow<String> = flow {
when (threadLocalState.get()) {
is State.ReadyForUserPrompt -> {
process_user_prompt(formattedMessage, nPredict).let {
if (it != 0) {
Log.e(TAG, "Failed to process user prompt: $it")
return@flow
}
}
Log.i(TAG, "User prompt processed! Generating assistant prompt...")
while (true) {
val str = predict_loop() ?: break
if (str.isNotEmpty()) {
emit(str)
}
}
Log.i(TAG, "Assistant generation complete!")
}
else -> {}
}
}.flowOn(runLoop)
/**
* Benchmark the model
*/
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
return withContext(runLoop) { return withContext(runLoop) {
when (val state = threadLocalState.get()) { when (val state = threadLocalState.get()) {
is State.ModelLoaded -> { is State.ModelLoaded -> {
Log.d(tag, "bench(): $state") Log.d(TAG, "bench(): $state")
bench_model(pp, tg, pl, nr) bench_model(pp, tg, pl, nr)
} }
@ -88,50 +157,6 @@ class LLamaAndroid {
} }
} }
private suspend fun initWithSystemPrompt(systemPrompt: String) {
withContext(runLoop) {
when (threadLocalState.get()) {
is State.ModelLoaded -> {
process_system_prompt(systemPrompt).let {
if (it != 0) {
throw IllegalStateException("Failed to process system prompt: $it")
}
}
Log.i(tag, "System prompt processed!")
threadLocalState.set(State.ReadyForUserPrompt)
}
else -> throw IllegalStateException("Model not loaded")
}
}
}
fun sendMessage(
formattedUserPrompt: String,
nPredict: Int = DEFAULT_PREDICT_LENGTH,
): Flow<String> = flow {
when (threadLocalState.get()) {
is State.ReadyForUserPrompt -> {
process_user_prompt(formattedUserPrompt, nPredict).let {
if (it != 0) {
Log.e(tag, "Failed to process user prompt: $it")
return@flow
}
}
Log.i(tag, "User prompt processed! Generating assistant prompt...")
while (true) {
val str = predict_loop() ?: break
if (str.isNotEmpty()) {
emit(str)
}
}
Log.i(tag, "Assistant generation complete!")
}
else -> {}
}
}.flowOn(runLoop)
/** /**
* Unloads the model and frees resources. * Unloads the model and frees resources.
* *
@ -150,17 +175,12 @@ class LLamaAndroid {
} }
companion object { companion object {
private const val DEFAULT_PREDICT_LENGTH = 128 private val TAG = this::class.simpleName
private sealed interface State { private const val DEFAULT_PREDICT_LENGTH = 128
data object Idle: State
data object ModelLoaded: State
data object ReadyForUserPrompt: State
}
// Enforce only one instance of Llm. // Enforce only one instance of Llm.
private val _instance: LLamaAndroid = LLamaAndroid() private val _instance: LLamaAndroid = LLamaAndroid()
fun instance(): LLamaAndroid = _instance fun instance(): LLamaAndroid = _instance
} }
} }