Polish: adopt common naming; init modularization;
This commit is contained in:
parent
8bf2f4d412
commit
4809112ec5
|
|
@ -108,10 +108,10 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int init_context(llama_model *model) {
|
static llama_context* init_context(llama_model *model) {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
LOGe("init_context(): model cannot be null");
|
LOGe("init_context(): model cannot be null");
|
||||||
return -1;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multi-threading setup
|
// Multi-threading setup
|
||||||
|
|
@ -128,15 +128,13 @@ static int init_context(llama_model *model) {
|
||||||
ctx_params.n_threads = n_threads;
|
ctx_params.n_threads = n_threads;
|
||||||
ctx_params.n_threads_batch = n_threads;
|
ctx_params.n_threads_batch = n_threads;
|
||||||
auto *context = llama_init_from_model(g_model, ctx_params);
|
auto *context = llama_init_from_model(g_model, ctx_params);
|
||||||
if (!context) {
|
if (context == nullptr) {
|
||||||
LOGe("llama_new_context_with_model() returned null)");
|
LOGe("llama_new_context_with_model() returned null)");
|
||||||
return -2;
|
|
||||||
}
|
}
|
||||||
g_context = context;
|
return context;
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
|
static llama_batch * new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
|
||||||
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
||||||
auto *batch = new llama_batch{
|
auto *batch = new llama_batch{
|
||||||
0,
|
0,
|
||||||
|
|
@ -161,22 +159,23 @@ static void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
|
||||||
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
||||||
}
|
}
|
||||||
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
||||||
g_batch = batch;
|
return batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
void new_sampler(float temp) {
|
static common_sampler* new_sampler(float temp) {
|
||||||
common_params_sampling sparams;
|
common_params_sampling sparams;
|
||||||
sparams.temp = temp;
|
sparams.temp = temp;
|
||||||
g_sampler = common_sampler_init(g_model, sparams);
|
return common_sampler_init(g_model, sparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jint JNICALL
|
JNIEXPORT jint JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
|
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
|
||||||
int ret = init_context(g_model);
|
auto *context = init_context(g_model);
|
||||||
if (ret != 0) { return ret; }
|
if (!context) { return -1; }
|
||||||
new_batch(BATCH_SIZE);
|
g_context = context;
|
||||||
new_sampler(SAMPLER_TEMP);
|
g_batch = new_batch(BATCH_SIZE);
|
||||||
|
g_sampler = new_sampler(SAMPLER_TEMP);
|
||||||
g_chat_templates = common_chat_templates_init(g_model, "");
|
g_chat_templates = common_chat_templates_init(g_model, "");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
@ -396,6 +395,8 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO-hyin: handle context overflow
|
||||||
|
|
||||||
// Decode system tokens in batches
|
// Decode system tokens in batches
|
||||||
if (decode_tokens_in_batches(g_context, system_tokens, current_position)) {
|
if (decode_tokens_in_batches(g_context, system_tokens, current_position)) {
|
||||||
LOGe("llama_decode() failed!");
|
LOGe("llama_decode() failed!");
|
||||||
|
|
@ -439,6 +440,7 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
||||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO-hyin: handle context overflow
|
||||||
// Check if context space is enough for desired tokens
|
// Check if context space is enough for desired tokens
|
||||||
int desired_budget = current_position + (int) user_tokens.size() + n_predict;
|
int desired_budget = current_position + (int) user_tokens.size() + n_predict;
|
||||||
if (desired_budget > llama_n_ctx(g_context)) {
|
if (desired_budget > llama_n_ctx(g_context)) {
|
||||||
|
|
@ -494,7 +496,7 @@ static bool is_valid_utf8(const char *string) {
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jstring JNICALL
|
JNIEXPORT jstring JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_predictLoop(
|
Java_android_llama_cpp_LLamaAndroid_completionLoop(
|
||||||
JNIEnv *env,
|
JNIEnv *env,
|
||||||
jobject /*unused*/
|
jobject /*unused*/
|
||||||
) {
|
) {
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,8 @@ class LLamaAndroid {
|
||||||
private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
|
private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
|
||||||
|
|
||||||
private external fun processSystemPrompt(systemPrompt: String): Int
|
private external fun processSystemPrompt(systemPrompt: String): Int
|
||||||
private external fun processUserPrompt(userPrompt: String, nPredict: Int): Int
|
private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
|
||||||
private external fun predictLoop(): String?
|
private external fun completionLoop(): String?
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Thread local state
|
* Thread local state
|
||||||
|
|
@ -128,7 +128,7 @@ class LLamaAndroid {
|
||||||
|
|
||||||
Log.i(TAG, "User prompt processed! Generating assistant prompt...")
|
Log.i(TAG, "User prompt processed! Generating assistant prompt...")
|
||||||
while (true) {
|
while (true) {
|
||||||
predictLoop()?.let { utf8token ->
|
completionLoop()?.let { utf8token ->
|
||||||
if (utf8token.isNotEmpty()) emit(utf8token)
|
if (utf8token.isNotEmpty()) emit(utf8token)
|
||||||
} ?: break
|
} ?: break
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue