common: map developer role to system (#20215)
* Map developer role to system * Simplify
This commit is contained in:
parent
43e1cbd6c1
commit
f76565db92
|
|
@ -1352,6 +1352,17 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
|||
|
||||
namespace workaround {
|
||||
|
||||
static void map_developer_role_to_system(json & messages) {
|
||||
for (auto & message : messages) {
|
||||
if (message.contains("role")) {
|
||||
if (message["role"] == "developer") {
|
||||
message["role"] = "system";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// if first message is system and template does not support it, merge it with next message
|
||||
static void system_message_not_supported(json & messages) {
|
||||
if (!messages.empty() && messages.front().at("role") == "system") {
|
||||
|
|
@ -1429,6 +1440,10 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
if (src.find("<|channel|>") == std::string::npos) {
|
||||
// map developer to system for all models except for GPT-OSS
|
||||
workaround::map_developer_role_to_system(params.messages);
|
||||
}
|
||||
workaround::func_args_not_string(params.messages);
|
||||
|
||||
if (!tmpl.original_caps().supports_system_role) {
|
||||
|
|
|
|||
|
|
@ -800,258 +800,6 @@ const common_chat_msg message_assist_call_python_lines_unclosed =
|
|||
const common_chat_msg message_assist_json_content =
|
||||
simple_assist_msg("{\n \"response\": \"Hello, world!\\nWhat's up?\"\n}");
|
||||
|
||||
struct delta_data {
|
||||
std::string delta;
|
||||
common_chat_params params;
|
||||
};
|
||||
|
||||
static delta_data init_delta(const struct common_chat_templates * tmpls,
|
||||
const std::vector<std::string> & end_tokens,
|
||||
const common_chat_msg & user_message,
|
||||
const common_chat_msg & delta_message,
|
||||
const std::vector<common_chat_tool> & tools,
|
||||
const common_chat_tool_choice & tool_choice) {
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.parallel_tool_calls = true;
|
||||
inputs.messages.push_back(user_message);
|
||||
inputs.tools = tools;
|
||||
inputs.tool_choice = tool_choice;
|
||||
auto params_prefix = common_chat_templates_apply(tmpls, inputs);
|
||||
|
||||
inputs.messages.push_back(delta_message);
|
||||
inputs.add_generation_prompt = false;
|
||||
auto params_full = common_chat_templates_apply(tmpls, inputs);
|
||||
|
||||
std::string prefix = params_prefix.prompt;
|
||||
std::string full = params_full.prompt;
|
||||
|
||||
if (full == prefix) {
|
||||
throw std::runtime_error("Full message is the same as the prefix");
|
||||
}
|
||||
|
||||
size_t common_prefix_length = 0;
|
||||
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
|
||||
if (prefix[i] != full[i]) {
|
||||
break;
|
||||
}
|
||||
if (prefix[i] == '<') {
|
||||
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
|
||||
// but it removes thinking tags for past messages.
|
||||
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
|
||||
continue;
|
||||
}
|
||||
common_prefix_length = i + 1;
|
||||
}
|
||||
auto delta = full.substr(common_prefix_length);
|
||||
|
||||
// Strip end tokens
|
||||
for (const auto & end_token : end_tokens) {
|
||||
// rfind to find the last occurrence
|
||||
auto pos = delta.rfind(end_token);
|
||||
if (pos != std::string::npos) {
|
||||
delta = delta.substr(0, pos);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return { delta, params_full };
|
||||
}
|
||||
|
||||
/*
|
||||
Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
|
||||
gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
|
||||
the parsed message is the same as the test_message
|
||||
*/
|
||||
static void test_templates(const struct common_chat_templates * tmpls,
|
||||
const std::vector<std::string> & end_tokens,
|
||||
const common_chat_msg & test_message,
|
||||
const std::vector<common_chat_tool> & tools = {},
|
||||
const std::string & expected_delta = "",
|
||||
bool expect_grammar_triggered = true,
|
||||
bool test_grammar_if_triggered = true,
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE,
|
||||
bool ignore_whitespace_differences = false) {
|
||||
common_chat_msg user_message;
|
||||
user_message.role = "user";
|
||||
user_message.content = "Hello, world!";
|
||||
|
||||
common_chat_templates_inputs inputs_tools;
|
||||
inputs_tools.messages = { message_user };
|
||||
inputs_tools.tools = { special_function_tool };
|
||||
|
||||
common_chat_params params = common_chat_templates_apply(tmpls, inputs_tools);
|
||||
|
||||
for (const auto & tool_choice :
|
||||
std::vector<common_chat_tool_choice>{ COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED }) {
|
||||
auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice);
|
||||
if (!expected_delta.empty()) {
|
||||
if (ignore_whitespace_differences) {
|
||||
assert_equals(string_strip(expected_delta), string_strip(data.delta));
|
||||
} else {
|
||||
assert_equals(expected_delta, data.delta);
|
||||
}
|
||||
}
|
||||
|
||||
if (expect_grammar_triggered) {
|
||||
// TODO @ngxson : refactor common_chat_parse to avoid passing format/reasoning_format every time
|
||||
common_chat_parser_params parser_params;
|
||||
parser_params.format = data.params.format;
|
||||
parser_params.reasoning_format = reasoning_format;
|
||||
if (!parser_params.parser.empty()) {
|
||||
parser_params.parser = common_peg_arena();
|
||||
parser_params.parser.load(params.parser);
|
||||
}
|
||||
const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, parser_params);
|
||||
assert_msg_equals(test_message, msg, ignore_whitespace_differences);
|
||||
}
|
||||
|
||||
if (!test_message.tool_calls.empty()) {
|
||||
GGML_ASSERT(!data.params.grammar.empty());
|
||||
}
|
||||
if (!data.params.grammar.empty()) {
|
||||
auto grammar = build_grammar(data.params.grammar);
|
||||
if (!grammar) {
|
||||
throw std::runtime_error("Failed to build grammar");
|
||||
}
|
||||
auto earliest_trigger_pos = std::string::npos;
|
||||
auto constrained = data.delta;
|
||||
for (const auto & trigger : data.params.grammar_triggers) {
|
||||
size_t pos = std::string::npos;
|
||||
std::smatch match;
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto & word = trigger.value;
|
||||
pos = constrained.find(word);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
{
|
||||
const auto & pattern = std::regex(trigger.value);
|
||||
if (std::regex_search(constrained, match, pattern)) {
|
||||
pos = match.position(pattern.mark_count());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||
{
|
||||
const auto & pattern = trigger.value;
|
||||
if (std::regex_match(constrained, match, std::regex(pattern))) {
|
||||
auto mpos = std::string::npos;
|
||||
for (size_t i = 1; i < match.size(); ++i) {
|
||||
if (match[i].length() > 0) {
|
||||
mpos = match.position(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (mpos == std::string::npos) {
|
||||
mpos = match.position(0);
|
||||
}
|
||||
pos = mpos;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("Unknown trigger type");
|
||||
}
|
||||
if (pos == std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
|
||||
earliest_trigger_pos = pos;
|
||||
}
|
||||
}
|
||||
auto grammar_triggered = false;
|
||||
if (earliest_trigger_pos != std::string::npos) {
|
||||
constrained = constrained.substr(earliest_trigger_pos);
|
||||
grammar_triggered = true;
|
||||
}
|
||||
if (data.params.grammar_lazy) {
|
||||
assert_equals(expect_grammar_triggered, grammar_triggered);
|
||||
}
|
||||
|
||||
if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
|
||||
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
|
||||
"\n\nConstrained: " + constrained + "\n\nGrammar: " + data.params.grammar);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Test if streaming=true is consistent with streaming=false for given partial parser
|
||||
* Also test if there is any problem with partial message
|
||||
*/
|
||||
template <typename T>
|
||||
static void test_parser_with_streaming(const common_chat_msg & expected, const std::string & raw_message, T parse_msg) {
|
||||
constexpr auto utf8_truncate_safe_len = [](const std::string_view s) -> size_t {
|
||||
auto len = s.size();
|
||||
if (len == 0) {
|
||||
return 0;
|
||||
}
|
||||
auto i = len;
|
||||
for (size_t back = 0; back < 4 && i > 0; ++back) {
|
||||
--i;
|
||||
unsigned char c = s[i];
|
||||
if ((c & 0x80) == 0) {
|
||||
return len;
|
||||
}
|
||||
if ((c & 0xC0) == 0xC0) {
|
||||
size_t expected_len = 0;
|
||||
if ((c & 0xE0) == 0xC0) {
|
||||
expected_len = 2;
|
||||
} else if ((c & 0xF0) == 0xE0) {
|
||||
expected_len = 3;
|
||||
} else if ((c & 0xF8) == 0xF0) {
|
||||
expected_len = 4;
|
||||
} else {
|
||||
return i;
|
||||
}
|
||||
if (len - i >= expected_len) {
|
||||
return len;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return len - std::min(len, size_t(3));
|
||||
};
|
||||
constexpr auto utf8_truncate_safe_view = [utf8_truncate_safe_len](const std::string_view s) {
|
||||
return s.substr(0, utf8_truncate_safe_len(s));
|
||||
};
|
||||
|
||||
auto merged = simple_assist_msg("");
|
||||
auto last_msg = parse_msg("");
|
||||
for (size_t i = 1; i <= raw_message.size(); ++i) {
|
||||
auto curr_msg = parse_msg(std::string(utf8_truncate_safe_view(std::string_view(raw_message).substr(0, i))));
|
||||
if (curr_msg == simple_assist_msg("")) {
|
||||
continue;
|
||||
}
|
||||
LOG_INF("Streaming msg: %s\n", common_chat_msgs_to_json_oaicompat({ curr_msg }).dump().c_str());
|
||||
for (auto diff : common_chat_msg_diff::compute_diffs(last_msg, curr_msg)) {
|
||||
LOG_INF("Streaming diff: %s\n", common_chat_msg_diff_to_json_oaicompat(diff).dump().c_str());
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
merged.reasoning_content += diff.reasoning_content_delta;
|
||||
}
|
||||
if (!diff.content_delta.empty()) {
|
||||
merged.content += diff.content_delta;
|
||||
}
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
if (!diff.tool_call_delta.name.empty()) {
|
||||
merged.tool_calls.push_back({ diff.tool_call_delta.name, "", "" });
|
||||
}
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
GGML_ASSERT(!merged.tool_calls.empty());
|
||||
merged.tool_calls.back().arguments += diff.tool_call_delta.arguments;
|
||||
}
|
||||
}
|
||||
LOG_INF("Streaming merged: %s\n", common_chat_msgs_to_json_oaicompat({ merged }).dump().c_str());
|
||||
}
|
||||
assert_msg_equals(curr_msg, merged, true);
|
||||
last_msg = curr_msg;
|
||||
}
|
||||
assert_msg_equals(expected, parse_msg(raw_message), true);
|
||||
assert_msg_equals(expected, merged, true);
|
||||
}
|
||||
|
||||
// Use for PEG parser implementations
|
||||
struct peg_test_case {
|
||||
common_chat_templates_inputs params;
|
||||
|
|
@ -3019,6 +2767,44 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test the developer role to system workaround with a simple mock template
|
||||
static void test_developer_role_to_system_workaround() {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
|
||||
// Simple mock template that supports system role
|
||||
const std::string mock_template =
|
||||
"{%- for message in messages -%}\n"
|
||||
" {{- '<|' + message.role + '|>' + message.content + '<|end|>' -}}\n"
|
||||
"{%- endfor -%}\n"
|
||||
"{%- if add_generation_prompt -%}\n"
|
||||
" {{- '<|assistant|>' -}}\n"
|
||||
"{%- endif -%}";
|
||||
|
||||
auto tmpls = common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, mock_template));
|
||||
|
||||
// Test case 1: Developer message - should be changed to system
|
||||
// After simplification we only test this case
|
||||
{
|
||||
common_chat_templates_inputs inputs;
|
||||
common_chat_msg developer_msg;
|
||||
developer_msg.role = "developer";
|
||||
developer_msg.content = "You are a helpful developer assistant.";
|
||||
inputs.messages = { developer_msg };
|
||||
inputs.add_generation_prompt = false;
|
||||
|
||||
auto params = common_chat_templates_apply(tmpls.get(), inputs);
|
||||
|
||||
// The developer role should have been changed to system
|
||||
if (params.prompt.find("<|developer|>") != std::string::npos) {
|
||||
throw std::runtime_error("Test failed: developer role was not changed to system");
|
||||
}
|
||||
if (params.prompt.find("<|system|>You are a helpful developer assistant.<|end|>") == std::string::npos) {
|
||||
throw std::runtime_error("Test failed: system message not found in output");
|
||||
}
|
||||
LOG_ERR("Test 1 passed: developer role changed to system\n");
|
||||
}
|
||||
}
|
||||
|
||||
static void test_msg_diffs_compute() {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
{
|
||||
|
|
@ -3155,6 +2941,7 @@ int main(int argc, char ** argv) {
|
|||
test_msg_diffs_compute();
|
||||
test_msgs_oaicompat_json_conversion();
|
||||
test_tools_oaicompat_json_conversion();
|
||||
test_developer_role_to_system_workaround();
|
||||
test_template_output_peg_parsers(detailed_debug);
|
||||
std::cout << "\n[chat] All tests passed!" << '\n';
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue