From 026730e8e3b029c45748421e5ae06c06f42e2321 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 29 Dec 2025 12:53:31 +0100 Subject: [PATCH] more fix, more tests --- common/jinja/jinja-lexer.cpp | 74 ++++++++++++++++++++++++++++++----- common/jinja/jinja-parser.cpp | 27 ++++++++++--- common/jinja/jinja-parser.h | 2 + common/jinja/jinja-value.cpp | 18 ++++----- common/jinja/jinja-vm.h | 2 +- tests/test-chat-jinja.cpp | 10 +++-- 6 files changed, 106 insertions(+), 27 deletions(-) diff --git a/common/jinja/jinja-lexer.cpp b/common/jinja/jinja-lexer.cpp index 541452f3fe..285ccc0151 100644 --- a/common/jinja/jinja-lexer.cpp +++ b/common/jinja/jinja-lexer.cpp @@ -1,4 +1,5 @@ #include "jinja-lexer.h" +#include "jinja-vm.h" #include #include @@ -7,13 +8,73 @@ #include #include #include +#include - -// #define JJ_DEBUG(msg, ...) printf("jinja-lexer: " msg "\n", __VA_ARGS__) -#define JJ_DEBUG(msg, ...) // no-op +#define FILENAME "jinja-lexer" namespace jinja { +// Trim template markers with '-' for whitespace control +// Example: [spaces]{%- ... -%} --> {% ... %} +#include +#include + +static void trim_template_markers_inplace(std::string & s) { + // i = head ; j = tail (i <= j) + size_t j = 0; // Write pointer + const size_t len = s.length(); + + for (size_t i = 0; i < len; ) { + bool handled = false; + + // We need at least 3 characters for any marker: {X- or -X} + if (i + 2 < len) { + const char c1 = s[i]; + const char c2 = s[i + 1]; + const char c3 = s[i + 2]; + + // 1. Closing trim: -X} where X = %, }, # + // Example: [content]-%} [spaces] -> [content]%} + if (c1 == '-' && c3 == '}' && (c2 == '%' || c2 == '}' || c2 == '#')) { + s[j++] = c2; + s[j++] = '}'; + i += 3; + // Strip leading whitespace AFTER the tag + while (i < len && std::isspace(static_cast(s[i]))) { + i++; + } + handled = true; + } + // 2. Opening trim: {X- where X = %, {, # + // Example: [spaces]{%- [content] -> {% [content] + else if (c1 == '{' && c3 == '-' && (c2 == '%' || c2 == '{' || c2 == '#')) { + // Trim trailing whitespace BEFORE the tag by moving write pointer back + while (j > 0 && std::isspace(static_cast(s[j - 1]))) { + j--; + } + + // Safety: Prevent merging '{' with tag start (avoid creating '{{%' or '{{{') + // if the character immediately before our new tag is a literal '{'. + if (j > 0 && s[j - 1] == '{') { + s[j++] = ' '; + } + + s[j++] = '{'; + s[j++] = c2; + i += 3; + handled = true; + } + } + + if (!handled) { + // Note: j is always <= i here, so this is safe. + s[j++] = s[i++]; + } + } + + s.resize(j); +} + std::string lexer::preprocess(const std::string & template_str, const preprocess_options & options) const { std::string result = template_str; // According to https://jinja.palletsprojects.com/en/3.0.x/templates/#whitespace-control @@ -40,12 +101,7 @@ std::string lexer::preprocess(const std::string & template_str, const preprocess } // Handle whitespace control with - in tags - result = std::regex_replace(result, std::regex(R"(-%\}\s*)"), "%}"); - result = std::regex_replace(result, std::regex(R"(\s*\{%-)"), "{%"); - result = std::regex_replace(result, std::regex(R"(-\}\}\s*)"), "}}"); - result = std::regex_replace(result, std::regex(R"(\s*\{\{-)"), "{{"); - result = std::regex_replace(result, std::regex(R"(-#\}\s*)"), "#}"); - result = std::regex_replace(result, std::regex(R"(\s*\{\#-)"), "{#"); + trim_template_markers_inplace(result); // Handle custom transformers-specific `generation` tag // See https://github.com/huggingface/transformers/pull/30650 for more information. diff --git a/common/jinja/jinja-parser.cpp b/common/jinja/jinja-parser.cpp index 5f42b0bd89..8cbb41eca6 100644 --- a/common/jinja/jinja-parser.cpp +++ b/common/jinja/jinja-parser.cpp @@ -26,8 +26,10 @@ class parser { // for debugging; a token can be multiple chars in source std::vector tok_pos_to_src_pos; + std::string source; // for error reporting + public: - parser(const std::vector & t) : tokens(t) { + parser(const std::vector & t, const std::string & src) : tokens(t), source(src) { tok_pos_to_src_pos.resize(tokens.size()); for (size_t i = 0; i < tokens.size(); i++) { tok_pos_to_src_pos[i] = tokens[i].pos; @@ -46,7 +48,16 @@ public: std::unique_ptr mk_stmt(Args&&... args) { auto ptr = std::make_unique(std::forward(args)...); ptr->pos = tok_pos_to_src_pos[prev_cur]; - JJ_DEBUG("Created %s statement at src pos %zu", ptr->type().c_str(), ptr->pos); + + std::string snippet = "no source"; + if (!source.empty()) { + size_t start_pos = ptr->pos; + size_t end_pos = start_pos + 20; + if (end_pos > source.size()) end_pos = source.size(); + snippet = source.substr(start_pos, end_pos - start_pos); + } + JJ_DEBUG("Created %-20s statement at src pos %-4zu (%s)", ptr->type().c_str(), ptr->pos, snippet.c_str()); + return ptr; } @@ -544,7 +555,9 @@ private: return mk_stmt(std::stoll(t.value)); case token::string_literal: { std::string val = t.value; - while (is(token::string_literal)) val += tokens[current++].value; + while (is(token::string_literal)) { + val += tokens[current++].value; + } return mk_stmt(val); } case token::identifier: @@ -575,13 +588,17 @@ private: return mk_stmt(std::move(pairs)); } default: - throw std::runtime_error("Unexpected token: " + t.value); + throw std::runtime_error("Unexpected token: " + t.value + " of type " + std::to_string(t.t)); } } }; program parse_from_tokens(const std::vector & tokens) { - return parser(tokens).parse(); + return parser(tokens, "").parse(); +} + +program parse_from_tokens(const lexer_result & lexer_res) { + return parser(lexer_res.tokens, lexer_res.preprocessed_source).parse(); } } // namespace jinja diff --git a/common/jinja/jinja-parser.h b/common/jinja/jinja-parser.h index ea212ad181..14ce135432 100644 --- a/common/jinja/jinja-parser.h +++ b/common/jinja/jinja-parser.h @@ -13,4 +13,6 @@ namespace jinja { program parse_from_tokens(const std::vector & tokens); +program parse_from_tokens(const lexer_result & lexer_res); + } // namespace jinja diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index 70cca62cff..218d893e26 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -131,23 +131,23 @@ const func_builtins & global_builtins() { if (args.args.size() < 1 || args.args.size() > 3) { throw raised_exception("slice() takes between 1 and 3 arguments"); } - int64_t arg0 = is_val(args.args[0]) ? args.args[0]->as_int() : 0; - int64_t arg1 = is_val(args.args[1]) ? args.args[1]->as_int() : -1; - int64_t arg2 = is_val(args.args[2]) ? args.args[2]->as_int() : 1; + auto & arg0 = args.args[0]; + auto & arg1 = args.args[1]; + auto & arg2 = args.args[2]; int64_t start, stop, step; if (args.args.size() == 1) { start = 0; - stop = arg0; + stop = arg0->as_int(); step = 1; } else if (args.args.size() == 2) { - start = arg0; - stop = arg1; + start = arg0->as_int(); + stop = arg1->as_int(); step = 1; } else { - start = arg0; - stop = arg1; - step = arg2; + start = arg0->as_int(); + stop = arg1->as_int(); + step = arg2->as_int(); } auto out = mk_val(); diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 045d45d980..02790945a9 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -10,7 +10,7 @@ #include #include -#define JJ_DEBUG(msg, ...) if (g_jinja_debug) printf("%s:%3d : " msg "\n", FILENAME, __LINE__, __VA_ARGS__) +#define JJ_DEBUG(msg, ...) if (g_jinja_debug) printf("%s:%-3d : " msg "\n", FILENAME, __LINE__, __VA_ARGS__) extern bool g_jinja_debug; diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 61ce80d8ac..f16ebb9e07 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -28,6 +28,8 @@ int main(void) { std::vector failed_tests; + bool stop_on_first_failure = false; + auto is_ignored_file = [](const std::string & filename) -> bool { std::vector ignored_files = { "Apriel-", @@ -64,7 +66,9 @@ int main(void) { std::cout << "Exception: " << e.what() << "\n"; std::cout << "=== ERROR WITH TEMPLATE FILE: " << entry.path().string() << " ===\n"; failed_tests.push_back(entry.path().string()); - exit(1); + if (stop_on_first_failure) { + break; + } } } } @@ -85,7 +89,7 @@ void run(std::string contents) { jinja::lexer lexer; jinja::preprocess_options options; - options.trim_blocks = true; + options.trim_blocks = false; options.lstrip_blocks = false; auto lexer_res = lexer.tokenize(contents, options); for (const auto & tok : lexer_res.tokens) { @@ -93,7 +97,7 @@ void run(std::string contents) { } std::cout << "\n=== AST ===\n"; - jinja::program ast = jinja::parse_from_tokens(lexer_res.tokens); + jinja::program ast = jinja::parse_from_tokens(lexer_res); for (const auto & stmt : ast.body) { //std::cout << "stmt type: " << stmt->type() << "\n"; }