Add proper reasoning detection for cli
This commit is contained in:
parent
4f61ee302d
commit
562fe10dd5
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include "chat-auto-parser.h"
|
||||
#include "ggml.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
|
|
@ -195,6 +196,22 @@ tagged_parse_result tagged_peg_parser::parse_and_extract(const std::string & inp
|
|||
return { std::move(parse_result), std::move(mapper.tags) };
|
||||
}
|
||||
|
||||
tagged_parse_result tagged_peg_parser::parse_anywhere_and_extract(const std::string & input) const {
|
||||
if (input.empty()) {
|
||||
return parse_and_extract(input, false);
|
||||
}
|
||||
for (int i = input.size() - 1; i >= 0; i--) {
|
||||
common_peg_parse_context ctx(input.substr(i), false);
|
||||
auto parse_result = arena.parse(ctx);
|
||||
if (parse_result.success() || i == 0) {
|
||||
tag_based_peg_mapper mapper;
|
||||
mapper.from_ast(ctx.ast, parse_result);
|
||||
return { std::move(parse_result), std::move(mapper.tags) };
|
||||
}
|
||||
}
|
||||
GGML_ABORT("Should not happen");
|
||||
}
|
||||
|
||||
tagged_peg_parser build_tagged_peg_parser(
|
||||
const std::function<common_peg_parser(common_peg_parser_builder & builder)> & fn) {
|
||||
common_peg_parser_builder builder;
|
||||
|
|
|
|||
|
|
@ -154,6 +154,7 @@ struct tagged_peg_parser {
|
|||
common_peg_arena arena;
|
||||
|
||||
tagged_parse_result parse_and_extract(const std::string & input, bool is_partial = false) const;
|
||||
tagged_parse_result parse_anywhere_and_extract(const std::string & input) const;
|
||||
};
|
||||
|
||||
tagged_peg_parser build_tagged_peg_parser(
|
||||
|
|
|
|||
|
|
@ -227,22 +227,18 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin
|
|||
}
|
||||
|
||||
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) {
|
||||
common_chat_templates_inputs dummy_inputs;
|
||||
common_chat_msg msg;
|
||||
msg.role = "user";
|
||||
msg.content = "test";
|
||||
dummy_inputs.messages = { msg };
|
||||
dummy_inputs.enable_thinking = false;
|
||||
const auto rendered_no_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
|
||||
dummy_inputs.enable_thinking = true;
|
||||
const auto rendered_with_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
|
||||
bool detect = rendered_no_thinking.prompt != rendered_with_thinking.prompt;
|
||||
const auto & tmpl = chat_templates->template_tool_use
|
||||
? *chat_templates->template_tool_use
|
||||
: *chat_templates->template_default;
|
||||
autoparser::analyze_template result(tmpl);
|
||||
detect |= result.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
return detect;
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
common_chat_msg msg;
|
||||
msg.role = "user";
|
||||
msg.content = "test";
|
||||
inputs.messages = { msg };
|
||||
inputs.enable_thinking = true;
|
||||
inputs.add_generation_prompt = true;
|
||||
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
|
||||
auto params = common_chat_templates_apply(chat_templates, inputs);
|
||||
return params.supports_thinking;
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
|
||||
|
|
@ -850,9 +846,10 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
|||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = true;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
data.supports_thinking = true;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
"[THINK]",
|
||||
"[/THINK]",
|
||||
"[TOOL_CALLS]",
|
||||
|
|
@ -949,8 +946,9 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||
}
|
||||
}
|
||||
|
||||
data.prompt = prompt;
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.prompt = prompt;
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
|
||||
// These special tokens are required to parse properly, so we include them
|
||||
// even if parse_tool_calls is false.
|
||||
|
|
@ -1281,7 +1279,9 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
|
||||
try {
|
||||
LOG_DBG("Using differential autoparser\n");
|
||||
auto auto_params = autoparser::universal_peg_generator::generate_parser(tmpl, params);
|
||||
auto analysis = autoparser::analyze_template(tmpl);
|
||||
auto auto_params = autoparser::universal_peg_generator::generate_parser(tmpl, params, analysis);
|
||||
auto_params.supports_thinking = analysis.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
return auto_params;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("Automatic parser generation failed: %s\n", e.what());
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ struct common_chat_templates_inputs {
|
|||
std::vector<common_chat_tool> tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
bool parallel_tool_calls = false;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking"
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking"
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
std::map<std::string, std::string> chat_template_kwargs;
|
||||
|
|
@ -212,6 +212,7 @@ struct common_chat_params {
|
|||
std::string grammar;
|
||||
bool grammar_lazy = false;
|
||||
bool thinking_forced_open = false;
|
||||
bool supports_thinking = false;
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
|
|
@ -222,7 +223,7 @@ struct common_chat_params {
|
|||
// should be derived from common_chat_params
|
||||
struct common_chat_parser_params {
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
|
||||
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
|
|
|
|||
|
|
@ -1268,6 +1268,9 @@ class TextModel(ModelBase):
|
|||
if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4":
|
||||
# ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct
|
||||
res = "qwen35"
|
||||
if chkhsh == "b4b8ca1f9769494fbd956ebc4c249de6131fb277a4a3345a7a92c7dd7a55808d":
|
||||
# ref: https://huggingface.co/jdopensource/JoyAI-LLM-Flash
|
||||
res = "joyai-llm-flash"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
|
|
|||
|
|
@ -149,7 +149,8 @@ models = [
|
|||
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
|
||||
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
|
||||
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", }
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", },
|
||||
{"name": "joyai-llm-flash", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
|
|
|
|||
|
|
@ -936,4 +936,24 @@ static void test_tagged_peg_parser(testing & t) {
|
|||
t.assert_equal("prefix tag", "key", result.tags.at("prefix"));
|
||||
t.assert_equal("value tag", "val", result.tags.at("value"));
|
||||
});
|
||||
|
||||
t.test("find in the middle", [&](testing & t) {
|
||||
auto parser = build_tagged_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.choice({ p.literal("{"), p.literal(":") }) + p.space() + p.literal("\"") + p.atomic(p.literal("fun_name"));
|
||||
});
|
||||
|
||||
std::string tpl = "This is a very long jinja template string. We have tools. We will try to call them now: <tool_call>{ \"fun_name\" : { \"arg\" : 1 }</tool_call>";
|
||||
auto result = parser.parse_anywhere_and_extract(tpl);
|
||||
t.assert_true("success", result.result.success());
|
||||
});
|
||||
|
||||
t.test("fail find in the middle", [&](testing & t) {
|
||||
auto parser = build_tagged_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.choice({ p.literal("{"), p.literal(":") }) + p.space() + p.literal("\"") + p.atomic(p.literal("fun_name"));
|
||||
});
|
||||
|
||||
std::string tpl = "This is a very long jinja template string. We have tools. We will try to call them now: <tool_call><fun=fun_name><arg name=arg>1</arg></tool_call>";
|
||||
auto result = parser.parse_anywhere_and_extract(tpl);
|
||||
t.assert_true("failure", result.result.fail());
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "arg.h"
|
||||
#include "console.h"
|
||||
|
|
@ -188,7 +189,8 @@ struct cli_context {
|
|||
inputs.use_jinja = chat_params.use_jinja;
|
||||
inputs.parallel_tool_calls = false;
|
||||
inputs.add_generation_prompt = true;
|
||||
inputs.enable_thinking = chat_params.enable_thinking;
|
||||
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
inputs.enable_thinking = common_chat_templates_support_enable_thinking(chat_params.tmpls.get());
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
return common_chat_templates_apply(chat_params.tmpls.get(), inputs);
|
||||
|
|
|
|||
Loading…
Reference in New Issue