grammar: Fix grammar root symbol check (#19761)

* grammar: fix bad check for root symbol, correct error logging

* add tests to demonstrate root symbol check failure
This commit is contained in:
Asbjørn Olling 2026-03-12 12:04:56 +01:00 committed by GitHub
parent deee23863b
commit 0a10c34dc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 6 deletions

View File

@ -1160,13 +1160,13 @@ struct llama_grammar * llama_grammar_init_impl(
// if there is a grammar, parse it
// rules will be empty (default) if there are parse errors
if (!parser.parse(grammar_str) || parser.rules.empty()) {
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
LLAMA_LOG_ERROR("failed to parse grammar\n");
return nullptr;
}
// Ensure that there is a "root" node.
if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
// Ensure that the grammar contains the start symbol
if (parser.symbol_ids.find(grammar_root) == parser.symbol_ids.end()) {
LLAMA_LOG_ERROR("grammar does not contain a '%s' symbol\n", grammar_root);
return nullptr;
}
@ -1195,7 +1195,7 @@ struct llama_grammar * llama_grammar_init_impl(
continue;
}
if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu\n", i);
return nullptr;
}
}

View File

@ -15,8 +15,12 @@
using json = nlohmann::ordered_json;
static llama_grammar * build_grammar_with_root(const std::string & grammar_str, const char * grammar_root) {
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), grammar_root, false, nullptr, 0, nullptr, 0);
}
static llama_grammar * build_grammar(const std::string & grammar_str) {
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0);
return build_grammar_with_root(grammar_str, "root");
}
static bool test_build_grammar_fails(const std::string & grammar_str) {
@ -860,6 +864,36 @@ static void test_failure_left_recursion() {
fprintf(stderr, " ✅︎ Passed\n");
}
static void test_failure_missing_root_symbol() {
fprintf(stderr, "⚫ Testing missing root symbol:\n");
const std::string grammar_str = R"""(
root ::= "foobar"
)""";
llama_grammar * failure_result = build_grammar_with_root(grammar_str, "nonexistent");
assert(failure_result == nullptr);
fprintf(stderr, " ✅︎ Passed\n");
}
static void test_custom_root_symbol_check() {
fprintf(stderr, "⚫ Testing custom root symbol check:\n");
const std::string custom_root_grammar_str = R"""(
foobar ::= "foobar"
)""";
llama_grammar * failure_result = build_grammar_with_root(custom_root_grammar_str, "root");
assert(failure_result == nullptr);
llama_grammar * success_result = build_grammar_with_root(custom_root_grammar_str, "foobar");
assert(success_result != nullptr);
llama_grammar_free_impl(success_result);
fprintf(stderr, " ✅︎ Passed\n");
}
static void test_json_schema() {
// Note that this is similar to the regular grammar tests,
// but we convert each json schema to a grammar before parsing.
@ -1433,6 +1467,8 @@ int main() {
test_failure_missing_root();
test_failure_missing_reference();
test_failure_left_recursion();
test_failure_missing_root_symbol();
test_custom_root_symbol_check();
test_json_schema();
fprintf(stdout, "All tests passed.\n");
return 0;