diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 711ddc5d19..c1c8fa380f 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -1,46 +1,31 @@ #include #include #include -#include +#include #include #include #include "llama.h" #include "common.h" -// Write C++ code here. -// -// Do not forget to dynamically load the C++ library into your application. -// -// For instance, -// -// In MainActivity.java: -// static { -// System.loadLibrary("llama-android"); -// } -// -// Or, in MainActivity.kt: -// companion object { -// init { -// System.loadLibrary("llama-android") -// } -// } - #define TAG "llama-android.cpp" #define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) #define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) +constexpr int CONTEXT_SIZE = 2048; +constexpr int N_THREADS_MIN = 1; +constexpr int N_THREADS_MAX = 8; +constexpr int N_THREADS_HEADROOM = 2; + jclass la_int_var; jmethodID la_int_var_value; jmethodID la_int_var_inc; std::string cached_token_chars; -bool is_valid_utf8(const char * string) { - if (!string) { - return true; - } +bool is_valid_utf8(const char *string) { + if (!string) { return true; } - const unsigned char * bytes = (const unsigned char *)string; + const auto *bytes = (const unsigned char *) string; int num; while (*bytes != 0x00) { @@ -72,11 +57,26 @@ bool is_valid_utf8(const char * string) { return true; } -static void log_callback(ggml_log_level level, const char * fmt, void * data) { - if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data); - else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data); - else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data); - else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data); +static void log_callback(ggml_log_level level, const char *fmt, void *data) { + int priority; + switch (level) { + case GGML_LOG_LEVEL_ERROR: + priority = ANDROID_LOG_ERROR; + break; + case GGML_LOG_LEVEL_WARN: + priority = GGML_LOG_LEVEL_WARN; + break; + case GGML_LOG_LEVEL_INFO: + priority = GGML_LOG_LEVEL_INFO; + break; + case GGML_LOG_LEVEL_DEBUG: + priority = GGML_LOG_LEVEL_DEBUG; + break; + default: + priority = ANDROID_LOG_DEFAULT; + break; + } + __android_log_print(priority, TAG, fmt, data); } extern "C" @@ -116,16 +116,18 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo return 0; } - int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2)); + int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX, + (int) sysconf(_SC_NPROCESSORS_ONLN) - + N_THREADS_HEADROOM)); LOGi("Using %d threads", n_threads); llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = 2048; - ctx_params.n_threads = n_threads; + ctx_params.n_ctx = CONTEXT_SIZE; + ctx_params.n_threads = n_threads; ctx_params.n_threads_batch = n_threads; - llama_context * context = llama_new_context_with_model(model, ctx_params); + llama_context *context = llama_init_from_model(model, ctx_params); if (!context) { LOGe("llama_new_context_with_model() returned null)"); @@ -152,7 +154,7 @@ Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) { extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) { - llama_log_set(log_callback, NULL); + llama_log_set(log_callback, nullptr); } extern "C" @@ -167,17 +169,17 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( jint tg, jint pl, jint nr - ) { +) { auto pp_avg = 0.0; auto tg_avg = 0.0; auto pp_std = 0.0; auto tg_std = 0.0; - const auto context = reinterpret_cast(context_pointer); - const auto model = reinterpret_cast(model_pointer); - const auto batch = reinterpret_cast(batch_pointer); + auto *const context = reinterpret_cast(context_pointer); + auto *const model = reinterpret_cast(model_pointer); + auto *const batch = reinterpret_cast(batch_pointer); - const int n_ctx = llama_n_ctx(context); + const uint32_t n_ctx = llama_n_ctx(context); LOGi("n_ctx = %d", n_ctx); @@ -190,7 +192,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const int n_tokens = pp; for (i = 0; i < n_tokens; i++) { - common_batch_add(*batch, 0, i, { 0 }, false); + common_batch_add(*batch, 0, i, {0}, false); } batch->logits[batch->n_tokens - 1] = true; @@ -212,7 +214,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( common_batch_clear(*batch); for (j = 0; j < pl; j++) { - common_batch_add(*batch, 0, i, { j }, true); + common_batch_add(*batch, 0, i, {j}, true); } LOGi("llama_decode() text generation: %d", i); @@ -254,35 +256,37 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( char model_desc[128]; llama_model_desc(model, model_desc, sizeof(model_desc)); - const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0; + const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0; const auto model_n_params = double(llama_model_n_params(model)) / 1e9; - const auto backend = "(Android)"; // TODO: What should this be? + const auto *const backend = "(Android)"; // TODO: What should this be? std::stringstream result; result << std::setprecision(2); result << "| model | size | params | backend | test | t/s |\n"; result << "| --- | --- | --- | --- | --- | --- |\n"; - result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n"; - result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n"; + result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " + << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n"; + result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " + << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n"; return env->NewStringUTF(result.str().c_str()); } extern "C" JNIEXPORT jlong JNICALL -Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { - +Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, + jint n_seq_max) { // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. - llama_batch *batch = new llama_batch { - 0, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, + auto *batch = new llama_batch{ + 0, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, }; if (embd) { @@ -291,13 +295,13 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); } - batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); - batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); - batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); + batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); + batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); + batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); for (int i = 0; i < n_tokens; ++i) { batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); } - batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); return reinterpret_cast(batch); } @@ -305,8 +309,8 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { - //llama_batch_free(*reinterpret_cast(batch_pointer)); - const auto batch = reinterpret_cast(batch_pointer); + //llama_batch_free(*reinterpret_cast(batch_pointer)); // TODO: what is this? + auto *const batch = reinterpret_cast(batch_pointer); delete batch; } @@ -315,7 +319,7 @@ JNIEXPORT jlong JNICALL Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) { auto sparams = llama_sampler_chain_default_params(); sparams.no_perf = true; - llama_sampler * smpl = llama_sampler_chain_init(sparams); + llama_sampler *smpl = llama_sampler_chain_init(sparams); llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); return reinterpret_cast(smpl); @@ -324,7 +328,9 @@ Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) { extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) { - llama_sampler_free(reinterpret_cast(sampler_pointer)); + // Properly cast from jlong to pointer type + auto* sampler = (llama_sampler*)(void*)(sampler_pointer); + llama_sampler_free(sampler); } extern "C" @@ -349,13 +355,12 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( jstring jtext, jboolean format_chat, jint n_len - ) { - +) { cached_token_chars.clear(); - const auto text = env->GetStringUTFChars(jtext, 0); - const auto context = reinterpret_cast(context_pointer); - const auto batch = reinterpret_cast(batch_pointer); + const auto *const text = env->GetStringUTFChars(jtext, 0); + auto *const context = reinterpret_cast(context_pointer); + auto *const batch = reinterpret_cast(batch_pointer); bool parse_special = (format_chat == JNI_TRUE); const auto tokens_list = common_tokenize(context, text, true, parse_special); @@ -369,7 +374,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); } - for (auto id : tokens_list) { + for (auto id: tokens_list) { LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id); } @@ -377,7 +382,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( // evaluate the initial prompt for (auto i = 0; i < tokens_list.size(); i++) { - common_batch_add(*batch, tokens_list[i], i, { 0 }, false); + common_batch_add(*batch, tokens_list[i], i, {0}, false); } // llama_decode will output logits only for the last token of the prompt @@ -395,7 +400,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( extern "C" JNIEXPORT jstring JNICALL Java_android_llama_cpp_LLamaAndroid_completion_1loop( - JNIEnv * env, + JNIEnv *env, jobject, jlong context_pointer, jlong batch_pointer, @@ -403,11 +408,11 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( jint n_len, jobject intvar_ncur ) { - const auto context = reinterpret_cast(context_pointer); - const auto batch = reinterpret_cast(batch_pointer); - const auto sampler = reinterpret_cast(sampler_pointer); - const auto model = llama_get_model(context); - const auto vocab = llama_model_get_vocab(model); + auto *const context = reinterpret_cast(context_pointer); + auto *const batch = reinterpret_cast(batch_pointer); + auto *const sampler = reinterpret_cast(sampler_pointer); + const auto *const model = llama_get_model(context); + const auto *const vocab = llama_model_get_vocab(model); if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); @@ -427,14 +432,15 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( jstring new_token = nullptr; if (is_valid_utf8(cached_token_chars.c_str())) { new_token = env->NewStringUTF(cached_token_chars.c_str()); - LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id); + LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), + new_token_chars.c_str(), new_token_id); cached_token_chars.clear(); } else { new_token = env->NewStringUTF(""); } common_batch_clear(*batch); - common_batch_add(*batch, new_token_id, n_cur, { 0 }, true); + common_batch_add(*batch, new_token_id, n_cur, {0}, true); env->CallVoidMethod(intvar_ncur, la_int_var_inc); diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index b964d93e37..95a8177426 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -36,8 +36,6 @@ class LLamaAndroid { } }.asCoroutineDispatcher() - private val nlen: Int = 64 - private external fun log_to_android() private external fun load_model(filename: String): Long private external fun free_model(model: Long) @@ -102,7 +100,7 @@ class LLamaAndroid { val context = new_context(model) if (context == 0L) throw IllegalStateException("new_context() failed") - val batch = new_batch(512, 0, 1) + val batch = new_batch(DEFAULT_BATCH_SIZE, 0, 1) if (batch == 0L) throw IllegalStateException("new_batch() failed") val sampler = new_sampler() @@ -116,17 +114,25 @@ class LLamaAndroid { } } - fun send(message: String, formatChat: Boolean = false): Flow = flow { + fun send( + message: String, + formatChat: Boolean = false, + predictLength: Int = DEFAULT_PREDICT_LENGTH, + ): Flow = flow { when (val state = threadLocalState.get()) { is State.Loaded -> { - val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen)) - while (ncur.value <= nlen) { - val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur) - if (str == null) { - break - } + val nCur = IntVar( + completion_init(state.context, state.batch, message, formatChat, predictLength) + ) + + while (nCur.value <= predictLength) { + val str = completion_loop( + state.context, state.batch, state.sampler, predictLength, nCur + ) ?: break + emit(str) } + kv_cache_clear(state.context) } else -> {} @@ -155,6 +161,9 @@ class LLamaAndroid { } companion object { + private const val DEFAULT_BATCH_SIZE = 512 + private const val DEFAULT_PREDICT_LENGTH = 128 + private class IntVar(value: Int) { @Volatile var value: Int = value