Perf: allocate `llama_batch` on stack with `llama_batch_init`
This commit is contained in:
parent
2b52563737
commit
e1bc87610e
|
|
@ -44,9 +44,9 @@ constexpr float DEFAULT_SAMPLER_TEMP = 0.3f;
|
|||
|
||||
static llama_model * g_model;
|
||||
static llama_context * g_context;
|
||||
static llama_batch * g_batch;
|
||||
static common_sampler * g_sampler;
|
||||
static llama_batch g_batch;
|
||||
static common_chat_templates_ptr g_chat_templates;
|
||||
static common_sampler * g_sampler;
|
||||
|
||||
static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
||||
int priority;
|
||||
|
|
@ -140,29 +140,6 @@ static llama_context *init_context(llama_model *model) {
|
|||
return context;
|
||||
}
|
||||
|
||||
static llama_batch *new_batch(int n_tokens, int n_seq_max = 1) {
|
||||
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
||||
auto *batch = new llama_batch{
|
||||
0,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
};
|
||||
|
||||
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
|
||||
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
|
||||
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
|
||||
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
||||
}
|
||||
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
||||
return batch;
|
||||
}
|
||||
|
||||
static common_sampler *new_sampler(float temp) {
|
||||
common_params_sampling sparams;
|
||||
sparams.temp = temp;
|
||||
|
|
@ -175,18 +152,18 @@ Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unus
|
|||
auto *context = init_context(g_model);
|
||||
if (!context) { return 1; }
|
||||
g_context = context;
|
||||
g_batch = new_batch(BATCH_SIZE);
|
||||
g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP);
|
||||
g_batch = llama_batch_init(BATCH_SIZE, 0, 1);
|
||||
g_chat_templates = common_chat_templates_init(g_model, "");
|
||||
g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP);
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||
g_chat_templates.reset();
|
||||
common_sampler_free(g_sampler);
|
||||
delete g_batch;
|
||||
g_chat_templates.reset();
|
||||
llama_batch_free(g_batch);
|
||||
llama_free(g_context);
|
||||
llama_model_free(g_model);
|
||||
llama_backend_free();
|
||||
|
|
@ -222,18 +199,18 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
|||
for (nri = 0; nri < nr; nri++) {
|
||||
LOGi("Benchmark prompt processing (pp)");
|
||||
|
||||
common_batch_clear(*g_batch);
|
||||
common_batch_clear(g_batch);
|
||||
|
||||
const int n_tokens = pp;
|
||||
for (i = 0; i < n_tokens; i++) {
|
||||
common_batch_add(*g_batch, 0, i, {0}, false);
|
||||
common_batch_add(g_batch, 0, i, {0}, false);
|
||||
}
|
||||
|
||||
g_batch->logits[g_batch->n_tokens - 1] = true;
|
||||
g_batch.logits[g_batch.n_tokens - 1] = true;
|
||||
llama_memory_clear(llama_get_memory(g_context), false);
|
||||
|
||||
const auto t_pp_start = ggml_time_us();
|
||||
if (llama_decode(g_context, *g_batch) != 0) {
|
||||
if (llama_decode(g_context, g_batch) != 0) {
|
||||
LOGe("llama_decode() failed during prompt processing");
|
||||
}
|
||||
const auto t_pp_end = ggml_time_us();
|
||||
|
|
@ -245,12 +222,12 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
|||
llama_memory_clear(llama_get_memory(g_context), false);
|
||||
const auto t_tg_start = ggml_time_us();
|
||||
for (i = 0; i < tg; i++) {
|
||||
common_batch_clear(*g_batch);
|
||||
common_batch_clear(g_batch);
|
||||
for (j = 0; j < pl; j++) {
|
||||
common_batch_add(*g_batch, 0, i, {j}, true);
|
||||
common_batch_add(g_batch, 0, i, {j}, true);
|
||||
}
|
||||
|
||||
if (llama_decode(g_context, *g_batch) != 0) {
|
||||
if (llama_decode(g_context, g_batch) != 0) {
|
||||
LOGe("llama_decode() failed during text generation");
|
||||
}
|
||||
}
|
||||
|
|
@ -371,15 +348,15 @@ static void reset_short_term_states() {
|
|||
|
||||
static int decode_tokens_in_batches(
|
||||
llama_context *context,
|
||||
llama_batch batch,
|
||||
const llama_tokens &tokens,
|
||||
const llama_pos start_pos,
|
||||
bool compute_last_logit = false,
|
||||
llama_batch *batch = g_batch) {
|
||||
const bool compute_last_logit = false) {
|
||||
// Process tokens in batches using the global batch
|
||||
LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos);
|
||||
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
||||
const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
||||
common_batch_clear(*batch);
|
||||
common_batch_clear(batch);
|
||||
LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i);
|
||||
|
||||
// Shift context if current batch cannot fit into the context
|
||||
|
|
@ -393,11 +370,11 @@ static int decode_tokens_in_batches(
|
|||
const llama_token token_id = tokens[i + j];
|
||||
const llama_pos position = start_pos + i + j;
|
||||
const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
|
||||
common_batch_add(*batch, token_id, position, {0}, want_logit);
|
||||
common_batch_add(batch, token_id, position, {0}, want_logit);
|
||||
}
|
||||
|
||||
// Decode this batch
|
||||
const int decode_result = llama_decode(context, *batch);
|
||||
const int decode_result = llama_decode(context, batch);
|
||||
if (decode_result) {
|
||||
LOGe("%s: llama_decode failed w/ %d", __func__, decode_result);
|
||||
return 1;
|
||||
|
|
@ -445,7 +422,7 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
|||
}
|
||||
|
||||
// Decode system tokens in batches
|
||||
if (decode_tokens_in_batches(g_context, system_tokens, current_position)) {
|
||||
if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) {
|
||||
LOGe("%s: llama_decode() failed!", __func__);
|
||||
return 2;
|
||||
}
|
||||
|
|
@ -494,7 +471,7 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
|||
}
|
||||
|
||||
// Decode user tokens in batches
|
||||
if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) {
|
||||
if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) {
|
||||
LOGe("%s: llama_decode() failed!", __func__);
|
||||
return 2;
|
||||
}
|
||||
|
|
@ -562,9 +539,9 @@ Java_android_llama_cpp_LLamaAndroid_generateNextToken(
|
|||
common_sampler_accept(g_sampler, new_token_id, true);
|
||||
|
||||
// Populate the batch with new token, then decode
|
||||
common_batch_clear(*g_batch);
|
||||
common_batch_add(*g_batch, new_token_id, current_position, {0}, true);
|
||||
if (llama_decode(g_context, *g_batch) != 0) {
|
||||
common_batch_clear(g_batch);
|
||||
common_batch_add(g_batch, new_token_id, current_position, {0}, true);
|
||||
if (llama_decode(g_context, g_batch) != 0) {
|
||||
LOGe("%s: llama_decode() failed for generated token", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue