Feature: chat template auto formatting
This commit is contained in:
parent
1b0754c0f5
commit
8bf2f4d412
|
|
@ -46,7 +46,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
|
|||
|
||||
viewModelScope.launch {
|
||||
// TODO-hyin: implement format message
|
||||
llamaAndroid.sendUserPrompt(formattedMessage = text)
|
||||
llamaAndroid.sendUserPrompt(message = actualText)
|
||||
.catch {
|
||||
Log.e(tag, "send() failed", it)
|
||||
messages += it.message!!
|
||||
|
|
|
|||
|
|
@ -5,8 +5,10 @@
|
|||
#include <string>
|
||||
#include <unistd.h>
|
||||
#include <sampling.h>
|
||||
#include "llama.h"
|
||||
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
template <class T> static std::string join(const std::vector<T> & values, const std::string & delim) {
|
||||
std::ostringstream str;
|
||||
|
|
@ -42,6 +44,7 @@ static llama_model * g_model;
|
|||
static llama_context * g_context;
|
||||
static llama_batch * g_batch;
|
||||
static common_sampler * g_sampler;
|
||||
static common_chat_templates_ptr g_chat_templates;
|
||||
|
||||
static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
||||
int priority;
|
||||
|
|
@ -174,16 +177,18 @@ Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unus
|
|||
if (ret != 0) { return ret; }
|
||||
new_batch(BATCH_SIZE);
|
||||
new_sampler(SAMPLER_TEMP);
|
||||
g_chat_templates = common_chat_templates_init(g_model, "");
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||
llama_model_free(g_model);
|
||||
llama_free(g_context);
|
||||
delete g_batch;
|
||||
g_chat_templates.reset();
|
||||
common_sampler_free(g_sampler);
|
||||
delete g_batch;
|
||||
llama_free(g_context);
|
||||
llama_model_free(g_model);
|
||||
llama_backend_free();
|
||||
}
|
||||
|
||||
|
|
@ -298,12 +303,25 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
|||
|
||||
|
||||
/**
|
||||
* Prediction loop's long-term and short-term states
|
||||
* Prediction loop's long-term states
|
||||
*/
|
||||
static llama_pos current_position;
|
||||
constexpr const char* ROLE_SYSTEM = "system";
|
||||
constexpr const char* ROLE_USER = "user";
|
||||
constexpr const char* ROLE_ASSISTANT = "assistant";
|
||||
|
||||
static llama_pos token_predict_budget;
|
||||
static std::string cached_token_chars;
|
||||
static llama_pos current_position;
|
||||
static std::vector<common_chat_msg> chat_msgs;
|
||||
|
||||
static std::string chat_add_and_format(const std::string & role, const std::string & content) {
|
||||
common_chat_msg new_msg;
|
||||
new_msg.role = role;
|
||||
new_msg.content = content;
|
||||
auto formatted = common_chat_format_single(
|
||||
g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
|
||||
chat_msgs.push_back(new_msg);
|
||||
LOGi("Formatted and added %s message: \n%s\n", role.c_str(), formatted.c_str());
|
||||
return formatted;
|
||||
}
|
||||
|
||||
static int decode_tokens_in_batches(
|
||||
llama_context *context,
|
||||
|
|
@ -337,6 +355,13 @@ static int decode_tokens_in_batches(
|
|||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Prediction loop's short-term states
|
||||
*/
|
||||
static llama_pos token_predict_budget;
|
||||
static std::string cached_token_chars;
|
||||
static std::ostringstream assistant_ss; // For storing current assistant message
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||
|
|
@ -351,14 +376,22 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
|||
// Reset short-term states
|
||||
token_predict_budget = 0;
|
||||
cached_token_chars.clear();
|
||||
assistant_ss.str("");
|
||||
|
||||
// Obtain and tokenize system prompt
|
||||
const auto *const system_text = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
||||
LOGd("System prompt received: \n%s", system_text);
|
||||
const auto system_tokens = common_tokenize(g_context, system_text, true, true);
|
||||
env->ReleaseStringUTFChars(jsystem_prompt, system_text);
|
||||
// Obtain system prompt from JEnv
|
||||
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
||||
LOGd("System prompt received: \n%s", system_prompt);
|
||||
std::string formatted_system_prompt(system_prompt);
|
||||
env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
|
||||
|
||||
// Print each token in verbose mode
|
||||
// Format system prompt if applicable
|
||||
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
|
||||
if (has_chat_template) {
|
||||
formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt);
|
||||
}
|
||||
|
||||
// Tokenize system prompt
|
||||
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt, has_chat_template, has_chat_template);
|
||||
for (auto id : system_tokens) {
|
||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||
}
|
||||
|
|
@ -386,14 +419,22 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
|||
// Reset short-term states
|
||||
token_predict_budget = 0;
|
||||
cached_token_chars.clear();
|
||||
assistant_ss.str("");
|
||||
|
||||
// Obtain and tokenize user prompt
|
||||
const auto *const user_text = env->GetStringUTFChars(juser_prompt, nullptr);
|
||||
LOGd("User prompt received: \n%s", user_text);
|
||||
const auto user_tokens = common_tokenize(g_context, user_text, true, true);
|
||||
env->ReleaseStringUTFChars(juser_prompt, user_text);
|
||||
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
|
||||
LOGd("User prompt received: \n%s", user_prompt);
|
||||
std::string formatted_user_prompt(user_prompt);
|
||||
env->ReleaseStringUTFChars(juser_prompt, user_prompt);
|
||||
|
||||
// Print each token in verbose mode
|
||||
// Format user prompt if applicable
|
||||
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
|
||||
if (has_chat_template) {
|
||||
formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt);
|
||||
}
|
||||
|
||||
// Decode formatted user prompts
|
||||
const auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
|
||||
for (auto id : user_tokens) {
|
||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||
}
|
||||
|
|
@ -481,6 +522,7 @@ Java_android_llama_cpp_LLamaAndroid_predictLoop(
|
|||
// Stop if next token is EOG
|
||||
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
|
||||
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
|
||||
chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
@ -493,6 +535,8 @@ Java_android_llama_cpp_LLamaAndroid_predictLoop(
|
|||
if (is_valid_utf8(cached_token_chars.c_str())) {
|
||||
result = env->NewStringUTF(cached_token_chars.c_str());
|
||||
LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
|
||||
|
||||
assistant_ss << cached_token_chars;
|
||||
cached_token_chars.clear();
|
||||
} else {
|
||||
LOGv("id: %d,\tappend to cache", new_token_id);
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class LLamaAndroid {
|
|||
/**
|
||||
* Load the LLM, then process the formatted system prompt if provided
|
||||
*/
|
||||
suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) =
|
||||
suspend fun load(pathToModel: String, systemPrompt: String? = null) =
|
||||
withContext(runLoop) {
|
||||
when (threadLocalState.get()) {
|
||||
is State.NotInitialized -> {
|
||||
|
|
@ -70,8 +70,8 @@ class LLamaAndroid {
|
|||
Log.i(TAG, "Loaded model $pathToModel")
|
||||
threadLocalState.set(State.EnvReady)
|
||||
|
||||
formattedSystemPrompt?.let {
|
||||
initWithSystemPrompt(formattedSystemPrompt)
|
||||
systemPrompt?.let {
|
||||
initWithSystemPrompt(systemPrompt)
|
||||
} ?: run {
|
||||
Log.w(TAG, "No system prompt to process.")
|
||||
threadLocalState.set(State.AwaitingUserPrompt)
|
||||
|
|
@ -108,10 +108,10 @@ class LLamaAndroid {
|
|||
* Send formatted user prompt to LLM
|
||||
*/
|
||||
fun sendUserPrompt(
|
||||
formattedMessage: String,
|
||||
nPredict: Int = DEFAULT_PREDICT_LENGTH,
|
||||
message: String,
|
||||
predictLength: Int = DEFAULT_PREDICT_LENGTH,
|
||||
): Flow<String> = flow {
|
||||
require(formattedMessage.isNotEmpty()) {
|
||||
require(message.isNotEmpty()) {
|
||||
Log.w(TAG, "User prompt discarded due to being empty!")
|
||||
}
|
||||
|
||||
|
|
@ -119,7 +119,7 @@ class LLamaAndroid {
|
|||
is State.AwaitingUserPrompt -> {
|
||||
Log.i(TAG, "Sending user prompt...")
|
||||
threadLocalState.set(State.Processing)
|
||||
processUserPrompt(formattedMessage, nPredict).let { result ->
|
||||
processUserPrompt(message, predictLength).let { result ->
|
||||
if (result != 0) {
|
||||
Log.e(TAG, "Failed to process user prompt: $result")
|
||||
return@flow
|
||||
|
|
|
|||
Loading…
Reference in New Issue