Rewrite llama-android JNI implementation
This commit is contained in:
parent
d4ab3832cf
commit
44720859d6
|
|
@ -45,7 +45,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
|
||||||
messages += ""
|
messages += ""
|
||||||
|
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
llamaAndroid.send(text)
|
llamaAndroid.sendMessage(text)
|
||||||
.catch {
|
.catch {
|
||||||
Log.e(tag, "send() failed", it)
|
Log.e(tag, "send() failed", it)
|
||||||
messages += it.message!!
|
messages += it.message!!
|
||||||
|
|
|
||||||
|
|
@ -8,57 +8,30 @@
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Logging utils
|
||||||
|
*/
|
||||||
#define TAG "llama-android.cpp"
|
#define TAG "llama-android.cpp"
|
||||||
|
#define LOGd(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__)
|
||||||
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
||||||
|
#define LOGw(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
|
||||||
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
|
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
|
||||||
|
|
||||||
constexpr int CONTEXT_SIZE = 2048;
|
/**
|
||||||
|
* LLama resources: context, model, batch and sampler
|
||||||
|
*/
|
||||||
|
constexpr int N_THREADS_MIN = 1;
|
||||||
|
constexpr int N_THREADS_MAX = 8;
|
||||||
|
constexpr int N_THREADS_HEADROOM = 2;
|
||||||
|
|
||||||
|
constexpr int CONTEXT_SIZE = 4096;
|
||||||
|
constexpr int BATCH_SIZE = 512;
|
||||||
constexpr float SAMPLER_TEMP = 0.3f;
|
constexpr float SAMPLER_TEMP = 0.3f;
|
||||||
|
|
||||||
constexpr int N_THREADS_MIN = 1;
|
llama_model * model;
|
||||||
constexpr int N_THREADS_MAX = 8;
|
llama_context * context;
|
||||||
constexpr int N_THREADS_HEADROOM = 2;
|
llama_batch * batch;
|
||||||
|
common_sampler * sampler;
|
||||||
jclass la_int_var;
|
|
||||||
jmethodID la_int_var_value;
|
|
||||||
jmethodID la_int_var_inc;
|
|
||||||
|
|
||||||
std::string cached_token_chars;
|
|
||||||
|
|
||||||
bool is_valid_utf8(const char *string) {
|
|
||||||
if (!string) { return true; }
|
|
||||||
|
|
||||||
const auto *bytes = (const unsigned char *) string;
|
|
||||||
int num;
|
|
||||||
|
|
||||||
while (*bytes != 0x00) {
|
|
||||||
if ((*bytes & 0x80) == 0x00) {
|
|
||||||
// U+0000 to U+007F
|
|
||||||
num = 1;
|
|
||||||
} else if ((*bytes & 0xE0) == 0xC0) {
|
|
||||||
// U+0080 to U+07FF
|
|
||||||
num = 2;
|
|
||||||
} else if ((*bytes & 0xF0) == 0xE0) {
|
|
||||||
// U+0800 to U+FFFF
|
|
||||||
num = 3;
|
|
||||||
} else if ((*bytes & 0xF8) == 0xF0) {
|
|
||||||
// U+10000 to U+10FFFF
|
|
||||||
num = 4;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bytes += 1;
|
|
||||||
for (int i = 1; i < num; ++i) {
|
|
||||||
if ((*bytes & 0xC0) != 0x80) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bytes += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
||||||
int priority;
|
int priority;
|
||||||
|
|
@ -83,105 +56,129 @@ static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jlong JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) {
|
Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||||
llama_model_params model_params = llama_model_default_params();
|
llama_log_set(log_callback, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
auto path_to_model = env->GetStringUTFChars(filename, 0);
|
|
||||||
LOGi("Loading model from %s", path_to_model);
|
|
||||||
|
|
||||||
auto model = llama_model_load_from_file(path_to_model, model_params);
|
extern "C"
|
||||||
env->ReleaseStringUTFChars(filename, path_to_model);
|
JNIEXPORT jstring JNICALL
|
||||||
|
Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject /*unused*/) {
|
||||||
if (!model) {
|
return env->NewStringUTF(llama_print_system_info());
|
||||||
LOGe("load_model() failed");
|
|
||||||
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed");
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return reinterpret_cast<jlong>(model);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) {
|
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||||
llama_model_free(reinterpret_cast<llama_model *>(model));
|
llama_backend_init();
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jlong JNICALL
|
JNIEXPORT jint JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) {
|
Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) {
|
||||||
auto model = reinterpret_cast<llama_model *>(jmodel);
|
llama_model_params model_params = llama_model_default_params();
|
||||||
|
|
||||||
|
const auto *path_to_model = env->GetStringUTFChars(filename, 0);
|
||||||
|
LOGi("Loading model from: %s", path_to_model);
|
||||||
|
|
||||||
|
model = llama_model_load_from_file(path_to_model, model_params);
|
||||||
|
env->ReleaseStringUTFChars(filename, path_to_model);
|
||||||
|
|
||||||
if (!model) {
|
if (!model) {
|
||||||
LOGe("new_context(): model cannot be null");
|
LOGe("load_model() failed");
|
||||||
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null");
|
return -1;
|
||||||
return 0;
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int init_context() {
|
||||||
|
if (!model) {
|
||||||
|
LOGe("init_context(): model cannot be null");
|
||||||
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Multi-threading setup
|
||||||
int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
|
int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
|
||||||
(int) sysconf(_SC_NPROCESSORS_ONLN) -
|
(int) sysconf(_SC_NPROCESSORS_ONLN) -
|
||||||
N_THREADS_HEADROOM));
|
N_THREADS_HEADROOM));
|
||||||
LOGi("Using %d threads", n_threads);
|
LOGi("Using %d threads", n_threads);
|
||||||
|
|
||||||
|
// Context parameters setup
|
||||||
llama_context_params ctx_params = llama_context_default_params();
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
|
|
||||||
ctx_params.n_ctx = CONTEXT_SIZE;
|
ctx_params.n_ctx = CONTEXT_SIZE;
|
||||||
ctx_params.n_threads = n_threads;
|
ctx_params.n_threads = n_threads;
|
||||||
ctx_params.n_threads_batch = n_threads;
|
ctx_params.n_threads_batch = n_threads;
|
||||||
|
|
||||||
llama_context *context = llama_init_from_model(model, ctx_params);
|
context = llama_init_from_model(model, ctx_params);
|
||||||
|
|
||||||
if (!context) {
|
if (!context) {
|
||||||
LOGe("llama_new_context_with_model() returned null)");
|
LOGe("llama_new_context_with_model() returned null)");
|
||||||
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
|
return -2;
|
||||||
"llama_new_context_with_model() returned null)");
|
}
|
||||||
return 0;
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) {
|
||||||
|
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
||||||
|
batch = new llama_batch{
|
||||||
|
0,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (embd) {
|
||||||
|
batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
|
||||||
|
} else {
|
||||||
|
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
return reinterpret_cast<jlong>(context);
|
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
|
||||||
|
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
|
||||||
|
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
|
||||||
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
|
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
||||||
|
}
|
||||||
|
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
void new_sampler(float temp) {
|
||||||
|
common_params_sampling sparams;
|
||||||
|
sparams.temp = temp;
|
||||||
|
sampler = common_sampler_init(model, sparams);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT jint JNICALL
|
||||||
|
Java_android_llama_cpp_LLamaAndroid_ctx_1init(JNIEnv * /*env*/, jobject /*unused*/) {
|
||||||
|
int ret = init_context();
|
||||||
|
if (ret != 0) { return ret; }
|
||||||
|
new_batch(BATCH_SIZE);
|
||||||
|
new_sampler(SAMPLER_TEMP);
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) {
|
Java_android_llama_cpp_LLamaAndroid_clean_1up(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||||
llama_free(reinterpret_cast<llama_context *>(context));
|
llama_model_free(model);
|
||||||
}
|
llama_free(context);
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) {
|
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
}
|
delete batch;
|
||||||
|
common_sampler_free(sampler);
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) {
|
|
||||||
llama_log_set(log_callback, nullptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jstring JNICALL
|
JNIEXPORT jstring JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
Java_android_llama_cpp_LLamaAndroid_bench_1model(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) {
|
||||||
JNIEnv *env,
|
|
||||||
jobject,
|
|
||||||
jlong context_pointer,
|
|
||||||
jlong model_pointer,
|
|
||||||
jlong batch_pointer,
|
|
||||||
jint pp,
|
|
||||||
jint tg,
|
|
||||||
jint pl,
|
|
||||||
jint nr
|
|
||||||
) {
|
|
||||||
auto pp_avg = 0.0;
|
auto pp_avg = 0.0;
|
||||||
auto tg_avg = 0.0;
|
auto tg_avg = 0.0;
|
||||||
auto pp_std = 0.0;
|
auto pp_std = 0.0;
|
||||||
auto tg_std = 0.0;
|
auto tg_std = 0.0;
|
||||||
|
|
||||||
auto *const context = reinterpret_cast<llama_context *>(context_pointer);
|
|
||||||
auto *const model = reinterpret_cast<llama_model *>(model_pointer);
|
|
||||||
auto *const batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
|
||||||
|
|
||||||
const uint32_t n_ctx = llama_n_ctx(context);
|
const uint32_t n_ctx = llama_n_ctx(context);
|
||||||
|
|
||||||
LOGi("n_ctx = %d", n_ctx);
|
LOGi("n_ctx = %d", n_ctx);
|
||||||
|
|
@ -203,7 +200,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
||||||
|
|
||||||
const auto t_pp_start = ggml_time_us();
|
const auto t_pp_start = ggml_time_us();
|
||||||
if (llama_decode(context, *batch) != 0) {
|
if (llama_decode(context, *batch) != 0) {
|
||||||
LOGi("llama_decode() failed during prompt processing");
|
LOGw("llama_decode() failed during prompt processing");
|
||||||
}
|
}
|
||||||
const auto t_pp_end = ggml_time_us();
|
const auto t_pp_end = ggml_time_us();
|
||||||
|
|
||||||
|
|
@ -222,7 +219,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
||||||
|
|
||||||
LOGi("llama_decode() text generation: %d", i);
|
LOGi("llama_decode() text generation: %d", i);
|
||||||
if (llama_decode(context, *batch) != 0) {
|
if (llama_decode(context, *batch) != 0) {
|
||||||
LOGi("llama_decode() failed during text generation");
|
LOGw("llama_decode() failed during text generation");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -276,187 +273,181 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
||||||
return env->NewStringUTF(result.str().c_str());
|
return env->NewStringUTF(result.str().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT jlong JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd,
|
|
||||||
jint n_seq_max) {
|
|
||||||
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
|
||||||
|
|
||||||
auto *batch = new llama_batch{
|
/**
|
||||||
0,
|
* Prediction loop's states
|
||||||
nullptr,
|
*/
|
||||||
nullptr,
|
int current_position;
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (embd) {
|
int token_predict_budget;
|
||||||
batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
|
std::string cached_token_chars;
|
||||||
} else {
|
|
||||||
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
|
|
||||||
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
|
|
||||||
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
|
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
|
||||||
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
|
||||||
}
|
|
||||||
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
|
||||||
|
|
||||||
return reinterpret_cast<jlong>(batch);
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
|
|
||||||
//llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer)); // TODO: what is this?
|
|
||||||
auto *const batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
|
||||||
delete batch;
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT jlong JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject, jlong model_pointer) {
|
|
||||||
auto *const model = reinterpret_cast<llama_model *>(model_pointer);
|
|
||||||
common_params_sampling sparams;
|
|
||||||
sparams.temp = SAMPLER_TEMP;
|
|
||||||
common_sampler *smpl = common_sampler_init(model, sparams);
|
|
||||||
|
|
||||||
return reinterpret_cast<jlong>(smpl);
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) {
|
|
||||||
// Properly cast from jlong to pointer type
|
|
||||||
auto* sampler = (llama_sampler*)(void*)(sampler_pointer);
|
|
||||||
llama_sampler_free(sampler);
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
|
|
||||||
llama_backend_init();
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT jstring JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) {
|
|
||||||
return env->NewStringUTF(llama_print_system_info());
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jint JNICALL
|
JNIEXPORT jint JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
Java_android_llama_cpp_LLamaAndroid_process_1system_1prompt(
|
||||||
JNIEnv *env,
|
JNIEnv *env,
|
||||||
jobject,
|
jobject /*unused*/,
|
||||||
jlong context_pointer,
|
jstring jsystem_prompt
|
||||||
jlong batch_pointer,
|
|
||||||
jstring jtext,
|
|
||||||
jboolean format_chat,
|
|
||||||
jint n_len
|
|
||||||
) {
|
) {
|
||||||
|
// Reset long-term states and reset KV cache
|
||||||
|
current_position = 0;
|
||||||
|
llama_memory_clear(llama_get_memory(context), false);
|
||||||
|
|
||||||
|
// Reset short-term states
|
||||||
|
token_predict_budget = 0;
|
||||||
cached_token_chars.clear();
|
cached_token_chars.clear();
|
||||||
|
|
||||||
const auto *const text = env->GetStringUTFChars(jtext, 0);
|
// Obtain and tokenize system prompt
|
||||||
auto *const context = reinterpret_cast<llama_context *>(context_pointer);
|
const auto *const system_text = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
||||||
auto *const batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
LOGi("System prompt: \n%s", system_text);
|
||||||
|
const auto system_tokens = common_tokenize(context, system_text, true, true);
|
||||||
bool parse_special = (format_chat == JNI_TRUE);
|
env->ReleaseStringUTFChars(jsystem_prompt, system_text);
|
||||||
const auto tokens_list = common_tokenize(context, text, true, parse_special);
|
|
||||||
|
|
||||||
auto n_ctx = llama_n_ctx(context);
|
|
||||||
auto n_tokens = tokens_list.size();
|
|
||||||
auto n_kv_req = n_tokens + n_len;
|
|
||||||
LOGi("n_ctx = %d, n_kv_req = n_tokens (%d) + n_len (%d) = %d", n_ctx, n_tokens, n_len, n_kv_req);
|
|
||||||
|
|
||||||
if (n_kv_req > n_ctx) {
|
|
||||||
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto id: tokens_list) {
|
|
||||||
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Add system prompt tokens to batch
|
||||||
common_batch_clear(*batch);
|
common_batch_clear(*batch);
|
||||||
|
// TODO-hyin: support batch processing!
|
||||||
// evaluate the initial prompt
|
for (int i = 0; i < system_tokens.size(); i++) {
|
||||||
for (auto i = 0; i < tokens_list.size(); i++) {
|
common_batch_add(*batch, system_tokens[i], i, {0}, false);
|
||||||
common_batch_add(*batch, tokens_list[i], i, {0}, false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// llama_decode will output logits only for the last token of the prompt
|
// Decode batch
|
||||||
batch->logits[batch->n_tokens - 1] = true;
|
int decode_result = llama_decode(context, *batch);
|
||||||
|
if (decode_result != 0) {
|
||||||
if (llama_decode(context, *batch) != 0) {
|
LOGe("llama_decode() failed: %d", decode_result);
|
||||||
LOGe("llama_decode() failed");
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
env->ReleaseStringUTFChars(jtext, text);
|
// Update position
|
||||||
|
current_position = system_tokens.size();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
return batch->n_tokens;
|
// TODO-hyin: support KV cache backtracking
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT jint JNICALL
|
||||||
|
Java_android_llama_cpp_LLamaAndroid_process_1user_1prompt(
|
||||||
|
JNIEnv *env,
|
||||||
|
jobject /*unused*/,
|
||||||
|
jstring juser_prompt,
|
||||||
|
jint n_len
|
||||||
|
) {
|
||||||
|
// Reset short-term states
|
||||||
|
token_predict_budget = 0;
|
||||||
|
cached_token_chars.clear();
|
||||||
|
|
||||||
|
// Obtain and tokenize user prompt
|
||||||
|
const auto *const user_text = env->GetStringUTFChars(juser_prompt, nullptr);
|
||||||
|
LOGi("User prompt: \n%s", user_text);
|
||||||
|
const auto user_tokens = common_tokenize(context, user_text, true, true);
|
||||||
|
env->ReleaseStringUTFChars(juser_prompt, user_text);
|
||||||
|
|
||||||
|
// Check if context space is enough for desired tokens
|
||||||
|
int desired_budget = current_position + user_tokens.size() + n_len;
|
||||||
|
if (desired_budget > llama_n_ctx(context)) {
|
||||||
|
LOGe("error: total tokens exceed context size");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
token_predict_budget = desired_budget;
|
||||||
|
|
||||||
|
// Add user prompt tokens to batch
|
||||||
|
common_batch_clear(*batch);
|
||||||
|
for (int i = 0; i < user_tokens.size(); i++) {
|
||||||
|
common_batch_add(*batch, user_tokens[i], current_position + i, {0}, false);
|
||||||
|
}
|
||||||
|
batch->logits[batch->n_tokens - 1] = true; // Set logits true only for last token
|
||||||
|
|
||||||
|
// Decode batch
|
||||||
|
int decode_result = llama_decode(context, *batch);
|
||||||
|
if (decode_result != 0) {
|
||||||
|
LOGe("llama_decode() failed: %d", decode_result);
|
||||||
|
return -2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update position
|
||||||
|
current_position += user_tokens.size(); // Update position
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_valid_utf8(const char *string) {
|
||||||
|
if (!string) { return true; }
|
||||||
|
|
||||||
|
const auto *bytes = (const unsigned char *) string;
|
||||||
|
int num;
|
||||||
|
|
||||||
|
while (*bytes != 0x00) {
|
||||||
|
if ((*bytes & 0x80) == 0x00) {
|
||||||
|
// U+0000 to U+007F
|
||||||
|
num = 1;
|
||||||
|
} else if ((*bytes & 0xE0) == 0xC0) {
|
||||||
|
// U+0080 to U+07FF
|
||||||
|
num = 2;
|
||||||
|
} else if ((*bytes & 0xF0) == 0xE0) {
|
||||||
|
// U+0800 to U+FFFF
|
||||||
|
num = 3;
|
||||||
|
} else if ((*bytes & 0xF8) == 0xF0) {
|
||||||
|
// U+10000 to U+10FFFF
|
||||||
|
num = 4;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes += 1;
|
||||||
|
for (int i = 1; i < num; ++i) {
|
||||||
|
if ((*bytes & 0xC0) != 0x80) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bytes += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jstring JNICALL
|
JNIEXPORT jstring JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
Java_android_llama_cpp_LLamaAndroid_predict_1loop(
|
||||||
JNIEnv *env,
|
JNIEnv *env,
|
||||||
jobject,
|
jobject /*unused*/
|
||||||
jlong context_pointer,
|
|
||||||
jlong batch_pointer,
|
|
||||||
jlong sampler_pointer,
|
|
||||||
jint n_len,
|
|
||||||
jobject intvar_ncur
|
|
||||||
) {
|
) {
|
||||||
auto *const context = reinterpret_cast<llama_context *>(context_pointer);
|
// Stop if running out of token budget
|
||||||
auto *const batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
if (current_position >= token_predict_budget) {
|
||||||
auto *const sampler = reinterpret_cast<common_sampler *>(sampler_pointer);
|
LOGi("STOP: current position (%d) exceeds budget (%d)", current_position, token_predict_budget);
|
||||||
const auto *const model = llama_get_model(context);
|
|
||||||
const auto *const vocab = llama_model_get_vocab(model);
|
|
||||||
|
|
||||||
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
|
|
||||||
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
|
|
||||||
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
|
|
||||||
|
|
||||||
// sample the most likely token
|
|
||||||
const auto new_token_id = common_sampler_sample(sampler, context, -1);
|
|
||||||
common_sampler_accept(sampler, new_token_id, /* accept_grammar= */ true);
|
|
||||||
|
|
||||||
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
|
||||||
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sample next token
|
||||||
|
const auto new_token_id = common_sampler_sample(sampler, context, -1);
|
||||||
|
common_sampler_accept(sampler, new_token_id, true);
|
||||||
|
|
||||||
|
// Stop if next token is EOG
|
||||||
|
if (llama_vocab_is_eog(llama_model_get_vocab(model), new_token_id)) {
|
||||||
|
LOGi("id: %d,\tIS EOG!\nSTOP.", new_token_id);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the context with the new token
|
||||||
|
common_batch_clear(*batch);
|
||||||
|
common_batch_add(*batch, new_token_id, current_position, {0}, true);
|
||||||
|
if (llama_decode(context, *batch) != 0) {
|
||||||
|
LOGe("llama_decode() failed for generated token");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to text
|
||||||
auto new_token_chars = common_token_to_piece(context, new_token_id);
|
auto new_token_chars = common_token_to_piece(context, new_token_id);
|
||||||
cached_token_chars += new_token_chars;
|
cached_token_chars += new_token_chars;
|
||||||
|
|
||||||
jstring new_token = nullptr;
|
// Create Java string
|
||||||
|
jstring result = nullptr;
|
||||||
if (is_valid_utf8(cached_token_chars.c_str())) {
|
if (is_valid_utf8(cached_token_chars.c_str())) {
|
||||||
new_token = env->NewStringUTF(cached_token_chars.c_str());
|
result = env->NewStringUTF(cached_token_chars.c_str());
|
||||||
LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(),
|
LOGd("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
|
||||||
new_token_chars.c_str(), new_token_id);
|
|
||||||
cached_token_chars.clear();
|
cached_token_chars.clear();
|
||||||
} else {
|
} else {
|
||||||
new_token = env->NewStringUTF("");
|
LOGd("id: %d,\tappend to cache", new_token_id);
|
||||||
|
result = env->NewStringUTF("");
|
||||||
}
|
}
|
||||||
|
|
||||||
common_batch_clear(*batch);
|
// Update position
|
||||||
common_batch_add(*batch, new_token_id, n_cur, {0}, true);
|
current_position++;
|
||||||
|
return result;
|
||||||
env->CallVoidMethod(intvar_ncur, la_int_var_inc);
|
|
||||||
|
|
||||||
if (llama_decode(context, *batch) != 0) {
|
|
||||||
LOGe("llama_decode() returned null");
|
|
||||||
}
|
|
||||||
|
|
||||||
return new_token;
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
|
|
||||||
llama_memory_clear(llama_get_memory(reinterpret_cast<llama_context *>(context)), true);
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ class LLamaAndroid {
|
||||||
|
|
||||||
// Set llama log handler to Android
|
// Set llama log handler to Android
|
||||||
log_to_android()
|
log_to_android()
|
||||||
backend_init(false)
|
backend_init()
|
||||||
|
|
||||||
Log.d(tag, system_info())
|
Log.d(tag, system_info())
|
||||||
|
|
||||||
|
|
@ -37,103 +37,96 @@ class LLamaAndroid {
|
||||||
}.asCoroutineDispatcher()
|
}.asCoroutineDispatcher()
|
||||||
|
|
||||||
private external fun log_to_android()
|
private external fun log_to_android()
|
||||||
private external fun load_model(filename: String): Long
|
|
||||||
private external fun free_model(model: Long)
|
|
||||||
private external fun new_context(model: Long): Long
|
|
||||||
private external fun free_context(context: Long)
|
|
||||||
private external fun backend_init(numa: Boolean)
|
|
||||||
private external fun backend_free()
|
|
||||||
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
|
|
||||||
private external fun free_batch(batch: Long)
|
|
||||||
private external fun new_sampler(model: Long): Long
|
|
||||||
private external fun free_sampler(sampler: Long)
|
|
||||||
private external fun bench_model(
|
|
||||||
context: Long,
|
|
||||||
model: Long,
|
|
||||||
batch: Long,
|
|
||||||
pp: Int,
|
|
||||||
tg: Int,
|
|
||||||
pl: Int,
|
|
||||||
nr: Int
|
|
||||||
): String
|
|
||||||
|
|
||||||
private external fun system_info(): String
|
private external fun system_info(): String
|
||||||
|
private external fun backend_init()
|
||||||
|
|
||||||
private external fun completion_init(
|
private external fun load_model(filename: String): Int
|
||||||
context: Long,
|
private external fun ctx_init(): Int
|
||||||
batch: Long,
|
private external fun clean_up()
|
||||||
text: String,
|
|
||||||
formatChat: Boolean,
|
|
||||||
nLen: Int
|
|
||||||
): Int
|
|
||||||
|
|
||||||
private external fun completion_loop(
|
private external fun bench_model(pp: Int, tg: Int, pl: Int, nr: Int): String
|
||||||
context: Long,
|
|
||||||
batch: Long,
|
|
||||||
sampler: Long,
|
|
||||||
nLen: Int,
|
|
||||||
ncur: IntVar
|
|
||||||
): String?
|
|
||||||
|
|
||||||
private external fun kv_cache_clear(context: Long)
|
private external fun process_system_prompt(system_prompt: String): Int
|
||||||
|
private external fun process_user_prompt(user_prompt: String, nLen: Int): Int
|
||||||
|
private external fun predict_loop(): String?
|
||||||
|
|
||||||
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
|
suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) {
|
||||||
return withContext(runLoop) {
|
|
||||||
when (val state = threadLocalState.get()) {
|
|
||||||
is State.Loaded -> {
|
|
||||||
Log.d(tag, "bench(): $state")
|
|
||||||
bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
|
|
||||||
}
|
|
||||||
|
|
||||||
else -> throw IllegalStateException("No model loaded")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
suspend fun load(pathToModel: String) {
|
|
||||||
withContext(runLoop) {
|
withContext(runLoop) {
|
||||||
when (threadLocalState.get()) {
|
when (threadLocalState.get()) {
|
||||||
is State.Idle -> {
|
is State.Idle -> {
|
||||||
val model = load_model(pathToModel)
|
val model = load_model(pathToModel)
|
||||||
if (model == 0L) throw IllegalStateException("load_model() failed")
|
if (model != 0) throw IllegalStateException("Load model failed")
|
||||||
|
|
||||||
val context = new_context(model)
|
val result = ctx_init()
|
||||||
if (context == 0L) throw IllegalStateException("new_context() failed")
|
if (result != 0) throw IllegalStateException("Initialization failed with error code: $result")
|
||||||
|
|
||||||
val batch = new_batch(DEFAULT_BATCH_SIZE, 0, 1)
|
|
||||||
if (batch == 0L) throw IllegalStateException("new_batch() failed")
|
|
||||||
|
|
||||||
val sampler = new_sampler(model)
|
|
||||||
if (sampler == 0L) throw IllegalStateException("new_sampler() failed")
|
|
||||||
|
|
||||||
Log.i(tag, "Loaded model $pathToModel")
|
Log.i(tag, "Loaded model $pathToModel")
|
||||||
threadLocalState.set(State.Loaded(model, context, batch, sampler))
|
threadLocalState.set(State.ModelLoaded)
|
||||||
|
|
||||||
|
formattedSystemPrompt?.let {
|
||||||
|
initWithSystemPrompt(formattedSystemPrompt)
|
||||||
|
} ?: {
|
||||||
|
threadLocalState.set(State.ReadyForUserPrompt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else -> throw IllegalStateException("Model already loaded")
|
else -> throw IllegalStateException("Model already loaded")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun send(
|
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
|
||||||
message: String,
|
return withContext(runLoop) {
|
||||||
formatChat: Boolean = false,
|
when (val state = threadLocalState.get()) {
|
||||||
predictLength: Int = DEFAULT_PREDICT_LENGTH,
|
is State.ModelLoaded -> {
|
||||||
): Flow<String> = flow {
|
Log.d(tag, "bench(): $state")
|
||||||
when (val state = threadLocalState.get()) {
|
bench_model(pp, tg, pl, nr)
|
||||||
is State.Loaded -> {
|
|
||||||
val nCur = IntVar(
|
|
||||||
completion_init(state.context, state.batch, message, formatChat, predictLength)
|
|
||||||
)
|
|
||||||
|
|
||||||
while (nCur.value <= predictLength) {
|
|
||||||
val str = completion_loop(
|
|
||||||
state.context, state.batch, state.sampler, predictLength, nCur
|
|
||||||
) ?: break
|
|
||||||
|
|
||||||
emit(str)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// kv_cache_clear(state.context)
|
// TODO-hyin: catch exception in ViewController; disable button when state incorrect
|
||||||
|
else -> throw IllegalStateException("No model loaded")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private suspend fun initWithSystemPrompt(systemPrompt: String) {
|
||||||
|
withContext(runLoop) {
|
||||||
|
when (threadLocalState.get()) {
|
||||||
|
is State.ModelLoaded -> {
|
||||||
|
process_system_prompt(systemPrompt).let {
|
||||||
|
if (it != 0) {
|
||||||
|
throw IllegalStateException("Failed to process system prompt: $it")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.i(tag, "System prompt processed!")
|
||||||
|
threadLocalState.set(State.ReadyForUserPrompt)
|
||||||
|
}
|
||||||
|
else -> throw IllegalStateException("Model not loaded")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun sendMessage(
|
||||||
|
formattedUserPrompt: String,
|
||||||
|
nPredict: Int = DEFAULT_PREDICT_LENGTH,
|
||||||
|
): Flow<String> = flow {
|
||||||
|
when (threadLocalState.get()) {
|
||||||
|
is State.ReadyForUserPrompt -> {
|
||||||
|
process_user_prompt(formattedUserPrompt, nPredict).let {
|
||||||
|
if (it != 0) {
|
||||||
|
Log.e(tag, "Failed to process user prompt: $it")
|
||||||
|
return@flow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.i(tag, "User prompt processed! Generating assistant prompt...")
|
||||||
|
while (true) {
|
||||||
|
val str = predict_loop() ?: break
|
||||||
|
if (str.isNotEmpty()) {
|
||||||
|
emit(str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Log.i(tag, "Assistant generation complete!")
|
||||||
}
|
}
|
||||||
else -> {}
|
else -> {}
|
||||||
}
|
}
|
||||||
|
|
@ -146,13 +139,9 @@ class LLamaAndroid {
|
||||||
*/
|
*/
|
||||||
suspend fun unload() {
|
suspend fun unload() {
|
||||||
withContext(runLoop) {
|
withContext(runLoop) {
|
||||||
when (val state = threadLocalState.get()) {
|
when (threadLocalState.get()) {
|
||||||
is State.Loaded -> {
|
is State.ModelLoaded -> {
|
||||||
free_context(state.context)
|
clean_up()
|
||||||
free_model(state.model)
|
|
||||||
free_batch(state.batch)
|
|
||||||
free_sampler(state.sampler);
|
|
||||||
|
|
||||||
threadLocalState.set(State.Idle)
|
threadLocalState.set(State.Idle)
|
||||||
}
|
}
|
||||||
else -> {}
|
else -> {}
|
||||||
|
|
@ -161,24 +150,12 @@ class LLamaAndroid {
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private const val DEFAULT_BATCH_SIZE = 512
|
|
||||||
private const val DEFAULT_PREDICT_LENGTH = 128
|
private const val DEFAULT_PREDICT_LENGTH = 128
|
||||||
|
|
||||||
private class IntVar(value: Int) {
|
|
||||||
@Volatile
|
|
||||||
var value: Int = value
|
|
||||||
private set
|
|
||||||
|
|
||||||
fun inc() {
|
|
||||||
synchronized(this) {
|
|
||||||
value += 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private sealed interface State {
|
private sealed interface State {
|
||||||
data object Idle: State
|
data object Idle: State
|
||||||
data class Loaded(val model: Long, val context: Long, val batch: Long, val sampler: Long): State
|
data object ModelLoaded: State
|
||||||
|
data object ReadyForUserPrompt: State
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enforce only one instance of Llm.
|
// Enforce only one instance of Llm.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue