Abort on system prompt too long; Truncate user prompt if too long.

This commit is contained in:
Han Yin 2025-04-08 11:27:00 -07:00
parent 4809112ec5
commit 4e515727b4
1 changed files with 49 additions and 32 deletions

View File

@ -10,7 +10,8 @@
#include "common.h" #include "common.h"
#include "llama.h" #include "llama.h"
template <class T> static std::string join(const std::vector<T> & values, const std::string & delim) { template<class T>
static std::string join(const std::vector<T> &values, const std::string &delim) {
std::ostringstream str; std::ostringstream str;
for (size_t i = 0; i < values.size(); i++) { for (size_t i = 0; i < values.size(); i++) {
str << values[i]; str << values[i];
@ -37,6 +38,7 @@ constexpr int N_THREADS_MAX = 8;
constexpr int N_THREADS_HEADROOM = 2; constexpr int N_THREADS_HEADROOM = 2;
constexpr int CONTEXT_SIZE = 4096; constexpr int CONTEXT_SIZE = 4096;
constexpr int OVERFLOW_HEADROOM = 4;
constexpr int BATCH_SIZE = 512; constexpr int BATCH_SIZE = 512;
constexpr float SAMPLER_TEMP = 0.3f; constexpr float SAMPLER_TEMP = 0.3f;
@ -102,7 +104,7 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file
env->ReleaseStringUTFChars(filename, path_to_model); env->ReleaseStringUTFChars(filename, path_to_model);
if (!model) { if (!model) {
LOGe("load_model() failed"); LOGe("load_model() failed");
return -1; return 1;
} }
g_model = model; g_model = model;
return 0; return 0;
@ -172,7 +174,7 @@ extern "C"
JNIEXPORT jint JNICALL JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) { Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
auto *context = init_context(g_model); auto *context = init_context(g_model);
if (!context) { return -1; } if (!context) { return 1; }
g_context = context; g_context = context;
g_batch = new_batch(BATCH_SIZE); g_batch = new_batch(BATCH_SIZE);
g_sampler = new_sampler(SAMPLER_TEMP); g_sampler = new_sampler(SAMPLER_TEMP);
@ -205,7 +207,8 @@ static std::string get_backend() {
extern "C" extern "C"
JNIEXPORT jstring JNICALL JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) { Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
jint pl, jint nr) {
auto pp_avg = 0.0; auto pp_avg = 0.0;
auto tg_avg = 0.0; auto tg_avg = 0.0;
auto pp_std = 0.0; auto pp_std = 0.0;
@ -329,7 +332,7 @@ static int decode_tokens_in_batches(
bool compute_last_logit = false, bool compute_last_logit = false,
llama_batch *batch = g_batch) { llama_batch *batch = g_batch) {
// Process tokens in batches using the global batch // Process tokens in batches using the global batch
LOGd("Decode %d tokens starting at position %d", tokens.size(), start_pos); LOGd("Decode %d tokens starting at position %d", (int) tokens.size(), start_pos);
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) { for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE); int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
common_batch_clear(*batch); common_batch_clear(*batch);
@ -347,10 +350,9 @@ static int decode_tokens_in_batches(
int decode_result = llama_decode(context, *batch); int decode_result = llama_decode(context, *batch);
if (decode_result) { if (decode_result) {
LOGe("llama_decode failed w/ %d", decode_result); LOGe("llama_decode failed w/ %d", decode_result);
return -1; return 1;
} }
} }
return 0; return 0;
} }
@ -390,17 +392,24 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
} }
// Tokenize system prompt // Tokenize system prompt
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt, has_chat_template, has_chat_template); const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
has_chat_template, has_chat_template);
for (auto id: system_tokens) { for (auto id: system_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
} }
// TODO-hyin: handle context overflow // Handle context overflow
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
if ((int) system_tokens.size() > max_batch_size) {
LOGe("System prompt too long for context! %d tokens, max: %d",
(int) system_tokens.size(), max_batch_size);
return 1;
}
// Decode system tokens in batches // Decode system tokens in batches
if (decode_tokens_in_batches(g_context, system_tokens, current_position)) { if (decode_tokens_in_batches(g_context, system_tokens, current_position)) {
LOGe("llama_decode() failed!"); LOGe("llama_decode() failed!");
return -1; return 2;
} }
// Update position // Update position
@ -435,28 +444,36 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
} }
// Decode formatted user prompts // Decode formatted user prompts
const auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template); auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
for (auto id: user_tokens) { for (auto id: user_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
} }
// TODO-hyin: handle context overflow // Ensure user prompt doesn't exceed the context size by truncating if necessary.
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
if ((int) user_tokens.size() > max_batch_size) {
const int skipped_tokens = (int) user_tokens.size() - max_batch_size;
user_tokens.resize(max_batch_size);
LOGw("User prompt too long! Skipped %d tokens!", skipped_tokens);
}
// TODO-hyin: implement context shifting
// Check if context space is enough for desired tokens // Check if context space is enough for desired tokens
int desired_budget = current_position + (int) user_tokens.size() + n_predict; int desired_budget = current_position + (int) user_tokens.size() + n_predict;
if (desired_budget > llama_n_ctx(g_context)) { if (desired_budget > max_batch_size) {
LOGe("error: total tokens exceed context size"); LOGe("Not enough context! %d total tokens, max: %d", desired_budget, max_batch_size);
return -1; return 1;
} }
token_predict_budget = desired_budget; token_predict_budget = desired_budget;
// Decode user tokens in batches // Decode user tokens in batches
if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) { if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) {
LOGe("llama_decode() failed!"); LOGe("llama_decode() failed!");
return -2; return 2;
} }
// Update position // Update position
current_position += (int) user_tokens.size(); // Update position current_position += (int) user_tokens.size();
return 0; return 0;
} }