Polish: adopt common naming; init modularization;

This commit is contained in:
Han Yin 2025-04-07 20:48:15 -07:00
parent 8bf2f4d412
commit 4809112ec5
2 changed files with 20 additions and 18 deletions

View File

@ -108,10 +108,10 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file
return 0; return 0;
} }
static int init_context(llama_model *model) { static llama_context* init_context(llama_model *model) {
if (!model) { if (!model) {
LOGe("init_context(): model cannot be null"); LOGe("init_context(): model cannot be null");
return -1; return nullptr;
} }
// Multi-threading setup // Multi-threading setup
@ -128,15 +128,13 @@ static int init_context(llama_model *model) {
ctx_params.n_threads = n_threads; ctx_params.n_threads = n_threads;
ctx_params.n_threads_batch = n_threads; ctx_params.n_threads_batch = n_threads;
auto *context = llama_init_from_model(g_model, ctx_params); auto *context = llama_init_from_model(g_model, ctx_params);
if (!context) { if (context == nullptr) {
LOGe("llama_new_context_with_model() returned null)"); LOGe("llama_new_context_with_model() returned null)");
return -2;
} }
g_context = context; return context;
return 0;
} }
static void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) { static llama_batch * new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated. // Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
auto *batch = new llama_batch{ auto *batch = new llama_batch{
0, 0,
@ -161,22 +159,23 @@ static void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
} }
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
g_batch = batch; return batch;
} }
void new_sampler(float temp) { static common_sampler* new_sampler(float temp) {
common_params_sampling sparams; common_params_sampling sparams;
sparams.temp = temp; sparams.temp = temp;
g_sampler = common_sampler_init(g_model, sparams); return common_sampler_init(g_model, sparams);
} }
extern "C" extern "C"
JNIEXPORT jint JNICALL JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) { Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
int ret = init_context(g_model); auto *context = init_context(g_model);
if (ret != 0) { return ret; } if (!context) { return -1; }
new_batch(BATCH_SIZE); g_context = context;
new_sampler(SAMPLER_TEMP); g_batch = new_batch(BATCH_SIZE);
g_sampler = new_sampler(SAMPLER_TEMP);
g_chat_templates = common_chat_templates_init(g_model, ""); g_chat_templates = common_chat_templates_init(g_model, "");
return 0; return 0;
} }
@ -396,6 +395,8 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
} }
// TODO-hyin: handle context overflow
// Decode system tokens in batches // Decode system tokens in batches
if (decode_tokens_in_batches(g_context, system_tokens, current_position)) { if (decode_tokens_in_batches(g_context, system_tokens, current_position)) {
LOGe("llama_decode() failed!"); LOGe("llama_decode() failed!");
@ -439,6 +440,7 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
} }
// TODO-hyin: handle context overflow
// Check if context space is enough for desired tokens // Check if context space is enough for desired tokens
int desired_budget = current_position + (int) user_tokens.size() + n_predict; int desired_budget = current_position + (int) user_tokens.size() + n_predict;
if (desired_budget > llama_n_ctx(g_context)) { if (desired_budget > llama_n_ctx(g_context)) {
@ -494,7 +496,7 @@ static bool is_valid_utf8(const char *string) {
extern "C" extern "C"
JNIEXPORT jstring JNICALL JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_predictLoop( Java_android_llama_cpp_LLamaAndroid_completionLoop(
JNIEnv *env, JNIEnv *env,
jobject /*unused*/ jobject /*unused*/
) { ) {

View File

@ -24,8 +24,8 @@ class LLamaAndroid {
private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
private external fun processSystemPrompt(systemPrompt: String): Int private external fun processSystemPrompt(systemPrompt: String): Int
private external fun processUserPrompt(userPrompt: String, nPredict: Int): Int private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
private external fun predictLoop(): String? private external fun completionLoop(): String?
/** /**
* Thread local state * Thread local state
@ -128,7 +128,7 @@ class LLamaAndroid {
Log.i(TAG, "User prompt processed! Generating assistant prompt...") Log.i(TAG, "User prompt processed! Generating assistant prompt...")
while (true) { while (true) {
predictLoop()?.let { utf8token -> completionLoop()?.let { utf8token ->
if (utf8token.isNotEmpty()) emit(utf8token) if (utf8token.isNotEmpty()) emit(utf8token)
} ?: break } ?: break
} }