Feature: chat template auto formatting

This commit is contained in:
Han Yin 2025-04-07 20:37:33 -07:00
parent 1b0754c0f5
commit 8bf2f4d412
3 changed files with 78 additions and 34 deletions

View File

@ -46,7 +46,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
viewModelScope.launch {
// TODO-hyin: implement format message
llamaAndroid.sendUserPrompt(formattedMessage = text)
llamaAndroid.sendUserPrompt(message = actualText)
.catch {
Log.e(tag, "send() failed", it)
messages += it.message!!

View File

@ -5,8 +5,10 @@
#include <string>
#include <unistd.h>
#include <sampling.h>
#include "llama.h"
#include "chat.h"
#include "common.h"
#include "llama.h"
template <class T> static std::string join(const std::vector<T> & values, const std::string & delim) {
std::ostringstream str;
@ -42,6 +44,7 @@ static llama_model * g_model;
static llama_context * g_context;
static llama_batch * g_batch;
static common_sampler * g_sampler;
static common_chat_templates_ptr g_chat_templates;
static void log_callback(ggml_log_level level, const char *fmt, void *data) {
int priority;
@ -174,16 +177,18 @@ Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unus
if (ret != 0) { return ret; }
new_batch(BATCH_SIZE);
new_sampler(SAMPLER_TEMP);
g_chat_templates = common_chat_templates_init(g_model, "");
return 0;
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) {
llama_model_free(g_model);
llama_free(g_context);
delete g_batch;
g_chat_templates.reset();
common_sampler_free(g_sampler);
delete g_batch;
llama_free(g_context);
llama_model_free(g_model);
llama_backend_free();
}
@ -298,12 +303,25 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
/**
* Prediction loop's long-term and short-term states
* Prediction loop's long-term states
*/
static llama_pos current_position;
constexpr const char* ROLE_SYSTEM = "system";
constexpr const char* ROLE_USER = "user";
constexpr const char* ROLE_ASSISTANT = "assistant";
static llama_pos token_predict_budget;
static std::string cached_token_chars;
static llama_pos current_position;
static std::vector<common_chat_msg> chat_msgs;
static std::string chat_add_and_format(const std::string & role, const std::string & content) {
common_chat_msg new_msg;
new_msg.role = role;
new_msg.content = content;
auto formatted = common_chat_format_single(
g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
chat_msgs.push_back(new_msg);
LOGi("Formatted and added %s message: \n%s\n", role.c_str(), formatted.c_str());
return formatted;
}
static int decode_tokens_in_batches(
llama_context *context,
@ -337,6 +355,13 @@ static int decode_tokens_in_batches(
return 0;
}
/**
* Prediction loop's short-term states
*/
static llama_pos token_predict_budget;
static std::string cached_token_chars;
static std::ostringstream assistant_ss; // For storing current assistant message
extern "C"
JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
@ -351,14 +376,22 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
// Reset short-term states
token_predict_budget = 0;
cached_token_chars.clear();
assistant_ss.str("");
// Obtain and tokenize system prompt
const auto *const system_text = env->GetStringUTFChars(jsystem_prompt, nullptr);
LOGd("System prompt received: \n%s", system_text);
const auto system_tokens = common_tokenize(g_context, system_text, true, true);
env->ReleaseStringUTFChars(jsystem_prompt, system_text);
// Obtain system prompt from JEnv
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
LOGd("System prompt received: \n%s", system_prompt);
std::string formatted_system_prompt(system_prompt);
env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
// Print each token in verbose mode
// Format system prompt if applicable
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
if (has_chat_template) {
formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt);
}
// Tokenize system prompt
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt, has_chat_template, has_chat_template);
for (auto id : system_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
}
@ -386,14 +419,22 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
// Reset short-term states
token_predict_budget = 0;
cached_token_chars.clear();
assistant_ss.str("");
// Obtain and tokenize user prompt
const auto *const user_text = env->GetStringUTFChars(juser_prompt, nullptr);
LOGd("User prompt received: \n%s", user_text);
const auto user_tokens = common_tokenize(g_context, user_text, true, true);
env->ReleaseStringUTFChars(juser_prompt, user_text);
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
LOGd("User prompt received: \n%s", user_prompt);
std::string formatted_user_prompt(user_prompt);
env->ReleaseStringUTFChars(juser_prompt, user_prompt);
// Print each token in verbose mode
// Format user prompt if applicable
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
if (has_chat_template) {
formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt);
}
// Decode formatted user prompts
const auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
for (auto id : user_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
}
@ -481,6 +522,7 @@ Java_android_llama_cpp_LLamaAndroid_predictLoop(
// Stop if next token is EOG
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str());
return nullptr;
}
@ -493,6 +535,8 @@ Java_android_llama_cpp_LLamaAndroid_predictLoop(
if (is_valid_utf8(cached_token_chars.c_str())) {
result = env->NewStringUTF(cached_token_chars.c_str());
LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
assistant_ss << cached_token_chars;
cached_token_chars.clear();
} else {
LOGv("id: %d,\tappend to cache", new_token_id);

View File

@ -57,7 +57,7 @@ class LLamaAndroid {
/**
* Load the LLM, then process the formatted system prompt if provided
*/
suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) =
suspend fun load(pathToModel: String, systemPrompt: String? = null) =
withContext(runLoop) {
when (threadLocalState.get()) {
is State.NotInitialized -> {
@ -70,8 +70,8 @@ class LLamaAndroid {
Log.i(TAG, "Loaded model $pathToModel")
threadLocalState.set(State.EnvReady)
formattedSystemPrompt?.let {
initWithSystemPrompt(formattedSystemPrompt)
systemPrompt?.let {
initWithSystemPrompt(systemPrompt)
} ?: run {
Log.w(TAG, "No system prompt to process.")
threadLocalState.set(State.AwaitingUserPrompt)
@ -108,10 +108,10 @@ class LLamaAndroid {
* Send formatted user prompt to LLM
*/
fun sendUserPrompt(
formattedMessage: String,
nPredict: Int = DEFAULT_PREDICT_LENGTH,
message: String,
predictLength: Int = DEFAULT_PREDICT_LENGTH,
): Flow<String> = flow {
require(formattedMessage.isNotEmpty()) {
require(message.isNotEmpty()) {
Log.w(TAG, "User prompt discarded due to being empty!")
}
@ -119,7 +119,7 @@ class LLamaAndroid {
is State.AwaitingUserPrompt -> {
Log.i(TAG, "Sending user prompt...")
threadLocalState.set(State.Processing)
processUserPrompt(formattedMessage, nPredict).let { result ->
processUserPrompt(message, predictLength).let { result ->
if (result != 0) {
Log.e(TAG, "Failed to process user prompt: $result")
return@flow