From 23c9182ce872666aff6bd0fc571c3e3b6ae5fd89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 22 Mar 2026 17:45:10 +0100 Subject: [PATCH] jinja : refactor token advancement (#20864) * refactor token advancement * exercise sub-expressions --- common/jinja/parser.cpp | 62 +++++++++++++++++++++++------------------ tests/test-jinja.cpp | 6 ++++ 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/common/jinja/parser.cpp b/common/jinja/parser.cpp index 7970336ac0..4ae4477445 100644 --- a/common/jinja/parser.cpp +++ b/common/jinja/parser.cpp @@ -53,6 +53,13 @@ private: return tokens[current + offset]; } + const token & next() { + if (current >= tokens.size()) { + throw parser_exception("Parser Error: Unexpected EOF", source, tokens.empty() ? 0 : tokens.back().pos); + } + return tokens[current++]; + } + token expect(token::type type, const std::string& error) { const auto & t = peek(); if (t.t != type) { @@ -90,9 +97,9 @@ private: size_t start_pos = current; switch (peek().t) { case token::comment: - return mk_stmt(start_pos, tokens[current++].value); + return mk_stmt(start_pos, next().value); case token::text: - return mk_stmt(start_pos, tokens[current++].value); + return mk_stmt(start_pos, next().value); case token::open_statement: return parse_jinja_statement(); case token::open_expression: @@ -119,8 +126,7 @@ private: } size_t start_pos = current; - std::string name = peek().value; - current++; // consume identifier + std::string name = next().value; statement_ptr result; if (name == "set") { @@ -202,7 +208,7 @@ private: // Ignore generation blocks (transformers-specific) // See https://github.com/huggingface/transformers/pull/30650 for more information. result = mk_stmt(start_pos); - current++; + ++current; } else { throw std::runtime_error("Unknown statement: " + name); @@ -217,7 +223,7 @@ private: statements body; if (is(token::equals)) { - current++; + ++current; value = parse_expression_sequence(); } else { // parsing multiline set here @@ -280,7 +286,7 @@ private: exprs.push_back(primary ? parse_primary_expression() : parse_expression()); bool is_tuple = is(token::comma); while (is(token::comma)) { - current++; // consume comma + ++current; // consume comma exprs.push_back(primary ? parse_primary_expression() : parse_expression()); } return is_tuple ? mk_stmt(start_pos, std::move(exprs)) : std::move(exprs[0]); @@ -290,7 +296,7 @@ private: // e.g., `message` in `for message in messages` auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple if (!is_identifier("in")) throw std::runtime_error("Expected 'in'"); - current++; + ++current; // consume 'in' // `messages` in `for message in messages` auto iterable = parse_expression(); @@ -305,7 +311,8 @@ private: } if (is_statement({"else"})) { - current += 2; + ++current; // consume {% + ++current; // consume 'else' expect(token::close_statement, "Expected %}"); while (!is_statement({"endfor"})) { alternate.push_back(parse_any()); @@ -347,7 +354,7 @@ private: auto left = parse_logical_and_expression(); while (is_identifier("or")) { size_t start_pos = current; - token op = tokens[current++]; + token op = next(); left = mk_stmt(start_pos, op, std::move(left), parse_logical_and_expression()); } return left; @@ -357,7 +364,7 @@ private: auto left = parse_logical_negation_expression(); while (is_identifier("and")) { size_t start_pos = current; - auto op = tokens[current++]; + auto op = next(); left = mk_stmt(start_pos, op, std::move(left), parse_logical_negation_expression()); } return left; @@ -367,7 +374,7 @@ private: // Try parse unary operators if (is_identifier("not")) { size_t start_pos = current; - auto op = tokens[current++]; + auto op = next(); return mk_stmt(start_pos, op, parse_logical_negation_expression()); } return parse_comparison_expression(); @@ -382,11 +389,12 @@ private: size_t start_pos = current; if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") { op = {token::identifier, "not in", tokens[current].pos}; - current += 2; + ++current; // consume 'not' + ++current; // consume 'in' } else if (is_identifier("in")) { - op = tokens[current++]; + op = next(); } else if (is(token::comparison_binary_operator)) { - op = tokens[current++]; + op = next(); } else break; left = mk_stmt(start_pos, op, std::move(left), parse_additive_expression()); } @@ -397,7 +405,7 @@ private: auto left = parse_multiplicative_expression(); while (is(token::additive_binary_operator)) { size_t start_pos = current; - auto op = tokens[current++]; + auto op = next(); left = mk_stmt(start_pos, op, std::move(left), parse_multiplicative_expression()); } return left; @@ -407,7 +415,7 @@ private: auto left = parse_test_expression(); while (is(token::multiplicative_binary_operator)) { size_t start_pos = current; - auto op = tokens[current++]; + auto op = next(); left = mk_stmt(start_pos, op, std::move(left), parse_test_expression()); } return left; @@ -417,9 +425,9 @@ private: auto operand = parse_filter_expression(); while (is_identifier("is")) { size_t start_pos = current; - current++; + ++current; // consume 'is' bool negate = false; - if (is_identifier("not")) { current++; negate = true; } + if (is_identifier("not")) { ++current; negate = true; } auto test_id = parse_primary_expression(); // FIXME: tests can also be expressed like this: if x is eq 3 if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id)); @@ -432,7 +440,7 @@ private: auto operand = parse_call_member_expression(); while (is(token::pipe)) { size_t start_pos = current; - current++; + ++current; // consume pipe auto filter = parse_primary_expression(); if (is(token::open_paren)) filter = parse_call_expression(std::move(filter)); operand = mk_stmt(start_pos, std::move(operand), std::move(filter)); @@ -490,7 +498,7 @@ private: statement_ptr parse_member_expression(statement_ptr object) { size_t start_pos = current; while (is(token::dot) || is(token::open_square_bracket)) { - auto op = tokens[current++]; + auto op = next(); bool computed = op.t == token::open_square_bracket; statement_ptr prop; if (computed) { @@ -536,7 +544,7 @@ private: statement_ptr parse_primary_expression() { size_t start_pos = current; - auto t = tokens[current++]; + auto t = next(); switch (t.t) { case token::numeric_literal: if (t.value.find('.') != std::string::npos) { @@ -547,7 +555,7 @@ private: case token::string_literal: { std::string val = t.value; while (is(token::string_literal)) { - val += tokens[current++].value; + val += next().value; } return mk_stmt(start_pos, val); } @@ -562,9 +570,9 @@ private: statements vals; while (!is(token::close_square_bracket)) { vals.push_back(parse_expression()); - if (is(token::comma)) current++; + if (is(token::comma)) ++current; } - current++; + ++current; return mk_stmt(start_pos, std::move(vals)); } case token::open_curly_bracket: { @@ -573,9 +581,9 @@ private: auto key = parse_expression(); expect(token::colon, "Expected :"); pairs.push_back({std::move(key), parse_expression()}); - if (is(token::comma)) current++; + if (is(token::comma)) ++current; } - current++; + ++current; return mk_stmt(start_pos, std::move(pairs)); } default: diff --git a/tests/test-jinja.cpp b/tests/test-jinja.cpp index ef9c8f73c8..1550627bf0 100644 --- a/tests/test-jinja.cpp +++ b/tests/test-jinja.cpp @@ -2264,6 +2264,7 @@ static void test_fuzzing(testing & t) { t.test("malformed templates (should error, not crash)", [&](testing & t) { const std::vector malformed = { + "", "{{ x", "{% if %}", "{% for %}", @@ -2284,6 +2285,11 @@ static void test_fuzzing(testing & t) { for (const auto & tmpl : malformed) { t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object())); } + std::string tmpl = "{% for message in messages %}{{ message.role | string }} : {{ message.content if ('content' in message and message.content is not none) }}{% endfor %"; + while (tmpl.length() > 0) { + t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object())); + tmpl.pop_back(); + } }); t.test("type coercion edge cases", [&](testing & t) {