From 37f3e1c415fc64bdd17e714411026f3cbc86b409 Mon Sep 17 00:00:00 2001 From: Han Yin Date: Fri, 11 Apr 2025 10:21:08 -0700 Subject: [PATCH] Feature: use local llama_context for benchmarking; support context init with custom context size --- .../llama/src/main/cpp/llama-android.cpp | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index b098940d44..4b51110a65 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -98,7 +98,7 @@ Java_android_llama_cpp_LLamaAndroid_load(JNIEnv *env, jobject, jstring jmodel_pa return 0; } -static llama_context *init_context(llama_model *model) { +static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) { if (!model) { LOGe("%s: model cannot be null", __func__); return nullptr; @@ -113,11 +113,11 @@ static llama_context *init_context(llama_model *model) { // Context parameters setup llama_context_params ctx_params = llama_context_default_params(); const int trained_context_size = llama_model_n_ctx_train(model); - if (DEFAULT_CONTEXT_SIZE > trained_context_size) { - LOGe("%s: Model was trained with only %d context size! Enforcing %d context size...", - __func__, trained_context_size, DEFAULT_CONTEXT_SIZE); + if (n_ctx > trained_context_size) { + LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...", + __func__, trained_context_size, n_ctx); } - ctx_params.n_ctx = DEFAULT_CONTEXT_SIZE; + ctx_params.n_ctx = n_ctx; ctx_params.n_batch = BATCH_SIZE; ctx_params.n_ubatch = BATCH_SIZE; ctx_params.n_threads = n_threads; @@ -169,13 +169,19 @@ extern "C" JNIEXPORT jstring JNICALL Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) { + auto *context = init_context(g_model, pp); + if (!context) { + const auto err_msg = "Fail to init_context! Bench aborted."; + LOGe(err_msg); + return env->NewStringUTF(err_msg); + } + auto pp_avg = 0.0; auto tg_avg = 0.0; auto pp_std = 0.0; auto tg_std = 0.0; - const uint32_t n_ctx = llama_n_ctx(g_context); - + const uint32_t n_ctx = llama_n_ctx(context); LOGi("n_ctx = %d", n_ctx); int i, j; @@ -191,10 +197,10 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, } g_batch.logits[g_batch.n_tokens - 1] = true; - llama_memory_clear(llama_get_memory(g_context), false); + llama_memory_clear(llama_get_memory(context), false); const auto t_pp_start = ggml_time_us(); - if (llama_decode(g_context, g_batch) != 0) { + if (llama_decode(context, g_batch) != 0) { LOGe("llama_decode() failed during prompt processing"); } const auto t_pp_end = ggml_time_us(); @@ -203,7 +209,7 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, LOGi("Benchmark text generation (tg)"); - llama_memory_clear(llama_get_memory(g_context), false); + llama_memory_clear(llama_get_memory(context), false); const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { common_batch_clear(g_batch); @@ -211,13 +217,13 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, common_batch_add(g_batch, 0, i, {j}, true); } - if (llama_decode(g_context, g_batch) != 0) { + if (llama_decode(context, g_batch) != 0) { LOGe("llama_decode() failed during text generation"); } } const auto t_tg_end = ggml_time_us(); - llama_memory_clear(llama_get_memory(g_context), false); + llama_memory_clear(llama_get_memory(context), false); const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; @@ -234,6 +240,8 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg); } + llama_free(context); + pp_avg /= double(nr); tg_avg /= double(nr);