jinja : refactor token advancement (#20864)

* refactor token advancement

* exercise sub-expressions
This commit is contained in:
Sigbjørn Skjæret 2026-03-22 17:45:10 +01:00 committed by GitHub
parent 81bc4d3ddc
commit 23c9182ce8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 27 deletions

View File

@ -53,6 +53,13 @@ private:
return tokens[current + offset]; return tokens[current + offset];
} }
const token & next() {
if (current >= tokens.size()) {
throw parser_exception("Parser Error: Unexpected EOF", source, tokens.empty() ? 0 : tokens.back().pos);
}
return tokens[current++];
}
token expect(token::type type, const std::string& error) { token expect(token::type type, const std::string& error) {
const auto & t = peek(); const auto & t = peek();
if (t.t != type) { if (t.t != type) {
@ -90,9 +97,9 @@ private:
size_t start_pos = current; size_t start_pos = current;
switch (peek().t) { switch (peek().t) {
case token::comment: case token::comment:
return mk_stmt<comment_statement>(start_pos, tokens[current++].value); return mk_stmt<comment_statement>(start_pos, next().value);
case token::text: case token::text:
return mk_stmt<string_literal>(start_pos, tokens[current++].value); return mk_stmt<string_literal>(start_pos, next().value);
case token::open_statement: case token::open_statement:
return parse_jinja_statement(); return parse_jinja_statement();
case token::open_expression: case token::open_expression:
@ -119,8 +126,7 @@ private:
} }
size_t start_pos = current; size_t start_pos = current;
std::string name = peek().value; std::string name = next().value;
current++; // consume identifier
statement_ptr result; statement_ptr result;
if (name == "set") { if (name == "set") {
@ -202,7 +208,7 @@ private:
// Ignore generation blocks (transformers-specific) // Ignore generation blocks (transformers-specific)
// See https://github.com/huggingface/transformers/pull/30650 for more information. // See https://github.com/huggingface/transformers/pull/30650 for more information.
result = mk_stmt<noop_statement>(start_pos); result = mk_stmt<noop_statement>(start_pos);
current++; ++current;
} else { } else {
throw std::runtime_error("Unknown statement: " + name); throw std::runtime_error("Unknown statement: " + name);
@ -217,7 +223,7 @@ private:
statements body; statements body;
if (is(token::equals)) { if (is(token::equals)) {
current++; ++current;
value = parse_expression_sequence(); value = parse_expression_sequence();
} else { } else {
// parsing multiline set here // parsing multiline set here
@ -280,7 +286,7 @@ private:
exprs.push_back(primary ? parse_primary_expression() : parse_expression()); exprs.push_back(primary ? parse_primary_expression() : parse_expression());
bool is_tuple = is(token::comma); bool is_tuple = is(token::comma);
while (is(token::comma)) { while (is(token::comma)) {
current++; // consume comma ++current; // consume comma
exprs.push_back(primary ? parse_primary_expression() : parse_expression()); exprs.push_back(primary ? parse_primary_expression() : parse_expression());
} }
return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]); return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]);
@ -290,7 +296,7 @@ private:
// e.g., `message` in `for message in messages` // e.g., `message` in `for message in messages`
auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple
if (!is_identifier("in")) throw std::runtime_error("Expected 'in'"); if (!is_identifier("in")) throw std::runtime_error("Expected 'in'");
current++; ++current; // consume 'in'
// `messages` in `for message in messages` // `messages` in `for message in messages`
auto iterable = parse_expression(); auto iterable = parse_expression();
@ -305,7 +311,8 @@ private:
} }
if (is_statement({"else"})) { if (is_statement({"else"})) {
current += 2; ++current; // consume {%
++current; // consume 'else'
expect(token::close_statement, "Expected %}"); expect(token::close_statement, "Expected %}");
while (!is_statement({"endfor"})) { while (!is_statement({"endfor"})) {
alternate.push_back(parse_any()); alternate.push_back(parse_any());
@ -347,7 +354,7 @@ private:
auto left = parse_logical_and_expression(); auto left = parse_logical_and_expression();
while (is_identifier("or")) { while (is_identifier("or")) {
size_t start_pos = current; size_t start_pos = current;
token op = tokens[current++]; token op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression());
} }
return left; return left;
@ -357,7 +364,7 @@ private:
auto left = parse_logical_negation_expression(); auto left = parse_logical_negation_expression();
while (is_identifier("and")) { while (is_identifier("and")) {
size_t start_pos = current; size_t start_pos = current;
auto op = tokens[current++]; auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression());
} }
return left; return left;
@ -367,7 +374,7 @@ private:
// Try parse unary operators // Try parse unary operators
if (is_identifier("not")) { if (is_identifier("not")) {
size_t start_pos = current; size_t start_pos = current;
auto op = tokens[current++]; auto op = next();
return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression()); return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression());
} }
return parse_comparison_expression(); return parse_comparison_expression();
@ -382,11 +389,12 @@ private:
size_t start_pos = current; size_t start_pos = current;
if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") { if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") {
op = {token::identifier, "not in", tokens[current].pos}; op = {token::identifier, "not in", tokens[current].pos};
current += 2; ++current; // consume 'not'
++current; // consume 'in'
} else if (is_identifier("in")) { } else if (is_identifier("in")) {
op = tokens[current++]; op = next();
} else if (is(token::comparison_binary_operator)) { } else if (is(token::comparison_binary_operator)) {
op = tokens[current++]; op = next();
} else break; } else break;
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression());
} }
@ -397,7 +405,7 @@ private:
auto left = parse_multiplicative_expression(); auto left = parse_multiplicative_expression();
while (is(token::additive_binary_operator)) { while (is(token::additive_binary_operator)) {
size_t start_pos = current; size_t start_pos = current;
auto op = tokens[current++]; auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression());
} }
return left; return left;
@ -407,7 +415,7 @@ private:
auto left = parse_test_expression(); auto left = parse_test_expression();
while (is(token::multiplicative_binary_operator)) { while (is(token::multiplicative_binary_operator)) {
size_t start_pos = current; size_t start_pos = current;
auto op = tokens[current++]; auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression());
} }
return left; return left;
@ -417,9 +425,9 @@ private:
auto operand = parse_filter_expression(); auto operand = parse_filter_expression();
while (is_identifier("is")) { while (is_identifier("is")) {
size_t start_pos = current; size_t start_pos = current;
current++; ++current; // consume 'is'
bool negate = false; bool negate = false;
if (is_identifier("not")) { current++; negate = true; } if (is_identifier("not")) { ++current; negate = true; }
auto test_id = parse_primary_expression(); auto test_id = parse_primary_expression();
// FIXME: tests can also be expressed like this: if x is eq 3 // FIXME: tests can also be expressed like this: if x is eq 3
if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id)); if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id));
@ -432,7 +440,7 @@ private:
auto operand = parse_call_member_expression(); auto operand = parse_call_member_expression();
while (is(token::pipe)) { while (is(token::pipe)) {
size_t start_pos = current; size_t start_pos = current;
current++; ++current; // consume pipe
auto filter = parse_primary_expression(); auto filter = parse_primary_expression();
if (is(token::open_paren)) filter = parse_call_expression(std::move(filter)); if (is(token::open_paren)) filter = parse_call_expression(std::move(filter));
operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter)); operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter));
@ -490,7 +498,7 @@ private:
statement_ptr parse_member_expression(statement_ptr object) { statement_ptr parse_member_expression(statement_ptr object) {
size_t start_pos = current; size_t start_pos = current;
while (is(token::dot) || is(token::open_square_bracket)) { while (is(token::dot) || is(token::open_square_bracket)) {
auto op = tokens[current++]; auto op = next();
bool computed = op.t == token::open_square_bracket; bool computed = op.t == token::open_square_bracket;
statement_ptr prop; statement_ptr prop;
if (computed) { if (computed) {
@ -536,7 +544,7 @@ private:
statement_ptr parse_primary_expression() { statement_ptr parse_primary_expression() {
size_t start_pos = current; size_t start_pos = current;
auto t = tokens[current++]; auto t = next();
switch (t.t) { switch (t.t) {
case token::numeric_literal: case token::numeric_literal:
if (t.value.find('.') != std::string::npos) { if (t.value.find('.') != std::string::npos) {
@ -547,7 +555,7 @@ private:
case token::string_literal: { case token::string_literal: {
std::string val = t.value; std::string val = t.value;
while (is(token::string_literal)) { while (is(token::string_literal)) {
val += tokens[current++].value; val += next().value;
} }
return mk_stmt<string_literal>(start_pos, val); return mk_stmt<string_literal>(start_pos, val);
} }
@ -562,9 +570,9 @@ private:
statements vals; statements vals;
while (!is(token::close_square_bracket)) { while (!is(token::close_square_bracket)) {
vals.push_back(parse_expression()); vals.push_back(parse_expression());
if (is(token::comma)) current++; if (is(token::comma)) ++current;
} }
current++; ++current;
return mk_stmt<array_literal>(start_pos, std::move(vals)); return mk_stmt<array_literal>(start_pos, std::move(vals));
} }
case token::open_curly_bracket: { case token::open_curly_bracket: {
@ -573,9 +581,9 @@ private:
auto key = parse_expression(); auto key = parse_expression();
expect(token::colon, "Expected :"); expect(token::colon, "Expected :");
pairs.push_back({std::move(key), parse_expression()}); pairs.push_back({std::move(key), parse_expression()});
if (is(token::comma)) current++; if (is(token::comma)) ++current;
} }
current++; ++current;
return mk_stmt<object_literal>(start_pos, std::move(pairs)); return mk_stmt<object_literal>(start_pos, std::move(pairs));
} }
default: default:

View File

@ -2264,6 +2264,7 @@ static void test_fuzzing(testing & t) {
t.test("malformed templates (should error, not crash)", [&](testing & t) { t.test("malformed templates (should error, not crash)", [&](testing & t) {
const std::vector<std::string> malformed = { const std::vector<std::string> malformed = {
"",
"{{ x", "{{ x",
"{% if %}", "{% if %}",
"{% for %}", "{% for %}",
@ -2284,6 +2285,11 @@ static void test_fuzzing(testing & t) {
for (const auto & tmpl : malformed) { for (const auto & tmpl : malformed) {
t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object())); t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object()));
} }
std::string tmpl = "{% for message in messages %}{{ message.role | string }} : {{ message.content if ('content' in message and message.content is not none) }}{% endfor %";
while (tmpl.length() > 0) {
t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object()));
tmpl.pop_back();
}
}); });
t.test("type coercion edge cases", [&](testing & t) { t.test("type coercion edge cases", [&](testing & t) {