Add proper reasoning detection for cli

This commit is contained in:
Piotr Wilkin 2026-02-16 01:44:16 +01:00
parent 4f61ee302d
commit 562fe10dd5
8 changed files with 71 additions and 26 deletions

View File

@ -2,6 +2,7 @@
#include "chat-auto-parser.h"
#include "ggml.h"
#include "peg-parser.h"
#include <nlohmann/json.hpp>
@ -195,6 +196,22 @@ tagged_parse_result tagged_peg_parser::parse_and_extract(const std::string & inp
return { std::move(parse_result), std::move(mapper.tags) };
}
tagged_parse_result tagged_peg_parser::parse_anywhere_and_extract(const std::string & input) const {
if (input.empty()) {
return parse_and_extract(input, false);
}
for (int i = input.size() - 1; i >= 0; i--) {
common_peg_parse_context ctx(input.substr(i), false);
auto parse_result = arena.parse(ctx);
if (parse_result.success() || i == 0) {
tag_based_peg_mapper mapper;
mapper.from_ast(ctx.ast, parse_result);
return { std::move(parse_result), std::move(mapper.tags) };
}
}
GGML_ABORT("Should not happen");
}
tagged_peg_parser build_tagged_peg_parser(
const std::function<common_peg_parser(common_peg_parser_builder & builder)> & fn) {
common_peg_parser_builder builder;

View File

@ -154,6 +154,7 @@ struct tagged_peg_parser {
common_peg_arena arena;
tagged_parse_result parse_and_extract(const std::string & input, bool is_partial = false) const;
tagged_parse_result parse_anywhere_and_extract(const std::string & input) const;
};
tagged_peg_parser build_tagged_peg_parser(

View File

@ -227,22 +227,18 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin
}
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) {
common_chat_templates_inputs dummy_inputs;
common_chat_msg msg;
msg.role = "user";
msg.content = "test";
dummy_inputs.messages = { msg };
dummy_inputs.enable_thinking = false;
const auto rendered_no_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
dummy_inputs.enable_thinking = true;
const auto rendered_with_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
bool detect = rendered_no_thinking.prompt != rendered_with_thinking.prompt;
const auto & tmpl = chat_templates->template_tool_use
? *chat_templates->template_tool_use
: *chat_templates->template_default;
autoparser::analyze_template result(tmpl);
detect |= result.reasoning.mode != autoparser::reasoning_mode::NONE;
return detect;
common_chat_templates_inputs inputs;
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
common_chat_msg msg;
msg.role = "user";
msg.content = "test";
inputs.messages = { msg };
inputs.enable_thinking = true;
inputs.add_generation_prompt = true;
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
auto params = common_chat_templates_apply(chat_templates, inputs);
return params.supports_thinking;
}
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
@ -850,9 +846,10 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = true;
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = {
data.supports_thinking = true;
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = {
"[THINK]",
"[/THINK]",
"[TOOL_CALLS]",
@ -949,8 +946,9 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
}
}
data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
// These special tokens are required to parse properly, so we include them
// even if parse_tool_calls is false.
@ -1281,7 +1279,9 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
try {
LOG_DBG("Using differential autoparser\n");
auto auto_params = autoparser::universal_peg_generator::generate_parser(tmpl, params);
auto analysis = autoparser::analyze_template(tmpl);
auto auto_params = autoparser::universal_peg_generator::generate_parser(tmpl, params, analysis);
auto_params.supports_thinking = analysis.reasoning.mode != autoparser::reasoning_mode::NONE;
return auto_params;
} catch (const std::exception & e) {
LOG_WRN("Automatic parser generation failed: %s\n", e.what());

View File

@ -198,7 +198,7 @@ struct common_chat_templates_inputs {
std::vector<common_chat_tool> tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking"
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking"
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::map<std::string, std::string> chat_template_kwargs;
@ -212,6 +212,7 @@ struct common_chat_params {
std::string grammar;
bool grammar_lazy = false;
bool thinking_forced_open = false;
bool supports_thinking = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
@ -222,7 +223,7 @@ struct common_chat_params {
// should be derived from common_chat_params
struct common_chat_parser_params {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false;
bool thinking_forced_open = false;

View File

@ -1268,6 +1268,9 @@ class TextModel(ModelBase):
if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4":
# ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct
res = "qwen35"
if chkhsh == "b4b8ca1f9769494fbd956ebc4c249de6131fb277a4a3345a7a92c7dd7a55808d":
# ref: https://huggingface.co/jdopensource/JoyAI-LLM-Flash
res = "joyai-llm-flash"
if res is None:
logger.warning("\n")

View File

@ -149,7 +149,8 @@ models = [
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", }
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", },
{"name": "joyai-llm-flash", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", },
]
# some models are known to be broken upstream, so we will skip them as exceptions

View File

@ -936,4 +936,24 @@ static void test_tagged_peg_parser(testing & t) {
t.assert_equal("prefix tag", "key", result.tags.at("prefix"));
t.assert_equal("value tag", "val", result.tags.at("value"));
});
t.test("find in the middle", [&](testing & t) {
auto parser = build_tagged_peg_parser([](common_peg_parser_builder & p) {
return p.choice({ p.literal("{"), p.literal(":") }) + p.space() + p.literal("\"") + p.atomic(p.literal("fun_name"));
});
std::string tpl = "This is a very long jinja template string. We have tools. We will try to call them now: <tool_call>{ \"fun_name\" : { \"arg\" : 1 }</tool_call>";
auto result = parser.parse_anywhere_and_extract(tpl);
t.assert_true("success", result.result.success());
});
t.test("fail find in the middle", [&](testing & t) {
auto parser = build_tagged_peg_parser([](common_peg_parser_builder & p) {
return p.choice({ p.literal("{"), p.literal(":") }) + p.space() + p.literal("\"") + p.atomic(p.literal("fun_name"));
});
std::string tpl = "This is a very long jinja template string. We have tools. We will try to call them now: <tool_call><fun=fun_name><arg name=arg>1</arg></tool_call>";
auto result = parser.parse_anywhere_and_extract(tpl);
t.assert_true("failure", result.result.fail());
});
}

View File

@ -1,3 +1,4 @@
#include "chat.h"
#include "common.h"
#include "arg.h"
#include "console.h"
@ -188,7 +189,8 @@ struct cli_context {
inputs.use_jinja = chat_params.use_jinja;
inputs.parallel_tool_calls = false;
inputs.add_generation_prompt = true;
inputs.enable_thinking = chat_params.enable_thinking;
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
inputs.enable_thinking = common_chat_templates_support_enable_thinking(chat_params.tmpls.get());
// Apply chat template to the list of messages
return common_chat_templates_apply(chat_params.tmpls.get(), inputs);