Polish binding: Remove verbose setup JNI APIs; Update state machine states.
This commit is contained in:
parent
7dc9968f82
commit
0ade7fb4d7
|
|
@ -12,10 +12,11 @@
|
|||
* Logging utils
|
||||
*/
|
||||
#define TAG "llama-android.cpp"
|
||||
#define LOGd(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__)
|
||||
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
||||
#define LOGw(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
|
||||
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
|
||||
#define LOGv(...) __android_log_print(ANDROID_LOG_VERBOSE, TAG, __VA_ARGS__)
|
||||
#define LOGd(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__)
|
||||
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
||||
#define LOGw(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
|
||||
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
|
||||
|
||||
/**
|
||||
* LLama resources: context, model, batch and sampler
|
||||
|
|
@ -55,32 +56,35 @@ static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
|||
__android_log_print(priority, TAG, fmt, data);
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||
llama_log_set(log_callback, nullptr);
|
||||
}
|
||||
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) {
|
||||
JNIEnv* env;
|
||||
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
|
||||
return JNI_ERR;
|
||||
}
|
||||
|
||||
// Set llama log handler to Android
|
||||
llama_log_set(log_callback, nullptr);
|
||||
|
||||
// Initialize backends
|
||||
llama_backend_init();
|
||||
LOGi("Backend initiated.");
|
||||
|
||||
return JNI_VERSION_1_6;
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject /*unused*/) {
|
||||
Java_android_llama_cpp_LLamaAndroid_systemInfo(JNIEnv *env, jobject /*unused*/) {
|
||||
return env->NewStringUTF(llama_print_system_info());
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||
llama_backend_init();
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) {
|
||||
Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring filename) {
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
|
||||
const auto *path_to_model = env->GetStringUTFChars(filename, 0);
|
||||
LOGi("Loading model from: %s", path_to_model);
|
||||
LOGd("Loading model from: %s", path_to_model);
|
||||
|
||||
model = llama_model_load_from_file(path_to_model, model_params);
|
||||
env->ReleaseStringUTFChars(filename, path_to_model);
|
||||
|
|
@ -153,7 +157,7 @@ void new_sampler(float temp) {
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_ctx_1init(JNIEnv * /*env*/, jobject /*unused*/) {
|
||||
Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) {
|
||||
int ret = init_context();
|
||||
if (ret != 0) { return ret; }
|
||||
new_batch(BATCH_SIZE);
|
||||
|
|
@ -163,7 +167,7 @@ Java_android_llama_cpp_LLamaAndroid_ctx_1init(JNIEnv * /*env*/, jobject /*unused
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_clean_1up(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||
Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||
llama_model_free(model);
|
||||
llama_free(context);
|
||||
llama_backend_free();
|
||||
|
|
@ -173,7 +177,7 @@ Java_android_llama_cpp_LLamaAndroid_clean_1up(JNIEnv * /*unused*/, jobject /*unu
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_bench_1model(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) {
|
||||
Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) {
|
||||
auto pp_avg = 0.0;
|
||||
auto tg_avg = 0.0;
|
||||
auto pp_std = 0.0;
|
||||
|
|
@ -284,7 +288,7 @@ std::string cached_token_chars;
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_process_1system_1prompt(
|
||||
Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/,
|
||||
jstring jsystem_prompt
|
||||
|
|
@ -299,10 +303,15 @@ Java_android_llama_cpp_LLamaAndroid_process_1system_1prompt(
|
|||
|
||||
// Obtain and tokenize system prompt
|
||||
const auto *const system_text = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
||||
LOGi("System prompt: \n%s", system_text);
|
||||
LOGd("System prompt received: \n%s", system_text);
|
||||
const auto system_tokens = common_tokenize(context, system_text, true, true);
|
||||
env->ReleaseStringUTFChars(jsystem_prompt, system_text);
|
||||
|
||||
// Print each token in verbose mode
|
||||
for (auto id : system_tokens) {
|
||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(context, id).c_str(), id);
|
||||
}
|
||||
|
||||
// Add system prompt tokens to batch
|
||||
common_batch_clear(*batch);
|
||||
// TODO-hyin: support batch processing!
|
||||
|
|
@ -325,11 +334,11 @@ Java_android_llama_cpp_LLamaAndroid_process_1system_1prompt(
|
|||
// TODO-hyin: support KV cache backtracking
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_process_1user_1prompt(
|
||||
Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/,
|
||||
jstring juser_prompt,
|
||||
jint n_len
|
||||
jint n_predict
|
||||
) {
|
||||
// Reset short-term states
|
||||
token_predict_budget = 0;
|
||||
|
|
@ -337,12 +346,17 @@ Java_android_llama_cpp_LLamaAndroid_process_1user_1prompt(
|
|||
|
||||
// Obtain and tokenize user prompt
|
||||
const auto *const user_text = env->GetStringUTFChars(juser_prompt, nullptr);
|
||||
LOGi("User prompt: \n%s", user_text);
|
||||
LOGd("User prompt received: \n%s", user_text);
|
||||
const auto user_tokens = common_tokenize(context, user_text, true, true);
|
||||
env->ReleaseStringUTFChars(juser_prompt, user_text);
|
||||
|
||||
// Print each token in verbose mode
|
||||
for (auto id : user_tokens) {
|
||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(context, id).c_str(), id);
|
||||
}
|
||||
|
||||
// Check if context space is enough for desired tokens
|
||||
int desired_budget = current_position + user_tokens.size() + n_len;
|
||||
int desired_budget = current_position + user_tokens.size() + n_predict;
|
||||
if (desired_budget > llama_n_ctx(context)) {
|
||||
LOGe("error: total tokens exceed context size");
|
||||
return -1;
|
||||
|
|
@ -404,13 +418,13 @@ bool is_valid_utf8(const char *string) {
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_predict_1loop(
|
||||
Java_android_llama_cpp_LLamaAndroid_predictLoop(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/
|
||||
) {
|
||||
// Stop if running out of token budget
|
||||
if (current_position >= token_predict_budget) {
|
||||
LOGi("STOP: current position (%d) exceeds budget (%d)", current_position, token_predict_budget);
|
||||
LOGw("STOP: current position (%d) exceeds budget (%d)", current_position, token_predict_budget);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
@ -420,7 +434,7 @@ Java_android_llama_cpp_LLamaAndroid_predict_1loop(
|
|||
|
||||
// Stop if next token is EOG
|
||||
if (llama_vocab_is_eog(llama_model_get_vocab(model), new_token_id)) {
|
||||
LOGi("id: %d,\tIS EOG!\nSTOP.", new_token_id);
|
||||
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,42 +15,36 @@ class LLamaAndroid {
|
|||
* JNI methods
|
||||
* @see llama-android.cpp
|
||||
*/
|
||||
private external fun log_to_android()
|
||||
private external fun system_info(): String
|
||||
private external fun backend_init()
|
||||
private external fun systemInfo(): String
|
||||
|
||||
private external fun load_model(filename: String): Int
|
||||
private external fun ctx_init(): Int
|
||||
private external fun clean_up()
|
||||
private external fun loadModel(filename: String): Int
|
||||
private external fun initContext(): Int
|
||||
private external fun cleanUp()
|
||||
|
||||
private external fun bench_model(pp: Int, tg: Int, pl: Int, nr: Int): String
|
||||
private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
|
||||
|
||||
private external fun process_system_prompt(system_prompt: String): Int
|
||||
private external fun process_user_prompt(user_prompt: String, nLen: Int): Int
|
||||
private external fun predict_loop(): String?
|
||||
private external fun processSystemPrompt(systemPrompt: String): Int
|
||||
private external fun processUserPrompt(userPrompt: String, nPredict: Int): Int
|
||||
private external fun predictLoop(): String?
|
||||
|
||||
/**
|
||||
* Thread local state
|
||||
*/
|
||||
private sealed interface State {
|
||||
data object Idle: State
|
||||
data object ModelLoaded: State
|
||||
data object ReadyForUserPrompt: State
|
||||
data object NotInitialized: State
|
||||
data object EnvReady: State
|
||||
data object AwaitingUserPrompt: State
|
||||
data object Processing: State
|
||||
}
|
||||
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
|
||||
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.NotInitialized }
|
||||
|
||||
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
|
||||
thread(start = false, name = "Llm-RunLoop") {
|
||||
thread(start = false, name = LLAMA_THREAD) {
|
||||
Log.d(TAG, "Dedicated thread for native code: ${Thread.currentThread().name}")
|
||||
|
||||
// No-op if called more than once.
|
||||
System.loadLibrary("llama-android")
|
||||
|
||||
// Set llama log handler to Android
|
||||
log_to_android()
|
||||
backend_init()
|
||||
|
||||
Log.d(TAG, system_info())
|
||||
System.loadLibrary(LIB_LLAMA_ANDROID)
|
||||
Log.d(TAG, systemInfo())
|
||||
|
||||
it.run()
|
||||
}.apply {
|
||||
|
|
@ -61,26 +55,26 @@ class LLamaAndroid {
|
|||
}.asCoroutineDispatcher()
|
||||
|
||||
/**
|
||||
* Load the LLM, then process the system prompt if provided
|
||||
* Load the LLM, then process the formatted system prompt if provided
|
||||
*/
|
||||
suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) {
|
||||
withContext(runLoop) {
|
||||
when (threadLocalState.get()) {
|
||||
is State.Idle -> {
|
||||
val model = load_model(pathToModel)
|
||||
if (model != 0) throw IllegalStateException("Load model failed")
|
||||
is State.NotInitialized -> {
|
||||
val modelResult = loadModel(pathToModel)
|
||||
if (modelResult != 0) throw IllegalStateException("Load model failed: $modelResult")
|
||||
|
||||
val result = ctx_init()
|
||||
if (result != 0) throw IllegalStateException("Initialization failed with error code: $result")
|
||||
val initResult = initContext()
|
||||
if (initResult != 0) throw IllegalStateException("Initialization failed with error code: $initResult")
|
||||
|
||||
Log.i(TAG, "Loaded model $pathToModel")
|
||||
threadLocalState.set(State.ModelLoaded)
|
||||
threadLocalState.set(State.EnvReady)
|
||||
|
||||
formattedSystemPrompt?.let {
|
||||
initWithSystemPrompt(formattedSystemPrompt)
|
||||
} ?: {
|
||||
Log.w(TAG, "No system prompt to process.")
|
||||
threadLocalState.set(State.ReadyForUserPrompt)
|
||||
threadLocalState.set(State.AwaitingUserPrompt)
|
||||
}
|
||||
}
|
||||
else -> throw IllegalStateException("Model already loaded")
|
||||
|
|
@ -94,15 +88,16 @@ class LLamaAndroid {
|
|||
private suspend fun initWithSystemPrompt(formattedMessage: String) {
|
||||
withContext(runLoop) {
|
||||
when (threadLocalState.get()) {
|
||||
is State.ModelLoaded -> {
|
||||
is State.EnvReady -> {
|
||||
Log.i(TAG, "Process system prompt...")
|
||||
process_system_prompt(formattedMessage).let {
|
||||
threadLocalState.set(State.Processing)
|
||||
processSystemPrompt(formattedMessage).let {
|
||||
if (it != 0)
|
||||
throw IllegalStateException("Failed to process system prompt: $it")
|
||||
}
|
||||
|
||||
Log.i(TAG, "System prompt processed!")
|
||||
threadLocalState.set(State.ReadyForUserPrompt)
|
||||
threadLocalState.set(State.AwaitingUserPrompt)
|
||||
}
|
||||
else -> throw IllegalStateException(
|
||||
"Failed to process system prompt: Model not loaded!"
|
||||
|
|
@ -112,31 +107,36 @@ class LLamaAndroid {
|
|||
}
|
||||
|
||||
/**
|
||||
* Send plain text user prompt to LLM
|
||||
* Send formatted user prompt to LLM
|
||||
*/
|
||||
fun sendUserPrompt(
|
||||
formattedMessage: String,
|
||||
nPredict: Int = DEFAULT_PREDICT_LENGTH,
|
||||
): Flow<String> = flow {
|
||||
when (threadLocalState.get()) {
|
||||
is State.ReadyForUserPrompt -> {
|
||||
process_user_prompt(formattedMessage, nPredict).let {
|
||||
if (it != 0) {
|
||||
Log.e(TAG, "Failed to process user prompt: $it")
|
||||
when (val state = threadLocalState.get()) {
|
||||
is State.AwaitingUserPrompt -> {
|
||||
Log.i(TAG, "Sending user prompt...")
|
||||
threadLocalState.set(State.Processing)
|
||||
processUserPrompt(formattedMessage, nPredict).let { result ->
|
||||
if (result != 0) {
|
||||
Log.e(TAG, "Failed to process user prompt: $result")
|
||||
return@flow
|
||||
}
|
||||
}
|
||||
|
||||
Log.i(TAG, "User prompt processed! Generating assistant prompt...")
|
||||
while (true) {
|
||||
val str = predict_loop() ?: break
|
||||
if (str.isNotEmpty()) {
|
||||
emit(str)
|
||||
}
|
||||
predictLoop()?.let { utf8token ->
|
||||
if (utf8token.isNotEmpty()) emit(utf8token)
|
||||
} ?: break
|
||||
}
|
||||
|
||||
Log.i(TAG, "Assistant generation complete!")
|
||||
threadLocalState.set(State.AwaitingUserPrompt)
|
||||
}
|
||||
else -> {
|
||||
Log.w(TAG, "User prompt discarded due to incorrect state: $state")
|
||||
}
|
||||
else -> {}
|
||||
}
|
||||
}.flowOn(runLoop)
|
||||
|
||||
|
|
@ -146,9 +146,9 @@ class LLamaAndroid {
|
|||
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
|
||||
return withContext(runLoop) {
|
||||
when (val state = threadLocalState.get()) {
|
||||
is State.ModelLoaded -> {
|
||||
is State.EnvReady -> {
|
||||
Log.d(TAG, "bench(): $state")
|
||||
bench_model(pp, tg, pl, nr)
|
||||
benchModel(pp, tg, pl, nr)
|
||||
}
|
||||
|
||||
// TODO-hyin: catch exception in ViewController; disable button when state incorrect
|
||||
|
|
@ -164,12 +164,14 @@ class LLamaAndroid {
|
|||
*/
|
||||
suspend fun unload() {
|
||||
withContext(runLoop) {
|
||||
when (threadLocalState.get()) {
|
||||
is State.ModelLoaded -> {
|
||||
clean_up()
|
||||
threadLocalState.set(State.Idle)
|
||||
when (val state = threadLocalState.get()) {
|
||||
is State.EnvReady, State.AwaitingUserPrompt -> {
|
||||
cleanUp()
|
||||
threadLocalState.set(State.NotInitialized)
|
||||
}
|
||||
else -> {
|
||||
Log.w(TAG, "Cannot unload model due to incorrect state: $state")
|
||||
}
|
||||
else -> {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -177,7 +179,10 @@ class LLamaAndroid {
|
|||
companion object {
|
||||
private val TAG = this::class.simpleName
|
||||
|
||||
private const val DEFAULT_PREDICT_LENGTH = 128
|
||||
private const val LIB_LLAMA_ANDROID = "llama-android"
|
||||
private const val LLAMA_THREAD = "llama-thread"
|
||||
|
||||
private const val DEFAULT_PREDICT_LENGTH = 64
|
||||
|
||||
// Enforce only one instance of Llm.
|
||||
private val _instance: LLamaAndroid = LLamaAndroid()
|
||||
|
|
|
|||
Loading…
Reference in New Issue