Abort on system prompt too long; Truncate user prompt if too long.
This commit is contained in:
parent
4809112ec5
commit
4e515727b4
|
|
@ -10,7 +10,8 @@
|
|||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
template <class T> static std::string join(const std::vector<T> & values, const std::string & delim) {
|
||||
template<class T>
|
||||
static std::string join(const std::vector<T> &values, const std::string &delim) {
|
||||
std::ostringstream str;
|
||||
for (size_t i = 0; i < values.size(); i++) {
|
||||
str << values[i];
|
||||
|
|
@ -37,6 +38,7 @@ constexpr int N_THREADS_MAX = 8;
|
|||
constexpr int N_THREADS_HEADROOM = 2;
|
||||
|
||||
constexpr int CONTEXT_SIZE = 4096;
|
||||
constexpr int OVERFLOW_HEADROOM = 4;
|
||||
constexpr int BATCH_SIZE = 512;
|
||||
constexpr float SAMPLER_TEMP = 0.3f;
|
||||
|
||||
|
|
@ -44,7 +46,7 @@ static llama_model * g_model;
|
|||
static llama_context * g_context;
|
||||
static llama_batch * g_batch;
|
||||
static common_sampler * g_sampler;
|
||||
static common_chat_templates_ptr g_chat_templates;
|
||||
static common_chat_templates_ptr g_chat_templates;
|
||||
|
||||
static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
||||
int priority;
|
||||
|
|
@ -68,9 +70,9 @@ static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
|||
__android_log_print(priority, TAG, fmt, data);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) {
|
||||
JNIEnv* env;
|
||||
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
|
||||
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
|
||||
JNIEnv *env;
|
||||
if (vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6) != JNI_OK) {
|
||||
return JNI_ERR;
|
||||
}
|
||||
|
||||
|
|
@ -102,13 +104,13 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file
|
|||
env->ReleaseStringUTFChars(filename, path_to_model);
|
||||
if (!model) {
|
||||
LOGe("load_model() failed");
|
||||
return -1;
|
||||
return 1;
|
||||
}
|
||||
g_model = model;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static llama_context* init_context(llama_model *model) {
|
||||
static llama_context *init_context(llama_model *model) {
|
||||
if (!model) {
|
||||
LOGe("init_context(): model cannot be null");
|
||||
return nullptr;
|
||||
|
|
@ -134,7 +136,7 @@ static llama_context* init_context(llama_model *model) {
|
|||
return context;
|
||||
}
|
||||
|
||||
static llama_batch * new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
|
||||
static llama_batch *new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
|
||||
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
||||
auto *batch = new llama_batch{
|
||||
0,
|
||||
|
|
@ -162,7 +164,7 @@ static llama_batch * new_batch(int n_tokens, bool embd = false, int n_seq_max =
|
|||
return batch;
|
||||
}
|
||||
|
||||
static common_sampler* new_sampler(float temp) {
|
||||
static common_sampler *new_sampler(float temp) {
|
||||
common_params_sampling sparams;
|
||||
sparams.temp = temp;
|
||||
return common_sampler_init(g_model, sparams);
|
||||
|
|
@ -172,7 +174,7 @@ extern "C"
|
|||
JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
|
||||
auto *context = init_context(g_model);
|
||||
if (!context) { return -1; }
|
||||
if (!context) { return 1; }
|
||||
g_context = context;
|
||||
g_batch = new_batch(BATCH_SIZE);
|
||||
g_sampler = new_sampler(SAMPLER_TEMP);
|
||||
|
|
@ -194,7 +196,7 @@ Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unuse
|
|||
static std::string get_backend() {
|
||||
std::vector<std::string> backends;
|
||||
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
||||
auto * reg = ggml_backend_reg_get(i);
|
||||
auto *reg = ggml_backend_reg_get(i);
|
||||
std::string name = ggml_backend_reg_name(reg);
|
||||
if (name != "CPU") {
|
||||
backends.push_back(ggml_backend_reg_name(reg));
|
||||
|
|
@ -205,7 +207,8 @@ static std::string get_backend() {
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) {
|
||||
Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
|
||||
jint pl, jint nr) {
|
||||
auto pp_avg = 0.0;
|
||||
auto tg_avg = 0.0;
|
||||
auto pp_std = 0.0;
|
||||
|
|
@ -304,14 +307,14 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
|||
/**
|
||||
* Prediction loop's long-term states
|
||||
*/
|
||||
constexpr const char* ROLE_SYSTEM = "system";
|
||||
constexpr const char* ROLE_USER = "user";
|
||||
constexpr const char* ROLE_ASSISTANT = "assistant";
|
||||
constexpr const char *ROLE_SYSTEM = "system";
|
||||
constexpr const char *ROLE_USER = "user";
|
||||
constexpr const char *ROLE_ASSISTANT = "assistant";
|
||||
|
||||
static llama_pos current_position;
|
||||
static std::vector<common_chat_msg> chat_msgs;
|
||||
|
||||
static std::string chat_add_and_format(const std::string & role, const std::string & content) {
|
||||
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
|
||||
common_chat_msg new_msg;
|
||||
new_msg.role = role;
|
||||
new_msg.content = content;
|
||||
|
|
@ -324,12 +327,12 @@ static std::string chat_add_and_format(const std::string & role, const std::stri
|
|||
|
||||
static int decode_tokens_in_batches(
|
||||
llama_context *context,
|
||||
const llama_tokens& tokens,
|
||||
const llama_tokens &tokens,
|
||||
const llama_pos start_pos,
|
||||
bool compute_last_logit = false,
|
||||
llama_batch *batch = g_batch) {
|
||||
// Process tokens in batches using the global batch
|
||||
LOGd("Decode %d tokens starting at position %d", tokens.size(), start_pos);
|
||||
LOGd("Decode %d tokens starting at position %d", (int) tokens.size(), start_pos);
|
||||
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
||||
int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
||||
common_batch_clear(*batch);
|
||||
|
|
@ -347,10 +350,9 @@ static int decode_tokens_in_batches(
|
|||
int decode_result = llama_decode(context, *batch);
|
||||
if (decode_result) {
|
||||
LOGe("llama_decode failed w/ %d", decode_result);
|
||||
return -1;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -390,17 +392,24 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
|||
}
|
||||
|
||||
// Tokenize system prompt
|
||||
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt, has_chat_template, has_chat_template);
|
||||
for (auto id : system_tokens) {
|
||||
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
|
||||
has_chat_template, has_chat_template);
|
||||
for (auto id: system_tokens) {
|
||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||
}
|
||||
|
||||
// TODO-hyin: handle context overflow
|
||||
// Handle context overflow
|
||||
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||
if ((int) system_tokens.size() > max_batch_size) {
|
||||
LOGe("System prompt too long for context! %d tokens, max: %d",
|
||||
(int) system_tokens.size(), max_batch_size);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Decode system tokens in batches
|
||||
if (decode_tokens_in_batches(g_context, system_tokens, current_position)) {
|
||||
LOGe("llama_decode() failed!");
|
||||
return -1;
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Update position
|
||||
|
|
@ -435,28 +444,36 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
|||
}
|
||||
|
||||
// Decode formatted user prompts
|
||||
const auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
|
||||
for (auto id : user_tokens) {
|
||||
auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
|
||||
for (auto id: user_tokens) {
|
||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||
}
|
||||
|
||||
// TODO-hyin: handle context overflow
|
||||
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
|
||||
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||
if ((int) user_tokens.size() > max_batch_size) {
|
||||
const int skipped_tokens = (int) user_tokens.size() - max_batch_size;
|
||||
user_tokens.resize(max_batch_size);
|
||||
LOGw("User prompt too long! Skipped %d tokens!", skipped_tokens);
|
||||
}
|
||||
|
||||
// TODO-hyin: implement context shifting
|
||||
// Check if context space is enough for desired tokens
|
||||
int desired_budget = current_position + (int) user_tokens.size() + n_predict;
|
||||
if (desired_budget > llama_n_ctx(g_context)) {
|
||||
LOGe("error: total tokens exceed context size");
|
||||
return -1;
|
||||
if (desired_budget > max_batch_size) {
|
||||
LOGe("Not enough context! %d total tokens, max: %d", desired_budget, max_batch_size);
|
||||
return 1;
|
||||
}
|
||||
token_predict_budget = desired_budget;
|
||||
|
||||
// Decode user tokens in batches
|
||||
if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) {
|
||||
LOGe("llama_decode() failed!");
|
||||
return -2;
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Update position
|
||||
current_position += (int) user_tokens.size(); // Update position
|
||||
current_position += (int) user_tokens.size();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue