Abort on system prompt too long; Truncate user prompt if too long.
This commit is contained in:
parent
4809112ec5
commit
4e515727b4
|
|
@ -10,7 +10,8 @@
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
template <class T> static std::string join(const std::vector<T> & values, const std::string & delim) {
|
template<class T>
|
||||||
|
static std::string join(const std::vector<T> &values, const std::string &delim) {
|
||||||
std::ostringstream str;
|
std::ostringstream str;
|
||||||
for (size_t i = 0; i < values.size(); i++) {
|
for (size_t i = 0; i < values.size(); i++) {
|
||||||
str << values[i];
|
str << values[i];
|
||||||
|
|
@ -37,6 +38,7 @@ constexpr int N_THREADS_MAX = 8;
|
||||||
constexpr int N_THREADS_HEADROOM = 2;
|
constexpr int N_THREADS_HEADROOM = 2;
|
||||||
|
|
||||||
constexpr int CONTEXT_SIZE = 4096;
|
constexpr int CONTEXT_SIZE = 4096;
|
||||||
|
constexpr int OVERFLOW_HEADROOM = 4;
|
||||||
constexpr int BATCH_SIZE = 512;
|
constexpr int BATCH_SIZE = 512;
|
||||||
constexpr float SAMPLER_TEMP = 0.3f;
|
constexpr float SAMPLER_TEMP = 0.3f;
|
||||||
|
|
||||||
|
|
@ -44,7 +46,7 @@ static llama_model * g_model;
|
||||||
static llama_context * g_context;
|
static llama_context * g_context;
|
||||||
static llama_batch * g_batch;
|
static llama_batch * g_batch;
|
||||||
static common_sampler * g_sampler;
|
static common_sampler * g_sampler;
|
||||||
static common_chat_templates_ptr g_chat_templates;
|
static common_chat_templates_ptr g_chat_templates;
|
||||||
|
|
||||||
static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
||||||
int priority;
|
int priority;
|
||||||
|
|
@ -68,9 +70,9 @@ static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
||||||
__android_log_print(priority, TAG, fmt, data);
|
__android_log_print(priority, TAG, fmt, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) {
|
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
|
||||||
JNIEnv* env;
|
JNIEnv *env;
|
||||||
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
|
if (vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6) != JNI_OK) {
|
||||||
return JNI_ERR;
|
return JNI_ERR;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -102,13 +104,13 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file
|
||||||
env->ReleaseStringUTFChars(filename, path_to_model);
|
env->ReleaseStringUTFChars(filename, path_to_model);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
LOGe("load_model() failed");
|
LOGe("load_model() failed");
|
||||||
return -1;
|
return 1;
|
||||||
}
|
}
|
||||||
g_model = model;
|
g_model = model;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static llama_context* init_context(llama_model *model) {
|
static llama_context *init_context(llama_model *model) {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
LOGe("init_context(): model cannot be null");
|
LOGe("init_context(): model cannot be null");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
@ -134,7 +136,7 @@ static llama_context* init_context(llama_model *model) {
|
||||||
return context;
|
return context;
|
||||||
}
|
}
|
||||||
|
|
||||||
static llama_batch * new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
|
static llama_batch *new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
|
||||||
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
||||||
auto *batch = new llama_batch{
|
auto *batch = new llama_batch{
|
||||||
0,
|
0,
|
||||||
|
|
@ -162,7 +164,7 @@ static llama_batch * new_batch(int n_tokens, bool embd = false, int n_seq_max =
|
||||||
return batch;
|
return batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_sampler* new_sampler(float temp) {
|
static common_sampler *new_sampler(float temp) {
|
||||||
common_params_sampling sparams;
|
common_params_sampling sparams;
|
||||||
sparams.temp = temp;
|
sparams.temp = temp;
|
||||||
return common_sampler_init(g_model, sparams);
|
return common_sampler_init(g_model, sparams);
|
||||||
|
|
@ -172,7 +174,7 @@ extern "C"
|
||||||
JNIEXPORT jint JNICALL
|
JNIEXPORT jint JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
|
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
|
||||||
auto *context = init_context(g_model);
|
auto *context = init_context(g_model);
|
||||||
if (!context) { return -1; }
|
if (!context) { return 1; }
|
||||||
g_context = context;
|
g_context = context;
|
||||||
g_batch = new_batch(BATCH_SIZE);
|
g_batch = new_batch(BATCH_SIZE);
|
||||||
g_sampler = new_sampler(SAMPLER_TEMP);
|
g_sampler = new_sampler(SAMPLER_TEMP);
|
||||||
|
|
@ -194,7 +196,7 @@ Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unuse
|
||||||
static std::string get_backend() {
|
static std::string get_backend() {
|
||||||
std::vector<std::string> backends;
|
std::vector<std::string> backends;
|
||||||
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
||||||
auto * reg = ggml_backend_reg_get(i);
|
auto *reg = ggml_backend_reg_get(i);
|
||||||
std::string name = ggml_backend_reg_name(reg);
|
std::string name = ggml_backend_reg_name(reg);
|
||||||
if (name != "CPU") {
|
if (name != "CPU") {
|
||||||
backends.push_back(ggml_backend_reg_name(reg));
|
backends.push_back(ggml_backend_reg_name(reg));
|
||||||
|
|
@ -205,7 +207,8 @@ static std::string get_backend() {
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jstring JNICALL
|
JNIEXPORT jstring JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) {
|
Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
|
||||||
|
jint pl, jint nr) {
|
||||||
auto pp_avg = 0.0;
|
auto pp_avg = 0.0;
|
||||||
auto tg_avg = 0.0;
|
auto tg_avg = 0.0;
|
||||||
auto pp_std = 0.0;
|
auto pp_std = 0.0;
|
||||||
|
|
@ -304,14 +307,14 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
||||||
/**
|
/**
|
||||||
* Prediction loop's long-term states
|
* Prediction loop's long-term states
|
||||||
*/
|
*/
|
||||||
constexpr const char* ROLE_SYSTEM = "system";
|
constexpr const char *ROLE_SYSTEM = "system";
|
||||||
constexpr const char* ROLE_USER = "user";
|
constexpr const char *ROLE_USER = "user";
|
||||||
constexpr const char* ROLE_ASSISTANT = "assistant";
|
constexpr const char *ROLE_ASSISTANT = "assistant";
|
||||||
|
|
||||||
static llama_pos current_position;
|
static llama_pos current_position;
|
||||||
static std::vector<common_chat_msg> chat_msgs;
|
static std::vector<common_chat_msg> chat_msgs;
|
||||||
|
|
||||||
static std::string chat_add_and_format(const std::string & role, const std::string & content) {
|
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
|
||||||
common_chat_msg new_msg;
|
common_chat_msg new_msg;
|
||||||
new_msg.role = role;
|
new_msg.role = role;
|
||||||
new_msg.content = content;
|
new_msg.content = content;
|
||||||
|
|
@ -324,12 +327,12 @@ static std::string chat_add_and_format(const std::string & role, const std::stri
|
||||||
|
|
||||||
static int decode_tokens_in_batches(
|
static int decode_tokens_in_batches(
|
||||||
llama_context *context,
|
llama_context *context,
|
||||||
const llama_tokens& tokens,
|
const llama_tokens &tokens,
|
||||||
const llama_pos start_pos,
|
const llama_pos start_pos,
|
||||||
bool compute_last_logit = false,
|
bool compute_last_logit = false,
|
||||||
llama_batch *batch = g_batch) {
|
llama_batch *batch = g_batch) {
|
||||||
// Process tokens in batches using the global batch
|
// Process tokens in batches using the global batch
|
||||||
LOGd("Decode %d tokens starting at position %d", tokens.size(), start_pos);
|
LOGd("Decode %d tokens starting at position %d", (int) tokens.size(), start_pos);
|
||||||
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
||||||
int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
||||||
common_batch_clear(*batch);
|
common_batch_clear(*batch);
|
||||||
|
|
@ -347,10 +350,9 @@ static int decode_tokens_in_batches(
|
||||||
int decode_result = llama_decode(context, *batch);
|
int decode_result = llama_decode(context, *batch);
|
||||||
if (decode_result) {
|
if (decode_result) {
|
||||||
LOGe("llama_decode failed w/ %d", decode_result);
|
LOGe("llama_decode failed w/ %d", decode_result);
|
||||||
return -1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -390,17 +392,24 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tokenize system prompt
|
// Tokenize system prompt
|
||||||
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt, has_chat_template, has_chat_template);
|
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
|
||||||
for (auto id : system_tokens) {
|
has_chat_template, has_chat_template);
|
||||||
|
for (auto id: system_tokens) {
|
||||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO-hyin: handle context overflow
|
// Handle context overflow
|
||||||
|
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||||
|
if ((int) system_tokens.size() > max_batch_size) {
|
||||||
|
LOGe("System prompt too long for context! %d tokens, max: %d",
|
||||||
|
(int) system_tokens.size(), max_batch_size);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Decode system tokens in batches
|
// Decode system tokens in batches
|
||||||
if (decode_tokens_in_batches(g_context, system_tokens, current_position)) {
|
if (decode_tokens_in_batches(g_context, system_tokens, current_position)) {
|
||||||
LOGe("llama_decode() failed!");
|
LOGe("llama_decode() failed!");
|
||||||
return -1;
|
return 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update position
|
// Update position
|
||||||
|
|
@ -435,28 +444,36 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode formatted user prompts
|
// Decode formatted user prompts
|
||||||
const auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
|
auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
|
||||||
for (auto id : user_tokens) {
|
for (auto id: user_tokens) {
|
||||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO-hyin: handle context overflow
|
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
|
||||||
|
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||||
|
if ((int) user_tokens.size() > max_batch_size) {
|
||||||
|
const int skipped_tokens = (int) user_tokens.size() - max_batch_size;
|
||||||
|
user_tokens.resize(max_batch_size);
|
||||||
|
LOGw("User prompt too long! Skipped %d tokens!", skipped_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO-hyin: implement context shifting
|
||||||
// Check if context space is enough for desired tokens
|
// Check if context space is enough for desired tokens
|
||||||
int desired_budget = current_position + (int) user_tokens.size() + n_predict;
|
int desired_budget = current_position + (int) user_tokens.size() + n_predict;
|
||||||
if (desired_budget > llama_n_ctx(g_context)) {
|
if (desired_budget > max_batch_size) {
|
||||||
LOGe("error: total tokens exceed context size");
|
LOGe("Not enough context! %d total tokens, max: %d", desired_budget, max_batch_size);
|
||||||
return -1;
|
return 1;
|
||||||
}
|
}
|
||||||
token_predict_budget = desired_budget;
|
token_predict_budget = desired_budget;
|
||||||
|
|
||||||
// Decode user tokens in batches
|
// Decode user tokens in batches
|
||||||
if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) {
|
if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) {
|
||||||
LOGe("llama_decode() failed!");
|
LOGe("llama_decode() failed!");
|
||||||
return -2;
|
return 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update position
|
// Update position
|
||||||
current_position += (int) user_tokens.size(); // Update position
|
current_position += (int) user_tokens.size();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue