Restructure `LLamaAndroid.kt`
This commit is contained in:
parent
44720859d6
commit
7dc9968f82
|
|
@ -45,7 +45,8 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
|
|||
messages += ""
|
||||
|
||||
viewModelScope.launch {
|
||||
llamaAndroid.sendMessage(text)
|
||||
// TODO-hyin: implement format message
|
||||
llamaAndroid.sendUserPrompt(formattedMessage = text)
|
||||
.catch {
|
||||
Log.e(tag, "send() failed", it)
|
||||
messages += it.message!!
|
||||
|
|
|
|||
|
|
@ -11,31 +11,10 @@ import java.util.concurrent.Executors
|
|||
import kotlin.concurrent.thread
|
||||
|
||||
class LLamaAndroid {
|
||||
private val tag: String? = this::class.simpleName
|
||||
|
||||
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
|
||||
|
||||
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
|
||||
thread(start = false, name = "Llm-RunLoop") {
|
||||
Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
|
||||
|
||||
// No-op if called more than once.
|
||||
System.loadLibrary("llama-android")
|
||||
|
||||
// Set llama log handler to Android
|
||||
log_to_android()
|
||||
backend_init()
|
||||
|
||||
Log.d(tag, system_info())
|
||||
|
||||
it.run()
|
||||
}.apply {
|
||||
uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
|
||||
Log.e(tag, "Unhandled exception", exception)
|
||||
}
|
||||
}
|
||||
}.asCoroutineDispatcher()
|
||||
|
||||
/**
|
||||
* JNI methods
|
||||
* @see llama-android.cpp
|
||||
*/
|
||||
private external fun log_to_android()
|
||||
private external fun system_info(): String
|
||||
private external fun backend_init()
|
||||
|
|
@ -50,6 +29,40 @@ class LLamaAndroid {
|
|||
private external fun process_user_prompt(user_prompt: String, nLen: Int): Int
|
||||
private external fun predict_loop(): String?
|
||||
|
||||
/**
|
||||
* Thread local state
|
||||
*/
|
||||
private sealed interface State {
|
||||
data object Idle: State
|
||||
data object ModelLoaded: State
|
||||
data object ReadyForUserPrompt: State
|
||||
}
|
||||
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
|
||||
|
||||
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
|
||||
thread(start = false, name = "Llm-RunLoop") {
|
||||
Log.d(TAG, "Dedicated thread for native code: ${Thread.currentThread().name}")
|
||||
|
||||
// No-op if called more than once.
|
||||
System.loadLibrary("llama-android")
|
||||
|
||||
// Set llama log handler to Android
|
||||
log_to_android()
|
||||
backend_init()
|
||||
|
||||
Log.d(TAG, system_info())
|
||||
|
||||
it.run()
|
||||
}.apply {
|
||||
uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
|
||||
Log.e(TAG, "Unhandled exception", exception)
|
||||
}
|
||||
}
|
||||
}.asCoroutineDispatcher()
|
||||
|
||||
/**
|
||||
* Load the LLM, then process the system prompt if provided
|
||||
*/
|
||||
suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) {
|
||||
withContext(runLoop) {
|
||||
when (threadLocalState.get()) {
|
||||
|
|
@ -60,12 +73,13 @@ class LLamaAndroid {
|
|||
val result = ctx_init()
|
||||
if (result != 0) throw IllegalStateException("Initialization failed with error code: $result")
|
||||
|
||||
Log.i(tag, "Loaded model $pathToModel")
|
||||
Log.i(TAG, "Loaded model $pathToModel")
|
||||
threadLocalState.set(State.ModelLoaded)
|
||||
|
||||
formattedSystemPrompt?.let {
|
||||
initWithSystemPrompt(formattedSystemPrompt)
|
||||
} ?: {
|
||||
Log.w(TAG, "No system prompt to process.")
|
||||
threadLocalState.set(State.ReadyForUserPrompt)
|
||||
}
|
||||
}
|
||||
|
|
@ -74,11 +88,66 @@ class LLamaAndroid {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method to process system prompt and update [State]
|
||||
*/
|
||||
private suspend fun initWithSystemPrompt(formattedMessage: String) {
|
||||
withContext(runLoop) {
|
||||
when (threadLocalState.get()) {
|
||||
is State.ModelLoaded -> {
|
||||
Log.i(TAG, "Process system prompt...")
|
||||
process_system_prompt(formattedMessage).let {
|
||||
if (it != 0)
|
||||
throw IllegalStateException("Failed to process system prompt: $it")
|
||||
}
|
||||
|
||||
Log.i(TAG, "System prompt processed!")
|
||||
threadLocalState.set(State.ReadyForUserPrompt)
|
||||
}
|
||||
else -> throw IllegalStateException(
|
||||
"Failed to process system prompt: Model not loaded!"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send plain text user prompt to LLM
|
||||
*/
|
||||
fun sendUserPrompt(
|
||||
formattedMessage: String,
|
||||
nPredict: Int = DEFAULT_PREDICT_LENGTH,
|
||||
): Flow<String> = flow {
|
||||
when (threadLocalState.get()) {
|
||||
is State.ReadyForUserPrompt -> {
|
||||
process_user_prompt(formattedMessage, nPredict).let {
|
||||
if (it != 0) {
|
||||
Log.e(TAG, "Failed to process user prompt: $it")
|
||||
return@flow
|
||||
}
|
||||
}
|
||||
|
||||
Log.i(TAG, "User prompt processed! Generating assistant prompt...")
|
||||
while (true) {
|
||||
val str = predict_loop() ?: break
|
||||
if (str.isNotEmpty()) {
|
||||
emit(str)
|
||||
}
|
||||
}
|
||||
Log.i(TAG, "Assistant generation complete!")
|
||||
}
|
||||
else -> {}
|
||||
}
|
||||
}.flowOn(runLoop)
|
||||
|
||||
/**
|
||||
* Benchmark the model
|
||||
*/
|
||||
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
|
||||
return withContext(runLoop) {
|
||||
when (val state = threadLocalState.get()) {
|
||||
is State.ModelLoaded -> {
|
||||
Log.d(tag, "bench(): $state")
|
||||
Log.d(TAG, "bench(): $state")
|
||||
bench_model(pp, tg, pl, nr)
|
||||
}
|
||||
|
||||
|
|
@ -88,50 +157,6 @@ class LLamaAndroid {
|
|||
}
|
||||
}
|
||||
|
||||
private suspend fun initWithSystemPrompt(systemPrompt: String) {
|
||||
withContext(runLoop) {
|
||||
when (threadLocalState.get()) {
|
||||
is State.ModelLoaded -> {
|
||||
process_system_prompt(systemPrompt).let {
|
||||
if (it != 0) {
|
||||
throw IllegalStateException("Failed to process system prompt: $it")
|
||||
}
|
||||
}
|
||||
|
||||
Log.i(tag, "System prompt processed!")
|
||||
threadLocalState.set(State.ReadyForUserPrompt)
|
||||
}
|
||||
else -> throw IllegalStateException("Model not loaded")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun sendMessage(
|
||||
formattedUserPrompt: String,
|
||||
nPredict: Int = DEFAULT_PREDICT_LENGTH,
|
||||
): Flow<String> = flow {
|
||||
when (threadLocalState.get()) {
|
||||
is State.ReadyForUserPrompt -> {
|
||||
process_user_prompt(formattedUserPrompt, nPredict).let {
|
||||
if (it != 0) {
|
||||
Log.e(tag, "Failed to process user prompt: $it")
|
||||
return@flow
|
||||
}
|
||||
}
|
||||
|
||||
Log.i(tag, "User prompt processed! Generating assistant prompt...")
|
||||
while (true) {
|
||||
val str = predict_loop() ?: break
|
||||
if (str.isNotEmpty()) {
|
||||
emit(str)
|
||||
}
|
||||
}
|
||||
Log.i(tag, "Assistant generation complete!")
|
||||
}
|
||||
else -> {}
|
||||
}
|
||||
}.flowOn(runLoop)
|
||||
|
||||
/**
|
||||
* Unloads the model and frees resources.
|
||||
*
|
||||
|
|
@ -150,17 +175,12 @@ class LLamaAndroid {
|
|||
}
|
||||
|
||||
companion object {
|
||||
private const val DEFAULT_PREDICT_LENGTH = 128
|
||||
private val TAG = this::class.simpleName
|
||||
|
||||
private sealed interface State {
|
||||
data object Idle: State
|
||||
data object ModelLoaded: State
|
||||
data object ReadyForUserPrompt: State
|
||||
}
|
||||
private const val DEFAULT_PREDICT_LENGTH = 128
|
||||
|
||||
// Enforce only one instance of Llm.
|
||||
private val _instance: LLamaAndroid = LLamaAndroid()
|
||||
|
||||
fun instance(): LLamaAndroid = _instance
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue