diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 23e23ca8c7..d740dac065 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -1395,14 +1395,6 @@ static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { builder.consume_reasoning_with_xml_tool_calls(form, "", ""); } -static void common_chat_parse_solar_open(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("<|think|>", "<|end|><|begin|>assistant<|content|>"); - - // TODO: Tool calling - - builder.add_content(builder.consume_rest()); -} - static void common_chat_parse_content_only(common_chat_msg_parser & builder) { builder.try_parse_reasoning("", ""); builder.add_content(builder.consume_rest()); @@ -1487,9 +1479,6 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_XIAOMI_MIMO: common_chat_parse_xiaomi_mimo(builder); break; - case COMMON_CHAT_FORMAT_SOLAR_OPEN: - common_chat_parse_solar_open(builder); - break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } diff --git a/common/chat.cpp b/common/chat.cpp index b98ab21ce1..e54e4b1dee 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -669,7 +669,6 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder"; case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5"; case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo"; - case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open"; case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple"; case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native"; case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed"; @@ -2521,20 +2520,161 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - // TODO: Reasoning effort - json additional_context = {}; + // Copy `reasoning_content` to `reasoning` + auto adjusted_messages = json::array(); + for (const auto & msg : inputs.messages) { + if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) { + auto adjusted_message = msg; + adjusted_message["reasoning"] = msg.at("reasoning_content"); + adjusted_message.erase("reasoning_content"); + adjusted_messages.push_back(adjusted_message); + } else { + adjusted_messages.push_back(msg); + } + } - data.prompt = apply(tmpl, inputs, std::nullopt, std::nullopt, additional_context); - data.format = COMMON_CHAT_FORMAT_SOLAR_OPEN; + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto include_grammar = true; + auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); + + // Check if we need to replace the flush token with end token during inference and without generation prompt. + if (inputs.is_inference && !inputs.add_generation_prompt) { + static constexpr std::string_view return_token = "<|flush|>"; + static constexpr std::string_view end_token = "<|end|>"; + if (size_t pos = prompt.rfind(return_token); pos != std::string::npos) { + prompt.replace(pos, return_token.length(), end_token); + } + } + + auto enable_thinking = prompt.rfind("<|think|><|end|>") == std::string::npos; + + data.prompt = prompt; + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.preserved_tokens = { "<|think|>", "<|content|>", "<|begin|>", "<|end|>", + "<|tool_calls|>", + "<|tool_call:begin|>", + "<|tool_call:end|>", + "<|tool_call:name|>", + "<|tool_call:args|>", }; - // TODO: Tool calling + auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { + auto lit_think = p.atomic(p.literal("<|think|>")); + auto lit_assistant_begin = p.atomic(p.literal("<|begin|>assistant")); + auto lit_content = p.atomic(p.literal("<|content|>")); + auto lit_end = p.atomic(p.literal("<|end|>")); + auto parser_until_end = p.until("<|end|>"); + + // reasoning <- "<|think|>" (!"<|end|>" .)* + auto parser_reasoning = p.rule("reasoning", lit_think + p.reasoning(parser_until_end)); + + // content <- "<|content|>" (!"<|end|>" .)* + auto parser_content = p.rule("content", lit_content + p.content(parser_until_end)); + + // wrap_choice(items) <- item-choice wrapped* + // item-choice <- items[0] / ... / items[n] + // wrapped <- "<|end|><|begin|>assistant" item-choice + auto wrap_choice = [&](const std::vector & items) { + auto choice = p.choice(items); + auto first = enable_thinking ? choice : lit_assistant_begin + choice; + return first + p.zero_or_more(lit_end + lit_assistant_begin + choice); + }; + + // wrap_seq(items) <- item[0] "<|end|><|begin|>assistant" item[1] ... + auto wrap_seq = [&](const std::vector & items) { + auto seq = p.sequence(); + for (auto i = 0u; i < items.size(); i++) { + if (i == 0) { + seq += enable_thinking ? items[i] : lit_assistant_begin + items[i]; + continue; + } + seq += lit_end + lit_assistant_begin + items[i]; + } + return seq; + }; + + // Response format parser + if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { + auto parser_response_format = lit_content + p.content(p.schema(p.json(), "response-format", inputs.json_schema)); + return p.choice({ + wrap_seq({parser_reasoning, parser_response_format}), + wrap_seq({parser_response_format}) + }); + } + + auto lit_tool_call_begin = p.literal("<|tool_call:begin|>"); + auto lit_tool_call_name = p.literal("<|tool_call:name|>"); + auto lit_tool_call_args = p.literal("<|tool_call:args|>"); + auto lit_tool_call_end = p.literal("<|tool_call:end|>"); + + // Tool call parser + if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { + auto parser_tool_call = p.choice(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); + + parser_tool_call |= p.rule("tool-" + name, + p.atomic(p.tool_name(p.literal(name)) + lit_tool_call_args) + + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))); + }); + + auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; + auto max_calls = inputs.parallel_tool_calls ? -1 : 1; + + auto parser_tool_calls = p.trigger_rule("tool-calls", + p.atomic(p.literal("<|tool_calls|>")) + + p.repeat( + p.tool_open( + lit_tool_call_begin + + p.tool_id(p.chars("[a-zA-Z0-9_-]", 1, -1)) + + lit_tool_call_name + + p.peek(p.chars("[^<]", 1, -1) + lit_tool_call_args)) + + parser_tool_call + + p.tool_close(lit_tool_call_end), + 1, max_calls)); + + if (min_calls == 1) { + // If required, then try any combination of the reasoning, content, and tool call + p.choice({ + wrap_seq({parser_reasoning, parser_content, parser_tool_calls}), + wrap_seq({parser_content, parser_tool_calls}), + wrap_seq({parser_tool_calls}) + }); + } + + return wrap_choice({parser_reasoning, parser_content, parser_tool_calls}); + } + + // Content only parser + include_grammar = false; + return wrap_choice({parser_reasoning, parser_content}); + }); + + data.parser = parser.save(); + + if (include_grammar) { + data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; + + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + auto schema = function.at("parameters"); + builder.resolve_refs(schema); + }); + parser.build_grammar(builder, data.grammar_lazy); + }); + + data.grammar_triggers = { + {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls|>"} + }; + } return data; } @@ -2763,6 +2903,13 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_apriel_1_5(tmpl, params); } + // Solar Open + if (src.find("<|tool_response:begin|>") != std::string::npos && + src.find("<|tool_response:name|>") != std::string::npos && + src.find("<|tool_response:result|>") != std::string::npos) { + return common_chat_params_init_solar_open(tmpl, params); + } + // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. if ((params.tools.is_array() && params.json_schema.is_object())) { @@ -2802,13 +2949,6 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_magistral(tmpl, params); } - // Solar Open - if (src.find("<|tool_response:begin|>") != std::string::npos && - src.find("<|tool_response:name|>") != std::string::npos && - src.find("<|tool_response:result|>") != std::string::npos) { - return common_chat_params_init_solar_open(tmpl, params); - } - // Plain handler (no tools) if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { return common_chat_params_init_without_tools(tmpl, params); diff --git a/common/chat.h b/common/chat.h index 8bd4a325ff..6085510a40 100644 --- a/common/chat.h +++ b/common/chat.h @@ -124,7 +124,6 @@ enum common_chat_format { COMMON_CHAT_FORMAT_QWEN3_CODER_XML, COMMON_CHAT_FORMAT_APRIEL_1_5, COMMON_CHAT_FORMAT_XIAOMI_MIMO, - COMMON_CHAT_FORMAT_SOLAR_OPEN, // These are intended to be parsed by the PEG parser COMMON_CHAT_FORMAT_PEG_SIMPLE, diff --git a/models/templates/upstage-Solar-Open-100B.jinja b/models/templates/upstage-Solar-Open-100B.jinja new file mode 100644 index 0000000000..13268c1a84 --- /dev/null +++ b/models/templates/upstage-Solar-Open-100B.jinja @@ -0,0 +1,156 @@ +{#- ======== Template Parameters ======== #} +{%- set add_generation_prompt = add_generation_prompt if add_generation_prompt is defined else true %} +{%- set default_system_prompt = default_system_prompt if default_system_prompt is defined else true %} +{%- set reasoning_effort = reasoning_effort if reasoning_effort is defined else "high" %} +{%- set think_render_option = think_render_option if think_render_option is defined else "lastthink" %} + +{#- ======== System Block State ======== #} +{%- set sys_ns = namespace(is_first_block=true) -%} + +{#- ======== Find last user message index ======== #} +{%- set last_user_idx = namespace(value=-1) -%} +{%- for message in messages -%} + {%- if message.role == 'user' -%} + {%- set last_user_idx.value = loop.index0 -%} + {%- endif -%} +{%- endfor -%} + +{#- ======== System messages renderers ======== #} +{%- macro render_system_message(user_system_messages) %} + {%- if default_system_prompt %} + {%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %} + {%- set sys_ns.is_first_block = false %} + {{- "## Provider System Prompt\n\nYou are Solar Open 100B, a large language model trained by Upstage AI, a Korean startup. Your knowledge cutoff is 2025-07. The current date is " + strftime_now("%Y-%m-%d") + "." }} + {%- endif -%} + {%- if user_system_messages %} + {%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %} + {%- set sys_ns.is_first_block = false %} + {{- "## System Prompt" }} + {%- for system_message in user_system_messages %} + {{- "\n\n" }} + {{- system_message }} + {%- endfor %} + {%- endif -%} +{%- endmacro %} + +{%- macro render_tool_instruction(tools) %} + {%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %} + {%- set sys_ns.is_first_block = false %} + {{- "## Tools\n\n### Tool Call Instruction" }} + {{- "\nYou may invoke one or more tools to assist with the user's query. Available tools are provided in JSON Schema format: <|tools:begin|><|tool:begin|><|tool:end|>...<|tools:end|>\n" }} + {{- "\n### Available Tools\n" }} + {{- "<|tools:begin|>" }} + {%- for tool in tools %} + {{- "<|tool:begin|>" }} + {{- tool.function | tojson }} + {{- "<|tool:end|>" }} + {%- endfor %} + {{- "<|tools:end|>\n" }} + {{- "\n### Tool Call Format\n" }} + {{- "For each tool call, return a JSON object with the following structure, enclosed within <|tool_call:begin|> and <|tool_call:end|> tags: \n<|tool_call:begin|><|tool_call:name|><|tool_call:args|><|tool_call:end|>\n" }} + {{- "- The must be a randomly generated string consisting of 10 lowercase letters (a-z) and/or digits (0-9) (e.g., a1b2c3d4e5)\n" }} + {{- "\n### Tool Response Format\n" }} + {{- "Each tool is responded by `tool` with the following structure:\n<|tool_response:id|><|tool_response:name|><|tool_response:result|><|tool_response:end|>\n" }} + {{- "- Ensure the matches the corresponding tool call" -}} +{%- endmacro %} + +{%- macro render_json_response_format_instruction(response_format) %} + {%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %} + {%- set sys_ns.is_first_block = false %} + {{- "## Output Format Constraint" }} + {{- "\n\nYour final response should follow the JSON schema: \n[Start of schema]" }} + {{- response_format }} + {{- "\n[End of schema]\nPlease ensure your answers adhere to this format and do not contain any unnecessary text." }} +{%- endmacro %} + +{%- macro get_tool_name(messages, tool_call_id) %} + {%- for msg in messages -%} + {%- if msg.role == 'assistant' and msg.tool_calls -%} + {%- for tool_call in msg.tool_calls -%} + {%- if tool_call.id == tool_call_id -%} + {{- tool_call.function.name }} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {%- endfor -%} +{%- endmacro %} + +{%- macro render_tool_arguments(tool_arguments) %} + {%- if tool_arguments is mapping -%} + {{- tool_arguments | tojson }} + {%- else -%} + {{- tool_arguments }} + {%- endif -%} +{%- endmacro %} + +{#- ======== Render system message ======== #} +{%- set ns = namespace(system_messages=[]) -%} +{%- for message in messages -%} + {%- if message.role == 'system' -%} + {%- set ns.system_messages = ns.system_messages + [message.content] -%} + {%- endif -%} +{%- endfor -%} + +{%- if ns.system_messages or default_system_prompt or tools or response_format -%} + {{- "<|begin|>system<|content|>" }} + {{- render_system_message(ns.system_messages) }} + {%- if tools -%} + {{- render_tool_instruction(tools) }} + {%- endif %} + {%- if response_format -%} + {{- render_json_response_format_instruction(response_format) }} + {%- endif %} + {{- "<|end|>" }} +{%- endif -%} + +{#- ======== Render main messages ======== #} +{%- for message in messages -%} + {%- if message.role == 'user' -%} + {{- "<|begin|>user<|content|>" + message.content + "<|end|>" }} + {%- elif message.role == 'tool' -%} + {%- set prev_is_tool = loop.index0 > 0 and messages[loop.index0 - 1].role == 'tool' -%} + {%- set next_is_tool = loop.index0 < (messages | length - 1) and messages[loop.index0 + 1].role == 'tool' -%} + {%- if not prev_is_tool -%} + {{- "<|begin|>tool<|tool_response|>" }} + {%- endif -%} + {{- "<|tool_response:begin|>" + message.tool_call_id + "<|tool_response:name|>" }} + {{- get_tool_name(messages, message.tool_call_id) }} + {{- "<|tool_response:result|>" }} + {{- message.content }} + {{- "<|tool_response:end|>" }} + {%- if not next_is_tool -%} + {{- "<|end|>" }} + {%- endif -%} + {%- elif message.role == 'assistant' -%} + {#- ======== Assistant Thinking ======== #} + {%- if think_render_option == "all" -%} + {%- if message.reasoning -%} + {{- "<|begin|>assistant<|think|>" + message.reasoning + "<|end|>" }} + {%- endif -%} + {%- elif think_render_option == "lastthink" -%} + {%- if message.reasoning and loop.index0 > last_user_idx.value -%} + {{- "<|begin|>assistant<|think|>" + message.reasoning + "<|end|>" }} + {%- endif -%} + {%- endif -%} + + {#- ======== Assistant Messages ======== #} + {%- if message.tool_calls -%} + {{- "<|begin|>assistant<|tool_calls|>" }} + {%- for tool_call in message.tool_calls -%} + {{- "<|tool_call:begin|>" + tool_call.id +"<|tool_call:name|>" + tool_call.function.name + "<|tool_call:args|>" }} + {{- render_tool_arguments(tool_call.function.arguments) }} + {{- "<|tool_call:end|>" }} + {%- endfor -%} + {{- "<|calls|>" }} + {%- else -%} + {{- "<|begin|>assistant<|content|>" + message.content + "<|end|>" }} + {%- endif -%} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt -%} + {%- if reasoning_effort in ["low", "minimal"] -%} + {{- "<|begin|>assistant<|think|><|end|>" }} + {%- endif -%} + {{- "<|begin|>assistant" }} +{%- endif -%} diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index a07c81fba6..3a8086ca91 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -589,7 +589,7 @@ static void test_peg_parser(common_chat_templates * tmpls, const std::function123456789" + "<|tool_call:name|>special_function" + "<|tool_call:args|>{\"arg1\":1}" + "<|tool_call:end|>"; + + t.params.chat_template_kwargs["reasoning_effort"] = "\"low\""; + t.params.tools = {special_function_tool}; + t.expect = message_assist_call_id; + }); + + // Test tool call with reasoning + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "<|think|>I'm\nthinking<|end|>" + "<|begin|>assistant<|tool_calls|>" + "<|tool_call:begin|>0" + "<|tool_call:name|>special_function" + "<|tool_call:args|>{\"arg1\":1}" + "<|tool_call:end|>"; + + t.params.tools = {special_function_tool}; + t.expect = message_assist_thoughts_call_idx; + }); + + // Test tool call with reasoning and tool_choice = required + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "<|think|>I'm\nthinking<|end|>" + "<|begin|>assistant<|tool_calls|>" + "<|tool_call:begin|>0" + "<|tool_call:name|>special_function" + "<|tool_call:args|>{\"arg1\":1}" + "<|tool_call:end|>"; + + t.params.tools = {special_function_tool}; + t.params.tool_choice = COMMON_CHAT_TOOL_CHOICE_REQUIRED; + t.expect = message_assist_thoughts_call_idx; + }); + + // Test tool call without reasoning and tool_choice = required + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "<|begin|>assistant<|tool_calls|>" + "<|tool_call:begin|>0" + "<|tool_call:name|>special_function" + "<|tool_call:args|>{\"arg1\":1}" + "<|tool_call:end|>"; + + t.params.tools = {special_function_tool}; + t.params.tool_choice = COMMON_CHAT_TOOL_CHOICE_REQUIRED; + t.params.chat_template_kwargs["reasoning_effort"] = "\"low\""; + t.expect = message_assist_call_idx; + }); + + // Test parallel tool calls + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "<|think|>I'm\nthinking<|end|>" + "<|begin|>assistant<|tool_calls|>" + "<|tool_call:begin|>0" + "<|tool_call:name|>special_function" + "<|tool_call:args|>{\"arg1\":1}" + "<|tool_call:end|>" + "<|tool_call:begin|>1" + "<|tool_call:name|>special_function_with_opt" + "<|tool_call:args|>{\"arg1\": 1, \"arg2\": 2}" + "<|tool_call:end|>"; + + t.params.parallel_tool_calls = true; + t.params.tools = {special_function_tool, special_function_tool_with_optional_param}; + + t.expect.reasoning_content = "I'm\nthinking"; + t.expect.tool_calls = {{ + /* .name = */ "special_function", + /* .arguments = */ R"({"arg1": 1})", + /* .id = */ "0", + }, { + /* .name = */ "special_function_with_opt", + /* .arguments = */ R"({"arg1": 1, "arg2": 2})", + /* .id = */ "1", + }}; + }); + + // Test response format + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "<|think|>I need to output the invoice details in JSON<|end|>" + "<|begin|>assistant<|content|>" + R"({"amount": 123.45, "date": "2025-12-03"})"; + + t.params.json_schema = invoice_schema; + + t.expect.reasoning_content = "I need to output the invoice details in JSON"; + t.expect.content =R"({"amount": 123.45, "date": "2025-12-03"})"; + }); + + // Test response format no reasoning + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "<|begin|>assistant<|content|>" + R"({"amount": 123.45, "date": "2025-12-03"})"; + + t.params.chat_template_kwargs["reasoning_effort"] = "\"low\""; + t.params.json_schema = invoice_schema; + + t.expect.content =R"({"amount": 123.45, "date": "2025-12-03"})"; + }); + } } static void test_msg_diffs_compute() {