REWRITE JNI bridge; Update viewmodel

This commit is contained in:
Han Yin 2025-04-09 10:09:23 -07:00
parent e1bc87610e
commit 6d2279e9cd
3 changed files with 218 additions and 163 deletions

View File

@ -29,7 +29,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
viewModelScope.launch { viewModelScope.launch {
try { try {
llamaAndroid.unload() llamaAndroid.destroy()
} catch (exc: IllegalStateException) { } catch (exc: IllegalStateException) {
messages += exc.message!! messages += exc.message!!
} }
@ -83,7 +83,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
fun load(pathToModel: String) { fun load(pathToModel: String) {
viewModelScope.launch { viewModelScope.launch {
try { try {
llamaAndroid.load(pathToModel) llamaAndroid.loadModel(pathToModel)
messages += "Loaded $pathToModel" messages += "Loaded $pathToModel"
} catch (exc: IllegalStateException) { } catch (exc: IllegalStateException) {
Log.e(tag, "load() failed", exc) Log.e(tag, "load() failed", exc)
@ -103,4 +103,14 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
fun log(message: String) { fun log(message: String) {
messages += message messages += message
} }
fun unload() =
viewModelScope.launch {
try {
llamaAndroid.unloadModel()
} catch (exc: IllegalStateException) {
Log.e(tag, "unload() failed", exc)
messages += exc.message!!
}
}
} }

View File

@ -33,8 +33,8 @@ static std::string join(const std::vector<T> &values, const std::string &delim)
/** /**
* LLama resources: context, model, batch and sampler * LLama resources: context, model, batch and sampler
*/ */
constexpr int N_THREADS_MIN = 1; constexpr int N_THREADS_MIN = 2;
constexpr int N_THREADS_MAX = 8; constexpr int N_THREADS_MAX = 4;
constexpr int N_THREADS_HEADROOM = 2; constexpr int N_THREADS_HEADROOM = 2;
constexpr int DEFAULT_CONTEXT_SIZE = 8192; constexpr int DEFAULT_CONTEXT_SIZE = 8192;
@ -70,38 +70,27 @@ static void log_callback(ggml_log_level level, const char *fmt, void *data) {
__android_log_print(priority, TAG, fmt, data); __android_log_print(priority, TAG, fmt, data);
} }
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { extern "C"
JNIEnv *env; JNIEXPORT void JNICALL
if (vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6) != JNI_OK) { Java_android_llama_cpp_LLamaAndroid_init(JNIEnv *env, jobject /*unused*/) {
return JNI_ERR;
}
// Set llama log handler to Android // Set llama log handler to Android
llama_log_set(log_callback, nullptr); llama_log_set(log_callback, nullptr);
// Initialize backends // Initialize backends
llama_backend_init(); llama_backend_init();
LOGi("Backend initiated."); LOGi("Backend initiated; Log handler set.");
return JNI_VERSION_1_6;
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_systemInfo(JNIEnv *env, jobject /*unused*/) {
return env->NewStringUTF(llama_print_system_info());
} }
extern "C" extern "C"
JNIEXPORT jint JNICALL JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring filename) { Java_android_llama_cpp_LLamaAndroid_load(JNIEnv *env, jobject, jstring jmodel_path) {
llama_model_params model_params = llama_model_default_params(); llama_model_params model_params = llama_model_default_params();
const auto *path_to_model = env->GetStringUTFChars(filename, 0); const auto *model_path = env->GetStringUTFChars(jmodel_path, 0);
LOGd("%s: Loading model from: \n%s\n", __func__, path_to_model); LOGd("%s: Loading model from: \n%s\n", __func__, model_path);
auto *model = llama_model_load_from_file(path_to_model, model_params); auto *model = llama_model_load_from_file(model_path, model_params);
env->ReleaseStringUTFChars(filename, path_to_model); env->ReleaseStringUTFChars(jmodel_path, model_path);
if (!model) { if (!model) {
return 1; return 1;
} }
@ -148,7 +137,7 @@ static common_sampler *new_sampler(float temp) {
extern "C" extern "C"
JNIEXPORT jint JNICALL JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) { Java_android_llama_cpp_LLamaAndroid_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
auto *context = init_context(g_model); auto *context = init_context(g_model);
if (!context) { return 1; } if (!context) { return 1; }
g_context = context; g_context = context;
@ -158,17 +147,6 @@ Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unus
return 0; return 0;
} }
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) {
common_sampler_free(g_sampler);
g_chat_templates.reset();
llama_batch_free(g_batch);
llama_free(g_context);
llama_model_free(g_model);
llama_backend_free();
}
static std::string get_backend() { static std::string get_backend() {
std::vector<std::string> backends; std::vector<std::string> backends;
for (size_t i = 0; i < ggml_backend_reg_count(); i++) { for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
@ -181,6 +159,12 @@ static std::string get_backend() {
return backends.empty() ? "CPU" : join(backends, ","); return backends.empty() ? "CPU" : join(backends, ",");
} }
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_systemInfo(JNIEnv *env, jobject /*unused*/) {
return env->NewStringUTF(llama_print_system_info());
}
extern "C" extern "C"
JNIEXPORT jstring JNICALL JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
@ -269,7 +253,7 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
const auto backend = get_backend(); const auto backend = get_backend();
std::stringstream result; std::stringstream result;
result << std::setprecision(2); result << std::setprecision(3);
result << "| model | size | params | backend | test | t/s |\n"; result << "| model | size | params | backend | test | t/s |\n";
result << "| --- | --- | --- | --- | --- | --- |\n"; result << "| --- | --- | --- | --- | --- | --- |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
@ -348,7 +332,7 @@ static void reset_short_term_states() {
static int decode_tokens_in_batches( static int decode_tokens_in_batches(
llama_context *context, llama_context *context,
llama_batch batch, llama_batch &batch,
const llama_tokens &tokens, const llama_tokens &tokens,
const llama_pos start_pos, const llama_pos start_pos,
const bool compute_last_logit = false) { const bool compute_last_logit = false) {
@ -574,3 +558,25 @@ Java_android_llama_cpp_LLamaAndroid_generateNextToken(
} }
return result; return result;
} }
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
// Reset long-term & short-term states
reset_long_term_states();
reset_short_term_states();
// Free up resources
common_sampler_free(g_sampler);
g_chat_templates.reset();
llama_batch_free(g_batch);
llama_free(g_context);
llama_model_free(g_model);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_shutdown(JNIEnv *env, jobject /*unused*/) {
llama_backend_free();
}

View File

@ -1,124 +1,163 @@
package android.llama.cpp package android.llama.cpp
import android.util.Log import android.util.Log
import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.asCoroutineDispatcher import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import java.util.concurrent.Executors import java.io.File
import kotlin.concurrent.thread
class LLamaAndroid { @Target(AnnotationTarget.FUNCTION)
@Retention(AnnotationRetention.SOURCE)
annotation class RequiresCleanup(val message: String = "Remember to call this method for proper cleanup!")
/**
* JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
*
* This class implements a singleton pattern for managing the lifecycle of a single LLM instance.
* All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety
* with the underlying C++ native code.
*
* The typical usage flow is:
* 1. Get instance via [instance]
* 2. Load a model with [loadModel]
* 3. Send prompts with [sendUserPrompt]
* 4. Generate responses as token streams
* 5. Unload the model with [unloadModel] when switching models
* 6. Call [destroy] when completely done
*
* State transitions are managed automatically and validated at each operation.
*
* @see llama-android.cpp for the native implementation details
*/
class LLamaAndroid private constructor() {
/** /**
* JNI methods * JNI methods
* @see llama-android.cpp * @see llama-android.cpp
*/ */
private external fun init()
private external fun load(modelPath: String): Int
private external fun prepare(): Int
private external fun systemInfo(): String private external fun systemInfo(): String
private external fun loadModel(filename: String): Int
private external fun initContext(): Int
private external fun cleanUp()
private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
private external fun processSystemPrompt(systemPrompt: String): Int private external fun processSystemPrompt(systemPrompt: String): Int
private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
private external fun generateNextToken(): String? private external fun generateNextToken(): String?
private external fun unload()
private external fun shutdown()
/** /**
* Thread local state * Fine-grained state management
*/ */
private sealed interface State { sealed class State {
data object NotInitialized: State object Uninitialized : State()
data object EnvReady: State object LibraryLoaded : State()
data object AwaitingUserPrompt: State
data object Processing: State object LoadingModel : State()
object ModelLoaded : State()
object ProcessingSystemPrompt : State()
object AwaitingUserPrompt : State()
object ProcessingUserPrompt : State()
object Generating : State()
object Benchmarking : State()
data class Error(
val errorMessage: String = ""
) : State()
} }
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.NotInitialized }
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor { private val _state = MutableStateFlow<State>(State.Uninitialized)
thread(start = false, name = LLAMA_THREAD) { val state: StateFlow<State> = _state
Log.d(TAG, "Dedicated thread for native code: ${Thread.currentThread().name}")
// No-op if called more than once. /**
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
*/
@OptIn(ExperimentalCoroutinesApi::class)
private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
init {
llamaScope.launch {
try {
System.loadLibrary(LIB_LLAMA_ANDROID) System.loadLibrary(LIB_LLAMA_ANDROID)
Log.d(TAG, systemInfo()) init()
_state.value = State.LibraryLoaded
it.run() Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}")
}.apply { } catch (e: Exception) {
uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable -> _state.value = State.Error("Failed to load native library: ${e.message}")
Log.e(TAG, "Unhandled exception", exception) Log.e(TAG, "Failed to load native library", e)
}
} }
} }
}.asCoroutineDispatcher()
/** /**
* Load the LLM, then process the formatted system prompt if provided * Load the LLM, then process the plain text system prompt if provided
*/ */
suspend fun load(pathToModel: String, systemPrompt: String? = null) = suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) =
withContext(runLoop) { withContext(llamaDispatcher) {
when (threadLocalState.get()) { check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" }
is State.NotInitialized -> { File(pathToModel).let {
val modelResult = loadModel(pathToModel) require(it.exists()) { "Model file not found: $pathToModel" }
if (modelResult != 0) throw IllegalStateException("Load model failed: $modelResult") require(it.isFile) { "Model file is not a file: $pathToModel" }
}
val initResult = initContext() Log.i(TAG, "Loading model... \n$pathToModel")
if (initResult != 0) throw IllegalStateException("Initialization failed with error code: $initResult") _state.value = State.LoadingModel
load(pathToModel).let { result ->
if (result != 0) throw IllegalStateException("Failed to Load model: $result")
}
prepare().let { result ->
if (result != 0) throw IllegalStateException("Failed to prepare resources: $result")
}
Log.i(TAG, "Model loaded!")
_state.value = State.ModelLoaded
Log.i(TAG, "Loaded model $pathToModel") systemPrompt?.let { prompt ->
threadLocalState.set(State.EnvReady) Log.i(TAG, "Sending system prompt...")
_state.value = State.ProcessingSystemPrompt
systemPrompt?.let { processSystemPrompt(prompt).let { result ->
initWithSystemPrompt(systemPrompt) if (result != 0) {
val errorMessage = "Failed to process system prompt: $result"
_state.value = State.Error(errorMessage)
throw IllegalStateException(errorMessage)
}
}
Log.i(TAG, "System prompt processed! Awaiting user prompt...")
} ?: run { } ?: run {
Log.w(TAG, "No system prompt to process.") Log.w(TAG, "No system prompt to process.")
threadLocalState.set(State.AwaitingUserPrompt)
}
}
else -> throw IllegalStateException("Model already loaded")
} }
_state.value = State.AwaitingUserPrompt
} }
/** /**
* Helper method to process system prompt and update [State] * Send plain text user prompt to LLM, which starts generating tokens in a [Flow]
*/
private suspend fun initWithSystemPrompt(formattedMessage: String) =
withContext(runLoop) {
when (threadLocalState.get()) {
is State.EnvReady -> {
Log.i(TAG, "Process system prompt...")
threadLocalState.set(State.Processing)
processSystemPrompt(formattedMessage).let {
if (it != 0)
throw IllegalStateException("Failed to process system prompt: $it")
}
Log.i(TAG, "System prompt processed!")
threadLocalState.set(State.AwaitingUserPrompt)
}
else -> throw IllegalStateException(
"Failed to process system prompt: Model not loaded!"
)
}
}
/**
* Send formatted user prompt to LLM
*/ */
fun sendUserPrompt( fun sendUserPrompt(
message: String, message: String,
predictLength: Int = DEFAULT_PREDICT_LENGTH, predictLength: Int = DEFAULT_PREDICT_LENGTH,
): Flow<String> = flow { ): Flow<String> = flow {
require(message.isNotEmpty()) { require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
Log.w(TAG, "User prompt discarded due to being empty!") check(_state.value is State.AwaitingUserPrompt) {
"User prompt discarded due to: ${_state.value}"
} }
when (val state = threadLocalState.get()) {
is State.AwaitingUserPrompt -> {
Log.i(TAG, "Sending user prompt...") Log.i(TAG, "Sending user prompt...")
threadLocalState.set(State.Processing) _state.value = State.ProcessingUserPrompt
processUserPrompt(message, predictLength).let { result -> processUserPrompt(message, predictLength).let { result ->
if (result != 0) { if (result != 0) {
Log.e(TAG, "Failed to process user prompt: $result") Log.e(TAG, "Failed to process user prompt: $result")
@ -127,64 +166,64 @@ class LLamaAndroid {
} }
Log.i(TAG, "User prompt processed! Generating assistant prompt...") Log.i(TAG, "User prompt processed! Generating assistant prompt...")
_state.value = State.Generating
while (true) { while (true) {
generateNextToken()?.let { utf8token -> generateNextToken()?.let { utf8token ->
if (utf8token.isNotEmpty()) emit(utf8token) if (utf8token.isNotEmpty()) emit(utf8token)
} ?: break } ?: break
} }
Log.i(TAG, "Assistant generation complete! Awaiting user prompt...")
Log.i(TAG, "Assistant generation complete!") _state.value = State.AwaitingUserPrompt
threadLocalState.set(State.AwaitingUserPrompt) }.flowOn(llamaDispatcher)
}
else -> {
Log.w(TAG, "User prompt discarded due to incorrect state: $state")
}
}
}.flowOn(runLoop)
/** /**
* Benchmark the model * Benchmark the model
*/ */
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String = suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String =
withContext(runLoop) { withContext(llamaDispatcher) {
when (threadLocalState.get()) { check(_state.value is State.AwaitingUserPrompt) {
is State.AwaitingUserPrompt -> { "Benchmark request discarded due to: $state"
threadLocalState.set(State.Processing) }
Log.d(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)") Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
_state.value = State.Benchmarking
benchModel(pp, tg, pl, nr).also { benchModel(pp, tg, pl, nr).also {
threadLocalState.set(State.AwaitingUserPrompt) _state.value = State.AwaitingUserPrompt
}
}
// TODO-hyin: disable button when state incorrect
else -> throw IllegalStateException("No model loaded")
} }
} }
/** /**
* Unloads the model and frees resources. * Unloads the model and frees resources
*
* This is a no-op if there's no model loaded.
*/ */
suspend fun unload() = suspend fun unloadModel() =
withContext(runLoop) { withContext(llamaDispatcher) {
when (val state = threadLocalState.get()) { when(_state.value) {
is State.EnvReady, State.AwaitingUserPrompt -> { is State.AwaitingUserPrompt, is State.Error -> {
cleanUp() Log.i(TAG, "Unloading model and free resources...")
threadLocalState.set(State.NotInitialized) unload()
_state.value = State.LibraryLoaded
Log.i(TAG, "Model unloaded!")
} }
else -> { else -> throw IllegalStateException("Cannot unload model in ${_state.value}")
Log.w(TAG, "Cannot unload model due to incorrect state: $state")
} }
} }
/**
* Cancel all ongoing coroutines and free GGML backends
*/
@RequiresCleanup("Call from `ViewModel.onCleared()` to prevent resource leaks!")
fun destroy() {
llamaScope.cancel()
when(_state.value) {
is State.Uninitialized -> {}
is State.LibraryLoaded -> shutdown()
else -> { unload(); shutdown() }
}
} }
companion object { companion object {
private val TAG = LLamaAndroid::class.simpleName private val TAG = LLamaAndroid::class.simpleName
private const val LIB_LLAMA_ANDROID = "llama-android" private const val LIB_LLAMA_ANDROID = "llama-android"
private const val LLAMA_THREAD = "llama-thread"
private const val DEFAULT_PREDICT_LENGTH = 64 private const val DEFAULT_PREDICT_LENGTH = 64
// Enforce only one instance of Llm. // Enforce only one instance of Llm.