jinja : refactor token advancement (#20864)

* refactor token advancement

* exercise sub-expressions
This commit is contained in:
Sigbjørn Skjæret 2026-03-22 17:45:10 +01:00 committed by GitHub
parent 81bc4d3ddc
commit 23c9182ce8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 27 deletions

View File

@ -53,6 +53,13 @@ private:
return tokens[current + offset];
}
const token & next() {
if (current >= tokens.size()) {
throw parser_exception("Parser Error: Unexpected EOF", source, tokens.empty() ? 0 : tokens.back().pos);
}
return tokens[current++];
}
token expect(token::type type, const std::string& error) {
const auto & t = peek();
if (t.t != type) {
@ -90,9 +97,9 @@ private:
size_t start_pos = current;
switch (peek().t) {
case token::comment:
return mk_stmt<comment_statement>(start_pos, tokens[current++].value);
return mk_stmt<comment_statement>(start_pos, next().value);
case token::text:
return mk_stmt<string_literal>(start_pos, tokens[current++].value);
return mk_stmt<string_literal>(start_pos, next().value);
case token::open_statement:
return parse_jinja_statement();
case token::open_expression:
@ -119,8 +126,7 @@ private:
}
size_t start_pos = current;
std::string name = peek().value;
current++; // consume identifier
std::string name = next().value;
statement_ptr result;
if (name == "set") {
@ -202,7 +208,7 @@ private:
// Ignore generation blocks (transformers-specific)
// See https://github.com/huggingface/transformers/pull/30650 for more information.
result = mk_stmt<noop_statement>(start_pos);
current++;
++current;
} else {
throw std::runtime_error("Unknown statement: " + name);
@ -217,7 +223,7 @@ private:
statements body;
if (is(token::equals)) {
current++;
++current;
value = parse_expression_sequence();
} else {
// parsing multiline set here
@ -280,7 +286,7 @@ private:
exprs.push_back(primary ? parse_primary_expression() : parse_expression());
bool is_tuple = is(token::comma);
while (is(token::comma)) {
current++; // consume comma
++current; // consume comma
exprs.push_back(primary ? parse_primary_expression() : parse_expression());
}
return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]);
@ -290,7 +296,7 @@ private:
// e.g., `message` in `for message in messages`
auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple
if (!is_identifier("in")) throw std::runtime_error("Expected 'in'");
current++;
++current; // consume 'in'
// `messages` in `for message in messages`
auto iterable = parse_expression();
@ -305,7 +311,8 @@ private:
}
if (is_statement({"else"})) {
current += 2;
++current; // consume {%
++current; // consume 'else'
expect(token::close_statement, "Expected %}");
while (!is_statement({"endfor"})) {
alternate.push_back(parse_any());
@ -347,7 +354,7 @@ private:
auto left = parse_logical_and_expression();
while (is_identifier("or")) {
size_t start_pos = current;
token op = tokens[current++];
token op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression());
}
return left;
@ -357,7 +364,7 @@ private:
auto left = parse_logical_negation_expression();
while (is_identifier("and")) {
size_t start_pos = current;
auto op = tokens[current++];
auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression());
}
return left;
@ -367,7 +374,7 @@ private:
// Try parse unary operators
if (is_identifier("not")) {
size_t start_pos = current;
auto op = tokens[current++];
auto op = next();
return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression());
}
return parse_comparison_expression();
@ -382,11 +389,12 @@ private:
size_t start_pos = current;
if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") {
op = {token::identifier, "not in", tokens[current].pos};
current += 2;
++current; // consume 'not'
++current; // consume 'in'
} else if (is_identifier("in")) {
op = tokens[current++];
op = next();
} else if (is(token::comparison_binary_operator)) {
op = tokens[current++];
op = next();
} else break;
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression());
}
@ -397,7 +405,7 @@ private:
auto left = parse_multiplicative_expression();
while (is(token::additive_binary_operator)) {
size_t start_pos = current;
auto op = tokens[current++];
auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression());
}
return left;
@ -407,7 +415,7 @@ private:
auto left = parse_test_expression();
while (is(token::multiplicative_binary_operator)) {
size_t start_pos = current;
auto op = tokens[current++];
auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression());
}
return left;
@ -417,9 +425,9 @@ private:
auto operand = parse_filter_expression();
while (is_identifier("is")) {
size_t start_pos = current;
current++;
++current; // consume 'is'
bool negate = false;
if (is_identifier("not")) { current++; negate = true; }
if (is_identifier("not")) { ++current; negate = true; }
auto test_id = parse_primary_expression();
// FIXME: tests can also be expressed like this: if x is eq 3
if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id));
@ -432,7 +440,7 @@ private:
auto operand = parse_call_member_expression();
while (is(token::pipe)) {
size_t start_pos = current;
current++;
++current; // consume pipe
auto filter = parse_primary_expression();
if (is(token::open_paren)) filter = parse_call_expression(std::move(filter));
operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter));
@ -490,7 +498,7 @@ private:
statement_ptr parse_member_expression(statement_ptr object) {
size_t start_pos = current;
while (is(token::dot) || is(token::open_square_bracket)) {
auto op = tokens[current++];
auto op = next();
bool computed = op.t == token::open_square_bracket;
statement_ptr prop;
if (computed) {
@ -536,7 +544,7 @@ private:
statement_ptr parse_primary_expression() {
size_t start_pos = current;
auto t = tokens[current++];
auto t = next();
switch (t.t) {
case token::numeric_literal:
if (t.value.find('.') != std::string::npos) {
@ -547,7 +555,7 @@ private:
case token::string_literal: {
std::string val = t.value;
while (is(token::string_literal)) {
val += tokens[current++].value;
val += next().value;
}
return mk_stmt<string_literal>(start_pos, val);
}
@ -562,9 +570,9 @@ private:
statements vals;
while (!is(token::close_square_bracket)) {
vals.push_back(parse_expression());
if (is(token::comma)) current++;
if (is(token::comma)) ++current;
}
current++;
++current;
return mk_stmt<array_literal>(start_pos, std::move(vals));
}
case token::open_curly_bracket: {
@ -573,9 +581,9 @@ private:
auto key = parse_expression();
expect(token::colon, "Expected :");
pairs.push_back({std::move(key), parse_expression()});
if (is(token::comma)) current++;
if (is(token::comma)) ++current;
}
current++;
++current;
return mk_stmt<object_literal>(start_pos, std::move(pairs));
}
default:

View File

@ -2264,6 +2264,7 @@ static void test_fuzzing(testing & t) {
t.test("malformed templates (should error, not crash)", [&](testing & t) {
const std::vector<std::string> malformed = {
"",
"{{ x",
"{% if %}",
"{% for %}",
@ -2284,6 +2285,11 @@ static void test_fuzzing(testing & t) {
for (const auto & tmpl : malformed) {
t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object()));
}
std::string tmpl = "{% for message in messages %}{{ message.role | string }} : {{ message.content if ('content' in message and message.content is not none) }}{% endfor %";
while (tmpl.length() > 0) {
t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object()));
tmpl.pop_back();
}
});
t.test("type coercion edge cases", [&](testing & t) {