Restructure `LLamaAndroid.kt`
This commit is contained in:
parent
44720859d6
commit
7dc9968f82
|
|
@ -45,7 +45,8 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
|
||||||
messages += ""
|
messages += ""
|
||||||
|
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
llamaAndroid.sendMessage(text)
|
// TODO-hyin: implement format message
|
||||||
|
llamaAndroid.sendUserPrompt(formattedMessage = text)
|
||||||
.catch {
|
.catch {
|
||||||
Log.e(tag, "send() failed", it)
|
Log.e(tag, "send() failed", it)
|
||||||
messages += it.message!!
|
messages += it.message!!
|
||||||
|
|
|
||||||
|
|
@ -11,31 +11,10 @@ import java.util.concurrent.Executors
|
||||||
import kotlin.concurrent.thread
|
import kotlin.concurrent.thread
|
||||||
|
|
||||||
class LLamaAndroid {
|
class LLamaAndroid {
|
||||||
private val tag: String? = this::class.simpleName
|
/**
|
||||||
|
* JNI methods
|
||||||
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
|
* @see llama-android.cpp
|
||||||
|
*/
|
||||||
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
|
|
||||||
thread(start = false, name = "Llm-RunLoop") {
|
|
||||||
Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
|
|
||||||
|
|
||||||
// No-op if called more than once.
|
|
||||||
System.loadLibrary("llama-android")
|
|
||||||
|
|
||||||
// Set llama log handler to Android
|
|
||||||
log_to_android()
|
|
||||||
backend_init()
|
|
||||||
|
|
||||||
Log.d(tag, system_info())
|
|
||||||
|
|
||||||
it.run()
|
|
||||||
}.apply {
|
|
||||||
uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
|
|
||||||
Log.e(tag, "Unhandled exception", exception)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}.asCoroutineDispatcher()
|
|
||||||
|
|
||||||
private external fun log_to_android()
|
private external fun log_to_android()
|
||||||
private external fun system_info(): String
|
private external fun system_info(): String
|
||||||
private external fun backend_init()
|
private external fun backend_init()
|
||||||
|
|
@ -50,6 +29,40 @@ class LLamaAndroid {
|
||||||
private external fun process_user_prompt(user_prompt: String, nLen: Int): Int
|
private external fun process_user_prompt(user_prompt: String, nLen: Int): Int
|
||||||
private external fun predict_loop(): String?
|
private external fun predict_loop(): String?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Thread local state
|
||||||
|
*/
|
||||||
|
private sealed interface State {
|
||||||
|
data object Idle: State
|
||||||
|
data object ModelLoaded: State
|
||||||
|
data object ReadyForUserPrompt: State
|
||||||
|
}
|
||||||
|
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
|
||||||
|
|
||||||
|
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
|
||||||
|
thread(start = false, name = "Llm-RunLoop") {
|
||||||
|
Log.d(TAG, "Dedicated thread for native code: ${Thread.currentThread().name}")
|
||||||
|
|
||||||
|
// No-op if called more than once.
|
||||||
|
System.loadLibrary("llama-android")
|
||||||
|
|
||||||
|
// Set llama log handler to Android
|
||||||
|
log_to_android()
|
||||||
|
backend_init()
|
||||||
|
|
||||||
|
Log.d(TAG, system_info())
|
||||||
|
|
||||||
|
it.run()
|
||||||
|
}.apply {
|
||||||
|
uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
|
||||||
|
Log.e(TAG, "Unhandled exception", exception)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}.asCoroutineDispatcher()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load the LLM, then process the system prompt if provided
|
||||||
|
*/
|
||||||
suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) {
|
suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) {
|
||||||
withContext(runLoop) {
|
withContext(runLoop) {
|
||||||
when (threadLocalState.get()) {
|
when (threadLocalState.get()) {
|
||||||
|
|
@ -60,12 +73,13 @@ class LLamaAndroid {
|
||||||
val result = ctx_init()
|
val result = ctx_init()
|
||||||
if (result != 0) throw IllegalStateException("Initialization failed with error code: $result")
|
if (result != 0) throw IllegalStateException("Initialization failed with error code: $result")
|
||||||
|
|
||||||
Log.i(tag, "Loaded model $pathToModel")
|
Log.i(TAG, "Loaded model $pathToModel")
|
||||||
threadLocalState.set(State.ModelLoaded)
|
threadLocalState.set(State.ModelLoaded)
|
||||||
|
|
||||||
formattedSystemPrompt?.let {
|
formattedSystemPrompt?.let {
|
||||||
initWithSystemPrompt(formattedSystemPrompt)
|
initWithSystemPrompt(formattedSystemPrompt)
|
||||||
} ?: {
|
} ?: {
|
||||||
|
Log.w(TAG, "No system prompt to process.")
|
||||||
threadLocalState.set(State.ReadyForUserPrompt)
|
threadLocalState.set(State.ReadyForUserPrompt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -74,11 +88,66 @@ class LLamaAndroid {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper method to process system prompt and update [State]
|
||||||
|
*/
|
||||||
|
private suspend fun initWithSystemPrompt(formattedMessage: String) {
|
||||||
|
withContext(runLoop) {
|
||||||
|
when (threadLocalState.get()) {
|
||||||
|
is State.ModelLoaded -> {
|
||||||
|
Log.i(TAG, "Process system prompt...")
|
||||||
|
process_system_prompt(formattedMessage).let {
|
||||||
|
if (it != 0)
|
||||||
|
throw IllegalStateException("Failed to process system prompt: $it")
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.i(TAG, "System prompt processed!")
|
||||||
|
threadLocalState.set(State.ReadyForUserPrompt)
|
||||||
|
}
|
||||||
|
else -> throw IllegalStateException(
|
||||||
|
"Failed to process system prompt: Model not loaded!"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send plain text user prompt to LLM
|
||||||
|
*/
|
||||||
|
fun sendUserPrompt(
|
||||||
|
formattedMessage: String,
|
||||||
|
nPredict: Int = DEFAULT_PREDICT_LENGTH,
|
||||||
|
): Flow<String> = flow {
|
||||||
|
when (threadLocalState.get()) {
|
||||||
|
is State.ReadyForUserPrompt -> {
|
||||||
|
process_user_prompt(formattedMessage, nPredict).let {
|
||||||
|
if (it != 0) {
|
||||||
|
Log.e(TAG, "Failed to process user prompt: $it")
|
||||||
|
return@flow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.i(TAG, "User prompt processed! Generating assistant prompt...")
|
||||||
|
while (true) {
|
||||||
|
val str = predict_loop() ?: break
|
||||||
|
if (str.isNotEmpty()) {
|
||||||
|
emit(str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Log.i(TAG, "Assistant generation complete!")
|
||||||
|
}
|
||||||
|
else -> {}
|
||||||
|
}
|
||||||
|
}.flowOn(runLoop)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Benchmark the model
|
||||||
|
*/
|
||||||
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
|
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
|
||||||
return withContext(runLoop) {
|
return withContext(runLoop) {
|
||||||
when (val state = threadLocalState.get()) {
|
when (val state = threadLocalState.get()) {
|
||||||
is State.ModelLoaded -> {
|
is State.ModelLoaded -> {
|
||||||
Log.d(tag, "bench(): $state")
|
Log.d(TAG, "bench(): $state")
|
||||||
bench_model(pp, tg, pl, nr)
|
bench_model(pp, tg, pl, nr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -88,50 +157,6 @@ class LLamaAndroid {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private suspend fun initWithSystemPrompt(systemPrompt: String) {
|
|
||||||
withContext(runLoop) {
|
|
||||||
when (threadLocalState.get()) {
|
|
||||||
is State.ModelLoaded -> {
|
|
||||||
process_system_prompt(systemPrompt).let {
|
|
||||||
if (it != 0) {
|
|
||||||
throw IllegalStateException("Failed to process system prompt: $it")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Log.i(tag, "System prompt processed!")
|
|
||||||
threadLocalState.set(State.ReadyForUserPrompt)
|
|
||||||
}
|
|
||||||
else -> throw IllegalStateException("Model not loaded")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun sendMessage(
|
|
||||||
formattedUserPrompt: String,
|
|
||||||
nPredict: Int = DEFAULT_PREDICT_LENGTH,
|
|
||||||
): Flow<String> = flow {
|
|
||||||
when (threadLocalState.get()) {
|
|
||||||
is State.ReadyForUserPrompt -> {
|
|
||||||
process_user_prompt(formattedUserPrompt, nPredict).let {
|
|
||||||
if (it != 0) {
|
|
||||||
Log.e(tag, "Failed to process user prompt: $it")
|
|
||||||
return@flow
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Log.i(tag, "User prompt processed! Generating assistant prompt...")
|
|
||||||
while (true) {
|
|
||||||
val str = predict_loop() ?: break
|
|
||||||
if (str.isNotEmpty()) {
|
|
||||||
emit(str)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Log.i(tag, "Assistant generation complete!")
|
|
||||||
}
|
|
||||||
else -> {}
|
|
||||||
}
|
|
||||||
}.flowOn(runLoop)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unloads the model and frees resources.
|
* Unloads the model and frees resources.
|
||||||
*
|
*
|
||||||
|
|
@ -150,17 +175,12 @@ class LLamaAndroid {
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private const val DEFAULT_PREDICT_LENGTH = 128
|
private val TAG = this::class.simpleName
|
||||||
|
|
||||||
private sealed interface State {
|
private const val DEFAULT_PREDICT_LENGTH = 128
|
||||||
data object Idle: State
|
|
||||||
data object ModelLoaded: State
|
|
||||||
data object ReadyForUserPrompt: State
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enforce only one instance of Llm.
|
// Enforce only one instance of Llm.
|
||||||
private val _instance: LLamaAndroid = LLamaAndroid()
|
private val _instance: LLamaAndroid = LLamaAndroid()
|
||||||
|
|
||||||
fun instance(): LLamaAndroid = _instance
|
fun instance(): LLamaAndroid = _instance
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue