From 0a10c34dc146b19cc256a66b36321d753899a684 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Asbj=C3=B8rn=20Olling?= Date: Thu, 12 Mar 2026 12:04:56 +0100 Subject: [PATCH] grammar: Fix grammar root symbol check (#19761) * grammar: fix bad check for root symbol, correct error logging * add tests to demonstrate root symbol check failure --- src/llama-grammar.cpp | 10 ++++---- tests/test-grammar-integration.cpp | 38 +++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 3b7a625234..aac0d41f2b 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1160,13 +1160,13 @@ struct llama_grammar * llama_grammar_init_impl( // if there is a grammar, parse it // rules will be empty (default) if there are parse errors if (!parser.parse(grammar_str) || parser.rules.empty()) { - fprintf(stderr, "%s: failed to parse grammar\n", __func__); + LLAMA_LOG_ERROR("failed to parse grammar\n"); return nullptr; } - // Ensure that there is a "root" node. - if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) { - fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); + // Ensure that the grammar contains the start symbol + if (parser.symbol_ids.find(grammar_root) == parser.symbol_ids.end()) { + LLAMA_LOG_ERROR("grammar does not contain a '%s' symbol\n", grammar_root); return nullptr; } @@ -1195,7 +1195,7 @@ struct llama_grammar * llama_grammar_init_impl( continue; } if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { - LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i); + LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu\n", i); return nullptr; } } diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 7aa7e58a5c..526470a224 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -15,8 +15,12 @@ using json = nlohmann::ordered_json; +static llama_grammar * build_grammar_with_root(const std::string & grammar_str, const char * grammar_root) { + return llama_grammar_init_impl(nullptr, grammar_str.c_str(), grammar_root, false, nullptr, 0, nullptr, 0); +} + static llama_grammar * build_grammar(const std::string & grammar_str) { - return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0); + return build_grammar_with_root(grammar_str, "root"); } static bool test_build_grammar_fails(const std::string & grammar_str) { @@ -860,6 +864,36 @@ static void test_failure_left_recursion() { fprintf(stderr, " ✅︎ Passed\n"); } +static void test_failure_missing_root_symbol() { + fprintf(stderr, "⚫ Testing missing root symbol:\n"); + + const std::string grammar_str = R"""( + root ::= "foobar" + )"""; + + llama_grammar * failure_result = build_grammar_with_root(grammar_str, "nonexistent"); + assert(failure_result == nullptr); + + fprintf(stderr, " ✅︎ Passed\n"); +} + +static void test_custom_root_symbol_check() { + fprintf(stderr, "⚫ Testing custom root symbol check:\n"); + + const std::string custom_root_grammar_str = R"""( + foobar ::= "foobar" + )"""; + + llama_grammar * failure_result = build_grammar_with_root(custom_root_grammar_str, "root"); + assert(failure_result == nullptr); + + llama_grammar * success_result = build_grammar_with_root(custom_root_grammar_str, "foobar"); + assert(success_result != nullptr); + llama_grammar_free_impl(success_result); + + fprintf(stderr, " ✅︎ Passed\n"); +} + static void test_json_schema() { // Note that this is similar to the regular grammar tests, // but we convert each json schema to a grammar before parsing. @@ -1433,6 +1467,8 @@ int main() { test_failure_missing_root(); test_failure_missing_reference(); test_failure_left_recursion(); + test_failure_missing_root_symbol(); + test_custom_root_symbol_check(); test_json_schema(); fprintf(stdout, "All tests passed.\n"); return 0;