REWRITE JNI bridge; Update viewmodel

This commit is contained in:
Han Yin 2025-04-09 10:09:23 -07:00
parent e1bc87610e
commit 6d2279e9cd
3 changed files with 218 additions and 163 deletions

View File

@ -29,7 +29,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
viewModelScope.launch {
try {
llamaAndroid.unload()
llamaAndroid.destroy()
} catch (exc: IllegalStateException) {
messages += exc.message!!
}
@ -83,7 +83,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
fun load(pathToModel: String) {
viewModelScope.launch {
try {
llamaAndroid.load(pathToModel)
llamaAndroid.loadModel(pathToModel)
messages += "Loaded $pathToModel"
} catch (exc: IllegalStateException) {
Log.e(tag, "load() failed", exc)
@ -103,4 +103,14 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
fun log(message: String) {
messages += message
}
fun unload() =
viewModelScope.launch {
try {
llamaAndroid.unloadModel()
} catch (exc: IllegalStateException) {
Log.e(tag, "unload() failed", exc)
messages += exc.message!!
}
}
}

View File

@ -33,8 +33,8 @@ static std::string join(const std::vector<T> &values, const std::string &delim)
/**
* LLama resources: context, model, batch and sampler
*/
constexpr int N_THREADS_MIN = 1;
constexpr int N_THREADS_MAX = 8;
constexpr int N_THREADS_MIN = 2;
constexpr int N_THREADS_MAX = 4;
constexpr int N_THREADS_HEADROOM = 2;
constexpr int DEFAULT_CONTEXT_SIZE = 8192;
@ -70,38 +70,27 @@ static void log_callback(ggml_log_level level, const char *fmt, void *data) {
__android_log_print(priority, TAG, fmt, data);
}
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
JNIEnv *env;
if (vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6) != JNI_OK) {
return JNI_ERR;
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_init(JNIEnv *env, jobject /*unused*/) {
// Set llama log handler to Android
llama_log_set(log_callback, nullptr);
// Initialize backends
llama_backend_init();
LOGi("Backend initiated.");
return JNI_VERSION_1_6;
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_systemInfo(JNIEnv *env, jobject /*unused*/) {
return env->NewStringUTF(llama_print_system_info());
LOGi("Backend initiated; Log handler set.");
}
extern "C"
JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring filename) {
Java_android_llama_cpp_LLamaAndroid_load(JNIEnv *env, jobject, jstring jmodel_path) {
llama_model_params model_params = llama_model_default_params();
const auto *path_to_model = env->GetStringUTFChars(filename, 0);
LOGd("%s: Loading model from: \n%s\n", __func__, path_to_model);
const auto *model_path = env->GetStringUTFChars(jmodel_path, 0);
LOGd("%s: Loading model from: \n%s\n", __func__, model_path);
auto *model = llama_model_load_from_file(path_to_model, model_params);
env->ReleaseStringUTFChars(filename, path_to_model);
auto *model = llama_model_load_from_file(model_path, model_params);
env->ReleaseStringUTFChars(jmodel_path, model_path);
if (!model) {
return 1;
}
@ -148,7 +137,7 @@ static common_sampler *new_sampler(float temp) {
extern "C"
JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
Java_android_llama_cpp_LLamaAndroid_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
auto *context = init_context(g_model);
if (!context) { return 1; }
g_context = context;
@ -158,17 +147,6 @@ Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unus
return 0;
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) {
common_sampler_free(g_sampler);
g_chat_templates.reset();
llama_batch_free(g_batch);
llama_free(g_context);
llama_model_free(g_model);
llama_backend_free();
}
static std::string get_backend() {
std::vector<std::string> backends;
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
@ -181,6 +159,12 @@ static std::string get_backend() {
return backends.empty() ? "CPU" : join(backends, ",");
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_systemInfo(JNIEnv *env, jobject /*unused*/) {
return env->NewStringUTF(llama_print_system_info());
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
@ -269,7 +253,7 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
const auto backend = get_backend();
std::stringstream result;
result << std::setprecision(2);
result << std::setprecision(3);
result << "| model | size | params | backend | test | t/s |\n";
result << "| --- | --- | --- | --- | --- | --- |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
@ -348,7 +332,7 @@ static void reset_short_term_states() {
static int decode_tokens_in_batches(
llama_context *context,
llama_batch batch,
llama_batch &batch,
const llama_tokens &tokens,
const llama_pos start_pos,
const bool compute_last_logit = false) {
@ -574,3 +558,25 @@ Java_android_llama_cpp_LLamaAndroid_generateNextToken(
}
return result;
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
// Reset long-term & short-term states
reset_long_term_states();
reset_short_term_states();
// Free up resources
common_sampler_free(g_sampler);
g_chat_templates.reset();
llama_batch_free(g_batch);
llama_free(g_context);
llama_model_free(g_model);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_shutdown(JNIEnv *env, jobject /*unused*/) {
llama_backend_free();
}

View File

@ -1,190 +1,229 @@
package android.llama.cpp
import android.util.Log
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.util.concurrent.Executors
import kotlin.concurrent.thread
import java.io.File
class LLamaAndroid {
@Target(AnnotationTarget.FUNCTION)
@Retention(AnnotationRetention.SOURCE)
annotation class RequiresCleanup(val message: String = "Remember to call this method for proper cleanup!")
/**
* JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
*
* This class implements a singleton pattern for managing the lifecycle of a single LLM instance.
* All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety
* with the underlying C++ native code.
*
* The typical usage flow is:
* 1. Get instance via [instance]
* 2. Load a model with [loadModel]
* 3. Send prompts with [sendUserPrompt]
* 4. Generate responses as token streams
* 5. Unload the model with [unloadModel] when switching models
* 6. Call [destroy] when completely done
*
* State transitions are managed automatically and validated at each operation.
*
* @see llama-android.cpp for the native implementation details
*/
class LLamaAndroid private constructor() {
/**
* JNI methods
* @see llama-android.cpp
*/
private external fun init()
private external fun load(modelPath: String): Int
private external fun prepare(): Int
private external fun systemInfo(): String
private external fun loadModel(filename: String): Int
private external fun initContext(): Int
private external fun cleanUp()
private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
private external fun processSystemPrompt(systemPrompt: String): Int
private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
private external fun generateNextToken(): String?
private external fun unload()
private external fun shutdown()
/**
* Thread local state
* Fine-grained state management
*/
private sealed interface State {
data object NotInitialized: State
data object EnvReady: State
data object AwaitingUserPrompt: State
data object Processing: State
sealed class State {
object Uninitialized : State()
object LibraryLoaded : State()
object LoadingModel : State()
object ModelLoaded : State()
object ProcessingSystemPrompt : State()
object AwaitingUserPrompt : State()
object ProcessingUserPrompt : State()
object Generating : State()
object Benchmarking : State()
data class Error(
val errorMessage: String = ""
) : State()
}
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.NotInitialized }
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
thread(start = false, name = LLAMA_THREAD) {
Log.d(TAG, "Dedicated thread for native code: ${Thread.currentThread().name}")
// No-op if called more than once.
System.loadLibrary(LIB_LLAMA_ANDROID)
Log.d(TAG, systemInfo())
it.run()
}.apply {
uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
Log.e(TAG, "Unhandled exception", exception)
}
}
}.asCoroutineDispatcher()
private val _state = MutableStateFlow<State>(State.Uninitialized)
val state: StateFlow<State> = _state
/**
* Load the LLM, then process the formatted system prompt if provided
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
*/
suspend fun load(pathToModel: String, systemPrompt: String? = null) =
withContext(runLoop) {
when (threadLocalState.get()) {
is State.NotInitialized -> {
val modelResult = loadModel(pathToModel)
if (modelResult != 0) throw IllegalStateException("Load model failed: $modelResult")
@OptIn(ExperimentalCoroutinesApi::class)
private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
val initResult = initContext()
if (initResult != 0) throw IllegalStateException("Initialization failed with error code: $initResult")
init {
llamaScope.launch {
try {
System.loadLibrary(LIB_LLAMA_ANDROID)
init()
_state.value = State.LibraryLoaded
Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}")
} catch (e: Exception) {
_state.value = State.Error("Failed to load native library: ${e.message}")
Log.e(TAG, "Failed to load native library", e)
}
}
}
Log.i(TAG, "Loaded model $pathToModel")
threadLocalState.set(State.EnvReady)
/**
* Load the LLM, then process the plain text system prompt if provided
*/
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) =
withContext(llamaDispatcher) {
check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" }
File(pathToModel).let {
require(it.exists()) { "Model file not found: $pathToModel" }
require(it.isFile) { "Model file is not a file: $pathToModel" }
}
systemPrompt?.let {
initWithSystemPrompt(systemPrompt)
} ?: run {
Log.w(TAG, "No system prompt to process.")
threadLocalState.set(State.AwaitingUserPrompt)
Log.i(TAG, "Loading model... \n$pathToModel")
_state.value = State.LoadingModel
load(pathToModel).let { result ->
if (result != 0) throw IllegalStateException("Failed to Load model: $result")
}
prepare().let { result ->
if (result != 0) throw IllegalStateException("Failed to prepare resources: $result")
}
Log.i(TAG, "Model loaded!")
_state.value = State.ModelLoaded
systemPrompt?.let { prompt ->
Log.i(TAG, "Sending system prompt...")
_state.value = State.ProcessingSystemPrompt
processSystemPrompt(prompt).let { result ->
if (result != 0) {
val errorMessage = "Failed to process system prompt: $result"
_state.value = State.Error(errorMessage)
throw IllegalStateException(errorMessage)
}
}
else -> throw IllegalStateException("Model already loaded")
Log.i(TAG, "System prompt processed! Awaiting user prompt...")
} ?: run {
Log.w(TAG, "No system prompt to process.")
}
_state.value = State.AwaitingUserPrompt
}
/**
* Helper method to process system prompt and update [State]
*/
private suspend fun initWithSystemPrompt(formattedMessage: String) =
withContext(runLoop) {
when (threadLocalState.get()) {
is State.EnvReady -> {
Log.i(TAG, "Process system prompt...")
threadLocalState.set(State.Processing)
processSystemPrompt(formattedMessage).let {
if (it != 0)
throw IllegalStateException("Failed to process system prompt: $it")
}
Log.i(TAG, "System prompt processed!")
threadLocalState.set(State.AwaitingUserPrompt)
}
else -> throw IllegalStateException(
"Failed to process system prompt: Model not loaded!"
)
}
}
/**
* Send formatted user prompt to LLM
* Send plain text user prompt to LLM, which starts generating tokens in a [Flow]
*/
fun sendUserPrompt(
message: String,
predictLength: Int = DEFAULT_PREDICT_LENGTH,
): Flow<String> = flow {
require(message.isNotEmpty()) {
Log.w(TAG, "User prompt discarded due to being empty!")
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is State.AwaitingUserPrompt) {
"User prompt discarded due to: ${_state.value}"
}
when (val state = threadLocalState.get()) {
is State.AwaitingUserPrompt -> {
Log.i(TAG, "Sending user prompt...")
threadLocalState.set(State.Processing)
processUserPrompt(message, predictLength).let { result ->
if (result != 0) {
Log.e(TAG, "Failed to process user prompt: $result")
return@flow
}
}
Log.i(TAG, "User prompt processed! Generating assistant prompt...")
while (true) {
generateNextToken()?.let { utf8token ->
if (utf8token.isNotEmpty()) emit(utf8token)
} ?: break
}
Log.i(TAG, "Assistant generation complete!")
threadLocalState.set(State.AwaitingUserPrompt)
}
else -> {
Log.w(TAG, "User prompt discarded due to incorrect state: $state")
Log.i(TAG, "Sending user prompt...")
_state.value = State.ProcessingUserPrompt
processUserPrompt(message, predictLength).let { result ->
if (result != 0) {
Log.e(TAG, "Failed to process user prompt: $result")
return@flow
}
}
}.flowOn(runLoop)
Log.i(TAG, "User prompt processed! Generating assistant prompt...")
_state.value = State.Generating
while (true) {
generateNextToken()?.let { utf8token ->
if (utf8token.isNotEmpty()) emit(utf8token)
} ?: break
}
Log.i(TAG, "Assistant generation complete! Awaiting user prompt...")
_state.value = State.AwaitingUserPrompt
}.flowOn(llamaDispatcher)
/**
* Benchmark the model
*/
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String =
withContext(runLoop) {
when (threadLocalState.get()) {
is State.AwaitingUserPrompt -> {
threadLocalState.set(State.Processing)
Log.d(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
benchModel(pp, tg, pl, nr).also {
threadLocalState.set(State.AwaitingUserPrompt)
}
}
// TODO-hyin: disable button when state incorrect
else -> throw IllegalStateException("No model loaded")
withContext(llamaDispatcher) {
check(_state.value is State.AwaitingUserPrompt) {
"Benchmark request discarded due to: $state"
}
Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
_state.value = State.Benchmarking
benchModel(pp, tg, pl, nr).also {
_state.value = State.AwaitingUserPrompt
}
}
/**
* Unloads the model and frees resources.
*
* This is a no-op if there's no model loaded.
* Unloads the model and frees resources
*/
suspend fun unload() =
withContext(runLoop) {
when (val state = threadLocalState.get()) {
is State.EnvReady, State.AwaitingUserPrompt -> {
cleanUp()
threadLocalState.set(State.NotInitialized)
}
else -> {
Log.w(TAG, "Cannot unload model due to incorrect state: $state")
suspend fun unloadModel() =
withContext(llamaDispatcher) {
when(_state.value) {
is State.AwaitingUserPrompt, is State.Error -> {
Log.i(TAG, "Unloading model and free resources...")
unload()
_state.value = State.LibraryLoaded
Log.i(TAG, "Model unloaded!")
}
else -> throw IllegalStateException("Cannot unload model in ${_state.value}")
}
}
/**
* Cancel all ongoing coroutines and free GGML backends
*/
@RequiresCleanup("Call from `ViewModel.onCleared()` to prevent resource leaks!")
fun destroy() {
llamaScope.cancel()
when(_state.value) {
is State.Uninitialized -> {}
is State.LibraryLoaded -> shutdown()
else -> { unload(); shutdown() }
}
}
companion object {
private val TAG = LLamaAndroid::class.simpleName
private const val LIB_LLAMA_ANDROID = "llama-android"
private const val LLAMA_THREAD = "llama-thread"
private const val DEFAULT_PREDICT_LENGTH = 64
// Enforce only one instance of Llm.