Feature: chat template auto formatting
This commit is contained in:
parent
1b0754c0f5
commit
8bf2f4d412
|
|
@ -46,7 +46,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan
|
||||||
|
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
// TODO-hyin: implement format message
|
// TODO-hyin: implement format message
|
||||||
llamaAndroid.sendUserPrompt(formattedMessage = text)
|
llamaAndroid.sendUserPrompt(message = actualText)
|
||||||
.catch {
|
.catch {
|
||||||
Log.e(tag, "send() failed", it)
|
Log.e(tag, "send() failed", it)
|
||||||
messages += it.message!!
|
messages += it.message!!
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,10 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#include <sampling.h>
|
#include <sampling.h>
|
||||||
#include "llama.h"
|
|
||||||
|
#include "chat.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
template <class T> static std::string join(const std::vector<T> & values, const std::string & delim) {
|
template <class T> static std::string join(const std::vector<T> & values, const std::string & delim) {
|
||||||
std::ostringstream str;
|
std::ostringstream str;
|
||||||
|
|
@ -42,6 +44,7 @@ static llama_model * g_model;
|
||||||
static llama_context * g_context;
|
static llama_context * g_context;
|
||||||
static llama_batch * g_batch;
|
static llama_batch * g_batch;
|
||||||
static common_sampler * g_sampler;
|
static common_sampler * g_sampler;
|
||||||
|
static common_chat_templates_ptr g_chat_templates;
|
||||||
|
|
||||||
static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
static void log_callback(ggml_log_level level, const char *fmt, void *data) {
|
||||||
int priority;
|
int priority;
|
||||||
|
|
@ -174,16 +177,18 @@ Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unus
|
||||||
if (ret != 0) { return ret; }
|
if (ret != 0) { return ret; }
|
||||||
new_batch(BATCH_SIZE);
|
new_batch(BATCH_SIZE);
|
||||||
new_sampler(SAMPLER_TEMP);
|
new_sampler(SAMPLER_TEMP);
|
||||||
|
g_chat_templates = common_chat_templates_init(g_model, "");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) {
|
Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||||
llama_model_free(g_model);
|
g_chat_templates.reset();
|
||||||
llama_free(g_context);
|
|
||||||
delete g_batch;
|
|
||||||
common_sampler_free(g_sampler);
|
common_sampler_free(g_sampler);
|
||||||
|
delete g_batch;
|
||||||
|
llama_free(g_context);
|
||||||
|
llama_model_free(g_model);
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -298,12 +303,25 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/,
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Prediction loop's long-term and short-term states
|
* Prediction loop's long-term states
|
||||||
*/
|
*/
|
||||||
static llama_pos current_position;
|
constexpr const char* ROLE_SYSTEM = "system";
|
||||||
|
constexpr const char* ROLE_USER = "user";
|
||||||
|
constexpr const char* ROLE_ASSISTANT = "assistant";
|
||||||
|
|
||||||
static llama_pos token_predict_budget;
|
static llama_pos current_position;
|
||||||
static std::string cached_token_chars;
|
static std::vector<common_chat_msg> chat_msgs;
|
||||||
|
|
||||||
|
static std::string chat_add_and_format(const std::string & role, const std::string & content) {
|
||||||
|
common_chat_msg new_msg;
|
||||||
|
new_msg.role = role;
|
||||||
|
new_msg.content = content;
|
||||||
|
auto formatted = common_chat_format_single(
|
||||||
|
g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
|
||||||
|
chat_msgs.push_back(new_msg);
|
||||||
|
LOGi("Formatted and added %s message: \n%s\n", role.c_str(), formatted.c_str());
|
||||||
|
return formatted;
|
||||||
|
}
|
||||||
|
|
||||||
static int decode_tokens_in_batches(
|
static int decode_tokens_in_batches(
|
||||||
llama_context *context,
|
llama_context *context,
|
||||||
|
|
@ -337,6 +355,13 @@ static int decode_tokens_in_batches(
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Prediction loop's short-term states
|
||||||
|
*/
|
||||||
|
static llama_pos token_predict_budget;
|
||||||
|
static std::string cached_token_chars;
|
||||||
|
static std::ostringstream assistant_ss; // For storing current assistant message
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jint JNICALL
|
JNIEXPORT jint JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||||
|
|
@ -351,14 +376,22 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||||
// Reset short-term states
|
// Reset short-term states
|
||||||
token_predict_budget = 0;
|
token_predict_budget = 0;
|
||||||
cached_token_chars.clear();
|
cached_token_chars.clear();
|
||||||
|
assistant_ss.str("");
|
||||||
|
|
||||||
// Obtain and tokenize system prompt
|
// Obtain system prompt from JEnv
|
||||||
const auto *const system_text = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
||||||
LOGd("System prompt received: \n%s", system_text);
|
LOGd("System prompt received: \n%s", system_prompt);
|
||||||
const auto system_tokens = common_tokenize(g_context, system_text, true, true);
|
std::string formatted_system_prompt(system_prompt);
|
||||||
env->ReleaseStringUTFChars(jsystem_prompt, system_text);
|
env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
|
||||||
|
|
||||||
// Print each token in verbose mode
|
// Format system prompt if applicable
|
||||||
|
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
|
||||||
|
if (has_chat_template) {
|
||||||
|
formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tokenize system prompt
|
||||||
|
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt, has_chat_template, has_chat_template);
|
||||||
for (auto id : system_tokens) {
|
for (auto id : system_tokens) {
|
||||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||||
}
|
}
|
||||||
|
|
@ -386,14 +419,22 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
||||||
// Reset short-term states
|
// Reset short-term states
|
||||||
token_predict_budget = 0;
|
token_predict_budget = 0;
|
||||||
cached_token_chars.clear();
|
cached_token_chars.clear();
|
||||||
|
assistant_ss.str("");
|
||||||
|
|
||||||
// Obtain and tokenize user prompt
|
// Obtain and tokenize user prompt
|
||||||
const auto *const user_text = env->GetStringUTFChars(juser_prompt, nullptr);
|
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
|
||||||
LOGd("User prompt received: \n%s", user_text);
|
LOGd("User prompt received: \n%s", user_prompt);
|
||||||
const auto user_tokens = common_tokenize(g_context, user_text, true, true);
|
std::string formatted_user_prompt(user_prompt);
|
||||||
env->ReleaseStringUTFChars(juser_prompt, user_text);
|
env->ReleaseStringUTFChars(juser_prompt, user_prompt);
|
||||||
|
|
||||||
// Print each token in verbose mode
|
// Format user prompt if applicable
|
||||||
|
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
|
||||||
|
if (has_chat_template) {
|
||||||
|
formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode formatted user prompts
|
||||||
|
const auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
|
||||||
for (auto id : user_tokens) {
|
for (auto id : user_tokens) {
|
||||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||||
}
|
}
|
||||||
|
|
@ -481,6 +522,7 @@ Java_android_llama_cpp_LLamaAndroid_predictLoop(
|
||||||
// Stop if next token is EOG
|
// Stop if next token is EOG
|
||||||
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
|
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
|
||||||
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
|
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
|
||||||
|
chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -493,6 +535,8 @@ Java_android_llama_cpp_LLamaAndroid_predictLoop(
|
||||||
if (is_valid_utf8(cached_token_chars.c_str())) {
|
if (is_valid_utf8(cached_token_chars.c_str())) {
|
||||||
result = env->NewStringUTF(cached_token_chars.c_str());
|
result = env->NewStringUTF(cached_token_chars.c_str());
|
||||||
LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
|
LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
|
||||||
|
|
||||||
|
assistant_ss << cached_token_chars;
|
||||||
cached_token_chars.clear();
|
cached_token_chars.clear();
|
||||||
} else {
|
} else {
|
||||||
LOGv("id: %d,\tappend to cache", new_token_id);
|
LOGv("id: %d,\tappend to cache", new_token_id);
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ class LLamaAndroid {
|
||||||
/**
|
/**
|
||||||
* Load the LLM, then process the formatted system prompt if provided
|
* Load the LLM, then process the formatted system prompt if provided
|
||||||
*/
|
*/
|
||||||
suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) =
|
suspend fun load(pathToModel: String, systemPrompt: String? = null) =
|
||||||
withContext(runLoop) {
|
withContext(runLoop) {
|
||||||
when (threadLocalState.get()) {
|
when (threadLocalState.get()) {
|
||||||
is State.NotInitialized -> {
|
is State.NotInitialized -> {
|
||||||
|
|
@ -70,8 +70,8 @@ class LLamaAndroid {
|
||||||
Log.i(TAG, "Loaded model $pathToModel")
|
Log.i(TAG, "Loaded model $pathToModel")
|
||||||
threadLocalState.set(State.EnvReady)
|
threadLocalState.set(State.EnvReady)
|
||||||
|
|
||||||
formattedSystemPrompt?.let {
|
systemPrompt?.let {
|
||||||
initWithSystemPrompt(formattedSystemPrompt)
|
initWithSystemPrompt(systemPrompt)
|
||||||
} ?: run {
|
} ?: run {
|
||||||
Log.w(TAG, "No system prompt to process.")
|
Log.w(TAG, "No system prompt to process.")
|
||||||
threadLocalState.set(State.AwaitingUserPrompt)
|
threadLocalState.set(State.AwaitingUserPrompt)
|
||||||
|
|
@ -108,10 +108,10 @@ class LLamaAndroid {
|
||||||
* Send formatted user prompt to LLM
|
* Send formatted user prompt to LLM
|
||||||
*/
|
*/
|
||||||
fun sendUserPrompt(
|
fun sendUserPrompt(
|
||||||
formattedMessage: String,
|
message: String,
|
||||||
nPredict: Int = DEFAULT_PREDICT_LENGTH,
|
predictLength: Int = DEFAULT_PREDICT_LENGTH,
|
||||||
): Flow<String> = flow {
|
): Flow<String> = flow {
|
||||||
require(formattedMessage.isNotEmpty()) {
|
require(message.isNotEmpty()) {
|
||||||
Log.w(TAG, "User prompt discarded due to being empty!")
|
Log.w(TAG, "User prompt discarded due to being empty!")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -119,7 +119,7 @@ class LLamaAndroid {
|
||||||
is State.AwaitingUserPrompt -> {
|
is State.AwaitingUserPrompt -> {
|
||||||
Log.i(TAG, "Sending user prompt...")
|
Log.i(TAG, "Sending user prompt...")
|
||||||
threadLocalState.set(State.Processing)
|
threadLocalState.set(State.Processing)
|
||||||
processUserPrompt(formattedMessage, nPredict).let { result ->
|
processUserPrompt(message, predictLength).let { result ->
|
||||||
if (result != 0) {
|
if (result != 0) {
|
||||||
Log.e(TAG, "Failed to process user prompt: $result")
|
Log.e(TAG, "Failed to process user prompt: $result")
|
||||||
return@flow
|
return@flow
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue