From 6d2279e9cd4283c803d04711d9107004079a109b Mon Sep 17 00:00:00 2001 From: Han Yin Date: Wed, 9 Apr 2025 10:09:23 -0700 Subject: [PATCH] REWRITE JNI bridge; Update viewmodel --- .../java/com/example/llama/MainViewModel.kt | 14 +- .../llama/src/main/cpp/llama-android.cpp | 78 ++--- .../java/android/llama/cpp/LLamaAndroid.kt | 289 ++++++++++-------- 3 files changed, 218 insertions(+), 163 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt index bb29cb08f6..bdee38c7b5 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt @@ -29,7 +29,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan viewModelScope.launch { try { - llamaAndroid.unload() + llamaAndroid.destroy() } catch (exc: IllegalStateException) { messages += exc.message!! } @@ -83,7 +83,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan fun load(pathToModel: String) { viewModelScope.launch { try { - llamaAndroid.load(pathToModel) + llamaAndroid.loadModel(pathToModel) messages += "Loaded $pathToModel" } catch (exc: IllegalStateException) { Log.e(tag, "load() failed", exc) @@ -103,4 +103,14 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan fun log(message: String) { messages += message } + + fun unload() = + viewModelScope.launch { + try { + llamaAndroid.unloadModel() + } catch (exc: IllegalStateException) { + Log.e(tag, "unload() failed", exc) + messages += exc.message!! + } + } } diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 06881104f2..b098940d44 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -33,8 +33,8 @@ static std::string join(const std::vector &values, const std::string &delim) /** * LLama resources: context, model, batch and sampler */ -constexpr int N_THREADS_MIN = 1; -constexpr int N_THREADS_MAX = 8; +constexpr int N_THREADS_MIN = 2; +constexpr int N_THREADS_MAX = 4; constexpr int N_THREADS_HEADROOM = 2; constexpr int DEFAULT_CONTEXT_SIZE = 8192; @@ -70,38 +70,27 @@ static void log_callback(ggml_log_level level, const char *fmt, void *data) { __android_log_print(priority, TAG, fmt, data); } -JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { - JNIEnv *env; - if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { - return JNI_ERR; - } - +extern "C" +JNIEXPORT void JNICALL +Java_android_llama_cpp_LLamaAndroid_init(JNIEnv *env, jobject /*unused*/) { // Set llama log handler to Android llama_log_set(log_callback, nullptr); // Initialize backends llama_backend_init(); - LOGi("Backend initiated."); - - return JNI_VERSION_1_6; -} - -extern "C" -JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_systemInfo(JNIEnv *env, jobject /*unused*/) { - return env->NewStringUTF(llama_print_system_info()); + LOGi("Backend initiated; Log handler set."); } extern "C" JNIEXPORT jint JNICALL -Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring filename) { +Java_android_llama_cpp_LLamaAndroid_load(JNIEnv *env, jobject, jstring jmodel_path) { llama_model_params model_params = llama_model_default_params(); - const auto *path_to_model = env->GetStringUTFChars(filename, 0); - LOGd("%s: Loading model from: \n%s\n", __func__, path_to_model); + const auto *model_path = env->GetStringUTFChars(jmodel_path, 0); + LOGd("%s: Loading model from: \n%s\n", __func__, model_path); - auto *model = llama_model_load_from_file(path_to_model, model_params); - env->ReleaseStringUTFChars(filename, path_to_model); + auto *model = llama_model_load_from_file(model_path, model_params); + env->ReleaseStringUTFChars(jmodel_path, model_path); if (!model) { return 1; } @@ -148,7 +137,7 @@ static common_sampler *new_sampler(float temp) { extern "C" JNIEXPORT jint JNICALL -Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) { +Java_android_llama_cpp_LLamaAndroid_prepare(JNIEnv * /*env*/, jobject /*unused*/) { auto *context = init_context(g_model); if (!context) { return 1; } g_context = context; @@ -158,17 +147,6 @@ Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unus return 0; } -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) { - common_sampler_free(g_sampler); - g_chat_templates.reset(); - llama_batch_free(g_batch); - llama_free(g_context); - llama_model_free(g_model); - llama_backend_free(); -} - static std::string get_backend() { std::vector backends; for (size_t i = 0; i < ggml_backend_reg_count(); i++) { @@ -181,6 +159,12 @@ static std::string get_backend() { return backends.empty() ? "CPU" : join(backends, ","); } +extern "C" +JNIEXPORT jstring JNICALL +Java_android_llama_cpp_LLamaAndroid_systemInfo(JNIEnv *env, jobject /*unused*/) { + return env->NewStringUTF(llama_print_system_info()); +} + extern "C" JNIEXPORT jstring JNICALL Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, @@ -269,7 +253,7 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, const auto backend = get_backend(); std::stringstream result; - result << std::setprecision(2); + result << std::setprecision(3); result << "| model | size | params | backend | test | t/s |\n"; result << "| --- | --- | --- | --- | --- | --- |\n"; result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " @@ -348,7 +332,7 @@ static void reset_short_term_states() { static int decode_tokens_in_batches( llama_context *context, - llama_batch batch, + llama_batch &batch, const llama_tokens &tokens, const llama_pos start_pos, const bool compute_last_logit = false) { @@ -574,3 +558,25 @@ Java_android_llama_cpp_LLamaAndroid_generateNextToken( } return result; } + + +extern "C" +JNIEXPORT void JNICALL +Java_android_llama_cpp_LLamaAndroid_unload(JNIEnv * /*unused*/, jobject /*unused*/) { + // Reset long-term & short-term states + reset_long_term_states(); + reset_short_term_states(); + + // Free up resources + common_sampler_free(g_sampler); + g_chat_templates.reset(); + llama_batch_free(g_batch); + llama_free(g_context); + llama_model_free(g_model); +} + +extern "C" +JNIEXPORT void JNICALL +Java_android_llama_cpp_LLamaAndroid_shutdown(JNIEnv *env, jobject /*unused*/) { + llama_backend_free(); +} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index 41e84f3fc2..7f52ccf68a 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -1,190 +1,229 @@ package android.llama.cpp import android.util.Log -import kotlinx.coroutines.CoroutineDispatcher -import kotlinx.coroutines.asCoroutineDispatcher +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flowOn +import kotlinx.coroutines.launch import kotlinx.coroutines.withContext -import java.util.concurrent.Executors -import kotlin.concurrent.thread +import java.io.File -class LLamaAndroid { +@Target(AnnotationTarget.FUNCTION) +@Retention(AnnotationRetention.SOURCE) +annotation class RequiresCleanup(val message: String = "Remember to call this method for proper cleanup!") + +/** + * JNI wrapper for the llama.cpp library providing Android-friendly access to large language models. + * + * This class implements a singleton pattern for managing the lifecycle of a single LLM instance. + * All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety + * with the underlying C++ native code. + * + * The typical usage flow is: + * 1. Get instance via [instance] + * 2. Load a model with [loadModel] + * 3. Send prompts with [sendUserPrompt] + * 4. Generate responses as token streams + * 5. Unload the model with [unloadModel] when switching models + * 6. Call [destroy] when completely done + * + * State transitions are managed automatically and validated at each operation. + * + * @see llama-android.cpp for the native implementation details + */ +class LLamaAndroid private constructor() { /** * JNI methods * @see llama-android.cpp */ + private external fun init() + private external fun load(modelPath: String): Int + private external fun prepare(): Int + private external fun systemInfo(): String - - private external fun loadModel(filename: String): Int - private external fun initContext(): Int - private external fun cleanUp() - private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String private external fun processSystemPrompt(systemPrompt: String): Int private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int private external fun generateNextToken(): String? + private external fun unload() + private external fun shutdown() + + /** - * Thread local state + * Fine-grained state management */ - private sealed interface State { - data object NotInitialized: State - data object EnvReady: State - data object AwaitingUserPrompt: State - data object Processing: State + sealed class State { + object Uninitialized : State() + object LibraryLoaded : State() + + object LoadingModel : State() + object ModelLoaded : State() + + object ProcessingSystemPrompt : State() + object AwaitingUserPrompt : State() + + object ProcessingUserPrompt : State() + object Generating : State() + + object Benchmarking : State() + + data class Error( + val errorMessage: String = "" + ) : State() } - private val threadLocalState: ThreadLocal = ThreadLocal.withInitial { State.NotInitialized } - private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor { - thread(start = false, name = LLAMA_THREAD) { - Log.d(TAG, "Dedicated thread for native code: ${Thread.currentThread().name}") - - // No-op if called more than once. - System.loadLibrary(LIB_LLAMA_ANDROID) - Log.d(TAG, systemInfo()) - - it.run() - }.apply { - uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable -> - Log.e(TAG, "Unhandled exception", exception) - } - } - }.asCoroutineDispatcher() + private val _state = MutableStateFlow(State.Uninitialized) + val state: StateFlow = _state /** - * Load the LLM, then process the formatted system prompt if provided + * Single-threaded coroutine dispatcher & scope for LLama asynchronous operations */ - suspend fun load(pathToModel: String, systemPrompt: String? = null) = - withContext(runLoop) { - when (threadLocalState.get()) { - is State.NotInitialized -> { - val modelResult = loadModel(pathToModel) - if (modelResult != 0) throw IllegalStateException("Load model failed: $modelResult") + @OptIn(ExperimentalCoroutinesApi::class) + private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1) + private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob()) - val initResult = initContext() - if (initResult != 0) throw IllegalStateException("Initialization failed with error code: $initResult") + init { + llamaScope.launch { + try { + System.loadLibrary(LIB_LLAMA_ANDROID) + init() + _state.value = State.LibraryLoaded + Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}") + } catch (e: Exception) { + _state.value = State.Error("Failed to load native library: ${e.message}") + Log.e(TAG, "Failed to load native library", e) + } + } + } - Log.i(TAG, "Loaded model $pathToModel") - threadLocalState.set(State.EnvReady) + /** + * Load the LLM, then process the plain text system prompt if provided + */ + suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) = + withContext(llamaDispatcher) { + check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" } + File(pathToModel).let { + require(it.exists()) { "Model file not found: $pathToModel" } + require(it.isFile) { "Model file is not a file: $pathToModel" } + } - systemPrompt?.let { - initWithSystemPrompt(systemPrompt) - } ?: run { - Log.w(TAG, "No system prompt to process.") - threadLocalState.set(State.AwaitingUserPrompt) + Log.i(TAG, "Loading model... \n$pathToModel") + _state.value = State.LoadingModel + load(pathToModel).let { result -> + if (result != 0) throw IllegalStateException("Failed to Load model: $result") + } + prepare().let { result -> + if (result != 0) throw IllegalStateException("Failed to prepare resources: $result") + } + Log.i(TAG, "Model loaded!") + _state.value = State.ModelLoaded + + systemPrompt?.let { prompt -> + Log.i(TAG, "Sending system prompt...") + _state.value = State.ProcessingSystemPrompt + processSystemPrompt(prompt).let { result -> + if (result != 0) { + val errorMessage = "Failed to process system prompt: $result" + _state.value = State.Error(errorMessage) + throw IllegalStateException(errorMessage) } } - else -> throw IllegalStateException("Model already loaded") + Log.i(TAG, "System prompt processed! Awaiting user prompt...") + } ?: run { + Log.w(TAG, "No system prompt to process.") } + _state.value = State.AwaitingUserPrompt } /** - * Helper method to process system prompt and update [State] - */ - private suspend fun initWithSystemPrompt(formattedMessage: String) = - withContext(runLoop) { - when (threadLocalState.get()) { - is State.EnvReady -> { - Log.i(TAG, "Process system prompt...") - threadLocalState.set(State.Processing) - processSystemPrompt(formattedMessage).let { - if (it != 0) - throw IllegalStateException("Failed to process system prompt: $it") - } - - Log.i(TAG, "System prompt processed!") - threadLocalState.set(State.AwaitingUserPrompt) - } - else -> throw IllegalStateException( - "Failed to process system prompt: Model not loaded!" - ) - } - } - - /** - * Send formatted user prompt to LLM + * Send plain text user prompt to LLM, which starts generating tokens in a [Flow] */ fun sendUserPrompt( message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH, ): Flow = flow { - require(message.isNotEmpty()) { - Log.w(TAG, "User prompt discarded due to being empty!") + require(message.isNotEmpty()) { "User prompt discarded due to being empty!" } + check(_state.value is State.AwaitingUserPrompt) { + "User prompt discarded due to: ${_state.value}" } - when (val state = threadLocalState.get()) { - is State.AwaitingUserPrompt -> { - Log.i(TAG, "Sending user prompt...") - threadLocalState.set(State.Processing) - processUserPrompt(message, predictLength).let { result -> - if (result != 0) { - Log.e(TAG, "Failed to process user prompt: $result") - return@flow - } - } - - Log.i(TAG, "User prompt processed! Generating assistant prompt...") - while (true) { - generateNextToken()?.let { utf8token -> - if (utf8token.isNotEmpty()) emit(utf8token) - } ?: break - } - - Log.i(TAG, "Assistant generation complete!") - threadLocalState.set(State.AwaitingUserPrompt) - } - else -> { - Log.w(TAG, "User prompt discarded due to incorrect state: $state") + Log.i(TAG, "Sending user prompt...") + _state.value = State.ProcessingUserPrompt + processUserPrompt(message, predictLength).let { result -> + if (result != 0) { + Log.e(TAG, "Failed to process user prompt: $result") + return@flow } } - }.flowOn(runLoop) + + Log.i(TAG, "User prompt processed! Generating assistant prompt...") + _state.value = State.Generating + while (true) { + generateNextToken()?.let { utf8token -> + if (utf8token.isNotEmpty()) emit(utf8token) + } ?: break + } + Log.i(TAG, "Assistant generation complete! Awaiting user prompt...") + _state.value = State.AwaitingUserPrompt + }.flowOn(llamaDispatcher) /** * Benchmark the model */ suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String = - withContext(runLoop) { - when (threadLocalState.get()) { - is State.AwaitingUserPrompt -> { - threadLocalState.set(State.Processing) - Log.d(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)") - benchModel(pp, tg, pl, nr).also { - threadLocalState.set(State.AwaitingUserPrompt) - } - } - - // TODO-hyin: disable button when state incorrect - else -> throw IllegalStateException("No model loaded") + withContext(llamaDispatcher) { + check(_state.value is State.AwaitingUserPrompt) { + "Benchmark request discarded due to: $state" + } + Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)") + _state.value = State.Benchmarking + benchModel(pp, tg, pl, nr).also { + _state.value = State.AwaitingUserPrompt } } /** - * Unloads the model and frees resources. - * - * This is a no-op if there's no model loaded. + * Unloads the model and frees resources */ - suspend fun unload() = - withContext(runLoop) { - when (val state = threadLocalState.get()) { - is State.EnvReady, State.AwaitingUserPrompt -> { - cleanUp() - threadLocalState.set(State.NotInitialized) - } - else -> { - Log.w(TAG, "Cannot unload model due to incorrect state: $state") + suspend fun unloadModel() = + withContext(llamaDispatcher) { + when(_state.value) { + is State.AwaitingUserPrompt, is State.Error -> { + Log.i(TAG, "Unloading model and free resources...") + unload() + _state.value = State.LibraryLoaded + Log.i(TAG, "Model unloaded!") } + else -> throw IllegalStateException("Cannot unload model in ${_state.value}") } } + /** + * Cancel all ongoing coroutines and free GGML backends + */ + @RequiresCleanup("Call from `ViewModel.onCleared()` to prevent resource leaks!") + fun destroy() { + llamaScope.cancel() + when(_state.value) { + is State.Uninitialized -> {} + is State.LibraryLoaded -> shutdown() + else -> { unload(); shutdown() } + } + } + companion object { private val TAG = LLamaAndroid::class.simpleName private const val LIB_LLAMA_ANDROID = "llama-android" - private const val LLAMA_THREAD = "llama-thread" - private const val DEFAULT_PREDICT_LENGTH = 64 // Enforce only one instance of Llm.