Clang-tidy linting: make functions & global variables static
This commit is contained in:
parent
f44882aeeb
commit
7bbb53aaf8
|
|
@ -29,10 +29,10 @@ constexpr int CONTEXT_SIZE = 4096;
|
||||||
constexpr int BATCH_SIZE = 512;
|
constexpr int BATCH_SIZE = 512;
|
||||||
constexpr float SAMPLER_TEMP = 0.3f;
|
constexpr float SAMPLER_TEMP = 0.3f;
|
||||||
|
|
||||||
llama_model * model;
|
static llama_model * g_model;
|
||||||
llama_context * context;
|
static llama_context * g_context;
|
||||||
llama_batch * batch;
|
static llama_batch * g_batch;
|
||||||
common_sampler * sampler;
|
static common_sampler * g_sampler;
|
||||||
|
|
||||||
static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
||||||
int priority;
|
int priority;
|
||||||
|
|
@ -86,17 +86,17 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file
|
||||||
const auto *path_to_model = env->GetStringUTFChars(filename, 0);
|
const auto *path_to_model = env->GetStringUTFChars(filename, 0);
|
||||||
LOGd("Loading model from: %s", path_to_model);
|
LOGd("Loading model from: %s", path_to_model);
|
||||||
|
|
||||||
model = llama_model_load_from_file(path_to_model, model_params);
|
auto *model = llama_model_load_from_file(path_to_model, model_params);
|
||||||
env->ReleaseStringUTFChars(filename, path_to_model);
|
env->ReleaseStringUTFChars(filename, path_to_model);
|
||||||
|
|
||||||
if (!model) {
|
if (!model) {
|
||||||
LOGe("load_model() failed");
|
LOGe("load_model() failed");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
g_model = model;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int init_context() {
|
static int init_context(llama_model *model) {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
LOGe("init_context(): model cannot be null");
|
LOGe("init_context(): model cannot be null");
|
||||||
return -1;
|
return -1;
|
||||||
|
|
@ -113,18 +113,18 @@ int init_context() {
|
||||||
ctx_params.n_ctx = CONTEXT_SIZE;
|
ctx_params.n_ctx = CONTEXT_SIZE;
|
||||||
ctx_params.n_threads = n_threads;
|
ctx_params.n_threads = n_threads;
|
||||||
ctx_params.n_threads_batch = n_threads;
|
ctx_params.n_threads_batch = n_threads;
|
||||||
|
auto *context = llama_init_from_model(g_model, ctx_params);
|
||||||
context = llama_init_from_model(model, ctx_params);
|
|
||||||
if (!context) {
|
if (!context) {
|
||||||
LOGe("llama_new_context_with_model() returned null)");
|
LOGe("llama_new_context_with_model() returned null)");
|
||||||
return -2;
|
return -2;
|
||||||
}
|
}
|
||||||
|
g_context = context;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
|
static void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
|
||||||
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
||||||
batch = new llama_batch{
|
auto *batch = new llama_batch{
|
||||||
0,
|
0,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
|
|
@ -147,18 +147,19 @@ void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
|
||||||
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
||||||
}
|
}
|
||||||
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
||||||
|
g_batch = batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
void new_sampler(float temp) {
|
void new_sampler(float temp) {
|
||||||
common_params_sampling sparams;
|
common_params_sampling sparams;
|
||||||
sparams.temp = temp;
|
sparams.temp = temp;
|
||||||
sampler = common_sampler_init(model, sparams);
|
g_sampler = common_sampler_init(g_model, sparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jint JNICALL
|
JNIEXPORT jint JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
|
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
|
||||||
int ret = init_context();
|
int ret = init_context(g_model);
|
||||||
if (ret != 0) { return ret; }
|
if (ret != 0) { return ret; }
|
||||||
new_batch(BATCH_SIZE);
|
new_batch(BATCH_SIZE);
|
||||||
new_sampler(SAMPLER_TEMP);
|
new_sampler(SAMPLER_TEMP);
|
||||||
|
|
@ -168,11 +169,11 @@ Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unus
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) {
|
Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||||
llama_model_free(model);
|
llama_model_free(g_model);
|
||||||
llama_free(context);
|
llama_free(g_context);
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
delete batch;
|
delete g_batch;
|
||||||
common_sampler_free(sampler);
|
common_sampler_free(g_sampler);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
|
|
@ -183,7 +184,7 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
||||||
auto pp_std = 0.0;
|
auto pp_std = 0.0;
|
||||||
auto tg_std = 0.0;
|
auto tg_std = 0.0;
|
||||||
|
|
||||||
const uint32_t n_ctx = llama_n_ctx(context);
|
const uint32_t n_ctx = llama_n_ctx(g_context);
|
||||||
|
|
||||||
LOGi("n_ctx = %d", n_ctx);
|
LOGi("n_ctx = %d", n_ctx);
|
||||||
|
|
||||||
|
|
@ -192,18 +193,18 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
||||||
for (nri = 0; nri < nr; nri++) {
|
for (nri = 0; nri < nr; nri++) {
|
||||||
LOGi("Benchmark prompt processing (pp)");
|
LOGi("Benchmark prompt processing (pp)");
|
||||||
|
|
||||||
common_batch_clear(*batch);
|
common_batch_clear(*g_batch);
|
||||||
|
|
||||||
const int n_tokens = pp;
|
const int n_tokens = pp;
|
||||||
for (i = 0; i < n_tokens; i++) {
|
for (i = 0; i < n_tokens; i++) {
|
||||||
common_batch_add(*batch, 0, i, {0}, false);
|
common_batch_add(*g_batch, 0, i, {0}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
batch->logits[batch->n_tokens - 1] = true;
|
g_batch->logits[g_batch->n_tokens - 1] = true;
|
||||||
llama_memory_clear(llama_get_memory(context), false);
|
llama_memory_clear(llama_get_memory(g_context), false);
|
||||||
|
|
||||||
const auto t_pp_start = ggml_time_us();
|
const auto t_pp_start = ggml_time_us();
|
||||||
if (llama_decode(context, *batch) != 0) {
|
if (llama_decode(g_context, *g_batch) != 0) {
|
||||||
LOGw("llama_decode() failed during prompt processing");
|
LOGw("llama_decode() failed during prompt processing");
|
||||||
}
|
}
|
||||||
const auto t_pp_end = ggml_time_us();
|
const auto t_pp_end = ggml_time_us();
|
||||||
|
|
@ -212,24 +213,24 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
||||||
|
|
||||||
LOGi("Benchmark text generation (tg)");
|
LOGi("Benchmark text generation (tg)");
|
||||||
|
|
||||||
llama_memory_clear(llama_get_memory(context), false);
|
llama_memory_clear(llama_get_memory(g_context), false);
|
||||||
const auto t_tg_start = ggml_time_us();
|
const auto t_tg_start = ggml_time_us();
|
||||||
for (i = 0; i < tg; i++) {
|
for (i = 0; i < tg; i++) {
|
||||||
|
|
||||||
common_batch_clear(*batch);
|
common_batch_clear(*g_batch);
|
||||||
for (j = 0; j < pl; j++) {
|
for (j = 0; j < pl; j++) {
|
||||||
common_batch_add(*batch, 0, i, {j}, true);
|
common_batch_add(*g_batch, 0, i, {j}, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGi("llama_decode() text generation: %d", i);
|
LOGi("llama_decode() text generation: %d", i);
|
||||||
if (llama_decode(context, *batch) != 0) {
|
if (llama_decode(g_context, *g_batch) != 0) {
|
||||||
LOGw("llama_decode() failed during text generation");
|
LOGw("llama_decode() failed during text generation");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto t_tg_end = ggml_time_us();
|
const auto t_tg_end = ggml_time_us();
|
||||||
|
|
||||||
llama_memory_clear(llama_get_memory(context), false);
|
llama_memory_clear(llama_get_memory(g_context), false);
|
||||||
|
|
||||||
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
|
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
|
||||||
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
|
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
|
||||||
|
|
@ -258,10 +259,10 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
||||||
}
|
}
|
||||||
|
|
||||||
char model_desc[128];
|
char model_desc[128];
|
||||||
llama_model_desc(model, model_desc, sizeof(model_desc));
|
llama_model_desc(g_model, model_desc, sizeof(model_desc));
|
||||||
|
|
||||||
const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
|
const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0;
|
||||||
const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
|
const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9;
|
||||||
|
|
||||||
const auto *const backend = "(Android)"; // TODO: What should this be?
|
const auto *const backend = "(Android)"; // TODO: What should this be?
|
||||||
|
|
||||||
|
|
@ -279,9 +280,12 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Prediction loop's states
|
* Prediction loop's long-term and short-term states
|
||||||
*/
|
*/
|
||||||
int current_position;
|
static int current_position;
|
||||||
|
|
||||||
|
static int token_predict_budget;
|
||||||
|
static std::string cached_token_chars;
|
||||||
|
|
||||||
int token_predict_budget;
|
int token_predict_budget;
|
||||||
std::string cached_token_chars;
|
std::string cached_token_chars;
|
||||||
|
|
@ -295,7 +299,7 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||||
) {
|
) {
|
||||||
// Reset long-term states and reset KV cache
|
// Reset long-term states and reset KV cache
|
||||||
current_position = 0;
|
current_position = 0;
|
||||||
llama_memory_clear(llama_get_memory(context), false);
|
llama_memory_clear(llama_get_memory(g_context), false);
|
||||||
|
|
||||||
// Reset short-term states
|
// Reset short-term states
|
||||||
token_predict_budget = 0;
|
token_predict_budget = 0;
|
||||||
|
|
@ -304,23 +308,23 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||||
// Obtain and tokenize system prompt
|
// Obtain and tokenize system prompt
|
||||||
const auto *const system_text = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
const auto *const system_text = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
||||||
LOGd("System prompt received: \n%s", system_text);
|
LOGd("System prompt received: \n%s", system_text);
|
||||||
const auto system_tokens = common_tokenize(context, system_text, true, true);
|
const auto system_tokens = common_tokenize(g_context, system_text, true, true);
|
||||||
env->ReleaseStringUTFChars(jsystem_prompt, system_text);
|
env->ReleaseStringUTFChars(jsystem_prompt, system_text);
|
||||||
|
|
||||||
// Print each token in verbose mode
|
// Print each token in verbose mode
|
||||||
for (auto id : system_tokens) {
|
for (auto id : system_tokens) {
|
||||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(context, id).c_str(), id);
|
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add system prompt tokens to batch
|
// Add system prompt tokens to batch
|
||||||
common_batch_clear(*batch);
|
common_batch_clear(*g_batch);
|
||||||
// TODO-hyin: support batch processing!
|
// TODO-hyin: support batch processing!
|
||||||
for (int i = 0; i < system_tokens.size(); i++) {
|
for (int i = 0; i < system_tokens.size(); i++) {
|
||||||
common_batch_add(*batch, system_tokens[i], i, {0}, false);
|
common_batch_add(*g_batch, system_tokens[i], i, {0}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode batch
|
// Decode batch
|
||||||
int decode_result = llama_decode(context, *batch);
|
int decode_result = llama_decode(g_context, *g_batch);
|
||||||
if (decode_result != 0) {
|
if (decode_result != 0) {
|
||||||
LOGe("llama_decode() failed: %d", decode_result);
|
LOGe("llama_decode() failed: %d", decode_result);
|
||||||
return -1;
|
return -1;
|
||||||
|
|
@ -347,31 +351,31 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
||||||
// Obtain and tokenize user prompt
|
// Obtain and tokenize user prompt
|
||||||
const auto *const user_text = env->GetStringUTFChars(juser_prompt, nullptr);
|
const auto *const user_text = env->GetStringUTFChars(juser_prompt, nullptr);
|
||||||
LOGd("User prompt received: \n%s", user_text);
|
LOGd("User prompt received: \n%s", user_text);
|
||||||
const auto user_tokens = common_tokenize(context, user_text, true, true);
|
const auto user_tokens = common_tokenize(g_context, user_text, true, true);
|
||||||
env->ReleaseStringUTFChars(juser_prompt, user_text);
|
env->ReleaseStringUTFChars(juser_prompt, user_text);
|
||||||
|
|
||||||
// Print each token in verbose mode
|
// Print each token in verbose mode
|
||||||
for (auto id : user_tokens) {
|
for (auto id : user_tokens) {
|
||||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(context, id).c_str(), id);
|
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if context space is enough for desired tokens
|
// Check if context space is enough for desired tokens
|
||||||
int desired_budget = current_position + user_tokens.size() + n_predict;
|
int desired_budget = current_position + user_tokens.size() + n_predict;
|
||||||
if (desired_budget > llama_n_ctx(context)) {
|
if (desired_budget > llama_n_ctx(g_context)) {
|
||||||
LOGe("error: total tokens exceed context size");
|
LOGe("error: total tokens exceed context size");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
token_predict_budget = desired_budget;
|
token_predict_budget = desired_budget;
|
||||||
|
|
||||||
// Add user prompt tokens to batch
|
// Add user prompt tokens to batch
|
||||||
common_batch_clear(*batch);
|
common_batch_clear(*g_batch);
|
||||||
for (int i = 0; i < user_tokens.size(); i++) {
|
for (int i = 0; i < user_tokens.size(); i++) {
|
||||||
common_batch_add(*batch, user_tokens[i], current_position + i, {0}, false);
|
common_batch_add(*g_batch, user_tokens[i], current_position + i, {0}, false);
|
||||||
}
|
}
|
||||||
batch->logits[batch->n_tokens - 1] = true; // Set logits true only for last token
|
g_batch->logits[g_batch->n_tokens - 1] = true; // Set logits true only for last token
|
||||||
|
|
||||||
// Decode batch
|
// Decode batch
|
||||||
int decode_result = llama_decode(context, *batch);
|
int decode_result = llama_decode(g_context, *g_batch);
|
||||||
if (decode_result != 0) {
|
if (decode_result != 0) {
|
||||||
LOGe("llama_decode() failed: %d", decode_result);
|
LOGe("llama_decode() failed: %d", decode_result);
|
||||||
return -2;
|
return -2;
|
||||||
|
|
@ -382,7 +386,7 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_valid_utf8(const char *string) {
|
static bool is_valid_utf8(const char *string) {
|
||||||
if (!string) { return true; }
|
if (!string) { return true; }
|
||||||
|
|
||||||
const auto *bytes = (const unsigned char *) string;
|
const auto *bytes = (const unsigned char *) string;
|
||||||
|
|
@ -429,25 +433,25 @@ Java_android_llama_cpp_LLamaAndroid_predictLoop(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sample next token
|
// Sample next token
|
||||||
const auto new_token_id = common_sampler_sample(sampler, context, -1);
|
const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1);
|
||||||
common_sampler_accept(sampler, new_token_id, true);
|
common_sampler_accept(g_sampler, new_token_id, true);
|
||||||
|
|
||||||
// Stop if next token is EOG
|
// Stop if next token is EOG
|
||||||
if (llama_vocab_is_eog(llama_model_get_vocab(model), new_token_id)) {
|
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
|
||||||
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
|
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the context with the new token
|
// Update the context with the new token
|
||||||
common_batch_clear(*batch);
|
common_batch_clear(*g_batch);
|
||||||
common_batch_add(*batch, new_token_id, current_position, {0}, true);
|
common_batch_add(*g_batch, new_token_id, current_position, {0}, true);
|
||||||
if (llama_decode(context, *batch) != 0) {
|
if (llama_decode(g_context, *g_batch) != 0) {
|
||||||
LOGe("llama_decode() failed for generated token");
|
LOGe("llama_decode() failed for generated token");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to text
|
// Convert to text
|
||||||
auto new_token_chars = common_token_to_piece(context, new_token_id);
|
auto new_token_chars = common_token_to_piece(g_context, new_token_id);
|
||||||
cached_token_chars += new_token_chars;
|
cached_token_chars += new_token_chars;
|
||||||
|
|
||||||
// Create Java string
|
// Create Java string
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue