Polish binding: Remove verbose setup JNI APIs; Update state machine states.

This commit is contained in:
Han Yin 2025-03-31 12:53:00 -07:00
parent 7dc9968f82
commit 0ade7fb4d7
2 changed files with 101 additions and 82 deletions

View File

@ -12,10 +12,11 @@
* Logging utils
*/
#define TAG "llama-android.cpp"
#define LOGd(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__)
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
#define LOGw(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
#define LOGv(...) __android_log_print(ANDROID_LOG_VERBOSE, TAG, __VA_ARGS__)
#define LOGd(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__)
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
#define LOGw(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
/**
* LLama resources: context, model, batch and sampler
@ -55,32 +56,35 @@ static void log_callback(ggml_log_level level, const char *fmt, void *data) {
__android_log_print(priority, TAG, fmt, data);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv * /*unused*/, jobject /*unused*/) {
llama_log_set(log_callback, nullptr);
}
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) {
JNIEnv* env;
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
return JNI_ERR;
}
// Set llama log handler to Android
llama_log_set(log_callback, nullptr);
// Initialize backends
llama_backend_init();
LOGi("Backend initiated.");
return JNI_VERSION_1_6;
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject /*unused*/) {
Java_android_llama_cpp_LLamaAndroid_systemInfo(JNIEnv *env, jobject /*unused*/) {
return env->NewStringUTF(llama_print_system_info());
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv * /*unused*/, jobject /*unused*/) {
llama_backend_init();
}
extern "C"
JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) {
Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring filename) {
llama_model_params model_params = llama_model_default_params();
const auto *path_to_model = env->GetStringUTFChars(filename, 0);
LOGi("Loading model from: %s", path_to_model);
LOGd("Loading model from: %s", path_to_model);
model = llama_model_load_from_file(path_to_model, model_params);
env->ReleaseStringUTFChars(filename, path_to_model);
@ -153,7 +157,7 @@ void new_sampler(float temp) {
extern "C"
JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_ctx_1init(JNIEnv * /*env*/, jobject /*unused*/) {
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
int ret = init_context();
if (ret != 0) { return ret; }
new_batch(BATCH_SIZE);
@ -163,7 +167,7 @@ Java_android_llama_cpp_LLamaAndroid_ctx_1init(JNIEnv * /*env*/, jobject /*unused
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_clean_1up(JNIEnv * /*unused*/, jobject /*unused*/) {
Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) {
llama_model_free(model);
llama_free(context);
llama_backend_free();
@ -173,7 +177,7 @@ Java_android_llama_cpp_LLamaAndroid_clean_1up(JNIEnv * /*unused*/, jobject /*unu
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_bench_1model(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) {
Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) {
auto pp_avg = 0.0;
auto tg_avg = 0.0;
auto pp_std = 0.0;
@ -284,7 +288,7 @@ std::string cached_token_chars;
extern "C"
JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_process_1system_1prompt(
Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
JNIEnv *env,
jobject /*unused*/,
jstring jsystem_prompt
@ -299,10 +303,15 @@ Java_android_llama_cpp_LLamaAndroid_process_1system_1prompt(
// Obtain and tokenize system prompt
const auto *const system_text = env->GetStringUTFChars(jsystem_prompt, nullptr);
LOGi("System prompt: \n%s", system_text);
LOGd("System prompt received: \n%s", system_text);
const auto system_tokens = common_tokenize(context, system_text, true, true);
env->ReleaseStringUTFChars(jsystem_prompt, system_text);
// Print each token in verbose mode
for (auto id : system_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(context, id).c_str(), id);
}
// Add system prompt tokens to batch
common_batch_clear(*batch);
// TODO-hyin: support batch processing!
@ -325,11 +334,11 @@ Java_android_llama_cpp_LLamaAndroid_process_1system_1prompt(
// TODO-hyin: support KV cache backtracking
extern "C"
JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_process_1user_1prompt(
Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
JNIEnv *env,
jobject /*unused*/,
jstring juser_prompt,
jint n_len
jint n_predict
) {
// Reset short-term states
token_predict_budget = 0;
@ -337,12 +346,17 @@ Java_android_llama_cpp_LLamaAndroid_process_1user_1prompt(
// Obtain and tokenize user prompt
const auto *const user_text = env->GetStringUTFChars(juser_prompt, nullptr);
LOGi("User prompt: \n%s", user_text);
LOGd("User prompt received: \n%s", user_text);
const auto user_tokens = common_tokenize(context, user_text, true, true);
env->ReleaseStringUTFChars(juser_prompt, user_text);
// Print each token in verbose mode
for (auto id : user_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(context, id).c_str(), id);
}
// Check if context space is enough for desired tokens
int desired_budget = current_position + user_tokens.size() + n_len;
int desired_budget = current_position + user_tokens.size() + n_predict;
if (desired_budget > llama_n_ctx(context)) {
LOGe("error: total tokens exceed context size");
return -1;
@ -404,13 +418,13 @@ bool is_valid_utf8(const char *string) {
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_predict_1loop(
Java_android_llama_cpp_LLamaAndroid_predictLoop(
JNIEnv *env,
jobject /*unused*/
) {
// Stop if running out of token budget
if (current_position >= token_predict_budget) {
LOGi("STOP: current position (%d) exceeds budget (%d)", current_position, token_predict_budget);
LOGw("STOP: current position (%d) exceeds budget (%d)", current_position, token_predict_budget);
return nullptr;
}
@ -420,7 +434,7 @@ Java_android_llama_cpp_LLamaAndroid_predict_1loop(
// Stop if next token is EOG
if (llama_vocab_is_eog(llama_model_get_vocab(model), new_token_id)) {
LOGi("id: %d,\tIS EOG!\nSTOP.", new_token_id);
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
return nullptr;
}

View File

@ -15,42 +15,36 @@ class LLamaAndroid {
* JNI methods
* @see llama-android.cpp
*/
private external fun log_to_android()
private external fun system_info(): String
private external fun backend_init()
private external fun systemInfo(): String
private external fun load_model(filename: String): Int
private external fun ctx_init(): Int
private external fun clean_up()
private external fun loadModel(filename: String): Int
private external fun initContext(): Int
private external fun cleanUp()
private external fun bench_model(pp: Int, tg: Int, pl: Int, nr: Int): String
private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
private external fun process_system_prompt(system_prompt: String): Int
private external fun process_user_prompt(user_prompt: String, nLen: Int): Int
private external fun predict_loop(): String?
private external fun processSystemPrompt(systemPrompt: String): Int
private external fun processUserPrompt(userPrompt: String, nPredict: Int): Int
private external fun predictLoop(): String?
/**
* Thread local state
*/
private sealed interface State {
data object Idle: State
data object ModelLoaded: State
data object ReadyForUserPrompt: State
data object NotInitialized: State
data object EnvReady: State
data object AwaitingUserPrompt: State
data object Processing: State
}
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.NotInitialized }
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
thread(start = false, name = "Llm-RunLoop") {
thread(start = false, name = LLAMA_THREAD) {
Log.d(TAG, "Dedicated thread for native code: ${Thread.currentThread().name}")
// No-op if called more than once.
System.loadLibrary("llama-android")
// Set llama log handler to Android
log_to_android()
backend_init()
Log.d(TAG, system_info())
System.loadLibrary(LIB_LLAMA_ANDROID)
Log.d(TAG, systemInfo())
it.run()
}.apply {
@ -61,26 +55,26 @@ class LLamaAndroid {
}.asCoroutineDispatcher()
/**
* Load the LLM, then process the system prompt if provided
* Load the LLM, then process the formatted system prompt if provided
*/
suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) {
withContext(runLoop) {
when (threadLocalState.get()) {
is State.Idle -> {
val model = load_model(pathToModel)
if (model != 0) throw IllegalStateException("Load model failed")
is State.NotInitialized -> {
val modelResult = loadModel(pathToModel)
if (modelResult != 0) throw IllegalStateException("Load model failed: $modelResult")
val result = ctx_init()
if (result != 0) throw IllegalStateException("Initialization failed with error code: $result")
val initResult = initContext()
if (initResult != 0) throw IllegalStateException("Initialization failed with error code: $initResult")
Log.i(TAG, "Loaded model $pathToModel")
threadLocalState.set(State.ModelLoaded)
threadLocalState.set(State.EnvReady)
formattedSystemPrompt?.let {
initWithSystemPrompt(formattedSystemPrompt)
} ?: {
Log.w(TAG, "No system prompt to process.")
threadLocalState.set(State.ReadyForUserPrompt)
threadLocalState.set(State.AwaitingUserPrompt)
}
}
else -> throw IllegalStateException("Model already loaded")
@ -94,15 +88,16 @@ class LLamaAndroid {
private suspend fun initWithSystemPrompt(formattedMessage: String) {
withContext(runLoop) {
when (threadLocalState.get()) {
is State.ModelLoaded -> {
is State.EnvReady -> {
Log.i(TAG, "Process system prompt...")
process_system_prompt(formattedMessage).let {
threadLocalState.set(State.Processing)
processSystemPrompt(formattedMessage).let {
if (it != 0)
throw IllegalStateException("Failed to process system prompt: $it")
}
Log.i(TAG, "System prompt processed!")
threadLocalState.set(State.ReadyForUserPrompt)
threadLocalState.set(State.AwaitingUserPrompt)
}
else -> throw IllegalStateException(
"Failed to process system prompt: Model not loaded!"
@ -112,31 +107,36 @@ class LLamaAndroid {
}
/**
* Send plain text user prompt to LLM
* Send formatted user prompt to LLM
*/
fun sendUserPrompt(
formattedMessage: String,
nPredict: Int = DEFAULT_PREDICT_LENGTH,
): Flow<String> = flow {
when (threadLocalState.get()) {
is State.ReadyForUserPrompt -> {
process_user_prompt(formattedMessage, nPredict).let {
if (it != 0) {
Log.e(TAG, "Failed to process user prompt: $it")
when (val state = threadLocalState.get()) {
is State.AwaitingUserPrompt -> {
Log.i(TAG, "Sending user prompt...")
threadLocalState.set(State.Processing)
processUserPrompt(formattedMessage, nPredict).let { result ->
if (result != 0) {
Log.e(TAG, "Failed to process user prompt: $result")
return@flow
}
}
Log.i(TAG, "User prompt processed! Generating assistant prompt...")
while (true) {
val str = predict_loop() ?: break
if (str.isNotEmpty()) {
emit(str)
}
predictLoop()?.let { utf8token ->
if (utf8token.isNotEmpty()) emit(utf8token)
} ?: break
}
Log.i(TAG, "Assistant generation complete!")
threadLocalState.set(State.AwaitingUserPrompt)
}
else -> {
Log.w(TAG, "User prompt discarded due to incorrect state: $state")
}
else -> {}
}
}.flowOn(runLoop)
@ -146,9 +146,9 @@ class LLamaAndroid {
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
return withContext(runLoop) {
when (val state = threadLocalState.get()) {
is State.ModelLoaded -> {
is State.EnvReady -> {
Log.d(TAG, "bench(): $state")
bench_model(pp, tg, pl, nr)
benchModel(pp, tg, pl, nr)
}
// TODO-hyin: catch exception in ViewController; disable button when state incorrect
@ -164,12 +164,14 @@ class LLamaAndroid {
*/
suspend fun unload() {
withContext(runLoop) {
when (threadLocalState.get()) {
is State.ModelLoaded -> {
clean_up()
threadLocalState.set(State.Idle)
when (val state = threadLocalState.get()) {
is State.EnvReady, State.AwaitingUserPrompt -> {
cleanUp()
threadLocalState.set(State.NotInitialized)
}
else -> {
Log.w(TAG, "Cannot unload model due to incorrect state: $state")
}
else -> {}
}
}
}
@ -177,7 +179,10 @@ class LLamaAndroid {
companion object {
private val TAG = this::class.simpleName
private const val DEFAULT_PREDICT_LENGTH = 128
private const val LIB_LLAMA_ANDROID = "llama-android"
private const val LLAMA_THREAD = "llama-thread"
private const val DEFAULT_PREDICT_LENGTH = 64
// Enforce only one instance of Llm.
private val _instance: LLamaAndroid = LLamaAndroid()