Bug fix: null system prompt state update; Safeguard empty user prompt

This commit is contained in:
Han Yin 2025-04-02 18:50:30 -07:00
parent 7bbb53aaf8
commit 02465137ca
1 changed files with 19 additions and 16 deletions

View File

@ -57,7 +57,7 @@ class LLamaAndroid {
/**
* Load the LLM, then process the formatted system prompt if provided
*/
suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) {
suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) =
withContext(runLoop) {
when (threadLocalState.get()) {
is State.NotInitialized -> {
@ -72,7 +72,7 @@ class LLamaAndroid {
formattedSystemPrompt?.let {
initWithSystemPrompt(formattedSystemPrompt)
} ?: {
} ?: run {
Log.w(TAG, "No system prompt to process.")
threadLocalState.set(State.AwaitingUserPrompt)
}
@ -80,12 +80,11 @@ class LLamaAndroid {
else -> throw IllegalStateException("Model already loaded")
}
}
}
/**
* Helper method to process system prompt and update [State]
*/
private suspend fun initWithSystemPrompt(formattedMessage: String) {
private suspend fun initWithSystemPrompt(formattedMessage: String) =
withContext(runLoop) {
when (threadLocalState.get()) {
is State.EnvReady -> {
@ -104,7 +103,6 @@ class LLamaAndroid {
)
}
}
}
/**
* Send formatted user prompt to LLM
@ -113,6 +111,10 @@ class LLamaAndroid {
formattedMessage: String,
nPredict: Int = DEFAULT_PREDICT_LENGTH,
): Flow<String> = flow {
require(formattedMessage.isNotEmpty()) {
Log.w(TAG, "User prompt discarded due to being empty!")
}
when (val state = threadLocalState.get()) {
is State.AwaitingUserPrompt -> {
Log.i(TAG, "Sending user prompt...")
@ -143,26 +145,28 @@ class LLamaAndroid {
/**
* Benchmark the model
*/
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
return withContext(runLoop) {
when (val state = threadLocalState.get()) {
is State.EnvReady -> {
Log.d(TAG, "bench(): $state")
benchModel(pp, tg, pl, nr)
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String =
withContext(runLoop) {
when (threadLocalState.get()) {
is State.AwaitingUserPrompt -> {
threadLocalState.set(State.Processing)
Log.d(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
benchModel(pp, tg, pl, nr).also {
threadLocalState.set(State.AwaitingUserPrompt)
}
}
// TODO-hyin: catch exception in ViewController; disable button when state incorrect
// TODO-hyin: disable button when state incorrect
else -> throw IllegalStateException("No model loaded")
}
}
}
/**
* Unloads the model and frees resources.
*
* This is a no-op if there's no model loaded.
*/
suspend fun unload() {
suspend fun unload() =
withContext(runLoop) {
when (val state = threadLocalState.get()) {
is State.EnvReady, State.AwaitingUserPrompt -> {
@ -174,10 +178,9 @@ class LLamaAndroid {
}
}
}
}
companion object {
private val TAG = this::class.simpleName
private val TAG = LLamaAndroid::class.simpleName
private const val LIB_LLAMA_ANDROID = "llama-android"
private const val LLAMA_THREAD = "llama-thread"