From c14c11dcbd7fd1efe826e9be4f81d94a2a65f761 Mon Sep 17 00:00:00 2001 From: Han Yin Date: Mon, 7 Apr 2025 14:20:51 -0700 Subject: [PATCH] Feature: decode system and user prompt in batches --- .../llama/src/main/cpp/llama-android.cpp | 122 +++++++++++------- 1 file changed, 73 insertions(+), 49 deletions(-) diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 3713cff4fd..ac7dd67030 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -111,6 +111,8 @@ static int init_context(llama_model *model) { // Context parameters setup llama_context_params ctx_params = llama_context_default_params(); ctx_params.n_ctx = CONTEXT_SIZE; + ctx_params.n_batch = BATCH_SIZE; + ctx_params.n_ubatch = BATCH_SIZE; ctx_params.n_threads = n_threads; ctx_params.n_threads_batch = n_threads; auto *context = llama_init_from_model(g_model, ctx_params); @@ -171,9 +173,21 @@ JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) { llama_model_free(g_model); llama_free(g_context); - llama_backend_free(); delete g_batch; common_sampler_free(g_sampler); + llama_backend_free(); +} + +static std::string get_backend() { + std::vector backends; + for (size_t i = 0; i < ggml_backend_reg_count(); i++) { + auto * reg = ggml_backend_reg_get(i); + std::string name = ggml_backend_reg_name(reg); + if (name != "CPU") { + backends.push_back(ggml_backend_reg_name(reg)); + } + } + return backends.empty() ? "CPU" : join(backends, ","); } extern "C" @@ -205,7 +219,7 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, const auto t_pp_start = ggml_time_us(); if (llama_decode(g_context, *g_batch) != 0) { - LOGw("llama_decode() failed during prompt processing"); + LOGe("llama_decode() failed during prompt processing"); } const auto t_pp_end = ggml_time_us(); @@ -216,18 +230,15 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, llama_memory_clear(llama_get_memory(g_context), false); const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { - common_batch_clear(*g_batch); for (j = 0; j < pl; j++) { common_batch_add(*g_batch, 0, i, {j}, true); } - LOGi("llama_decode() text generation: %d", i); if (llama_decode(g_context, *g_batch) != 0) { - LOGw("llama_decode() failed during text generation"); + LOGe("llama_decode() failed during text generation"); } } - const auto t_tg_end = ggml_time_us(); llama_memory_clear(llama_get_memory(g_context), false); @@ -282,13 +293,42 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, /** * Prediction loop's long-term and short-term states */ -static int current_position; +static llama_pos current_position; -static int token_predict_budget; +static llama_pos token_predict_budget; static std::string cached_token_chars; -int token_predict_budget; -std::string cached_token_chars; +static int decode_tokens_in_batches( + llama_context *context, + const llama_tokens& tokens, + const llama_pos start_pos, + bool compute_last_logit = false, + llama_batch *batch = g_batch) { + // Process tokens in batches using the global batch + LOGd("Decode %d tokens starting at position %d", tokens.size(), start_pos); + for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) { + int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE); + common_batch_clear(*batch); + LOGv("Preparing a batch size of %d starting at: %d", cur_batch_size, i); + + // Add tokens to the batch with proper positions + for (int j = 0; j < cur_batch_size; j++) { + llama_token token_id = tokens[i + j]; + llama_pos position = start_pos + i + j; + bool want_logit = compute_last_logit && (i + j == tokens.size() - 1); + common_batch_add(*batch, token_id, position, {0}, want_logit); + } + + // Decode this batch + int decode_result = llama_decode(context, *batch); + if (decode_result) { + LOGe("llama_decode failed w/ %d", decode_result); + return -1; + } + } + + return 0; +} extern "C" JNIEXPORT jint JNICALL @@ -316,22 +356,14 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); } - // Add system prompt tokens to batch - common_batch_clear(*g_batch); - // TODO-hyin: support batch processing! - for (int i = 0; i < system_tokens.size(); i++) { - common_batch_add(*g_batch, system_tokens[i], i, {0}, false); - } - - // Decode batch - int decode_result = llama_decode(g_context, *g_batch); - if (decode_result != 0) { - LOGe("llama_decode() failed: %d", decode_result); + // Decode system tokens in batches + if (decode_tokens_in_batches(g_context, system_tokens, current_position)) { + LOGe("llama_decode() failed!"); return -1; } // Update position - current_position = system_tokens.size(); + current_position = (int) system_tokens.size(); return 0; } @@ -360,29 +392,21 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt( } // Check if context space is enough for desired tokens - int desired_budget = current_position + user_tokens.size() + n_predict; + int desired_budget = current_position + (int) user_tokens.size() + n_predict; if (desired_budget > llama_n_ctx(g_context)) { LOGe("error: total tokens exceed context size"); return -1; } token_predict_budget = desired_budget; - // Add user prompt tokens to batch - common_batch_clear(*g_batch); - for (int i = 0; i < user_tokens.size(); i++) { - common_batch_add(*g_batch, user_tokens[i], current_position + i, {0}, false); - } - g_batch->logits[g_batch->n_tokens - 1] = true; // Set logits true only for last token - - // Decode batch - int decode_result = llama_decode(g_context, *g_batch); - if (decode_result != 0) { - LOGe("llama_decode() failed: %d", decode_result); + // Decode user tokens in batches + if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) { + LOGe("llama_decode() failed!"); return -2; } // Update position - current_position += user_tokens.size(); // Update position + current_position += (int) user_tokens.size(); // Update position return 0; } @@ -436,13 +460,7 @@ Java_android_llama_cpp_LLamaAndroid_predictLoop( const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1); common_sampler_accept(g_sampler, new_token_id, true); - // Stop if next token is EOG - if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) { - LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id); - return nullptr; - } - - // Update the context with the new token + // Populate the batch with new token, then decode common_batch_clear(*g_batch); common_batch_add(*g_batch, new_token_id, current_position, {0}, true); if (llama_decode(g_context, *g_batch) != 0) { @@ -450,22 +468,28 @@ Java_android_llama_cpp_LLamaAndroid_predictLoop( return nullptr; } - // Convert to text + // Update position + current_position++; + + // Stop if next token is EOG + if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) { + LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id); + return nullptr; + } + + // If not EOG, convert to text auto new_token_chars = common_token_to_piece(g_context, new_token_id); cached_token_chars += new_token_chars; - // Create Java string + // Create and return Java string jstring result = nullptr; if (is_valid_utf8(cached_token_chars.c_str())) { result = env->NewStringUTF(cached_token_chars.c_str()); - LOGd("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str()); + LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str()); cached_token_chars.clear(); } else { - LOGd("id: %d,\tappend to cache", new_token_id); + LOGv("id: %d,\tappend to cache", new_token_id); result = env->NewStringUTF(""); } - - // Update position - current_position++; return result; }