fix(grammar): enforce recursion limit to prevent stack overflow

This commit is contained in:
Lucas 'Paperboy' Rose-Winters 2026-01-25 21:59:06 +11:00
parent 080b161995
commit ce0a0eac74
4 changed files with 64 additions and 7 deletions

View File

@ -9,6 +9,7 @@
#include <cstdint>
#include <stdexcept>
#define MAX_RECURSION_DEPTH 100
#define MAX_REPETITION_THRESHOLD 2000
//
// helpers
@ -434,13 +435,17 @@ const char * llama_grammar_parser::parse_alternates(
const char * src,
const std::string & rule_name,
uint32_t rule_id,
bool is_nested) {
bool is_nested,
int depth) {
if (depth > MAX_RECURSION_DEPTH) {
throw std::runtime_error("grammar recursion depth exceeded");
}
llama_grammar_rule rule;
const char * pos = parse_sequence(src, rule_name, rule, is_nested);
const char * pos = parse_sequence(src, rule_name, rule, is_nested, depth);
while (*pos == '|') {
rule.push_back({LLAMA_GRETYPE_ALT, 0});
pos = parse_space(pos + 1, true);
pos = parse_sequence(pos, rule_name, rule, is_nested);
pos = parse_sequence(pos, rule_name, rule, is_nested, depth);
}
rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule(rule_id, rule);
@ -451,7 +456,8 @@ const char * llama_grammar_parser::parse_sequence(
const char * src,
const std::string & rule_name,
llama_grammar_rule & rule,
bool is_nested) {
bool is_nested,
int depth) {
size_t last_sym_start = rule.size();
const char * pos = src;
@ -573,7 +579,7 @@ const char * llama_grammar_parser::parse_sequence(
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
uint32_t sub_rule_id = generate_symbol_id(rule_name);
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
pos = parse_alternates(pos, rule_name, sub_rule_id, true, depth + 1);
last_sym_start = rule.size();
// output reference to synthesized rule
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});

View File

@ -102,13 +102,15 @@ struct llama_grammar_parser {
const char * src,
const std::string & rule_name,
uint32_t rule_id,
bool is_nested);
bool is_nested,
int depth = 0);
const char * parse_sequence(
const char * src,
const std::string & rule_name,
llama_grammar_rule & rule,
bool is_nested);
bool is_nested,
int depth = 0);
const char * parse_rule(const char * src);

View File

@ -147,6 +147,7 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
llama_build_and_test(test-sampling.cpp)
llama_build_and_test(test-grammar-parser.cpp)
llama_build_and_test(test-grammar-integration.cpp)
llama_build_and_test(test-grammar-dos.cpp)
llama_build_and_test(test-llama-grammar.cpp)
llama_build_and_test(test-chat.cpp)
# TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
@ -188,6 +189,7 @@ llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp)
llama_build_and_test(test-chat-template.cpp)
llama_build_and_test(test-jinja.cpp)
llama_test(test-jinja NAME test-jinja-py ARGS -py LABEL python)
llama_build_and_test(test-jinja-dos.cpp)
llama_build_and_test(test-json-partial.cpp)
llama_build_and_test(test-log.cpp)
llama_build_and_test(

View File

@ -0,0 +1,47 @@
#include "../src/llama-grammar.h"
#include "../src/llama-vocab.h"
#include <iostream>
#include <string>
#include <vector>
static void test_grammar_recursion_dos() {
// Create a deeply nested grammar: root ::= (((...)))
// Each pair of parens adds a stack frame in parse_sequence -> parse_alternates
int depth = 100000;
std::string grammar = "root ::= ";
for (int i = 0; i < depth; ++i) {
grammar += "(";
}
grammar += " \"a\" ";
for (int i = 0; i < depth; ++i) {
grammar += ")";
}
try {
std::cout << "Attempting to parse deeply nested grammar (depth=" << depth << ")..." << std::endl;
// We don't need a real vocab for this test if we don't use token references
// But the parser might need one if we used <token> syntax. We are using string literals.
llama_grammar_parser parser;
if (!parser.parse(grammar.c_str())) {
std::cout << "Parser returned false (failed gracefully)" << std::endl;
} else {
std::cout << "Parser succeeded (unexpected)" << std::endl;
}
} catch (const std::exception & e) {
std::string what = e.what();
if (what.find("grammar recursion depth exceeded") != std::string::npos) {
std::cout << "Caught expected exception: " << what << std::endl;
} else {
std::cout << "Caught unexpected exception: " << what << std::endl;
throw;
}
}
}
int main() {
test_grammar_recursion_dos();
return 0;
}