Feature: decode system and user prompt in batches
This commit is contained in:
parent
02465137ca
commit
c14c11dcbd
|
|
@ -111,6 +111,8 @@ static int init_context(llama_model *model) {
|
||||||
// Context parameters setup
|
// Context parameters setup
|
||||||
llama_context_params ctx_params = llama_context_default_params();
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
ctx_params.n_ctx = CONTEXT_SIZE;
|
ctx_params.n_ctx = CONTEXT_SIZE;
|
||||||
|
ctx_params.n_batch = BATCH_SIZE;
|
||||||
|
ctx_params.n_ubatch = BATCH_SIZE;
|
||||||
ctx_params.n_threads = n_threads;
|
ctx_params.n_threads = n_threads;
|
||||||
ctx_params.n_threads_batch = n_threads;
|
ctx_params.n_threads_batch = n_threads;
|
||||||
auto *context = llama_init_from_model(g_model, ctx_params);
|
auto *context = llama_init_from_model(g_model, ctx_params);
|
||||||
|
|
@ -171,9 +173,21 @@ JNIEXPORT void JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) {
|
Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||||
llama_model_free(g_model);
|
llama_model_free(g_model);
|
||||||
llama_free(g_context);
|
llama_free(g_context);
|
||||||
llama_backend_free();
|
|
||||||
delete g_batch;
|
delete g_batch;
|
||||||
common_sampler_free(g_sampler);
|
common_sampler_free(g_sampler);
|
||||||
|
llama_backend_free();
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string get_backend() {
|
||||||
|
std::vector<std::string> backends;
|
||||||
|
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
||||||
|
auto * reg = ggml_backend_reg_get(i);
|
||||||
|
std::string name = ggml_backend_reg_name(reg);
|
||||||
|
if (name != "CPU") {
|
||||||
|
backends.push_back(ggml_backend_reg_name(reg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return backends.empty() ? "CPU" : join(backends, ",");
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
|
|
@ -205,7 +219,7 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
||||||
|
|
||||||
const auto t_pp_start = ggml_time_us();
|
const auto t_pp_start = ggml_time_us();
|
||||||
if (llama_decode(g_context, *g_batch) != 0) {
|
if (llama_decode(g_context, *g_batch) != 0) {
|
||||||
LOGw("llama_decode() failed during prompt processing");
|
LOGe("llama_decode() failed during prompt processing");
|
||||||
}
|
}
|
||||||
const auto t_pp_end = ggml_time_us();
|
const auto t_pp_end = ggml_time_us();
|
||||||
|
|
||||||
|
|
@ -216,18 +230,15 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
||||||
llama_memory_clear(llama_get_memory(g_context), false);
|
llama_memory_clear(llama_get_memory(g_context), false);
|
||||||
const auto t_tg_start = ggml_time_us();
|
const auto t_tg_start = ggml_time_us();
|
||||||
for (i = 0; i < tg; i++) {
|
for (i = 0; i < tg; i++) {
|
||||||
|
|
||||||
common_batch_clear(*g_batch);
|
common_batch_clear(*g_batch);
|
||||||
for (j = 0; j < pl; j++) {
|
for (j = 0; j < pl; j++) {
|
||||||
common_batch_add(*g_batch, 0, i, {j}, true);
|
common_batch_add(*g_batch, 0, i, {j}, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGi("llama_decode() text generation: %d", i);
|
|
||||||
if (llama_decode(g_context, *g_batch) != 0) {
|
if (llama_decode(g_context, *g_batch) != 0) {
|
||||||
LOGw("llama_decode() failed during text generation");
|
LOGe("llama_decode() failed during text generation");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto t_tg_end = ggml_time_us();
|
const auto t_tg_end = ggml_time_us();
|
||||||
|
|
||||||
llama_memory_clear(llama_get_memory(g_context), false);
|
llama_memory_clear(llama_get_memory(g_context), false);
|
||||||
|
|
@ -282,13 +293,42 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
||||||
/**
|
/**
|
||||||
* Prediction loop's long-term and short-term states
|
* Prediction loop's long-term and short-term states
|
||||||
*/
|
*/
|
||||||
static int current_position;
|
static llama_pos current_position;
|
||||||
|
|
||||||
static int token_predict_budget;
|
static llama_pos token_predict_budget;
|
||||||
static std::string cached_token_chars;
|
static std::string cached_token_chars;
|
||||||
|
|
||||||
int token_predict_budget;
|
static int decode_tokens_in_batches(
|
||||||
std::string cached_token_chars;
|
llama_context *context,
|
||||||
|
const llama_tokens& tokens,
|
||||||
|
const llama_pos start_pos,
|
||||||
|
bool compute_last_logit = false,
|
||||||
|
llama_batch *batch = g_batch) {
|
||||||
|
// Process tokens in batches using the global batch
|
||||||
|
LOGd("Decode %d tokens starting at position %d", tokens.size(), start_pos);
|
||||||
|
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
||||||
|
int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
||||||
|
common_batch_clear(*batch);
|
||||||
|
LOGv("Preparing a batch size of %d starting at: %d", cur_batch_size, i);
|
||||||
|
|
||||||
|
// Add tokens to the batch with proper positions
|
||||||
|
for (int j = 0; j < cur_batch_size; j++) {
|
||||||
|
llama_token token_id = tokens[i + j];
|
||||||
|
llama_pos position = start_pos + i + j;
|
||||||
|
bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
|
||||||
|
common_batch_add(*batch, token_id, position, {0}, want_logit);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode this batch
|
||||||
|
int decode_result = llama_decode(context, *batch);
|
||||||
|
if (decode_result) {
|
||||||
|
LOGe("llama_decode failed w/ %d", decode_result);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jint JNICALL
|
JNIEXPORT jint JNICALL
|
||||||
|
|
@ -316,22 +356,14 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add system prompt tokens to batch
|
// Decode system tokens in batches
|
||||||
common_batch_clear(*g_batch);
|
if (decode_tokens_in_batches(g_context, system_tokens, current_position)) {
|
||||||
// TODO-hyin: support batch processing!
|
LOGe("llama_decode() failed!");
|
||||||
for (int i = 0; i < system_tokens.size(); i++) {
|
|
||||||
common_batch_add(*g_batch, system_tokens[i], i, {0}, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode batch
|
|
||||||
int decode_result = llama_decode(g_context, *g_batch);
|
|
||||||
if (decode_result != 0) {
|
|
||||||
LOGe("llama_decode() failed: %d", decode_result);
|
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update position
|
// Update position
|
||||||
current_position = system_tokens.size();
|
current_position = (int) system_tokens.size();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -360,29 +392,21 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if context space is enough for desired tokens
|
// Check if context space is enough for desired tokens
|
||||||
int desired_budget = current_position + user_tokens.size() + n_predict;
|
int desired_budget = current_position + (int) user_tokens.size() + n_predict;
|
||||||
if (desired_budget > llama_n_ctx(g_context)) {
|
if (desired_budget > llama_n_ctx(g_context)) {
|
||||||
LOGe("error: total tokens exceed context size");
|
LOGe("error: total tokens exceed context size");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
token_predict_budget = desired_budget;
|
token_predict_budget = desired_budget;
|
||||||
|
|
||||||
// Add user prompt tokens to batch
|
// Decode user tokens in batches
|
||||||
common_batch_clear(*g_batch);
|
if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) {
|
||||||
for (int i = 0; i < user_tokens.size(); i++) {
|
LOGe("llama_decode() failed!");
|
||||||
common_batch_add(*g_batch, user_tokens[i], current_position + i, {0}, false);
|
|
||||||
}
|
|
||||||
g_batch->logits[g_batch->n_tokens - 1] = true; // Set logits true only for last token
|
|
||||||
|
|
||||||
// Decode batch
|
|
||||||
int decode_result = llama_decode(g_context, *g_batch);
|
|
||||||
if (decode_result != 0) {
|
|
||||||
LOGe("llama_decode() failed: %d", decode_result);
|
|
||||||
return -2;
|
return -2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update position
|
// Update position
|
||||||
current_position += user_tokens.size(); // Update position
|
current_position += (int) user_tokens.size(); // Update position
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -436,13 +460,7 @@ Java_android_llama_cpp_LLamaAndroid_predictLoop(
|
||||||
const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1);
|
const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1);
|
||||||
common_sampler_accept(g_sampler, new_token_id, true);
|
common_sampler_accept(g_sampler, new_token_id, true);
|
||||||
|
|
||||||
// Stop if next token is EOG
|
// Populate the batch with new token, then decode
|
||||||
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
|
|
||||||
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the context with the new token
|
|
||||||
common_batch_clear(*g_batch);
|
common_batch_clear(*g_batch);
|
||||||
common_batch_add(*g_batch, new_token_id, current_position, {0}, true);
|
common_batch_add(*g_batch, new_token_id, current_position, {0}, true);
|
||||||
if (llama_decode(g_context, *g_batch) != 0) {
|
if (llama_decode(g_context, *g_batch) != 0) {
|
||||||
|
|
@ -450,22 +468,28 @@ Java_android_llama_cpp_LLamaAndroid_predictLoop(
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to text
|
// Update position
|
||||||
|
current_position++;
|
||||||
|
|
||||||
|
// Stop if next token is EOG
|
||||||
|
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
|
||||||
|
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not EOG, convert to text
|
||||||
auto new_token_chars = common_token_to_piece(g_context, new_token_id);
|
auto new_token_chars = common_token_to_piece(g_context, new_token_id);
|
||||||
cached_token_chars += new_token_chars;
|
cached_token_chars += new_token_chars;
|
||||||
|
|
||||||
// Create Java string
|
// Create and return Java string
|
||||||
jstring result = nullptr;
|
jstring result = nullptr;
|
||||||
if (is_valid_utf8(cached_token_chars.c_str())) {
|
if (is_valid_utf8(cached_token_chars.c_str())) {
|
||||||
result = env->NewStringUTF(cached_token_chars.c_str());
|
result = env->NewStringUTF(cached_token_chars.c_str());
|
||||||
LOGd("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
|
LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
|
||||||
cached_token_chars.clear();
|
cached_token_chars.clear();
|
||||||
} else {
|
} else {
|
||||||
LOGd("id: %d,\tappend to cache", new_token_id);
|
LOGv("id: %d,\tappend to cache", new_token_id);
|
||||||
result = env->NewStringUTF("");
|
result = env->NewStringUTF("");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update position
|
|
||||||
current_position++;
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue