Feature: use local llama_context for benchmarking; support context init with custom context size

This commit is contained in:
Han Yin 2025-04-11 10:21:08 -07:00
parent 6d2279e9cd
commit 37f3e1c415
1 changed files with 20 additions and 12 deletions

View File

@ -98,7 +98,7 @@ Java_android_llama_cpp_LLamaAndroid_load(JNIEnv *env, jobject, jstring jmodel_pa
return 0;
}
static llama_context *init_context(llama_model *model) {
static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) {
if (!model) {
LOGe("%s: model cannot be null", __func__);
return nullptr;
@ -113,11 +113,11 @@ static llama_context *init_context(llama_model *model) {
// Context parameters setup
llama_context_params ctx_params = llama_context_default_params();
const int trained_context_size = llama_model_n_ctx_train(model);
if (DEFAULT_CONTEXT_SIZE > trained_context_size) {
LOGe("%s: Model was trained with only %d context size! Enforcing %d context size...",
__func__, trained_context_size, DEFAULT_CONTEXT_SIZE);
if (n_ctx > trained_context_size) {
LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...",
__func__, trained_context_size, n_ctx);
}
ctx_params.n_ctx = DEFAULT_CONTEXT_SIZE;
ctx_params.n_ctx = n_ctx;
ctx_params.n_batch = BATCH_SIZE;
ctx_params.n_ubatch = BATCH_SIZE;
ctx_params.n_threads = n_threads;
@ -169,13 +169,19 @@ extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
jint pl, jint nr) {
auto *context = init_context(g_model, pp);
if (!context) {
const auto err_msg = "Fail to init_context! Bench aborted.";
LOGe(err_msg);
return env->NewStringUTF(err_msg);
}
auto pp_avg = 0.0;
auto tg_avg = 0.0;
auto pp_std = 0.0;
auto tg_std = 0.0;
const uint32_t n_ctx = llama_n_ctx(g_context);
const uint32_t n_ctx = llama_n_ctx(context);
LOGi("n_ctx = %d", n_ctx);
int i, j;
@ -191,10 +197,10 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
}
g_batch.logits[g_batch.n_tokens - 1] = true;
llama_memory_clear(llama_get_memory(g_context), false);
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp_start = ggml_time_us();
if (llama_decode(g_context, g_batch) != 0) {
if (llama_decode(context, g_batch) != 0) {
LOGe("llama_decode() failed during prompt processing");
}
const auto t_pp_end = ggml_time_us();
@ -203,7 +209,7 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
LOGi("Benchmark text generation (tg)");
llama_memory_clear(llama_get_memory(g_context), false);
llama_memory_clear(llama_get_memory(context), false);
const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) {
common_batch_clear(g_batch);
@ -211,13 +217,13 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
common_batch_add(g_batch, 0, i, {j}, true);
}
if (llama_decode(g_context, g_batch) != 0) {
if (llama_decode(context, g_batch) != 0) {
LOGe("llama_decode() failed during text generation");
}
}
const auto t_tg_end = ggml_time_us();
llama_memory_clear(llama_get_memory(g_context), false);
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
@ -234,6 +240,8 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
}
llama_free(context);
pp_avg /= double(nr);
tg_avg /= double(nr);