common: map developer role to system (#20215)

* Map developer role to system
* Simplify
This commit is contained in:
Piotr Wilkin (ilintar) 2026-03-09 14:25:11 +01:00 committed by GitHub
parent 43e1cbd6c1
commit f76565db92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 54 additions and 252 deletions

View File

@ -1352,6 +1352,17 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
namespace workaround {
static void map_developer_role_to_system(json & messages) {
for (auto & message : messages) {
if (message.contains("role")) {
if (message["role"] == "developer") {
message["role"] = "system";
}
}
}
}
// if first message is system and template does not support it, merge it with next message
static void system_message_not_supported(json & messages) {
if (!messages.empty() && messages.front().at("role") == "system") {
@ -1429,6 +1440,10 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
params.add_bos = tmpls->add_bos;
params.add_eos = tmpls->add_eos;
if (src.find("<|channel|>") == std::string::npos) {
// map developer to system for all models except for GPT-OSS
workaround::map_developer_role_to_system(params.messages);
}
workaround::func_args_not_string(params.messages);
if (!tmpl.original_caps().supports_system_role) {

View File

@ -800,258 +800,6 @@ const common_chat_msg message_assist_call_python_lines_unclosed =
const common_chat_msg message_assist_json_content =
simple_assist_msg("{\n \"response\": \"Hello, world!\\nWhat's up?\"\n}");
struct delta_data {
std::string delta;
common_chat_params params;
};
static delta_data init_delta(const struct common_chat_templates * tmpls,
const std::vector<std::string> & end_tokens,
const common_chat_msg & user_message,
const common_chat_msg & delta_message,
const std::vector<common_chat_tool> & tools,
const common_chat_tool_choice & tool_choice) {
common_chat_templates_inputs inputs;
inputs.parallel_tool_calls = true;
inputs.messages.push_back(user_message);
inputs.tools = tools;
inputs.tool_choice = tool_choice;
auto params_prefix = common_chat_templates_apply(tmpls, inputs);
inputs.messages.push_back(delta_message);
inputs.add_generation_prompt = false;
auto params_full = common_chat_templates_apply(tmpls, inputs);
std::string prefix = params_prefix.prompt;
std::string full = params_full.prompt;
if (full == prefix) {
throw std::runtime_error("Full message is the same as the prefix");
}
size_t common_prefix_length = 0;
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
if (prefix[i] != full[i]) {
break;
}
if (prefix[i] == '<') {
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
// but it removes thinking tags for past messages.
// The prefix and full strings diverge at <think> vs. <tool▁calls▁begin>, we avoid consuming the leading <.
continue;
}
common_prefix_length = i + 1;
}
auto delta = full.substr(common_prefix_length);
// Strip end tokens
for (const auto & end_token : end_tokens) {
// rfind to find the last occurrence
auto pos = delta.rfind(end_token);
if (pos != std::string::npos) {
delta = delta.substr(0, pos);
break;
}
}
return { delta, params_full };
}
/*
Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
the parsed message is the same as the test_message
*/
static void test_templates(const struct common_chat_templates * tmpls,
const std::vector<std::string> & end_tokens,
const common_chat_msg & test_message,
const std::vector<common_chat_tool> & tools = {},
const std::string & expected_delta = "",
bool expect_grammar_triggered = true,
bool test_grammar_if_triggered = true,
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE,
bool ignore_whitespace_differences = false) {
common_chat_msg user_message;
user_message.role = "user";
user_message.content = "Hello, world!";
common_chat_templates_inputs inputs_tools;
inputs_tools.messages = { message_user };
inputs_tools.tools = { special_function_tool };
common_chat_params params = common_chat_templates_apply(tmpls, inputs_tools);
for (const auto & tool_choice :
std::vector<common_chat_tool_choice>{ COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED }) {
auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice);
if (!expected_delta.empty()) {
if (ignore_whitespace_differences) {
assert_equals(string_strip(expected_delta), string_strip(data.delta));
} else {
assert_equals(expected_delta, data.delta);
}
}
if (expect_grammar_triggered) {
// TODO @ngxson : refactor common_chat_parse to avoid passing format/reasoning_format every time
common_chat_parser_params parser_params;
parser_params.format = data.params.format;
parser_params.reasoning_format = reasoning_format;
if (!parser_params.parser.empty()) {
parser_params.parser = common_peg_arena();
parser_params.parser.load(params.parser);
}
const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, parser_params);
assert_msg_equals(test_message, msg, ignore_whitespace_differences);
}
if (!test_message.tool_calls.empty()) {
GGML_ASSERT(!data.params.grammar.empty());
}
if (!data.params.grammar.empty()) {
auto grammar = build_grammar(data.params.grammar);
if (!grammar) {
throw std::runtime_error("Failed to build grammar");
}
auto earliest_trigger_pos = std::string::npos;
auto constrained = data.delta;
for (const auto & trigger : data.params.grammar_triggers) {
size_t pos = std::string::npos;
std::smatch match;
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto & word = trigger.value;
pos = constrained.find(word);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
{
const auto & pattern = std::regex(trigger.value);
if (std::regex_search(constrained, match, pattern)) {
pos = match.position(pattern.mark_count());
}
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
const auto & pattern = trigger.value;
if (std::regex_match(constrained, match, std::regex(pattern))) {
auto mpos = std::string::npos;
for (size_t i = 1; i < match.size(); ++i) {
if (match[i].length() > 0) {
mpos = match.position(i);
break;
}
}
if (mpos == std::string::npos) {
mpos = match.position(0);
}
pos = mpos;
}
break;
}
default:
throw std::runtime_error("Unknown trigger type");
}
if (pos == std::string::npos) {
continue;
}
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
earliest_trigger_pos = pos;
}
}
auto grammar_triggered = false;
if (earliest_trigger_pos != std::string::npos) {
constrained = constrained.substr(earliest_trigger_pos);
grammar_triggered = true;
}
if (data.params.grammar_lazy) {
assert_equals(expect_grammar_triggered, grammar_triggered);
}
if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
"\n\nConstrained: " + constrained + "\n\nGrammar: " + data.params.grammar);
}
}
}
}
/**
* Test if streaming=true is consistent with streaming=false for given partial parser
* Also test if there is any problem with partial message
*/
template <typename T>
static void test_parser_with_streaming(const common_chat_msg & expected, const std::string & raw_message, T parse_msg) {
constexpr auto utf8_truncate_safe_len = [](const std::string_view s) -> size_t {
auto len = s.size();
if (len == 0) {
return 0;
}
auto i = len;
for (size_t back = 0; back < 4 && i > 0; ++back) {
--i;
unsigned char c = s[i];
if ((c & 0x80) == 0) {
return len;
}
if ((c & 0xC0) == 0xC0) {
size_t expected_len = 0;
if ((c & 0xE0) == 0xC0) {
expected_len = 2;
} else if ((c & 0xF0) == 0xE0) {
expected_len = 3;
} else if ((c & 0xF8) == 0xF0) {
expected_len = 4;
} else {
return i;
}
if (len - i >= expected_len) {
return len;
}
return i;
}
}
return len - std::min(len, size_t(3));
};
constexpr auto utf8_truncate_safe_view = [utf8_truncate_safe_len](const std::string_view s) {
return s.substr(0, utf8_truncate_safe_len(s));
};
auto merged = simple_assist_msg("");
auto last_msg = parse_msg("");
for (size_t i = 1; i <= raw_message.size(); ++i) {
auto curr_msg = parse_msg(std::string(utf8_truncate_safe_view(std::string_view(raw_message).substr(0, i))));
if (curr_msg == simple_assist_msg("")) {
continue;
}
LOG_INF("Streaming msg: %s\n", common_chat_msgs_to_json_oaicompat({ curr_msg }).dump().c_str());
for (auto diff : common_chat_msg_diff::compute_diffs(last_msg, curr_msg)) {
LOG_INF("Streaming diff: %s\n", common_chat_msg_diff_to_json_oaicompat(diff).dump().c_str());
if (!diff.reasoning_content_delta.empty()) {
merged.reasoning_content += diff.reasoning_content_delta;
}
if (!diff.content_delta.empty()) {
merged.content += diff.content_delta;
}
if (diff.tool_call_index != std::string::npos) {
if (!diff.tool_call_delta.name.empty()) {
merged.tool_calls.push_back({ diff.tool_call_delta.name, "", "" });
}
if (!diff.tool_call_delta.arguments.empty()) {
GGML_ASSERT(!merged.tool_calls.empty());
merged.tool_calls.back().arguments += diff.tool_call_delta.arguments;
}
}
LOG_INF("Streaming merged: %s\n", common_chat_msgs_to_json_oaicompat({ merged }).dump().c_str());
}
assert_msg_equals(curr_msg, merged, true);
last_msg = curr_msg;
}
assert_msg_equals(expected, parse_msg(raw_message), true);
assert_msg_equals(expected, merged, true);
}
// Use for PEG parser implementations
struct peg_test_case {
common_chat_templates_inputs params;
@ -3019,6 +2767,44 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
}
}
// Test the developer role to system workaround with a simple mock template
static void test_developer_role_to_system_workaround() {
LOG_DBG("%s\n", __func__);
// Simple mock template that supports system role
const std::string mock_template =
"{%- for message in messages -%}\n"
" {{- '<|' + message.role + '|>' + message.content + '<|end|>' -}}\n"
"{%- endfor -%}\n"
"{%- if add_generation_prompt -%}\n"
" {{- '<|assistant|>' -}}\n"
"{%- endif -%}";
auto tmpls = common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, mock_template));
// Test case 1: Developer message - should be changed to system
// After simplification we only test this case
{
common_chat_templates_inputs inputs;
common_chat_msg developer_msg;
developer_msg.role = "developer";
developer_msg.content = "You are a helpful developer assistant.";
inputs.messages = { developer_msg };
inputs.add_generation_prompt = false;
auto params = common_chat_templates_apply(tmpls.get(), inputs);
// The developer role should have been changed to system
if (params.prompt.find("<|developer|>") != std::string::npos) {
throw std::runtime_error("Test failed: developer role was not changed to system");
}
if (params.prompt.find("<|system|>You are a helpful developer assistant.<|end|>") == std::string::npos) {
throw std::runtime_error("Test failed: system message not found in output");
}
LOG_ERR("Test 1 passed: developer role changed to system\n");
}
}
static void test_msg_diffs_compute() {
LOG_DBG("%s\n", __func__);
{
@ -3155,6 +2941,7 @@ int main(int argc, char ** argv) {
test_msg_diffs_compute();
test_msgs_oaicompat_json_conversion();
test_tools_oaicompat_json_conversion();
test_developer_role_to_system_workaround();
test_template_output_peg_parsers(detailed_debug);
std::cout << "\n[chat] All tests passed!" << '\n';
}