Bug fix: null system prompt state update; Safeguard empty user prompt

This commit is contained in:
Han Yin 2025-04-02 18:50:30 -07:00
parent 7bbb53aaf8
commit 02465137ca
1 changed files with 19 additions and 16 deletions

View File

@ -57,7 +57,7 @@ class LLamaAndroid {
/** /**
* Load the LLM, then process the formatted system prompt if provided * Load the LLM, then process the formatted system prompt if provided
*/ */
suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) { suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) =
withContext(runLoop) { withContext(runLoop) {
when (threadLocalState.get()) { when (threadLocalState.get()) {
is State.NotInitialized -> { is State.NotInitialized -> {
@ -72,7 +72,7 @@ class LLamaAndroid {
formattedSystemPrompt?.let { formattedSystemPrompt?.let {
initWithSystemPrompt(formattedSystemPrompt) initWithSystemPrompt(formattedSystemPrompt)
} ?: { } ?: run {
Log.w(TAG, "No system prompt to process.") Log.w(TAG, "No system prompt to process.")
threadLocalState.set(State.AwaitingUserPrompt) threadLocalState.set(State.AwaitingUserPrompt)
} }
@ -80,12 +80,11 @@ class LLamaAndroid {
else -> throw IllegalStateException("Model already loaded") else -> throw IllegalStateException("Model already loaded")
} }
} }
}
/** /**
* Helper method to process system prompt and update [State] * Helper method to process system prompt and update [State]
*/ */
private suspend fun initWithSystemPrompt(formattedMessage: String) { private suspend fun initWithSystemPrompt(formattedMessage: String) =
withContext(runLoop) { withContext(runLoop) {
when (threadLocalState.get()) { when (threadLocalState.get()) {
is State.EnvReady -> { is State.EnvReady -> {
@ -104,7 +103,6 @@ class LLamaAndroid {
) )
} }
} }
}
/** /**
* Send formatted user prompt to LLM * Send formatted user prompt to LLM
@ -113,6 +111,10 @@ class LLamaAndroid {
formattedMessage: String, formattedMessage: String,
nPredict: Int = DEFAULT_PREDICT_LENGTH, nPredict: Int = DEFAULT_PREDICT_LENGTH,
): Flow<String> = flow { ): Flow<String> = flow {
require(formattedMessage.isNotEmpty()) {
Log.w(TAG, "User prompt discarded due to being empty!")
}
when (val state = threadLocalState.get()) { when (val state = threadLocalState.get()) {
is State.AwaitingUserPrompt -> { is State.AwaitingUserPrompt -> {
Log.i(TAG, "Sending user prompt...") Log.i(TAG, "Sending user prompt...")
@ -143,26 +145,28 @@ class LLamaAndroid {
/** /**
* Benchmark the model * Benchmark the model
*/ */
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String =
return withContext(runLoop) { withContext(runLoop) {
when (val state = threadLocalState.get()) { when (threadLocalState.get()) {
is State.EnvReady -> { is State.AwaitingUserPrompt -> {
Log.d(TAG, "bench(): $state") threadLocalState.set(State.Processing)
benchModel(pp, tg, pl, nr) Log.d(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
benchModel(pp, tg, pl, nr).also {
threadLocalState.set(State.AwaitingUserPrompt)
}
} }
// TODO-hyin: catch exception in ViewController; disable button when state incorrect // TODO-hyin: disable button when state incorrect
else -> throw IllegalStateException("No model loaded") else -> throw IllegalStateException("No model loaded")
} }
} }
}
/** /**
* Unloads the model and frees resources. * Unloads the model and frees resources.
* *
* This is a no-op if there's no model loaded. * This is a no-op if there's no model loaded.
*/ */
suspend fun unload() { suspend fun unload() =
withContext(runLoop) { withContext(runLoop) {
when (val state = threadLocalState.get()) { when (val state = threadLocalState.get()) {
is State.EnvReady, State.AwaitingUserPrompt -> { is State.EnvReady, State.AwaitingUserPrompt -> {
@ -174,10 +178,9 @@ class LLamaAndroid {
} }
} }
} }
}
companion object { companion object {
private val TAG = this::class.simpleName private val TAG = LLamaAndroid::class.simpleName
private const val LIB_LLAMA_ANDROID = "llama-android" private const val LIB_LLAMA_ANDROID = "llama-android"
private const val LLAMA_THREAD = "llama-thread" private const val LLAMA_THREAD = "llama-thread"