diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt index 45ac29938f..c7ffa3f77c 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt @@ -45,7 +45,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan messages += "" viewModelScope.launch { - llamaAndroid.send(text) + llamaAndroid.sendMessage(text) .catch { Log.e(tag, "send() failed", it) messages += it.message!! diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 458fdbdc60..109e630727 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -8,57 +8,30 @@ #include "llama.h" #include "common.h" +/** + * Logging utils + */ #define TAG "llama-android.cpp" +#define LOGd(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__) #define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) +#define LOGw(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) #define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) -constexpr int CONTEXT_SIZE = 2048; +/** + * LLama resources: context, model, batch and sampler + */ +constexpr int N_THREADS_MIN = 1; +constexpr int N_THREADS_MAX = 8; +constexpr int N_THREADS_HEADROOM = 2; + +constexpr int CONTEXT_SIZE = 4096; +constexpr int BATCH_SIZE = 512; constexpr float SAMPLER_TEMP = 0.3f; -constexpr int N_THREADS_MIN = 1; -constexpr int N_THREADS_MAX = 8; -constexpr int N_THREADS_HEADROOM = 2; - -jclass la_int_var; -jmethodID la_int_var_value; -jmethodID la_int_var_inc; - -std::string cached_token_chars; - -bool is_valid_utf8(const char *string) { - if (!string) { return true; } - - const auto *bytes = (const unsigned char *) string; - int num; - - while (*bytes != 0x00) { - if ((*bytes & 0x80) == 0x00) { - // U+0000 to U+007F - num = 1; - } else if ((*bytes & 0xE0) == 0xC0) { - // U+0080 to U+07FF - num = 2; - } else if ((*bytes & 0xF0) == 0xE0) { - // U+0800 to U+FFFF - num = 3; - } else if ((*bytes & 0xF8) == 0xF0) { - // U+10000 to U+10FFFF - num = 4; - } else { - return false; - } - - bytes += 1; - for (int i = 1; i < num; ++i) { - if ((*bytes & 0xC0) != 0x80) { - return false; - } - bytes += 1; - } - } - - return true; -} +llama_model * model; +llama_context * context; +llama_batch * batch; +common_sampler * sampler; static void log_callback(ggml_log_level level, const char *fmt, void *data) { int priority; @@ -83,105 +56,129 @@ static void log_callback(ggml_log_level level, const char *fmt, void *data) { } extern "C" -JNIEXPORT jlong JNICALL -Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) { - llama_model_params model_params = llama_model_default_params(); +JNIEXPORT void JNICALL +Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv * /*unused*/, jobject /*unused*/) { + llama_log_set(log_callback, nullptr); +} - auto path_to_model = env->GetStringUTFChars(filename, 0); - LOGi("Loading model from %s", path_to_model); - auto model = llama_model_load_from_file(path_to_model, model_params); - env->ReleaseStringUTFChars(filename, path_to_model); - - if (!model) { - LOGe("load_model() failed"); - env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed"); - return 0; - } - - return reinterpret_cast(model); +extern "C" +JNIEXPORT jstring JNICALL +Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject /*unused*/) { + return env->NewStringUTF(llama_print_system_info()); } extern "C" JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) { - llama_model_free(reinterpret_cast(model)); +Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv * /*unused*/, jobject /*unused*/) { + llama_backend_init(); } extern "C" -JNIEXPORT jlong JNICALL -Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) { - auto model = reinterpret_cast(jmodel); +JNIEXPORT jint JNICALL +Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) { + llama_model_params model_params = llama_model_default_params(); + + const auto *path_to_model = env->GetStringUTFChars(filename, 0); + LOGi("Loading model from: %s", path_to_model); + + model = llama_model_load_from_file(path_to_model, model_params); + env->ReleaseStringUTFChars(filename, path_to_model); if (!model) { - LOGe("new_context(): model cannot be null"); - env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null"); - return 0; + LOGe("load_model() failed"); + return -1; + } + return 0; +} + +int init_context() { + if (!model) { + LOGe("init_context(): model cannot be null"); + return -1; } + // Multi-threading setup int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX, (int) sysconf(_SC_NPROCESSORS_ONLN) - N_THREADS_HEADROOM)); LOGi("Using %d threads", n_threads); + // Context parameters setup llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = CONTEXT_SIZE; ctx_params.n_threads = n_threads; ctx_params.n_threads_batch = n_threads; - llama_context *context = llama_init_from_model(model, ctx_params); - + context = llama_init_from_model(model, ctx_params); if (!context) { LOGe("llama_new_context_with_model() returned null)"); - env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), - "llama_new_context_with_model() returned null)"); - return 0; + return -2; + } + return 0; +} + +void new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) { + // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. + batch = new llama_batch{ + 0, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + }; + + if (embd) { + batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd); + } else { + batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); } - return reinterpret_cast(context); + batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); + batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); + batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); + for (int i = 0; i < n_tokens; ++i) { + batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + } + batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); +} + +void new_sampler(float temp) { + common_params_sampling sparams; + sparams.temp = temp; + sampler = common_sampler_init(model, sparams); +} + +extern "C" +JNIEXPORT jint JNICALL +Java_android_llama_cpp_LLamaAndroid_ctx_1init(JNIEnv * /*env*/, jobject /*unused*/) { + int ret = init_context(); + if (ret != 0) { return ret; } + new_batch(BATCH_SIZE); + new_sampler(SAMPLER_TEMP); + return 0; } extern "C" JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) { - llama_free(reinterpret_cast(context)); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) { +Java_android_llama_cpp_LLamaAndroid_clean_1up(JNIEnv * /*unused*/, jobject /*unused*/) { + llama_model_free(model); + llama_free(context); llama_backend_free(); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) { - llama_log_set(log_callback, nullptr); + delete batch; + common_sampler_free(sampler); } extern "C" JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_bench_1model( - JNIEnv *env, - jobject, - jlong context_pointer, - jlong model_pointer, - jlong batch_pointer, - jint pp, - jint tg, - jint pl, - jint nr -) { +Java_android_llama_cpp_LLamaAndroid_bench_1model(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) { auto pp_avg = 0.0; auto tg_avg = 0.0; auto pp_std = 0.0; auto tg_std = 0.0; - auto *const context = reinterpret_cast(context_pointer); - auto *const model = reinterpret_cast(model_pointer); - auto *const batch = reinterpret_cast(batch_pointer); - const uint32_t n_ctx = llama_n_ctx(context); LOGi("n_ctx = %d", n_ctx); @@ -203,7 +200,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto t_pp_start = ggml_time_us(); if (llama_decode(context, *batch) != 0) { - LOGi("llama_decode() failed during prompt processing"); + LOGw("llama_decode() failed during prompt processing"); } const auto t_pp_end = ggml_time_us(); @@ -222,7 +219,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( LOGi("llama_decode() text generation: %d", i); if (llama_decode(context, *batch) != 0) { - LOGi("llama_decode() failed during text generation"); + LOGw("llama_decode() failed during text generation"); } } @@ -276,187 +273,181 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( return env->NewStringUTF(result.str().c_str()); } -extern "C" -JNIEXPORT jlong JNICALL -Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, - jint n_seq_max) { - // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. - auto *batch = new llama_batch{ - 0, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - }; +/** + * Prediction loop's states + */ +int current_position; - if (embd) { - batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd); - } else { - batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); - } - - batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); - batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); - batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); - for (int i = 0; i < n_tokens; ++i) { - batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); - } - batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); - - return reinterpret_cast(batch); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { - //llama_batch_free(*reinterpret_cast(batch_pointer)); // TODO: what is this? - auto *const batch = reinterpret_cast(batch_pointer); - delete batch; -} - -extern "C" -JNIEXPORT jlong JNICALL -Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject, jlong model_pointer) { - auto *const model = reinterpret_cast(model_pointer); - common_params_sampling sparams; - sparams.temp = SAMPLER_TEMP; - common_sampler *smpl = common_sampler_init(model, sparams); - - return reinterpret_cast(smpl); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) { - // Properly cast from jlong to pointer type - auto* sampler = (llama_sampler*)(void*)(sampler_pointer); - llama_sampler_free(sampler); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) { - llama_backend_init(); -} - -extern "C" -JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) { - return env->NewStringUTF(llama_print_system_info()); -} +int token_predict_budget; +std::string cached_token_chars; extern "C" JNIEXPORT jint JNICALL -Java_android_llama_cpp_LLamaAndroid_completion_1init( +Java_android_llama_cpp_LLamaAndroid_process_1system_1prompt( JNIEnv *env, - jobject, - jlong context_pointer, - jlong batch_pointer, - jstring jtext, - jboolean format_chat, - jint n_len + jobject /*unused*/, + jstring jsystem_prompt ) { + // Reset long-term states and reset KV cache + current_position = 0; + llama_memory_clear(llama_get_memory(context), false); + + // Reset short-term states + token_predict_budget = 0; cached_token_chars.clear(); - const auto *const text = env->GetStringUTFChars(jtext, 0); - auto *const context = reinterpret_cast(context_pointer); - auto *const batch = reinterpret_cast(batch_pointer); - - bool parse_special = (format_chat == JNI_TRUE); - const auto tokens_list = common_tokenize(context, text, true, parse_special); - - auto n_ctx = llama_n_ctx(context); - auto n_tokens = tokens_list.size(); - auto n_kv_req = n_tokens + n_len; - LOGi("n_ctx = %d, n_kv_req = n_tokens (%d) + n_len (%d) = %d", n_ctx, n_tokens, n_len, n_kv_req); - - if (n_kv_req > n_ctx) { - LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); - } - - for (auto id: tokens_list) { - LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id); - } + // Obtain and tokenize system prompt + const auto *const system_text = env->GetStringUTFChars(jsystem_prompt, nullptr); + LOGi("System prompt: \n%s", system_text); + const auto system_tokens = common_tokenize(context, system_text, true, true); + env->ReleaseStringUTFChars(jsystem_prompt, system_text); + // Add system prompt tokens to batch common_batch_clear(*batch); - - // evaluate the initial prompt - for (auto i = 0; i < tokens_list.size(); i++) { - common_batch_add(*batch, tokens_list[i], i, {0}, false); + // TODO-hyin: support batch processing! + for (int i = 0; i < system_tokens.size(); i++) { + common_batch_add(*batch, system_tokens[i], i, {0}, false); } - // llama_decode will output logits only for the last token of the prompt - batch->logits[batch->n_tokens - 1] = true; - - if (llama_decode(context, *batch) != 0) { - LOGe("llama_decode() failed"); + // Decode batch + int decode_result = llama_decode(context, *batch); + if (decode_result != 0) { + LOGe("llama_decode() failed: %d", decode_result); + return -1; } - env->ReleaseStringUTFChars(jtext, text); + // Update position + current_position = system_tokens.size(); + return 0; +} - return batch->n_tokens; +// TODO-hyin: support KV cache backtracking +extern "C" +JNIEXPORT jint JNICALL +Java_android_llama_cpp_LLamaAndroid_process_1user_1prompt( + JNIEnv *env, + jobject /*unused*/, + jstring juser_prompt, + jint n_len +) { + // Reset short-term states + token_predict_budget = 0; + cached_token_chars.clear(); + + // Obtain and tokenize user prompt + const auto *const user_text = env->GetStringUTFChars(juser_prompt, nullptr); + LOGi("User prompt: \n%s", user_text); + const auto user_tokens = common_tokenize(context, user_text, true, true); + env->ReleaseStringUTFChars(juser_prompt, user_text); + + // Check if context space is enough for desired tokens + int desired_budget = current_position + user_tokens.size() + n_len; + if (desired_budget > llama_n_ctx(context)) { + LOGe("error: total tokens exceed context size"); + return -1; + } + token_predict_budget = desired_budget; + + // Add user prompt tokens to batch + common_batch_clear(*batch); + for (int i = 0; i < user_tokens.size(); i++) { + common_batch_add(*batch, user_tokens[i], current_position + i, {0}, false); + } + batch->logits[batch->n_tokens - 1] = true; // Set logits true only for last token + + // Decode batch + int decode_result = llama_decode(context, *batch); + if (decode_result != 0) { + LOGe("llama_decode() failed: %d", decode_result); + return -2; + } + + // Update position + current_position += user_tokens.size(); // Update position + return 0; +} + +bool is_valid_utf8(const char *string) { + if (!string) { return true; } + + const auto *bytes = (const unsigned char *) string; + int num; + + while (*bytes != 0x00) { + if ((*bytes & 0x80) == 0x00) { + // U+0000 to U+007F + num = 1; + } else if ((*bytes & 0xE0) == 0xC0) { + // U+0080 to U+07FF + num = 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // U+0800 to U+FFFF + num = 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // U+10000 to U+10FFFF + num = 4; + } else { + return false; + } + + bytes += 1; + for (int i = 1; i < num; ++i) { + if ((*bytes & 0xC0) != 0x80) { + return false; + } + bytes += 1; + } + } + return true; } extern "C" JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_completion_1loop( +Java_android_llama_cpp_LLamaAndroid_predict_1loop( JNIEnv *env, - jobject, - jlong context_pointer, - jlong batch_pointer, - jlong sampler_pointer, - jint n_len, - jobject intvar_ncur + jobject /*unused*/ ) { - auto *const context = reinterpret_cast(context_pointer); - auto *const batch = reinterpret_cast(batch_pointer); - auto *const sampler = reinterpret_cast(sampler_pointer); - const auto *const model = llama_get_model(context); - const auto *const vocab = llama_model_get_vocab(model); - - if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); - if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); - if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); - - // sample the most likely token - const auto new_token_id = common_sampler_sample(sampler, context, -1); - common_sampler_accept(sampler, new_token_id, /* accept_grammar= */ true); - - const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); - if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { + // Stop if running out of token budget + if (current_position >= token_predict_budget) { + LOGi("STOP: current position (%d) exceeds budget (%d)", current_position, token_predict_budget); return nullptr; } + // Sample next token + const auto new_token_id = common_sampler_sample(sampler, context, -1); + common_sampler_accept(sampler, new_token_id, true); + + // Stop if next token is EOG + if (llama_vocab_is_eog(llama_model_get_vocab(model), new_token_id)) { + LOGi("id: %d,\tIS EOG!\nSTOP.", new_token_id); + return nullptr; + } + + // Update the context with the new token + common_batch_clear(*batch); + common_batch_add(*batch, new_token_id, current_position, {0}, true); + if (llama_decode(context, *batch) != 0) { + LOGe("llama_decode() failed for generated token"); + return nullptr; + } + + // Convert to text auto new_token_chars = common_token_to_piece(context, new_token_id); cached_token_chars += new_token_chars; - jstring new_token = nullptr; + // Create Java string + jstring result = nullptr; if (is_valid_utf8(cached_token_chars.c_str())) { - new_token = env->NewStringUTF(cached_token_chars.c_str()); - LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), - new_token_chars.c_str(), new_token_id); + result = env->NewStringUTF(cached_token_chars.c_str()); + LOGd("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str()); cached_token_chars.clear(); } else { - new_token = env->NewStringUTF(""); + LOGd("id: %d,\tappend to cache", new_token_id); + result = env->NewStringUTF(""); } - common_batch_clear(*batch); - common_batch_add(*batch, new_token_id, n_cur, {0}, true); - - env->CallVoidMethod(intvar_ncur, la_int_var_inc); - - if (llama_decode(context, *batch) != 0) { - LOGe("llama_decode() returned null"); - } - - return new_token; -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { - llama_memory_clear(llama_get_memory(reinterpret_cast(context)), true); + // Update position + current_position++; + return result; } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index 65eedd2d41..dfba00475b 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -24,7 +24,7 @@ class LLamaAndroid { // Set llama log handler to Android log_to_android() - backend_init(false) + backend_init() Log.d(tag, system_info()) @@ -37,103 +37,96 @@ class LLamaAndroid { }.asCoroutineDispatcher() private external fun log_to_android() - private external fun load_model(filename: String): Long - private external fun free_model(model: Long) - private external fun new_context(model: Long): Long - private external fun free_context(context: Long) - private external fun backend_init(numa: Boolean) - private external fun backend_free() - private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long - private external fun free_batch(batch: Long) - private external fun new_sampler(model: Long): Long - private external fun free_sampler(sampler: Long) - private external fun bench_model( - context: Long, - model: Long, - batch: Long, - pp: Int, - tg: Int, - pl: Int, - nr: Int - ): String - private external fun system_info(): String + private external fun backend_init() - private external fun completion_init( - context: Long, - batch: Long, - text: String, - formatChat: Boolean, - nLen: Int - ): Int + private external fun load_model(filename: String): Int + private external fun ctx_init(): Int + private external fun clean_up() - private external fun completion_loop( - context: Long, - batch: Long, - sampler: Long, - nLen: Int, - ncur: IntVar - ): String? + private external fun bench_model(pp: Int, tg: Int, pl: Int, nr: Int): String - private external fun kv_cache_clear(context: Long) + private external fun process_system_prompt(system_prompt: String): Int + private external fun process_user_prompt(user_prompt: String, nLen: Int): Int + private external fun predict_loop(): String? - suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { - return withContext(runLoop) { - when (val state = threadLocalState.get()) { - is State.Loaded -> { - Log.d(tag, "bench(): $state") - bench_model(state.context, state.model, state.batch, pp, tg, pl, nr) - } - - else -> throw IllegalStateException("No model loaded") - } - } - } - - suspend fun load(pathToModel: String) { + suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) { withContext(runLoop) { when (threadLocalState.get()) { is State.Idle -> { val model = load_model(pathToModel) - if (model == 0L) throw IllegalStateException("load_model() failed") + if (model != 0) throw IllegalStateException("Load model failed") - val context = new_context(model) - if (context == 0L) throw IllegalStateException("new_context() failed") - - val batch = new_batch(DEFAULT_BATCH_SIZE, 0, 1) - if (batch == 0L) throw IllegalStateException("new_batch() failed") - - val sampler = new_sampler(model) - if (sampler == 0L) throw IllegalStateException("new_sampler() failed") + val result = ctx_init() + if (result != 0) throw IllegalStateException("Initialization failed with error code: $result") Log.i(tag, "Loaded model $pathToModel") - threadLocalState.set(State.Loaded(model, context, batch, sampler)) + threadLocalState.set(State.ModelLoaded) + + formattedSystemPrompt?.let { + initWithSystemPrompt(formattedSystemPrompt) + } ?: { + threadLocalState.set(State.ReadyForUserPrompt) + } } else -> throw IllegalStateException("Model already loaded") } } } - fun send( - message: String, - formatChat: Boolean = false, - predictLength: Int = DEFAULT_PREDICT_LENGTH, - ): Flow = flow { - when (val state = threadLocalState.get()) { - is State.Loaded -> { - val nCur = IntVar( - completion_init(state.context, state.batch, message, formatChat, predictLength) - ) - - while (nCur.value <= predictLength) { - val str = completion_loop( - state.context, state.batch, state.sampler, predictLength, nCur - ) ?: break - - emit(str) + suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { + return withContext(runLoop) { + when (val state = threadLocalState.get()) { + is State.ModelLoaded -> { + Log.d(tag, "bench(): $state") + bench_model(pp, tg, pl, nr) } - // kv_cache_clear(state.context) + // TODO-hyin: catch exception in ViewController; disable button when state incorrect + else -> throw IllegalStateException("No model loaded") + } + } + } + + private suspend fun initWithSystemPrompt(systemPrompt: String) { + withContext(runLoop) { + when (threadLocalState.get()) { + is State.ModelLoaded -> { + process_system_prompt(systemPrompt).let { + if (it != 0) { + throw IllegalStateException("Failed to process system prompt: $it") + } + } + + Log.i(tag, "System prompt processed!") + threadLocalState.set(State.ReadyForUserPrompt) + } + else -> throw IllegalStateException("Model not loaded") + } + } + } + + fun sendMessage( + formattedUserPrompt: String, + nPredict: Int = DEFAULT_PREDICT_LENGTH, + ): Flow = flow { + when (threadLocalState.get()) { + is State.ReadyForUserPrompt -> { + process_user_prompt(formattedUserPrompt, nPredict).let { + if (it != 0) { + Log.e(tag, "Failed to process user prompt: $it") + return@flow + } + } + + Log.i(tag, "User prompt processed! Generating assistant prompt...") + while (true) { + val str = predict_loop() ?: break + if (str.isNotEmpty()) { + emit(str) + } + } + Log.i(tag, "Assistant generation complete!") } else -> {} } @@ -146,13 +139,9 @@ class LLamaAndroid { */ suspend fun unload() { withContext(runLoop) { - when (val state = threadLocalState.get()) { - is State.Loaded -> { - free_context(state.context) - free_model(state.model) - free_batch(state.batch) - free_sampler(state.sampler); - + when (threadLocalState.get()) { + is State.ModelLoaded -> { + clean_up() threadLocalState.set(State.Idle) } else -> {} @@ -161,24 +150,12 @@ class LLamaAndroid { } companion object { - private const val DEFAULT_BATCH_SIZE = 512 private const val DEFAULT_PREDICT_LENGTH = 128 - private class IntVar(value: Int) { - @Volatile - var value: Int = value - private set - - fun inc() { - synchronized(this) { - value += 1 - } - } - } - private sealed interface State { data object Idle: State - data class Loaded(val model: Long, val context: Long, val batch: Long, val sampler: Long): State + data object ModelLoaded: State + data object ReadyForUserPrompt: State } // Enforce only one instance of Llm.