Abort on system prompt too long; Truncate user prompt if too long.

This commit is contained in:
Han Yin 2025-04-08 11:27:00 -07:00
parent 4809112ec5
commit 4e515727b4
1 changed files with 49 additions and 32 deletions

View File

@ -10,7 +10,8 @@
#include "common.h"
#include "llama.h"
template <class T> static std::string join(const std::vector<T> & values, const std::string & delim) {
template<class T>
static std::string join(const std::vector<T> &values, const std::string &delim) {
std::ostringstream str;
for (size_t i = 0; i < values.size(); i++) {
str << values[i];
@ -37,6 +38,7 @@ constexpr int N_THREADS_MAX = 8;
constexpr int N_THREADS_HEADROOM = 2;
constexpr int CONTEXT_SIZE = 4096;
constexpr int OVERFLOW_HEADROOM = 4;
constexpr int BATCH_SIZE = 512;
constexpr float SAMPLER_TEMP = 0.3f;
@ -44,7 +46,7 @@ static llama_model * g_model;
static llama_context * g_context;
static llama_batch * g_batch;
static common_sampler * g_sampler;
static common_chat_templates_ptr g_chat_templates;
static common_chat_templates_ptr g_chat_templates;
static void log_callback(ggml_log_level level, const char *fmt, void *data) {
int priority;
@ -68,9 +70,9 @@ static void log_callback(ggml_log_level level, const char *fmt, void *data) {
__android_log_print(priority, TAG, fmt, data);
}
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) {
JNIEnv* env;
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
JNIEnv *env;
if (vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6) != JNI_OK) {
return JNI_ERR;
}
@ -102,13 +104,13 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file
env->ReleaseStringUTFChars(filename, path_to_model);
if (!model) {
LOGe("load_model() failed");
return -1;
return 1;
}
g_model = model;
return 0;
}
static llama_context* init_context(llama_model *model) {
static llama_context *init_context(llama_model *model) {
if (!model) {
LOGe("init_context(): model cannot be null");
return nullptr;
@ -134,7 +136,7 @@ static llama_context* init_context(llama_model *model) {
return context;
}
static llama_batch * new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
static llama_batch *new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
auto *batch = new llama_batch{
0,
@ -162,7 +164,7 @@ static llama_batch * new_batch(int n_tokens, bool embd = false, int n_seq_max =
return batch;
}
static common_sampler* new_sampler(float temp) {
static common_sampler *new_sampler(float temp) {
common_params_sampling sparams;
sparams.temp = temp;
return common_sampler_init(g_model, sparams);
@ -172,7 +174,7 @@ extern "C"
JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
auto *context = init_context(g_model);
if (!context) { return -1; }
if (!context) { return 1; }
g_context = context;
g_batch = new_batch(BATCH_SIZE);
g_sampler = new_sampler(SAMPLER_TEMP);
@ -194,7 +196,7 @@ Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unuse
static std::string get_backend() {
std::vector<std::string> backends;
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
auto * reg = ggml_backend_reg_get(i);
auto *reg = ggml_backend_reg_get(i);
std::string name = ggml_backend_reg_name(reg);
if (name != "CPU") {
backends.push_back(ggml_backend_reg_name(reg));
@ -205,7 +207,8 @@ static std::string get_backend() {
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) {
Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
jint pl, jint nr) {
auto pp_avg = 0.0;
auto tg_avg = 0.0;
auto pp_std = 0.0;
@ -304,14 +307,14 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
/**
* Prediction loop's long-term states
*/
constexpr const char* ROLE_SYSTEM = "system";
constexpr const char* ROLE_USER = "user";
constexpr const char* ROLE_ASSISTANT = "assistant";
constexpr const char *ROLE_SYSTEM = "system";
constexpr const char *ROLE_USER = "user";
constexpr const char *ROLE_ASSISTANT = "assistant";
static llama_pos current_position;
static std::vector<common_chat_msg> chat_msgs;
static std::string chat_add_and_format(const std::string & role, const std::string & content) {
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
common_chat_msg new_msg;
new_msg.role = role;
new_msg.content = content;
@ -324,12 +327,12 @@ static std::string chat_add_and_format(const std::string & role, const std::stri
static int decode_tokens_in_batches(
llama_context *context,
const llama_tokens& tokens,
const llama_tokens &tokens,
const llama_pos start_pos,
bool compute_last_logit = false,
llama_batch *batch = g_batch) {
// Process tokens in batches using the global batch
LOGd("Decode %d tokens starting at position %d", tokens.size(), start_pos);
LOGd("Decode %d tokens starting at position %d", (int) tokens.size(), start_pos);
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
common_batch_clear(*batch);
@ -347,10 +350,9 @@ static int decode_tokens_in_batches(
int decode_result = llama_decode(context, *batch);
if (decode_result) {
LOGe("llama_decode failed w/ %d", decode_result);
return -1;
return 1;
}
}
return 0;
}
@ -390,17 +392,24 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
}
// Tokenize system prompt
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt, has_chat_template, has_chat_template);
for (auto id : system_tokens) {
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
has_chat_template, has_chat_template);
for (auto id: system_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
}
// TODO-hyin: handle context overflow
// Handle context overflow
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
if ((int) system_tokens.size() > max_batch_size) {
LOGe("System prompt too long for context! %d tokens, max: %d",
(int) system_tokens.size(), max_batch_size);
return 1;
}
// Decode system tokens in batches
if (decode_tokens_in_batches(g_context, system_tokens, current_position)) {
LOGe("llama_decode() failed!");
return -1;
return 2;
}
// Update position
@ -435,28 +444,36 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
}
// Decode formatted user prompts
const auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
for (auto id : user_tokens) {
auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
for (auto id: user_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
}
// TODO-hyin: handle context overflow
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
if ((int) user_tokens.size() > max_batch_size) {
const int skipped_tokens = (int) user_tokens.size() - max_batch_size;
user_tokens.resize(max_batch_size);
LOGw("User prompt too long! Skipped %d tokens!", skipped_tokens);
}
// TODO-hyin: implement context shifting
// Check if context space is enough for desired tokens
int desired_budget = current_position + (int) user_tokens.size() + n_predict;
if (desired_budget > llama_n_ctx(g_context)) {
LOGe("error: total tokens exceed context size");
return -1;
if (desired_budget > max_batch_size) {
LOGe("Not enough context! %d total tokens, max: %d", desired_budget, max_batch_size);
return 1;
}
token_predict_budget = desired_budget;
// Decode user tokens in batches
if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) {
LOGe("llama_decode() failed!");
return -2;
return 2;
}
// Update position
current_position += (int) user_tokens.size(); // Update position
current_position += (int) user_tokens.size();
return 0;
}