From ce0a0eac747e0cfd4c6b6fe4c04b92381c4a5770 Mon Sep 17 00:00:00 2001 From: Lucas 'Paperboy' Rose-Winters Date: Sun, 25 Jan 2026 21:59:06 +1100 Subject: [PATCH] fix(grammar): enforce recursion limit to prevent stack overflow --- src/llama-grammar.cpp | 16 +++++++++---- src/llama-grammar.h | 6 +++-- tests/CMakeLists.txt | 2 ++ tests/test-grammar-dos.cpp | 47 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 7 deletions(-) create mode 100644 tests/test-grammar-dos.cpp diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 64ea2fd00a..d6caa5b298 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -9,6 +9,7 @@ #include #include +#define MAX_RECURSION_DEPTH 100 #define MAX_REPETITION_THRESHOLD 2000 // // helpers @@ -434,13 +435,17 @@ const char * llama_grammar_parser::parse_alternates( const char * src, const std::string & rule_name, uint32_t rule_id, - bool is_nested) { + bool is_nested, + int depth) { + if (depth > MAX_RECURSION_DEPTH) { + throw std::runtime_error("grammar recursion depth exceeded"); + } llama_grammar_rule rule; - const char * pos = parse_sequence(src, rule_name, rule, is_nested); + const char * pos = parse_sequence(src, rule_name, rule, is_nested, depth); while (*pos == '|') { rule.push_back({LLAMA_GRETYPE_ALT, 0}); pos = parse_space(pos + 1, true); - pos = parse_sequence(pos, rule_name, rule, is_nested); + pos = parse_sequence(pos, rule_name, rule, is_nested, depth); } rule.push_back({LLAMA_GRETYPE_END, 0}); add_rule(rule_id, rule); @@ -451,7 +456,8 @@ const char * llama_grammar_parser::parse_sequence( const char * src, const std::string & rule_name, llama_grammar_rule & rule, - bool is_nested) { + bool is_nested, + int depth) { size_t last_sym_start = rule.size(); const char * pos = src; @@ -573,7 +579,7 @@ const char * llama_grammar_parser::parse_sequence( // parse nested alternates into synthesized rule pos = parse_space(pos + 1, true); uint32_t sub_rule_id = generate_symbol_id(rule_name); - pos = parse_alternates(pos, rule_name, sub_rule_id, true); + pos = parse_alternates(pos, rule_name, sub_rule_id, true, depth + 1); last_sym_start = rule.size(); // output reference to synthesized rule rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); diff --git a/src/llama-grammar.h b/src/llama-grammar.h index b5a0e588e9..63e912f3e7 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -102,13 +102,15 @@ struct llama_grammar_parser { const char * src, const std::string & rule_name, uint32_t rule_id, - bool is_nested); + bool is_nested, + int depth = 0); const char * parse_sequence( const char * src, const std::string & rule_name, llama_grammar_rule & rule, - bool is_nested); + bool is_nested, + int depth = 0); const char * parse_rule(const char * src); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c9436c5995..1a1ac019de 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -147,6 +147,7 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) llama_build_and_test(test-sampling.cpp) llama_build_and_test(test-grammar-parser.cpp) llama_build_and_test(test-grammar-integration.cpp) + llama_build_and_test(test-grammar-dos.cpp) llama_build_and_test(test-llama-grammar.cpp) llama_build_and_test(test-chat.cpp) # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8 @@ -188,6 +189,7 @@ llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp) llama_build_and_test(test-chat-template.cpp) llama_build_and_test(test-jinja.cpp) llama_test(test-jinja NAME test-jinja-py ARGS -py LABEL python) +llama_build_and_test(test-jinja-dos.cpp) llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) llama_build_and_test( diff --git a/tests/test-grammar-dos.cpp b/tests/test-grammar-dos.cpp new file mode 100644 index 0000000000..9bc3f9ba33 --- /dev/null +++ b/tests/test-grammar-dos.cpp @@ -0,0 +1,47 @@ +#include "../src/llama-grammar.h" +#include "../src/llama-vocab.h" +#include +#include +#include + +static void test_grammar_recursion_dos() { + // Create a deeply nested grammar: root ::= (((...))) + // Each pair of parens adds a stack frame in parse_sequence -> parse_alternates + int depth = 100000; + std::string grammar = "root ::= "; + for (int i = 0; i < depth; ++i) { + grammar += "("; + } + grammar += " \"a\" "; + for (int i = 0; i < depth; ++i) { + grammar += ")"; + } + + try { + std::cout << "Attempting to parse deeply nested grammar (depth=" << depth << ")..." << std::endl; + + // We don't need a real vocab for this test if we don't use token references + // But the parser might need one if we used syntax. We are using string literals. + + llama_grammar_parser parser; + if (!parser.parse(grammar.c_str())) { + std::cout << "Parser returned false (failed gracefully)" << std::endl; + } else { + std::cout << "Parser succeeded (unexpected)" << std::endl; + } + + } catch (const std::exception & e) { + std::string what = e.what(); + if (what.find("grammar recursion depth exceeded") != std::string::npos) { + std::cout << "Caught expected exception: " << what << std::endl; + } else { + std::cout << "Caught unexpected exception: " << what << std::endl; + throw; + } + } +} + +int main() { + test_grammar_recursion_dos(); + return 0; +}