Abort on system prompt too long; Truncate user prompt if too long.
This commit is contained in:
parent
4809112ec5
commit
4e515727b4
|
|
@ -10,7 +10,8 @@
|
|||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
template <class T> static std::string join(const std::vector<T> & values, const std::string & delim) {
|
||||
template<class T>
|
||||
static std::string join(const std::vector<T> &values, const std::string &delim) {
|
||||
std::ostringstream str;
|
||||
for (size_t i = 0; i < values.size(); i++) {
|
||||
str << values[i];
|
||||
|
|
@ -37,6 +38,7 @@ constexpr int N_THREADS_MAX = 8;
|
|||
constexpr int N_THREADS_HEADROOM = 2;
|
||||
|
||||
constexpr int CONTEXT_SIZE = 4096;
|
||||
constexpr int OVERFLOW_HEADROOM = 4;
|
||||
constexpr int BATCH_SIZE = 512;
|
||||
constexpr float SAMPLER_TEMP = 0.3f;
|
||||
|
||||
|
|
@ -102,7 +104,7 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file
|
|||
env->ReleaseStringUTFChars(filename, path_to_model);
|
||||
if (!model) {
|
||||
LOGe("load_model() failed");
|
||||
return -1;
|
||||
return 1;
|
||||
}
|
||||
g_model = model;
|
||||
return 0;
|
||||
|
|
@ -172,7 +174,7 @@ extern "C"
|
|||
JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
|
||||
auto *context = init_context(g_model);
|
||||
if (!context) { return -1; }
|
||||
if (!context) { return 1; }
|
||||
g_context = context;
|
||||
g_batch = new_batch(BATCH_SIZE);
|
||||
g_sampler = new_sampler(SAMPLER_TEMP);
|
||||
|
|
@ -205,7 +207,8 @@ static std::string get_backend() {
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) {
|
||||
Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
|
||||
jint pl, jint nr) {
|
||||
auto pp_avg = 0.0;
|
||||
auto tg_avg = 0.0;
|
||||
auto pp_std = 0.0;
|
||||
|
|
@ -329,7 +332,7 @@ static int decode_tokens_in_batches(
|
|||
bool compute_last_logit = false,
|
||||
llama_batch *batch = g_batch) {
|
||||
// Process tokens in batches using the global batch
|
||||
LOGd("Decode %d tokens starting at position %d", tokens.size(), start_pos);
|
||||
LOGd("Decode %d tokens starting at position %d", (int) tokens.size(), start_pos);
|
||||
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
||||
int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
||||
common_batch_clear(*batch);
|
||||
|
|
@ -347,10 +350,9 @@ static int decode_tokens_in_batches(
|
|||
int decode_result = llama_decode(context, *batch);
|
||||
if (decode_result) {
|
||||
LOGe("llama_decode failed w/ %d", decode_result);
|
||||
return -1;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -390,17 +392,24 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
|||
}
|
||||
|
||||
// Tokenize system prompt
|
||||
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt, has_chat_template, has_chat_template);
|
||||
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
|
||||
has_chat_template, has_chat_template);
|
||||
for (auto id: system_tokens) {
|
||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||
}
|
||||
|
||||
// TODO-hyin: handle context overflow
|
||||
// Handle context overflow
|
||||
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||
if ((int) system_tokens.size() > max_batch_size) {
|
||||
LOGe("System prompt too long for context! %d tokens, max: %d",
|
||||
(int) system_tokens.size(), max_batch_size);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Decode system tokens in batches
|
||||
if (decode_tokens_in_batches(g_context, system_tokens, current_position)) {
|
||||
LOGe("llama_decode() failed!");
|
||||
return -1;
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Update position
|
||||
|
|
@ -435,28 +444,36 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
|||
}
|
||||
|
||||
// Decode formatted user prompts
|
||||
const auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
|
||||
auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
|
||||
for (auto id: user_tokens) {
|
||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||
}
|
||||
|
||||
// TODO-hyin: handle context overflow
|
||||
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
|
||||
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||
if ((int) user_tokens.size() > max_batch_size) {
|
||||
const int skipped_tokens = (int) user_tokens.size() - max_batch_size;
|
||||
user_tokens.resize(max_batch_size);
|
||||
LOGw("User prompt too long! Skipped %d tokens!", skipped_tokens);
|
||||
}
|
||||
|
||||
// TODO-hyin: implement context shifting
|
||||
// Check if context space is enough for desired tokens
|
||||
int desired_budget = current_position + (int) user_tokens.size() + n_predict;
|
||||
if (desired_budget > llama_n_ctx(g_context)) {
|
||||
LOGe("error: total tokens exceed context size");
|
||||
return -1;
|
||||
if (desired_budget > max_batch_size) {
|
||||
LOGe("Not enough context! %d total tokens, max: %d", desired_budget, max_batch_size);
|
||||
return 1;
|
||||
}
|
||||
token_predict_budget = desired_budget;
|
||||
|
||||
// Decode user tokens in batches
|
||||
if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) {
|
||||
LOGe("llama_decode() failed!");
|
||||
return -2;
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Update position
|
||||
current_position += (int) user_tokens.size(); // Update position
|
||||
current_position += (int) user_tokens.size();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue