From 4809112ec51126335dbfa8facfd096a8fadbec10 Mon Sep 17 00:00:00 2001 From: Han Yin Date: Mon, 7 Apr 2025 20:48:15 -0700 Subject: [PATCH] Polish: adopt common naming; init modularization; --- .../llama/src/main/cpp/llama-android.cpp | 32 ++++++++++--------- .../java/android/llama/cpp/LLamaAndroid.kt | 6 ++-- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 415cc197de..4a0492aebb 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -108,10 +108,10 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file return 0; } -static int init_context(llama_model *model) { +static llama_context* init_context(llama_model *model) { if (!model) { LOGe("init_context(): model cannot be null"); - return -1; + return nullptr; } // Multi-threading setup @@ -128,15 +128,13 @@ static int init_context(llama_model *model) { ctx_params.n_threads = n_threads; ctx_params.n_threads_batch = n_threads; auto *context = llama_init_from_model(g_model, ctx_params); - if (!context) { + if (context == nullptr) { LOGe("llama_new_context_with_model() returned null)"); - return -2; } - g_context = context; - return 0; + return context; } -static void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) { +static llama_batch * new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) { // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. auto *batch = new llama_batch{ 0, @@ -161,22 +159,23 @@ static void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) { batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); } batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); - g_batch = batch; + return batch; } -void new_sampler(float temp) { +static common_sampler* new_sampler(float temp) { common_params_sampling sparams; sparams.temp = temp; - g_sampler = common_sampler_init(g_model, sparams); + return common_sampler_init(g_model, sparams); } extern "C" JNIEXPORT jint JNICALL Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) { - int ret = init_context(g_model); - if (ret != 0) { return ret; } - new_batch(BATCH_SIZE); - new_sampler(SAMPLER_TEMP); + auto *context = init_context(g_model); + if (!context) { return -1; } + g_context = context; + g_batch = new_batch(BATCH_SIZE); + g_sampler = new_sampler(SAMPLER_TEMP); g_chat_templates = common_chat_templates_init(g_model, ""); return 0; } @@ -396,6 +395,8 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); } + // TODO-hyin: handle context overflow + // Decode system tokens in batches if (decode_tokens_in_batches(g_context, system_tokens, current_position)) { LOGe("llama_decode() failed!"); @@ -439,6 +440,7 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt( LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); } + // TODO-hyin: handle context overflow // Check if context space is enough for desired tokens int desired_budget = current_position + (int) user_tokens.size() + n_predict; if (desired_budget > llama_n_ctx(g_context)) { @@ -494,7 +496,7 @@ static bool is_valid_utf8(const char *string) { extern "C" JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_predictLoop( +Java_android_llama_cpp_LLamaAndroid_completionLoop( JNIEnv *env, jobject /*unused*/ ) { diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index b99ab2867b..4895086f97 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -24,8 +24,8 @@ class LLamaAndroid { private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String private external fun processSystemPrompt(systemPrompt: String): Int - private external fun processUserPrompt(userPrompt: String, nPredict: Int): Int - private external fun predictLoop(): String? + private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int + private external fun completionLoop(): String? /** * Thread local state @@ -128,7 +128,7 @@ class LLamaAndroid { Log.i(TAG, "User prompt processed! Generating assistant prompt...") while (true) { - predictLoop()?.let { utf8token -> + completionLoop()?.let { utf8token -> if (utf8token.isNotEmpty()) emit(utf8token) } ?: break }