Polish: adopt common naming; init modularization;

This commit is contained in:
Han Yin 2025-04-07 20:48:15 -07:00
parent 8bf2f4d412
commit 4809112ec5
2 changed files with 20 additions and 18 deletions

View File

@ -108,10 +108,10 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file
return 0;
}
static int init_context(llama_model *model) {
static llama_context* init_context(llama_model *model) {
if (!model) {
LOGe("init_context(): model cannot be null");
return -1;
return nullptr;
}
// Multi-threading setup
@ -128,15 +128,13 @@ static int init_context(llama_model *model) {
ctx_params.n_threads = n_threads;
ctx_params.n_threads_batch = n_threads;
auto *context = llama_init_from_model(g_model, ctx_params);
if (!context) {
if (context == nullptr) {
LOGe("llama_new_context_with_model() returned null)");
return -2;
}
g_context = context;
return 0;
return context;
}
static void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
static llama_batch * new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
auto *batch = new llama_batch{
0,
@ -161,22 +159,23 @@ static void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
}
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
g_batch = batch;
return batch;
}
void new_sampler(float temp) {
static common_sampler* new_sampler(float temp) {
common_params_sampling sparams;
sparams.temp = temp;
g_sampler = common_sampler_init(g_model, sparams);
return common_sampler_init(g_model, sparams);
}
extern "C"
JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
int ret = init_context(g_model);
if (ret != 0) { return ret; }
new_batch(BATCH_SIZE);
new_sampler(SAMPLER_TEMP);
auto *context = init_context(g_model);
if (!context) { return -1; }
g_context = context;
g_batch = new_batch(BATCH_SIZE);
g_sampler = new_sampler(SAMPLER_TEMP);
g_chat_templates = common_chat_templates_init(g_model, "");
return 0;
}
@ -396,6 +395,8 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
}
// TODO-hyin: handle context overflow
// Decode system tokens in batches
if (decode_tokens_in_batches(g_context, system_tokens, current_position)) {
LOGe("llama_decode() failed!");
@ -439,6 +440,7 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
}
// TODO-hyin: handle context overflow
// Check if context space is enough for desired tokens
int desired_budget = current_position + (int) user_tokens.size() + n_predict;
if (desired_budget > llama_n_ctx(g_context)) {
@ -494,7 +496,7 @@ static bool is_valid_utf8(const char *string) {
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_predictLoop(
Java_android_llama_cpp_LLamaAndroid_completionLoop(
JNIEnv *env,
jobject /*unused*/
) {

View File

@ -24,8 +24,8 @@ class LLamaAndroid {
private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
private external fun processSystemPrompt(systemPrompt: String): Int
private external fun processUserPrompt(userPrompt: String, nPredict: Int): Int
private external fun predictLoop(): String?
private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
private external fun completionLoop(): String?
/**
* Thread local state
@ -128,7 +128,7 @@ class LLamaAndroid {
Log.i(TAG, "User prompt processed! Generating assistant prompt...")
while (true) {
predictLoop()?.let { utf8token ->
completionLoop()?.let { utf8token ->
if (utf8token.isNotEmpty()) emit(utf8token)
} ?: break
}