From 15b7c50e95f4824e30b3edf7a5689809a8c3fa3e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 25 Dec 2025 21:08:51 +0100 Subject: [PATCH] lexer --- common/jinja/jinja-compiler.h | 79 ++++++++ common/jinja/jinja-lexer.h | 336 ++++++++++++++++++++++++++++++++++ tests/CMakeLists.txt | 1 + tests/test-chat-jinja.cpp | 55 ++++++ 4 files changed, 471 insertions(+) create mode 100644 common/jinja/jinja-compiler.h create mode 100644 common/jinja/jinja-lexer.h create mode 100644 tests/test-chat-jinja.cpp diff --git a/common/jinja/jinja-compiler.h b/common/jinja/jinja-compiler.h new file mode 100644 index 0000000000..32792d6050 --- /dev/null +++ b/common/jinja/jinja-compiler.h @@ -0,0 +1,79 @@ +#include "common.h" +#include +#include + +namespace jinja { + +struct compiler { + common_chat_peg_native_builder builder; + common_peg_parser root; + + compiler() : root(builder.choice()) { + auto & p = builder; + + auto ws = p.rule("ws", p.chars("[ \t]", 0, -1)); + auto num = p.rule("num", p.chars("[0-9]", 1, -1)); + + // + // expressions + // + + auto expression = p.choice(); + + auto var_name = p.rule("var_name", p.chars("[a-zA-Z_]", 1, -1) << p.chars("[a-zA-Z0-9_]", 0, -1)); + expression |= var_name; + + // value + auto p_int = p.rule("value_int", num); + auto p_flt = p.rule("value_flt", num << "." << p.optional(num)); + auto p_str = p.rule("value_str", + p.json_string() | + p.literal("'") + p.chars("[^']*", 0, -1) + p.literal("'") + ); + + expression |= p_int; + expression |= p_flt; + expression |= p_str; + + // function calls + auto p_args = p.rule("args", expression << ws << p.zero_or_more("," << ws << expression)); + auto p_func = p.rule("func", ws << var_name << ws << "(" << ws << p_args << ws << ")"); + expression |= p_func; + + // indexing + auto p_idx = p.rule("idx", ws << "[" << ws << expression << ws << "]"); + expression |= p_idx; + + // set + auto p_set = p.rule("set", "set " << ws << var_name << ws << "=" << expression); + expression |= p_set; + + // if, else, endif + auto p_if = p.rule("if", "if " << ws << expression << ws); + auto p_else = p.rule("else", "else " << ws << expression << ws); + auto p_endif = p.rule("endif", p.literal("endif")); + + expression |= p_if; + expression |= p_else; + expression |= p_endif; + + expression = p.space() + expression + p.space(); + + // + // root + // + + // auto strip = p.rule("strip", "-" << expression << "-"); + auto print = p.rule("print", "{{" << (expression) << "}}"); + auto ctrl = p.rule("ctrl", "{%" << (expression) << "%}"); + + root |= print; + root |= ctrl; + root |= p.rule("text", p.negate(root)); + + root = p.one_or_more(root); + root += p.end(); + } +}; + +} // namespace jinja diff --git a/common/jinja/jinja-lexer.h b/common/jinja/jinja-lexer.h new file mode 100644 index 0000000000..62850a33f6 --- /dev/null +++ b/common/jinja/jinja-lexer.h @@ -0,0 +1,336 @@ +#include +#include +#include +#include +#include +#include +#include + +// #define JJ_DEBUG(msg, ...) printf("jinja-lexer: " msg "\n", __VA_ARGS__) +#define JJ_DEBUG(msg, ...) // no-op + +namespace jinja { + +struct preprocess_options { + bool trim_blocks = false; + bool lstrip_blocks = false; +}; + +struct token { + enum type { + undefined, + text, // The text between Jinja statements or expressions + + numeric_literal, // e.g., 123, 1.0 + string_literal, // 'string' + identifier, // Variables, functions, statements, booleans, etc. + equals, // = + open_paren, // ( + close_paren, // ) + open_statement, // {% + close_statement, // %} + open_expression, // {{ + close_expression, // }} + open_square_bracket, // [ + close_square_bracket, // ] + open_curly_bracket, // { + close_curly_bracket, // } + comma, // , + dot, // . + colon, // : + pipe, // | + + call_operator, // () + additive_binary_operator, // + - ~ + multiplicative_binary_operator, // * / % + comparison_binary_operator, // < > <= >= == != + unary_operator, // ! - + + comment, // {# ... #} + }; + type t; + std::string value; +}; + +struct lexer { + const std::map escape_chars = { + {'n', '\n'}, + {'t', '\t'}, + {'r', '\r'}, + {'b', '\b'}, + {'f', '\f'}, + {'v', '\v'}, + {'\\', '\\'}, + {'\'', '\''}, + {'\"', '\"'}, + }; + + static bool is_word(char c) { + return std::isalnum(static_cast(c)) || c == '_'; + } + + static bool is_integer(char c) { + return std::isdigit(static_cast(c)); + } + + const std::vector> ordered_mapping_table = { + // Control sequences + {"{%", token::open_statement}, + {"%}", token::close_statement}, + {"{{", token::open_expression}, + {"}}", token::close_expression}, + // Single character tokens + {"(", token::open_paren}, + {")", token::close_paren}, + {"{", token::open_curly_bracket}, + {"}", token::close_curly_bracket}, + {"[", token::open_square_bracket}, + {"]", token::close_square_bracket}, + {",", token::comma}, + {".", token::dot}, + {":", token::colon}, + {"|", token::pipe}, + // Comparison operators + {"<=", token::comparison_binary_operator}, + {">=", token::comparison_binary_operator}, + {"==", token::comparison_binary_operator}, + {"!=", token::comparison_binary_operator}, + {"<", token::comparison_binary_operator}, + {">", token::comparison_binary_operator}, + // Arithmetic operators + {"+", token::additive_binary_operator}, + {"-", token::additive_binary_operator}, + {"~", token::additive_binary_operator}, + {"*", token::multiplicative_binary_operator}, + {"/", token::multiplicative_binary_operator}, + {"%", token::multiplicative_binary_operator}, + // Assignment operator + {"=", token::equals}, + }; + + std::string preprocess(const std::string& template_str, const preprocess_options& options) const { + std::string result = template_str; + // According to https://jinja.palletsprojects.com/en/3.0.x/templates/#whitespace-control + + // In the default configuration: + // - a single trailing newline is stripped if present + // - other whitespace (spaces, tabs, newlines etc.) is returned unchanged + if (!result.empty() && result.back() == '\n') { + result.pop_back(); + } + + if (options.lstrip_blocks) { + // The lstrip_blocks option can also be set to strip tabs and spaces from the + // beginning of a line to the start of a block. (Nothing will be stripped if + // there are other characters before the start of the block.) + // result = std::regex_replace(result, std::regex(R"((?m)^[ \t]*(\{[#%-]))"), "$1"); + throw std::runtime_error("lstrip_blocks option is not implemented yet"); + } + + if (options.trim_blocks) { + // If an application configures Jinja to trim_blocks, the first newline after + // a template tag is removed automatically (like in PHP). + result = std::regex_replace(result, std::regex(R"(([#%-]\})\n)"), "$1"); + } + + // Handle whitespace control with - in tags + result = std::regex_replace(result, std::regex(R"(-%\}\s*)"), "%}"); + result = std::regex_replace(result, std::regex(R"(\s*\{%-)"), "{%"); + result = std::regex_replace(result, std::regex(R"(-\}\}\s*)"), "}}"); + result = std::regex_replace(result, std::regex(R"(\s*\{\{-)"), "{{"); + result = std::regex_replace(result, std::regex(R"(-#\}\s*)"), "#}"); + result = std::regex_replace(result, std::regex(R"(\s*\{\#-)"), "{#"); + + // Handle custom transformers-specific `generation` tag + // See https://github.com/huggingface/transformers/pull/30650 for more information. + // result = std::regex_replace(result, std::regex(R"((?s)\{%\s*generation\s*%\}.+?\{%\s*endgeneration\s*%\})"), ""); + + return result; + } + + std::vector tokenize(const std::string & input, const preprocess_options & options = {}) { + std::vector tokens; + std::string src = preprocess(input, options); + JJ_DEBUG("preprocessed input: '%s'", src.c_str()); + + size_t pos = 0; + size_t curly_bracket_depth = 0; + + using pred = std::function; + auto consume_while = [&](pred predicate) -> std::string { + std::string str; + while (predicate(src[pos])) { + // check for escape char + if (src[pos] == '\\') { + // consume backslash + ++pos; + // check for end of input + if (pos >= src.size()) { + throw std::runtime_error("lexer: unexpected end of input after escape character"); + } + // add escaped char + char escaped_char = src[pos++]; + if (escape_chars.find(escaped_char) == escape_chars.end()) { + throw std::runtime_error(std::string("lexer: unknown escape character \\") + escaped_char); + } + char unescaped_char = escape_chars.at(escaped_char); + str += unescaped_char; + continue; + } + + str += src[pos++]; + if (pos > src.size()) { + throw std::runtime_error("lexer: unexpected end of input during consume_while"); + } + } + return str; + }; + + auto next_pos_is = [&](std::initializer_list chars) -> bool { + if (pos + 1 >= src.size()) return false; + for (char c : chars) { + if (src[pos + 1] == c) return true; + } + return false; + }; + + while (pos < src.size()) { + JJ_DEBUG("lexer main loop at pos %zu: '%s...'", pos, src.substr(pos, 10).c_str()); + + // First, consume all text that is outside of a Jinja statement or expression + token::type last_token_type = tokens.empty() + ? token::undefined + : tokens.back().t; + if (last_token_type == token::undefined || + last_token_type == token::close_statement || + last_token_type == token::close_expression || + last_token_type == token::comment) { + std::string text; + while (pos < src.size() && + // Keep going until we hit the next Jinja statement or expression + !( + src[pos] == '{' && + next_pos_is( {'%', '{', '#'} ) + )) { + text += src[pos++]; + } + JJ_DEBUG("consumed text: '%s'", text.c_str()); + if (!text.empty()) { + tokens.push_back({token::text, text}); + continue; + } + } + + // Possibly consume a comment + if (src[pos] == '{' && next_pos_is( {'#'} )) { + pos += 2; // Skip the opening {# + std::string comment; + while (!(src[pos] == '#' && next_pos_is( {'}'} ))) { + if (pos + 2 >= src.size()) { + throw std::runtime_error("lexer: missing end of comment tag"); + } + comment += src[pos++]; + } + JJ_DEBUG("consumed comment: '%s'", comment.c_str()); + tokens.push_back({token::comment, comment}); + pos += 2; // Skip the closing #} + continue; + } + + // Consume (and ignore) all whitespace inside Jinja statements or expressions + consume_while([](char c) { return std::isspace(static_cast(c)); }); + + if (pos >= src.size()) break; + + char ch = src[pos]; + + // Check for unary operators + if (ch == '-' || ch == '+') { + token::type last_token_type = tokens.empty() ? token::undefined : tokens.back().t; + if (last_token_type == token::text || last_token_type == token::undefined) { + throw std::runtime_error(std::string("lexer: unexpected character: ") + ch); + } + switch (last_token_type) { + case token::identifier: + case token::numeric_literal: + case token::string_literal: + case token::close_paren: + case token::close_square_bracket: + // Part of a binary operator + // a - 1, 1 - 1, true - 1, "apple" - 1, (1) - 1, a[1] - 1 + // Continue parsing normally + break; + default: { + // Is part of a unary operator + // (-1), [-1], (1 + -1), not -1, -apple + ++pos; // Consume the operator + + // Check for numbers following the unary operator + std::string num = consume_while(is_integer); + std::string value = std::string(1, ch) + num; + token::type t = num.empty() ? token::unary_operator : token::numeric_literal; + JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str()); + tokens.push_back({t, value}); + continue; + } + } + } + + // Try to match one of the tokens in the mapping table + bool matched = false; + for (const auto & [seq, typ] : ordered_mapping_table) { + // Inside an object literal, don't treat "}}" as expression-end + if (seq == "}}" && curly_bracket_depth > 0) { + continue; + } + if (pos + seq.size() <= src.size() && src.substr(pos, seq.size()) == seq) { + tokens.push_back({typ, seq}); + if (typ == token::open_expression) { + curly_bracket_depth = 0; + } else if (typ == token::open_curly_bracket) { + ++curly_bracket_depth; + } else if (typ == token::close_curly_bracket) { + --curly_bracket_depth; + } + pos += seq.size(); + matched = true; + break; // continue main loop + } + } + if (matched) continue; // continue main loop + + // Strings + if (ch == '\'' || ch == '"') { + ++pos; // Skip opening quote + std::string str = consume_while([ch](char c) { return c != ch; }); + tokens.push_back({token::string_literal, str}); + ++pos; // Skip closing quote + continue; + } + + // Numbers + if (is_integer(ch)) { + std::string num = consume_while(is_integer); + if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) { + ++pos; // Consume '.' + std::string frac = consume_while(is_integer); + num += "." + frac; + } + tokens.push_back({token::numeric_literal, num}); + continue; + } + + // Identifiers + if (is_word(ch)) { + std::string word = consume_while(is_word); + tokens.push_back({token::identifier, word}); + continue; + } + + throw std::runtime_error(std::string("lexer: unexpected character: ") + ch); + } + + return tokens; + } +}; + +} // namespace jinja diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c3d9f9c324..f86a5b6657 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -186,6 +186,7 @@ endif() llama_build_and_test(test-chat-parser.cpp) llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp) llama_build_and_test(test-chat-template.cpp) +llama_build_and_test(test-chat-jinja.cpp) llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) llama_build_and_test( diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp new file mode 100644 index 0000000000..9fa0c7c817 --- /dev/null +++ b/tests/test-chat-jinja.cpp @@ -0,0 +1,55 @@ +#include +#include +#include +#include +#include + +#undef NDEBUG +#include + +#include "peg-parser.h" +#include "json-schema-to-grammar.h" +#include "jinja/jinja-compiler.h" +#include "jinja/jinja-lexer.h" + +int main(void) { + std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; + + std::cout << "=== INPUT ===\n" << contents << "\n\n"; + + jinja::lexer lexer; + jinja::preprocess_options options; + options.trim_blocks = true; + options.lstrip_blocks = false; + auto tokens = lexer.tokenize(contents, options); + for (const auto & tok : tokens) { + std::cout << "token: type=" << static_cast(tok.t) << " text='" << tok.value << "'\n"; + } + + // jinja::compiler compiler; + // compiler.builder.set_root(compiler.root); + // auto parser = compiler.builder.build(); + + // auto grammar = build_grammar([&](const common_grammar_builder & builder0) { + // parser.build_grammar(builder0); + // }); + // printf("== GRAMMAR ==\n"); + // printf("%s\n", grammar.c_str()); + + // // printf("== DUMP ==\n"); + // // printf("%s\n", parser.dump(compiler.root.id()).c_str()); + + // printf("== PARSE ==\n"); + + // common_peg_parse_context ctx(contents); + // const auto result = parser.parse(ctx); + // if (!result.success()) { + // throw std::runtime_error("failed to parse, type = " + std::to_string(result.type)); + // } + + // ctx.ast.visit(result, [&](const common_peg_ast_node & node) { + // printf("node: rule='%s' text='%s'\n", node.rule.c_str(), std::string(node.text).c_str()); + // }); + + return 0; +}