diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 366d721fb7..3713cff4fd 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -29,10 +29,10 @@ constexpr int CONTEXT_SIZE = 4096; constexpr int BATCH_SIZE = 512; constexpr float SAMPLER_TEMP = 0.3f; -llama_model * model; -llama_context * context; -llama_batch * batch; -common_sampler * sampler; +static llama_model * g_model; +static llama_context * g_context; +static llama_batch * g_batch; +static common_sampler * g_sampler; static void log_callback(ggml_log_level level, const char *fmt, void *data) { int priority; @@ -86,17 +86,17 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file const auto *path_to_model = env->GetStringUTFChars(filename, 0); LOGd("Loading model from: %s", path_to_model); - model = llama_model_load_from_file(path_to_model, model_params); + auto *model = llama_model_load_from_file(path_to_model, model_params); env->ReleaseStringUTFChars(filename, path_to_model); - if (!model) { LOGe("load_model() failed"); return -1; } + g_model = model; return 0; } -int init_context() { +static int init_context(llama_model *model) { if (!model) { LOGe("init_context(): model cannot be null"); return -1; @@ -113,18 +113,18 @@ int init_context() { ctx_params.n_ctx = CONTEXT_SIZE; ctx_params.n_threads = n_threads; ctx_params.n_threads_batch = n_threads; - - context = llama_init_from_model(model, ctx_params); + auto *context = llama_init_from_model(g_model, ctx_params); if (!context) { LOGe("llama_new_context_with_model() returned null)"); return -2; } + g_context = context; return 0; } -void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) { +static void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) { // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. - batch = new llama_batch{ + auto *batch = new llama_batch{ 0, nullptr, nullptr, @@ -147,18 +147,19 @@ void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) { batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); } batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + g_batch = batch; } void new_sampler(float temp) { common_params_sampling sparams; sparams.temp = temp; - sampler = common_sampler_init(model, sparams); + g_sampler = common_sampler_init(g_model, sparams); } extern "C" JNIEXPORT jint JNICALL Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) { - int ret = init_context(); + int ret = init_context(g_model); if (ret != 0) { return ret; } new_batch(BATCH_SIZE); new_sampler(SAMPLER_TEMP); @@ -168,11 +169,11 @@ Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unus extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) { - llama_model_free(model); - llama_free(context); + llama_model_free(g_model); + llama_free(g_context); llama_backend_free(); - delete batch; - common_sampler_free(sampler); + delete g_batch; + common_sampler_free(g_sampler); } extern "C" @@ -183,7 +184,7 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, auto pp_std = 0.0; auto tg_std = 0.0; - const uint32_t n_ctx = llama_n_ctx(context); + const uint32_t n_ctx = llama_n_ctx(g_context); LOGi("n_ctx = %d", n_ctx); @@ -192,18 +193,18 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, for (nri = 0; nri < nr; nri++) { LOGi("Benchmark prompt processing (pp)"); - common_batch_clear(*batch); + common_batch_clear(*g_batch); const int n_tokens = pp; for (i = 0; i < n_tokens; i++) { - common_batch_add(*batch, 0, i, {0}, false); + common_batch_add(*g_batch, 0, i, {0}, false); } - batch->logits[batch->n_tokens - 1] = true; - llama_memory_clear(llama_get_memory(context), false); + g_batch->logits[g_batch->n_tokens - 1] = true; + llama_memory_clear(llama_get_memory(g_context), false); const auto t_pp_start = ggml_time_us(); - if (llama_decode(context, *batch) != 0) { + if (llama_decode(g_context, *g_batch) != 0) { LOGw("llama_decode() failed during prompt processing"); } const auto t_pp_end = ggml_time_us(); @@ -212,24 +213,24 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, LOGi("Benchmark text generation (tg)"); - llama_memory_clear(llama_get_memory(context), false); + llama_memory_clear(llama_get_memory(g_context), false); const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { - common_batch_clear(*batch); + common_batch_clear(*g_batch); for (j = 0; j < pl; j++) { - common_batch_add(*batch, 0, i, {j}, true); + common_batch_add(*g_batch, 0, i, {j}, true); } LOGi("llama_decode() text generation: %d", i); - if (llama_decode(context, *batch) != 0) { + if (llama_decode(g_context, *g_batch) != 0) { LOGw("llama_decode() failed during text generation"); } } const auto t_tg_end = ggml_time_us(); - llama_memory_clear(llama_get_memory(context), false); + llama_memory_clear(llama_get_memory(g_context), false); const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; @@ -258,10 +259,10 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, } char model_desc[128]; - llama_model_desc(model, model_desc, sizeof(model_desc)); + llama_model_desc(g_model, model_desc, sizeof(model_desc)); - const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0; - const auto model_n_params = double(llama_model_n_params(model)) / 1e9; + const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0; + const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9; const auto *const backend = "(Android)"; // TODO: What should this be? @@ -279,9 +280,12 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, /** - * Prediction loop's states + * Prediction loop's long-term and short-term states */ -int current_position; +static int current_position; + +static int token_predict_budget; +static std::string cached_token_chars; int token_predict_budget; std::string cached_token_chars; @@ -295,7 +299,7 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( ) { // Reset long-term states and reset KV cache current_position = 0; - llama_memory_clear(llama_get_memory(context), false); + llama_memory_clear(llama_get_memory(g_context), false); // Reset short-term states token_predict_budget = 0; @@ -304,23 +308,23 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( // Obtain and tokenize system prompt const auto *const system_text = env->GetStringUTFChars(jsystem_prompt, nullptr); LOGd("System prompt received: \n%s", system_text); - const auto system_tokens = common_tokenize(context, system_text, true, true); + const auto system_tokens = common_tokenize(g_context, system_text, true, true); env->ReleaseStringUTFChars(jsystem_prompt, system_text); // Print each token in verbose mode for (auto id : system_tokens) { - LOGv("token: `%s`\t -> `%d`", common_token_to_piece(context, id).c_str(), id); + LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); } // Add system prompt tokens to batch - common_batch_clear(*batch); + common_batch_clear(*g_batch); // TODO-hyin: support batch processing! for (int i = 0; i < system_tokens.size(); i++) { - common_batch_add(*batch, system_tokens[i], i, {0}, false); + common_batch_add(*g_batch, system_tokens[i], i, {0}, false); } // Decode batch - int decode_result = llama_decode(context, *batch); + int decode_result = llama_decode(g_context, *g_batch); if (decode_result != 0) { LOGe("llama_decode() failed: %d", decode_result); return -1; @@ -347,31 +351,31 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt( // Obtain and tokenize user prompt const auto *const user_text = env->GetStringUTFChars(juser_prompt, nullptr); LOGd("User prompt received: \n%s", user_text); - const auto user_tokens = common_tokenize(context, user_text, true, true); + const auto user_tokens = common_tokenize(g_context, user_text, true, true); env->ReleaseStringUTFChars(juser_prompt, user_text); // Print each token in verbose mode for (auto id : user_tokens) { - LOGv("token: `%s`\t -> `%d`", common_token_to_piece(context, id).c_str(), id); + LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); } // Check if context space is enough for desired tokens int desired_budget = current_position + user_tokens.size() + n_predict; - if (desired_budget > llama_n_ctx(context)) { + if (desired_budget > llama_n_ctx(g_context)) { LOGe("error: total tokens exceed context size"); return -1; } token_predict_budget = desired_budget; // Add user prompt tokens to batch - common_batch_clear(*batch); + common_batch_clear(*g_batch); for (int i = 0; i < user_tokens.size(); i++) { - common_batch_add(*batch, user_tokens[i], current_position + i, {0}, false); + common_batch_add(*g_batch, user_tokens[i], current_position + i, {0}, false); } - batch->logits[batch->n_tokens - 1] = true; // Set logits true only for last token + g_batch->logits[g_batch->n_tokens - 1] = true; // Set logits true only for last token // Decode batch - int decode_result = llama_decode(context, *batch); + int decode_result = llama_decode(g_context, *g_batch); if (decode_result != 0) { LOGe("llama_decode() failed: %d", decode_result); return -2; @@ -382,7 +386,7 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt( return 0; } -bool is_valid_utf8(const char *string) { +static bool is_valid_utf8(const char *string) { if (!string) { return true; } const auto *bytes = (const unsigned char *) string; @@ -429,25 +433,25 @@ Java_android_llama_cpp_LLamaAndroid_predictLoop( } // Sample next token - const auto new_token_id = common_sampler_sample(sampler, context, -1); - common_sampler_accept(sampler, new_token_id, true); + const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1); + common_sampler_accept(g_sampler, new_token_id, true); // Stop if next token is EOG - if (llama_vocab_is_eog(llama_model_get_vocab(model), new_token_id)) { + if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) { LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id); return nullptr; } // Update the context with the new token - common_batch_clear(*batch); - common_batch_add(*batch, new_token_id, current_position, {0}, true); - if (llama_decode(context, *batch) != 0) { + common_batch_clear(*g_batch); + common_batch_add(*g_batch, new_token_id, current_position, {0}, true); + if (llama_decode(g_context, *g_batch) != 0) { LOGe("llama_decode() failed for generated token"); return nullptr; } // Convert to text - auto new_token_chars = common_token_to_piece(context, new_token_id); + auto new_token_chars = common_token_to_piece(g_context, new_token_id); cached_token_chars += new_token_chars; // Create Java string