From e1bc87610eec839fab8f659e8fae03c9dc42d7f3 Mon Sep 17 00:00:00 2001 From: Han Yin Date: Tue, 8 Apr 2025 16:39:17 -0700 Subject: [PATCH] Perf: allocate `llama_batch` on stack with `llama_batch_init` --- .../llama/src/main/cpp/llama-android.cpp | 69 +++++++------------ 1 file changed, 23 insertions(+), 46 deletions(-) diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index fceffcc693..06881104f2 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -44,9 +44,9 @@ constexpr float DEFAULT_SAMPLER_TEMP = 0.3f; static llama_model * g_model; static llama_context * g_context; -static llama_batch * g_batch; -static common_sampler * g_sampler; +static llama_batch g_batch; static common_chat_templates_ptr g_chat_templates; +static common_sampler * g_sampler; static void log_callback(ggml_log_level level, const char *fmt, void *data) { int priority; @@ -140,29 +140,6 @@ static llama_context *init_context(llama_model *model) { return context; } -static llama_batch *new_batch(int n_tokens, int n_seq_max = 1) { - // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. - auto *batch = new llama_batch{ - 0, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - }; - - batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); - batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); - batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); - batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); - for (int i = 0; i < n_tokens; ++i) { - batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); - } - batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); - return batch; -} - static common_sampler *new_sampler(float temp) { common_params_sampling sparams; sparams.temp = temp; @@ -175,18 +152,18 @@ Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unus auto *context = init_context(g_model); if (!context) { return 1; } g_context = context; - g_batch = new_batch(BATCH_SIZE); - g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP); + g_batch = llama_batch_init(BATCH_SIZE, 0, 1); g_chat_templates = common_chat_templates_init(g_model, ""); + g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP); return 0; } extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) { - g_chat_templates.reset(); common_sampler_free(g_sampler); - delete g_batch; + g_chat_templates.reset(); + llama_batch_free(g_batch); llama_free(g_context); llama_model_free(g_model); llama_backend_free(); @@ -222,18 +199,18 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, for (nri = 0; nri < nr; nri++) { LOGi("Benchmark prompt processing (pp)"); - common_batch_clear(*g_batch); + common_batch_clear(g_batch); const int n_tokens = pp; for (i = 0; i < n_tokens; i++) { - common_batch_add(*g_batch, 0, i, {0}, false); + common_batch_add(g_batch, 0, i, {0}, false); } - g_batch->logits[g_batch->n_tokens - 1] = true; + g_batch.logits[g_batch.n_tokens - 1] = true; llama_memory_clear(llama_get_memory(g_context), false); const auto t_pp_start = ggml_time_us(); - if (llama_decode(g_context, *g_batch) != 0) { + if (llama_decode(g_context, g_batch) != 0) { LOGe("llama_decode() failed during prompt processing"); } const auto t_pp_end = ggml_time_us(); @@ -245,12 +222,12 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, llama_memory_clear(llama_get_memory(g_context), false); const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { - common_batch_clear(*g_batch); + common_batch_clear(g_batch); for (j = 0; j < pl; j++) { - common_batch_add(*g_batch, 0, i, {j}, true); + common_batch_add(g_batch, 0, i, {j}, true); } - if (llama_decode(g_context, *g_batch) != 0) { + if (llama_decode(g_context, g_batch) != 0) { LOGe("llama_decode() failed during text generation"); } } @@ -371,15 +348,15 @@ static void reset_short_term_states() { static int decode_tokens_in_batches( llama_context *context, + llama_batch batch, const llama_tokens &tokens, const llama_pos start_pos, - bool compute_last_logit = false, - llama_batch *batch = g_batch) { + const bool compute_last_logit = false) { // Process tokens in batches using the global batch LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos); for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) { const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE); - common_batch_clear(*batch); + common_batch_clear(batch); LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i); // Shift context if current batch cannot fit into the context @@ -393,11 +370,11 @@ static int decode_tokens_in_batches( const llama_token token_id = tokens[i + j]; const llama_pos position = start_pos + i + j; const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1); - common_batch_add(*batch, token_id, position, {0}, want_logit); + common_batch_add(batch, token_id, position, {0}, want_logit); } // Decode this batch - const int decode_result = llama_decode(context, *batch); + const int decode_result = llama_decode(context, batch); if (decode_result) { LOGe("%s: llama_decode failed w/ %d", __func__, decode_result); return 1; @@ -445,7 +422,7 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( } // Decode system tokens in batches - if (decode_tokens_in_batches(g_context, system_tokens, current_position)) { + if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) { LOGe("%s: llama_decode() failed!", __func__); return 2; } @@ -494,7 +471,7 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt( } // Decode user tokens in batches - if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) { + if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) { LOGe("%s: llama_decode() failed!", __func__); return 2; } @@ -562,9 +539,9 @@ Java_android_llama_cpp_LLamaAndroid_generateNextToken( common_sampler_accept(g_sampler, new_token_id, true); // Populate the batch with new token, then decode - common_batch_clear(*g_batch); - common_batch_add(*g_batch, new_token_id, current_position, {0}, true); - if (llama_decode(g_context, *g_batch) != 0) { + common_batch_clear(g_batch); + common_batch_add(g_batch, new_token_id, current_position, {0}, true); + if (llama_decode(g_context, g_batch) != 0) { LOGe("%s: llama_decode() failed for generated token", __func__); return nullptr; }