From 8d8030142e57bec1d69dd7e128ba864528cef954 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 25 Dec 2025 00:19:23 +0100 Subject: [PATCH 01/47] jinja vm --- common/jinja/jinja-vm.cpp | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 common/jinja/jinja-vm.cpp diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp new file mode 100644 index 0000000000..7c8d0cf732 --- /dev/null +++ b/common/jinja/jinja-vm.cpp @@ -0,0 +1,28 @@ +#include +#include + +struct vm_context { + std::ostringstream out; +}; + +struct op_base { + virtual ~op_base() = default; + virtual void execute(vm_context & ctx) = 0; +}; + +struct op_print : public op_base { + std::string message; + op_print(const std::string & message) : message(message) {} + void execute(vm_context & ctx) override { + ctx.out << message; + } +}; + +struct op_load : public op_base { + std::string dst; + std::string src; + std::string value; + op_load(const std::string & dst) : dst(dst) {} + void execute(vm_context & ctx) override { + } +}; From 15b7c50e95f4824e30b3edf7a5689809a8c3fa3e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 25 Dec 2025 21:08:51 +0100 Subject: [PATCH 02/47] lexer --- common/jinja/jinja-compiler.h | 79 ++++++++ common/jinja/jinja-lexer.h | 336 ++++++++++++++++++++++++++++++++++ tests/CMakeLists.txt | 1 + tests/test-chat-jinja.cpp | 55 ++++++ 4 files changed, 471 insertions(+) create mode 100644 common/jinja/jinja-compiler.h create mode 100644 common/jinja/jinja-lexer.h create mode 100644 tests/test-chat-jinja.cpp diff --git a/common/jinja/jinja-compiler.h b/common/jinja/jinja-compiler.h new file mode 100644 index 0000000000..32792d6050 --- /dev/null +++ b/common/jinja/jinja-compiler.h @@ -0,0 +1,79 @@ +#include "common.h" +#include +#include + +namespace jinja { + +struct compiler { + common_chat_peg_native_builder builder; + common_peg_parser root; + + compiler() : root(builder.choice()) { + auto & p = builder; + + auto ws = p.rule("ws", p.chars("[ \t]", 0, -1)); + auto num = p.rule("num", p.chars("[0-9]", 1, -1)); + + // + // expressions + // + + auto expression = p.choice(); + + auto var_name = p.rule("var_name", p.chars("[a-zA-Z_]", 1, -1) << p.chars("[a-zA-Z0-9_]", 0, -1)); + expression |= var_name; + + // value + auto p_int = p.rule("value_int", num); + auto p_flt = p.rule("value_flt", num << "." << p.optional(num)); + auto p_str = p.rule("value_str", + p.json_string() | + p.literal("'") + p.chars("[^']*", 0, -1) + p.literal("'") + ); + + expression |= p_int; + expression |= p_flt; + expression |= p_str; + + // function calls + auto p_args = p.rule("args", expression << ws << p.zero_or_more("," << ws << expression)); + auto p_func = p.rule("func", ws << var_name << ws << "(" << ws << p_args << ws << ")"); + expression |= p_func; + + // indexing + auto p_idx = p.rule("idx", ws << "[" << ws << expression << ws << "]"); + expression |= p_idx; + + // set + auto p_set = p.rule("set", "set " << ws << var_name << ws << "=" << expression); + expression |= p_set; + + // if, else, endif + auto p_if = p.rule("if", "if " << ws << expression << ws); + auto p_else = p.rule("else", "else " << ws << expression << ws); + auto p_endif = p.rule("endif", p.literal("endif")); + + expression |= p_if; + expression |= p_else; + expression |= p_endif; + + expression = p.space() + expression + p.space(); + + // + // root + // + + // auto strip = p.rule("strip", "-" << expression << "-"); + auto print = p.rule("print", "{{" << (expression) << "}}"); + auto ctrl = p.rule("ctrl", "{%" << (expression) << "%}"); + + root |= print; + root |= ctrl; + root |= p.rule("text", p.negate(root)); + + root = p.one_or_more(root); + root += p.end(); + } +}; + +} // namespace jinja diff --git a/common/jinja/jinja-lexer.h b/common/jinja/jinja-lexer.h new file mode 100644 index 0000000000..62850a33f6 --- /dev/null +++ b/common/jinja/jinja-lexer.h @@ -0,0 +1,336 @@ +#include +#include +#include +#include +#include +#include +#include + +// #define JJ_DEBUG(msg, ...) printf("jinja-lexer: " msg "\n", __VA_ARGS__) +#define JJ_DEBUG(msg, ...) // no-op + +namespace jinja { + +struct preprocess_options { + bool trim_blocks = false; + bool lstrip_blocks = false; +}; + +struct token { + enum type { + undefined, + text, // The text between Jinja statements or expressions + + numeric_literal, // e.g., 123, 1.0 + string_literal, // 'string' + identifier, // Variables, functions, statements, booleans, etc. + equals, // = + open_paren, // ( + close_paren, // ) + open_statement, // {% + close_statement, // %} + open_expression, // {{ + close_expression, // }} + open_square_bracket, // [ + close_square_bracket, // ] + open_curly_bracket, // { + close_curly_bracket, // } + comma, // , + dot, // . + colon, // : + pipe, // | + + call_operator, // () + additive_binary_operator, // + - ~ + multiplicative_binary_operator, // * / % + comparison_binary_operator, // < > <= >= == != + unary_operator, // ! - + + comment, // {# ... #} + }; + type t; + std::string value; +}; + +struct lexer { + const std::map escape_chars = { + {'n', '\n'}, + {'t', '\t'}, + {'r', '\r'}, + {'b', '\b'}, + {'f', '\f'}, + {'v', '\v'}, + {'\\', '\\'}, + {'\'', '\''}, + {'\"', '\"'}, + }; + + static bool is_word(char c) { + return std::isalnum(static_cast(c)) || c == '_'; + } + + static bool is_integer(char c) { + return std::isdigit(static_cast(c)); + } + + const std::vector> ordered_mapping_table = { + // Control sequences + {"{%", token::open_statement}, + {"%}", token::close_statement}, + {"{{", token::open_expression}, + {"}}", token::close_expression}, + // Single character tokens + {"(", token::open_paren}, + {")", token::close_paren}, + {"{", token::open_curly_bracket}, + {"}", token::close_curly_bracket}, + {"[", token::open_square_bracket}, + {"]", token::close_square_bracket}, + {",", token::comma}, + {".", token::dot}, + {":", token::colon}, + {"|", token::pipe}, + // Comparison operators + {"<=", token::comparison_binary_operator}, + {">=", token::comparison_binary_operator}, + {"==", token::comparison_binary_operator}, + {"!=", token::comparison_binary_operator}, + {"<", token::comparison_binary_operator}, + {">", token::comparison_binary_operator}, + // Arithmetic operators + {"+", token::additive_binary_operator}, + {"-", token::additive_binary_operator}, + {"~", token::additive_binary_operator}, + {"*", token::multiplicative_binary_operator}, + {"/", token::multiplicative_binary_operator}, + {"%", token::multiplicative_binary_operator}, + // Assignment operator + {"=", token::equals}, + }; + + std::string preprocess(const std::string& template_str, const preprocess_options& options) const { + std::string result = template_str; + // According to https://jinja.palletsprojects.com/en/3.0.x/templates/#whitespace-control + + // In the default configuration: + // - a single trailing newline is stripped if present + // - other whitespace (spaces, tabs, newlines etc.) is returned unchanged + if (!result.empty() && result.back() == '\n') { + result.pop_back(); + } + + if (options.lstrip_blocks) { + // The lstrip_blocks option can also be set to strip tabs and spaces from the + // beginning of a line to the start of a block. (Nothing will be stripped if + // there are other characters before the start of the block.) + // result = std::regex_replace(result, std::regex(R"((?m)^[ \t]*(\{[#%-]))"), "$1"); + throw std::runtime_error("lstrip_blocks option is not implemented yet"); + } + + if (options.trim_blocks) { + // If an application configures Jinja to trim_blocks, the first newline after + // a template tag is removed automatically (like in PHP). + result = std::regex_replace(result, std::regex(R"(([#%-]\})\n)"), "$1"); + } + + // Handle whitespace control with - in tags + result = std::regex_replace(result, std::regex(R"(-%\}\s*)"), "%}"); + result = std::regex_replace(result, std::regex(R"(\s*\{%-)"), "{%"); + result = std::regex_replace(result, std::regex(R"(-\}\}\s*)"), "}}"); + result = std::regex_replace(result, std::regex(R"(\s*\{\{-)"), "{{"); + result = std::regex_replace(result, std::regex(R"(-#\}\s*)"), "#}"); + result = std::regex_replace(result, std::regex(R"(\s*\{\#-)"), "{#"); + + // Handle custom transformers-specific `generation` tag + // See https://github.com/huggingface/transformers/pull/30650 for more information. + // result = std::regex_replace(result, std::regex(R"((?s)\{%\s*generation\s*%\}.+?\{%\s*endgeneration\s*%\})"), ""); + + return result; + } + + std::vector tokenize(const std::string & input, const preprocess_options & options = {}) { + std::vector tokens; + std::string src = preprocess(input, options); + JJ_DEBUG("preprocessed input: '%s'", src.c_str()); + + size_t pos = 0; + size_t curly_bracket_depth = 0; + + using pred = std::function; + auto consume_while = [&](pred predicate) -> std::string { + std::string str; + while (predicate(src[pos])) { + // check for escape char + if (src[pos] == '\\') { + // consume backslash + ++pos; + // check for end of input + if (pos >= src.size()) { + throw std::runtime_error("lexer: unexpected end of input after escape character"); + } + // add escaped char + char escaped_char = src[pos++]; + if (escape_chars.find(escaped_char) == escape_chars.end()) { + throw std::runtime_error(std::string("lexer: unknown escape character \\") + escaped_char); + } + char unescaped_char = escape_chars.at(escaped_char); + str += unescaped_char; + continue; + } + + str += src[pos++]; + if (pos > src.size()) { + throw std::runtime_error("lexer: unexpected end of input during consume_while"); + } + } + return str; + }; + + auto next_pos_is = [&](std::initializer_list chars) -> bool { + if (pos + 1 >= src.size()) return false; + for (char c : chars) { + if (src[pos + 1] == c) return true; + } + return false; + }; + + while (pos < src.size()) { + JJ_DEBUG("lexer main loop at pos %zu: '%s...'", pos, src.substr(pos, 10).c_str()); + + // First, consume all text that is outside of a Jinja statement or expression + token::type last_token_type = tokens.empty() + ? token::undefined + : tokens.back().t; + if (last_token_type == token::undefined || + last_token_type == token::close_statement || + last_token_type == token::close_expression || + last_token_type == token::comment) { + std::string text; + while (pos < src.size() && + // Keep going until we hit the next Jinja statement or expression + !( + src[pos] == '{' && + next_pos_is( {'%', '{', '#'} ) + )) { + text += src[pos++]; + } + JJ_DEBUG("consumed text: '%s'", text.c_str()); + if (!text.empty()) { + tokens.push_back({token::text, text}); + continue; + } + } + + // Possibly consume a comment + if (src[pos] == '{' && next_pos_is( {'#'} )) { + pos += 2; // Skip the opening {# + std::string comment; + while (!(src[pos] == '#' && next_pos_is( {'}'} ))) { + if (pos + 2 >= src.size()) { + throw std::runtime_error("lexer: missing end of comment tag"); + } + comment += src[pos++]; + } + JJ_DEBUG("consumed comment: '%s'", comment.c_str()); + tokens.push_back({token::comment, comment}); + pos += 2; // Skip the closing #} + continue; + } + + // Consume (and ignore) all whitespace inside Jinja statements or expressions + consume_while([](char c) { return std::isspace(static_cast(c)); }); + + if (pos >= src.size()) break; + + char ch = src[pos]; + + // Check for unary operators + if (ch == '-' || ch == '+') { + token::type last_token_type = tokens.empty() ? token::undefined : tokens.back().t; + if (last_token_type == token::text || last_token_type == token::undefined) { + throw std::runtime_error(std::string("lexer: unexpected character: ") + ch); + } + switch (last_token_type) { + case token::identifier: + case token::numeric_literal: + case token::string_literal: + case token::close_paren: + case token::close_square_bracket: + // Part of a binary operator + // a - 1, 1 - 1, true - 1, "apple" - 1, (1) - 1, a[1] - 1 + // Continue parsing normally + break; + default: { + // Is part of a unary operator + // (-1), [-1], (1 + -1), not -1, -apple + ++pos; // Consume the operator + + // Check for numbers following the unary operator + std::string num = consume_while(is_integer); + std::string value = std::string(1, ch) + num; + token::type t = num.empty() ? token::unary_operator : token::numeric_literal; + JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str()); + tokens.push_back({t, value}); + continue; + } + } + } + + // Try to match one of the tokens in the mapping table + bool matched = false; + for (const auto & [seq, typ] : ordered_mapping_table) { + // Inside an object literal, don't treat "}}" as expression-end + if (seq == "}}" && curly_bracket_depth > 0) { + continue; + } + if (pos + seq.size() <= src.size() && src.substr(pos, seq.size()) == seq) { + tokens.push_back({typ, seq}); + if (typ == token::open_expression) { + curly_bracket_depth = 0; + } else if (typ == token::open_curly_bracket) { + ++curly_bracket_depth; + } else if (typ == token::close_curly_bracket) { + --curly_bracket_depth; + } + pos += seq.size(); + matched = true; + break; // continue main loop + } + } + if (matched) continue; // continue main loop + + // Strings + if (ch == '\'' || ch == '"') { + ++pos; // Skip opening quote + std::string str = consume_while([ch](char c) { return c != ch; }); + tokens.push_back({token::string_literal, str}); + ++pos; // Skip closing quote + continue; + } + + // Numbers + if (is_integer(ch)) { + std::string num = consume_while(is_integer); + if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) { + ++pos; // Consume '.' + std::string frac = consume_while(is_integer); + num += "." + frac; + } + tokens.push_back({token::numeric_literal, num}); + continue; + } + + // Identifiers + if (is_word(ch)) { + std::string word = consume_while(is_word); + tokens.push_back({token::identifier, word}); + continue; + } + + throw std::runtime_error(std::string("lexer: unexpected character: ") + ch); + } + + return tokens; + } +}; + +} // namespace jinja diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c3d9f9c324..f86a5b6657 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -186,6 +186,7 @@ endif() llama_build_and_test(test-chat-parser.cpp) llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp) llama_build_and_test(test-chat-template.cpp) +llama_build_and_test(test-chat-jinja.cpp) llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) llama_build_and_test( diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp new file mode 100644 index 0000000000..9fa0c7c817 --- /dev/null +++ b/tests/test-chat-jinja.cpp @@ -0,0 +1,55 @@ +#include +#include +#include +#include +#include + +#undef NDEBUG +#include + +#include "peg-parser.h" +#include "json-schema-to-grammar.h" +#include "jinja/jinja-compiler.h" +#include "jinja/jinja-lexer.h" + +int main(void) { + std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; + + std::cout << "=== INPUT ===\n" << contents << "\n\n"; + + jinja::lexer lexer; + jinja::preprocess_options options; + options.trim_blocks = true; + options.lstrip_blocks = false; + auto tokens = lexer.tokenize(contents, options); + for (const auto & tok : tokens) { + std::cout << "token: type=" << static_cast(tok.t) << " text='" << tok.value << "'\n"; + } + + // jinja::compiler compiler; + // compiler.builder.set_root(compiler.root); + // auto parser = compiler.builder.build(); + + // auto grammar = build_grammar([&](const common_grammar_builder & builder0) { + // parser.build_grammar(builder0); + // }); + // printf("== GRAMMAR ==\n"); + // printf("%s\n", grammar.c_str()); + + // // printf("== DUMP ==\n"); + // // printf("%s\n", parser.dump(compiler.root.id()).c_str()); + + // printf("== PARSE ==\n"); + + // common_peg_parse_context ctx(contents); + // const auto result = parser.parse(ctx); + // if (!result.success()) { + // throw std::runtime_error("failed to parse, type = " + std::to_string(result.type)); + // } + + // ctx.ast.visit(result, [&](const common_peg_ast_node & node) { + // printf("node: rule='%s' text='%s'\n", node.rule.c_str(), std::string(node.text).c_str()); + // }); + + return 0; +} From a35fcb00b5dad35bea361fef4bc89f9fa5daabdc Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 27 Dec 2025 12:12:07 +0100 Subject: [PATCH 03/47] add vm types --- common/jinja/jinja-lexer.cpp | 242 +++++++++++++++++++++ common/jinja/jinja-lexer.h | 228 +------------------- common/jinja/jinja-vm.h | 393 +++++++++++++++++++++++++++++++++++ 3 files changed, 637 insertions(+), 226 deletions(-) create mode 100644 common/jinja/jinja-lexer.cpp create mode 100644 common/jinja/jinja-vm.h diff --git a/common/jinja/jinja-lexer.cpp b/common/jinja/jinja-lexer.cpp new file mode 100644 index 0000000000..a5ce7af9e1 --- /dev/null +++ b/common/jinja/jinja-lexer.cpp @@ -0,0 +1,242 @@ +#include "jinja-lexer.h" + +#include +#include +#include +#include +#include +#include +#include + + +// #define JJ_DEBUG(msg, ...) printf("jinja-lexer: " msg "\n", __VA_ARGS__) +#define JJ_DEBUG(msg, ...) // no-op + +namespace jinja { + +std::string lexer::preprocess(const std::string & template_str, const preprocess_options & options) const { + std::string result = template_str; + // According to https://jinja.palletsprojects.com/en/3.0.x/templates/#whitespace-control + + // In the default configuration: + // - a single trailing newline is stripped if present + // - other whitespace (spaces, tabs, newlines etc.) is returned unchanged + if (!result.empty() && result.back() == '\n') { + result.pop_back(); + } + + if (options.lstrip_blocks) { + // The lstrip_blocks option can also be set to strip tabs and spaces from the + // beginning of a line to the start of a block. (Nothing will be stripped if + // there are other characters before the start of the block.) + // result = std::regex_replace(result, std::regex(R"((?m)^[ \t]*(\{[#%-]))"), "$1"); + throw std::runtime_error("lstrip_blocks option is not implemented yet"); + } + + if (options.trim_blocks) { + // If an application configures Jinja to trim_blocks, the first newline after + // a template tag is removed automatically (like in PHP). + result = std::regex_replace(result, std::regex(R"(([#%-]\})\n)"), "$1"); + } + + // Handle whitespace control with - in tags + result = std::regex_replace(result, std::regex(R"(-%\}\s*)"), "%}"); + result = std::regex_replace(result, std::regex(R"(\s*\{%-)"), "{%"); + result = std::regex_replace(result, std::regex(R"(-\}\}\s*)"), "}}"); + result = std::regex_replace(result, std::regex(R"(\s*\{\{-)"), "{{"); + result = std::regex_replace(result, std::regex(R"(-#\}\s*)"), "#}"); + result = std::regex_replace(result, std::regex(R"(\s*\{\#-)"), "{#"); + + // Handle custom transformers-specific `generation` tag + // See https://github.com/huggingface/transformers/pull/30650 for more information. + // result = std::regex_replace(result, std::regex(R"((?s)\{%\s*generation\s*%\}.+?\{%\s*endgeneration\s*%\})"), ""); + + return result; +} + +std::vector lexer::tokenize(const std::string & input, const preprocess_options & options) { + std::vector tokens; + std::string src = preprocess(input, options); + JJ_DEBUG("preprocessed input: '%s'", src.c_str()); + + size_t pos = 0; + size_t curly_bracket_depth = 0; + + using pred = std::function; + auto consume_while = [&](pred predicate) -> std::string { + std::string str; + while (predicate(src[pos])) { + // check for escape char + if (src[pos] == '\\') { + // consume backslash + ++pos; + // check for end of input + if (pos >= src.size()) { + throw std::runtime_error("lexer: unexpected end of input after escape character"); + } + // add escaped char + char escaped_char = src[pos++]; + if (escape_chars.find(escaped_char) == escape_chars.end()) { + throw std::runtime_error(std::string("lexer: unknown escape character \\") + escaped_char); + } + char unescaped_char = escape_chars.at(escaped_char); + str += unescaped_char; + continue; + } + + str += src[pos++]; + if (pos > src.size()) { + throw std::runtime_error("lexer: unexpected end of input during consume_while"); + } + } + return str; + }; + + auto next_pos_is = [&](std::initializer_list chars) -> bool { + if (pos + 1 >= src.size()) return false; + for (char c : chars) { + if (src[pos + 1] == c) return true; + } + return false; + }; + + while (pos < src.size()) { + JJ_DEBUG("lexer main loop at pos %zu: '%s...'", pos, src.substr(pos, 10).c_str()); + + // First, consume all text that is outside of a Jinja statement or expression + token::type last_token_type = tokens.empty() + ? token::undefined + : tokens.back().t; + if (last_token_type == token::undefined || + last_token_type == token::close_statement || + last_token_type == token::close_expression || + last_token_type == token::comment) { + std::string text; + while (pos < src.size() && + // Keep going until we hit the next Jinja statement or expression + !( + src[pos] == '{' && + next_pos_is( {'%', '{', '#'} ) + )) { + text += src[pos++]; + } + JJ_DEBUG("consumed text: '%s'", text.c_str()); + if (!text.empty()) { + tokens.push_back({token::text, text}); + continue; + } + } + + // Possibly consume a comment + if (src[pos] == '{' && next_pos_is( {'#'} )) { + pos += 2; // Skip the opening {# + std::string comment; + while (!(src[pos] == '#' && next_pos_is( {'}'} ))) { + if (pos + 2 >= src.size()) { + throw std::runtime_error("lexer: missing end of comment tag"); + } + comment += src[pos++]; + } + JJ_DEBUG("consumed comment: '%s'", comment.c_str()); + tokens.push_back({token::comment, comment}); + pos += 2; // Skip the closing #} + continue; + } + + // Consume (and ignore) all whitespace inside Jinja statements or expressions + consume_while([](char c) { return std::isspace(static_cast(c)); }); + + if (pos >= src.size()) break; + + char ch = src[pos]; + + // Check for unary operators + if (ch == '-' || ch == '+') { + token::type last_token_type = tokens.empty() ? token::undefined : tokens.back().t; + if (last_token_type == token::text || last_token_type == token::undefined) { + throw std::runtime_error(std::string("lexer: unexpected character: ") + ch); + } + switch (last_token_type) { + case token::identifier: + case token::numeric_literal: + case token::string_literal: + case token::close_paren: + case token::close_square_bracket: + // Part of a binary operator + // a - 1, 1 - 1, true - 1, "apple" - 1, (1) - 1, a[1] - 1 + // Continue parsing normally + break; + default: { + // Is part of a unary operator + // (-1), [-1], (1 + -1), not -1, -apple + ++pos; // Consume the operator + + // Check for numbers following the unary operator + std::string num = consume_while(is_integer); + std::string value = std::string(1, ch) + num; + token::type t = num.empty() ? token::unary_operator : token::numeric_literal; + JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str()); + tokens.push_back({t, value}); + continue; + } + } + } + + // Try to match one of the tokens in the mapping table + bool matched = false; + for (const auto & [seq, typ] : ordered_mapping_table) { + // Inside an object literal, don't treat "}}" as expression-end + if (seq == "}}" && curly_bracket_depth > 0) { + continue; + } + if (pos + seq.size() <= src.size() && src.substr(pos, seq.size()) == seq) { + tokens.push_back({typ, seq}); + if (typ == token::open_expression) { + curly_bracket_depth = 0; + } else if (typ == token::open_curly_bracket) { + ++curly_bracket_depth; + } else if (typ == token::close_curly_bracket) { + --curly_bracket_depth; + } + pos += seq.size(); + matched = true; + break; // continue main loop + } + } + if (matched) continue; // continue main loop + + // Strings + if (ch == '\'' || ch == '"') { + ++pos; // Skip opening quote + std::string str = consume_while([ch](char c) { return c != ch; }); + tokens.push_back({token::string_literal, str}); + ++pos; // Skip closing quote + continue; + } + + // Numbers + if (is_integer(ch)) { + std::string num = consume_while(is_integer); + if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) { + ++pos; // Consume '.' + std::string frac = consume_while(is_integer); + num += "." + frac; + } + tokens.push_back({token::numeric_literal, num}); + continue; + } + + // Identifiers + if (is_word(ch)) { + std::string word = consume_while(is_word); + tokens.push_back({token::identifier, word}); + continue; + } + + throw std::runtime_error(std::string("lexer: unexpected character: ") + ch); + } + + return tokens; +} + +} // namespace jinja diff --git a/common/jinja/jinja-lexer.h b/common/jinja/jinja-lexer.h index 62850a33f6..554f30500a 100644 --- a/common/jinja/jinja-lexer.h +++ b/common/jinja/jinja-lexer.h @@ -6,9 +6,6 @@ #include #include -// #define JJ_DEBUG(msg, ...) printf("jinja-lexer: " msg "\n", __VA_ARGS__) -#define JJ_DEBUG(msg, ...) // no-op - namespace jinja { struct preprocess_options { @@ -107,230 +104,9 @@ struct lexer { {"=", token::equals}, }; - std::string preprocess(const std::string& template_str, const preprocess_options& options) const { - std::string result = template_str; - // According to https://jinja.palletsprojects.com/en/3.0.x/templates/#whitespace-control + std::string preprocess(const std::string& template_str, const preprocess_options& options) const; - // In the default configuration: - // - a single trailing newline is stripped if present - // - other whitespace (spaces, tabs, newlines etc.) is returned unchanged - if (!result.empty() && result.back() == '\n') { - result.pop_back(); - } - - if (options.lstrip_blocks) { - // The lstrip_blocks option can also be set to strip tabs and spaces from the - // beginning of a line to the start of a block. (Nothing will be stripped if - // there are other characters before the start of the block.) - // result = std::regex_replace(result, std::regex(R"((?m)^[ \t]*(\{[#%-]))"), "$1"); - throw std::runtime_error("lstrip_blocks option is not implemented yet"); - } - - if (options.trim_blocks) { - // If an application configures Jinja to trim_blocks, the first newline after - // a template tag is removed automatically (like in PHP). - result = std::regex_replace(result, std::regex(R"(([#%-]\})\n)"), "$1"); - } - - // Handle whitespace control with - in tags - result = std::regex_replace(result, std::regex(R"(-%\}\s*)"), "%}"); - result = std::regex_replace(result, std::regex(R"(\s*\{%-)"), "{%"); - result = std::regex_replace(result, std::regex(R"(-\}\}\s*)"), "}}"); - result = std::regex_replace(result, std::regex(R"(\s*\{\{-)"), "{{"); - result = std::regex_replace(result, std::regex(R"(-#\}\s*)"), "#}"); - result = std::regex_replace(result, std::regex(R"(\s*\{\#-)"), "{#"); - - // Handle custom transformers-specific `generation` tag - // See https://github.com/huggingface/transformers/pull/30650 for more information. - // result = std::regex_replace(result, std::regex(R"((?s)\{%\s*generation\s*%\}.+?\{%\s*endgeneration\s*%\})"), ""); - - return result; - } - - std::vector tokenize(const std::string & input, const preprocess_options & options = {}) { - std::vector tokens; - std::string src = preprocess(input, options); - JJ_DEBUG("preprocessed input: '%s'", src.c_str()); - - size_t pos = 0; - size_t curly_bracket_depth = 0; - - using pred = std::function; - auto consume_while = [&](pred predicate) -> std::string { - std::string str; - while (predicate(src[pos])) { - // check for escape char - if (src[pos] == '\\') { - // consume backslash - ++pos; - // check for end of input - if (pos >= src.size()) { - throw std::runtime_error("lexer: unexpected end of input after escape character"); - } - // add escaped char - char escaped_char = src[pos++]; - if (escape_chars.find(escaped_char) == escape_chars.end()) { - throw std::runtime_error(std::string("lexer: unknown escape character \\") + escaped_char); - } - char unescaped_char = escape_chars.at(escaped_char); - str += unescaped_char; - continue; - } - - str += src[pos++]; - if (pos > src.size()) { - throw std::runtime_error("lexer: unexpected end of input during consume_while"); - } - } - return str; - }; - - auto next_pos_is = [&](std::initializer_list chars) -> bool { - if (pos + 1 >= src.size()) return false; - for (char c : chars) { - if (src[pos + 1] == c) return true; - } - return false; - }; - - while (pos < src.size()) { - JJ_DEBUG("lexer main loop at pos %zu: '%s...'", pos, src.substr(pos, 10).c_str()); - - // First, consume all text that is outside of a Jinja statement or expression - token::type last_token_type = tokens.empty() - ? token::undefined - : tokens.back().t; - if (last_token_type == token::undefined || - last_token_type == token::close_statement || - last_token_type == token::close_expression || - last_token_type == token::comment) { - std::string text; - while (pos < src.size() && - // Keep going until we hit the next Jinja statement or expression - !( - src[pos] == '{' && - next_pos_is( {'%', '{', '#'} ) - )) { - text += src[pos++]; - } - JJ_DEBUG("consumed text: '%s'", text.c_str()); - if (!text.empty()) { - tokens.push_back({token::text, text}); - continue; - } - } - - // Possibly consume a comment - if (src[pos] == '{' && next_pos_is( {'#'} )) { - pos += 2; // Skip the opening {# - std::string comment; - while (!(src[pos] == '#' && next_pos_is( {'}'} ))) { - if (pos + 2 >= src.size()) { - throw std::runtime_error("lexer: missing end of comment tag"); - } - comment += src[pos++]; - } - JJ_DEBUG("consumed comment: '%s'", comment.c_str()); - tokens.push_back({token::comment, comment}); - pos += 2; // Skip the closing #} - continue; - } - - // Consume (and ignore) all whitespace inside Jinja statements or expressions - consume_while([](char c) { return std::isspace(static_cast(c)); }); - - if (pos >= src.size()) break; - - char ch = src[pos]; - - // Check for unary operators - if (ch == '-' || ch == '+') { - token::type last_token_type = tokens.empty() ? token::undefined : tokens.back().t; - if (last_token_type == token::text || last_token_type == token::undefined) { - throw std::runtime_error(std::string("lexer: unexpected character: ") + ch); - } - switch (last_token_type) { - case token::identifier: - case token::numeric_literal: - case token::string_literal: - case token::close_paren: - case token::close_square_bracket: - // Part of a binary operator - // a - 1, 1 - 1, true - 1, "apple" - 1, (1) - 1, a[1] - 1 - // Continue parsing normally - break; - default: { - // Is part of a unary operator - // (-1), [-1], (1 + -1), not -1, -apple - ++pos; // Consume the operator - - // Check for numbers following the unary operator - std::string num = consume_while(is_integer); - std::string value = std::string(1, ch) + num; - token::type t = num.empty() ? token::unary_operator : token::numeric_literal; - JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str()); - tokens.push_back({t, value}); - continue; - } - } - } - - // Try to match one of the tokens in the mapping table - bool matched = false; - for (const auto & [seq, typ] : ordered_mapping_table) { - // Inside an object literal, don't treat "}}" as expression-end - if (seq == "}}" && curly_bracket_depth > 0) { - continue; - } - if (pos + seq.size() <= src.size() && src.substr(pos, seq.size()) == seq) { - tokens.push_back({typ, seq}); - if (typ == token::open_expression) { - curly_bracket_depth = 0; - } else if (typ == token::open_curly_bracket) { - ++curly_bracket_depth; - } else if (typ == token::close_curly_bracket) { - --curly_bracket_depth; - } - pos += seq.size(); - matched = true; - break; // continue main loop - } - } - if (matched) continue; // continue main loop - - // Strings - if (ch == '\'' || ch == '"') { - ++pos; // Skip opening quote - std::string str = consume_while([ch](char c) { return c != ch; }); - tokens.push_back({token::string_literal, str}); - ++pos; // Skip closing quote - continue; - } - - // Numbers - if (is_integer(ch)) { - std::string num = consume_while(is_integer); - if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) { - ++pos; // Consume '.' - std::string frac = consume_while(is_integer); - num += "." + frac; - } - tokens.push_back({token::numeric_literal, num}); - continue; - } - - // Identifiers - if (is_word(ch)) { - std::string word = consume_while(is_word); - tokens.push_back({token::identifier, word}); - continue; - } - - throw std::runtime_error(std::string("lexer: unexpected character: ") + ch); - } - - return tokens; - } + std::vector tokenize(const std::string & input, const preprocess_options & options); }; } // namespace jinja diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h new file mode 100644 index 0000000000..9ee2917531 --- /dev/null +++ b/common/jinja/jinja-vm.h @@ -0,0 +1,393 @@ +#include "jinja-lexer.h" + +#include +#include +#include +#include + + +namespace jinja { + +struct context { + // TODO +}; + +/** + * Base class for all nodes in the AST. + */ +struct statement { + virtual ~statement() = default; + virtual std::string type() const { return "Statement"; } + virtual void execute(context & ctx) = 0; +}; + +using statement_ptr = std::unique_ptr; +using statements = std::vector; + +// Type Checking Utilities + +template +static void chk_type(const statement_ptr & ptr) { + if (!ptr) return; // Allow null for optional fields + assert(dynamic_cast(ptr.get()) != nullptr); +} + +template +static void chk_type(const statement_ptr & ptr) { + if (!ptr) return; + assert(dynamic_cast(ptr.get()) != nullptr || dynamic_cast(ptr.get()) != nullptr); +} + +// Base Types + +/** + * Expressions will result in a value at runtime (unlike statements). + */ +struct expression : public statement { + std::string type() const override { return "Expression"; } + void execute(context & ctx) override {} +}; + +// Statements + +struct program : public statement { + statements body; + + explicit program(statements && body) : body(std::move(body)) {} + std::string type() const override { return "Program"; } + void execute(context & ctx) override {} +}; + +struct if_statement : public statement { + statement_ptr test; + statements body; + statements alternate; + + if_statement(statement_ptr && test, statements && body, statements && alternate) + : test(std::move(test)), body(std::move(body)), alternate(std::move(alternate)) { + chk_type(this->test); + } + + std::string type() const override { return "If"; } + void execute(context & ctx) override {} +}; + +struct identifier; +struct tuple_literal; + +/** + * Loop over each item in a sequence + * https://jinja.palletsprojects.com/en/3.0.x/templates/#for + */ +struct for_statement : public statement { + statement_ptr loopvar; // Identifier | TupleLiteral + statement_ptr iterable; + statements body; + statements default_block; // if no iteration took place + + for_statement(statement_ptr && loopvar, statement_ptr && iterable, statements && body, statements && default_block) + : loopvar(std::move(loopvar)), iterable(std::move(iterable)), + body(std::move(body)), default_block(std::move(default_block)) { + chk_type(this->loopvar); + chk_type(this->iterable); + } + + std::string type() const override { return "For"; } + void execute(context & ctx) override {} +}; + +struct break_statement : public statement { + std::string type() const override { return "Break"; } + void execute(context & ctx) override {} +}; + +struct continue_statement : public statement { + std::string type() const override { return "Continue"; } + void execute(context & ctx) override {} +}; + +struct set_statement : public statement { + statement_ptr assignee; + statement_ptr value; + statements body; + + set_statement(statement_ptr && assignee, statement_ptr && value, statements && body) + : assignee(std::move(assignee)), value(std::move(value)), body(std::move(body)) { + chk_type(this->assignee); + chk_type(this->value); + } + + std::string type() const override { return "Set"; } + void execute(context & ctx) override {} +}; + +struct macro_statement : public statement { + statement_ptr name; + statements args; + statements body; + + macro_statement(statement_ptr && name, statements && args, statements && body) + : name(std::move(name)), args(std::move(args)), body(std::move(body)) { + chk_type(this->name); + for (const auto& arg : this->args) chk_type(arg); + } + + std::string type() const override { return "Macro"; } + void execute(context & ctx) override {} +}; + +struct comment_statement : public statement { + std::string value; + explicit comment_statement(const std::string & value) : value(value) {} + std::string type() const override { return "Comment"; } + void execute(context & ctx) override {} +}; + +// Expressions + +struct member_expression : public expression { + statement_ptr object; + statement_ptr property; + bool computed; + + member_expression(statement_ptr && object, statement_ptr && property, bool computed) + : object(std::move(object)), property(std::move(property)), computed(computed) { + chk_type(this->object); + chk_type(this->property); + } + std::string type() const override { return "MemberExpression"; } +}; + +struct call_expression : public expression { + statement_ptr callee; + statements args; + + call_expression(statement_ptr && callee, statements && args) + : callee(std::move(callee)), args(std::move(args)) { + chk_type(this->callee); + for (const auto& arg : this->args) chk_type(arg); + } + std::string type() const override { return "CallExpression"; } +}; + +/** + * Represents a user-defined variable or symbol in the template. + */ +struct identifier : public expression { + std::string value; + explicit identifier(const std::string & value) : value(value) {} + std::string type() const override { return "Identifier"; } +}; + +// Literals + +/** + * Abstract base class for all Literal expressions. + * Should not be instantiated directly. + */ +template +struct literal : public expression { + T value; + explicit literal(T && value) : value(std::move(value)) {} + std::string type() const override { return "Literal"; } +}; + +struct integer_literal : public literal { + std::string type() const override { return "IntegerLiteral"; } +}; + +struct float_literal : public literal { + std::string type() const override { return "FloatLiteral"; } +}; + +struct string_literal : public literal { + std::string type() const override { return "StringLiteral"; } +}; + +struct array_literal : public expression { + statements value; + explicit array_literal(statements && value) : value(std::move(value)) { + for (const auto& item : this->value) chk_type(item); + } + std::string type() const override { return "ArrayLiteral"; } +}; + +struct tuple_literal : public expression { + statements value; + explicit tuple_literal(statements && value) : value(std::move(value)) { + for (const auto& item : this->value) chk_type(item); + } + std::string type() const override { return "TupleLiteral"; } +}; + +struct object_literal : public expression { + std::vector> value; + explicit object_literal(std::vector> && value) + : value(std::move(value)) { + for (const auto & pair : this->value) { + chk_type(pair.first); + chk_type(pair.second); + } + } + std::string type() const override { return "ObjectLiteral"; } +}; + +// Complex Expressions + +/** + * An operation with two sides, separated by an operator. + * Note: Either side can be a Complex Expression, with order + * of operations being determined by the operator. + */ +struct binary_expression : public expression { + token::type op; + statement_ptr left; + statement_ptr right; + + binary_expression(token::type op, statement_ptr && left, statement_ptr && right) + : op(op), left(std::move(left)), right(std::move(right)) { + chk_type(this->left); + chk_type(this->right); + } + std::string type() const override { return "BinaryExpression"; } +}; + +/** + * An operation with two sides, separated by the | operator. + * Operator precedence: https://github.com/pallets/jinja/issues/379#issuecomment-168076202 + */ +struct filter_expression : public expression { + statement_ptr operand; + statement_ptr filter; + + filter_expression(statement_ptr && operand, statement_ptr && filter) + : operand(std::move(operand)), filter(std::move(filter)) { + chk_type(this->operand); + chk_type(this->filter); + } + std::string type() const override { return "FilterExpression"; } +}; + +struct filter_statement : public statement { + statement_ptr filter; + statements body; + + filter_statement(statement_ptr && filter, statements && body) + : filter(std::move(filter)), body(std::move(body)) { + chk_type(this->filter); + } + std::string type() const override { return "FilterStatement"; } + void execute(context & ctx) override {} +}; + +/** + * An operation which filters a sequence of objects by applying a test to each object, + * and only selecting the objects with the test succeeding. + * + * It may also be used as a shortcut for a ternary operator. + */ +struct select_expression : public expression { + statement_ptr lhs; + statement_ptr test; + + select_expression(statement_ptr && lhs, statement_ptr && test) + : lhs(std::move(lhs)), test(std::move(test)) { + chk_type(this->lhs); + chk_type(this->test); + } + std::string type() const override { return "SelectExpression"; } +}; + +/** + * An operation with two sides, separated by the "is" operator. + */ +struct test_expression : public expression { + statement_ptr operand; + bool negate; + statement_ptr test; + + test_expression(statement_ptr && operand, bool negate, statement_ptr && test) + : operand(std::move(operand)), negate(negate), test(std::move(test)) { + chk_type(this->operand); + chk_type(this->test); + } + std::string type() const override { return "TestExpression"; } +}; + +/** + * An operation with one side (operator on the left). + */ +struct unary_expression : public expression { + token op; + statement_ptr argument; + + unary_expression(token op, statement_ptr && argument) + : op(std::move(op)), argument(std::move(argument)) { + chk_type(this->argument); + } + std::string type() const override { return "UnaryExpression"; } +}; + +struct slice_expression : public expression { + statement_ptr start; + statement_ptr stop; + statement_ptr step; + + slice_expression(statement_ptr && start, statement_ptr && stop, statement_ptr && step) + : start(std::move(start)), stop(std::move(stop)), step(std::move(step)) { + chk_type(this->start); + chk_type(this->stop); + chk_type(this->step); + } + std::string type() const override { return "SliceExpression"; } +}; + +struct keyword_argument_expression : public expression { + statement_ptr key; + statement_ptr value; + + keyword_argument_expression(statement_ptr && key, statement_ptr && value) + : key(std::move(key)), value(std::move(value)) { + chk_type(this->key); + chk_type(this->value); + } + std::string type() const override { return "KeywordArgumentExpression"; } +}; + +struct spread_expression : public expression { + statement_ptr argument; + explicit spread_expression(statement_ptr && argument) : argument(std::move(argument)) { + chk_type(this->argument); + } + std::string type() const override { return "SpreadExpression"; } +}; + +struct call_statement : public statement { + statement_ptr call; + statements caller_args; + statements body; + + call_statement(statement_ptr && call, statements && caller_args, statements && body) + : call(std::move(call)), caller_args(std::move(caller_args)), body(std::move(body)) { + chk_type(this->call); + for (const auto& arg : this->caller_args) chk_type(arg); + } + std::string type() const override { return "CallStatement"; } + void execute(context & ctx) override {} +}; + +struct ternary_expression : public expression { + statement_ptr condition; + statement_ptr true_expr; + statement_ptr false_expr; + + ternary_expression(statement_ptr && condition, statement_ptr && true_expr, statement_ptr && false_expr) + : condition(std::move(condition)), true_expr(std::move(true_expr)), false_expr(std::move(false_expr)) { + chk_type(this->condition); + chk_type(this->true_expr); + chk_type(this->false_expr); + } + std::string type() const override { return "Ternary"; } +}; + +} // namespace jinja From a6e0ae7a85b698baa05d7b1b631da6c54e9e9f6d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 27 Dec 2025 12:22:34 +0100 Subject: [PATCH 04/47] demo --- common/jinja/jinja-compiler.h | 79 ----------------------------------- common/jinja/jinja-lexer.h | 32 ++++++++++++++ common/jinja/jinja-parser.cpp | 39 +++++++++++++++++ common/jinja/jinja-vm.cpp | 28 ------------- 4 files changed, 71 insertions(+), 107 deletions(-) delete mode 100644 common/jinja/jinja-compiler.h create mode 100644 common/jinja/jinja-parser.cpp delete mode 100644 common/jinja/jinja-vm.cpp diff --git a/common/jinja/jinja-compiler.h b/common/jinja/jinja-compiler.h deleted file mode 100644 index 32792d6050..0000000000 --- a/common/jinja/jinja-compiler.h +++ /dev/null @@ -1,79 +0,0 @@ -#include "common.h" -#include -#include - -namespace jinja { - -struct compiler { - common_chat_peg_native_builder builder; - common_peg_parser root; - - compiler() : root(builder.choice()) { - auto & p = builder; - - auto ws = p.rule("ws", p.chars("[ \t]", 0, -1)); - auto num = p.rule("num", p.chars("[0-9]", 1, -1)); - - // - // expressions - // - - auto expression = p.choice(); - - auto var_name = p.rule("var_name", p.chars("[a-zA-Z_]", 1, -1) << p.chars("[a-zA-Z0-9_]", 0, -1)); - expression |= var_name; - - // value - auto p_int = p.rule("value_int", num); - auto p_flt = p.rule("value_flt", num << "." << p.optional(num)); - auto p_str = p.rule("value_str", - p.json_string() | - p.literal("'") + p.chars("[^']*", 0, -1) + p.literal("'") - ); - - expression |= p_int; - expression |= p_flt; - expression |= p_str; - - // function calls - auto p_args = p.rule("args", expression << ws << p.zero_or_more("," << ws << expression)); - auto p_func = p.rule("func", ws << var_name << ws << "(" << ws << p_args << ws << ")"); - expression |= p_func; - - // indexing - auto p_idx = p.rule("idx", ws << "[" << ws << expression << ws << "]"); - expression |= p_idx; - - // set - auto p_set = p.rule("set", "set " << ws << var_name << ws << "=" << expression); - expression |= p_set; - - // if, else, endif - auto p_if = p.rule("if", "if " << ws << expression << ws); - auto p_else = p.rule("else", "else " << ws << expression << ws); - auto p_endif = p.rule("endif", p.literal("endif")); - - expression |= p_if; - expression |= p_else; - expression |= p_endif; - - expression = p.space() + expression + p.space(); - - // - // root - // - - // auto strip = p.rule("strip", "-" << expression << "-"); - auto print = p.rule("print", "{{" << (expression) << "}}"); - auto ctrl = p.rule("ctrl", "{%" << (expression) << "%}"); - - root |= print; - root |= ctrl; - root |= p.rule("text", p.negate(root)); - - root = p.one_or_more(root); - root += p.end(); - } -}; - -} // namespace jinja diff --git a/common/jinja/jinja-lexer.h b/common/jinja/jinja-lexer.h index 554f30500a..2011e487b1 100644 --- a/common/jinja/jinja-lexer.h +++ b/common/jinja/jinja-lexer.h @@ -48,6 +48,38 @@ struct token { std::string value; }; +std::string type_to_string(token::type t) { + switch (t) { + case token::undefined: return "undefined"; + case token::text: return "text"; + case token::numeric_literal: return "numeric_literal"; + case token::string_literal: return "string_literal"; + case token::identifier: return "identifier"; + case token::equals: return "equals"; + case token::open_paren: return "open_paren"; + case token::close_paren: return "close_paren"; + case token::open_statement: return "open_statement"; + case token::close_statement: return "close_statement"; + case token::open_expression: return "open_expression"; + case token::close_expression: return "close_expression"; + case token::open_square_bracket: return "open_square_bracket"; + case token::close_square_bracket: return "close_square_bracket"; + case token::open_curly_bracket: return "open_curly_bracket"; + case token::close_curly_bracket: return "close_curly_bracket"; + case token::comma: return "comma"; + case token::dot: return "dot"; + case token::colon: return "colon"; + case token::pipe: return "pipe"; + case token::call_operator: return "call_operator"; + case token::additive_binary_operator: return "additive_binary_operator"; + case token::multiplicative_binary_operator: return "multiplicative_binary_operator"; + case token::comparison_binary_operator: return "comparison_binary_operator"; + case token::unary_operator: return "unary_operator"; + case token::comment: return "comment"; + default: return "unknown"; + } +} + struct lexer { const std::map escape_chars = { {'n', '\n'}, diff --git a/common/jinja/jinja-parser.cpp b/common/jinja/jinja-parser.cpp new file mode 100644 index 0000000000..fa8cd9785a --- /dev/null +++ b/common/jinja/jinja-parser.cpp @@ -0,0 +1,39 @@ +#include "jinja-lexer.h" +#include "jinja-vm.h" + +namespace jinja { + +void parse(const std::vector & tokens) { + auto program = std::make_unique(); + size_t current = 0; + + /** + * Consume the next token if it matches the expected type, otherwise throw an error. + * @param type The expected token type + * @param error The error message to throw if the token does not match the expected type + * @returns The consumed token + */ + auto expect = [&](const token::type & type, const std::string & error) -> token { + const auto & prev = tokens[current++]; + if (prev.t != type) { + throw std::runtime_error("Parser Error: " + error + " (" + type_to_string(prev.t) + " != " + type_to_string(type) + ")"); + } + return prev; + }; + + auto next_token = [&]() -> const token & { + if (current >= tokens.size()) { + return token{token::undefined, ""}; + } + return tokens[current++]; + }; + + auto expect_identifier = [&](const std::string & name) -> void { + if (!is_identifier(name)) { + throw std::runtime_error("Expected " + name); + } + ++current; + }; +} + +}; // namespace jinja diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp deleted file mode 100644 index 7c8d0cf732..0000000000 --- a/common/jinja/jinja-vm.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include -#include - -struct vm_context { - std::ostringstream out; -}; - -struct op_base { - virtual ~op_base() = default; - virtual void execute(vm_context & ctx) = 0; -}; - -struct op_print : public op_base { - std::string message; - op_print(const std::string & message) : message(message) {} - void execute(vm_context & ctx) override { - ctx.out << message; - } -}; - -struct op_load : public op_base { - std::string dst; - std::string src; - std::string value; - op_load(const std::string & dst) : dst(dst) {} - void execute(vm_context & ctx) override { - } -}; From 7ac8e98b2838835749d5a5f4ad88a9a2a945c3d0 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 27 Dec 2025 12:35:19 +0100 Subject: [PATCH 05/47] clean up --- common/jinja/jinja-parser.cpp | 570 ++++++++++++++++++++++++++++++++-- 1 file changed, 538 insertions(+), 32 deletions(-) diff --git a/common/jinja/jinja-parser.cpp b/common/jinja/jinja-parser.cpp index fa8cd9785a..07cb71fe11 100644 --- a/common/jinja/jinja-parser.cpp +++ b/common/jinja/jinja-parser.cpp @@ -1,39 +1,545 @@ #include "jinja-lexer.h" #include "jinja-vm.h" +#include +#include +#include +#include +#include + namespace jinja { -void parse(const std::vector & tokens) { - auto program = std::make_unique(); - size_t current = 0; - - /** - * Consume the next token if it matches the expected type, otherwise throw an error. - * @param type The expected token type - * @param error The error message to throw if the token does not match the expected type - * @returns The consumed token - */ - auto expect = [&](const token::type & type, const std::string & error) -> token { - const auto & prev = tokens[current++]; - if (prev.t != type) { - throw std::runtime_error("Parser Error: " + error + " (" + type_to_string(prev.t) + " != " + type_to_string(type) + ")"); - } - return prev; - }; - - auto next_token = [&]() -> const token & { - if (current >= tokens.size()) { - return token{token::undefined, ""}; - } - return tokens[current++]; - }; - - auto expect_identifier = [&](const std::string & name) -> void { - if (!is_identifier(name)) { - throw std::runtime_error("Expected " + name); - } - ++current; - }; +// Helper to check type without asserting (useful for logic) +template +static bool is_type(const statement_ptr & ptr) { + return dynamic_cast(ptr.get()) != nullptr; } -}; // namespace jinja +class parser { + const std::vector & tokens; + size_t current = 0; + +public: + parser(const std::vector & t) : tokens(t) {} + + statement_ptr parse() { + statements body; + while (current < tokens.size()) { + body.push_back(parse_any()); + } + return std::make_unique(std::move(body)); + } + +private: + const token & peek(size_t offset = 0) const { + if (current + offset >= tokens.size()) { + static const token end_token{token::undefined, ""}; + return end_token; + } + return tokens[current + offset]; + } + + token expect(token::type type, const std::string& error) { + const auto & t = peek(); + if (t.t != type) { + throw std::runtime_error("Parser Error: " + error + " (Got " + t.value + ")"); + } + current++; + return t; + } + + void expect_identifier(const std::string& name) { + const auto & t = peek(); + if (t.t != token::identifier || t.value != name) { + throw std::runtime_error("Expected identifier: " + name); + } + current++; + } + + bool is(token::type type) const { + return peek().t == type; + } + + bool is_identifier(const std::string& name) const { + return peek().t == token::identifier && peek().value == name; + } + + bool is_statement(const std::vector& names) const { + if (peek(0).t != token::open_statement || peek(1).t != token::identifier) { + return false; + } + std::string val = peek(1).value; + return std::find(names.begin(), names.end(), val) != names.end(); + } + + statement_ptr parse_any() { + switch (peek().t) { + case token::comment: + return std::make_unique(tokens[current++].value); + case token::text: + return std::make_unique(tokens[current++].value); + case token::open_statement: + return parse_jinja_statement(); + case token::open_expression: + return parse_jinja_expression(); + default: + throw std::runtime_error("Unexpected token type"); + } + } + + statement_ptr parse_jinja_expression() { + // Consume {{ }} tokens + expect(token::open_expression, "Expected {{"); + auto result = parse_expression(); + expect(token::close_expression, "Expected }}"); + return result; + } + + statement_ptr parse_jinja_statement() { + // Consume {% token + expect(token::open_statement, "Expected {%"); + + if (peek().t != token::identifier) { + throw std::runtime_error("Unknown statement"); + } + + std::string name = peek().value; + current++; // consume identifier + + statement_ptr result; + if (name == "set") { + result = parse_set_statement(); + + } else if (name == "if") { + result = parse_if_statement(); + // expect {% endif %} + expect(token::open_statement, "Expected {%"); + expect_identifier("endif"); + expect(token::close_statement, "Expected %}"); + + } else if (name == "macro") { + result = parse_macro_statement(); + // expect {% endmacro %} + expect(token::open_statement, "Expected {%"); + expect_identifier("endmacro"); + expect(token::close_statement, "Expected %}"); + + } else if (name == "for") { + result = parse_for_statement(); + // expect {% endfor %} + expect(token::open_statement, "Expected {%"); + expect_identifier("endfor"); + expect(token::close_statement, "Expected %}"); + + } else if (name == "break") { + expect(token::close_statement, "Expected %}"); + result = std::make_unique(); + + } else if (name == "continue") { + expect(token::close_statement, "Expected %}"); + result = std::make_unique(); + + } else if (name == "call") { + statements caller_args; + bool has_caller_args = false; + if (is(token::open_paren)) { + // Optional caller arguments, e.g. {% call(user) dump_users(...) %} + caller_args = parse_args(); + has_caller_args = true; + } + auto callee = parse_primary_expression(); + if (!is_type(callee)) throw std::runtime_error("Expected identifier"); + + auto call_args = parse_args(); + expect(token::close_statement, "Expected %}"); + + statements body; + while (!is_statement({"endcall"})) { + body.push_back(parse_any()); + } + + expect(token::open_statement, "Expected {%"); + expect_identifier("endcall"); + expect(token::close_statement, "Expected %}"); + + auto call_expr = std::make_unique(std::move(callee), std::move(call_args)); + result = std::make_unique(std::move(call_expr), std::move(caller_args), std::move(body)); + + } else if (name == "filter") { + auto filter_node = parse_primary_expression(); + if (is_type(filter_node) && is(token::open_paren)) { + filter_node = parse_call_expression(std::move(filter_node)); + } + expect(token::close_statement, "Expected %}"); + + statements body; + while (!is_statement({"endfilter"})) { + body.push_back(parse_any()); + } + + expect(token::open_statement, "Expected {%"); + expect_identifier("endfilter"); + expect(token::close_statement, "Expected %}"); + result = std::make_unique(std::move(filter_node), std::move(body)); + + } else { + throw std::runtime_error("Unknown statement: " + name); + } + return result; + } + + statement_ptr parse_set_statement() { + // NOTE: `set` acts as both declaration statement and assignment expression + auto left = parse_expression_sequence(); + statement_ptr value = nullptr; + statements body; + + if (is(token::equals)) { + current++; + value = parse_expression_sequence(); + } else { + // parsing multiline set here + expect(token::close_statement, "Expected %}"); + while (!is_statement({"endset"})) { + body.push_back(parse_any()); + } + expect(token::open_statement, "Expected {%"); + expect_identifier("endset"); + } + expect(token::close_statement, "Expected %}"); + return std::make_unique(std::move(left), std::move(value), std::move(body)); + } + + statement_ptr parse_if_statement() { + auto test = parse_expression(); + expect(token::close_statement, "Expected %}"); + + statements body; + statements alternate; + + // Keep parsing 'if' body until we reach the first {% elif %} or {% else %} or {% endif %} + while (!is_statement({"elif", "else", "endif"})) { + body.push_back(parse_any()); + } + + if (is_statement({"elif"})) { + ++current; // consume {% + ++current; // consume 'elif' + alternate.push_back(parse_if_statement()); // nested If + } else if (is_statement({"else"})) { + ++current; // consume {% + ++current; // consume 'else' + expect(token::close_statement, "Expected %}"); + + // keep going until we hit {% endif %} + while (!is_statement({"endif"})) { + alternate.push_back(parse_any()); + } + } + return std::make_unique(std::move(test), std::move(body), std::move(alternate)); + } + + statement_ptr parse_macro_statement() { + auto name = parse_primary_expression(); + auto args = parse_args(); + expect(token::close_statement, "Expected %}"); + statements body; + // Keep going until we hit {% endmacro + while (!is_statement({"endmacro"})) { + body.push_back(parse_any()); + } + return std::make_unique(std::move(name), std::move(args), std::move(body)); + } + + statement_ptr parse_expression_sequence(bool primary = false) { + statements exprs; + exprs.push_back(primary ? parse_primary_expression() : parse_expression()); + bool is_tuple = is(token::comma); + while (is(token::comma)) { + current++; // consume comma + exprs.push_back(primary ? parse_primary_expression() : parse_expression()); + if (!is(token::comma)) break; + } + return is_tuple ? std::make_unique(std::move(exprs)) : std::move(exprs[0]); + } + + statement_ptr parse_for_statement() { + // e.g., `message` in `for message in messages` + auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple + if (!is_identifier("in")) throw std::runtime_error("Expected 'in'"); + current++; + + // `messages` in `for message in messages` + auto iterable = parse_expression(); + expect(token::close_statement, "Expected %}"); + + statements body; + statements alternate; + + // Keep going until we hit {% endfor or {% else + while (!is_statement({"endfor", "else"})) { + body.push_back(parse_any()); + } + + if (is_statement({"else"})) { + current += 2; + expect(token::close_statement, "Expected %}"); + while (!is_statement({"endfor"})) { + alternate.push_back(parse_any()); + } + } + return std::make_unique( + std::move(loop_var), std::move(iterable), + std::move(body), std::move(alternate)); + } + + statement_ptr parse_expression() { + // Choose parse function with lowest precedence + return parse_if_expression(); + } + + statement_ptr parse_if_expression() { + auto a = parse_logical_or_expression(); + if (is_identifier("if")) { + // Ternary expression + ++current; // consume 'if' + auto test = parse_logical_or_expression(); + if (is_identifier("else")) { + // Ternary expression with else + ++current; // consume 'else' + auto false_expr = parse_if_expression(); // recurse to support chained ternaries + return std::make_unique(std::move(test), std::move(a), std::move(false_expr)); + } else { + // Select expression on iterable + return std::make_unique(std::move(a), std::move(test)); + } + } + return a; + } + + statement_ptr parse_logical_or_expression() { + auto left = parse_logical_and_expression(); + while (is_identifier("or")) { + auto op = tokens[current++]; + left = std::make_unique(op, std::move(left), parse_logical_and_expression()); + } + return left; + } + + statement_ptr parse_logical_and_expression() { + auto left = parse_logical_negation_expression(); + while (is_identifier("and")) { + auto op = tokens[current++]; + left = std::make_unique(op, std::move(left), parse_logical_negation_expression()); + } + return left; + } + + statement_ptr parse_logical_negation_expression() { + // Try parse unary operators + if (is_identifier("not")) { + auto op = tokens[current]; + ++current; // consume 'not' + return std::make_unique(op, parse_logical_negation_expression()); + } + return parse_comparison_expression(); + } + + statement_ptr parse_comparison_expression() { + // NOTE: membership has same precedence as comparison + // e.g., ('a' in 'apple' == 'b' in 'banana') evaluates as ('a' in ('apple' == ('b' in 'banana'))) + auto left = parse_additive_expression(); + while (true) { + token op; + if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") { + op = {token::identifier, "not in"}; + current += 2; + } else if (is_identifier("in")) { + op = tokens[current++]; + } else if (is(token::comparison_binary_operator)) { + op = tokens[current++]; + } else break; + left = std::make_unique(op, std::move(left), parse_additive_expression()); + } + return left; + } + + statement_ptr parse_additive_expression() { + auto left = parse_multiplicative_expression(); + while (is(token::additive_binary_operator)) { + auto op = tokens[current++]; + left = std::make_unique(op, std::move(left), parse_multiplicative_expression()); + } + return left; + } + + statement_ptr parse_multiplicative_expression() { + auto left = parse_test_expression(); + while (is(token::multiplicative_binary_operator)) { + auto op = tokens[current++]; + left = std::make_unique(op, std::move(left), parse_test_expression()); + } + return left; + } + + statement_ptr parse_test_expression() { + auto operand = parse_filter_expression(); + while (is_identifier("is")) { + current++; + bool negate = false; + if (is_identifier("not")) { current++; negate = true; } + auto test_id = parse_primary_expression(); + operand = std::make_unique(std::move(operand), negate, std::move(test_id)); + } + return operand; + } + + statement_ptr parse_filter_expression() { + auto operand = parse_call_member_expression(); + while (is(token::pipe)) { + current++; + auto filter = parse_primary_expression(); + if (is(token::open_paren)) filter = parse_call_expression(std::move(filter)); + operand = std::make_unique(std::move(operand), std::move(filter)); + } + return operand; + } + + statement_ptr parse_call_member_expression() { + // Handle member expressions recursively + auto member = parse_member_expression(parse_primary_expression()); + return is(token::open_paren) + ? parse_call_expression(std::move(member)) // foo.x() + : std::move(member); + } + + statement_ptr parse_call_expression(statement_ptr callee) { + auto expr = std::make_unique(std::move(callee), parse_args()); + auto member = parse_member_expression(std::move(expr)); // foo.x().y + return is(token::open_paren) + ? parse_call_expression(std::move(member)) // foo.x()() + : std::move(member); + } + + statements parse_args() { + // comma-separated arguments list + expect(token::open_paren, "Expected ("); + statements args; + while (!is(token::close_paren)) { + statement_ptr arg; + // unpacking: *expr + if (peek().t == token::multiplicative_binary_operator && peek().value == "*") { + ++current; // consume * + arg = std::make_unique(parse_expression()); + } else { + arg = parse_expression(); + if (is(token::equals)) { + // keyword argument + // e.g., func(x = 5, y = a or b) + ++current; // consume equals + arg = std::make_unique(std::move(arg), parse_expression()); + } + } + args.push_back(std::move(arg)); + if (is(token::comma)) { + ++current; // consume comma + } + } + expect(token::close_paren, "Expected )"); + return args; + } + + statement_ptr parse_member_expression(statement_ptr object) { + while (is(token::dot) || is(token::open_square_bracket)) { + auto op = tokens[current++]; + bool computed = op.t == token::open_square_bracket; + statement_ptr prop; + if (computed) { + prop = parse_member_expression_arguments(); + expect(token::close_square_bracket, "Expected ]"); + } else { + prop = parse_primary_expression(); + } + object = std::make_unique(std::move(object), std::move(prop), computed); + } + return object; + } + + statement_ptr parse_member_expression_arguments() { + // NOTE: This also handles slice expressions colon-separated arguments list + // e.g., ['test'], [0], [:2], [1:], [1:2], [1:2:3] + statements slices; + bool is_slice = false; + while (!is(token::close_square_bracket)) { + if (is(token::colon)) { + // A case where a default is used + // e.g., [:2] will be parsed as [undefined, 2] + slices.push_back(nullptr); + ++current; // consume colon + is_slice = true; + } else { + slices.push_back(parse_expression()); + if (is(token::colon)) { + ++current; // consume colon after expression, if it exists + is_slice = true; + } + } + } + if (is_slice) { + statement_ptr start = slices.size() > 0 ? std::move(slices[0]) : nullptr; + statement_ptr stop = slices.size() > 1 ? std::move(slices[1]) : nullptr; + statement_ptr step = slices.size() > 2 ? std::move(slices[2]) : nullptr; + return std::make_unique(std::move(start), std::move(stop), std::move(step)); + } + return std::move(slices[0]); + } + + statement_ptr parse_primary_expression() { + auto t = tokens[current++]; + switch (t.t) { + case token::numeric_literal: + if (t.value.find('.') != std::string::npos) return std::make_unique(std::stod(t.value)); + return std::make_unique(std::stoll(t.value)); + case token::string_literal: { + std::string val = t.value; + while (is(token::string_literal)) val += tokens[current++].value; + return std::make_unique(val); + } + case token::identifier: + return std::make_unique(t.value); + case token::open_paren: { + auto expr = parse_expression_sequence(); + expect(token::close_paren, "Expected )"); + return expr; + } + case token::open_square_bracket: { + statements vals; + while (!is(token::close_square_bracket)) { + vals.push_back(parse_expression()); + if (is(token::comma)) current++; + } + current++; + return std::make_unique(std::move(vals)); + } + case token::open_curly_bracket: { + std::vector> pairs; + while (!is(token::close_curly_bracket)) { + auto key = parse_expression(); + expect(token::colon, "Expected :"); + pairs.push_back({std::move(key), parse_expression()}); + if (is(token::comma)) current++; + } + current++; + return std::make_unique(std::move(pairs)); + } + default: + throw std::runtime_error("Unexpected token: " + t.value); + } + } +}; + +statement_ptr parse(const std::vector& tokens) { + return parser(tokens).parse(); +} + +} // namespace jinja From 8cea1ed6b0d81fada93e60a4e41f2b31df5cc283 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 27 Dec 2025 12:55:01 +0100 Subject: [PATCH 06/47] parser ok --- common/CMakeLists.txt | 3 +++ common/jinja/jinja-lexer.h | 4 +++- common/jinja/jinja-parser.cpp | 9 +++++---- common/jinja/jinja-parser.h | 16 ++++++++++++++++ common/jinja/jinja-vm.cpp | 0 common/jinja/jinja-vm.h | 28 ++++++++++++---------------- tests/test-chat-jinja.cpp | 33 ++++++--------------------------- 7 files changed, 45 insertions(+), 48 deletions(-) create mode 100644 common/jinja/jinja-parser.h create mode 100644 common/jinja/jinja-vm.cpp diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index f7b99159e3..49ce25a842 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -83,6 +83,9 @@ add_library(${TARGET} STATIC speculative.h unicode.cpp unicode.h + jinja/jinja-lexer.cpp + jinja/jinja-parser.cpp + jinja/jinja-vm.cpp ) target_include_directories(${TARGET} PUBLIC . ../vendor) diff --git a/common/jinja/jinja-lexer.h b/common/jinja/jinja-lexer.h index 2011e487b1..3ed173a4f0 100644 --- a/common/jinja/jinja-lexer.h +++ b/common/jinja/jinja-lexer.h @@ -1,3 +1,5 @@ +#pragma once + #include #include #include @@ -48,7 +50,7 @@ struct token { std::string value; }; -std::string type_to_string(token::type t) { +static std::string type_to_string(token::type t) { switch (t) { case token::undefined: return "undefined"; case token::text: return "text"; diff --git a/common/jinja/jinja-parser.cpp b/common/jinja/jinja-parser.cpp index 07cb71fe11..5b20f010dc 100644 --- a/common/jinja/jinja-parser.cpp +++ b/common/jinja/jinja-parser.cpp @@ -1,5 +1,6 @@ #include "jinja-lexer.h" #include "jinja-vm.h" +#include "jinja-parser.h" #include #include @@ -22,12 +23,12 @@ class parser { public: parser(const std::vector & t) : tokens(t) {} - statement_ptr parse() { + program parse() { statements body; while (current < tokens.size()) { body.push_back(parse_any()); } - return std::make_unique(std::move(body)); + return program(std::move(body)); } private: @@ -320,7 +321,7 @@ private: statement_ptr parse_logical_or_expression() { auto left = parse_logical_and_expression(); while (is_identifier("or")) { - auto op = tokens[current++]; + token op = tokens[current++]; left = std::make_unique(op, std::move(left), parse_logical_and_expression()); } return left; @@ -538,7 +539,7 @@ private: } }; -statement_ptr parse(const std::vector& tokens) { +program parse_from_tokens(const std::vector & tokens) { return parser(tokens).parse(); } diff --git a/common/jinja/jinja-parser.h b/common/jinja/jinja-parser.h new file mode 100644 index 0000000000..ea212ad181 --- /dev/null +++ b/common/jinja/jinja-parser.h @@ -0,0 +1,16 @@ +#pragma once + +#include "jinja-lexer.h" +#include "jinja-vm.h" + +#include +#include +#include +#include +#include + +namespace jinja { + +program parse_from_tokens(const std::vector & tokens); + +} // namespace jinja diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp new file mode 100644 index 0000000000..e69de29bb2 diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 9ee2917531..b848ec4d9b 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -1,3 +1,4 @@ +#pragma once #include "jinja-lexer.h" #include @@ -181,26 +182,21 @@ struct identifier : public expression { // Literals -/** - * Abstract base class for all Literal expressions. - * Should not be instantiated directly. - */ -template -struct literal : public expression { - T value; - explicit literal(T && value) : value(std::move(value)) {} - std::string type() const override { return "Literal"; } -}; - -struct integer_literal : public literal { +struct integer_literal : public expression { + int64_t value; + explicit integer_literal(int64_t value) : value(value) {} std::string type() const override { return "IntegerLiteral"; } }; -struct float_literal : public literal { +struct float_literal : public expression { + double value; + explicit float_literal(double value) : value(value) {} std::string type() const override { return "FloatLiteral"; } }; -struct string_literal : public literal { +struct string_literal : public expression { + std::string value; + explicit string_literal(const std::string & value) : value(value) {} std::string type() const override { return "StringLiteral"; } }; @@ -240,11 +236,11 @@ struct object_literal : public expression { * of operations being determined by the operator. */ struct binary_expression : public expression { - token::type op; + token op; statement_ptr left; statement_ptr right; - binary_expression(token::type op, statement_ptr && left, statement_ptr && right) + binary_expression(token op, statement_ptr && left, statement_ptr && right) : op(op), left(std::move(left)), right(std::move(right)) { chk_type(this->left); chk_type(this->right); diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 9fa0c7c817..ebebba37b1 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -7,9 +7,7 @@ #undef NDEBUG #include -#include "peg-parser.h" -#include "json-schema-to-grammar.h" -#include "jinja/jinja-compiler.h" +#include "jinja/jinja-parser.h" #include "jinja/jinja-lexer.h" int main(void) { @@ -26,30 +24,11 @@ int main(void) { std::cout << "token: type=" << static_cast(tok.t) << " text='" << tok.value << "'\n"; } - // jinja::compiler compiler; - // compiler.builder.set_root(compiler.root); - // auto parser = compiler.builder.build(); - - // auto grammar = build_grammar([&](const common_grammar_builder & builder0) { - // parser.build_grammar(builder0); - // }); - // printf("== GRAMMAR ==\n"); - // printf("%s\n", grammar.c_str()); - - // // printf("== DUMP ==\n"); - // // printf("%s\n", parser.dump(compiler.root.id()).c_str()); - - // printf("== PARSE ==\n"); - - // common_peg_parse_context ctx(contents); - // const auto result = parser.parse(ctx); - // if (!result.success()) { - // throw std::runtime_error("failed to parse, type = " + std::to_string(result.type)); - // } - - // ctx.ast.visit(result, [&](const common_peg_ast_node & node) { - // printf("node: rule='%s' text='%s'\n", node.rule.c_str(), std::string(node.text).c_str()); - // }); + jinja::program ast = jinja::parse_from_tokens(tokens); + std::cout << "\n=== AST ===\n"; + for (const auto & stmt : ast.body) { + std::cout << "stmt type: " << stmt->type() << "\n"; + } return 0; } From 7ad6eb39caf2ba75fd585f317a255ebc9ca47080 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 27 Dec 2025 16:00:07 +0100 Subject: [PATCH 07/47] binary_expression::execute --- common/jinja/jinja-parser.cpp | 2 +- common/jinja/jinja-value.h | 98 ++++++++++++++++++++++ common/jinja/jinja-vm.cpp | 151 ++++++++++++++++++++++++++++++++++ common/jinja/jinja-vm.h | 30 ++++--- 4 files changed, 267 insertions(+), 14 deletions(-) create mode 100644 common/jinja/jinja-value.h diff --git a/common/jinja/jinja-parser.cpp b/common/jinja/jinja-parser.cpp index 5b20f010dc..de61023560 100644 --- a/common/jinja/jinja-parser.cpp +++ b/common/jinja/jinja-parser.cpp @@ -474,7 +474,7 @@ private: while (!is(token::close_square_bracket)) { if (is(token::colon)) { // A case where a default is used - // e.g., [:2] will be parsed as [undefined, 2] + // e.g., [:2] will be parsed as [undefined, 2] slices.push_back(nullptr); ++current; // consume colon is_slice = true; diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h new file mode 100644 index 0000000000..b06f465a1d --- /dev/null +++ b/common/jinja/jinja-value.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include +#include + + +namespace jinja { + +struct value_t; +using value = std::unique_ptr; + +struct value_t { + int64_t val_int; + double val_flt; + std::string val_str; + bool val_bool; + std::vector val_arr; + std::map val_obj; + + virtual std::string type() const { return ""; } + + virtual ~value_t() = default; + virtual int64_t as_int() const { throw std::runtime_error("Not an int value"); } + virtual double as_float() const { throw std::runtime_error("Not a float value"); } + virtual std::string as_string() const { throw std::runtime_error("Not a string value"); } + virtual bool as_bool() const { throw std::runtime_error("Not a bool value"); } + virtual const std::vector & as_array() const { throw std::runtime_error("Not an array value"); } + virtual const std::map & as_object() const { throw std::runtime_error("Not an object value"); } + virtual bool is_null() const { return false; } + virtual bool is_undefined() const { return false; } + + virtual bool operator==(const value & other) const { + // TODO + return false; + } + virtual bool operator!=(const value & other) const { + return !(*this == other); + } +}; + +struct value_int_t : public value_t { + value_int_t(int64_t v) { val_int = v; } + virtual std::string type() const override { return "Integer"; } + virtual int64_t as_int() const override { return val_int; } + virtual double as_float() const override { return static_cast(val_int); } +}; +using value_int = std::unique_ptr; + +struct value_float_t : public value_t { + value_float_t(double v) { val_flt = v; } + virtual std::string type() const override { return "Float"; } + virtual double as_float() const override { return val_flt; } + virtual int64_t as_int() const override { return static_cast(val_flt); } +}; +using value_float = std::unique_ptr; + +struct value_string_t : public value_t { + value_string_t(const std::string & v) { val_str = v; } + virtual std::string type() const override { return "String"; } + virtual std::string as_string() const override { return val_str; } +}; +using value_string = std::unique_ptr; + +struct value_bool_t : public value_t { + value_bool_t(bool v) { val_bool = v; } + virtual std::string type() const override { return "Boolean"; } + virtual bool as_bool() const override { return val_bool; } +}; +using value_bool = std::unique_ptr; + +struct value_array_t : public value_t { + value_array_t(const std::vector && v) { val_arr = std::move(v); } + virtual std::string type() const override { return "Array"; } + virtual const std::vector & as_array() const override { return val_arr; } +}; +using value_array = std::unique_ptr; + +struct value_object_t : public value_t { + value_object_t(const std::map & v) { val_obj = v; } + virtual std::string type() const override { return "Object"; } + virtual const std::map & as_object() const override { return val_obj; } +}; +using value_object = std::unique_ptr; + +struct value_null_t : public value_t { + virtual std::string type() const override { return "Null"; } + virtual bool is_null() const override { return true; } +}; +using value_null = std::unique_ptr; + +struct value_undefined_t : public value_t { + virtual std::string type() const override { return "Undefined"; } + virtual bool is_undefined() const override { return true; } +}; +using value_undefined = std::unique_ptr; + +} // namespace jinja diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index e69de29bb2..1c3ec49013 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -0,0 +1,151 @@ +#include "jinja-lexer.h" +#include "jinja-vm.h" +#include "jinja-parser.h" + +#include +#include +#include +#include + +namespace jinja { + +// Helper to check type without asserting (useful for logic) +template +static bool is_type(const value & ptr) { + return dynamic_cast(ptr.get()) != nullptr; +} + +struct vm { + context & ctx; + explicit vm(context & ctx) : ctx(ctx) {} + + void execute(program & prog) { + for (auto & stmt : prog.body) { + stmt->execute(ctx); + } + } +}; + +value binary_expression::execute(context & ctx) { + value left_val = left->execute(ctx); + + // Logical operators + if (op.value == "and") { + return left_val->as_bool() ? right->execute(ctx) : std::move(left_val); + } else if (op.value == "or") { + return left_val->as_bool() ? std::move(left_val) : right->execute(ctx); + } + + // Equality operators + value right_val = right->execute(ctx); + if (op.value == "==") { + return std::make_unique(left_val == right_val); + } else if (op.value == "!=") { + return std::make_unique(left_val != right_val); + } + + // Handle undefined and null values + if (is_type(left_val) || is_type(right_val)) { + if (is_type(right_val) && (op.value == "in" || op.value == "not in")) { + // Special case: `anything in undefined` is `false` and `anything not in undefined` is `true` + return std::make_unique(op.value == "not in"); + } + throw std::runtime_error("Cannot perform operation " + op.value + " on undefined values"); + } else if (is_type(left_val) || is_type(right_val)) { + throw std::runtime_error("Cannot perform operation on null values"); + } + + // String concatenation with ~ + if (op.value == "~") { + return std::make_unique(left_val->as_string() + right_val->as_string()); + } + + // Float operations + if ((is_type(left_val) || is_type(left_val)) && + (is_type(right_val) || is_type(right_val))) { + double a = left_val->as_float(); + double b = right_val->as_float(); + if (op.value == "+" || op.value == "-" || op.value == "*") { + double res = (op.value == "+") ? a + b : (op.value == "-") ? a - b : a * b; + bool is_float = is_type(left_val) || is_type(right_val); + if (is_float) { + return std::make_unique(res); + } else { + return std::make_unique(static_cast(res)); + } + } else if (op.value == "/") { + return std::make_unique(a / b); + } else if (op.value == "%") { + double rem = std::fmod(a, b); + bool is_float = is_type(left_val) || is_type(right_val); + if (is_float) { + return std::make_unique(rem); + } else { + return std::make_unique(static_cast(rem)); + } + } else if (op.value == "<") { + return std::make_unique(a < b); + } else if (op.value == ">") { + return std::make_unique(a > b); + } else if (op.value == ">=") { + return std::make_unique(a >= b); + } else if (op.value == "<=") { + return std::make_unique(a <= b); + } + } + + // Array operations + if (is_type(left_val) && is_type(right_val)) { + if (op.value == "+") { + auto& left_arr = left_val->as_array(); + auto& right_arr = right_val->as_array(); + std::vector result = left_arr; + for (auto & v : right_arr) { + result.push_back(std::move(v)); + } + return std::make_unique(result); + } + } else if (is_type(right_val)) { + auto & arr = right_val->as_array(); + bool member = std::find_if(arr.begin(), arr.end(), [&](const value& v) { return v == left_val; }) != arr.end(); + if (op.value == "in") { + return std::make_unique(member); + } else if (op.value == "not in") { + return std::make_unique(!member); + } + } + + // String concatenation + if (is_type(left_val) || is_type(right_val)) { + if (op.value == "+") { + return std::make_unique(left_val->as_string() + right_val->as_string()); + } + } + + // String membership + if (is_type(left_val) && is_type(right_val)) { + auto left_str = left_val->as_string(); + auto right_str = right_val->as_string(); + if (op.value == "in") { + return std::make_unique(right_str.find(left_str) != std::string::npos); + } else if (op.value == "not in") { + return std::make_unique(right_str.find(left_str) == std::string::npos); + } + } + + // String in object + if (is_type(left_val) && is_type(right_val)) { + auto key = left_val->as_string(); + auto & obj = right_val->as_object(); + bool has_key = obj.find(key) != obj.end(); + if (op.value == "in") { + return std::make_unique(has_key); + } else if (op.value == "not in") { + return std::make_unique(!has_key); + } + } + + throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type()); +} + +} // namespace jinja diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index b848ec4d9b..a77f21cdfa 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -1,16 +1,20 @@ #pragma once + #include "jinja-lexer.h" +#include "jinja-value.h" #include #include #include #include +#include namespace jinja { struct context { - // TODO + std::ostringstream out; + std::map var; }; /** @@ -19,7 +23,7 @@ struct context { struct statement { virtual ~statement() = default; virtual std::string type() const { return "Statement"; } - virtual void execute(context & ctx) = 0; + virtual value execute(context & ctx) = 0; }; using statement_ptr = std::unique_ptr; @@ -46,7 +50,6 @@ static void chk_type(const statement_ptr & ptr) { */ struct expression : public statement { std::string type() const override { return "Expression"; } - void execute(context & ctx) override {} }; // Statements @@ -56,7 +59,7 @@ struct program : public statement { explicit program(statements && body) : body(std::move(body)) {} std::string type() const override { return "Program"; } - void execute(context & ctx) override {} + value execute(context & ctx) override {} }; struct if_statement : public statement { @@ -70,7 +73,7 @@ struct if_statement : public statement { } std::string type() const override { return "If"; } - void execute(context & ctx) override {} + value execute(context & ctx) override {} }; struct identifier; @@ -94,17 +97,17 @@ struct for_statement : public statement { } std::string type() const override { return "For"; } - void execute(context & ctx) override {} + value execute(context & ctx) override {} }; struct break_statement : public statement { std::string type() const override { return "Break"; } - void execute(context & ctx) override {} + value execute(context & ctx) override {} }; struct continue_statement : public statement { std::string type() const override { return "Continue"; } - void execute(context & ctx) override {} + value execute(context & ctx) override {} }; struct set_statement : public statement { @@ -119,7 +122,7 @@ struct set_statement : public statement { } std::string type() const override { return "Set"; } - void execute(context & ctx) override {} + value execute(context & ctx) override {} }; struct macro_statement : public statement { @@ -134,14 +137,14 @@ struct macro_statement : public statement { } std::string type() const override { return "Macro"; } - void execute(context & ctx) override {} + value execute(context & ctx) override {} }; struct comment_statement : public statement { std::string value; explicit comment_statement(const std::string & value) : value(value) {} std::string type() const override { return "Comment"; } - void execute(context & ctx) override {} + value execute(context & ctx) override {} }; // Expressions @@ -246,6 +249,7 @@ struct binary_expression : public expression { chk_type(this->right); } std::string type() const override { return "BinaryExpression"; } + value execute(context & ctx) override; }; /** @@ -273,7 +277,7 @@ struct filter_statement : public statement { chk_type(this->filter); } std::string type() const override { return "FilterStatement"; } - void execute(context & ctx) override {} + value execute(context & ctx) override {} }; /** @@ -369,7 +373,7 @@ struct call_statement : public statement { for (const auto& arg : this->caller_args) chk_type(arg); } std::string type() const override { return "CallStatement"; } - void execute(context & ctx) override {} + value execute(context & ctx) override {} }; struct ternary_expression : public expression { From 8d1e9a0d127b9bef10883bdb890f50f83799fda8 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 27 Dec 2025 16:06:23 +0100 Subject: [PATCH 08/47] shadow naming --- common/jinja/jinja-vm.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index a77f21cdfa..58a71abe24 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -112,13 +112,13 @@ struct continue_statement : public statement { struct set_statement : public statement { statement_ptr assignee; - statement_ptr value; + statement_ptr val; statements body; set_statement(statement_ptr && assignee, statement_ptr && value, statements && body) - : assignee(std::move(assignee)), value(std::move(value)), body(std::move(body)) { + : assignee(std::move(assignee)), val(std::move(value)), body(std::move(body)) { chk_type(this->assignee); - chk_type(this->value); + chk_type(this->val); } std::string type() const override { return "Set"; } @@ -141,8 +141,8 @@ struct macro_statement : public statement { }; struct comment_statement : public statement { - std::string value; - explicit comment_statement(const std::string & value) : value(value) {} + std::string val; + explicit comment_statement(const std::string & v) : val(v) {} std::string type() const override { return "Comment"; } value execute(context & ctx) override {} }; @@ -266,6 +266,7 @@ struct filter_expression : public expression { chk_type(this->filter); } std::string type() const override { return "FilterExpression"; } + value execute(context & ctx) override; }; struct filter_statement : public statement { From d8ef00e610071267f90dca1c582b23eec042401d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 27 Dec 2025 20:16:46 +0100 Subject: [PATCH 09/47] bin ops works! --- common/jinja/jinja-value.h | 84 ++++++++++++++++++++++++++++++---- common/jinja/jinja-vm.cpp | 94 ++++++++++++++++++++++++++++++-------- common/jinja/jinja-vm.h | 59 +++++++++++++++++------- tests/test-chat-jinja.cpp | 15 +++++- 4 files changed, 206 insertions(+), 46 deletions(-) diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index b06f465a1d..01cfffe529 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -15,12 +15,22 @@ struct value_t { double val_flt; std::string val_str; bool val_bool; - std::vector val_arr; - std::map val_obj; + + // array and object are stored as shared_ptr to allow reference access + // example: + // my_obj = {"a": 1, "b": 2} + // my_arr = [my_obj] + // my_obj["a"] = 3 + // print(my_arr[0]["a"]) # should print 3 + std::shared_ptr> val_arr; + std::shared_ptr> val_obj; + + value_t() = default; + value_t(const value_t &) = default; + virtual ~value_t() = default; virtual std::string type() const { return ""; } - virtual ~value_t() = default; virtual int64_t as_int() const { throw std::runtime_error("Not an int value"); } virtual double as_float() const { throw std::runtime_error("Not a float value"); } virtual std::string as_string() const { throw std::runtime_error("Not a string value"); } @@ -30,6 +40,10 @@ struct value_t { virtual bool is_null() const { return false; } virtual bool is_undefined() const { return false; } + virtual value clone() const { + return std::make_unique(*this); + } + virtual bool operator==(const value & other) const { // TODO return false; @@ -44,6 +58,8 @@ struct value_int_t : public value_t { virtual std::string type() const override { return "Integer"; } virtual int64_t as_int() const override { return val_int; } virtual double as_float() const override { return static_cast(val_int); } + virtual std::string as_string() const override { return std::to_string(val_int); } + virtual value clone() const override { return std::make_unique(*this); } }; using value_int = std::unique_ptr; @@ -52,6 +68,8 @@ struct value_float_t : public value_t { virtual std::string type() const override { return "Float"; } virtual double as_float() const override { return val_flt; } virtual int64_t as_int() const override { return static_cast(val_flt); } + virtual std::string as_string() const override { return std::to_string(val_flt); } + virtual value clone() const override { return std::make_unique(*this); } }; using value_float = std::unique_ptr; @@ -59,6 +77,7 @@ struct value_string_t : public value_t { value_string_t(const std::string & v) { val_str = v; } virtual std::string type() const override { return "String"; } virtual std::string as_string() const override { return val_str; } + virtual value clone() const override { return std::make_unique(*this); } }; using value_string = std::unique_ptr; @@ -66,32 +85,81 @@ struct value_bool_t : public value_t { value_bool_t(bool v) { val_bool = v; } virtual std::string type() const override { return "Boolean"; } virtual bool as_bool() const override { return val_bool; } + virtual std::string as_string() const override { return val_bool ? "True" : "False"; } + virtual value clone() const override { return std::make_unique(*this); } }; using value_bool = std::unique_ptr; struct value_array_t : public value_t { - value_array_t(const std::vector && v) { val_arr = std::move(v); } + value_array_t() { + val_arr = std::make_shared>(); + } + value_array_t(value & v) { + // point to the same underlying data + val_arr = v->val_arr; + } + value_array_t(value_array_t & other, size_t start = 0, size_t end = -1) { + val_arr = std::make_shared>(); + size_t sz = other.val_arr->size(); + if (end == static_cast(-1) || end > sz) { + end = sz; + } + if (start > end || start >= sz) { + return; + } + for (size_t i = start; i < end; i++) { + val_arr->push_back(other.val_arr->at(i)->clone()); + } + } virtual std::string type() const override { return "Array"; } - virtual const std::vector & as_array() const override { return val_arr; } + virtual const std::vector & as_array() const override { return *val_arr; } + virtual value clone() const override { + auto tmp = std::make_unique(); + tmp->val_arr = this->val_arr; + return tmp; + } }; using value_array = std::unique_ptr; -struct value_object_t : public value_t { - value_object_t(const std::map & v) { val_obj = v; } +/*struct value_object_t : public value_t { + value_object_t() { + val_obj = std::make_shared>(); + } + value_object_t(value & v) { + // point to the same underlying data + val_obj = v->val_obj; + } + value_object_t(const std::map & obj) { + val_obj = std::make_shared>(obj); + } virtual std::string type() const override { return "Object"; } - virtual const std::map & as_object() const override { return val_obj; } + virtual const std::map & as_object() const override { return *val_obj; } + virtual value clone() const override { + auto tmp = std::make_unique(); + tmp->val_obj = this->val_obj; + return tmp; + } +}; +using value_object = std::unique_ptr;*/ + +struct value_object_t : public value_t { + virtual std::string type() const override { return "TEST"; } + virtual bool is_null() const override { return true; } + virtual value clone() const override { return std::make_unique(*this); } }; using value_object = std::unique_ptr; struct value_null_t : public value_t { virtual std::string type() const override { return "Null"; } virtual bool is_null() const override { return true; } + virtual value clone() const override { return std::make_unique(*this); } }; using value_null = std::unique_ptr; struct value_undefined_t : public value_t { virtual std::string type() const override { return "Undefined"; } virtual bool is_undefined() const override { return true; } + virtual value clone() const override { return std::make_unique(*this); } }; using value_undefined = std::unique_ptr; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 1c3ec49013..aff6e90603 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -9,22 +9,27 @@ namespace jinja { -// Helper to check type without asserting (useful for logic) +// Helper to extract the inner type if T is unique_ptr, else T itself template -static bool is_type(const value & ptr) { - return dynamic_cast(ptr.get()) != nullptr; +struct extract_pointee { + using type = T; +}; + +template +struct extract_pointee> { + using type = U; +}; + +template +static bool is_type(const value& ptr) { + using PointeeType = typename extract_pointee::type; + return dynamic_cast(ptr.get()) != nullptr; } -struct vm { - context & ctx; - explicit vm(context & ctx) : ctx(ctx) {} - - void execute(program & prog) { - for (auto & stmt : prog.body) { - stmt->execute(ctx); - } - } -}; +template +static bool is_stmt(const statement_ptr & ptr) { + return dynamic_cast(ptr.get()) != nullptr; +} value binary_expression::execute(context & ctx) { value left_val = left->execute(ctx); @@ -97,13 +102,16 @@ value binary_expression::execute(context & ctx) { // Array operations if (is_type(left_val) && is_type(right_val)) { if (op.value == "+") { - auto& left_arr = left_val->as_array(); - auto& right_arr = right_val->as_array(); - std::vector result = left_arr; - for (auto & v : right_arr) { - result.push_back(std::move(v)); + auto & left_arr = left_val->as_array(); + auto & right_arr = right_val->as_array(); + auto result = std::make_unique(); + for (const auto & item : left_arr) { + result->val_arr->push_back(item->clone()); } - return std::make_unique(result); + for (const auto & item : right_arr) { + result->val_arr->push_back(item->clone()); + } + return result; } } else if (is_type(right_val)) { auto & arr = right_val->as_array(); @@ -148,4 +156,52 @@ value binary_expression::execute(context & ctx) { throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type()); } +value filter_expression::execute(context & ctx) { + value input = operand->execute(ctx); + value filter_func = filter->execute(ctx); + + if (is_stmt(filter)) { + auto filter_val = dynamic_cast(filter.get())->value; + + if (filter_val == "to_json") { + // TODO: Implement to_json filter + throw std::runtime_error("to_json filter not implemented"); + } + + if (is_type(input)) { + auto & arr = input->as_array(); + if (filter_val == "list") { + return std::make_unique(input); + } else if (filter_val == "first") { + if (arr.empty()) { + return std::make_unique(); + } + return arr[0]->clone(); + } else if (filter_val == "last") { + if (arr.empty()) { + return std::make_unique(); + } + return arr[arr.size() - 1]->clone(); + } else if (filter_val == "length") { + return std::make_unique(static_cast(arr.size())); + } else { + // TODO: reverse, sort, join, string, unique + throw std::runtime_error("Unknown filter '" + filter_val + "' for array"); + } + + } else if (is_type(input)) { + auto str = input->as_string(); + // TODO + throw std::runtime_error("Unknown filter '" + filter_val + "' for string"); + + } else if (is_type(input) || is_type(input)) { + // TODO + throw std::runtime_error("Unknown filter '" + filter_val + "' for number"); + + } else { + throw std::runtime_error("Filters not supported for type " + input->type()); + } + } +} + } // namespace jinja diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 58a71abe24..2c547294a8 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -23,7 +23,7 @@ struct context { struct statement { virtual ~statement() = default; virtual std::string type() const { return "Statement"; } - virtual value execute(context & ctx) = 0; + virtual value execute(context & ctx) { throw std::runtime_error("cannot exec " + type()); }; }; using statement_ptr = std::unique_ptr; @@ -186,44 +186,53 @@ struct identifier : public expression { // Literals struct integer_literal : public expression { - int64_t value; - explicit integer_literal(int64_t value) : value(value) {} + int64_t val; + explicit integer_literal(int64_t val) : val(val) {} std::string type() const override { return "IntegerLiteral"; } + value execute(context & ctx) override { + return std::make_unique(val); + } }; struct float_literal : public expression { - double value; - explicit float_literal(double value) : value(value) {} + double val; + explicit float_literal(double val) : val(val) {} std::string type() const override { return "FloatLiteral"; } + value execute(context & ctx) override { + return std::make_unique(val); + } }; struct string_literal : public expression { - std::string value; - explicit string_literal(const std::string & value) : value(value) {} + std::string val; + explicit string_literal(const std::string & val) : val(val) {} std::string type() const override { return "StringLiteral"; } + value execute(context & ctx) override { + return std::make_unique(val); + } }; struct array_literal : public expression { - statements value; - explicit array_literal(statements && value) : value(std::move(value)) { - for (const auto& item : this->value) chk_type(item); + statements val; + explicit array_literal(statements && val) : val(std::move(val)) { + for (const auto& item : this->val) chk_type(item); } std::string type() const override { return "ArrayLiteral"; } }; struct tuple_literal : public expression { - statements value; - explicit tuple_literal(statements && value) : value(std::move(value)) { - for (const auto& item : this->value) chk_type(item); + statements val; + explicit tuple_literal(statements && val) : val(std::move(val)) { + for (const auto & item : this->val) chk_type(item); } std::string type() const override { return "TupleLiteral"; } }; struct object_literal : public expression { - std::vector> value; - explicit object_literal(std::vector> && value) - : value(std::move(value)) { - for (const auto & pair : this->value) { + std::vector> val; + explicit object_literal(std::vector> && val) + : val(std::move(val)) { + for (const auto & pair : this->val) { chk_type(pair.first); chk_type(pair.second); } @@ -391,4 +400,20 @@ struct ternary_expression : public expression { std::string type() const override { return "Ternary"; } }; +////////////////////// + +struct vm { + context & ctx; + explicit vm(context & ctx) : ctx(ctx) {} + + std::vector execute(program & prog) { + std::vector results; + for (auto & stmt : prog.body) { + value res = stmt->execute(ctx); + results.push_back(std::move(res)); + } + return results; + } +}; + } // namespace jinja diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index ebebba37b1..e0b5d8f8d9 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -11,7 +11,9 @@ #include "jinja/jinja-lexer.h" int main(void) { - std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; + //std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; + + std::string contents = "{{ 'hi' + 'fi' }}"; std::cout << "=== INPUT ===\n" << contents << "\n\n"; @@ -24,11 +26,20 @@ int main(void) { std::cout << "token: type=" << static_cast(tok.t) << " text='" << tok.value << "'\n"; } - jinja::program ast = jinja::parse_from_tokens(tokens); std::cout << "\n=== AST ===\n"; + jinja::program ast = jinja::parse_from_tokens(tokens); for (const auto & stmt : ast.body) { std::cout << "stmt type: " << stmt->type() << "\n"; } + std::cout << "\n=== OUTPUT ===\n"; + jinja::context ctx; + jinja::vm vm(ctx); + auto results = vm.execute(ast); + for (const auto & res : results) { + std::cout << "result type: " << res->type() << "\n"; + std::cout << "result value: " << res->as_string() << "\n"; + } + return 0; } From 5a041e65b8aabf5238aec771035b96bbdeda144e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 27 Dec 2025 20:38:06 +0100 Subject: [PATCH 10/47] fix map object --- common/jinja/jinja-value.h | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 01cfffe529..87ee91f693 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -121,7 +121,7 @@ struct value_array_t : public value_t { }; using value_array = std::unique_ptr; -/*struct value_object_t : public value_t { +struct value_object_t : public value_t { value_object_t() { val_obj = std::make_shared>(); } @@ -130,7 +130,10 @@ using value_array = std::unique_ptr; val_obj = v->val_obj; } value_object_t(const std::map & obj) { - val_obj = std::make_shared>(obj); + val_obj = std::make_shared>(); + for (const auto & pair : obj) { + (*val_obj)[pair.first] = pair.second->clone(); + } } virtual std::string type() const override { return "Object"; } virtual const std::map & as_object() const override { return *val_obj; } @@ -140,13 +143,6 @@ using value_array = std::unique_ptr; return tmp; } }; -using value_object = std::unique_ptr;*/ - -struct value_object_t : public value_t { - virtual std::string type() const override { return "TEST"; } - virtual bool is_null() const override { return true; } - virtual value clone() const override { return std::make_unique(*this); } -}; using value_object = std::unique_ptr; struct value_null_t : public value_t { From 15b3dbab05a85a892c2d0ebaf6f3b6913d3ea24e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 27 Dec 2025 21:52:50 +0100 Subject: [PATCH 11/47] add string builtins --- common/CMakeLists.txt | 4 + common/jinja/jinja-value.h | 78 ++++++++++++++++ common/jinja/jinja-vm-builtins.cpp | 139 +++++++++++++++++++++++++++++ common/jinja/jinja-vm.cpp | 58 +++++------- common/jinja/jinja-vm.h | 2 +- tests/test-chat-jinja.cpp | 2 +- 6 files changed, 247 insertions(+), 36 deletions(-) create mode 100644 common/jinja/jinja-vm-builtins.cpp diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 49ce25a842..4ed0df100f 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -84,8 +84,12 @@ add_library(${TARGET} STATIC unicode.cpp unicode.h jinja/jinja-lexer.cpp + jinja/jinja-lexer.h jinja/jinja-parser.cpp + jinja/jinja-parser.h jinja/jinja-vm.cpp + jinja/jinja-vm.h + jinja/jinja-vm-builtins.cpp ) target_include_directories(${TARGET} PUBLIC . ../vendor) diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 87ee91f693..289acb1c7d 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -3,6 +3,8 @@ #include #include #include +#include +#include namespace jinja { @@ -10,6 +12,63 @@ namespace jinja { struct value_t; using value = std::unique_ptr; + +// Helper to check the type of a value +template +struct extract_pointee { + using type = T; +}; +template +struct extract_pointee> { + using type = U; +}; +template +bool is_val(const value & ptr) { + using PointeeType = typename extract_pointee::type; + return dynamic_cast(ptr.get()) != nullptr; +} +template +bool mk_val(Args&&... args) { + using PointeeType = typename extract_pointee::type; + return std::make_unique(std::forward(args)...); +} +template +void ensure_val(const value & ptr) { + if (!is_val(ptr)) { + throw std::runtime_error("Expected value of type " + std::string(typeid(T).name())); + } +} +// End Helper + + +struct func_args { + std::vector args; + void ensure_count(size_t count) const { + if (args.size() != count) { + throw std::runtime_error("Expected " + std::to_string(count) + " arguments, got " + std::to_string(args.size())); + } + } + // utility functions + template void ensure_vals() const { + ensure_count(1); + ensure_val(args[0]); + } + template void ensure_vals() const { + ensure_count(2); + ensure_val(args[0]); + ensure_val(args[1]); + } + template void ensure_vals() const { + ensure_count(3); + ensure_val(args[0]); + ensure_val(args[1]); + ensure_val(args[2]); + } +}; + +using func_handler = std::function; +using func_builtins = std::map; + struct value_t { int64_t val_int; double val_flt; @@ -25,6 +84,8 @@ struct value_t { std::shared_ptr> val_arr; std::shared_ptr> val_obj; + func_handler val_func; + value_t() = default; value_t(const value_t &) = default; virtual ~value_t() = default; @@ -37,8 +98,12 @@ struct value_t { virtual bool as_bool() const { throw std::runtime_error("Not a bool value"); } virtual const std::vector & as_array() const { throw std::runtime_error("Not an array value"); } virtual const std::map & as_object() const { throw std::runtime_error("Not an object value"); } + virtual value invoke(const func_args &) const { throw std::runtime_error("Not a function value"); } virtual bool is_null() const { return false; } virtual bool is_undefined() const { return false; } + virtual const func_builtins & get_builtins() const { + throw std::runtime_error("No builtins available for type " + type()); + } virtual value clone() const { return std::make_unique(*this); @@ -78,6 +143,7 @@ struct value_string_t : public value_t { virtual std::string type() const override { return "String"; } virtual std::string as_string() const override { return val_str; } virtual value clone() const override { return std::make_unique(*this); } + const func_builtins & get_builtins() const override; }; using value_string = std::unique_ptr; @@ -145,6 +211,18 @@ struct value_object_t : public value_t { }; using value_object = std::unique_ptr; +struct value_func_t : public value_t { + value_func_t(func_handler & func) { + val_func = func; + } + virtual value invoke(const func_args & args) const override { + return val_func(args); + } + virtual std::string type() const override { return "Function"; } + virtual value clone() const override { return std::make_unique(*this); } +}; +using value_func = std::unique_ptr; + struct value_null_t : public value_t { virtual std::string type() const override { return "Null"; } virtual bool is_null() const override { return true; } diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp new file mode 100644 index 0000000000..85d0681867 --- /dev/null +++ b/common/jinja/jinja-vm-builtins.cpp @@ -0,0 +1,139 @@ +#include "jinja-lexer.h" +#include "jinja-vm.h" +#include "jinja-parser.h" +#include "jinja-value.h" + +#include +#include + +namespace jinja { + +static std::string string_strip(const std::string & str, bool left, bool right) { + size_t start = 0; + size_t end = str.length(); + if (left) { + while (start < end && isspace(static_cast(str[start]))) { + ++start; + } + } + if (right) { + while (end > start && isspace(static_cast(str[end - 1]))) { + --end; + } + } + return str.substr(start, end - start); +} + +static bool string_startswith(const std::string & str, const std::string & prefix) { + if (str.length() < prefix.length()) return false; + return str.compare(0, prefix.length(), prefix) == 0; +} + +static bool string_endswith(const std::string & str, const std::string & suffix) { + if (str.length() < suffix.length()) return false; + return str.compare(str.length() - suffix.length(), suffix.length(), suffix) == 0; +} + +const func_builtins & value_string_t::get_builtins() const { + static const func_builtins builtins = { + {"upper", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + std::transform(str.begin(), str.end(), str.begin(), ::toupper); + return std::make_unique(str); + }}, + {"lower", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + std::transform(str.begin(), str.end(), str.begin(), ::tolower); + return std::make_unique(str); + }}, + {"strip", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + return std::make_unique(string_strip(str, true, true)); + }}, + {"rstrip", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + return std::make_unique(string_strip(str, false, true)); + }}, + {"lstrip", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + return std::make_unique(string_strip(str, true, false)); + }}, + {"title", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + bool capitalize_next = true; + for (char &c : str) { + if (isspace(static_cast(c))) { + capitalize_next = true; + } else if (capitalize_next) { + c = ::toupper(static_cast(c)); + capitalize_next = false; + } else { + c = ::tolower(static_cast(c)); + } + } + return std::make_unique(str); + }}, + {"capitalize", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + if (!str.empty()) { + str[0] = ::toupper(static_cast(str[0])); + std::transform(str.begin() + 1, str.end(), str.begin() + 1, ::tolower); + } + return std::make_unique(str); + }}, + {"length", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + return std::make_unique(str.length()); + }}, + {"startswith", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + std::string prefix = args.args[1]->as_string(); + return std::make_unique(string_startswith(str, prefix)); + }}, + {"endswith", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + std::string suffix = args.args[1]->as_string(); + return std::make_unique(string_endswith(str, suffix)); + }}, + {"split", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + std::string delim = (args.args.size() > 1) ? args.args[1]->as_string() : " "; + auto result = std::make_unique(); + size_t pos = 0; + std::string token; + while ((pos = str.find(delim)) != std::string::npos) { + token = str.substr(0, pos); + result->val_arr->push_back(std::make_unique(token)); + str.erase(0, pos + delim.length()); + } + result->val_arr->push_back(std::make_unique(str)); + return std::move(result); + }}, + {"replace", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + std::string old_str = args.args[1]->as_string(); + std::string new_str = args.args[2]->as_string(); + size_t pos = 0; + while ((pos = str.find(old_str, pos)) != std::string::npos) { + str.replace(pos, old_str.length(), new_str); + pos += new_str.length(); + } + return std::make_unique(str); + }}, + }; + return builtins; +}; + +} // namespace jinja diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index aff6e90603..25106f1e4a 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -1,6 +1,7 @@ #include "jinja-lexer.h" #include "jinja-vm.h" #include "jinja-parser.h" +#include "jinja-value.h" #include #include @@ -9,23 +10,6 @@ namespace jinja { -// Helper to extract the inner type if T is unique_ptr, else T itself -template -struct extract_pointee { - using type = T; -}; - -template -struct extract_pointee> { - using type = U; -}; - -template -static bool is_type(const value& ptr) { - using PointeeType = typename extract_pointee::type; - return dynamic_cast(ptr.get()) != nullptr; -} - template static bool is_stmt(const statement_ptr & ptr) { return dynamic_cast(ptr.get()) != nullptr; @@ -50,13 +34,13 @@ value binary_expression::execute(context & ctx) { } // Handle undefined and null values - if (is_type(left_val) || is_type(right_val)) { - if (is_type(right_val) && (op.value == "in" || op.value == "not in")) { + if (is_val(left_val) || is_val(right_val)) { + if (is_val(right_val) && (op.value == "in" || op.value == "not in")) { // Special case: `anything in undefined` is `false` and `anything not in undefined` is `true` return std::make_unique(op.value == "not in"); } throw std::runtime_error("Cannot perform operation " + op.value + " on undefined values"); - } else if (is_type(left_val) || is_type(right_val)) { + } else if (is_val(left_val) || is_val(right_val)) { throw std::runtime_error("Cannot perform operation on null values"); } @@ -66,13 +50,13 @@ value binary_expression::execute(context & ctx) { } // Float operations - if ((is_type(left_val) || is_type(left_val)) && - (is_type(right_val) || is_type(right_val))) { + if ((is_val(left_val) || is_val(left_val)) && + (is_val(right_val) || is_val(right_val))) { double a = left_val->as_float(); double b = right_val->as_float(); if (op.value == "+" || op.value == "-" || op.value == "*") { double res = (op.value == "+") ? a + b : (op.value == "-") ? a - b : a * b; - bool is_float = is_type(left_val) || is_type(right_val); + bool is_float = is_val(left_val) || is_val(right_val); if (is_float) { return std::make_unique(res); } else { @@ -82,7 +66,7 @@ value binary_expression::execute(context & ctx) { return std::make_unique(a / b); } else if (op.value == "%") { double rem = std::fmod(a, b); - bool is_float = is_type(left_val) || is_type(right_val); + bool is_float = is_val(left_val) || is_val(right_val); if (is_float) { return std::make_unique(rem); } else { @@ -100,7 +84,7 @@ value binary_expression::execute(context & ctx) { } // Array operations - if (is_type(left_val) && is_type(right_val)) { + if (is_val(left_val) && is_val(right_val)) { if (op.value == "+") { auto & left_arr = left_val->as_array(); auto & right_arr = right_val->as_array(); @@ -113,7 +97,7 @@ value binary_expression::execute(context & ctx) { } return result; } - } else if (is_type(right_val)) { + } else if (is_val(right_val)) { auto & arr = right_val->as_array(); bool member = std::find_if(arr.begin(), arr.end(), [&](const value& v) { return v == left_val; }) != arr.end(); if (op.value == "in") { @@ -124,14 +108,14 @@ value binary_expression::execute(context & ctx) { } // String concatenation - if (is_type(left_val) || is_type(right_val)) { + if (is_val(left_val) || is_val(right_val)) { if (op.value == "+") { return std::make_unique(left_val->as_string() + right_val->as_string()); } } // String membership - if (is_type(left_val) && is_type(right_val)) { + if (is_val(left_val) && is_val(right_val)) { auto left_str = left_val->as_string(); auto right_str = right_val->as_string(); if (op.value == "in") { @@ -142,7 +126,7 @@ value binary_expression::execute(context & ctx) { } // String in object - if (is_type(left_val) && is_type(right_val)) { + if (is_val(left_val) && is_val(right_val)) { auto key = left_val->as_string(); auto & obj = right_val->as_object(); bool has_key = obj.find(key) != obj.end(); @@ -158,7 +142,7 @@ value binary_expression::execute(context & ctx) { value filter_expression::execute(context & ctx) { value input = operand->execute(ctx); - value filter_func = filter->execute(ctx); + // value filter_func = filter->execute(ctx); if (is_stmt(filter)) { auto filter_val = dynamic_cast(filter.get())->value; @@ -168,7 +152,7 @@ value filter_expression::execute(context & ctx) { throw std::runtime_error("to_json filter not implemented"); } - if (is_type(input)) { + if (is_val(input)) { auto & arr = input->as_array(); if (filter_val == "list") { return std::make_unique(input); @@ -189,12 +173,18 @@ value filter_expression::execute(context & ctx) { throw std::runtime_error("Unknown filter '" + filter_val + "' for array"); } - } else if (is_type(input)) { + } else if (is_val(input)) { auto str = input->as_string(); - // TODO + auto builtins = input->get_builtins(); + auto it = builtins.find(filter_val); + if (it != builtins.end()) { + func_args args; + args.args.push_back(input->clone()); + return it->second(args); + } throw std::runtime_error("Unknown filter '" + filter_val + "' for string"); - } else if (is_type(input) || is_type(input)) { + } else if (is_val(input) || is_val(input)) { // TODO throw std::runtime_error("Unknown filter '" + filter_val + "' for number"); diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 2c547294a8..ac5d679e88 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -23,7 +23,7 @@ struct context { struct statement { virtual ~statement() = default; virtual std::string type() const { return "Statement"; } - virtual value execute(context & ctx) { throw std::runtime_error("cannot exec " + type()); }; + virtual value execute(context & ctx) { throw std::runtime_error("cannot exec " + type()); } }; using statement_ptr = std::unique_ptr; diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index e0b5d8f8d9..3a8fc0cd87 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -13,7 +13,7 @@ int main(void) { //std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; - std::string contents = "{{ 'hi' + 'fi' }}"; + std::string contents = "{{ ('hi' + 'fi') | upper }}"; std::cout << "=== INPUT ===\n" << contents << "\n\n"; From 7ed11f78f94d57f618223b7cabbe9dc8f75930fd Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 27 Dec 2025 22:10:45 +0100 Subject: [PATCH 12/47] add more builtins --- common/jinja/jinja-value.h | 16 ++- common/jinja/jinja-vm-builtins.cpp | 172 +++++++++++++++++++++++++++++ common/jinja/jinja-vm.cpp | 57 ++++++---- 3 files changed, 220 insertions(+), 25 deletions(-) diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 289acb1c7d..ac742b2f44 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -118,6 +118,7 @@ struct value_t { } }; + struct value_int_t : public value_t { value_int_t(int64_t v) { val_int = v; } virtual std::string type() const override { return "Integer"; } @@ -125,9 +126,11 @@ struct value_int_t : public value_t { virtual double as_float() const override { return static_cast(val_int); } virtual std::string as_string() const override { return std::to_string(val_int); } virtual value clone() const override { return std::make_unique(*this); } + virtual const func_builtins & get_builtins() const override; }; using value_int = std::unique_ptr; + struct value_float_t : public value_t { value_float_t(double v) { val_flt = v; } virtual std::string type() const override { return "Float"; } @@ -135,27 +138,32 @@ struct value_float_t : public value_t { virtual int64_t as_int() const override { return static_cast(val_flt); } virtual std::string as_string() const override { return std::to_string(val_flt); } virtual value clone() const override { return std::make_unique(*this); } + virtual const func_builtins & get_builtins() const override; }; using value_float = std::unique_ptr; + struct value_string_t : public value_t { value_string_t(const std::string & v) { val_str = v; } virtual std::string type() const override { return "String"; } virtual std::string as_string() const override { return val_str; } virtual value clone() const override { return std::make_unique(*this); } - const func_builtins & get_builtins() const override; + virtual const func_builtins & get_builtins() const override; }; using value_string = std::unique_ptr; + struct value_bool_t : public value_t { value_bool_t(bool v) { val_bool = v; } virtual std::string type() const override { return "Boolean"; } virtual bool as_bool() const override { return val_bool; } virtual std::string as_string() const override { return val_bool ? "True" : "False"; } virtual value clone() const override { return std::make_unique(*this); } + virtual const func_builtins & get_builtins() const override; }; using value_bool = std::unique_ptr; + struct value_array_t : public value_t { value_array_t() { val_arr = std::make_shared>(); @@ -184,9 +192,11 @@ struct value_array_t : public value_t { tmp->val_arr = this->val_arr; return tmp; } + virtual const func_builtins & get_builtins() const override; }; using value_array = std::unique_ptr; + struct value_object_t : public value_t { value_object_t() { val_obj = std::make_shared>(); @@ -208,9 +218,11 @@ struct value_object_t : public value_t { tmp->val_obj = this->val_obj; return tmp; } + virtual const func_builtins & get_builtins() const override; }; using value_object = std::unique_ptr; + struct value_func_t : public value_t { value_func_t(func_handler & func) { val_func = func; @@ -223,6 +235,7 @@ struct value_func_t : public value_t { }; using value_func = std::unique_ptr; + struct value_null_t : public value_t { virtual std::string type() const override { return "Null"; } virtual bool is_null() const override { return true; } @@ -230,6 +243,7 @@ struct value_null_t : public value_t { }; using value_null = std::unique_ptr; + struct value_undefined_t : public value_t { virtual std::string type() const override { return "Undefined"; } virtual bool is_undefined() const override { return true; } diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index 85d0681867..c369455fde 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -8,6 +8,40 @@ namespace jinja { +const func_builtins & value_int_t::get_builtins() const { + static const func_builtins builtins = { + {"abs", [](const func_args & args) -> value { + args.ensure_vals(); + int64_t val = args.args[0]->as_int(); + return std::make_unique(val < 0 ? -val : val); + }}, + {"float", [](const func_args & args) -> value { + args.ensure_vals(); + double val = static_cast(args.args[0]->as_int()); + return std::make_unique(val); + }}, + }; + return builtins; +} + + +const func_builtins & value_float_t::get_builtins() const { + static const func_builtins builtins = { + {"abs", [](const func_args & args) -> value { + args.ensure_vals(); + double val = args.args[0]->as_float(); + return std::make_unique(val < 0.0 ? -val : val); + }}, + {"int", [](const func_args & args) -> value { + args.ensure_vals(); + int64_t val = static_cast(args.args[0]->as_float()); + return std::make_unique(val); + }}, + }; + return builtins; +} + + static std::string string_strip(const std::string & str, bool left, bool right) { size_t start = 0; size_t end = str.length(); @@ -132,8 +166,146 @@ const func_builtins & value_string_t::get_builtins() const { } return std::make_unique(str); }}, + {"int", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + try { + return std::make_unique(std::stoi(str)); + } catch (...) { + throw std::runtime_error("Cannot convert string '" + str + "' to int"); + } + }}, + {"float", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + try { + return std::make_unique(std::stod(str)); + } catch (...) { + throw std::runtime_error("Cannot convert string '" + str + "' to float"); + } + }}, + {"string", [](const func_args & args) -> value { + // no-op + args.ensure_vals(); + return std::make_unique(args.args[0]->as_string()); + }}, + {"indent", [](const func_args & args) -> value { + throw std::runtime_error("indent builtin not implemented"); + }}, + {"join", [](const func_args & args) -> value { + throw std::runtime_error("join builtin not implemented"); + }}, }; return builtins; }; + +const func_builtins & value_bool_t::get_builtins() const { + static const func_builtins builtins = { + {"int", [](const func_args & args) -> value { + args.ensure_vals(); + bool val = args.args[0]->as_bool(); + return std::make_unique(val ? 1 : 0); + }}, + {"float", [](const func_args & args) -> value { + args.ensure_vals(); + bool val = args.args[0]->as_bool(); + return std::make_unique(val ? 1.0 : 0.0); + }}, + {"string", [](const func_args & args) -> value { + args.ensure_vals(); + bool val = args.args[0]->as_bool(); + return std::make_unique(val ? "True" : "False"); + }}, + }; + return builtins; +} + + +const func_builtins & value_array_t::get_builtins() const { + static const func_builtins builtins = { + {"list", [](const func_args & args) -> value { + args.ensure_vals(); + const auto & arr = args.args[0]->as_array(); + auto result = std::make_unique(); + for (const auto& v : arr) { + result->val_arr->push_back(v->clone()); + } + return result; + }}, + {"first", [](const func_args & args) -> value { + args.ensure_vals(); + const auto & arr = args.args[0]->as_array(); + if (arr.empty()) { + return std::make_unique(); + } + return arr[0]->clone(); + }}, + {"last", [](const func_args & args) -> value { + args.ensure_vals(); + const auto & arr = args.args[0]->as_array(); + if (arr.empty()) { + return std::make_unique(); + } + return arr[arr.size() - 1]->clone(); + }}, + {"length", [](const func_args & args) -> value { + args.ensure_vals(); + const auto & arr = args.args[0]->as_array(); + return std::make_unique(static_cast(arr.size())); + }}, + // TODO: reverse, sort, join, string, unique + }; + return builtins; +} + + +const func_builtins & value_object_t::get_builtins() const { + static const func_builtins builtins = { + {"get", [](const func_args & args) -> value { + args.ensure_vals(); // TODO: add default value + const auto & obj = args.args[0]->as_object(); + std::string key = args.args[1]->as_string(); + auto it = obj.find(key); + if (it != obj.end()) { + return it->second->clone(); + } else { + return std::make_unique(); + } + }}, + {"keys", [](const func_args & args) -> value { + args.ensure_vals(); + const auto & obj = args.args[0]->as_object(); + auto result = std::make_unique(); + for (const auto & pair : obj) { + result->val_arr->push_back(std::make_unique(pair.first)); + } + return result; + }}, + {"values", [](const func_args & args) -> value { + args.ensure_vals(); + const auto & obj = args.args[0]->as_object(); + auto result = std::make_unique(); + for (const auto & pair : obj) { + result->val_arr->push_back(pair.second->clone()); + } + return result; + }}, + {"items", [](const func_args & args) -> value { + args.ensure_vals(); + const auto & obj = args.args[0]->as_object(); + auto result = std::make_unique(); + for (const auto & pair : obj) { + auto item = std::make_unique(); + item->val_arr->push_back(std::make_unique(pair.first)); + item->val_arr->push_back(pair.second->clone()); + result->val_arr->push_back(std::move(item)); + } + return result; + }}, + }; + return builtins; +} + + } // namespace jinja diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 25106f1e4a..bd1017f5db 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -142,7 +142,17 @@ value binary_expression::execute(context & ctx) { value filter_expression::execute(context & ctx) { value input = operand->execute(ctx); - // value filter_func = filter->execute(ctx); + + auto try_builtin = [&](const std::string & name) -> value { + auto builtins = input->get_builtins(); + auto it = builtins.find(name); + if (it != builtins.end()) { + func_args args; + args.args.push_back(input->clone()); + return it->second(args); + } + return nullptr; + }; if (is_stmt(filter)) { auto filter_val = dynamic_cast(filter.get())->value; @@ -154,43 +164,42 @@ value filter_expression::execute(context & ctx) { if (is_val(input)) { auto & arr = input->as_array(); - if (filter_val == "list") { - return std::make_unique(input); - } else if (filter_val == "first") { - if (arr.empty()) { - return std::make_unique(); - } - return arr[0]->clone(); - } else if (filter_val == "last") { - if (arr.empty()) { - return std::make_unique(); - } - return arr[arr.size() - 1]->clone(); - } else if (filter_val == "length") { - return std::make_unique(static_cast(arr.size())); - } else { - // TODO: reverse, sort, join, string, unique - throw std::runtime_error("Unknown filter '" + filter_val + "' for array"); + auto res = try_builtin(filter_val); + if (res) { + return res; } + throw std::runtime_error("Unknown filter '" + filter_val + "' for array"); } else if (is_val(input)) { auto str = input->as_string(); auto builtins = input->get_builtins(); - auto it = builtins.find(filter_val); - if (it != builtins.end()) { - func_args args; - args.args.push_back(input->clone()); - return it->second(args); + if (filter_val == "trim") { + filter_val = "strip"; // alias + } + auto res = try_builtin(filter_val); + if (res) { + return res; } throw std::runtime_error("Unknown filter '" + filter_val + "' for string"); } else if (is_val(input) || is_val(input)) { - // TODO + auto res = try_builtin(filter_val); + if (res) { + return res; + } throw std::runtime_error("Unknown filter '" + filter_val + "' for number"); } else { throw std::runtime_error("Filters not supported for type " + input->type()); } + + } else if (is_stmt(filter)) { + // TODO + // value filter_func = filter->execute(ctx); + throw std::runtime_error("Filter with arguments not implemented"); + + } else { + throw std::runtime_error("Invalid filter expression"); } } From da7bbe5813b936260be6568075dfe180b74d04e9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 27 Dec 2025 22:25:19 +0100 Subject: [PATCH 13/47] wip --- common/jinja/jinja-vm.cpp | 20 ++++++++++++++++++++ common/jinja/jinja-vm.h | 25 ++++++++++++++++++------- tests/test-chat-jinja.cpp | 4 ++-- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index bd1017f5db..5ad2cd2826 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -203,4 +203,24 @@ value filter_expression::execute(context & ctx) { } } +value if_statement::execute(context & ctx) { + throw std::runtime_error("if_statement::execute not implemented"); +} + +value for_statement::execute(context & ctx) { + throw std::runtime_error("for_statement::execute not implemented"); +} + +value break_statement::execute(context & ctx) { + throw std::runtime_error("break_statement::execute not implemented"); +} + +value continue_statement::execute(context & ctx) { + throw std::runtime_error("continue_statement::execute not implemented"); +} + +value set_statement::execute(context & ctx) { + throw std::runtime_error("set_statement::execute not implemented"); +} + } // namespace jinja diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index ac5d679e88..5b620026a2 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -13,8 +13,17 @@ namespace jinja { struct context { - std::ostringstream out; std::map var; + + context() = default; + ~context() = default; + + context(const context & parent) { + // inherit variables (for example, when entering a new scope) + for (const auto & pair : parent.var) { + var[pair.first] = pair.second->clone(); + } + } }; /** @@ -59,7 +68,9 @@ struct program : public statement { explicit program(statements && body) : body(std::move(body)) {} std::string type() const override { return "Program"; } - value execute(context & ctx) override {} + value execute(context & ctx) override { + throw std::runtime_error("Cannot execute program directly, use jinja::vm instead"); + } }; struct if_statement : public statement { @@ -73,7 +84,7 @@ struct if_statement : public statement { } std::string type() const override { return "If"; } - value execute(context & ctx) override {} + value execute(context & ctx) override; }; struct identifier; @@ -97,17 +108,17 @@ struct for_statement : public statement { } std::string type() const override { return "For"; } - value execute(context & ctx) override {} + value execute(context & ctx) override; }; struct break_statement : public statement { std::string type() const override { return "Break"; } - value execute(context & ctx) override {} + value execute(context & ctx) override; }; struct continue_statement : public statement { std::string type() const override { return "Continue"; } - value execute(context & ctx) override {} + value execute(context & ctx) override; }; struct set_statement : public statement { @@ -122,7 +133,7 @@ struct set_statement : public statement { } std::string type() const override { return "Set"; } - value execute(context & ctx) override {} + value execute(context & ctx) override; }; struct macro_statement : public statement { diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 3a8fc0cd87..e923da4481 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -11,9 +11,9 @@ #include "jinja/jinja-lexer.h" int main(void) { - //std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; + std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; - std::string contents = "{{ ('hi' + 'fi') | upper }}"; + //std::string contents = "{{ ('hi' + 'fi') | upper }}"; std::cout << "=== INPUT ===\n" << contents << "\n\n"; From c08f4ddf01776c85fdadda168f6dc17fec26b72e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 27 Dec 2025 22:28:54 +0100 Subject: [PATCH 14/47] use mk_val --- common/jinja/jinja-value.h | 2 +- common/jinja/jinja-vm-builtins.cpp | 70 +++++++++++++++--------------- common/jinja/jinja-vm.cpp | 42 +++++++++--------- 3 files changed, 57 insertions(+), 57 deletions(-) diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index ac742b2f44..9a6acae7e2 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -28,7 +28,7 @@ bool is_val(const value & ptr) { return dynamic_cast(ptr.get()) != nullptr; } template -bool mk_val(Args&&... args) { +value mk_val(Args&&... args) { using PointeeType = typename extract_pointee::type; return std::make_unique(std::forward(args)...); } diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index c369455fde..cc2b2b39a0 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -13,12 +13,12 @@ const func_builtins & value_int_t::get_builtins() const { {"abs", [](const func_args & args) -> value { args.ensure_vals(); int64_t val = args.args[0]->as_int(); - return std::make_unique(val < 0 ? -val : val); + return mk_val(val < 0 ? -val : val); }}, {"float", [](const func_args & args) -> value { args.ensure_vals(); double val = static_cast(args.args[0]->as_int()); - return std::make_unique(val); + return mk_val(val); }}, }; return builtins; @@ -30,12 +30,12 @@ const func_builtins & value_float_t::get_builtins() const { {"abs", [](const func_args & args) -> value { args.ensure_vals(); double val = args.args[0]->as_float(); - return std::make_unique(val < 0.0 ? -val : val); + return mk_val(val < 0.0 ? -val : val); }}, {"int", [](const func_args & args) -> value { args.ensure_vals(); int64_t val = static_cast(args.args[0]->as_float()); - return std::make_unique(val); + return mk_val(val); }}, }; return builtins; @@ -74,28 +74,28 @@ const func_builtins & value_string_t::get_builtins() const { args.ensure_vals(); std::string str = args.args[0]->as_string(); std::transform(str.begin(), str.end(), str.begin(), ::toupper); - return std::make_unique(str); + return mk_val(str); }}, {"lower", [](const func_args & args) -> value { args.ensure_vals(); std::string str = args.args[0]->as_string(); std::transform(str.begin(), str.end(), str.begin(), ::tolower); - return std::make_unique(str); + return mk_val(str); }}, {"strip", [](const func_args & args) -> value { args.ensure_vals(); std::string str = args.args[0]->as_string(); - return std::make_unique(string_strip(str, true, true)); + return mk_val(string_strip(str, true, true)); }}, {"rstrip", [](const func_args & args) -> value { args.ensure_vals(); std::string str = args.args[0]->as_string(); - return std::make_unique(string_strip(str, false, true)); + return mk_val(string_strip(str, false, true)); }}, {"lstrip", [](const func_args & args) -> value { args.ensure_vals(); std::string str = args.args[0]->as_string(); - return std::make_unique(string_strip(str, true, false)); + return mk_val(string_strip(str, true, false)); }}, {"title", [](const func_args & args) -> value { args.ensure_vals(); @@ -111,7 +111,7 @@ const func_builtins & value_string_t::get_builtins() const { c = ::tolower(static_cast(c)); } } - return std::make_unique(str); + return mk_val(str); }}, {"capitalize", [](const func_args & args) -> value { args.ensure_vals(); @@ -120,38 +120,38 @@ const func_builtins & value_string_t::get_builtins() const { str[0] = ::toupper(static_cast(str[0])); std::transform(str.begin() + 1, str.end(), str.begin() + 1, ::tolower); } - return std::make_unique(str); + return mk_val(str); }}, {"length", [](const func_args & args) -> value { args.ensure_vals(); std::string str = args.args[0]->as_string(); - return std::make_unique(str.length()); + return mk_val(str.length()); }}, {"startswith", [](const func_args & args) -> value { args.ensure_vals(); std::string str = args.args[0]->as_string(); std::string prefix = args.args[1]->as_string(); - return std::make_unique(string_startswith(str, prefix)); + return mk_val(string_startswith(str, prefix)); }}, {"endswith", [](const func_args & args) -> value { args.ensure_vals(); std::string str = args.args[0]->as_string(); std::string suffix = args.args[1]->as_string(); - return std::make_unique(string_endswith(str, suffix)); + return mk_val(string_endswith(str, suffix)); }}, {"split", [](const func_args & args) -> value { args.ensure_vals(); std::string str = args.args[0]->as_string(); std::string delim = (args.args.size() > 1) ? args.args[1]->as_string() : " "; - auto result = std::make_unique(); + auto result = mk_val(); size_t pos = 0; std::string token; while ((pos = str.find(delim)) != std::string::npos) { token = str.substr(0, pos); - result->val_arr->push_back(std::make_unique(token)); + result->val_arr->push_back(mk_val(token)); str.erase(0, pos + delim.length()); } - result->val_arr->push_back(std::make_unique(str)); + result->val_arr->push_back(mk_val(str)); return std::move(result); }}, {"replace", [](const func_args & args) -> value { @@ -164,13 +164,13 @@ const func_builtins & value_string_t::get_builtins() const { str.replace(pos, old_str.length(), new_str); pos += new_str.length(); } - return std::make_unique(str); + return mk_val(str); }}, {"int", [](const func_args & args) -> value { args.ensure_vals(); std::string str = args.args[0]->as_string(); try { - return std::make_unique(std::stoi(str)); + return mk_val(std::stoi(str)); } catch (...) { throw std::runtime_error("Cannot convert string '" + str + "' to int"); } @@ -179,7 +179,7 @@ const func_builtins & value_string_t::get_builtins() const { args.ensure_vals(); std::string str = args.args[0]->as_string(); try { - return std::make_unique(std::stod(str)); + return mk_val(std::stod(str)); } catch (...) { throw std::runtime_error("Cannot convert string '" + str + "' to float"); } @@ -187,7 +187,7 @@ const func_builtins & value_string_t::get_builtins() const { {"string", [](const func_args & args) -> value { // no-op args.ensure_vals(); - return std::make_unique(args.args[0]->as_string()); + return mk_val(args.args[0]->as_string()); }}, {"indent", [](const func_args & args) -> value { throw std::runtime_error("indent builtin not implemented"); @@ -205,17 +205,17 @@ const func_builtins & value_bool_t::get_builtins() const { {"int", [](const func_args & args) -> value { args.ensure_vals(); bool val = args.args[0]->as_bool(); - return std::make_unique(val ? 1 : 0); + return mk_val(val ? 1 : 0); }}, {"float", [](const func_args & args) -> value { args.ensure_vals(); bool val = args.args[0]->as_bool(); - return std::make_unique(val ? 1.0 : 0.0); + return mk_val(val ? 1.0 : 0.0); }}, {"string", [](const func_args & args) -> value { args.ensure_vals(); bool val = args.args[0]->as_bool(); - return std::make_unique(val ? "True" : "False"); + return mk_val(val ? "True" : "False"); }}, }; return builtins; @@ -227,7 +227,7 @@ const func_builtins & value_array_t::get_builtins() const { {"list", [](const func_args & args) -> value { args.ensure_vals(); const auto & arr = args.args[0]->as_array(); - auto result = std::make_unique(); + auto result = mk_val(); for (const auto& v : arr) { result->val_arr->push_back(v->clone()); } @@ -237,7 +237,7 @@ const func_builtins & value_array_t::get_builtins() const { args.ensure_vals(); const auto & arr = args.args[0]->as_array(); if (arr.empty()) { - return std::make_unique(); + return mk_val(); } return arr[0]->clone(); }}, @@ -245,14 +245,14 @@ const func_builtins & value_array_t::get_builtins() const { args.ensure_vals(); const auto & arr = args.args[0]->as_array(); if (arr.empty()) { - return std::make_unique(); + return mk_val(); } return arr[arr.size() - 1]->clone(); }}, {"length", [](const func_args & args) -> value { args.ensure_vals(); const auto & arr = args.args[0]->as_array(); - return std::make_unique(static_cast(arr.size())); + return mk_val(static_cast(arr.size())); }}, // TODO: reverse, sort, join, string, unique }; @@ -270,22 +270,22 @@ const func_builtins & value_object_t::get_builtins() const { if (it != obj.end()) { return it->second->clone(); } else { - return std::make_unique(); + return mk_val(); } }}, {"keys", [](const func_args & args) -> value { args.ensure_vals(); const auto & obj = args.args[0]->as_object(); - auto result = std::make_unique(); + auto result = mk_val(); for (const auto & pair : obj) { - result->val_arr->push_back(std::make_unique(pair.first)); + result->val_arr->push_back(mk_val(pair.first)); } return result; }}, {"values", [](const func_args & args) -> value { args.ensure_vals(); const auto & obj = args.args[0]->as_object(); - auto result = std::make_unique(); + auto result = mk_val(); for (const auto & pair : obj) { result->val_arr->push_back(pair.second->clone()); } @@ -294,10 +294,10 @@ const func_builtins & value_object_t::get_builtins() const { {"items", [](const func_args & args) -> value { args.ensure_vals(); const auto & obj = args.args[0]->as_object(); - auto result = std::make_unique(); + auto result = mk_val(); for (const auto & pair : obj) { - auto item = std::make_unique(); - item->val_arr->push_back(std::make_unique(pair.first)); + auto item = mk_val(); + item->val_arr->push_back(mk_val(pair.first)); item->val_arr->push_back(pair.second->clone()); result->val_arr->push_back(std::move(item)); } diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 5ad2cd2826..3a28977e6b 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -28,16 +28,16 @@ value binary_expression::execute(context & ctx) { // Equality operators value right_val = right->execute(ctx); if (op.value == "==") { - return std::make_unique(left_val == right_val); + return mk_val(left_val == right_val); } else if (op.value == "!=") { - return std::make_unique(left_val != right_val); + return mk_val(left_val != right_val); } // Handle undefined and null values if (is_val(left_val) || is_val(right_val)) { if (is_val(right_val) && (op.value == "in" || op.value == "not in")) { // Special case: `anything in undefined` is `false` and `anything not in undefined` is `true` - return std::make_unique(op.value == "not in"); + return mk_val(op.value == "not in"); } throw std::runtime_error("Cannot perform operation " + op.value + " on undefined values"); } else if (is_val(left_val) || is_val(right_val)) { @@ -46,7 +46,7 @@ value binary_expression::execute(context & ctx) { // String concatenation with ~ if (op.value == "~") { - return std::make_unique(left_val->as_string() + right_val->as_string()); + return mk_val(left_val->as_string() + right_val->as_string()); } // Float operations @@ -58,28 +58,28 @@ value binary_expression::execute(context & ctx) { double res = (op.value == "+") ? a + b : (op.value == "-") ? a - b : a * b; bool is_float = is_val(left_val) || is_val(right_val); if (is_float) { - return std::make_unique(res); + return mk_val(res); } else { - return std::make_unique(static_cast(res)); + return mk_val(static_cast(res)); } } else if (op.value == "/") { - return std::make_unique(a / b); + return mk_val(a / b); } else if (op.value == "%") { double rem = std::fmod(a, b); bool is_float = is_val(left_val) || is_val(right_val); if (is_float) { - return std::make_unique(rem); + return mk_val(rem); } else { - return std::make_unique(static_cast(rem)); + return mk_val(static_cast(rem)); } } else if (op.value == "<") { - return std::make_unique(a < b); + return mk_val(a < b); } else if (op.value == ">") { - return std::make_unique(a > b); + return mk_val(a > b); } else if (op.value == ">=") { - return std::make_unique(a >= b); + return mk_val(a >= b); } else if (op.value == "<=") { - return std::make_unique(a <= b); + return mk_val(a <= b); } } @@ -88,7 +88,7 @@ value binary_expression::execute(context & ctx) { if (op.value == "+") { auto & left_arr = left_val->as_array(); auto & right_arr = right_val->as_array(); - auto result = std::make_unique(); + auto result = mk_val(); for (const auto & item : left_arr) { result->val_arr->push_back(item->clone()); } @@ -101,16 +101,16 @@ value binary_expression::execute(context & ctx) { auto & arr = right_val->as_array(); bool member = std::find_if(arr.begin(), arr.end(), [&](const value& v) { return v == left_val; }) != arr.end(); if (op.value == "in") { - return std::make_unique(member); + return mk_val(member); } else if (op.value == "not in") { - return std::make_unique(!member); + return mk_val(!member); } } // String concatenation if (is_val(left_val) || is_val(right_val)) { if (op.value == "+") { - return std::make_unique(left_val->as_string() + right_val->as_string()); + return mk_val(left_val->as_string() + right_val->as_string()); } } @@ -119,9 +119,9 @@ value binary_expression::execute(context & ctx) { auto left_str = left_val->as_string(); auto right_str = right_val->as_string(); if (op.value == "in") { - return std::make_unique(right_str.find(left_str) != std::string::npos); + return mk_val(right_str.find(left_str) != std::string::npos); } else if (op.value == "not in") { - return std::make_unique(right_str.find(left_str) == std::string::npos); + return mk_val(right_str.find(left_str) == std::string::npos); } } @@ -131,9 +131,9 @@ value binary_expression::execute(context & ctx) { auto & obj = right_val->as_object(); bool has_key = obj.find(key) != obj.end(); if (op.value == "in") { - return std::make_unique(has_key); + return mk_val(has_key); } else if (op.value == "not in") { - return std::make_unique(!has_key); + return mk_val(!has_key); } } From 10835f2720b2e482f86616dc413a94cc98093acb Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 27 Dec 2025 23:25:20 +0100 Subject: [PATCH 15/47] eval with is_user_input --- common/jinja/jinja-value.h | 13 ++++ common/jinja/jinja-vm-builtins.cpp | 2 +- common/jinja/jinja-vm.cpp | 103 ++++++++++++++++++++++++++++- common/jinja/jinja-vm.h | 6 +- tests/test-chat-jinja.cpp | 33 +++++++-- 5 files changed, 147 insertions(+), 10 deletions(-) diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 9a6acae7e2..a5362169c4 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -5,6 +5,7 @@ #include #include #include +#include namespace jinja { @@ -144,6 +145,8 @@ using value_float = std::unique_ptr; struct value_string_t : public value_t { + bool is_user_input = false; // may skip parsing special tokens if true + value_string_t(const std::string & v) { val_str = v; } virtual std::string type() const override { return "String"; } virtual std::string as_string() const override { return val_str; } @@ -192,6 +195,16 @@ struct value_array_t : public value_t { tmp->val_arr = this->val_arr; return tmp; } + virtual std::string as_string() const override { + std::ostringstream ss; + ss << "["; + for (size_t i = 0; i < val_arr->size(); i++) { + if (i > 0) ss << ", "; + ss << val_arr->at(i)->as_string(); + } + ss << "]"; + return ss.str(); + } virtual const func_builtins & get_builtins() const override; }; using value_array = std::unique_ptr; diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index cc2b2b39a0..860f67b629 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -197,7 +197,7 @@ const func_builtins & value_string_t::get_builtins() const { }}, }; return builtins; -}; +} const func_builtins & value_bool_t::get_builtins() const { diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 3a28977e6b..73ad5bae0d 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -8,6 +8,9 @@ #include #include +#define JJ_DEBUG(msg, ...) printf("jinja-vm: " msg "\n", __VA_ARGS__) +//#define JJ_DEBUG(msg, ...) // no-op + namespace jinja { template @@ -15,6 +18,17 @@ static bool is_stmt(const statement_ptr & ptr) { return dynamic_cast(ptr.get()) != nullptr; } +value identifier::execute(context & ctx) { + auto it = ctx.var.find(val); + if (it != ctx.var.end()) { + JJ_DEBUG("Identifier '%s' found", val.c_str()); + return it->second->clone(); + } else { + JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str()); + return mk_val(); + } +} + value binary_expression::execute(context & ctx) { value left_val = left->execute(ctx); @@ -151,11 +165,11 @@ value filter_expression::execute(context & ctx) { args.args.push_back(input->clone()); return it->second(args); } - return nullptr; + throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type()); }; if (is_stmt(filter)) { - auto filter_val = dynamic_cast(filter.get())->value; + auto filter_val = dynamic_cast(filter.get())->val; if (filter_val == "to_json") { // TODO: Implement to_json filter @@ -204,7 +218,15 @@ value filter_expression::execute(context & ctx) { } value if_statement::execute(context & ctx) { - throw std::runtime_error("if_statement::execute not implemented"); + value test_val = test->execute(ctx); + auto out = mk_val(); + if (test_val->as_bool()) { + for (auto & stmt : body) { + JJ_DEBUG("Executing if body statement of type %s", stmt->type().c_str()); + out->val_arr->push_back(stmt->execute(ctx)); + } + } + return out; } value for_statement::execute(context & ctx) { @@ -223,4 +245,79 @@ value set_statement::execute(context & ctx) { throw std::runtime_error("set_statement::execute not implemented"); } +value member_expression::execute(context & ctx) { + value object = this->object->execute(ctx); + + value property; + if (this->computed) { + property = this->property->execute(ctx); + } else { + property = mk_val(dynamic_cast(this->property.get())->val); + } + + value val = mk_val(); + + if (is_val(object)) { + if (!is_val(property)) { + throw std::runtime_error("Cannot access object with non-string: got " + property->type()); + } + auto key = property->as_string(); + auto & obj = object->as_object(); + auto it = obj.find(key); + if (it != obj.end()) { + val = it->second->clone(); + } else { + auto builtins = object->get_builtins(); + auto bit = builtins.find(key); + if (bit != builtins.end()) { + func_args args; + args.args.push_back(object->clone()); + val = bit->second(args); + } + } + + } else if (is_val(object) || is_val(object)) { + if (is_val(property)) { + int64_t index = property->as_int(); + if (is_val(object)) { + auto & arr = object->as_array(); + if (index >= 0 && index < static_cast(arr.size())) { + val = arr[index]->clone(); + } + } else { // value_string + auto str = object->as_string(); + if (index >= 0 && index < static_cast(str.size())) { + val = mk_val(std::string(1, str[index])); + } + } + } else if (is_val(property)) { + auto key = property->as_string(); + auto builtins = object->get_builtins(); + auto bit = builtins.find(key); + if (bit != builtins.end()) { + func_args args; + args.args.push_back(object->clone()); + val = bit->second(args); + } + } else { + throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type()); + } + + } else { + if (!is_val(property)) { + throw std::runtime_error("Cannot access property with non-string: got " + property->type()); + } + auto key = property->as_string(); + auto builtins = object->get_builtins(); + auto bit = builtins.find(key); + if (bit != builtins.end()) { + func_args args; + args.args.push_back(object->clone()); + val = bit->second(args); + } + } + + return val; +} + } // namespace jinja diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 5b620026a2..d2e763b13b 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -171,6 +171,7 @@ struct member_expression : public expression { chk_type(this->property); } std::string type() const override { return "MemberExpression"; } + value execute(context & ctx) override; }; struct call_expression : public expression { @@ -189,9 +190,10 @@ struct call_expression : public expression { * Represents a user-defined variable or symbol in the template. */ struct identifier : public expression { - std::string value; - explicit identifier(const std::string & value) : value(value) {} + std::string val; + explicit identifier(const std::string & val) : val(val) {} std::string type() const override { return "Identifier"; } + value execute(context & ctx) override; }; // Literals diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index e923da4481..63048841c3 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -11,9 +11,11 @@ #include "jinja/jinja-lexer.h" int main(void) { - std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; + //std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; - //std::string contents = "{{ ('hi' + 'fi') | upper }}"; + //std::string contents = "{% if messages[0]['role'] != 'system' %}nice {{ messages[0]['content'] }}{% endif %}"; + + std::string contents = " {{ messages[0]['content'] }} "; std::cout << "=== INPUT ===\n" << contents << "\n\n"; @@ -34,11 +36,34 @@ int main(void) { std::cout << "\n=== OUTPUT ===\n"; jinja::context ctx; + + auto make_non_special_string = [](const std::string & s) { + jinja::value_string str_val = std::make_unique(s); + str_val->is_user_input = true; + return str_val; + }; + + jinja::value messages = jinja::mk_val(); + jinja::value msg1 = jinja::mk_val(); + (*msg1->val_obj)["role"] = make_non_special_string("user"); + (*msg1->val_obj)["content"] = make_non_special_string("Hello, how are you?"); + messages->val_arr->push_back(std::move(msg1)); + jinja::value msg2 = jinja::mk_val(); + (*msg2->val_obj)["role"] = make_non_special_string("assistant"); + (*msg2->val_obj)["content"] = make_non_special_string("I am fine, thank you!"); + messages->val_arr->push_back(std::move(msg2)); + + ctx.var["messages"] = std::move(messages); + jinja::vm vm(ctx); auto results = vm.execute(ast); for (const auto & res : results) { - std::cout << "result type: " << res->type() << "\n"; - std::cout << "result value: " << res->as_string() << "\n"; + auto str_ptr = dynamic_cast(res.get()); + std::string is_user_input = "false"; + if (str_ptr) { + is_user_input = str_ptr->is_user_input ? "true" : "false"; + } + std::cout << "result type: " << res->type() << " | value: " << res->as_string() << " | is_user_input: " << is_user_input << "\n"; } return 0; From 81310d29c1adfe1770443862abb7734d19d864e9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 12:04:23 +0100 Subject: [PATCH 16/47] render gemma tmpl ok --- common/jinja/jinja-parser.cpp | 4 +- common/jinja/jinja-value.h | 49 ++++-- common/jinja/jinja-vm-builtins.cpp | 17 +- common/jinja/jinja-vm.cpp | 250 +++++++++++++++++++++++++++-- common/jinja/jinja-vm.h | 45 ++++-- tests/test-chat-jinja.cpp | 13 +- 6 files changed, 330 insertions(+), 48 deletions(-) diff --git a/common/jinja/jinja-parser.cpp b/common/jinja/jinja-parser.cpp index de61023560..8b7058b8fa 100644 --- a/common/jinja/jinja-parser.cpp +++ b/common/jinja/jinja-parser.cpp @@ -142,11 +142,11 @@ private: } else if (name == "call") { statements caller_args; - bool has_caller_args = false; + // bool has_caller_args = false; if (is(token::open_paren)) { // Optional caller arguments, e.g. {% call(user) dump_users(...) %} caller_args = parse_args(); - has_caller_args = true; + // has_caller_args = true; } auto callee = parse_primary_expression(); if (!is_type(callee)) throw std::runtime_error("Expected identifier"); diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index a5362169c4..8b2d74ae35 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -28,8 +28,13 @@ bool is_val(const value & ptr) { using PointeeType = typename extract_pointee::type; return dynamic_cast(ptr.get()) != nullptr; } +template +bool is_val(const value_t * ptr) { + using PointeeType = typename extract_pointee::type; + return dynamic_cast(ptr) != nullptr; +} template -value mk_val(Args&&... args) { +std::unique_ptr::type> mk_val(Args&&... args) { using PointeeType = typename extract_pointee::type; return std::make_unique(std::forward(args)...); } @@ -70,6 +75,8 @@ struct func_args { using func_handler = std::function; using func_builtins = std::map; +bool value_compare(const value & a, const value & b); + struct value_t { int64_t val_int; double val_flt; @@ -93,12 +100,12 @@ struct value_t { virtual std::string type() const { return ""; } - virtual int64_t as_int() const { throw std::runtime_error("Not an int value"); } - virtual double as_float() const { throw std::runtime_error("Not a float value"); } - virtual std::string as_string() const { throw std::runtime_error("Not a string value"); } - virtual bool as_bool() const { throw std::runtime_error("Not a bool value"); } - virtual const std::vector & as_array() const { throw std::runtime_error("Not an array value"); } - virtual const std::map & as_object() const { throw std::runtime_error("Not an object value"); } + virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); } + virtual double as_float() const { throw std::runtime_error(type() + " is not a float value"); } + virtual std::string as_string() const { throw std::runtime_error(type() + " is not a string value"); } + virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); } + virtual const std::vector & as_array() const { throw std::runtime_error(type() + " is not an array value"); } + virtual const std::map & as_object() const { throw std::runtime_error(type() + " is not an object value"); } virtual value invoke(const func_args &) const { throw std::runtime_error("Not a function value"); } virtual bool is_null() const { return false; } virtual bool is_undefined() const { return false; } @@ -106,17 +113,11 @@ struct value_t { throw std::runtime_error("No builtins available for type " + type()); } + virtual std::string as_repr() const { return as_string(); } + virtual value clone() const { return std::make_unique(*this); } - - virtual bool operator==(const value & other) const { - // TODO - return false; - } - virtual bool operator!=(const value & other) const { - return !(*this == other); - } }; @@ -188,8 +189,12 @@ struct value_array_t : public value_t { val_arr->push_back(other.val_arr->at(i)->clone()); } } + void push_back(const value & val) { + val_arr->push_back(val->clone()); + } virtual std::string type() const override { return "Array"; } virtual const std::vector & as_array() const override { return *val_arr; } + // clone will also share the underlying data (point to the same vector) virtual value clone() const override { auto tmp = std::make_unique(); tmp->val_arr = this->val_arr; @@ -200,7 +205,7 @@ struct value_array_t : public value_t { ss << "["; for (size_t i = 0; i < val_arr->size(); i++) { if (i > 0) ss << ", "; - ss << val_arr->at(i)->as_string(); + ss << val_arr->at(i)->as_repr(); } ss << "]"; return ss.str(); @@ -224,8 +229,12 @@ struct value_object_t : public value_t { (*val_obj)[pair.first] = pair.second->clone(); } } + void insert(const std::string & key, const value & val) { + (*val_obj)[key] = val->clone(); + } virtual std::string type() const override { return "Object"; } virtual const std::map & as_object() const override { return *val_obj; } + // clone will also share the underlying data (point to the same map) virtual value clone() const override { auto tmp = std::make_unique(); tmp->val_obj = this->val_obj; @@ -244,6 +253,7 @@ struct value_func_t : public value_t { return val_func(args); } virtual std::string type() const override { return "Function"; } + virtual std::string as_repr() const override { return type(); } virtual value clone() const override { return std::make_unique(*this); } }; using value_func = std::unique_ptr; @@ -252,6 +262,8 @@ using value_func = std::unique_ptr; struct value_null_t : public value_t { virtual std::string type() const override { return "Null"; } virtual bool is_null() const override { return true; } + virtual bool as_bool() const override { return false; } + virtual std::string as_repr() const override { return type(); } virtual value clone() const override { return std::make_unique(*this); } }; using value_null = std::unique_ptr; @@ -260,8 +272,13 @@ using value_null = std::unique_ptr; struct value_undefined_t : public value_t { virtual std::string type() const override { return "Undefined"; } virtual bool is_undefined() const override { return true; } + virtual bool as_bool() const override { return false; } + virtual std::string as_repr() const override { return type(); } virtual value clone() const override { return std::make_unique(*this); } }; using value_undefined = std::unique_ptr; + +const func_builtins & global_builtins(); + } // namespace jinja diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index 860f67b629..493c71e25e 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -8,6 +8,18 @@ namespace jinja { +const func_builtins & global_builtins() { + static const func_builtins builtins = { + {"raise_exception", [](const func_args & args) -> value { + args.ensure_count(1); + std::string msg = args.args[0]->as_string(); + throw raised_exception("Jinja Exception: " + msg); + }}, + }; + return builtins; +} + + const func_builtins & value_int_t::get_builtins() const { static const func_builtins builtins = { {"abs", [](const func_args & args) -> value { @@ -189,10 +201,10 @@ const func_builtins & value_string_t::get_builtins() const { args.ensure_vals(); return mk_val(args.args[0]->as_string()); }}, - {"indent", [](const func_args & args) -> value { + {"indent", [](const func_args &) -> value { throw std::runtime_error("indent builtin not implemented"); }}, - {"join", [](const func_args & args) -> value { + {"join", [](const func_args &) -> value { throw std::runtime_error("join builtin not implemented"); }}, }; @@ -307,5 +319,4 @@ const func_builtins & value_object_t::get_builtins() const { return builtins; } - } // namespace jinja diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 73ad5bae0d..7fb323c58b 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -18,11 +18,24 @@ static bool is_stmt(const statement_ptr & ptr) { return dynamic_cast(ptr.get()) != nullptr; } +static value_array exec_statements(const statements & stmts, context & ctx) { + auto result = mk_val(); + for (const auto & stmt : stmts) { + JJ_DEBUG("Executing statement of type %s", stmt->type().c_str()); + result->val_arr->push_back(stmt->execute(ctx)); + } + return result; +} + value identifier::execute(context & ctx) { auto it = ctx.var.find(val); + auto builtins = global_builtins(); if (it != ctx.var.end()) { JJ_DEBUG("Identifier '%s' found", val.c_str()); return it->second->clone(); + } else if (builtins.find(val) != builtins.end()) { + JJ_DEBUG("Identifier '%s' found in builtins", val.c_str()); + return mk_val(builtins.at(val)); } else { JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str()); return mk_val(); @@ -31,6 +44,7 @@ value identifier::execute(context & ctx) { value binary_expression::execute(context & ctx) { value left_val = left->execute(ctx); + JJ_DEBUG("Executing binary expression with operator '%s'", op.value.c_str()); // Logical operators if (op.value == "and") { @@ -42,9 +56,9 @@ value binary_expression::execute(context & ctx) { // Equality operators value right_val = right->execute(ctx); if (op.value == "==") { - return mk_val(left_val == right_val); + return mk_val(value_compare(left_val, right_val)); } else if (op.value == "!=") { - return mk_val(left_val != right_val); + return mk_val(!value_compare(left_val, right_val)); } // Handle undefined and null values @@ -70,6 +84,7 @@ value binary_expression::execute(context & ctx) { double b = right_val->as_float(); if (op.value == "+" || op.value == "-" || op.value == "*") { double res = (op.value == "+") ? a + b : (op.value == "-") ? a - b : a * b; + JJ_DEBUG("Arithmetic operation: %f %s %f = %f", a, op.value.c_str(), b, res); bool is_float = is_val(left_val) || is_val(right_val); if (is_float) { return mk_val(res); @@ -80,6 +95,7 @@ value binary_expression::execute(context & ctx) { return mk_val(a / b); } else if (op.value == "%") { double rem = std::fmod(a, b); + JJ_DEBUG("Modulo operation: %f %% %f = %f", a, b, rem); bool is_float = is_val(left_val) || is_val(right_val); if (is_float) { return mk_val(rem); @@ -123,6 +139,7 @@ value binary_expression::execute(context & ctx) { // String concatenation if (is_val(left_val) || is_val(right_val)) { + JJ_DEBUG("%s", "String concatenation with + operator"); if (op.value == "+") { return mk_val(left_val->as_string() + right_val->as_string()); } @@ -177,7 +194,6 @@ value filter_expression::execute(context & ctx) { } if (is_val(input)) { - auto & arr = input->as_array(); auto res = try_builtin(filter_val); if (res) { return res; @@ -222,7 +238,12 @@ value if_statement::execute(context & ctx) { auto out = mk_val(); if (test_val->as_bool()) { for (auto & stmt : body) { - JJ_DEBUG("Executing if body statement of type %s", stmt->type().c_str()); + JJ_DEBUG("IF --> Executing THEN body, current block: %s", stmt->type().c_str()); + out->val_arr->push_back(stmt->execute(ctx)); + } + } else { + for (auto & stmt : alternate) { + JJ_DEBUG("IF --> Executing ELSE body, current block: %s", stmt->type().c_str()); out->val_arr->push_back(stmt->execute(ctx)); } } @@ -230,19 +251,171 @@ value if_statement::execute(context & ctx) { } value for_statement::execute(context & ctx) { - throw std::runtime_error("for_statement::execute not implemented"); -} + context scope(ctx); // new scope for loop variables -value break_statement::execute(context & ctx) { - throw std::runtime_error("break_statement::execute not implemented"); -} + statement_ptr iter_expr = std::move(iterable); + statement_ptr test_expr = nullptr; -value continue_statement::execute(context & ctx) { - throw std::runtime_error("continue_statement::execute not implemented"); + if (is_stmt(iterable)) { + JJ_DEBUG("%s", "For loop has test expression"); + auto select = dynamic_cast(iterable.get()); + iter_expr = std::move(select->lhs); + test_expr = std::move(select->test); + } + + JJ_DEBUG("Executing for statement, iterable type: %s", iter_expr->type().c_str()); + + value iterable_val = iter_expr->execute(scope); + if (!is_val(iterable_val) && !is_val(iterable_val)) { + throw std::runtime_error("Expected iterable or object type in for loop: got " + iterable_val->type()); + } + + std::vector items; + if (is_val(iterable_val)) { + auto & obj = iterable_val->as_object(); + for (auto & p : obj) { + items.push_back(mk_val(p.first)); + } + } else { + auto & arr = iterable_val->as_array(); + for (const auto & item : arr) { + items.push_back(item->clone()); + } + } + + std::vector> scope_update_fns; + + std::vector filtered_items; + for (size_t i = 0; i < items.size(); ++i) { + context loop_scope(scope); + + const value & current = items[i]; + + std::function scope_update_fn = [](context &) { /* no-op */}; + if (is_stmt(loopvar)) { + auto id = dynamic_cast(loopvar.get())->val; + scope_update_fn = [id, &items, i](context & ctx) { + ctx.var[id] = items[i]->clone(); + }; + } else if (is_stmt(loopvar)) { + auto tuple = dynamic_cast(loopvar.get()); + if (!is_val(current)) { + throw std::runtime_error("Cannot unpack non-iterable type: " + current->type()); + } + auto & c_arr = current->as_array(); + if (tuple->val.size() != c_arr.size()) { + throw std::runtime_error(std::string("Too ") + (tuple->val.size() > c_arr.size() ? "few" : "many") + " items to unpack"); + } + scope_update_fn = [tuple, &items, i](context & ctx) { + auto & c_arr = items[i]->as_array(); + for (size_t j = 0; j < tuple->val.size(); ++j) { + if (!is_stmt(tuple->val[j])) { + throw std::runtime_error("Cannot unpack non-identifier type: " + tuple->val[j]->type()); + } + auto id = dynamic_cast(tuple->val[j].get())->val; + ctx.var[id] = c_arr[j]->clone(); + } + }; + } else { + throw std::runtime_error("Invalid loop variable(s): " + loopvar->type()); + } + if (test_expr) { + scope_update_fn(loop_scope); + value test_val = test_expr->execute(loop_scope); + if (!test_val->as_bool()) { + continue; + } + } + filtered_items.push_back(current->clone()); + scope_update_fns.push_back(scope_update_fn); + } + + auto result = mk_val(); + + bool noIteration = true; + for (size_t i = 0; i < filtered_items.size(); ++i) { + JJ_DEBUG("For loop iteration %zu/%zu", i + 1, filtered_items.size()); + value_object loop_obj = mk_val(); + loop_obj->insert("index", mk_val(i + 1)); + loop_obj->insert("index0", mk_val(i)); + loop_obj->insert("revindex", mk_val(filtered_items.size() - i)); + loop_obj->insert("revindex0", mk_val(filtered_items.size() - i - 1)); + loop_obj->insert("first", mk_val(i == 0)); + loop_obj->insert("last", mk_val(i == filtered_items.size() - 1)); + loop_obj->insert("length", mk_val(filtered_items.size())); + loop_obj->insert("previtem", i > 0 ? filtered_items[i - 1]->clone() : mk_val()); + loop_obj->insert("nextitem", i < filtered_items.size() - 1 ? filtered_items[i + 1]->clone() : mk_val()); + ctx.var["loop"] = loop_obj->clone(); + scope_update_fns[i](ctx); + try { + for (auto & stmt : body) { + value val = stmt->execute(ctx); + result->push_back(val); + } + } catch (const continue_statement::exception &) { + continue; + } catch (const break_statement::exception &) { + break; + } + noIteration = false; + } + if (noIteration) { + for (auto & stmt : default_block) { + value val = stmt->execute(ctx); + result->push_back(val); + } + } + + return result; } value set_statement::execute(context & ctx) { - throw std::runtime_error("set_statement::execute not implemented"); + auto rhs = val ? val->execute(ctx) : exec_statements(body, ctx); + + if (is_stmt(assignee)) { + auto var_name = dynamic_cast(assignee.get())->val; + JJ_DEBUG("Setting variable '%s'", var_name.c_str()); + ctx.var[var_name] = rhs->clone(); + + } else if (is_stmt(assignee)) { + auto tuple = dynamic_cast(assignee.get()); + if (!is_val(rhs)) { + throw std::runtime_error("Cannot unpack non-iterable type in set: " + rhs->type()); + } + auto & arr = rhs->as_array(); + if (arr.size() != tuple->val.size()) { + throw std::runtime_error(std::string("Too ") + (tuple->val.size() > arr.size() ? "few" : "many") + " items to unpack in set"); + } + for (size_t i = 0; i < tuple->val.size(); ++i) { + auto & elem = tuple->val[i]; + if (!is_stmt(elem)) { + throw std::runtime_error("Cannot unpack to non-identifier in set: " + elem->type()); + } + auto var_name = dynamic_cast(elem.get())->val; + ctx.var[var_name] = arr[i]->clone(); + } + + } else if (is_stmt(assignee)) { + auto member = dynamic_cast(assignee.get()); + value object = member->object->execute(ctx); + if (!is_val(object)) { + throw std::runtime_error("Cannot assign to member of non-object"); + } + if (member->computed) { + throw std::runtime_error("Cannot assign to computed member"); + } + if (!is_stmt(member->property)) { + throw std::runtime_error("Cannot assign to member with non-identifier property"); + } + auto prop_name = dynamic_cast(member->property.get())->val; + auto obj_ptr = dynamic_cast(object.get()); + JJ_DEBUG("Setting object property '%s'", prop_name.c_str()); + obj_ptr->get()->insert(prop_name, rhs->clone()); + + } else { + throw std::runtime_error("Invalid LHS inside assignment expression: " + assignee->type()); + } + return mk_val(); } value member_expression::execute(context & ctx) { @@ -279,6 +452,7 @@ value member_expression::execute(context & ctx) { } else if (is_val(object) || is_val(object)) { if (is_val(property)) { int64_t index = property->as_int(); + JJ_DEBUG("Accessing %s index %lld", is_val(object) ? "array" : "string", index); if (is_val(object)) { auto & arr = object->as_array(); if (index >= 0 && index < static_cast(arr.size())) { @@ -292,6 +466,7 @@ value member_expression::execute(context & ctx) { } } else if (is_val(property)) { auto key = property->as_string(); + JJ_DEBUG("Accessing %s built-in '%s'", is_val(object) ? "array" : "string", key.c_str()); auto builtins = object->get_builtins(); auto bit = builtins.find(key); if (bit != builtins.end()) { @@ -320,4 +495,55 @@ value member_expression::execute(context & ctx) { return val; } +static func_args gather_call_args(const statements & arg_stmts, context & ctx) { + func_args args; + for (auto & arg_stmt : arg_stmts) { + args.args.push_back(arg_stmt->execute(ctx)); + } + return args; +} + +value call_expression::execute(context & ctx) { + auto args = gather_call_args(this->args, ctx); + value callee_val = callee->execute(ctx); + JJ_DEBUG("Calling function of type %s with %zu arguments", callee_val->type().c_str(), args.args.size()); + if (!is_val(callee_val)) { + throw std::runtime_error("Callee is not a function: got " + callee_val->type()); + } + return callee_val->invoke(args); +} + +// compare operator for value_t +bool value_compare(const value & a, const value & b) { + JJ_DEBUG("Comparing types: %s and %s", a->type().c_str(), b->type().c_str()); + // compare numeric types + if ((is_val(a) || is_val(a)) && + (is_val(b) || is_val(b))){ + try { + return a->as_float() == b->as_float(); + } catch (...) {} + } + // compare string and number + // TODO: not sure if this is the right behavior + if ((is_val(b) && (is_val(a) || is_val(a))) || + (is_val(a) && (is_val(b) || is_val(b)))) { + try { + return a->as_string() == b->as_string(); + } catch (...) {} + } + // compare boolean simple + if (is_val(a) && is_val(b)) { + return a->as_bool() == b->as_bool(); + } + // compare string simple + if (is_val(a) && is_val(b)) { + return a->as_string() == b->as_string(); + } + // compare by type + if (a->type() != b->type()) { + return false; + } + return false; +} + } // namespace jinja diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index d2e763b13b..7c431cd47e 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -32,7 +32,7 @@ struct context { struct statement { virtual ~statement() = default; virtual std::string type() const { return "Statement"; } - virtual value execute(context & ctx) { throw std::runtime_error("cannot exec " + type()); } + virtual value execute(context &) { throw std::runtime_error("cannot exec " + type()); } }; using statement_ptr = std::unique_ptr; @@ -68,7 +68,7 @@ struct program : public statement { explicit program(statements && body) : body(std::move(body)) {} std::string type() const override { return "Program"; } - value execute(context & ctx) override { + value execute(context &) override { throw std::runtime_error("Cannot execute program directly, use jinja::vm instead"); } }; @@ -113,12 +113,30 @@ struct for_statement : public statement { struct break_statement : public statement { std::string type() const override { return "Break"; } - value execute(context & ctx) override; + + struct exception : public std::exception { + const char* what() const noexcept override { + return "Break statement executed"; + } + }; + + value execute(context &) override { + throw break_statement::exception(); + } }; struct continue_statement : public statement { std::string type() const override { return "Continue"; } - value execute(context & ctx) override; + + struct exception : public std::exception { + const char* what() const noexcept override { + return "Continue statement executed"; + } + }; + + value execute(context &) override { + throw continue_statement::exception(); + } }; struct set_statement : public statement { @@ -148,14 +166,12 @@ struct macro_statement : public statement { } std::string type() const override { return "Macro"; } - value execute(context & ctx) override {} }; struct comment_statement : public statement { std::string val; explicit comment_statement(const std::string & v) : val(v) {} std::string type() const override { return "Comment"; } - value execute(context & ctx) override {} }; // Expressions @@ -184,6 +200,7 @@ struct call_expression : public expression { for (const auto& arg : this->args) chk_type(arg); } std::string type() const override { return "CallExpression"; } + value execute(context & ctx) override; }; /** @@ -202,7 +219,7 @@ struct integer_literal : public expression { int64_t val; explicit integer_literal(int64_t val) : val(val) {} std::string type() const override { return "IntegerLiteral"; } - value execute(context & ctx) override { + value execute(context &) override { return std::make_unique(val); } }; @@ -211,7 +228,7 @@ struct float_literal : public expression { double val; explicit float_literal(double val) : val(val) {} std::string type() const override { return "FloatLiteral"; } - value execute(context & ctx) override { + value execute(context &) override { return std::make_unique(val); } }; @@ -220,7 +237,7 @@ struct string_literal : public expression { std::string val; explicit string_literal(const std::string & val) : val(val) {} std::string type() const override { return "StringLiteral"; } - value execute(context & ctx) override { + value execute(context &) override { return std::make_unique(val); } }; @@ -300,7 +317,6 @@ struct filter_statement : public statement { chk_type(this->filter); } std::string type() const override { return "FilterStatement"; } - value execute(context & ctx) override {} }; /** @@ -396,7 +412,6 @@ struct call_statement : public statement { for (const auto& arg : this->caller_args) chk_type(arg); } std::string type() const override { return "CallStatement"; } - value execute(context & ctx) override {} }; struct ternary_expression : public expression { @@ -413,6 +428,14 @@ struct ternary_expression : public expression { std::string type() const override { return "Ternary"; } }; +struct raised_exception : public std::exception { + std::string message; + raised_exception(const std::string & msg) : message(msg) {} + const char* what() const noexcept override { + return message.c_str(); + } +}; + ////////////////////// struct vm { diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 63048841c3..085531a673 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -11,11 +11,11 @@ #include "jinja/jinja-lexer.h" int main(void) { - //std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; + std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; //std::string contents = "{% if messages[0]['role'] != 'system' %}nice {{ messages[0]['content'] }}{% endif %}"; - std::string contents = " {{ messages[0]['content'] }} "; + //std::string contents = " {{ messages[0]['content'] }} "; std::cout << "=== INPUT ===\n" << contents << "\n\n"; @@ -34,11 +34,11 @@ int main(void) { std::cout << "stmt type: " << stmt->type() << "\n"; } - std::cout << "\n=== OUTPUT ===\n"; + std::cout << "\n=== RUN ===\n"; jinja::context ctx; auto make_non_special_string = [](const std::string & s) { - jinja::value_string str_val = std::make_unique(s); + jinja::value_string str_val = jinja::mk_val(s); str_val->is_user_input = true; return str_val; }; @@ -57,7 +57,12 @@ int main(void) { jinja::vm vm(ctx); auto results = vm.execute(ast); + + std::cout << "\n=== RESULTS ===\n"; for (const auto & res : results) { + if (res->is_null()) { + continue; + } auto str_ptr = dynamic_cast(res.get()); std::string is_user_input = "false"; if (str_ptr) { From 4ca114b09539e8c43a31132a9cf3dc8f61a4c859 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 12:48:35 +0100 Subject: [PATCH 17/47] track input string even after transformations --- common/jinja/jinja-string.h | 166 +++++++++++++++++++++++++++++ common/jinja/jinja-value.h | 33 ++++-- common/jinja/jinja-vm-builtins.cpp | 103 ++++++++---------- common/jinja/jinja-vm.cpp | 37 +++---- tests/test-chat-jinja.cpp | 9 +- 5 files changed, 252 insertions(+), 96 deletions(-) create mode 100644 common/jinja/jinja-string.h diff --git a/common/jinja/jinja-string.h b/common/jinja/jinja-string.h new file mode 100644 index 0000000000..fb3371271f --- /dev/null +++ b/common/jinja/jinja-string.h @@ -0,0 +1,166 @@ +#pragma once + +#include +#include +#include +#include + + +namespace jinja { + +// allow differentiate between user input strings and template strings +// transformations should handle this information as follows: +// - one-to-one (e.g., uppercase, lowercase): preserve is_input flag +// - one-to-many (e.g., strip): if input string is marked as is_input, all resulting parts should be marked as is_input +// - many-to-one (e.g., concat): if ALL input parts are marked as is_input, resulting part should be marked as is_input +struct string_part { + bool is_input = false; // may skip parsing special tokens if true + std::string val; +}; + +struct string { + using transform_fn = std::function; + + std::vector parts; + string() = default; + string(const std::string & v, bool user_input = false) { + parts.push_back({user_input, v}); + } + string(int v) { + parts.push_back({false, std::to_string(v)}); + } + string(double v) { + parts.push_back({false, std::to_string(v)}); + } + + void mark_input() { + for (auto & part : parts) { + part.is_input = true; + } + } + + std::string str() const { + if (parts.size() == 1) { + return parts[0].val; + } + std::ostringstream oss; + for (const auto & part : parts) { + oss << part.val; + } + return oss.str(); + } + + size_t length() const { + size_t len = 0; + for (const auto & part : parts) { + len += part.val.length(); + } + return len; + } + + bool all_parts_are_input() const { + for (const auto & part : parts) { + if (!part.is_input) { + return false; + } + } + return true; + } + + // mark this string as input if other has ALL parts as input + void mark_input_based_on(const string & other) { + if (other.all_parts_are_input()) { + for (auto & part : parts) { + part.is_input = true; + } + } + } + + string append(const string & other) { + for (const auto & part : other.parts) { + parts.push_back(part); + } + return *this; + } + + // in-place transformation + + string apply_transform(const transform_fn & fn) { + for (auto & part : parts) { + part.val = fn(part.val); + } + return *this; + } + string uppercase() { + return apply_transform([](const std::string & s) { + std::string res = s; + std::transform(res.begin(), res.end(), res.begin(), ::toupper); + return res; + }); + } + string lowercase() { + return apply_transform([](const std::string & s) { + std::string res = s; + std::transform(res.begin(), res.end(), res.begin(), ::tolower); + return res; + }); + } + string capitalize() { + return apply_transform([](const std::string & s) { + if (s.empty()) return s; + std::string res = s; + res[0] = ::toupper(static_cast(res[0])); + std::transform(res.begin() + 1, res.end(), res.begin() + 1, ::tolower); + return res; + }); + } + string titlecase() { + return apply_transform([](const std::string & s) { + std::string res = s; + bool capitalize_next = true; + for (char &c : res) { + if (isspace(static_cast(c))) { + capitalize_next = true; + } else if (capitalize_next) { + c = ::toupper(static_cast(c)); + capitalize_next = false; + } else { + c = ::tolower(static_cast(c)); + } + } + return res; + }); + } + string strip(bool left, bool right) { + // TODO: what if leading/trailing continue in multiple parts? + + static auto strip_part = [](const std::string & s, bool left, bool right) -> std::string { + size_t start = 0; + size_t end = s.length(); + if (left) { + while (start < end && isspace(static_cast(s[start]))) { + ++start; + } + } + if (right) { + while (end > start && isspace(static_cast(s[end - 1]))) { + --end; + } + } + return s.substr(start, end - start); + }; + if (parts.empty()) { + return *this; + } + if (left) { + parts[0].val = strip_part(parts[0].val, true, false); + } + if (right) { + auto & last = parts[parts.size() - 1]; + last.val = strip_part(last.val, false, true); + } + return *this; + } +}; + +} // namespace jinja diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 8b2d74ae35..74366de9ba 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -7,6 +7,7 @@ #include #include +#include "jinja-string.h" namespace jinja { @@ -80,7 +81,7 @@ bool value_compare(const value & a, const value & b); struct value_t { int64_t val_int; double val_flt; - std::string val_str; + string val_str; bool val_bool; // array and object are stored as shared_ptr to allow reference access @@ -102,7 +103,7 @@ struct value_t { virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); } virtual double as_float() const { throw std::runtime_error(type() + " is not a float value"); } - virtual std::string as_string() const { throw std::runtime_error(type() + " is not a string value"); } + virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); } virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); } virtual const std::vector & as_array() const { throw std::runtime_error(type() + " is not an array value"); } virtual const std::map & as_object() const { throw std::runtime_error(type() + " is not an object value"); } @@ -113,7 +114,7 @@ struct value_t { throw std::runtime_error("No builtins available for type " + type()); } - virtual std::string as_repr() const { return as_string(); } + virtual std::string as_repr() const { return as_string().str(); } virtual value clone() const { return std::make_unique(*this); @@ -126,7 +127,7 @@ struct value_int_t : public value_t { virtual std::string type() const override { return "Integer"; } virtual int64_t as_int() const override { return val_int; } virtual double as_float() const override { return static_cast(val_int); } - virtual std::string as_string() const override { return std::to_string(val_int); } + virtual string as_string() const override { return std::to_string(val_int); } virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; }; @@ -138,7 +139,7 @@ struct value_float_t : public value_t { virtual std::string type() const override { return "Float"; } virtual double as_float() const override { return val_flt; } virtual int64_t as_int() const override { return static_cast(val_flt); } - virtual std::string as_string() const override { return std::to_string(val_flt); } + virtual string as_string() const override { return std::to_string(val_flt); } virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; }; @@ -146,13 +147,23 @@ using value_float = std::unique_ptr; struct value_string_t : public value_t { - bool is_user_input = false; // may skip parsing special tokens if true - - value_string_t(const std::string & v) { val_str = v; } + value_string_t() { val_str = string(); } + value_string_t(const std::string & v) { val_str = string(v); } + value_string_t(const string & v) { val_str = v; } virtual std::string type() const override { return "String"; } - virtual std::string as_string() const override { return val_str; } + virtual string as_string() const override { return val_str; } + virtual std::string as_repr() const override { + std::ostringstream ss; + for (const auto & part : val_str.parts) { + ss << (part.is_input ? "INPUT: " : "TMPL: ") << part.val << "\n"; + } + return ss.str(); + } virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; + void mark_input() { + val_str.mark_input(); + } }; using value_string = std::unique_ptr; @@ -161,7 +172,7 @@ struct value_bool_t : public value_t { value_bool_t(bool v) { val_bool = v; } virtual std::string type() const override { return "Boolean"; } virtual bool as_bool() const override { return val_bool; } - virtual std::string as_string() const override { return val_bool ? "True" : "False"; } + virtual string as_string() const override { return std::string(val_bool ? "True" : "False"); } virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; }; @@ -200,7 +211,7 @@ struct value_array_t : public value_t { tmp->val_arr = this->val_arr; return tmp; } - virtual std::string as_string() const override { + virtual string as_string() const override { std::ostringstream ss; ss << "["; for (size_t i = 0; i < val_arr->size(); i++) { diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index 493c71e25e..e8c8eee993 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -12,7 +12,7 @@ const func_builtins & global_builtins() { static const func_builtins builtins = { {"raise_exception", [](const func_args & args) -> value { args.ensure_count(1); - std::string msg = args.args[0]->as_string(); + std::string msg = args.args[0]->as_string().str(); throw raised_exception("Jinja Exception: " + msg); }}, }; @@ -54,21 +54,21 @@ const func_builtins & value_float_t::get_builtins() const { } -static std::string string_strip(const std::string & str, bool left, bool right) { - size_t start = 0; - size_t end = str.length(); - if (left) { - while (start < end && isspace(static_cast(str[start]))) { - ++start; - } - } - if (right) { - while (end > start && isspace(static_cast(str[end - 1]))) { - --end; - } - } - return str.substr(start, end - start); -} +// static std::string string_strip(const std::string & str, bool left, bool right) { +// size_t start = 0; +// size_t end = str.length(); +// if (left) { +// while (start < end && isspace(static_cast(str[start]))) { +// ++start; +// } +// } +// if (right) { +// while (end > start && isspace(static_cast(str[end - 1]))) { +// --end; +// } +// } +// return str.substr(start, end - start); +// } static bool string_startswith(const std::string & str, const std::string & prefix) { if (str.length() < prefix.length()) return false; @@ -84,77 +84,60 @@ const func_builtins & value_string_t::get_builtins() const { static const func_builtins builtins = { {"upper", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - std::transform(str.begin(), str.end(), str.begin(), ::toupper); + jinja::string str = args.args[0]->as_string().uppercase(); return mk_val(str); }}, {"lower", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - std::transform(str.begin(), str.end(), str.begin(), ::tolower); + jinja::string str = args.args[0]->as_string().lowercase(); return mk_val(str); }}, {"strip", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - return mk_val(string_strip(str, true, true)); + jinja::string str = args.args[0]->as_string().strip(true, true); + return mk_val(str); }}, {"rstrip", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - return mk_val(string_strip(str, false, true)); + jinja::string str = args.args[0]->as_string().strip(false, true); + return mk_val(str); }}, {"lstrip", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - return mk_val(string_strip(str, true, false)); + jinja::string str = args.args[0]->as_string().strip(true, false); + return mk_val(str); }}, {"title", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - bool capitalize_next = true; - for (char &c : str) { - if (isspace(static_cast(c))) { - capitalize_next = true; - } else if (capitalize_next) { - c = ::toupper(static_cast(c)); - capitalize_next = false; - } else { - c = ::tolower(static_cast(c)); - } - } + jinja::string str = args.args[0]->as_string().titlecase(); return mk_val(str); }}, {"capitalize", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - if (!str.empty()) { - str[0] = ::toupper(static_cast(str[0])); - std::transform(str.begin() + 1, str.end(), str.begin() + 1, ::tolower); - } + jinja::string str = args.args[0]->as_string().capitalize(); return mk_val(str); }}, {"length", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); + jinja::string str = args.args[0]->as_string(); return mk_val(str.length()); }}, {"startswith", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - std::string prefix = args.args[1]->as_string(); + std::string str = args.args[0]->as_string().str(); + std::string prefix = args.args[1]->as_string().str(); return mk_val(string_startswith(str, prefix)); }}, {"endswith", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - std::string suffix = args.args[1]->as_string(); + std::string str = args.args[0]->as_string().str(); + std::string suffix = args.args[1]->as_string().str(); return mk_val(string_endswith(str, suffix)); }}, {"split", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - std::string delim = (args.args.size() > 1) ? args.args[1]->as_string() : " "; + std::string str = args.args[0]->as_string().str(); + std::string delim = (args.args.size() > 1) ? args.args[1]->as_string().str() : " "; auto result = mk_val(); size_t pos = 0; std::string token; @@ -163,24 +146,28 @@ const func_builtins & value_string_t::get_builtins() const { result->val_arr->push_back(mk_val(token)); str.erase(0, pos + delim.length()); } - result->val_arr->push_back(mk_val(str)); + auto res = mk_val(str); + res->val_str.mark_input_based_on(args.args[0]->val_str); + result->val_arr->push_back(std::move(res)); return std::move(result); }}, {"replace", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - std::string old_str = args.args[1]->as_string(); - std::string new_str = args.args[2]->as_string(); + std::string str = args.args[0]->as_string().str(); + std::string old_str = args.args[1]->as_string().str(); + std::string new_str = args.args[2]->as_string().str(); size_t pos = 0; while ((pos = str.find(old_str, pos)) != std::string::npos) { str.replace(pos, old_str.length(), new_str); pos += new_str.length(); } - return mk_val(str); + auto res = mk_val(str); + res->val_str.mark_input_based_on(args.args[0]->val_str); + return res; }}, {"int", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); + std::string str = args.args[0]->as_string().str(); try { return mk_val(std::stoi(str)); } catch (...) { @@ -189,7 +176,7 @@ const func_builtins & value_string_t::get_builtins() const { }}, {"float", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); + std::string str = args.args[0]->as_string().str(); try { return mk_val(std::stod(str)); } catch (...) { @@ -277,7 +264,7 @@ const func_builtins & value_object_t::get_builtins() const { {"get", [](const func_args & args) -> value { args.ensure_vals(); // TODO: add default value const auto & obj = args.args[0]->as_object(); - std::string key = args.args[1]->as_string(); + std::string key = args.args[1]->as_string().str(); auto it = obj.find(key); if (it != obj.end()) { return it->second->clone(); diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 7fb323c58b..c6861eeb39 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -72,11 +72,6 @@ value binary_expression::execute(context & ctx) { throw std::runtime_error("Cannot perform operation on null values"); } - // String concatenation with ~ - if (op.value == "~") { - return mk_val(left_val->as_string() + right_val->as_string()); - } - // Float operations if ((is_val(left_val) || is_val(left_val)) && (is_val(right_val) || is_val(right_val))) { @@ -137,18 +132,20 @@ value binary_expression::execute(context & ctx) { } } - // String concatenation - if (is_val(left_val) || is_val(right_val)) { - JJ_DEBUG("%s", "String concatenation with + operator"); - if (op.value == "+") { - return mk_val(left_val->as_string() + right_val->as_string()); - } + // String concatenation with ~ and + + if ((is_val(left_val) || is_val(right_val)) && + (op.value == "~" || op.value == "+")) { + JJ_DEBUG("String concatenation with %s operator", op.value.c_str()); + auto output = left_val->as_string().append(right_val->as_string()); + auto res = mk_val(); + res->val_str = std::move(output); + return res; } // String membership if (is_val(left_val) && is_val(right_val)) { - auto left_str = left_val->as_string(); - auto right_str = right_val->as_string(); + auto left_str = left_val->as_string().str(); + auto right_str = right_val->as_string().str(); if (op.value == "in") { return mk_val(right_str.find(left_str) != std::string::npos); } else if (op.value == "not in") { @@ -158,7 +155,7 @@ value binary_expression::execute(context & ctx) { // String in object if (is_val(left_val) && is_val(right_val)) { - auto key = left_val->as_string(); + auto key = left_val->as_string().str(); auto & obj = right_val->as_object(); bool has_key = obj.find(key) != obj.end(); if (op.value == "in") { @@ -434,7 +431,7 @@ value member_expression::execute(context & ctx) { if (!is_val(property)) { throw std::runtime_error("Cannot access object with non-string: got " + property->type()); } - auto key = property->as_string(); + auto key = property->as_string().str(); auto & obj = object->as_object(); auto it = obj.find(key); if (it != obj.end()) { @@ -459,13 +456,13 @@ value member_expression::execute(context & ctx) { val = arr[index]->clone(); } } else { // value_string - auto str = object->as_string(); + auto str = object->as_string().str(); if (index >= 0 && index < static_cast(str.size())) { val = mk_val(std::string(1, str[index])); } } } else if (is_val(property)) { - auto key = property->as_string(); + auto key = property->as_string().str(); JJ_DEBUG("Accessing %s built-in '%s'", is_val(object) ? "array" : "string", key.c_str()); auto builtins = object->get_builtins(); auto bit = builtins.find(key); @@ -482,7 +479,7 @@ value member_expression::execute(context & ctx) { if (!is_val(property)) { throw std::runtime_error("Cannot access property with non-string: got " + property->type()); } - auto key = property->as_string(); + auto key = property->as_string().str(); auto builtins = object->get_builtins(); auto bit = builtins.find(key); if (bit != builtins.end()) { @@ -528,7 +525,7 @@ bool value_compare(const value & a, const value & b) { if ((is_val(b) && (is_val(a) || is_val(a))) || (is_val(a) && (is_val(b) || is_val(b)))) { try { - return a->as_string() == b->as_string(); + return a->as_string().str() == b->as_string().str(); } catch (...) {} } // compare boolean simple @@ -537,7 +534,7 @@ bool value_compare(const value & a, const value & b) { } // compare string simple if (is_val(a) && is_val(b)) { - return a->as_string() == b->as_string(); + return a->as_string().str() == b->as_string().str(); } // compare by type if (a->type() != b->type()) { diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 085531a673..acbf7daf2a 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -39,7 +39,7 @@ int main(void) { auto make_non_special_string = [](const std::string & s) { jinja::value_string str_val = jinja::mk_val(s); - str_val->is_user_input = true; + str_val->mark_input(); return str_val; }; @@ -63,12 +63,7 @@ int main(void) { if (res->is_null()) { continue; } - auto str_ptr = dynamic_cast(res.get()); - std::string is_user_input = "false"; - if (str_ptr) { - is_user_input = str_ptr->is_user_input ? "true" : "false"; - } - std::cout << "result type: " << res->type() << " | value: " << res->as_string() << " | is_user_input: " << is_user_input << "\n"; + std::cout << "result type: " << res->type() << " | value: " << res->as_repr(); } return 0; From 45c194622efbd32660cc4fdf83ac8c32dcd20c3c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 15:33:14 +0100 Subject: [PATCH 18/47] support binded functions --- common/jinja/jinja-string.h | 36 +++++++ common/jinja/jinja-value.h | 36 ++++++- common/jinja/jinja-vm-builtins.cpp | 58 ++++++++++- common/jinja/jinja-vm.cpp | 151 +++++++++++++++++------------ common/jinja/jinja-vm.h | 33 ++++++- tests/test-chat-jinja.cpp | 16 +-- 6 files changed, 254 insertions(+), 76 deletions(-) diff --git a/common/jinja/jinja-string.h b/common/jinja/jinja-string.h index fb3371271f..d26bb1e20c 100644 --- a/common/jinja/jinja-string.h +++ b/common/jinja/jinja-string.h @@ -16,6 +16,24 @@ namespace jinja { struct string_part { bool is_input = false; // may skip parsing special tokens if true std::string val; + + bool is_uppercase() const { + for (char c : val) { + if (std::islower(static_cast(c))) { + return false; + } + } + return true; + } + + bool is_lowercase() const { + for (char c : val) { + if (std::isupper(static_cast(c))) { + return false; + } + } + return true; + } }; struct string { @@ -67,6 +85,24 @@ struct string { return true; } + bool is_uppercase() const { + for (const auto & part : parts) { + if (!part.is_uppercase()) { + return false; + } + } + return true; + } + + bool is_lowercase() const { + for (const auto & part : parts) { + if (!part.is_lowercase()) { + return false; + } + } + return true; + } + // mark this string as input if other has ALL parts as input void mark_input_based_on(const string & other) { if (other.all_parts_are_input()) { diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 74366de9ba..787cec46b3 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -107,7 +107,7 @@ struct value_t { virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); } virtual const std::vector & as_array() const { throw std::runtime_error(type() + " is not an array value"); } virtual const std::map & as_object() const { throw std::runtime_error(type() + " is not an object value"); } - virtual value invoke(const func_args &) const { throw std::runtime_error("Not a function value"); } + virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); } virtual bool is_null() const { return false; } virtual bool is_undefined() const { return false; } virtual const func_builtins & get_builtins() const { @@ -221,6 +221,9 @@ struct value_array_t : public value_t { ss << "]"; return ss.str(); } + virtual bool as_bool() const override { + return !val_arr->empty(); + } virtual const func_builtins & get_builtins() const override; }; using value_array = std::unique_ptr; @@ -251,17 +254,44 @@ struct value_object_t : public value_t { tmp->val_obj = this->val_obj; return tmp; } + virtual bool as_bool() const override { + return !val_obj->empty(); + } virtual const func_builtins & get_builtins() const override; }; using value_object = std::unique_ptr; struct value_func_t : public value_t { - value_func_t(func_handler & func) { + std::string name; // for debugging + value arg0; // bound "this" argument, if any + value_func_t(const value_func_t & other) { + val_func = other.val_func; + name = other.name; + if (other.arg0) { + arg0 = other.arg0->clone(); + } + } + value_func_t(const func_handler & func, std::string func_name = "") { val_func = func; + name = func_name; + } + value_func_t(const func_handler & func, const value & arg_this, std::string func_name = "") { + val_func = func; + name = func_name; + arg0 = arg_this->clone(); } virtual value invoke(const func_args & args) const override { - return val_func(args); + if (arg0) { + func_args new_args; + new_args.args.push_back(arg0->clone()); + for (const auto & a : args.args) { + new_args.args.push_back(a->clone()); + } + return val_func(new_args); + } else { + return val_func(args); + } } virtual std::string type() const override { return "Function"; } virtual std::string as_repr() const override { return type(); } diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index e8c8eee993..160001e522 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -8,13 +8,69 @@ namespace jinja { +template +static value test_type_fn(const func_args & args) { + args.ensure_count(1); + bool is_type = is_val(args.args[0]); + return mk_val(is_type); +} +template +static value test_type_fn(const func_args & args) { + args.ensure_count(1); + bool is_type = is_val(args.args[0]) || is_val(args.args[0]); + return mk_val(is_type); +} + const func_builtins & global_builtins() { static const func_builtins builtins = { {"raise_exception", [](const func_args & args) -> value { - args.ensure_count(1); + args.ensure_vals(); std::string msg = args.args[0]->as_string().str(); throw raised_exception("Jinja Exception: " + msg); }}, + + // tests + {"test_is_boolean", test_type_fn}, + {"test_is_callable", test_type_fn}, + {"test_is_odd", [](const func_args & args) -> value { + args.ensure_vals(); + int64_t val = args.args[0]->as_int(); + return mk_val(val % 2 != 0); + }}, + {"test_is_even", [](const func_args & args) -> value { + args.ensure_vals(); + int64_t val = args.args[0]->as_int(); + return mk_val(val % 2 == 0); + }}, + {"test_is_false", [](const func_args & args) -> value { + args.ensure_count(1); + bool val = is_val(args.args[0]) && !args.args[0]->as_bool(); + return mk_val(val); + }}, + {"test_is_true", [](const func_args & args) -> value { + args.ensure_count(1); + bool val = is_val(args.args[0]) && args.args[0]->as_bool(); + return mk_val(val); + }}, + {"test_is_string", test_type_fn}, + {"test_is_integer", test_type_fn}, + {"test_is_number", test_type_fn}, + {"test_is_iterable", test_type_fn}, + {"test_is_mapping", test_type_fn}, + {"test_is_lower", [](const func_args & args) -> value { + args.ensure_vals(); + return mk_val(args.args[0]->val_str.is_lowercase()); + }}, + {"test_is_upper", [](const func_args & args) -> value { + args.ensure_vals(); + return mk_val(args.args[0]->val_str.is_uppercase()); + }}, + {"test_is_none", test_type_fn}, + {"test_is_defined", [](const func_args & args) -> value { + args.ensure_count(1); + return mk_val(!is_val(args.args[0])); + }}, + {"test_is_undefined", test_type_fn}, }; return builtins; } diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index c6861eeb39..bd4d53bded 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -8,7 +8,7 @@ #include #include -#define JJ_DEBUG(msg, ...) printf("jinja-vm: " msg "\n", __VA_ARGS__) +#define JJ_DEBUG(msg, ...) printf("jinja-vm:%3d : " msg "\n", __LINE__, __VA_ARGS__) //#define JJ_DEBUG(msg, ...) // no-op namespace jinja { @@ -44,7 +44,7 @@ value identifier::execute(context & ctx) { value binary_expression::execute(context & ctx) { value left_val = left->execute(ctx); - JJ_DEBUG("Executing binary expression with operator '%s'", op.value.c_str()); + JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right->type().c_str()); // Logical operators if (op.value == "and") { @@ -168,20 +168,19 @@ value binary_expression::execute(context & ctx) { throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type()); } +static value try_builtin_func(const std::string & name, const value & input) { + auto builtins = input->get_builtins(); + auto it = builtins.find(name); + if (it != builtins.end()) { + JJ_DEBUG("Binding built-in '%s'", name.c_str()); + return mk_val(it->second, input, name); + } + throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type()); +} + value filter_expression::execute(context & ctx) { value input = operand->execute(ctx); - auto try_builtin = [&](const std::string & name) -> value { - auto builtins = input->get_builtins(); - auto it = builtins.find(name); - if (it != builtins.end()) { - func_args args; - args.args.push_back(input->clone()); - return it->second(args); - } - throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type()); - }; - if (is_stmt(filter)) { auto filter_val = dynamic_cast(filter.get())->val; @@ -190,35 +189,12 @@ value filter_expression::execute(context & ctx) { throw std::runtime_error("to_json filter not implemented"); } - if (is_val(input)) { - auto res = try_builtin(filter_val); - if (res) { - return res; - } - throw std::runtime_error("Unknown filter '" + filter_val + "' for array"); - - } else if (is_val(input)) { - auto str = input->as_string(); - auto builtins = input->get_builtins(); - if (filter_val == "trim") { - filter_val = "strip"; // alias - } - auto res = try_builtin(filter_val); - if (res) { - return res; - } - throw std::runtime_error("Unknown filter '" + filter_val + "' for string"); - - } else if (is_val(input) || is_val(input)) { - auto res = try_builtin(filter_val); - if (res) { - return res; - } - throw std::runtime_error("Unknown filter '" + filter_val + "' for number"); - - } else { - throw std::runtime_error("Filters not supported for type " + input->type()); + auto str = input->as_string(); + if (filter_val == "trim") { + filter_val = "strip"; // alias } + JJ_DEBUG("Applying filter '%s' to %s", filter_val.c_str(), input->type().c_str()); + return try_builtin_func(filter_val, input); } else if (is_stmt(filter)) { // TODO @@ -230,6 +206,44 @@ value filter_expression::execute(context & ctx) { } } +value test_expression::execute(context & ctx) { + // NOTE: "value is something" translates to function call "test_is_something(value)" + const auto & builtins = global_builtins(); + if (!is_stmt(test)) { + throw std::runtime_error("Invalid test expression"); + } + + auto test_id = dynamic_cast(test.get())->val; + auto it = builtins.find("test_is_" + test_id); + JJ_DEBUG("Test expression %s '%s'", operand->type().c_str(), test_id.c_str()); + if (it == builtins.end()) { + throw std::runtime_error("Unknown test '" + test_id + "'"); + } + + func_args args; + args.args.push_back(operand->execute(ctx)); + return it->second(args); +} + +value unary_expression::execute(context & ctx) { + value operand_val = argument->execute(ctx); + JJ_DEBUG("Executing unary expression with operator '%s'", op.value.c_str()); + + if (op.value == "not") { + return mk_val(!operand_val->as_bool()); + } else if (op.value == "-") { + if (is_val(operand_val)) { + return mk_val(-operand_val->as_int()); + } else if (is_val(operand_val)) { + return mk_val(-operand_val->as_float()); + } else { + throw std::runtime_error("Unary - operator requires numeric operand"); + } + } + + throw std::runtime_error("Unknown unary operator '" + op.value + "'"); +} + value if_statement::execute(context & ctx) { value test_val = test->execute(ctx); auto out = mk_val(); @@ -415,16 +429,46 @@ value set_statement::execute(context & ctx) { return mk_val(); } +value macro_statement::execute(context & ctx) { + std::string name = dynamic_cast(this->name.get())->val; + const func_handler func = [this, &ctx, name](const func_args & args) -> value { + JJ_DEBUG("Invoking macro '%s' with %zu arguments", name.c_str(), args.args.size()); + context macro_ctx(ctx); // new scope for macro execution + + // bind parameters + size_t param_count = this->args.size(); + size_t arg_count = args.args.size(); + for (size_t i = 0; i < param_count; ++i) { + std::string param_name = dynamic_cast(this->args[i].get())->val; + if (i < arg_count) { + macro_ctx.var[param_name] = args.args[i]->clone(); + } else { + macro_ctx.var[param_name] = mk_val(); + } + } + + // execute macro body + return exec_statements(this->body, macro_ctx); + }; + + JJ_DEBUG("Defining macro '%s' with %zu parameters", name.c_str(), args.size()); + ctx.var[name] = mk_val(func); + return mk_val(); +} + value member_expression::execute(context & ctx) { value object = this->object->execute(ctx); value property; if (this->computed) { + JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str()); property = this->property->execute(ctx); } else { property = mk_val(dynamic_cast(this->property.get())->val); } + JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str()); + value val = mk_val(); if (is_val(object)) { @@ -432,18 +476,13 @@ value member_expression::execute(context & ctx) { throw std::runtime_error("Cannot access object with non-string: got " + property->type()); } auto key = property->as_string().str(); + JJ_DEBUG("Accessing object property '%s'", key.c_str()); auto & obj = object->as_object(); auto it = obj.find(key); if (it != obj.end()) { val = it->second->clone(); } else { - auto builtins = object->get_builtins(); - auto bit = builtins.find(key); - if (bit != builtins.end()) { - func_args args; - args.args.push_back(object->clone()); - val = bit->second(args); - } + val = try_builtin_func(key, object); } } else if (is_val(object) || is_val(object)) { @@ -464,13 +503,7 @@ value member_expression::execute(context & ctx) { } else if (is_val(property)) { auto key = property->as_string().str(); JJ_DEBUG("Accessing %s built-in '%s'", is_val(object) ? "array" : "string", key.c_str()); - auto builtins = object->get_builtins(); - auto bit = builtins.find(key); - if (bit != builtins.end()) { - func_args args; - args.args.push_back(object->clone()); - val = bit->second(args); - } + val = try_builtin_func(key, object); } else { throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type()); } @@ -480,13 +513,7 @@ value member_expression::execute(context & ctx) { throw std::runtime_error("Cannot access property with non-string: got " + property->type()); } auto key = property->as_string().str(); - auto builtins = object->get_builtins(); - auto bit = builtins.find(key); - if (bit != builtins.end()) { - func_args args; - args.args.push_back(object->clone()); - val = bit->second(args); - } + val = try_builtin_func(key, object); } return val; diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 7c431cd47e..786d49bad1 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -166,12 +166,16 @@ struct macro_statement : public statement { } std::string type() const override { return "Macro"; } + value execute(context & ctx) override; }; struct comment_statement : public statement { std::string val; explicit comment_statement(const std::string & v) : val(v) {} std::string type() const override { return "Comment"; } + value execute(context &) override { + return mk_val(); + } }; // Expressions @@ -339,6 +343,7 @@ struct select_expression : public expression { /** * An operation with two sides, separated by the "is" operator. + * NOTE: "value is something" translates to function call "test_is_something(value)" */ struct test_expression : public expression { statement_ptr operand; @@ -351,6 +356,7 @@ struct test_expression : public expression { chk_type(this->test); } std::string type() const override { return "TestExpression"; } + value execute(context & ctx) override; }; /** @@ -365,6 +371,7 @@ struct unary_expression : public expression { chk_type(this->argument); } std::string type() const override { return "UnaryExpression"; } + value execute(context & ctx) override; }; struct slice_expression : public expression { @@ -442,14 +449,34 @@ struct vm { context & ctx; explicit vm(context & ctx) : ctx(ctx) {} - std::vector execute(program & prog) { - std::vector results; + value_array execute(program & prog) { + value_array results = mk_val(); for (auto & stmt : prog.body) { value res = stmt->execute(ctx); - results.push_back(std::move(res)); + results->val_arr->push_back(std::move(res)); } return results; } + + std::vector gather_string_parts(const value & val) { + std::vector parts; + gather_string_parts_recursive(val, parts); + return parts; + } + + void gather_string_parts_recursive(const value & val, std::vector & parts) { + if (is_val(val)) { + const auto & str_val = dynamic_cast(val.get())->val_str; + for (const auto & part : str_val.parts) { + parts.push_back(part); + } + } else if (is_val(val)) { + auto items = dynamic_cast(val.get())->val_arr.get(); + for (const auto & item : *items) { + gather_string_parts_recursive(item, parts); + } + } + } }; } // namespace jinja diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index acbf7daf2a..87ac00fca1 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #undef NDEBUG #include @@ -11,12 +12,15 @@ #include "jinja/jinja-lexer.h" int main(void) { - std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; + //std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; //std::string contents = "{% if messages[0]['role'] != 'system' %}nice {{ messages[0]['content'] }}{% endif %}"; //std::string contents = " {{ messages[0]['content'] }} "; + std::ifstream infile("models/templates/moonshotai-Kimi-K2.jinja"); + std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); + std::cout << "=== INPUT ===\n" << contents << "\n\n"; jinja::lexer lexer; @@ -56,14 +60,12 @@ int main(void) { ctx.var["messages"] = std::move(messages); jinja::vm vm(ctx); - auto results = vm.execute(ast); + const jinja::value results = vm.execute(ast); + auto parts = vm.gather_string_parts(results); std::cout << "\n=== RESULTS ===\n"; - for (const auto & res : results) { - if (res->is_null()) { - continue; - } - std::cout << "result type: " << res->type() << " | value: " << res->as_repr(); + for (const auto & part : parts) { + std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n"; } return 0; From 4331e9c8e979bedff396f4a4e5764fa50df8df92 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 17:23:29 +0100 Subject: [PATCH 19/47] keyword arguments and slicing array --- common/jinja/jinja-value.h | 29 +++++----- common/jinja/jinja-vm-builtins.cpp | 85 +++++++++++++++++++++++++++++ common/jinja/jinja-vm.cpp | 86 +++++++++++++++++++++--------- common/jinja/jinja-vm.h | 34 +++++++----- tests/test-chat-jinja.cpp | 2 +- 5 files changed, 184 insertions(+), 52 deletions(-) diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 787cec46b3..2bb600c1b9 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -55,6 +55,7 @@ struct func_args { throw std::runtime_error("Expected " + std::to_string(count) + " arguments, got " + std::to_string(args.size())); } } + // TODO: add support for get kwargs // utility functions template void ensure_vals() const { ensure_count(1); @@ -187,19 +188,6 @@ struct value_array_t : public value_t { // point to the same underlying data val_arr = v->val_arr; } - value_array_t(value_array_t & other, size_t start = 0, size_t end = -1) { - val_arr = std::make_shared>(); - size_t sz = other.val_arr->size(); - if (end == static_cast(-1) || end > sz) { - end = sz; - } - if (start > end || start >= sz) { - return; - } - for (size_t i = start; i < end; i++) { - val_arr->push_back(other.val_arr->at(i)->clone()); - } - } void push_back(const value & val) { val_arr->push_back(val->clone()); } @@ -319,6 +307,21 @@ struct value_undefined_t : public value_t { }; using value_undefined = std::unique_ptr; +// special value for kwarg +struct value_kwarg_t : public value_t { + std::string key; + value val; + value_kwarg_t(const value_kwarg_t & other) { + key = other.key; + val = other.val->clone(); + } + value_kwarg_t(const std::string & k, const value & v) : key(k), val(v->clone()) {} + virtual std::string type() const override { return "KwArg"; } + virtual std::string as_repr() const override { return type(); } + virtual value clone() const override { return std::make_unique(*this); } +}; +using value_kwarg = std::unique_ptr; + const func_builtins & global_builtins(); diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index 160001e522..feb7ffb5d2 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -5,9 +5,62 @@ #include #include +#include +#include +#include namespace jinja { +/** + * Function that mimics Python's array slicing. + */ +template +static T slice(const T & array, std::optional start = std::nullopt, std::optional stop = std::nullopt, int64_t step = 1) { + int64_t len = static_cast(array.size()); + int64_t direction = (step > 0) ? 1 : ((step < 0) ? -1 : 0); + int64_t start_val; + int64_t stop_val; + if (direction >= 0) { + start_val = start.value_or(0); + if (start_val < 0) { + start_val = std::max(len + start_val, (int64_t)0); + } else { + start_val = std::min(start_val, len); + } + + stop_val = stop.value_or(len); + if (stop_val < 0) { + stop_val = std::max(len + stop_val, (int64_t)0); + } else { + stop_val = std::min(stop_val, len); + } + } else { + start_val = start.value_or(len - 1); + if (start_val < 0) { + start_val = std::max(len + start_val, (int64_t)-1); + } else { + start_val = std::min(start_val, len - 1); + } + + stop_val = stop.value_or(-1); + if (stop_val < -1) { + stop_val = std::max(len + stop_val, (int64_t)-1); + } else { + stop_val = std::min(stop_val, len - 1); + } + } + T result; + if (direction == 0) { + return result; + } + for (int64_t i = start_val; direction * i < direction * stop_val; i += step) { + if (i >= 0 && i < len) { + result.push_back(std::move(array[static_cast(i)]->clone())); + } + } + return result; +} + template static value test_type_fn(const func_args & args) { args.ensure_count(1); @@ -28,6 +81,17 @@ const func_builtins & global_builtins() { std::string msg = args.args[0]->as_string().str(); throw raised_exception("Jinja Exception: " + msg); }}, + {"namespace", [](const func_args & args) -> value { + auto out = mk_val(); + for (const auto & arg : args.args) { + if (!is_val(arg)) { + throw raised_exception("namespace() arguments must be kwargs"); + } + auto kwarg = dynamic_cast(arg.get()); + out->insert(kwarg->key, kwarg->val); + } + return out; + }}, // tests {"test_is_boolean", test_type_fn}, @@ -126,6 +190,8 @@ const func_builtins & value_float_t::get_builtins() const { // return str.substr(start, end - start); // } + + static bool string_startswith(const std::string & str, const std::string & prefix) { if (str.length() < prefix.length()) return false; return str.compare(0, prefix.length(), prefix) == 0; @@ -250,6 +316,9 @@ const func_builtins & value_string_t::get_builtins() const { {"join", [](const func_args &) -> value { throw std::runtime_error("join builtin not implemented"); }}, + {"slice", [](const func_args &) -> value { + throw std::runtime_error("slice builtin not implemented"); + }}, }; return builtins; } @@ -309,6 +378,22 @@ const func_builtins & value_array_t::get_builtins() const { const auto & arr = args.args[0]->as_array(); return mk_val(static_cast(arr.size())); }}, + {"slice", [](const func_args & args) -> value { + args.ensure_count(4); + int64_t start = is_val(args.args[1]) ? args.args[1]->as_int() : 0; + int64_t stop = is_val(args.args[2]) ? args.args[2]->as_int() : -1; + int64_t step = is_val(args.args[3]) ? args.args[3]->as_int() : 1; + if (!is_val(args.args[0])) { + throw raised_exception("slice() first argument must be an array"); + } + if (step == 0) { + throw raised_exception("slice step cannot be zero"); + } + auto arr = slice(args.args[0]->as_array(), start, stop, step); + auto res = mk_val(); + res->val_arr = std::make_shared>(std::move(arr)); + return res; + }}, // TODO: reverse, sort, join, string, unique }; return builtins; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index bd4d53bded..f39321fa00 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -35,7 +35,7 @@ value identifier::execute(context & ctx) { return it->second->clone(); } else if (builtins.find(val) != builtins.end()) { JJ_DEBUG("Identifier '%s' found in builtins", val.c_str()); - return mk_val(builtins.at(val)); + return mk_val(builtins.at(val), val); } else { JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str()); return mk_val(); @@ -168,13 +168,16 @@ value binary_expression::execute(context & ctx) { throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type()); } -static value try_builtin_func(const std::string & name, const value & input) { +static value try_builtin_func(const std::string & name, const value & input, bool undef_on_missing = true) { auto builtins = input->get_builtins(); auto it = builtins.find(name); if (it != builtins.end()) { JJ_DEBUG("Binding built-in '%s'", name.c_str()); return mk_val(it->second, input, name); } + if (undef_on_missing) { + return mk_val(); + } throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type()); } @@ -189,12 +192,11 @@ value filter_expression::execute(context & ctx) { throw std::runtime_error("to_json filter not implemented"); } - auto str = input->as_string(); if (filter_val == "trim") { filter_val = "strip"; // alias } JJ_DEBUG("Applying filter '%s' to %s", filter_val.c_str(), input->type().c_str()); - return try_builtin_func(filter_val, input); + return try_builtin_func(filter_val, input)->invoke({}); } else if (is_stmt(filter)) { // TODO @@ -385,7 +387,7 @@ value set_statement::execute(context & ctx) { if (is_stmt(assignee)) { auto var_name = dynamic_cast(assignee.get())->val; - JJ_DEBUG("Setting variable '%s'", var_name.c_str()); + JJ_DEBUG("Setting variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str()); ctx.var[var_name] = rhs->clone(); } else if (is_stmt(assignee)) { @@ -408,10 +410,6 @@ value set_statement::execute(context & ctx) { } else if (is_stmt(assignee)) { auto member = dynamic_cast(assignee.get()); - value object = member->object->execute(ctx); - if (!is_val(object)) { - throw std::runtime_error("Cannot assign to member of non-object"); - } if (member->computed) { throw std::runtime_error("Cannot assign to computed member"); } @@ -419,9 +417,14 @@ value set_statement::execute(context & ctx) { throw std::runtime_error("Cannot assign to member with non-identifier property"); } auto prop_name = dynamic_cast(member->property.get())->val; - auto obj_ptr = dynamic_cast(object.get()); + + value object = member->object->execute(ctx); + if (!is_val(object)) { + throw std::runtime_error("Cannot assign to member of non-object"); + } + auto obj_ptr = dynamic_cast(object.get()); JJ_DEBUG("Setting object property '%s'", prop_name.c_str()); - obj_ptr->get()->insert(prop_name, rhs->clone()); + obj_ptr->insert(prop_name, rhs->clone()); } else { throw std::runtime_error("Invalid LHS inside assignment expression: " + assignee->type()); @@ -462,7 +465,26 @@ value member_expression::execute(context & ctx) { value property; if (this->computed) { JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str()); - property = this->property->execute(ctx); + if (is_stmt(this->property)) { + auto s = dynamic_cast(this->property.get()); + value start_val = s->start_expr ? s->start_expr->execute(ctx) : mk_val(); + value stop_val = s->stop_expr ? s->stop_expr->execute(ctx) : mk_val(); + value step_val = s->step_expr ? s->step_expr->execute(ctx) : mk_val(); + + // translate to function call: obj.slice(start, stop, step) + JJ_DEBUG("Member expression is a slice: start %s, stop %s, step %s", + start_val->as_repr().c_str(), + stop_val->as_repr().c_str(), + step_val->as_repr().c_str()); + auto slice_func = try_builtin_func("slice", object); + func_args args; + args.args.push_back(start_val->clone()); + args.args.push_back(stop_val->clone()); + args.args.push_back(step_val->clone()); + return slice_func->invoke(args); + } else { + property = this->property->execute(ctx); + } } else { property = mk_val(dynamic_cast(this->property.get())->val); } @@ -482,7 +504,7 @@ value member_expression::execute(context & ctx) { if (it != obj.end()) { val = it->second->clone(); } else { - val = try_builtin_func(key, object); + val = try_builtin_func(key, object, true); } } else if (is_val(object) || is_val(object)) { @@ -519,22 +541,22 @@ value member_expression::execute(context & ctx) { return val; } -static func_args gather_call_args(const statements & arg_stmts, context & ctx) { - func_args args; - for (auto & arg_stmt : arg_stmts) { - args.args.push_back(arg_stmt->execute(ctx)); - } - return args; -} - value call_expression::execute(context & ctx) { - auto args = gather_call_args(this->args, ctx); + // gather arguments + func_args args; + for (auto & arg_stmt : this->args) { + auto arg_val = arg_stmt->execute(ctx); + JJ_DEBUG(" Argument type: %s", arg_val->type().c_str()); + args.args.push_back(std::move(arg_val)); + } + // execute callee value callee_val = callee->execute(ctx); - JJ_DEBUG("Calling function of type %s with %zu arguments", callee_val->type().c_str(), args.args.size()); - if (!is_val(callee_val)) { + if (!is_val(callee_val)) { throw std::runtime_error("Callee is not a function: got " + callee_val->type()); } - return callee_val->invoke(args); + auto * callee_func = dynamic_cast(callee_val.get()); + JJ_DEBUG("Calling function '%s' with %zu arguments", callee_func->name.c_str(), args.args.size()); + return callee_func->invoke(args); } // compare operator for value_t @@ -570,4 +592,18 @@ bool value_compare(const value & a, const value & b) { return false; } +value keyword_argument_expression::execute(context & ctx) { + if (!is_stmt(key)) { + throw std::runtime_error("Keyword argument key must be identifiers"); + } + + std::string k = dynamic_cast(key.get())->val; + JJ_DEBUG("Keyword argument expression key: %s, value: %s", k.c_str(), val->type().c_str()); + + value v = val->execute(ctx); + JJ_DEBUG("Keyword argument value executed, type: %s", v->type().c_str()); + + return mk_val(k, v); +} + } // namespace jinja diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 786d49bad1..a931bc1ea8 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -15,7 +15,11 @@ namespace jinja { struct context { std::map var; - context() = default; + context() { + var["true"] = mk_val(true); + var["false"] = mk_val(false); + var["none"] = mk_val(); + } ~context() = default; context(const context & parent) { @@ -375,29 +379,33 @@ struct unary_expression : public expression { }; struct slice_expression : public expression { - statement_ptr start; - statement_ptr stop; - statement_ptr step; + statement_ptr start_expr; + statement_ptr stop_expr; + statement_ptr step_expr; - slice_expression(statement_ptr && start, statement_ptr && stop, statement_ptr && step) - : start(std::move(start)), stop(std::move(stop)), step(std::move(step)) { - chk_type(this->start); - chk_type(this->stop); - chk_type(this->step); + slice_expression(statement_ptr && start_expr, statement_ptr && stop_expr, statement_ptr && step_expr) + : start_expr(std::move(start_expr)), stop_expr(std::move(stop_expr)), step_expr(std::move(step_expr)) { + chk_type(this->start_expr); + chk_type(this->stop_expr); + chk_type(this->step_expr); } std::string type() const override { return "SliceExpression"; } + value execute(context &) override { + throw std::runtime_error("must be handled by MemberExpression"); + } }; struct keyword_argument_expression : public expression { statement_ptr key; - statement_ptr value; + statement_ptr val; - keyword_argument_expression(statement_ptr && key, statement_ptr && value) - : key(std::move(key)), value(std::move(value)) { + keyword_argument_expression(statement_ptr && key, statement_ptr && val) + : key(std::move(key)), val(std::move(val)) { chk_type(this->key); - chk_type(this->value); + chk_type(this->val); } std::string type() const override { return "KeywordArgumentExpression"; } + value execute(context & ctx) override; }; struct spread_expression : public expression { diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 87ac00fca1..ce17df5b1d 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -18,7 +18,7 @@ int main(void) { //std::string contents = " {{ messages[0]['content'] }} "; - std::ifstream infile("models/templates/moonshotai-Kimi-K2.jinja"); + std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja"); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); std::cout << "=== INPUT ===\n" << contents << "\n\n"; From 7f17608ea433729e47751d452eb7545768ed45d9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 17:46:25 +0100 Subject: [PATCH 20/47] use shared_ptr for values --- common/jinja/jinja-value.h | 113 +++++++++++------------------ common/jinja/jinja-vm-builtins.cpp | 28 +++---- common/jinja/jinja-vm.cpp | 85 ++++++++++------------ common/jinja/jinja-vm.h | 40 ++++++++-- tests/test-chat-jinja.cpp | 18 ++--- 5 files changed, 137 insertions(+), 147 deletions(-) diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 2bb600c1b9..6c6f4a30d6 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -12,7 +12,7 @@ namespace jinja { struct value_t; -using value = std::unique_ptr; +using value = std::shared_ptr; // Helper to check the type of a value @@ -21,7 +21,7 @@ struct extract_pointee { using type = T; }; template -struct extract_pointee> { +struct extract_pointee> { using type = U; }; template @@ -35,9 +35,19 @@ bool is_val(const value_t * ptr) { return dynamic_cast(ptr) != nullptr; } template -std::unique_ptr::type> mk_val(Args&&... args) { +std::shared_ptr::type> mk_val(Args&&... args) { using PointeeType = typename extract_pointee::type; - return std::make_unique(std::forward(args)...); + return std::make_shared(std::forward(args)...); +} +template +const typename extract_pointee::type * cast_val(const value & ptr) { + using PointeeType = typename extract_pointee::type; + return dynamic_cast(ptr.get()); +} +template +typename extract_pointee::type * cast_val(value & ptr) { + using PointeeType = typename extract_pointee::type; + return dynamic_cast(ptr.get()); } template void ensure_val(const value & ptr) { @@ -91,8 +101,8 @@ struct value_t { // my_arr = [my_obj] // my_obj["a"] = 3 // print(my_arr[0]["a"]) # should print 3 - std::shared_ptr> val_arr; - std::shared_ptr> val_obj; + std::vector val_arr; + std::map val_obj; func_handler val_func; @@ -116,10 +126,6 @@ struct value_t { } virtual std::string as_repr() const { return as_string().str(); } - - virtual value clone() const { - return std::make_unique(*this); - } }; @@ -129,10 +135,9 @@ struct value_int_t : public value_t { virtual int64_t as_int() const override { return val_int; } virtual double as_float() const override { return static_cast(val_int); } virtual string as_string() const override { return std::to_string(val_int); } - virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; }; -using value_int = std::unique_ptr; +using value_int = std::shared_ptr; struct value_float_t : public value_t { @@ -141,10 +146,9 @@ struct value_float_t : public value_t { virtual double as_float() const override { return val_flt; } virtual int64_t as_int() const override { return static_cast(val_flt); } virtual string as_string() const override { return std::to_string(val_flt); } - virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; }; -using value_float = std::unique_ptr; +using value_float = std::shared_ptr; struct value_string_t : public value_t { @@ -160,13 +164,12 @@ struct value_string_t : public value_t { } return ss.str(); } - virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; void mark_input() { val_str.mark_input(); } }; -using value_string = std::unique_ptr; +using value_string = std::shared_ptr; struct value_bool_t : public value_t { @@ -174,92 +177,68 @@ struct value_bool_t : public value_t { virtual std::string type() const override { return "Boolean"; } virtual bool as_bool() const override { return val_bool; } virtual string as_string() const override { return std::string(val_bool ? "True" : "False"); } - virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; }; -using value_bool = std::unique_ptr; +using value_bool = std::shared_ptr; struct value_array_t : public value_t { - value_array_t() { - val_arr = std::make_shared>(); - } + value_array_t() = default; value_array_t(value & v) { // point to the same underlying data val_arr = v->val_arr; } void push_back(const value & val) { - val_arr->push_back(val->clone()); + val_arr.push_back(val); } virtual std::string type() const override { return "Array"; } - virtual const std::vector & as_array() const override { return *val_arr; } - // clone will also share the underlying data (point to the same vector) - virtual value clone() const override { - auto tmp = std::make_unique(); - tmp->val_arr = this->val_arr; - return tmp; - } + virtual const std::vector & as_array() const override { return val_arr; } virtual string as_string() const override { std::ostringstream ss; ss << "["; - for (size_t i = 0; i < val_arr->size(); i++) { + for (size_t i = 0; i < val_arr.size(); i++) { if (i > 0) ss << ", "; - ss << val_arr->at(i)->as_repr(); + ss << val_arr.at(i)->as_repr(); } ss << "]"; return ss.str(); } virtual bool as_bool() const override { - return !val_arr->empty(); + return !val_arr.empty(); } virtual const func_builtins & get_builtins() const override; }; -using value_array = std::unique_ptr; +using value_array = std::shared_ptr; struct value_object_t : public value_t { - value_object_t() { - val_obj = std::make_shared>(); - } + value_object_t() = default; value_object_t(value & v) { // point to the same underlying data val_obj = v->val_obj; } value_object_t(const std::map & obj) { - val_obj = std::make_shared>(); + val_obj = std::map(); for (const auto & pair : obj) { - (*val_obj)[pair.first] = pair.second->clone(); + val_obj[pair.first] = pair.second; } } void insert(const std::string & key, const value & val) { - (*val_obj)[key] = val->clone(); + val_obj[key] = val; } virtual std::string type() const override { return "Object"; } - virtual const std::map & as_object() const override { return *val_obj; } - // clone will also share the underlying data (point to the same map) - virtual value clone() const override { - auto tmp = std::make_unique(); - tmp->val_obj = this->val_obj; - return tmp; - } + virtual const std::map & as_object() const override { return val_obj; } virtual bool as_bool() const override { - return !val_obj->empty(); + return !val_obj.empty(); } virtual const func_builtins & get_builtins() const override; }; -using value_object = std::unique_ptr; +using value_object = std::shared_ptr; struct value_func_t : public value_t { std::string name; // for debugging value arg0; // bound "this" argument, if any - value_func_t(const value_func_t & other) { - val_func = other.val_func; - name = other.name; - if (other.arg0) { - arg0 = other.arg0->clone(); - } - } value_func_t(const func_handler & func, std::string func_name = "") { val_func = func; name = func_name; @@ -267,14 +246,14 @@ struct value_func_t : public value_t { value_func_t(const func_handler & func, const value & arg_this, std::string func_name = "") { val_func = func; name = func_name; - arg0 = arg_this->clone(); + arg0 = arg_this; } virtual value invoke(const func_args & args) const override { if (arg0) { func_args new_args; - new_args.args.push_back(arg0->clone()); + new_args.args.push_back(arg0); for (const auto & a : args.args) { - new_args.args.push_back(a->clone()); + new_args.args.push_back(a); } return val_func(new_args); } else { @@ -283,9 +262,8 @@ struct value_func_t : public value_t { } virtual std::string type() const override { return "Function"; } virtual std::string as_repr() const override { return type(); } - virtual value clone() const override { return std::make_unique(*this); } }; -using value_func = std::unique_ptr; +using value_func = std::shared_ptr; struct value_null_t : public value_t { @@ -293,9 +271,8 @@ struct value_null_t : public value_t { virtual bool is_null() const override { return true; } virtual bool as_bool() const override { return false; } virtual std::string as_repr() const override { return type(); } - virtual value clone() const override { return std::make_unique(*this); } }; -using value_null = std::unique_ptr; +using value_null = std::shared_ptr; struct value_undefined_t : public value_t { @@ -303,24 +280,18 @@ struct value_undefined_t : public value_t { virtual bool is_undefined() const override { return true; } virtual bool as_bool() const override { return false; } virtual std::string as_repr() const override { return type(); } - virtual value clone() const override { return std::make_unique(*this); } }; -using value_undefined = std::unique_ptr; +using value_undefined = std::shared_ptr; // special value for kwarg struct value_kwarg_t : public value_t { std::string key; value val; - value_kwarg_t(const value_kwarg_t & other) { - key = other.key; - val = other.val->clone(); - } - value_kwarg_t(const std::string & k, const value & v) : key(k), val(v->clone()) {} + value_kwarg_t(const std::string & k, const value & v) : key(k), val(v) {} virtual std::string type() const override { return "KwArg"; } virtual std::string as_repr() const override { return type(); } - virtual value clone() const override { return std::make_unique(*this); } }; -using value_kwarg = std::unique_ptr; +using value_kwarg = std::shared_ptr; const func_builtins & global_builtins(); diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index feb7ffb5d2..ed601eb9b1 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -55,7 +55,7 @@ static T slice(const T & array, std::optional start = std::nullopt, std } for (int64_t i = start_val; direction * i < direction * stop_val; i += step) { if (i >= 0 && i < len) { - result.push_back(std::move(array[static_cast(i)]->clone())); + result.push_back(array[static_cast(i)]); } } return result; @@ -87,7 +87,7 @@ const func_builtins & global_builtins() { if (!is_val(arg)) { throw raised_exception("namespace() arguments must be kwargs"); } - auto kwarg = dynamic_cast(arg.get()); + auto kwarg = cast_val(arg); out->insert(kwarg->key, kwarg->val); } return out; @@ -265,12 +265,12 @@ const func_builtins & value_string_t::get_builtins() const { std::string token; while ((pos = str.find(delim)) != std::string::npos) { token = str.substr(0, pos); - result->val_arr->push_back(mk_val(token)); + result->push_back(mk_val(token)); str.erase(0, pos + delim.length()); } auto res = mk_val(str); res->val_str.mark_input_based_on(args.args[0]->val_str); - result->val_arr->push_back(std::move(res)); + result->push_back(std::move(res)); return std::move(result); }}, {"replace", [](const func_args & args) -> value { @@ -353,7 +353,7 @@ const func_builtins & value_array_t::get_builtins() const { const auto & arr = args.args[0]->as_array(); auto result = mk_val(); for (const auto& v : arr) { - result->val_arr->push_back(v->clone()); + result->push_back(v); } return result; }}, @@ -363,7 +363,7 @@ const func_builtins & value_array_t::get_builtins() const { if (arr.empty()) { return mk_val(); } - return arr[0]->clone(); + return arr[0]; }}, {"last", [](const func_args & args) -> value { args.ensure_vals(); @@ -371,7 +371,7 @@ const func_builtins & value_array_t::get_builtins() const { if (arr.empty()) { return mk_val(); } - return arr[arr.size() - 1]->clone(); + return arr[arr.size() - 1]; }}, {"length", [](const func_args & args) -> value { args.ensure_vals(); @@ -391,7 +391,7 @@ const func_builtins & value_array_t::get_builtins() const { } auto arr = slice(args.args[0]->as_array(), start, stop, step); auto res = mk_val(); - res->val_arr = std::make_shared>(std::move(arr)); + res->val_arr = std::move(arr); return res; }}, // TODO: reverse, sort, join, string, unique @@ -408,7 +408,7 @@ const func_builtins & value_object_t::get_builtins() const { std::string key = args.args[1]->as_string().str(); auto it = obj.find(key); if (it != obj.end()) { - return it->second->clone(); + return it->second; } else { return mk_val(); } @@ -418,7 +418,7 @@ const func_builtins & value_object_t::get_builtins() const { const auto & obj = args.args[0]->as_object(); auto result = mk_val(); for (const auto & pair : obj) { - result->val_arr->push_back(mk_val(pair.first)); + result->push_back(mk_val(pair.first)); } return result; }}, @@ -427,7 +427,7 @@ const func_builtins & value_object_t::get_builtins() const { const auto & obj = args.args[0]->as_object(); auto result = mk_val(); for (const auto & pair : obj) { - result->val_arr->push_back(pair.second->clone()); + result->push_back(pair.second); } return result; }}, @@ -437,9 +437,9 @@ const func_builtins & value_object_t::get_builtins() const { auto result = mk_val(); for (const auto & pair : obj) { auto item = mk_val(); - item->val_arr->push_back(mk_val(pair.first)); - item->val_arr->push_back(pair.second->clone()); - result->val_arr->push_back(std::move(item)); + item->push_back(mk_val(pair.first)); + item->push_back(pair.second); + result->push_back(std::move(item)); } return result; }}, diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index f39321fa00..fea7c75f06 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -13,16 +13,11 @@ namespace jinja { -template -static bool is_stmt(const statement_ptr & ptr) { - return dynamic_cast(ptr.get()) != nullptr; -} - static value_array exec_statements(const statements & stmts, context & ctx) { auto result = mk_val(); for (const auto & stmt : stmts) { JJ_DEBUG("Executing statement of type %s", stmt->type().c_str()); - result->val_arr->push_back(stmt->execute(ctx)); + result->push_back(stmt->execute(ctx)); } return result; } @@ -32,7 +27,7 @@ value identifier::execute(context & ctx) { auto builtins = global_builtins(); if (it != ctx.var.end()) { JJ_DEBUG("Identifier '%s' found", val.c_str()); - return it->second->clone(); + return it->second; } else if (builtins.find(val) != builtins.end()) { JJ_DEBUG("Identifier '%s' found in builtins", val.c_str()); return mk_val(builtins.at(val), val); @@ -115,10 +110,10 @@ value binary_expression::execute(context & ctx) { auto & right_arr = right_val->as_array(); auto result = mk_val(); for (const auto & item : left_arr) { - result->val_arr->push_back(item->clone()); + result->push_back(item); } for (const auto & item : right_arr) { - result->val_arr->push_back(item->clone()); + result->push_back(item); } return result; } @@ -185,7 +180,7 @@ value filter_expression::execute(context & ctx) { value input = operand->execute(ctx); if (is_stmt(filter)) { - auto filter_val = dynamic_cast(filter.get())->val; + auto filter_val = cast_stmt(filter)->val; if (filter_val == "to_json") { // TODO: Implement to_json filter @@ -215,7 +210,7 @@ value test_expression::execute(context & ctx) { throw std::runtime_error("Invalid test expression"); } - auto test_id = dynamic_cast(test.get())->val; + auto test_id = cast_stmt(test)->val; auto it = builtins.find("test_is_" + test_id); JJ_DEBUG("Test expression %s '%s'", operand->type().c_str(), test_id.c_str()); if (it == builtins.end()) { @@ -252,12 +247,12 @@ value if_statement::execute(context & ctx) { if (test_val->as_bool()) { for (auto & stmt : body) { JJ_DEBUG("IF --> Executing THEN body, current block: %s", stmt->type().c_str()); - out->val_arr->push_back(stmt->execute(ctx)); + out->push_back(stmt->execute(ctx)); } } else { for (auto & stmt : alternate) { JJ_DEBUG("IF --> Executing ELSE body, current block: %s", stmt->type().c_str()); - out->val_arr->push_back(stmt->execute(ctx)); + out->push_back(stmt->execute(ctx)); } } return out; @@ -271,7 +266,7 @@ value for_statement::execute(context & ctx) { if (is_stmt(iterable)) { JJ_DEBUG("%s", "For loop has test expression"); - auto select = dynamic_cast(iterable.get()); + auto select = cast_stmt(iterable); iter_expr = std::move(select->lhs); test_expr = std::move(select->test); } @@ -292,7 +287,7 @@ value for_statement::execute(context & ctx) { } else { auto & arr = iterable_val->as_array(); for (const auto & item : arr) { - items.push_back(item->clone()); + items.push_back(item); } } @@ -306,12 +301,12 @@ value for_statement::execute(context & ctx) { std::function scope_update_fn = [](context &) { /* no-op */}; if (is_stmt(loopvar)) { - auto id = dynamic_cast(loopvar.get())->val; + auto id = cast_stmt(loopvar)->val; scope_update_fn = [id, &items, i](context & ctx) { - ctx.var[id] = items[i]->clone(); + ctx.var[id] = items[i]; }; } else if (is_stmt(loopvar)) { - auto tuple = dynamic_cast(loopvar.get()); + auto tuple = cast_stmt(loopvar); if (!is_val(current)) { throw std::runtime_error("Cannot unpack non-iterable type: " + current->type()); } @@ -325,8 +320,8 @@ value for_statement::execute(context & ctx) { if (!is_stmt(tuple->val[j])) { throw std::runtime_error("Cannot unpack non-identifier type: " + tuple->val[j]->type()); } - auto id = dynamic_cast(tuple->val[j].get())->val; - ctx.var[id] = c_arr[j]->clone(); + auto id = cast_stmt(tuple->val[j])->val; + ctx.var[id] = c_arr[j]; } }; } else { @@ -339,7 +334,7 @@ value for_statement::execute(context & ctx) { continue; } } - filtered_items.push_back(current->clone()); + filtered_items.push_back(current); scope_update_fns.push_back(scope_update_fn); } @@ -356,9 +351,9 @@ value for_statement::execute(context & ctx) { loop_obj->insert("first", mk_val(i == 0)); loop_obj->insert("last", mk_val(i == filtered_items.size() - 1)); loop_obj->insert("length", mk_val(filtered_items.size())); - loop_obj->insert("previtem", i > 0 ? filtered_items[i - 1]->clone() : mk_val()); - loop_obj->insert("nextitem", i < filtered_items.size() - 1 ? filtered_items[i + 1]->clone() : mk_val()); - ctx.var["loop"] = loop_obj->clone(); + loop_obj->insert("previtem", i > 0 ? filtered_items[i - 1] : mk_val()); + loop_obj->insert("nextitem", i < filtered_items.size() - 1 ? filtered_items[i + 1] : mk_val()); + ctx.var["loop"] = loop_obj; scope_update_fns[i](ctx); try { for (auto & stmt : body) { @@ -386,12 +381,12 @@ value set_statement::execute(context & ctx) { auto rhs = val ? val->execute(ctx) : exec_statements(body, ctx); if (is_stmt(assignee)) { - auto var_name = dynamic_cast(assignee.get())->val; + auto var_name = cast_stmt(assignee)->val; JJ_DEBUG("Setting variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str()); - ctx.var[var_name] = rhs->clone(); + ctx.var[var_name] = rhs; } else if (is_stmt(assignee)) { - auto tuple = dynamic_cast(assignee.get()); + auto tuple = cast_stmt(assignee); if (!is_val(rhs)) { throw std::runtime_error("Cannot unpack non-iterable type in set: " + rhs->type()); } @@ -404,27 +399,27 @@ value set_statement::execute(context & ctx) { if (!is_stmt(elem)) { throw std::runtime_error("Cannot unpack to non-identifier in set: " + elem->type()); } - auto var_name = dynamic_cast(elem.get())->val; - ctx.var[var_name] = arr[i]->clone(); + auto var_name = cast_stmt(elem)->val; + ctx.var[var_name] = arr[i]; } } else if (is_stmt(assignee)) { - auto member = dynamic_cast(assignee.get()); + auto member = cast_stmt(assignee); if (member->computed) { throw std::runtime_error("Cannot assign to computed member"); } if (!is_stmt(member->property)) { throw std::runtime_error("Cannot assign to member with non-identifier property"); } - auto prop_name = dynamic_cast(member->property.get())->val; + auto prop_name = cast_stmt(member->property)->val; value object = member->object->execute(ctx); if (!is_val(object)) { throw std::runtime_error("Cannot assign to member of non-object"); } - auto obj_ptr = dynamic_cast(object.get()); + auto obj_ptr = cast_val(object); JJ_DEBUG("Setting object property '%s'", prop_name.c_str()); - obj_ptr->insert(prop_name, rhs->clone()); + obj_ptr->insert(prop_name, rhs); } else { throw std::runtime_error("Invalid LHS inside assignment expression: " + assignee->type()); @@ -433,7 +428,7 @@ value set_statement::execute(context & ctx) { } value macro_statement::execute(context & ctx) { - std::string name = dynamic_cast(this->name.get())->val; + std::string name = cast_stmt(this->name)->val; const func_handler func = [this, &ctx, name](const func_args & args) -> value { JJ_DEBUG("Invoking macro '%s' with %zu arguments", name.c_str(), args.args.size()); context macro_ctx(ctx); // new scope for macro execution @@ -442,9 +437,9 @@ value macro_statement::execute(context & ctx) { size_t param_count = this->args.size(); size_t arg_count = args.args.size(); for (size_t i = 0; i < param_count; ++i) { - std::string param_name = dynamic_cast(this->args[i].get())->val; + std::string param_name = cast_stmt(this->args[i])->val; if (i < arg_count) { - macro_ctx.var[param_name] = args.args[i]->clone(); + macro_ctx.var[param_name] = args.args[i]; } else { macro_ctx.var[param_name] = mk_val(); } @@ -466,7 +461,7 @@ value member_expression::execute(context & ctx) { if (this->computed) { JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str()); if (is_stmt(this->property)) { - auto s = dynamic_cast(this->property.get()); + auto s = cast_stmt(this->property); value start_val = s->start_expr ? s->start_expr->execute(ctx) : mk_val(); value stop_val = s->stop_expr ? s->stop_expr->execute(ctx) : mk_val(); value step_val = s->step_expr ? s->step_expr->execute(ctx) : mk_val(); @@ -478,15 +473,15 @@ value member_expression::execute(context & ctx) { step_val->as_repr().c_str()); auto slice_func = try_builtin_func("slice", object); func_args args; - args.args.push_back(start_val->clone()); - args.args.push_back(stop_val->clone()); - args.args.push_back(step_val->clone()); + args.args.push_back(start_val); + args.args.push_back(stop_val); + args.args.push_back(step_val); return slice_func->invoke(args); } else { property = this->property->execute(ctx); } } else { - property = mk_val(dynamic_cast(this->property.get())->val); + property = mk_val(cast_stmt(this->property)->val); } JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str()); @@ -502,7 +497,7 @@ value member_expression::execute(context & ctx) { auto & obj = object->as_object(); auto it = obj.find(key); if (it != obj.end()) { - val = it->second->clone(); + val = it->second; } else { val = try_builtin_func(key, object, true); } @@ -514,7 +509,7 @@ value member_expression::execute(context & ctx) { if (is_val(object)) { auto & arr = object->as_array(); if (index >= 0 && index < static_cast(arr.size())) { - val = arr[index]->clone(); + val = arr[index]; } } else { // value_string auto str = object->as_string().str(); @@ -554,7 +549,7 @@ value call_expression::execute(context & ctx) { if (!is_val(callee_val)) { throw std::runtime_error("Callee is not a function: got " + callee_val->type()); } - auto * callee_func = dynamic_cast(callee_val.get()); + auto * callee_func = cast_val(callee_val); JJ_DEBUG("Calling function '%s' with %zu arguments", callee_func->name.c_str(), args.args.size()); return callee_func->invoke(args); } @@ -597,7 +592,7 @@ value keyword_argument_expression::execute(context & ctx) { throw std::runtime_error("Keyword argument key must be identifiers"); } - std::string k = dynamic_cast(key.get())->val; + std::string k = cast_stmt(key)->val; JJ_DEBUG("Keyword argument expression key: %s, value: %s", k.c_str(), val->type().c_str()); value v = val->execute(ctx); diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index a931bc1ea8..3cfc4b81df 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -12,6 +12,33 @@ namespace jinja { +struct statement; +using statement_ptr = std::unique_ptr; +using statements = std::vector; + +// Helpers for dynamic casting and type checking +template +struct extract_pointee_unique { + using type = T; +}; +template +struct extract_pointee_unique> { + using type = U; +}; +template +bool is_stmt(const statement_ptr & ptr) { + return dynamic_cast(ptr.get()) != nullptr; +} +template +T * cast_stmt(statement_ptr & ptr) { + return dynamic_cast(ptr.get()); +} +template +const T * cast_stmt(const statement_ptr & ptr) { + return dynamic_cast(ptr.get()); +} +// End Helpers + struct context { std::map var; @@ -25,7 +52,7 @@ struct context { context(const context & parent) { // inherit variables (for example, when entering a new scope) for (const auto & pair : parent.var) { - var[pair.first] = pair.second->clone(); + var[pair.first] = pair.second; } } }; @@ -39,9 +66,6 @@ struct statement { virtual value execute(context &) { throw std::runtime_error("cannot exec " + type()); } }; -using statement_ptr = std::unique_ptr; -using statements = std::vector; - // Type Checking Utilities template @@ -461,7 +485,7 @@ struct vm { value_array results = mk_val(); for (auto & stmt : prog.body) { value res = stmt->execute(ctx); - results->val_arr->push_back(std::move(res)); + results->push_back(std::move(res)); } return results; } @@ -474,13 +498,13 @@ struct vm { void gather_string_parts_recursive(const value & val, std::vector & parts) { if (is_val(val)) { - const auto & str_val = dynamic_cast(val.get())->val_str; + const auto & str_val = cast_val(val)->val_str; for (const auto & part : str_val.parts) { parts.push_back(part); } } else if (is_val(val)) { - auto items = dynamic_cast(val.get())->val_arr.get(); - for (const auto & item : *items) { + auto items = cast_val(val)->as_array(); + for (const auto & item : items) { gather_string_parts_recursive(item, parts); } } diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index ce17df5b1d..eff9831ff4 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -47,15 +47,15 @@ int main(void) { return str_val; }; - jinja::value messages = jinja::mk_val(); - jinja::value msg1 = jinja::mk_val(); - (*msg1->val_obj)["role"] = make_non_special_string("user"); - (*msg1->val_obj)["content"] = make_non_special_string("Hello, how are you?"); - messages->val_arr->push_back(std::move(msg1)); - jinja::value msg2 = jinja::mk_val(); - (*msg2->val_obj)["role"] = make_non_special_string("assistant"); - (*msg2->val_obj)["content"] = make_non_special_string("I am fine, thank you!"); - messages->val_arr->push_back(std::move(msg2)); + jinja::value_array messages = jinja::mk_val(); + jinja::value_object msg1 = jinja::mk_val(); + msg1->insert("role", make_non_special_string("user")); + msg1->insert("content", make_non_special_string("Hello, how are you?")); + messages->push_back(std::move(msg1)); + jinja::value_object msg2 = jinja::mk_val(); + msg2->insert("role", make_non_special_string("assistant")); + msg2->insert("content", make_non_special_string("I am fine, thank you!")); + messages->push_back(std::move(msg2)); ctx.var["messages"] = std::move(messages); From 64e29a5848d4b87736ccfb989c8cfcca55b9b73f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 17:48:14 +0100 Subject: [PATCH 21/47] add mk_stmt --- common/jinja/jinja-parser.cpp | 66 +++++++++++++++++------------------ common/jinja/jinja-vm.h | 4 +++ 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/common/jinja/jinja-parser.cpp b/common/jinja/jinja-parser.cpp index 8b7058b8fa..c375d545ef 100644 --- a/common/jinja/jinja-parser.cpp +++ b/common/jinja/jinja-parser.cpp @@ -76,9 +76,9 @@ private: statement_ptr parse_any() { switch (peek().t) { case token::comment: - return std::make_unique(tokens[current++].value); + return mk_stmt(tokens[current++].value); case token::text: - return std::make_unique(tokens[current++].value); + return mk_stmt(tokens[current++].value); case token::open_statement: return parse_jinja_statement(); case token::open_expression: @@ -134,11 +134,11 @@ private: } else if (name == "break") { expect(token::close_statement, "Expected %}"); - result = std::make_unique(); + result = mk_stmt(); } else if (name == "continue") { expect(token::close_statement, "Expected %}"); - result = std::make_unique(); + result = mk_stmt(); } else if (name == "call") { statements caller_args; @@ -163,8 +163,8 @@ private: expect_identifier("endcall"); expect(token::close_statement, "Expected %}"); - auto call_expr = std::make_unique(std::move(callee), std::move(call_args)); - result = std::make_unique(std::move(call_expr), std::move(caller_args), std::move(body)); + auto call_expr = mk_stmt(std::move(callee), std::move(call_args)); + result = mk_stmt(std::move(call_expr), std::move(caller_args), std::move(body)); } else if (name == "filter") { auto filter_node = parse_primary_expression(); @@ -181,7 +181,7 @@ private: expect(token::open_statement, "Expected {%"); expect_identifier("endfilter"); expect(token::close_statement, "Expected %}"); - result = std::make_unique(std::move(filter_node), std::move(body)); + result = mk_stmt(std::move(filter_node), std::move(body)); } else { throw std::runtime_error("Unknown statement: " + name); @@ -208,7 +208,7 @@ private: expect_identifier("endset"); } expect(token::close_statement, "Expected %}"); - return std::make_unique(std::move(left), std::move(value), std::move(body)); + return mk_stmt(std::move(left), std::move(value), std::move(body)); } statement_ptr parse_if_statement() { @@ -237,7 +237,7 @@ private: alternate.push_back(parse_any()); } } - return std::make_unique(std::move(test), std::move(body), std::move(alternate)); + return mk_stmt(std::move(test), std::move(body), std::move(alternate)); } statement_ptr parse_macro_statement() { @@ -249,7 +249,7 @@ private: while (!is_statement({"endmacro"})) { body.push_back(parse_any()); } - return std::make_unique(std::move(name), std::move(args), std::move(body)); + return mk_stmt(std::move(name), std::move(args), std::move(body)); } statement_ptr parse_expression_sequence(bool primary = false) { @@ -261,7 +261,7 @@ private: exprs.push_back(primary ? parse_primary_expression() : parse_expression()); if (!is(token::comma)) break; } - return is_tuple ? std::make_unique(std::move(exprs)) : std::move(exprs[0]); + return is_tuple ? mk_stmt(std::move(exprs)) : std::move(exprs[0]); } statement_ptr parse_for_statement() { @@ -289,7 +289,7 @@ private: alternate.push_back(parse_any()); } } - return std::make_unique( + return mk_stmt( std::move(loop_var), std::move(iterable), std::move(body), std::move(alternate)); } @@ -309,10 +309,10 @@ private: // Ternary expression with else ++current; // consume 'else' auto false_expr = parse_if_expression(); // recurse to support chained ternaries - return std::make_unique(std::move(test), std::move(a), std::move(false_expr)); + return mk_stmt(std::move(test), std::move(a), std::move(false_expr)); } else { // Select expression on iterable - return std::make_unique(std::move(a), std::move(test)); + return mk_stmt(std::move(a), std::move(test)); } } return a; @@ -322,7 +322,7 @@ private: auto left = parse_logical_and_expression(); while (is_identifier("or")) { token op = tokens[current++]; - left = std::make_unique(op, std::move(left), parse_logical_and_expression()); + left = mk_stmt(op, std::move(left), parse_logical_and_expression()); } return left; } @@ -331,7 +331,7 @@ private: auto left = parse_logical_negation_expression(); while (is_identifier("and")) { auto op = tokens[current++]; - left = std::make_unique(op, std::move(left), parse_logical_negation_expression()); + left = mk_stmt(op, std::move(left), parse_logical_negation_expression()); } return left; } @@ -341,7 +341,7 @@ private: if (is_identifier("not")) { auto op = tokens[current]; ++current; // consume 'not' - return std::make_unique(op, parse_logical_negation_expression()); + return mk_stmt(op, parse_logical_negation_expression()); } return parse_comparison_expression(); } @@ -360,7 +360,7 @@ private: } else if (is(token::comparison_binary_operator)) { op = tokens[current++]; } else break; - left = std::make_unique(op, std::move(left), parse_additive_expression()); + left = mk_stmt(op, std::move(left), parse_additive_expression()); } return left; } @@ -369,7 +369,7 @@ private: auto left = parse_multiplicative_expression(); while (is(token::additive_binary_operator)) { auto op = tokens[current++]; - left = std::make_unique(op, std::move(left), parse_multiplicative_expression()); + left = mk_stmt(op, std::move(left), parse_multiplicative_expression()); } return left; } @@ -378,7 +378,7 @@ private: auto left = parse_test_expression(); while (is(token::multiplicative_binary_operator)) { auto op = tokens[current++]; - left = std::make_unique(op, std::move(left), parse_test_expression()); + left = mk_stmt(op, std::move(left), parse_test_expression()); } return left; } @@ -390,7 +390,7 @@ private: bool negate = false; if (is_identifier("not")) { current++; negate = true; } auto test_id = parse_primary_expression(); - operand = std::make_unique(std::move(operand), negate, std::move(test_id)); + operand = mk_stmt(std::move(operand), negate, std::move(test_id)); } return operand; } @@ -401,7 +401,7 @@ private: current++; auto filter = parse_primary_expression(); if (is(token::open_paren)) filter = parse_call_expression(std::move(filter)); - operand = std::make_unique(std::move(operand), std::move(filter)); + operand = mk_stmt(std::move(operand), std::move(filter)); } return operand; } @@ -415,7 +415,7 @@ private: } statement_ptr parse_call_expression(statement_ptr callee) { - auto expr = std::make_unique(std::move(callee), parse_args()); + auto expr = mk_stmt(std::move(callee), parse_args()); auto member = parse_member_expression(std::move(expr)); // foo.x().y return is(token::open_paren) ? parse_call_expression(std::move(member)) // foo.x()() @@ -431,14 +431,14 @@ private: // unpacking: *expr if (peek().t == token::multiplicative_binary_operator && peek().value == "*") { ++current; // consume * - arg = std::make_unique(parse_expression()); + arg = mk_stmt(parse_expression()); } else { arg = parse_expression(); if (is(token::equals)) { // keyword argument // e.g., func(x = 5, y = a or b) ++current; // consume equals - arg = std::make_unique(std::move(arg), parse_expression()); + arg = mk_stmt(std::move(arg), parse_expression()); } } args.push_back(std::move(arg)); @@ -461,7 +461,7 @@ private: } else { prop = parse_primary_expression(); } - object = std::make_unique(std::move(object), std::move(prop), computed); + object = mk_stmt(std::move(object), std::move(prop), computed); } return object; } @@ -490,7 +490,7 @@ private: statement_ptr start = slices.size() > 0 ? std::move(slices[0]) : nullptr; statement_ptr stop = slices.size() > 1 ? std::move(slices[1]) : nullptr; statement_ptr step = slices.size() > 2 ? std::move(slices[2]) : nullptr; - return std::make_unique(std::move(start), std::move(stop), std::move(step)); + return mk_stmt(std::move(start), std::move(stop), std::move(step)); } return std::move(slices[0]); } @@ -499,15 +499,15 @@ private: auto t = tokens[current++]; switch (t.t) { case token::numeric_literal: - if (t.value.find('.') != std::string::npos) return std::make_unique(std::stod(t.value)); - return std::make_unique(std::stoll(t.value)); + if (t.value.find('.') != std::string::npos) return mk_stmt(std::stod(t.value)); + return mk_stmt(std::stoll(t.value)); case token::string_literal: { std::string val = t.value; while (is(token::string_literal)) val += tokens[current++].value; - return std::make_unique(val); + return mk_stmt(val); } case token::identifier: - return std::make_unique(t.value); + return mk_stmt(t.value); case token::open_paren: { auto expr = parse_expression_sequence(); expect(token::close_paren, "Expected )"); @@ -520,7 +520,7 @@ private: if (is(token::comma)) current++; } current++; - return std::make_unique(std::move(vals)); + return mk_stmt(std::move(vals)); } case token::open_curly_bracket: { std::vector> pairs; @@ -531,7 +531,7 @@ private: if (is(token::comma)) current++; } current++; - return std::make_unique(std::move(pairs)); + return mk_stmt(std::move(pairs)); } default: throw std::runtime_error("Unexpected token: " + t.value); diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 3cfc4b81df..165bfafd96 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -37,6 +37,10 @@ template const T * cast_stmt(const statement_ptr & ptr) { return dynamic_cast(ptr.get()); } +template +std::unique_ptr mk_stmt(Args&&... args) { + return std::make_unique(std::forward(args)...); +} // End Helpers struct context { From acb0effa251675df825e611a9c4eab24bcdcf7ad Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 18:45:41 +0100 Subject: [PATCH 22/47] allow print source on exception --- common/jinja/jinja-lexer.cpp | 26 +++++++---- common/jinja/jinja-lexer.h | 8 +++- common/jinja/jinja-parser.cpp | 47 +++++++++++++++++-- common/jinja/jinja-value.h | 3 ++ common/jinja/jinja-vm-builtins.cpp | 19 -------- common/jinja/jinja-vm.cpp | 73 ++++++++++++++++++++++++------ common/jinja/jinja-vm.h | 54 ++++++++++++---------- tests/test-chat-jinja.cpp | 14 +++--- 8 files changed, 167 insertions(+), 77 deletions(-) diff --git a/common/jinja/jinja-lexer.cpp b/common/jinja/jinja-lexer.cpp index a5ce7af9e1..541452f3fe 100644 --- a/common/jinja/jinja-lexer.cpp +++ b/common/jinja/jinja-lexer.cpp @@ -54,12 +54,13 @@ std::string lexer::preprocess(const std::string & template_str, const preprocess return result; } -std::vector lexer::tokenize(const std::string & input, const preprocess_options & options) { +lexer_result lexer::tokenize(const std::string & input, const preprocess_options & options) { std::vector tokens; std::string src = preprocess(input, options); JJ_DEBUG("preprocessed input: '%s'", src.c_str()); size_t pos = 0; + size_t start_pos = 0; size_t curly_bracket_depth = 0; using pred = std::function; @@ -101,6 +102,7 @@ std::vector lexer::tokenize(const std::string & input, const preprocess_o }; while (pos < src.size()) { + start_pos = pos; JJ_DEBUG("lexer main loop at pos %zu: '%s...'", pos, src.substr(pos, 10).c_str()); // First, consume all text that is outside of a Jinja statement or expression @@ -122,13 +124,14 @@ std::vector lexer::tokenize(const std::string & input, const preprocess_o } JJ_DEBUG("consumed text: '%s'", text.c_str()); if (!text.empty()) { - tokens.push_back({token::text, text}); + tokens.push_back({token::text, text, start_pos}); continue; } } // Possibly consume a comment if (src[pos] == '{' && next_pos_is( {'#'} )) { + start_pos = pos; pos += 2; // Skip the opening {# std::string comment; while (!(src[pos] == '#' && next_pos_is( {'}'} ))) { @@ -138,7 +141,7 @@ std::vector lexer::tokenize(const std::string & input, const preprocess_o comment += src[pos++]; } JJ_DEBUG("consumed comment: '%s'", comment.c_str()); - tokens.push_back({token::comment, comment}); + tokens.push_back({token::comment, comment, start_pos}); pos += 2; // Skip the closing #} continue; } @@ -152,6 +155,7 @@ std::vector lexer::tokenize(const std::string & input, const preprocess_o // Check for unary operators if (ch == '-' || ch == '+') { + start_pos = pos; token::type last_token_type = tokens.empty() ? token::undefined : tokens.back().t; if (last_token_type == token::text || last_token_type == token::undefined) { throw std::runtime_error(std::string("lexer: unexpected character: ") + ch); @@ -176,7 +180,7 @@ std::vector lexer::tokenize(const std::string & input, const preprocess_o std::string value = std::string(1, ch) + num; token::type t = num.empty() ? token::unary_operator : token::numeric_literal; JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str()); - tokens.push_back({t, value}); + tokens.push_back({t, value, start_pos}); continue; } } @@ -185,12 +189,13 @@ std::vector lexer::tokenize(const std::string & input, const preprocess_o // Try to match one of the tokens in the mapping table bool matched = false; for (const auto & [seq, typ] : ordered_mapping_table) { + start_pos = pos; // Inside an object literal, don't treat "}}" as expression-end if (seq == "}}" && curly_bracket_depth > 0) { continue; } if (pos + seq.size() <= src.size() && src.substr(pos, seq.size()) == seq) { - tokens.push_back({typ, seq}); + tokens.push_back({typ, seq, start_pos}); if (typ == token::open_expression) { curly_bracket_depth = 0; } else if (typ == token::open_curly_bracket) { @@ -207,36 +212,39 @@ std::vector lexer::tokenize(const std::string & input, const preprocess_o // Strings if (ch == '\'' || ch == '"') { + start_pos = pos; ++pos; // Skip opening quote std::string str = consume_while([ch](char c) { return c != ch; }); - tokens.push_back({token::string_literal, str}); + tokens.push_back({token::string_literal, str, start_pos}); ++pos; // Skip closing quote continue; } // Numbers if (is_integer(ch)) { + start_pos = pos; std::string num = consume_while(is_integer); if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) { ++pos; // Consume '.' std::string frac = consume_while(is_integer); num += "." + frac; } - tokens.push_back({token::numeric_literal, num}); + tokens.push_back({token::numeric_literal, num, start_pos}); continue; } // Identifiers if (is_word(ch)) { + start_pos = pos; std::string word = consume_while(is_word); - tokens.push_back({token::identifier, word}); + tokens.push_back({token::identifier, word, start_pos}); continue; } throw std::runtime_error(std::string("lexer: unexpected character: ") + ch); } - return tokens; + return {std::move(tokens), std::move(src)}; } } // namespace jinja diff --git a/common/jinja/jinja-lexer.h b/common/jinja/jinja-lexer.h index 3ed173a4f0..f9bbe0a991 100644 --- a/common/jinja/jinja-lexer.h +++ b/common/jinja/jinja-lexer.h @@ -48,6 +48,7 @@ struct token { }; type t; std::string value; + size_t pos; }; static std::string type_to_string(token::type t) { @@ -82,6 +83,11 @@ static std::string type_to_string(token::type t) { } } +struct lexer_result { + std::vector tokens; + std::string preprocessed_source; +}; + struct lexer { const std::map escape_chars = { {'n', '\n'}, @@ -140,7 +146,7 @@ struct lexer { std::string preprocess(const std::string& template_str, const preprocess_options& options) const; - std::vector tokenize(const std::string & input, const preprocess_options & options); + lexer_result tokenize(const std::string & input, const preprocess_options & options); }; } // namespace jinja diff --git a/common/jinja/jinja-parser.cpp b/common/jinja/jinja-parser.cpp index c375d545ef..5f42b0bd89 100644 --- a/common/jinja/jinja-parser.cpp +++ b/common/jinja/jinja-parser.cpp @@ -8,6 +8,8 @@ #include #include +#define FILENAME "jinja-parser" + namespace jinja { // Helper to check type without asserting (useful for logic) @@ -19,9 +21,18 @@ static bool is_type(const statement_ptr & ptr) { class parser { const std::vector & tokens; size_t current = 0; + size_t prev_cur = 0; + + // for debugging; a token can be multiple chars in source + std::vector tok_pos_to_src_pos; public: - parser(const std::vector & t) : tokens(t) {} + parser(const std::vector & t) : tokens(t) { + tok_pos_to_src_pos.resize(tokens.size()); + for (size_t i = 0; i < tokens.size(); i++) { + tok_pos_to_src_pos[i] = tokens[i].pos; + } + } program parse() { statements body; @@ -31,10 +42,18 @@ public: return program(std::move(body)); } + template + std::unique_ptr mk_stmt(Args&&... args) { + auto ptr = std::make_unique(std::forward(args)...); + ptr->pos = tok_pos_to_src_pos[prev_cur]; + JJ_DEBUG("Created %s statement at src pos %zu", ptr->type().c_str(), ptr->pos); + return ptr; + } + private: const token & peek(size_t offset = 0) const { if (current + offset >= tokens.size()) { - static const token end_token{token::undefined, ""}; + static const token end_token{token::undefined, "", 0}; return end_token; } return tokens[current + offset]; @@ -74,6 +93,7 @@ private: } statement_ptr parse_any() { + prev_cur = current; switch (peek().t) { case token::comment: return mk_stmt(tokens[current++].value); @@ -90,6 +110,7 @@ private: statement_ptr parse_jinja_expression() { // Consume {{ }} tokens + prev_cur = current; expect(token::open_expression, "Expected {{"); auto result = parse_expression(); expect(token::close_expression, "Expected }}"); @@ -98,6 +119,7 @@ private: statement_ptr parse_jinja_statement() { // Consume {% token + prev_cur = current; expect(token::open_statement, "Expected {%"); if (peek().t != token::identifier) { @@ -194,6 +216,8 @@ private: auto left = parse_expression_sequence(); statement_ptr value = nullptr; statements body; + + prev_cur = current; if (is(token::equals)) { current++; @@ -218,6 +242,8 @@ private: statements body; statements alternate; + prev_cur = current; + // Keep parsing 'if' body until we reach the first {% elif %} or {% else %} or {% endif %} while (!is_statement({"elif", "else", "endif"})) { body.push_back(parse_any()); @@ -257,6 +283,7 @@ private: exprs.push_back(primary ? parse_primary_expression() : parse_expression()); bool is_tuple = is(token::comma); while (is(token::comma)) { + prev_cur = current; current++; // consume comma exprs.push_back(primary ? parse_primary_expression() : parse_expression()); if (!is(token::comma)) break; @@ -283,6 +310,7 @@ private: } if (is_statement({"else"})) { + prev_cur = current; current += 2; expect(token::close_statement, "Expected %}"); while (!is_statement({"endfor"})) { @@ -303,10 +331,12 @@ private: auto a = parse_logical_or_expression(); if (is_identifier("if")) { // Ternary expression + prev_cur = current; ++current; // consume 'if' auto test = parse_logical_or_expression(); if (is_identifier("else")) { // Ternary expression with else + prev_cur = current; ++current; // consume 'else' auto false_expr = parse_if_expression(); // recurse to support chained ternaries return mk_stmt(std::move(test), std::move(a), std::move(false_expr)); @@ -321,6 +351,7 @@ private: statement_ptr parse_logical_or_expression() { auto left = parse_logical_and_expression(); while (is_identifier("or")) { + prev_cur = current; token op = tokens[current++]; left = mk_stmt(op, std::move(left), parse_logical_and_expression()); } @@ -330,6 +361,7 @@ private: statement_ptr parse_logical_and_expression() { auto left = parse_logical_negation_expression(); while (is_identifier("and")) { + prev_cur = current; auto op = tokens[current++]; left = mk_stmt(op, std::move(left), parse_logical_negation_expression()); } @@ -339,6 +371,7 @@ private: statement_ptr parse_logical_negation_expression() { // Try parse unary operators if (is_identifier("not")) { + prev_cur = current; auto op = tokens[current]; ++current; // consume 'not' return mk_stmt(op, parse_logical_negation_expression()); @@ -352,8 +385,9 @@ private: auto left = parse_additive_expression(); while (true) { token op; + prev_cur = current; if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") { - op = {token::identifier, "not in"}; + op = {token::identifier, "not in", tokens[current].pos}; current += 2; } else if (is_identifier("in")) { op = tokens[current++]; @@ -368,6 +402,7 @@ private: statement_ptr parse_additive_expression() { auto left = parse_multiplicative_expression(); while (is(token::additive_binary_operator)) { + prev_cur = current; auto op = tokens[current++]; left = mk_stmt(op, std::move(left), parse_multiplicative_expression()); } @@ -377,6 +412,7 @@ private: statement_ptr parse_multiplicative_expression() { auto left = parse_test_expression(); while (is(token::multiplicative_binary_operator)) { + prev_cur = current; auto op = tokens[current++]; left = mk_stmt(op, std::move(left), parse_test_expression()); } @@ -386,6 +422,7 @@ private: statement_ptr parse_test_expression() { auto operand = parse_filter_expression(); while (is_identifier("is")) { + prev_cur = current; current++; bool negate = false; if (is_identifier("not")) { current++; negate = true; } @@ -398,6 +435,7 @@ private: statement_ptr parse_filter_expression() { auto operand = parse_call_member_expression(); while (is(token::pipe)) { + prev_cur = current; current++; auto filter = parse_primary_expression(); if (is(token::open_paren)) filter = parse_call_expression(std::move(filter)); @@ -428,6 +466,7 @@ private: statements args; while (!is(token::close_paren)) { statement_ptr arg; + prev_cur = current; // unpacking: *expr if (peek().t == token::multiplicative_binary_operator && peek().value == "*") { ++current; // consume * @@ -472,6 +511,7 @@ private: statements slices; bool is_slice = false; while (!is(token::close_square_bracket)) { + prev_cur = current; if (is(token::colon)) { // A case where a default is used // e.g., [:2] will be parsed as [undefined, 2] @@ -496,6 +536,7 @@ private: } statement_ptr parse_primary_expression() { + prev_cur = current; auto t = tokens[current++]; switch (t.t) { case token::numeric_literal: diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 6c6f4a30d6..94c638eab2 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -164,6 +164,9 @@ struct value_string_t : public value_t { } return ss.str(); } + virtual bool as_bool() const override { + return val_str.length() > 0; + } virtual const func_builtins & get_builtins() const override; void mark_input() { val_str.mark_input(); diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index ed601eb9b1..5802253a3e 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -173,25 +173,6 @@ const func_builtins & value_float_t::get_builtins() const { return builtins; } - -// static std::string string_strip(const std::string & str, bool left, bool right) { -// size_t start = 0; -// size_t end = str.length(); -// if (left) { -// while (start < end && isspace(static_cast(str[start]))) { -// ++start; -// } -// } -// if (right) { -// while (end > start && isspace(static_cast(str[end - 1]))) { -// --end; -// } -// } -// return str.substr(start, end - start); -// } - - - static bool string_startswith(const std::string & str, const std::string & prefix) { if (str.length() < prefix.length()) return false; return str.compare(0, prefix.length(), prefix) == 0; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index fea7c75f06..ca213b0462 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -8,8 +8,9 @@ #include #include -#define JJ_DEBUG(msg, ...) printf("jinja-vm:%3d : " msg "\n", __LINE__, __VA_ARGS__) -//#define JJ_DEBUG(msg, ...) // no-op +#define FILENAME "jinja-vm" + +bool g_jinja_debug = true; namespace jinja { @@ -22,7 +23,51 @@ static value_array exec_statements(const statements & stmts, context & ctx) { return result; } -value identifier::execute(context & ctx) { +static void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { + if (search.empty()) { + return; + } + std::string builder; + builder.reserve(s.length()); + size_t pos = 0; + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); +} + +// execute with error handling +value statement::execute(context & ctx) { + try { + return execute_impl(ctx); + } catch (const std::exception & e) { + if (ctx.source.empty()) { + std::ostringstream oss; + oss << "\nError executing " << type() << " at position " << pos << ": " << e.what(); + throw raised_exception(oss.str()); + } else { + std::ostringstream oss; + constexpr int max_peak_chars = 40; + oss << "\n------------\n"; + oss << "While executing " << type() << " at position " << pos << " in source:\n"; + size_t start = (pos >= max_peak_chars) ? (pos - max_peak_chars) : 0; + size_t end = std::min(pos + max_peak_chars, ctx.source.length()); + std::string substr = ctx.source.substr(start, end - start); + string_replace_all(substr, "\n", "\\n"); + oss << "..." << substr << "...\n"; + std::string spaces(pos - start + 3, ' '); + oss << spaces << "^\n"; + oss << "Error: " << e.what(); + throw raised_exception(oss.str()); + } + } +} + +value identifier::execute_impl(context & ctx) { auto it = ctx.var.find(val); auto builtins = global_builtins(); if (it != ctx.var.end()) { @@ -37,7 +82,7 @@ value identifier::execute(context & ctx) { } } -value binary_expression::execute(context & ctx) { +value binary_expression::execute_impl(context & ctx) { value left_val = left->execute(ctx); JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right->type().c_str()); @@ -176,7 +221,7 @@ static value try_builtin_func(const std::string & name, const value & input, boo throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type()); } -value filter_expression::execute(context & ctx) { +value filter_expression::execute_impl(context & ctx) { value input = operand->execute(ctx); if (is_stmt(filter)) { @@ -203,7 +248,7 @@ value filter_expression::execute(context & ctx) { } } -value test_expression::execute(context & ctx) { +value test_expression::execute_impl(context & ctx) { // NOTE: "value is something" translates to function call "test_is_something(value)" const auto & builtins = global_builtins(); if (!is_stmt(test)) { @@ -222,7 +267,7 @@ value test_expression::execute(context & ctx) { return it->second(args); } -value unary_expression::execute(context & ctx) { +value unary_expression::execute_impl(context & ctx) { value operand_val = argument->execute(ctx); JJ_DEBUG("Executing unary expression with operator '%s'", op.value.c_str()); @@ -241,7 +286,7 @@ value unary_expression::execute(context & ctx) { throw std::runtime_error("Unknown unary operator '" + op.value + "'"); } -value if_statement::execute(context & ctx) { +value if_statement::execute_impl(context & ctx) { value test_val = test->execute(ctx); auto out = mk_val(); if (test_val->as_bool()) { @@ -258,7 +303,7 @@ value if_statement::execute(context & ctx) { return out; } -value for_statement::execute(context & ctx) { +value for_statement::execute_impl(context & ctx) { context scope(ctx); // new scope for loop variables statement_ptr iter_expr = std::move(iterable); @@ -377,7 +422,7 @@ value for_statement::execute(context & ctx) { return result; } -value set_statement::execute(context & ctx) { +value set_statement::execute_impl(context & ctx) { auto rhs = val ? val->execute(ctx) : exec_statements(body, ctx); if (is_stmt(assignee)) { @@ -427,7 +472,7 @@ value set_statement::execute(context & ctx) { return mk_val(); } -value macro_statement::execute(context & ctx) { +value macro_statement::execute_impl(context & ctx) { std::string name = cast_stmt(this->name)->val; const func_handler func = [this, &ctx, name](const func_args & args) -> value { JJ_DEBUG("Invoking macro '%s' with %zu arguments", name.c_str(), args.args.size()); @@ -454,7 +499,7 @@ value macro_statement::execute(context & ctx) { return mk_val(); } -value member_expression::execute(context & ctx) { +value member_expression::execute_impl(context & ctx) { value object = this->object->execute(ctx); value property; @@ -536,7 +581,7 @@ value member_expression::execute(context & ctx) { return val; } -value call_expression::execute(context & ctx) { +value call_expression::execute_impl(context & ctx) { // gather arguments func_args args; for (auto & arg_stmt : this->args) { @@ -587,7 +632,7 @@ bool value_compare(const value & a, const value & b) { return false; } -value keyword_argument_expression::execute(context & ctx) { +value keyword_argument_expression::execute_impl(context & ctx) { if (!is_stmt(key)) { throw std::runtime_error("Keyword argument key must be identifiers"); } diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 165bfafd96..639fba9d03 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -9,6 +9,9 @@ #include #include +#define JJ_DEBUG(msg, ...) if (g_jinja_debug) printf("%s:%3d : " msg "\n", FILENAME, __LINE__, __VA_ARGS__) + +extern bool g_jinja_debug; namespace jinja { @@ -37,14 +40,11 @@ template const T * cast_stmt(const statement_ptr & ptr) { return dynamic_cast(ptr.get()); } -template -std::unique_ptr mk_stmt(Args&&... args) { - return std::make_unique(std::forward(args)...); -} // End Helpers struct context { std::map var; + std::string source; // for debugging context() { var["true"] = mk_val(true); @@ -65,9 +65,13 @@ struct context { * Base class for all nodes in the AST. */ struct statement { + size_t pos; // position in source, for debugging virtual ~statement() = default; virtual std::string type() const { return "Statement"; } - virtual value execute(context &) { throw std::runtime_error("cannot exec " + type()); } + // execute_impl must be overridden by derived classes + virtual value execute_impl(context &) { throw std::runtime_error("cannot exec " + type()); } + // execute is the public method to execute a statement with error handling + virtual value execute(context &); }; // Type Checking Utilities @@ -100,7 +104,7 @@ struct program : public statement { explicit program(statements && body) : body(std::move(body)) {} std::string type() const override { return "Program"; } - value execute(context &) override { + value execute_impl(context &) override { throw std::runtime_error("Cannot execute program directly, use jinja::vm instead"); } }; @@ -116,7 +120,7 @@ struct if_statement : public statement { } std::string type() const override { return "If"; } - value execute(context & ctx) override; + value execute_impl(context & ctx) override; }; struct identifier; @@ -140,7 +144,7 @@ struct for_statement : public statement { } std::string type() const override { return "For"; } - value execute(context & ctx) override; + value execute_impl(context & ctx) override; }; struct break_statement : public statement { @@ -152,7 +156,7 @@ struct break_statement : public statement { } }; - value execute(context &) override { + value execute_impl(context &) override { throw break_statement::exception(); } }; @@ -166,7 +170,7 @@ struct continue_statement : public statement { } }; - value execute(context &) override { + value execute_impl(context &) override { throw continue_statement::exception(); } }; @@ -183,7 +187,7 @@ struct set_statement : public statement { } std::string type() const override { return "Set"; } - value execute(context & ctx) override; + value execute_impl(context & ctx) override; }; struct macro_statement : public statement { @@ -198,14 +202,14 @@ struct macro_statement : public statement { } std::string type() const override { return "Macro"; } - value execute(context & ctx) override; + value execute_impl(context & ctx) override; }; struct comment_statement : public statement { std::string val; explicit comment_statement(const std::string & v) : val(v) {} std::string type() const override { return "Comment"; } - value execute(context &) override { + value execute_impl(context &) override { return mk_val(); } }; @@ -223,7 +227,7 @@ struct member_expression : public expression { chk_type(this->property); } std::string type() const override { return "MemberExpression"; } - value execute(context & ctx) override; + value execute_impl(context & ctx) override; }; struct call_expression : public expression { @@ -236,7 +240,7 @@ struct call_expression : public expression { for (const auto& arg : this->args) chk_type(arg); } std::string type() const override { return "CallExpression"; } - value execute(context & ctx) override; + value execute_impl(context & ctx) override; }; /** @@ -246,7 +250,7 @@ struct identifier : public expression { std::string val; explicit identifier(const std::string & val) : val(val) {} std::string type() const override { return "Identifier"; } - value execute(context & ctx) override; + value execute_impl(context & ctx) override; }; // Literals @@ -255,7 +259,7 @@ struct integer_literal : public expression { int64_t val; explicit integer_literal(int64_t val) : val(val) {} std::string type() const override { return "IntegerLiteral"; } - value execute(context &) override { + value execute_impl(context &) override { return std::make_unique(val); } }; @@ -264,7 +268,7 @@ struct float_literal : public expression { double val; explicit float_literal(double val) : val(val) {} std::string type() const override { return "FloatLiteral"; } - value execute(context &) override { + value execute_impl(context &) override { return std::make_unique(val); } }; @@ -273,7 +277,7 @@ struct string_literal : public expression { std::string val; explicit string_literal(const std::string & val) : val(val) {} std::string type() const override { return "StringLiteral"; } - value execute(context &) override { + value execute_impl(context &) override { return std::make_unique(val); } }; @@ -324,7 +328,7 @@ struct binary_expression : public expression { chk_type(this->right); } std::string type() const override { return "BinaryExpression"; } - value execute(context & ctx) override; + value execute_impl(context & ctx) override; }; /** @@ -341,7 +345,7 @@ struct filter_expression : public expression { chk_type(this->filter); } std::string type() const override { return "FilterExpression"; } - value execute(context & ctx) override; + value execute_impl(context & ctx) override; }; struct filter_statement : public statement { @@ -388,7 +392,7 @@ struct test_expression : public expression { chk_type(this->test); } std::string type() const override { return "TestExpression"; } - value execute(context & ctx) override; + value execute_impl(context & ctx) override; }; /** @@ -403,7 +407,7 @@ struct unary_expression : public expression { chk_type(this->argument); } std::string type() const override { return "UnaryExpression"; } - value execute(context & ctx) override; + value execute_impl(context & ctx) override; }; struct slice_expression : public expression { @@ -418,7 +422,7 @@ struct slice_expression : public expression { chk_type(this->step_expr); } std::string type() const override { return "SliceExpression"; } - value execute(context &) override { + value execute_impl(context &) override { throw std::runtime_error("must be handled by MemberExpression"); } }; @@ -433,7 +437,7 @@ struct keyword_argument_expression : public expression { chk_type(this->val); } std::string type() const override { return "KeywordArgumentExpression"; } - value execute(context & ctx) override; + value execute_impl(context & ctx) override; }; struct spread_expression : public expression { diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index eff9831ff4..36cfde7c5f 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -16,9 +16,10 @@ int main(void) { //std::string contents = "{% if messages[0]['role'] != 'system' %}nice {{ messages[0]['content'] }}{% endif %}"; - //std::string contents = " {{ messages[0]['content'] }} "; + //std::string contents = " {{ messages[a]['content'] }} "; + //std::string contents = "{{ aaa[bbb] }}"; - std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja"); + std::ifstream infile("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja"); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); std::cout << "=== INPUT ===\n" << contents << "\n\n"; @@ -27,19 +28,20 @@ int main(void) { jinja::preprocess_options options; options.trim_blocks = true; options.lstrip_blocks = false; - auto tokens = lexer.tokenize(contents, options); - for (const auto & tok : tokens) { - std::cout << "token: type=" << static_cast(tok.t) << " text='" << tok.value << "'\n"; + auto lexer_res = lexer.tokenize(contents, options); + for (const auto & tok : lexer_res.tokens) { + std::cout << "token: type=" << static_cast(tok.t) << " text='" << tok.value << "' pos=" << tok.pos << "\n"; } std::cout << "\n=== AST ===\n"; - jinja::program ast = jinja::parse_from_tokens(tokens); + jinja::program ast = jinja::parse_from_tokens(lexer_res.tokens); for (const auto & stmt : ast.body) { std::cout << "stmt type: " << stmt->type() << "\n"; } std::cout << "\n=== RUN ===\n"; jinja::context ctx; + ctx.source = lexer_res.preprocessed_source; auto make_non_special_string = [](const std::string & s) { jinja::value_string str_val = jinja::mk_val(s); From db09a7468d849cb40c56a8916f27250c193435af Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 19:07:01 +0100 Subject: [PATCH 23/47] fix negate test --- common/jinja/jinja-vm-builtins.cpp | 7 ++++++- common/jinja/jinja-vm.cpp | 12 +++++++++--- tests/test-chat-jinja.cpp | 5 ++--- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index 5802253a3e..cf9de3636e 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -9,6 +9,8 @@ #include #include +#define FILENAME "jinja-vm-builtins" + namespace jinja { /** @@ -88,6 +90,7 @@ const func_builtins & global_builtins() { throw raised_exception("namespace() arguments must be kwargs"); } auto kwarg = cast_val(arg); + JJ_DEBUG("namespace: adding key '%s'", kwarg->key.c_str()); out->insert(kwarg->key, kwarg->val); } return out; @@ -132,7 +135,9 @@ const func_builtins & global_builtins() { {"test_is_none", test_type_fn}, {"test_is_defined", [](const func_args & args) -> value { args.ensure_count(1); - return mk_val(!is_val(args.args[0])); + bool res = !args.args[0]->is_undefined(); + JJ_DEBUG("test_is_defined: result=%d", res ? 1 : 0); + return mk_val(res); }}, {"test_is_undefined", test_type_fn}, }; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index ca213b0462..7aef38cfbd 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -257,14 +257,20 @@ value test_expression::execute_impl(context & ctx) { auto test_id = cast_stmt(test)->val; auto it = builtins.find("test_is_" + test_id); - JJ_DEBUG("Test expression %s '%s'", operand->type().c_str(), test_id.c_str()); + JJ_DEBUG("Test expression %s '%s' %s", operand->type().c_str(), test_id.c_str(), negate ? "(negate)" : ""); if (it == builtins.end()) { throw std::runtime_error("Unknown test '" + test_id + "'"); } func_args args; args.args.push_back(operand->execute(ctx)); - return it->second(args); + auto res = it->second(args); + + if (negate) { + return mk_val(!res->as_bool()); + } else { + return res; + } } value unary_expression::execute_impl(context & ctx) { @@ -538,7 +544,6 @@ value member_expression::execute_impl(context & ctx) { throw std::runtime_error("Cannot access object with non-string: got " + property->type()); } auto key = property->as_string().str(); - JJ_DEBUG("Accessing object property '%s'", key.c_str()); auto & obj = object->as_object(); auto it = obj.find(key); if (it != obj.end()) { @@ -546,6 +551,7 @@ value member_expression::execute_impl(context & ctx) { } else { val = try_builtin_func(key, object, true); } + JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str()); } else if (is_val(object) || is_val(object)) { if (is_val(property)) { diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 36cfde7c5f..097c60a543 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -17,10 +17,9 @@ int main(void) { //std::string contents = "{% if messages[0]['role'] != 'system' %}nice {{ messages[0]['content'] }}{% endif %}"; //std::string contents = " {{ messages[a]['content'] }} "; - //std::string contents = "{{ aaa[bbb] }}"; + //std::string contents = "{% if a is not defined %}hello{% endif %}"; - std::ifstream infile("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja"); - std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); + std::ifstream infile("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja"); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); std::cout << "=== INPUT ===\n" << contents << "\n\n"; From 45df0c91e7427b9def621c4995c48fbdb232c42c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 19:50:09 +0100 Subject: [PATCH 24/47] testing more templates --- common/jinja/jinja-vm-builtins.cpp | 5 +++++ common/jinja/jinja-vm.cpp | 20 ++++++++++++++++++-- common/jinja/jinja-vm.h | 1 + tests/test-chat-jinja.cpp | 28 +++++++++++++++++++++++++--- 4 files changed, 49 insertions(+), 5 deletions(-) diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index cf9de3636e..ecc2cfea52 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -429,6 +429,11 @@ const func_builtins & value_object_t::get_builtins() const { } return result; }}, + {{"dictsort"}, [](const func_args & args) -> value { + // no-op + args.ensure_vals(); + return args.args[0]; + }}, }; return builtins; } diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 7aef38cfbd..276c79156c 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -82,6 +82,17 @@ value identifier::execute_impl(context & ctx) { } } +value object_literal::execute_impl(context & ctx) { + auto obj = mk_val(); + for (const auto & pair : val) { + std::string key = pair.first->execute(ctx)->as_string().str(); + value val = pair.second->execute(ctx); + JJ_DEBUG("Object literal: setting key '%s' of type %s", key.c_str(), val->type().c_str()); + obj->val_obj[key] = val; + } + return obj; +} + value binary_expression::execute_impl(context & ctx) { value left_val = left->execute(ctx); JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right->type().c_str()); @@ -208,7 +219,7 @@ value binary_expression::execute_impl(context & ctx) { throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type()); } -static value try_builtin_func(const std::string & name, const value & input, bool undef_on_missing = true) { +static value try_builtin_func(const std::string & name, const value & input, bool undef_on_missing = false) { auto builtins = input->get_builtins(); auto it = builtins.find(name); if (it != builtins.end()) { @@ -331,11 +342,16 @@ value for_statement::execute_impl(context & ctx) { std::vector items; if (is_val(iterable_val)) { + JJ_DEBUG("%s", "For loop over object keys"); auto & obj = iterable_val->as_object(); for (auto & p : obj) { - items.push_back(mk_val(p.first)); + auto tuple = mk_val(); + tuple->push_back(mk_val(p.first)); + tuple->push_back(p.second); + items.push_back(tuple); } } else { + JJ_DEBUG("%s", "For loop over array items"); auto & arr = iterable_val->as_array(); for (const auto & item : arr) { items.push_back(item); diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 639fba9d03..647da3a72b 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -308,6 +308,7 @@ struct object_literal : public expression { } } std::string type() const override { return "ObjectLiteral"; } + value execute_impl(context & ctx) override; }; // Complex Expressions diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 097c60a543..0bf15bed91 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #undef NDEBUG #include @@ -11,6 +12,8 @@ #include "jinja/jinja-parser.h" #include "jinja/jinja-lexer.h" +void run(std::string contents); + int main(void) { //std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; @@ -19,8 +22,29 @@ int main(void) { //std::string contents = " {{ messages[a]['content'] }} "; //std::string contents = "{% if a is not defined %}hello{% endif %}"; - std::ifstream infile("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja"); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); + //std::ifstream infile("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja"); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); + // list all files in models/templates/ and run each + std::string dir_path = "models/templates/"; + for (const auto & entry : std::filesystem::directory_iterator(dir_path)) { + if (entry.is_regular_file()) { + std::cout << "\n\n=== RUNNING TEMPLATE FILE: " << entry.path().string() << " ===\n"; + std::ifstream infile(entry.path()); + std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); + try { + run(contents); + } catch (const std::exception & e) { + std::cout << "Exception: " << e.what() << "\n"; + std::cout << "=== CURRENT TEMPLATE FILE: " << entry.path().string() << " ===\n"; + exit(1); + } + } + } + return 0; +} + + +void run(std::string contents) { std::cout << "=== INPUT ===\n" << contents << "\n\n"; jinja::lexer lexer; @@ -68,6 +92,4 @@ int main(void) { for (const auto & part : parts) { std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n"; } - - return 0; } From 9a8a45ff3bb51eeed117b7305264833758039849 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 21:32:55 +0100 Subject: [PATCH 25/47] mostly works --- common/jinja/jinja-utils.h | 26 ++++++ common/jinja/jinja-value.h | 2 +- common/jinja/jinja-vm-builtins.cpp | 69 ++++++++++++++ common/jinja/jinja-vm.cpp | 139 +++++++++++++++++------------ common/jinja/jinja-vm.h | 31 +++++-- tests/test-chat-jinja.cpp | 2 + 6 files changed, 206 insertions(+), 63 deletions(-) create mode 100644 common/jinja/jinja-utils.h diff --git a/common/jinja/jinja-utils.h b/common/jinja/jinja-utils.h new file mode 100644 index 0000000000..a7d3bea5a8 --- /dev/null +++ b/common/jinja/jinja-utils.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include +#include + +namespace jinja { + +static void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { + if (search.empty()) { + return; + } + std::string builder; + builder.reserve(s.length()); + size_t pos = 0; + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); +} + +} // namespace jinja diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 94c638eab2..a5eafda2dd 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -65,7 +65,7 @@ struct func_args { throw std::runtime_error("Expected " + std::to_string(count) + " arguments, got " + std::to_string(args.size())); } } - // TODO: add support for get kwargs + value get_kwarg(const std::string & key) const; // utility functions template void ensure_vals() const { ensure_count(1); diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index ecc2cfea52..39ae955e79 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -67,12 +67,14 @@ template static value test_type_fn(const func_args & args) { args.ensure_count(1); bool is_type = is_val(args.args[0]); + JJ_DEBUG("test_type_fn: type=%s result=%d", typeid(T).name(), is_type ? 1 : 0); return mk_val(is_type); } template static value test_type_fn(const func_args & args) { args.ensure_count(1); bool is_type = is_val(args.args[0]) || is_val(args.args[0]); + JJ_DEBUG("test_type_fn: type=%s or %s result=%d", typeid(T).name(), typeid(U).name(), is_type ? 1 : 0); return mk_val(is_type); } @@ -95,6 +97,20 @@ const func_builtins & global_builtins() { } return out; }}, + {"strftime_now", [](const func_args & args) -> value { + args.ensure_count(1); + args.ensure_vals(); + std::string format = args.args[0]->as_string().str(); + // get current time + // TODO: make sure this is the same behavior as Python's strftime + std::time_t t = std::time(nullptr); + char buf[100]; + if (std::strftime(buf, sizeof(buf), format.c_str(), std::localtime(&t))) { + return mk_val(std::string(buf)); + } else { + throw raised_exception("strftime_now: failed to format time"); + } + }}, // tests {"test_is_boolean", test_type_fn}, @@ -296,6 +312,25 @@ const func_builtins & value_string_t::get_builtins() const { args.ensure_vals(); return mk_val(args.args[0]->as_string()); }}, + {"default", [](const func_args & args) -> value { + value input = args.args[0]; + if (!is_val(input)) { + throw raised_exception("default() first argument must be a string"); + } + value default_val = mk_val(""); + if (args.args.size() > 1 && !args.args[1]->is_undefined()) { + default_val = args.args[1]; + } + value boolean_val = mk_val(false); + if (args.args.size() > 1) { + boolean_val = args.args[1]; + } + if (input->is_undefined() || (boolean_val->as_bool() && !input->as_bool())) { + return default_val; + } else { + return input; + } + }}, {"indent", [](const func_args &) -> value { throw std::runtime_error("indent builtin not implemented"); }}, @@ -380,6 +415,40 @@ const func_builtins & value_array_t::get_builtins() const { res->val_arr = std::move(arr); return res; }}, + {"selectattr", [](const func_args & args) -> value { + value input = args.args[0]; + if (!is_val(input)) { + throw raised_exception("selectattr() first argument must be an array, got " + input->type()); + } + std::vector selected; + for (size_t i = 1; i < args.args.size(); ++i) { + const auto & v = args.args[i]; + if (!is_val(v)) { + throw raised_exception("selectattr() attributes must be strings, got " + v->type()); + } + JJ_DEBUG("selectattr: selecting attribute '%s'", v->as_string().str().c_str()); + selected.push_back(v->as_string().str()); + } + auto result = mk_val(); + for (const auto & item : input->as_array()) { + if (!is_val(item)) { + continue; + } + const auto & obj = item->as_object(); + bool match = true; + for (const auto & attr : selected) { + auto it = obj.find(attr); + if (it == obj.end() || it->second->is_undefined() || (is_val(it->second) && !it->second->as_bool())) { + match = false; + break; + } + } + if (match) { + result->push_back(item); + } + } + return result; + }}, // TODO: reverse, sort, join, string, unique }; return builtins; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 276c79156c..844dcdef7d 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -2,6 +2,7 @@ #include "jinja-vm.h" #include "jinja-parser.h" #include "jinja-value.h" +#include "jinja-utils.h" #include #include @@ -14,6 +15,22 @@ bool g_jinja_debug = true; namespace jinja { +// func_args method implementations + +value func_args::get_kwarg(const std::string & key) const { + for (const auto & arg : args) { + if (is_val(arg)) { + auto * kwarg = cast_val(arg); + if (kwarg->key == key) { + return kwarg->val; + } + } + } + return mk_val(); +} + +// utils + static value_array exec_statements(const statements & stmts, context & ctx) { auto result = mk_val(); for (const auto & stmt : stmts) { @@ -23,23 +40,6 @@ static value_array exec_statements(const statements & stmts, context & ctx) { return result; } -static void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { - if (search.empty()) { - return; - } - std::string builder; - builder.reserve(s.length()); - size_t pos = 0; - size_t last_pos = 0; - while ((pos = s.find(search, last_pos)) != std::string::npos) { - builder.append(s, last_pos, pos - last_pos); - builder.append(replace); - last_pos = pos + search.length(); - } - builder.append(s, last_pos, std::string::npos); - s = std::move(builder); -} - // execute with error handling value statement::execute(context & ctx) { try { @@ -138,6 +138,7 @@ value binary_expression::execute_impl(context & ctx) { return mk_val(static_cast(res)); } } else if (op.value == "/") { + JJ_DEBUG("Division operation: %f / %f", a, b); return mk_val(a / b); } else if (op.value == "%") { double rem = std::fmod(a, b); @@ -149,12 +150,16 @@ value binary_expression::execute_impl(context & ctx) { return mk_val(static_cast(rem)); } } else if (op.value == "<") { + JJ_DEBUG("Comparison operation: %f < %f is %d", a, b, a < b); return mk_val(a < b); } else if (op.value == ">") { + JJ_DEBUG("Comparison operation: %f > %f is %d", a, b, a > b); return mk_val(a > b); } else if (op.value == ">=") { + JJ_DEBUG("Comparison operation: %f >= %f is %d", a, b, a >= b); return mk_val(a >= b); } else if (op.value == "<=") { + JJ_DEBUG("Comparison operation: %f <= %f is %d", a, b, a <= b); return mk_val(a <= b); } } @@ -235,24 +240,33 @@ static value try_builtin_func(const std::string & name, const value & input, boo value filter_expression::execute_impl(context & ctx) { value input = operand->execute(ctx); - if (is_stmt(filter)) { - auto filter_val = cast_stmt(filter)->val; + JJ_DEBUG("Applying filter to %s", input->type().c_str()); - if (filter_val == "to_json") { + if (is_stmt(filter)) { + auto filter_id = cast_stmt(filter)->val; + + if (filter_id == "to_json") { // TODO: Implement to_json filter throw std::runtime_error("to_json filter not implemented"); } - if (filter_val == "trim") { - filter_val = "strip"; // alias + if (filter_id == "trim") { + filter_id = "strip"; // alias } - JJ_DEBUG("Applying filter '%s' to %s", filter_val.c_str(), input->type().c_str()); - return try_builtin_func(filter_val, input)->invoke({}); + JJ_DEBUG("Applying filter '%s' to %s", filter_id.c_str(), input->type().c_str()); + return try_builtin_func(filter_id, input)->invoke({}); } else if (is_stmt(filter)) { - // TODO - // value filter_func = filter->execute(ctx); - throw std::runtime_error("Filter with arguments not implemented"); + auto call = cast_stmt(filter); + auto filter_id = cast_stmt(call->callee)->val; + + JJ_DEBUG("Applying filter '%s' with arguments to %s", filter_id.c_str(), input->type().c_str()); + func_args args; + for (const auto & arg_expr : call->args) { + args.args.push_back(arg_expr->execute(ctx)); + } + + return try_builtin_func(filter_id, input)->invoke(args); } else { throw std::runtime_error("Invalid filter expression"); @@ -268,7 +282,7 @@ value test_expression::execute_impl(context & ctx) { auto test_id = cast_stmt(test)->val; auto it = builtins.find("test_is_" + test_id); - JJ_DEBUG("Test expression %s '%s' %s", operand->type().c_str(), test_id.c_str(), negate ? "(negate)" : ""); + JJ_DEBUG("Test expression %s '%s' %s (using function 'test_is_%s')", operand->type().c_str(), test_id.c_str(), negate ? "(negate)" : "", test_id.c_str()); if (it == builtins.end()) { throw std::runtime_error("Unknown test '" + test_id + "'"); } @@ -336,6 +350,12 @@ value for_statement::execute_impl(context & ctx) { JJ_DEBUG("Executing for statement, iterable type: %s", iter_expr->type().c_str()); value iterable_val = iter_expr->execute(scope); + + if (iterable_val->is_undefined()) { + JJ_DEBUG("%s", "For loop iterable is undefined, skipping loop"); + iterable_val = mk_val(); + } + if (!is_val(iterable_val) && !is_val(iterable_val)) { throw std::runtime_error("Expected iterable or object type in for loop: got " + iterable_val->type()); } @@ -555,7 +575,10 @@ value member_expression::execute_impl(context & ctx) { value val = mk_val(); - if (is_val(object)) { + if (is_val(object)) { + JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined"); + return val; + } else if (is_val(object)) { if (!is_val(property)) { throw std::runtime_error("Cannot access object with non-string: got " + property->type()); } @@ -623,35 +646,39 @@ value call_expression::execute_impl(context & ctx) { // compare operator for value_t bool value_compare(const value & a, const value & b) { - JJ_DEBUG("Comparing types: %s and %s", a->type().c_str(), b->type().c_str()); - // compare numeric types - if ((is_val(a) || is_val(a)) && - (is_val(b) || is_val(b))){ - try { - return a->as_float() == b->as_float(); - } catch (...) {} - } - // compare string and number - // TODO: not sure if this is the right behavior - if ((is_val(b) && (is_val(a) || is_val(a))) || - (is_val(a) && (is_val(b) || is_val(b)))) { - try { + auto cmp = [&]() { + // compare numeric types + if ((is_val(a) || is_val(a)) && + (is_val(b) || is_val(b))){ + try { + return a->as_float() == b->as_float(); + } catch (...) {} + } + // compare string and number + // TODO: not sure if this is the right behavior + if ((is_val(b) && (is_val(a) || is_val(a))) || + (is_val(a) && (is_val(b) || is_val(b)))) { + try { + return a->as_string().str() == b->as_string().str(); + } catch (...) {} + } + // compare boolean simple + if (is_val(a) && is_val(b)) { + return a->as_bool() == b->as_bool(); + } + // compare string simple + if (is_val(a) && is_val(b)) { return a->as_string().str() == b->as_string().str(); - } catch (...) {} - } - // compare boolean simple - if (is_val(a) && is_val(b)) { - return a->as_bool() == b->as_bool(); - } - // compare string simple - if (is_val(a) && is_val(b)) { - return a->as_string().str() == b->as_string().str(); - } - // compare by type - if (a->type() != b->type()) { + } + // compare by type + if (a->type() != b->type()) { + return false; + } return false; - } - return false; + }; + auto result = cmp(); + JJ_DEBUG("Comparing types: %s and %s result=%d", a->type().c_str(), b->type().c_str(), result); + return result; } value keyword_argument_expression::execute_impl(context & ctx) { diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 647da3a72b..5172969a9d 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -71,7 +71,7 @@ struct statement { // execute_impl must be overridden by derived classes virtual value execute_impl(context &) { throw std::runtime_error("cannot exec " + type()); } // execute is the public method to execute a statement with error handling - virtual value execute(context &); + value execute(context &); }; // Type Checking Utilities @@ -288,13 +288,17 @@ struct array_literal : public expression { for (const auto& item : this->val) chk_type(item); } std::string type() const override { return "ArrayLiteral"; } + value execute_impl(context & ctx) override { + auto arr = mk_val(); + for (const auto & item_stmt : val) { + arr->push_back(item_stmt->execute(ctx)); + } + return arr; + } }; -struct tuple_literal : public expression { - statements val; - explicit tuple_literal(statements && val) : val(std::move(val)) { - for (const auto & item : this->val) chk_type(item); - } +struct tuple_literal : public array_literal { + explicit tuple_literal(statements && val) : array_literal(std::move(val)) {} std::string type() const override { return "TupleLiteral"; } }; @@ -376,6 +380,13 @@ struct select_expression : public expression { chk_type(this->test); } std::string type() const override { return "SelectExpression"; } + value execute_impl(context & ctx) override { + auto predicate = test->execute_impl(ctx); + if (!predicate->as_bool()) { + return mk_val(); + } + return lhs->execute_impl(ctx); + } }; /** @@ -474,6 +485,14 @@ struct ternary_expression : public expression { chk_type(this->false_expr); } std::string type() const override { return "Ternary"; } + value execute_impl(context & ctx) override { + value cond_val = condition->execute(ctx); + if (cond_val->as_bool()) { + return true_expr->execute(ctx); + } else { + return false_expr->execute(ctx); + } + } }; struct raised_exception : public std::exception { diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 0bf15bed91..64777a3495 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -83,6 +83,8 @@ void run(std::string contents) { messages->push_back(std::move(msg2)); ctx.var["messages"] = std::move(messages); + ctx.var["eos_token"] = jinja::mk_val(""); + // ctx.var["tools"] = jinja::mk_val(); jinja::vm vm(ctx); const jinja::value results = vm.execute(ast); From adad34f64d2e4b6493df57d2a2a01eeb3ebbb911 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 22:02:22 +0100 Subject: [PATCH 26/47] add filter_statement --- common/jinja/jinja-vm.cpp | 17 +++++++++++- common/jinja/jinja-vm.h | 58 ++++++++++++++++++++++++++------------- tests/test-chat-jinja.cpp | 8 +++--- 3 files changed, 59 insertions(+), 24 deletions(-) diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 844dcdef7d..8ec8e742f0 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -11,10 +11,14 @@ #define FILENAME "jinja-vm" -bool g_jinja_debug = true; +bool g_jinja_debug = false; namespace jinja { +void enable_debug(bool enable) { + g_jinja_debug = enable; +} + // func_args method implementations value func_args::get_kwarg(const std::string & key) const { @@ -273,6 +277,17 @@ value filter_expression::execute_impl(context & ctx) { } } +value filter_statement::execute_impl(context & ctx) { + // eval body as string, then apply filter + auto body_val = exec_statements(body, ctx); + value_string parts = mk_val(); + gather_string_parts_recursive(body_val, parts); + + JJ_DEBUG("FilterStatement: applying filter to body string of length %zu", parts->val_str.length()); + filter_expression filter_expr(std::move(parts), std::move(filter)); + return filter_expr.execute(ctx); +} + value test_expression::execute_impl(context & ctx) { // NOTE: "value is something" translates to function call "test_is_something(value)" const auto & builtins = global_builtins(); diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 5172969a9d..d67bc2d5c1 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -42,6 +42,10 @@ const T * cast_stmt(const statement_ptr & ptr) { } // End Helpers + +// not thread-safe +void enable_debug(bool enable); + struct context { std::map var; std::string source; // for debugging @@ -260,7 +264,7 @@ struct integer_literal : public expression { explicit integer_literal(int64_t val) : val(val) {} std::string type() const override { return "IntegerLiteral"; } value execute_impl(context &) override { - return std::make_unique(val); + return mk_val(val); } }; @@ -269,7 +273,7 @@ struct float_literal : public expression { explicit float_literal(double val) : val(val) {} std::string type() const override { return "FloatLiteral"; } value execute_impl(context &) override { - return std::make_unique(val); + return mk_val(val); } }; @@ -278,7 +282,7 @@ struct string_literal : public expression { explicit string_literal(const std::string & val) : val(val) {} std::string type() const override { return "StringLiteral"; } value execute_impl(context &) override { - return std::make_unique(val); + return mk_val(val); } }; @@ -341,7 +345,10 @@ struct binary_expression : public expression { * Operator precedence: https://github.com/pallets/jinja/issues/379#issuecomment-168076202 */ struct filter_expression : public expression { + // either an expression or a value is allowed statement_ptr operand; + value_string val; // will be set by filter_statement + statement_ptr filter; filter_expression(statement_ptr && operand, statement_ptr && filter) @@ -349,6 +356,12 @@ struct filter_expression : public expression { chk_type(this->operand); chk_type(this->filter); } + + filter_expression(value_string && val, statement_ptr && filter) + : val(std::move(val)), filter(std::move(filter)) { + chk_type(this->filter); + } + std::string type() const override { return "FilterExpression"; } value execute_impl(context & ctx) override; }; @@ -362,6 +375,7 @@ struct filter_statement : public statement { chk_type(this->filter); } std::string type() const override { return "FilterStatement"; } + value execute_impl(context & ctx) override; }; /** @@ -505,6 +519,26 @@ struct raised_exception : public std::exception { ////////////////////// +static void gather_string_parts_recursive(const value & val, value_string & parts) { + if (is_val(val)) { + const auto & str_val = cast_val(val)->val_str; + parts->val_str.append(str_val); + } else if (is_val(val)) { + auto items = cast_val(val)->as_array(); + for (const auto & item : items) { + gather_string_parts_recursive(item, parts); + } + } +} + +static std::string render_string_parts(const value_string & parts) { + std::ostringstream oss; + for (const auto & part : parts->val_str.parts) { + oss << part.val; + } + return oss.str(); +} + struct vm { context & ctx; explicit vm(context & ctx) : ctx(ctx) {} @@ -518,25 +552,11 @@ struct vm { return results; } - std::vector gather_string_parts(const value & val) { - std::vector parts; + value_string gather_string_parts(const value & val) { + value_string parts = mk_val(); gather_string_parts_recursive(val, parts); return parts; } - - void gather_string_parts_recursive(const value & val, std::vector & parts) { - if (is_val(val)) { - const auto & str_val = cast_val(val)->val_str; - for (const auto & part : str_val.parts) { - parts.push_back(part); - } - } else if (is_val(val)) { - auto items = cast_val(val)->as_array(); - for (const auto & item : items) { - gather_string_parts_recursive(item, parts); - } - } - } }; } // namespace jinja diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 64777a3495..1f9dedb1e4 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -45,7 +45,7 @@ int main(void) { void run(std::string contents) { - std::cout << "=== INPUT ===\n" << contents << "\n\n"; + // jinja::enable_debug(true); jinja::lexer lexer; jinja::preprocess_options options; @@ -53,13 +53,13 @@ void run(std::string contents) { options.lstrip_blocks = false; auto lexer_res = lexer.tokenize(contents, options); for (const auto & tok : lexer_res.tokens) { - std::cout << "token: type=" << static_cast(tok.t) << " text='" << tok.value << "' pos=" << tok.pos << "\n"; + //std::cout << "token: type=" << static_cast(tok.t) << " text='" << tok.value << "' pos=" << tok.pos << "\n"; } std::cout << "\n=== AST ===\n"; jinja::program ast = jinja::parse_from_tokens(lexer_res.tokens); for (const auto & stmt : ast.body) { - std::cout << "stmt type: " << stmt->type() << "\n"; + //std::cout << "stmt type: " << stmt->type() << "\n"; } std::cout << "\n=== RUN ===\n"; @@ -91,7 +91,7 @@ void run(std::string contents) { auto parts = vm.gather_string_parts(results); std::cout << "\n=== RESULTS ===\n"; - for (const auto & part : parts) { + for (const auto & part : parts.get()->val_str.parts) { std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n"; } } From c7f246e7a5c2934fc1a0d25497a1638c7bcd0f9a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 22:15:10 +0100 Subject: [PATCH 27/47] allow func to access ctx --- common/jinja/jinja-value.h | 5 ++++- common/jinja/jinja-vm-builtins.cpp | 3 +-- common/jinja/jinja-vm.cpp | 10 +++++----- common/jinja/jinja-vm.h | 3 +++ 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index a5eafda2dd..b5ce893162 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -57,9 +57,12 @@ void ensure_val(const value & ptr) { } // End Helper +struct context; // forward declaration struct func_args { std::vector args; + context & ctx; + func_args(context & ctx) : ctx(ctx) {} void ensure_count(size_t count) const { if (args.size() != count) { throw std::runtime_error("Expected " + std::to_string(count) + " arguments, got " + std::to_string(args.size())); @@ -253,7 +256,7 @@ struct value_func_t : public value_t { } virtual value invoke(const func_args & args) const override { if (arg0) { - func_args new_args; + func_args new_args(args.ctx); new_args.args.push_back(arg0); for (const auto & a : args.args) { new_args.args.push_back(a); diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index 39ae955e79..258d0da487 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -103,9 +103,8 @@ const func_builtins & global_builtins() { std::string format = args.args[0]->as_string().str(); // get current time // TODO: make sure this is the same behavior as Python's strftime - std::time_t t = std::time(nullptr); char buf[100]; - if (std::strftime(buf, sizeof(buf), format.c_str(), std::localtime(&t))) { + if (std::strftime(buf, sizeof(buf), format.c_str(), std::localtime(&args.ctx.current_time))) { return mk_val(std::string(buf)); } else { throw raised_exception("strftime_now: failed to format time"); diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 8ec8e742f0..f1f252108f 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -258,14 +258,14 @@ value filter_expression::execute_impl(context & ctx) { filter_id = "strip"; // alias } JJ_DEBUG("Applying filter '%s' to %s", filter_id.c_str(), input->type().c_str()); - return try_builtin_func(filter_id, input)->invoke({}); + return try_builtin_func(filter_id, input)->invoke(func_args(ctx)); } else if (is_stmt(filter)) { auto call = cast_stmt(filter); auto filter_id = cast_stmt(call->callee)->val; JJ_DEBUG("Applying filter '%s' with arguments to %s", filter_id.c_str(), input->type().c_str()); - func_args args; + func_args args(ctx); for (const auto & arg_expr : call->args) { args.args.push_back(arg_expr->execute(ctx)); } @@ -302,7 +302,7 @@ value test_expression::execute_impl(context & ctx) { throw std::runtime_error("Unknown test '" + test_id + "'"); } - func_args args; + func_args args(ctx); args.args.push_back(operand->execute(ctx)); auto res = it->second(args); @@ -574,7 +574,7 @@ value member_expression::execute_impl(context & ctx) { stop_val->as_repr().c_str(), step_val->as_repr().c_str()); auto slice_func = try_builtin_func("slice", object); - func_args args; + func_args args(ctx); args.args.push_back(start_val); args.args.push_back(stop_val); args.args.push_back(step_val); @@ -643,7 +643,7 @@ value member_expression::execute_impl(context & ctx) { value call_expression::execute_impl(context & ctx) { // gather arguments - func_args args; + func_args args(ctx); for (auto & arg_stmt : this->args) { auto arg_val = arg_stmt->execute(ctx); JJ_DEBUG(" Argument type: %s", arg_val->type().c_str()); diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index d67bc2d5c1..596f325194 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -50,10 +50,13 @@ struct context { std::map var; std::string source; // for debugging + std::time_t current_time; // for functions that need current time + context() { var["true"] = mk_val(true); var["false"] = mk_val(false); var["none"] = mk_val(); + current_time = std::time(nullptr); } ~context() = default; From 55fe96a9dfe6aadcea42577e7318997295f3b2f4 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 22:49:31 +0100 Subject: [PATCH 28/47] add jinja-value.cpp --- common/CMakeLists.txt | 4 +++- .../{jinja-vm-builtins.cpp => jinja-value.cpp} | 14 ++++++++++++++ common/jinja/jinja-vm.cpp | 16 ---------------- 3 files changed, 17 insertions(+), 17 deletions(-) rename common/jinja/{jinja-vm-builtins.cpp => jinja-value.cpp} (98%) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 4ed0df100f..b270bebbcc 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -89,7 +89,9 @@ add_library(${TARGET} STATIC jinja/jinja-parser.h jinja/jinja-vm.cpp jinja/jinja-vm.h - jinja/jinja-vm-builtins.cpp + jinja/jinja-value.cpp + jinja/jinja-value.h + jinja/jinja-string.h ) target_include_directories(${TARGET} PUBLIC . ../vendor) diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-value.cpp similarity index 98% rename from common/jinja/jinja-vm-builtins.cpp rename to common/jinja/jinja-value.cpp index 258d0da487..cdf39a8f66 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-value.cpp @@ -13,6 +13,20 @@ namespace jinja { +// func_args method implementations + +value func_args::get_kwarg(const std::string & key) const { + for (const auto & arg : args) { + if (is_val(arg)) { + auto * kwarg = cast_val(arg); + if (kwarg->key == key) { + return kwarg->val; + } + } + } + return mk_val(); +} + /** * Function that mimics Python's array slicing. */ diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index f1f252108f..edb9363123 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -19,22 +19,6 @@ void enable_debug(bool enable) { g_jinja_debug = enable; } -// func_args method implementations - -value func_args::get_kwarg(const std::string & key) const { - for (const auto & arg : args) { - if (is_val(arg)) { - auto * kwarg = cast_val(arg); - if (kwarg->key == key) { - return kwarg->val; - } - } - } - return mk_val(); -} - -// utils - static value_array exec_statements(const statements & stmts, context & ctx) { auto result = mk_val(); for (const auto & stmt : stmts) { From 1784a57e7bec130c51a4175ba94adbb2ce136eb6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 23:15:48 +0100 Subject: [PATCH 29/47] impl global_from_json --- common/jinja/jinja-value.cpp | 49 +++++++++++++++++++++++++++++++++++ common/jinja/jinja-value.h | 33 ++++++++++++++++++++++++ common/jinja/jinja-vm.cpp | 2 +- tests/test-chat-jinja.cpp | 50 +++++++++++++++++++++--------------- 4 files changed, 112 insertions(+), 22 deletions(-) diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index cdf39a8f66..9461901c6d 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -3,6 +3,9 @@ #include "jinja-parser.h" #include "jinja-value.h" +// for converting from JSON to jinja values +#include + #include #include #include @@ -520,4 +523,50 @@ const func_builtins & value_object_t::get_builtins() const { return builtins; } +static value from_json(const nlohmann::json & j) { + if (j.is_null()) { + return mk_val(); + } else if (j.is_boolean()) { + return mk_val(j.get()); + } else if (j.is_number_integer()) { + return mk_val(j.get()); + } else if (j.is_number_float()) { + return mk_val(j.get()); + } else if (j.is_string()) { + return mk_val(j.get()); + } else if (j.is_array()) { + auto arr = mk_val(); + for (const auto & item : j) { + arr->push_back(from_json(item)); + } + return arr; + } else if (j.is_object()) { + if (j.contains("__input__")) { + // handle input marking + auto str = mk_val(j.at("__input__").get()); + str->mark_input(); + return str; + } else { + // normal object + auto obj = mk_val(); + for (auto it = j.begin(); it != j.end(); ++it) { + obj->insert(it.key(), from_json(it.value())); + } + return obj; + } + } else { + throw std::runtime_error("Unsupported JSON value type"); + } +} + +template<> +void global_from_json(context & ctx, const nlohmann::json & json_obj) { + if (json_obj.is_null() || !json_obj.is_object()) { + throw std::runtime_error("global_from_json: input JSON value must be an object"); + } + for (auto it = json_obj.begin(); it != json_obj.end(); ++it) { + ctx.var[it.key()] = from_json(it.value()); + } +} + } // namespace jinja diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index b5ce893162..04c6c6da28 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -57,8 +57,41 @@ void ensure_val(const value & ptr) { } // End Helper + struct context; // forward declaration + +// for converting from JSON to jinja values +// example input JSON: +// { +// "messages": [ +// {"role": "user", "content": "Hello!"}, +// {"role": "assistant", "content": "Hi there!"} +// ], +// "bos_token": "", +// "eos_token": "", +// } +// +// to mark strings as user input, wrap them in a special object: +// { +// "messages": [ +// { +// "role": "user", +// "content": {"__input__": "Hello!"} // this string is user input +// }, +// ... +// ], +// } +// +// marking input can be useful for tracking data provenance +// and preventing template injection attacks +// +// Note: T_JSON can be nlohmann::json or similar types +template +void global_from_json(context & ctx, const T_JSON & json_obj); + + + struct func_args { std::vector args; context & ctx; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index edb9363123..4c38ebde7d 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -226,7 +226,7 @@ static value try_builtin_func(const std::string & name, const value & input, boo } value filter_expression::execute_impl(context & ctx) { - value input = operand->execute(ctx); + value input = operand ? operand->execute(ctx) : val; JJ_DEBUG("Applying filter to %s", input->type().c_str()); diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 1f9dedb1e4..997d463061 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -6,6 +6,8 @@ #include #include +#include + #undef NDEBUG #include @@ -24,10 +26,14 @@ int main(void) { //std::ifstream infile("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja"); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); + std::vector failed_tests; + // list all files in models/templates/ and run each + size_t test_count = 0; std::string dir_path = "models/templates/"; for (const auto & entry : std::filesystem::directory_iterator(dir_path)) { if (entry.is_regular_file()) { + test_count++; std::cout << "\n\n=== RUNNING TEMPLATE FILE: " << entry.path().string() << " ===\n"; std::ifstream infile(entry.path()); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); @@ -35,11 +41,18 @@ int main(void) { run(contents); } catch (const std::exception & e) { std::cout << "Exception: " << e.what() << "\n"; - std::cout << "=== CURRENT TEMPLATE FILE: " << entry.path().string() << " ===\n"; - exit(1); + std::cout << "=== ERROR WITH TEMPLATE FILE: " << entry.path().string() << " ===\n"; + failed_tests.push_back(entry.path().string()); } } } + + std::cout << "\n\n=== TEST SUMMARY ===\n"; + std::cout << "Total tests run: " << test_count << "\n"; + std::cout << "Total failed tests: " << failed_tests.size() << "\n"; + for (const auto & test : failed_tests) { + std::cout << "FAILED TEST: " << test << "\n"; + } return 0; } @@ -66,25 +79,20 @@ void run(std::string contents) { jinja::context ctx; ctx.source = lexer_res.preprocessed_source; - auto make_non_special_string = [](const std::string & s) { - jinja::value_string str_val = jinja::mk_val(s); - str_val->mark_input(); - return str_val; - }; - - jinja::value_array messages = jinja::mk_val(); - jinja::value_object msg1 = jinja::mk_val(); - msg1->insert("role", make_non_special_string("user")); - msg1->insert("content", make_non_special_string("Hello, how are you?")); - messages->push_back(std::move(msg1)); - jinja::value_object msg2 = jinja::mk_val(); - msg2->insert("role", make_non_special_string("assistant")); - msg2->insert("content", make_non_special_string("I am fine, thank you!")); - messages->push_back(std::move(msg2)); - - ctx.var["messages"] = std::move(messages); - ctx.var["eos_token"] = jinja::mk_val(""); - // ctx.var["tools"] = jinja::mk_val(); + std::string json_inp = R"({ + "messages": [ + { + "role": "user", + "content": {"__input__": "Hello, how are you?"} + }, + { + "role": "assistant", + "content": {"__input__": "I am fine, thank you!"} + } + ], + "eos_token": "" + })"; + jinja::global_from_json(ctx, nlohmann::json::parse(json_inp)); jinja::vm vm(ctx); const jinja::value results = vm.execute(ast); From 2a31c9a30cf984f39a3ba71e66f3efee1bc59aa7 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 29 Dec 2025 00:38:29 +0100 Subject: [PATCH 30/47] a lot of fixes --- common/jinja/jinja-value.cpp | 134 ++++++++++++++++++++++++++++++++++- common/jinja/jinja-value.h | 12 ++-- common/jinja/jinja-vm.cpp | 80 +++++++++++++++------ tests/test-chat-jinja.cpp | 7 +- 4 files changed, 202 insertions(+), 31 deletions(-) diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index 9461901c6d..f382a64a86 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -127,6 +127,44 @@ const func_builtins & global_builtins() { throw raised_exception("strftime_now: failed to format time"); } }}, + {"range", [](const func_args & args) -> value { + if (args.args.size() < 1 || args.args.size() > 3) { + throw raised_exception("slice() takes between 1 and 3 arguments"); + } + int64_t arg0 = is_val(args.args[0]) ? args.args[0]->as_int() : 0; + int64_t arg1 = is_val(args.args[1]) ? args.args[1]->as_int() : -1; + int64_t arg2 = is_val(args.args[2]) ? args.args[2]->as_int() : 1; + + int64_t start, stop, step; + if (args.args.size() == 1) { + start = 0; + stop = arg0; + step = 1; + } else if (args.args.size() == 2) { + start = arg0; + stop = arg1; + step = 1; + } else { + start = arg0; + stop = arg1; + step = arg2; + } + + auto out = mk_val(); + if (step == 0) { + throw raised_exception("range() step argument must not be zero"); + } + if (step > 0) { + for (int64_t i = start; i < stop; i += step) { + out->push_back(mk_val(i)); + } + } else { + for (int64_t i = start; i > stop; i += step) { + out->push_back(mk_val(i)); + } + } + return out; + }}, // tests {"test_is_boolean", test_type_fn}, @@ -416,7 +454,9 @@ const func_builtins & value_array_t::get_builtins() const { return mk_val(static_cast(arr.size())); }}, {"slice", [](const func_args & args) -> value { - args.ensure_count(4); + if (args.args.size() < 1 || args.args.size() > 4) { + throw raised_exception("slice() takes between 1 and 4 arguments"); + } int64_t start = is_val(args.args[1]) ? args.args[1]->as_int() : 0; int64_t stop = is_val(args.args[2]) ? args.args[2]->as_int() : -1; int64_t step = is_val(args.args[3]) ? args.args[3]->as_int() : 1; @@ -465,7 +505,77 @@ const func_builtins & value_array_t::get_builtins() const { } return result; }}, - // TODO: reverse, sort, join, string, unique + {"rejectattr", [](const func_args & args) -> value { + value input = args.args[0]; + if (!is_val(input)) { + throw raised_exception("rejectattr() first argument must be an array, got " + input->type()); + } + std::vector rejected; + for (size_t i = 1; i < args.args.size(); ++i) { + const auto & v = args.args[i]; + if (!is_val(v)) { + throw raised_exception("rejectattr() attributes must be strings, got " + v->type()); + } + JJ_DEBUG("rejectattr: rejecting attribute '%s'", v->as_string().str().c_str()); + rejected.push_back(v->as_string().str()); + } + auto result = mk_val(); + for (const auto & item : input->as_array()) { + if (!is_val(item)) { + result->push_back(item); + continue; + } + const auto & obj = item->as_object(); + bool match = false; + for (const auto & attr : rejected) { + auto it = obj.find(attr); + if (it != obj.end() && !it->second->is_undefined() && (!is_val(it->second) || it->second->as_bool())) { + match = true; + break; + } + } + if (!match) { + result->push_back(item); + } + } + return result; + }}, + {"join", [](const func_args & args) -> value { + if (args.args.size() < 1 || args.args.size() > 2) { + throw raised_exception("join() takes one or two arguments"); + } + if (!is_val(args.args[0])) { + throw raised_exception("join() first argument must be an array"); + } + const auto & arr = args.args[0]->as_array(); + std::string delim = (args.args.size() > 1 && is_val(args.args[1])) ? args.args[1]->as_string().str() : ""; + std::string result; + for (size_t i = 0; i < arr.size(); ++i) { + if (!is_val(arr[i])) { + throw raised_exception("join() can only join arrays of strings"); + } + result += arr[i]->as_string().str(); + if (i < arr.size() - 1) { + result += delim; + } + } + return mk_val(result); + }}, + {"string", [](const func_args & args) -> value { + args.ensure_vals(); + auto str = mk_val(); + gather_string_parts_recursive(args.args[0], str); + return str; + }}, + {"sort", [](const func_args &) -> value { + throw std::runtime_error("Array sort builtin not implemented"); + }}, + {"reverse", [](const func_args &) -> value { + throw std::runtime_error("Array reverse builtin not implemented"); + }}, + {"unique", [](const func_args &) -> value { + throw std::runtime_error("Array unique builtin not implemented"); + }}, }; return builtins; } @@ -523,6 +633,26 @@ const func_builtins & value_object_t::get_builtins() const { return builtins; } +const func_builtins & value_null_t::get_builtins() const { + static const func_builtins builtins = { + {"list", [](const func_args &) -> value { + // fix for meetkai-functionary-medium-v3.1.jinja + // TODO: hide under a flag? + return mk_val(); + }}, + {"selectattr", [](const func_args &) -> value { + // fix for meetkai-functionary-medium-v3.1.jinja + // TODO: hide under a flag? + return mk_val(); + }}, + }; + return builtins; +} + + +////////////////////////////////// + + static value from_json(const nlohmann::json & j) { if (j.is_null()) { return mk_val(); diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 04c6c6da28..3289a0de59 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -96,13 +96,14 @@ struct func_args { std::vector args; context & ctx; func_args(context & ctx) : ctx(ctx) {} - void ensure_count(size_t count) const { - if (args.size() != count) { - throw std::runtime_error("Expected " + std::to_string(count) + " arguments, got " + std::to_string(args.size())); + void ensure_count(size_t min, size_t max = 999) const { + if (args.size() < min || args.size() > max) { + throw std::runtime_error("Expected between " + std::to_string(min) + " and " + std::to_string(max) + " arguments, got " + std::to_string(args.size())); } } value get_kwarg(const std::string & key) const; // utility functions + // TODO: allow optional arguments template void ensure_vals() const { ensure_count(1); ensure_val(args[0]); @@ -310,12 +311,15 @@ struct value_null_t : public value_t { virtual bool is_null() const override { return true; } virtual bool as_bool() const override { return false; } virtual std::string as_repr() const override { return type(); } + virtual const func_builtins & get_builtins() const override; }; using value_null = std::shared_ptr; struct value_undefined_t : public value_t { - virtual std::string type() const override { return "Undefined"; } + std::string hint; // for debugging, to indicate where undefined came from + value_undefined_t(const std::string & h = "") : hint(h) {} + virtual std::string type() const override { return hint.empty() ? "Undefined" : "Undefined (hint: '" + hint + "')"; } virtual bool is_undefined() const override { return true; } virtual bool as_bool() const override { return false; } virtual std::string as_repr() const override { return type(); } diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 4c38ebde7d..0211ef9013 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -19,13 +19,16 @@ void enable_debug(bool enable) { g_jinja_debug = enable; } -static value_array exec_statements(const statements & stmts, context & ctx) { +static value_string exec_statements(const statements & stmts, context & ctx) { auto result = mk_val(); for (const auto & stmt : stmts) { JJ_DEBUG("Executing statement of type %s", stmt->type().c_str()); result->push_back(stmt->execute(ctx)); } - return result; + // convert to string parts + value_string str = mk_val(); + gather_string_parts_recursive(result, str); + return str; } // execute with error handling @@ -66,7 +69,7 @@ value identifier::execute_impl(context & ctx) { return mk_val(builtins.at(val), val); } else { JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str()); - return mk_val(); + return mk_val(val); } } @@ -83,7 +86,6 @@ value object_literal::execute_impl(context & ctx) { value binary_expression::execute_impl(context & ctx) { value left_val = left->execute(ctx); - JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right->type().c_str()); // Logical operators if (op.value == "and") { @@ -94,6 +96,7 @@ value binary_expression::execute_impl(context & ctx) { // Equality operators value right_val = right->execute(ctx); + JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right_val->type().c_str()); if (op.value == "==") { return mk_val(value_compare(left_val, right_val)); } else if (op.value == "!=") { @@ -168,10 +171,18 @@ value binary_expression::execute_impl(context & ctx) { } } else if (is_val(right_val)) { auto & arr = right_val->as_array(); - bool member = std::find_if(arr.begin(), arr.end(), [&](const value& v) { return v == left_val; }) != arr.end(); + bool member = false; + for (const auto & item : arr) { + if (value_compare(left_val, item)) { + member = true; + break; + } + } if (op.value == "in") { + JJ_DEBUG("Checking membership: %s in Array is %d", left_val->type().c_str(), member); return mk_val(member); } else if (op.value == "not in") { + JJ_DEBUG("Checking non-membership: %s not in Array is %d", left_val->type().c_str(), !member); return mk_val(!member); } } @@ -220,7 +231,7 @@ static value try_builtin_func(const std::string & name, const value & input, boo return mk_val(it->second, input, name); } if (undef_on_missing) { - return mk_val(); + return mk_val(name); } throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type()); } @@ -330,7 +341,10 @@ value if_statement::execute_impl(context & ctx) { out->push_back(stmt->execute(ctx)); } } - return out; + // convert to string parts + value_string str = mk_val(); + gather_string_parts_recursive(out, str); + return str; } value for_statement::execute_impl(context & ctx) { @@ -437,8 +451,8 @@ value for_statement::execute_impl(context & ctx) { loop_obj->insert("first", mk_val(i == 0)); loop_obj->insert("last", mk_val(i == filtered_items.size() - 1)); loop_obj->insert("length", mk_val(filtered_items.size())); - loop_obj->insert("previtem", i > 0 ? filtered_items[i - 1] : mk_val()); - loop_obj->insert("nextitem", i < filtered_items.size() - 1 ? filtered_items[i + 1] : mk_val()); + loop_obj->insert("previtem", i > 0 ? filtered_items[i - 1] : mk_val("previtem")); + loop_obj->insert("nextitem", i < filtered_items.size() - 1 ? filtered_items[i + 1] : mk_val("nextitem")); ctx.var["loop"] = loop_obj; scope_update_fns[i](ctx); try { @@ -460,7 +474,10 @@ value for_statement::execute_impl(context & ctx) { } } - return result; + // convert to string parts + value_string str = mk_val(); + gather_string_parts_recursive(result, str); + return str; } value set_statement::execute_impl(context & ctx) { @@ -515,24 +532,41 @@ value set_statement::execute_impl(context & ctx) { value macro_statement::execute_impl(context & ctx) { std::string name = cast_stmt(this->name)->val; - const func_handler func = [this, &ctx, name](const func_args & args) -> value { - JJ_DEBUG("Invoking macro '%s' with %zu arguments", name.c_str(), args.args.size()); + + const func_handler func = [this, name, &ctx](const func_args & args) -> value { + size_t expected_count = this->args.size(); + size_t input_count = args.args.size(); + + JJ_DEBUG("Invoking macro '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count); context macro_ctx(ctx); // new scope for macro execution // bind parameters - size_t param_count = this->args.size(); - size_t arg_count = args.args.size(); - for (size_t i = 0; i < param_count; ++i) { - std::string param_name = cast_stmt(this->args[i])->val; - if (i < arg_count) { + for (size_t i = 0; i < expected_count; ++i) { + if (i < input_count) { + std::string param_name = cast_stmt(this->args[i])->val; + JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), args.args[i]->type().c_str()); macro_ctx.var[param_name] = args.args[i]; } else { - macro_ctx.var[param_name] = mk_val(); + auto & default_arg = this->args[i]; + if (is_stmt(default_arg)) { + auto kwarg = cast_stmt(default_arg); + std::string param_name = cast_stmt(kwarg->key)->val; + JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str()); + macro_ctx.var[param_name] = kwarg->val->execute(ctx); + } else { + throw std::runtime_error("Not enough arguments provided to macro '" + name + "'"); + } + //std::string param_name = cast_stmt(default_args[i])->val; + //JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str()); + //macro_ctx.var[param_name] = default_args[i]->execute(ctx); } } // execute macro body - return exec_statements(this->body, macro_ctx); + JJ_DEBUG("Executing macro '%s' body with %zu statements", name.c_str(), this->body.size()); + auto res = exec_statements(this->body, macro_ctx); + JJ_DEBUG("Macro '%s' execution complete, result: %s", name.c_str(), res->val_str.str().c_str()); + return res; }; JJ_DEBUG("Defining macro '%s' with %zu parameters", name.c_str(), args.size()); @@ -548,9 +582,9 @@ value member_expression::execute_impl(context & ctx) { JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str()); if (is_stmt(this->property)) { auto s = cast_stmt(this->property); - value start_val = s->start_expr ? s->start_expr->execute(ctx) : mk_val(); - value stop_val = s->stop_expr ? s->stop_expr->execute(ctx) : mk_val(); - value step_val = s->step_expr ? s->step_expr->execute(ctx) : mk_val(); + value start_val = s->start_expr ? s->start_expr->execute(ctx) : mk_val("start"); + value stop_val = s->stop_expr ? s->stop_expr->execute(ctx) : mk_val("stop"); + value step_val = s->step_expr ? s->step_expr->execute(ctx) : mk_val("step"); // translate to function call: obj.slice(start, stop, step) JJ_DEBUG("Member expression is a slice: start %s, stop %s, step %s", @@ -572,7 +606,7 @@ value member_expression::execute_impl(context & ctx) { JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str()); - value val = mk_val(); + value val = mk_val("object_property"); if (is_val(object)) { JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined"); diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 997d463061..72f3ee9822 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -58,7 +58,7 @@ int main(void) { void run(std::string contents) { - // jinja::enable_debug(true); + jinja::enable_debug(true); jinja::lexer lexer; jinja::preprocess_options options; @@ -90,7 +90,10 @@ void run(std::string contents) { "content": {"__input__": "I am fine, thank you!"} } ], - "eos_token": "" + "bos_token": "", + "eos_token": "", + "functions": "", + "datetime": "" })"; jinja::global_from_json(ctx, nlohmann::json::parse(json_inp)); From 1cf25734a981d6d173c2d13621ee9b233f114ad1 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 29 Dec 2025 10:53:32 +0100 Subject: [PATCH 31/47] more tests --- common/jinja/jinja-value.cpp | 44 ++++++++++++++++++++++++--------- common/jinja/jinja-vm.cpp | 14 ++++++++++- common/jinja/jinja-vm.h | 3 +++ common/jinja/jinja-workaround.h | 20 +++++++++++++++ tests/test-chat-jinja.cpp | 26 ++++++++++++++++++- 5 files changed, 94 insertions(+), 13 deletions(-) create mode 100644 common/jinja/jinja-workaround.h diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index f382a64a86..70cca62cff 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -385,14 +385,30 @@ const func_builtins & value_string_t::get_builtins() const { return input; } }}, + {"slice", [](const func_args & args) -> value { + auto & input = args.args[0]; + if (!is_val(input)) { + throw raised_exception("slice() first argument must be a string"); + } + if (args.args.size() < 1 || args.args.size() > 4) { + throw raised_exception("slice() takes between 1 and 4 arguments"); + } + int64_t start = is_val(args.args[1]) ? args.args[1]->as_int() : 0; + int64_t stop = is_val(args.args[2]) ? args.args[2]->as_int() : -1; + int64_t step = is_val(args.args[3]) ? args.args[3]->as_int() : 1; + if (step == 0) { + throw raised_exception("slice step cannot be zero"); + } + auto sliced = slice(input->as_string().str(), start, stop, step); + auto res = mk_val(sliced); + res->val_str.mark_input_based_on(input->as_string()); + return res; + }}, {"indent", [](const func_args &) -> value { - throw std::runtime_error("indent builtin not implemented"); + throw std::runtime_error("String indent builtin not implemented"); }}, {"join", [](const func_args &) -> value { - throw std::runtime_error("join builtin not implemented"); - }}, - {"slice", [](const func_args &) -> value { - throw std::runtime_error("slice builtin not implemented"); + throw std::runtime_error("String join builtin not implemented"); }}, }; return builtins; @@ -635,15 +651,21 @@ const func_builtins & value_object_t::get_builtins() const { const func_builtins & value_null_t::get_builtins() const { static const func_builtins builtins = { - {"list", [](const func_args &) -> value { + {"list", [](const func_args & args) -> value { // fix for meetkai-functionary-medium-v3.1.jinja - // TODO: hide under a flag? - return mk_val(); + if (args.ctx.wrk_around.none_has_builtins) { + return mk_val(); + } else { + throw raised_exception("'list' builtin not supported for none type"); + } }}, - {"selectattr", [](const func_args &) -> value { + {"selectattr", [](const func_args & args) -> value { // fix for meetkai-functionary-medium-v3.1.jinja - // TODO: hide under a flag? - return mk_val(); + if (args.ctx.wrk_around.none_has_builtins) { + return mk_val(); + } else { + throw raised_exception("'selectattr' builtin not supported for none type"); + } }}, }; return builtins; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 0211ef9013..94ee370029 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -109,6 +109,15 @@ value binary_expression::execute_impl(context & ctx) { // Special case: `anything in undefined` is `false` and `anything not in undefined` is `true` return mk_val(op.value == "not in"); } + if (ctx.wrk_around.string_plus_undefined_is_string && (op.value == "+" || op.value == "~")) { + JJ_DEBUG("%s", "Workaround: treating undefined as empty string for string concatenation"); + auto left_str = left_val->is_undefined() ? string() : left_val->as_string(); + auto right_str = right_val->is_undefined() ? string() : right_val->as_string(); + auto output = left_str.append(right_str); + auto res = mk_val(); + res->val_str = std::move(output); + return res; + } throw std::runtime_error("Cannot perform operation " + op.value + " on undefined values"); } else if (is_val(left_val) || is_val(right_val)) { throw std::runtime_error("Cannot perform operation on null values"); @@ -628,9 +637,12 @@ value member_expression::execute_impl(context & ctx) { } else if (is_val(object) || is_val(object)) { if (is_val(property)) { int64_t index = property->as_int(); - JJ_DEBUG("Accessing %s index %lld", is_val(object) ? "array" : "string", index); + JJ_DEBUG("Accessing %s index %lld", object->type().c_str(), index); if (is_val(object)) { auto & arr = object->as_array(); + if (index < 0) { + index += static_cast(arr.size()); + } if (index >= 0 && index < static_cast(arr.size())) { val = arr[index]; } diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 596f325194..045d45d980 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -2,6 +2,7 @@ #include "jinja-lexer.h" #include "jinja-value.h" +#include "jinja-workaround.h" #include #include @@ -52,6 +53,8 @@ struct context { std::time_t current_time; // for functions that need current time + workarounds wrk_around; // workarounds for non-standard jinja behavior + context() { var["true"] = mk_val(true); var["false"] = mk_val(false); diff --git a/common/jinja/jinja-workaround.h b/common/jinja/jinja-workaround.h new file mode 100644 index 0000000000..766132c0ca --- /dev/null +++ b/common/jinja/jinja-workaround.h @@ -0,0 +1,20 @@ +#pragma once + +#include "jinja-value.h" + +#include +#include + +namespace jinja { + +// containing workarounds for Jinja templates that rely on non-standard behavior + +struct workarounds { + // meetkai-functionary-medium-v3.1.jinja call filter on None type + bool none_has_builtins = true; + + // Olmo calls operation + between string and undefined + bool string_plus_undefined_is_string = true; +}; + +} // namespace jinja diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 72f3ee9822..61ce80d8ac 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -28,11 +28,32 @@ int main(void) { std::vector failed_tests; + auto is_ignored_file = [](const std::string & filename) -> bool { + std::vector ignored_files = { + "Apriel-", + "Olmo-3-7B-Instruct-Heretic-GGUF", + }; + for (const auto & ignored : ignored_files) { + if (filename.find(ignored) != std::string::npos) { + return true; + } + } + return false; + }; + // list all files in models/templates/ and run each size_t test_count = 0; - std::string dir_path = "models/templates/"; + size_t skip_count = 0; + //std::string dir_path = "models/templates/"; + std::string dir_path = "../test-jinja/templates/"; for (const auto & entry : std::filesystem::directory_iterator(dir_path)) { if (entry.is_regular_file()) { + if (is_ignored_file(entry.path().filename().string())) { + std::cout << "=== SKIPPING TEMPLATE FILE: " << entry.path().string() << " ===\n"; + skip_count++; + continue; + } + test_count++; std::cout << "\n\n=== RUNNING TEMPLATE FILE: " << entry.path().string() << " ===\n"; std::ifstream infile(entry.path()); @@ -43,6 +64,7 @@ int main(void) { std::cout << "Exception: " << e.what() << "\n"; std::cout << "=== ERROR WITH TEMPLATE FILE: " << entry.path().string() << " ===\n"; failed_tests.push_back(entry.path().string()); + exit(1); } } } @@ -50,6 +72,7 @@ int main(void) { std::cout << "\n\n=== TEST SUMMARY ===\n"; std::cout << "Total tests run: " << test_count << "\n"; std::cout << "Total failed tests: " << failed_tests.size() << "\n"; + std::cout << "Total skipped tests: " << skip_count << "\n"; for (const auto & test : failed_tests) { std::cout << "FAILED TEST: " << test << "\n"; } @@ -92,6 +115,7 @@ void run(std::string contents) { ], "bos_token": "", "eos_token": "", + "tools": [], "functions": "", "datetime": "" })"; From 026730e8e3b029c45748421e5ae06c06f42e2321 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 29 Dec 2025 12:53:31 +0100 Subject: [PATCH 32/47] more fix, more tests --- common/jinja/jinja-lexer.cpp | 74 ++++++++++++++++++++++++++++++----- common/jinja/jinja-parser.cpp | 27 ++++++++++--- common/jinja/jinja-parser.h | 2 + common/jinja/jinja-value.cpp | 18 ++++----- common/jinja/jinja-vm.h | 2 +- tests/test-chat-jinja.cpp | 10 +++-- 6 files changed, 106 insertions(+), 27 deletions(-) diff --git a/common/jinja/jinja-lexer.cpp b/common/jinja/jinja-lexer.cpp index 541452f3fe..285ccc0151 100644 --- a/common/jinja/jinja-lexer.cpp +++ b/common/jinja/jinja-lexer.cpp @@ -1,4 +1,5 @@ #include "jinja-lexer.h" +#include "jinja-vm.h" #include #include @@ -7,13 +8,73 @@ #include #include #include +#include - -// #define JJ_DEBUG(msg, ...) printf("jinja-lexer: " msg "\n", __VA_ARGS__) -#define JJ_DEBUG(msg, ...) // no-op +#define FILENAME "jinja-lexer" namespace jinja { +// Trim template markers with '-' for whitespace control +// Example: [spaces]{%- ... -%} --> {% ... %} +#include +#include + +static void trim_template_markers_inplace(std::string & s) { + // i = head ; j = tail (i <= j) + size_t j = 0; // Write pointer + const size_t len = s.length(); + + for (size_t i = 0; i < len; ) { + bool handled = false; + + // We need at least 3 characters for any marker: {X- or -X} + if (i + 2 < len) { + const char c1 = s[i]; + const char c2 = s[i + 1]; + const char c3 = s[i + 2]; + + // 1. Closing trim: -X} where X = %, }, # + // Example: [content]-%} [spaces] -> [content]%} + if (c1 == '-' && c3 == '}' && (c2 == '%' || c2 == '}' || c2 == '#')) { + s[j++] = c2; + s[j++] = '}'; + i += 3; + // Strip leading whitespace AFTER the tag + while (i < len && std::isspace(static_cast(s[i]))) { + i++; + } + handled = true; + } + // 2. Opening trim: {X- where X = %, {, # + // Example: [spaces]{%- [content] -> {% [content] + else if (c1 == '{' && c3 == '-' && (c2 == '%' || c2 == '{' || c2 == '#')) { + // Trim trailing whitespace BEFORE the tag by moving write pointer back + while (j > 0 && std::isspace(static_cast(s[j - 1]))) { + j--; + } + + // Safety: Prevent merging '{' with tag start (avoid creating '{{%' or '{{{') + // if the character immediately before our new tag is a literal '{'. + if (j > 0 && s[j - 1] == '{') { + s[j++] = ' '; + } + + s[j++] = '{'; + s[j++] = c2; + i += 3; + handled = true; + } + } + + if (!handled) { + // Note: j is always <= i here, so this is safe. + s[j++] = s[i++]; + } + } + + s.resize(j); +} + std::string lexer::preprocess(const std::string & template_str, const preprocess_options & options) const { std::string result = template_str; // According to https://jinja.palletsprojects.com/en/3.0.x/templates/#whitespace-control @@ -40,12 +101,7 @@ std::string lexer::preprocess(const std::string & template_str, const preprocess } // Handle whitespace control with - in tags - result = std::regex_replace(result, std::regex(R"(-%\}\s*)"), "%}"); - result = std::regex_replace(result, std::regex(R"(\s*\{%-)"), "{%"); - result = std::regex_replace(result, std::regex(R"(-\}\}\s*)"), "}}"); - result = std::regex_replace(result, std::regex(R"(\s*\{\{-)"), "{{"); - result = std::regex_replace(result, std::regex(R"(-#\}\s*)"), "#}"); - result = std::regex_replace(result, std::regex(R"(\s*\{\#-)"), "{#"); + trim_template_markers_inplace(result); // Handle custom transformers-specific `generation` tag // See https://github.com/huggingface/transformers/pull/30650 for more information. diff --git a/common/jinja/jinja-parser.cpp b/common/jinja/jinja-parser.cpp index 5f42b0bd89..8cbb41eca6 100644 --- a/common/jinja/jinja-parser.cpp +++ b/common/jinja/jinja-parser.cpp @@ -26,8 +26,10 @@ class parser { // for debugging; a token can be multiple chars in source std::vector tok_pos_to_src_pos; + std::string source; // for error reporting + public: - parser(const std::vector & t) : tokens(t) { + parser(const std::vector & t, const std::string & src) : tokens(t), source(src) { tok_pos_to_src_pos.resize(tokens.size()); for (size_t i = 0; i < tokens.size(); i++) { tok_pos_to_src_pos[i] = tokens[i].pos; @@ -46,7 +48,16 @@ public: std::unique_ptr mk_stmt(Args&&... args) { auto ptr = std::make_unique(std::forward(args)...); ptr->pos = tok_pos_to_src_pos[prev_cur]; - JJ_DEBUG("Created %s statement at src pos %zu", ptr->type().c_str(), ptr->pos); + + std::string snippet = "no source"; + if (!source.empty()) { + size_t start_pos = ptr->pos; + size_t end_pos = start_pos + 20; + if (end_pos > source.size()) end_pos = source.size(); + snippet = source.substr(start_pos, end_pos - start_pos); + } + JJ_DEBUG("Created %-20s statement at src pos %-4zu (%s)", ptr->type().c_str(), ptr->pos, snippet.c_str()); + return ptr; } @@ -544,7 +555,9 @@ private: return mk_stmt(std::stoll(t.value)); case token::string_literal: { std::string val = t.value; - while (is(token::string_literal)) val += tokens[current++].value; + while (is(token::string_literal)) { + val += tokens[current++].value; + } return mk_stmt(val); } case token::identifier: @@ -575,13 +588,17 @@ private: return mk_stmt(std::move(pairs)); } default: - throw std::runtime_error("Unexpected token: " + t.value); + throw std::runtime_error("Unexpected token: " + t.value + " of type " + std::to_string(t.t)); } } }; program parse_from_tokens(const std::vector & tokens) { - return parser(tokens).parse(); + return parser(tokens, "").parse(); +} + +program parse_from_tokens(const lexer_result & lexer_res) { + return parser(lexer_res.tokens, lexer_res.preprocessed_source).parse(); } } // namespace jinja diff --git a/common/jinja/jinja-parser.h b/common/jinja/jinja-parser.h index ea212ad181..14ce135432 100644 --- a/common/jinja/jinja-parser.h +++ b/common/jinja/jinja-parser.h @@ -13,4 +13,6 @@ namespace jinja { program parse_from_tokens(const std::vector & tokens); +program parse_from_tokens(const lexer_result & lexer_res); + } // namespace jinja diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index 70cca62cff..218d893e26 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -131,23 +131,23 @@ const func_builtins & global_builtins() { if (args.args.size() < 1 || args.args.size() > 3) { throw raised_exception("slice() takes between 1 and 3 arguments"); } - int64_t arg0 = is_val(args.args[0]) ? args.args[0]->as_int() : 0; - int64_t arg1 = is_val(args.args[1]) ? args.args[1]->as_int() : -1; - int64_t arg2 = is_val(args.args[2]) ? args.args[2]->as_int() : 1; + auto & arg0 = args.args[0]; + auto & arg1 = args.args[1]; + auto & arg2 = args.args[2]; int64_t start, stop, step; if (args.args.size() == 1) { start = 0; - stop = arg0; + stop = arg0->as_int(); step = 1; } else if (args.args.size() == 2) { - start = arg0; - stop = arg1; + start = arg0->as_int(); + stop = arg1->as_int(); step = 1; } else { - start = arg0; - stop = arg1; - step = arg2; + start = arg0->as_int(); + stop = arg1->as_int(); + step = arg2->as_int(); } auto out = mk_val(); diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 045d45d980..02790945a9 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -10,7 +10,7 @@ #include #include -#define JJ_DEBUG(msg, ...) if (g_jinja_debug) printf("%s:%3d : " msg "\n", FILENAME, __LINE__, __VA_ARGS__) +#define JJ_DEBUG(msg, ...) if (g_jinja_debug) printf("%s:%-3d : " msg "\n", FILENAME, __LINE__, __VA_ARGS__) extern bool g_jinja_debug; diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 61ce80d8ac..f16ebb9e07 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -28,6 +28,8 @@ int main(void) { std::vector failed_tests; + bool stop_on_first_failure = false; + auto is_ignored_file = [](const std::string & filename) -> bool { std::vector ignored_files = { "Apriel-", @@ -64,7 +66,9 @@ int main(void) { std::cout << "Exception: " << e.what() << "\n"; std::cout << "=== ERROR WITH TEMPLATE FILE: " << entry.path().string() << " ===\n"; failed_tests.push_back(entry.path().string()); - exit(1); + if (stop_on_first_failure) { + break; + } } } } @@ -85,7 +89,7 @@ void run(std::string contents) { jinja::lexer lexer; jinja::preprocess_options options; - options.trim_blocks = true; + options.trim_blocks = false; options.lstrip_blocks = false; auto lexer_res = lexer.tokenize(contents, options); for (const auto & tok : lexer_res.tokens) { @@ -93,7 +97,7 @@ void run(std::string contents) { } std::cout << "\n=== AST ===\n"; - jinja::program ast = jinja::parse_from_tokens(lexer_res.tokens); + jinja::program ast = jinja::parse_from_tokens(lexer_res); for (const auto & stmt : ast.body) { //std::cout << "stmt type: " << stmt->type() << "\n"; } From 9e9a70f72f2361875cbe494c61b467b17ecc6df6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 29 Dec 2025 15:07:18 +0100 Subject: [PATCH 33/47] more fixes --- common/jinja/jinja-lexer.cpp | 3 ++- common/jinja/jinja-value.cpp | 11 +++++++++++ common/jinja/jinja-vm.cpp | 34 +++++++++++++++++++++------------ common/jinja/jinja-vm.h | 8 ++++---- common/jinja/jinja-workaround.h | 4 ++++ tests/test-chat-jinja.cpp | 19 ++++++++++++++---- 6 files changed, 58 insertions(+), 21 deletions(-) diff --git a/common/jinja/jinja-lexer.cpp b/common/jinja/jinja-lexer.cpp index 285ccc0151..189f8f5b10 100644 --- a/common/jinja/jinja-lexer.cpp +++ b/common/jinja/jinja-lexer.cpp @@ -105,7 +105,8 @@ std::string lexer::preprocess(const std::string & template_str, const preprocess // Handle custom transformers-specific `generation` tag // See https://github.com/huggingface/transformers/pull/30650 for more information. - // result = std::regex_replace(result, std::regex(R"((?s)\{%\s*generation\s*%\}.+?\{%\s*endgeneration\s*%\})"), ""); + result = std::regex_replace(result, std::regex(R"(\{%\s*generation\s*%\})"), ""); + result = std::regex_replace(result, std::regex(R"(\{%\s*endgeneration\s*%\})"), ""); return result; } diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index 218d893e26..688f6cdb0f 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -404,6 +404,17 @@ const func_builtins & value_string_t::get_builtins() const { res->val_str.mark_input_based_on(input->as_string()); return res; }}, + {"selectattr", [](const func_args & args) -> value { + if (args.ctx.wrk_around.string_has_selectattr) { + // no-op, return an array containing the original string + args.ensure_vals(); + auto result = mk_val(); + result->push_back(args.args[0]); + return result; + } else { + throw raised_exception("String selectattr builtin not supported"); + } + }}, {"indent", [](const func_args &) -> value { throw std::runtime_error("String indent builtin not implemented"); }}, diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 94ee370029..8797b866f4 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -35,6 +35,10 @@ static value_string exec_statements(const statements & stmts, context & ctx) { value statement::execute(context & ctx) { try { return execute_impl(ctx); + } catch (const continue_statement::signal & ex) { + throw ex; + } catch (const break_statement::signal & ex) { + throw ex; } catch (const std::exception & e) { if (ctx.source.empty()) { std::ostringstream oss; @@ -359,15 +363,17 @@ value if_statement::execute_impl(context & ctx) { value for_statement::execute_impl(context & ctx) { context scope(ctx); // new scope for loop variables - statement_ptr iter_expr = std::move(iterable); - statement_ptr test_expr = nullptr; + jinja::select_expression * select_expr = cast_stmt(iterable); + statement_ptr test_expr_nullptr; - if (is_stmt(iterable)) { - JJ_DEBUG("%s", "For loop has test expression"); - auto select = cast_stmt(iterable); - iter_expr = std::move(select->lhs); - test_expr = std::move(select->test); - } + statement_ptr & iter_expr = [&]() -> statement_ptr & { + auto tmp = cast_stmt(iterable); + return tmp ? tmp->lhs : iterable; + }(); + statement_ptr & test_expr = [&]() -> statement_ptr & { + auto tmp = cast_stmt(iterable); + return tmp ? tmp->test : test_expr_nullptr; + }(); JJ_DEBUG("Executing for statement, iterable type: %s", iter_expr->type().c_str()); @@ -436,21 +442,23 @@ value for_statement::execute_impl(context & ctx) { } else { throw std::runtime_error("Invalid loop variable(s): " + loopvar->type()); } - if (test_expr) { + if (select_expr && test_expr) { scope_update_fn(loop_scope); value test_val = test_expr->execute(loop_scope); if (!test_val->as_bool()) { continue; } } + JJ_DEBUG("For loop: adding item type %s at index %zu", current->type().c_str(), i); filtered_items.push_back(current); scope_update_fns.push_back(scope_update_fn); } + JJ_DEBUG("For loop: %zu items after filtering", filtered_items.size()); auto result = mk_val(); bool noIteration = true; - for (size_t i = 0; i < filtered_items.size(); ++i) { + for (size_t i = 0; i < filtered_items.size(); i++) { JJ_DEBUG("For loop iteration %zu/%zu", i + 1, filtered_items.size()); value_object loop_obj = mk_val(); loop_obj->insert("index", mk_val(i + 1)); @@ -469,13 +477,15 @@ value for_statement::execute_impl(context & ctx) { value val = stmt->execute(ctx); result->push_back(val); } - } catch (const continue_statement::exception &) { + } catch (const continue_statement::signal &) { continue; - } catch (const break_statement::exception &) { + } catch (const break_statement::signal &) { break; } noIteration = false; } + + JJ_DEBUG("For loop complete, total iterations: %zu", filtered_items.size()); if (noIteration) { for (auto & stmt : default_block) { value val = stmt->execute(ctx); diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 02790945a9..1526a365a1 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -160,28 +160,28 @@ struct for_statement : public statement { struct break_statement : public statement { std::string type() const override { return "Break"; } - struct exception : public std::exception { + struct signal : public std::exception { const char* what() const noexcept override { return "Break statement executed"; } }; value execute_impl(context &) override { - throw break_statement::exception(); + throw break_statement::signal(); } }; struct continue_statement : public statement { std::string type() const override { return "Continue"; } - struct exception : public std::exception { + struct signal : public std::exception { const char* what() const noexcept override { return "Continue statement executed"; } }; value execute_impl(context &) override { - throw continue_statement::exception(); + throw continue_statement::signal(); } }; diff --git a/common/jinja/jinja-workaround.h b/common/jinja/jinja-workaround.h index 766132c0ca..ed7e92df45 100644 --- a/common/jinja/jinja-workaround.h +++ b/common/jinja/jinja-workaround.h @@ -8,6 +8,7 @@ namespace jinja { // containing workarounds for Jinja templates that rely on non-standard behavior +// NOTE: this is kept as a dedicated file for better documentation struct workarounds { // meetkai-functionary-medium-v3.1.jinja call filter on None type @@ -15,6 +16,9 @@ struct workarounds { // Olmo calls operation + between string and undefined bool string_plus_undefined_is_string = true; + + // sheldonrobinson-Llama-Guard call selectattr on string + bool string_has_selectattr = true; }; } // namespace jinja diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index f16ebb9e07..0e2f5e4faa 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -34,6 +34,10 @@ int main(void) { std::vector ignored_files = { "Apriel-", "Olmo-3-7B-Instruct-Heretic-GGUF", + "sheldonrobinson-Llama-Guard", + "deepseek-community-Janus-Pro-1B", + "bitshrine-gemma-2-2B-function-calling", + "PaddlePaddle-PaddleOCR-VL", }; for (const auto & ignored : ignored_files) { if (filename.find(ignored) != std::string::npos) { @@ -119,11 +123,18 @@ void run(std::string contents) { ], "bos_token": "", "eos_token": "", - "tools": [], - "functions": "", - "datetime": "" + "tools": [] })"; - jinja::global_from_json(ctx, nlohmann::json::parse(json_inp)); + auto input_json = nlohmann::json::parse(json_inp); + + // workaround for functionary models + input_json["functions"] = ""; + input_json["datetime"] = ""; + + // workaround for Llama Guard models + input_json["excluded_category_keys"] = nlohmann::json::array(); + + jinja::global_from_json(ctx, input_json); jinja::vm vm(ctx); const jinja::value results = vm.execute(ast); From 9c0fa6f81001e14b1d2224da7f9b6094c8520845 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 30 Dec 2025 16:07:23 +0100 Subject: [PATCH 34/47] rm workarounds --- common/jinja/jinja-value.cpp | 32 ++++++-------------------------- common/jinja/jinja-value.h | 6 ------ common/jinja/jinja-vm.cpp | 18 +++++++++--------- common/jinja/jinja-vm.h | 3 --- common/jinja/jinja-workaround.h | 24 ------------------------ tests/test-chat-jinja.cpp | 27 +++++++++++++++------------ 6 files changed, 30 insertions(+), 80 deletions(-) delete mode 100644 common/jinja/jinja-workaround.h diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index 688f6cdb0f..2c9ce6c76c 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -404,16 +404,11 @@ const func_builtins & value_string_t::get_builtins() const { res->val_str.mark_input_based_on(input->as_string()); return res; }}, - {"selectattr", [](const func_args & args) -> value { - if (args.ctx.wrk_around.string_has_selectattr) { - // no-op, return an array containing the original string - args.ensure_vals(); - auto result = mk_val(); - result->push_back(args.args[0]); - return result; - } else { - throw raised_exception("String selectattr builtin not supported"); - } + {"selectattr", [](const func_args &) -> value { + throw std::runtime_error("String selectattr builtin not supported"); + }}, + {"rejectattr", [](const func_args &) -> value { + throw std::runtime_error("String rejectattr builtin not supported"); }}, {"indent", [](const func_args &) -> value { throw std::runtime_error("String indent builtin not implemented"); @@ -662,22 +657,7 @@ const func_builtins & value_object_t::get_builtins() const { const func_builtins & value_null_t::get_builtins() const { static const func_builtins builtins = { - {"list", [](const func_args & args) -> value { - // fix for meetkai-functionary-medium-v3.1.jinja - if (args.ctx.wrk_around.none_has_builtins) { - return mk_val(); - } else { - throw raised_exception("'list' builtin not supported for none type"); - } - }}, - {"selectattr", [](const func_args & args) -> value { - // fix for meetkai-functionary-medium-v3.1.jinja - if (args.ctx.wrk_around.none_has_builtins) { - return mk_val(); - } else { - throw raised_exception("'selectattr' builtin not supported for none type"); - } - }}, + // TODO: may need to implement this, idk }; return builtins; } diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 3289a0de59..7c7d98d932 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -132,12 +132,6 @@ struct value_t { string val_str; bool val_bool; - // array and object are stored as shared_ptr to allow reference access - // example: - // my_obj = {"a": 1, "b": 2} - // my_arr = [my_obj] - // my_obj["a"] = 3 - // print(my_arr[0]["a"]) # should print 3 std::vector val_arr; std::map val_obj; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 8797b866f4..b99fc605f0 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -113,15 +113,15 @@ value binary_expression::execute_impl(context & ctx) { // Special case: `anything in undefined` is `false` and `anything not in undefined` is `true` return mk_val(op.value == "not in"); } - if (ctx.wrk_around.string_plus_undefined_is_string && (op.value == "+" || op.value == "~")) { - JJ_DEBUG("%s", "Workaround: treating undefined as empty string for string concatenation"); - auto left_str = left_val->is_undefined() ? string() : left_val->as_string(); - auto right_str = right_val->is_undefined() ? string() : right_val->as_string(); - auto output = left_str.append(right_str); - auto res = mk_val(); - res->val_str = std::move(output); - return res; - } + // if (ctx.wrk_around.string_plus_undefined_is_string && (op.value == "+" || op.value == "~")) { + // JJ_DEBUG("%s", "Workaround: treating undefined as empty string for string concatenation"); + // auto left_str = left_val->is_undefined() ? string() : left_val->as_string(); + // auto right_str = right_val->is_undefined() ? string() : right_val->as_string(); + // auto output = left_str.append(right_str); + // auto res = mk_val(); + // res->val_str = std::move(output); + // return res; + // } throw std::runtime_error("Cannot perform operation " + op.value + " on undefined values"); } else if (is_val(left_val) || is_val(right_val)) { throw std::runtime_error("Cannot perform operation on null values"); diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 1526a365a1..1095d71870 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -2,7 +2,6 @@ #include "jinja-lexer.h" #include "jinja-value.h" -#include "jinja-workaround.h" #include #include @@ -53,8 +52,6 @@ struct context { std::time_t current_time; // for functions that need current time - workarounds wrk_around; // workarounds for non-standard jinja behavior - context() { var["true"] = mk_val(true); var["false"] = mk_val(false); diff --git a/common/jinja/jinja-workaround.h b/common/jinja/jinja-workaround.h deleted file mode 100644 index ed7e92df45..0000000000 --- a/common/jinja/jinja-workaround.h +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once - -#include "jinja-value.h" - -#include -#include - -namespace jinja { - -// containing workarounds for Jinja templates that rely on non-standard behavior -// NOTE: this is kept as a dedicated file for better documentation - -struct workarounds { - // meetkai-functionary-medium-v3.1.jinja call filter on None type - bool none_has_builtins = true; - - // Olmo calls operation + between string and undefined - bool string_plus_undefined_is_string = true; - - // sheldonrobinson-Llama-Guard call selectattr on string - bool string_has_selectattr = true; -}; - -} // namespace jinja diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 0e2f5e4faa..0ab18c0f4f 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -14,7 +14,8 @@ #include "jinja/jinja-parser.h" #include "jinja/jinja-lexer.h" -void run(std::string contents); +void run_multiple(); +void run_single(std::string contents); int main(void) { //std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; @@ -24,8 +25,16 @@ int main(void) { //std::string contents = " {{ messages[a]['content'] }} "; //std::string contents = "{% if a is not defined %}hello{% endif %}"; - //std::ifstream infile("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja"); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); + std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja"); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); + run_single(contents); + + //run_multiple(); + + return 0; +} + +void run_multiple(void) { std::vector failed_tests; bool stop_on_first_failure = false; @@ -65,7 +74,7 @@ int main(void) { std::ifstream infile(entry.path()); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); try { - run(contents); + run_single(contents); } catch (const std::exception & e) { std::cout << "Exception: " << e.what() << "\n"; std::cout << "=== ERROR WITH TEMPLATE FILE: " << entry.path().string() << " ===\n"; @@ -84,27 +93,21 @@ int main(void) { for (const auto & test : failed_tests) { std::cout << "FAILED TEST: " << test << "\n"; } - return 0; } -void run(std::string contents) { +void run_single(std::string contents) { jinja::enable_debug(true); + // lexing jinja::lexer lexer; jinja::preprocess_options options; options.trim_blocks = false; options.lstrip_blocks = false; auto lexer_res = lexer.tokenize(contents, options); - for (const auto & tok : lexer_res.tokens) { - //std::cout << "token: type=" << static_cast(tok.t) << " text='" << tok.value << "' pos=" << tok.pos << "\n"; - } - std::cout << "\n=== AST ===\n"; + // compile to AST jinja::program ast = jinja::parse_from_tokens(lexer_res); - for (const auto & stmt : ast.body) { - //std::cout << "stmt type: " << stmt->type() << "\n"; - } std::cout << "\n=== RUN ===\n"; jinja::context ctx; From 4479c382ce611eec159bd3d529d854fa1c5df864 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 30 Dec 2025 17:26:23 +0100 Subject: [PATCH 35/47] demo: type inferrence --- common/jinja/jinja-type-infer.h | 38 +++++++++++++++++++ common/jinja/jinja-value.cpp | 2 +- common/jinja/jinja-value.h | 28 ++++++++++++++ common/jinja/jinja-vm.cpp | 35 ++++++++++------- common/jinja/jinja-vm.h | 67 +++++++++++++++++++++++++++++---- tests/test-chat-jinja.cpp | 19 ++++++++++ 6 files changed, 167 insertions(+), 22 deletions(-) create mode 100644 common/jinja/jinja-type-infer.h diff --git a/common/jinja/jinja-type-infer.h b/common/jinja/jinja-type-infer.h new file mode 100644 index 0000000000..3f7508787f --- /dev/null +++ b/common/jinja/jinja-type-infer.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +#include "jinja-value.h" + +namespace jinja { + +struct value_t; +using value = std::shared_ptr; + +// this is used as a hint for chat parsing +// it is not a 1-to-1 mapping to value_t derived types +enum class inferred_type { + numeric, // int, float + string, + boolean, + array, + object, + optional, // null, undefined + unknown, +}; + +static std::string inferred_type_to_string(inferred_type type) { + switch (type) { + case inferred_type::numeric: return "numeric"; + case inferred_type::string: return "string"; + case inferred_type::boolean: return "boolean"; + case inferred_type::array: return "array"; + case inferred_type::object: return "object"; + case inferred_type::optional: return "optional"; + case inferred_type::unknown: return "unknown"; + default: return "invalid"; + } +} + +} // namespace jinja diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index 2c9ce6c76c..5a515fc8e4 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -708,7 +708,7 @@ void global_from_json(context & ctx, const nlohmann::json & json_obj) { throw std::runtime_error("global_from_json: input JSON value must be an object"); } for (auto it = json_obj.begin(); it != json_obj.end(); ++it) { - ctx.var[it.key()] = from_json(it.value()); + ctx.set_val(it.key(), from_json(it.value())); } } diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 7c7d98d932..77d30c82f7 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -6,8 +6,10 @@ #include #include #include +#include #include "jinja-string.h" +#include "jinja-type-infer.h" namespace jinja { @@ -137,6 +139,10 @@ struct value_t { func_handler val_func; + // for type inference + std::set inf_types; + std::vector inf_vals; + value_t() = default; value_t(const value_t &) = default; virtual ~value_t() = default; @@ -333,4 +339,26 @@ using value_kwarg = std::shared_ptr; const func_builtins & global_builtins(); + +// utils + +static inferred_type value_to_inferred_type(const value & val) { + if (is_val(val) || is_val(val)) { + return inferred_type::numeric; + } else if (is_val(val)) { + return inferred_type::string; + } else if (is_val(val)) { + return inferred_type::boolean; + } else if (is_val(val)) { + return inferred_type::array; + } else if (is_val(val)) { + return inferred_type::object; + } else if (is_val(val) || is_val(val)) { + return inferred_type::optional; + } else { + return inferred_type::unknown; + } +} + + } // namespace jinja diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index b99fc605f0..ed98f1d050 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -63,11 +63,11 @@ value statement::execute(context & ctx) { } value identifier::execute_impl(context & ctx) { - auto it = ctx.var.find(val); + auto it = ctx.get_val(val); auto builtins = global_builtins(); - if (it != ctx.var.end()) { + if (!it->is_undefined()) { JJ_DEBUG("Identifier '%s' found", val.c_str()); - return it->second; + return it; } else if (builtins.find(val) != builtins.end()) { JJ_DEBUG("Identifier '%s' found in builtins", val.c_str()); return mk_val(builtins.at(val), val); @@ -102,6 +102,8 @@ value binary_expression::execute_impl(context & ctx) { value right_val = right->execute(ctx); JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right_val->type().c_str()); if (op.value == "==") { + ctx.mark_known_type(left_val, right_val); + ctx.mark_known_type(right_val, left_val); return mk_val(value_compare(left_val, right_val)); } else if (op.value == "!=") { return mk_val(!value_compare(left_val, right_val)); @@ -342,6 +344,10 @@ value unary_expression::execute_impl(context & ctx) { value if_statement::execute_impl(context & ctx) { value test_val = test->execute(ctx); + + ctx.mark_known_type(test_val, inferred_type::boolean); + ctx.mark_known_type(test_val, inferred_type::optional); + auto out = mk_val(); if (test_val->as_bool()) { for (auto & stmt : body) { @@ -384,6 +390,9 @@ value for_statement::execute_impl(context & ctx) { iterable_val = mk_val(); } + ctx.mark_known_type(iterable_val, inferred_type::array); + ctx.mark_known_type(iterable_val, inferred_type::object); + if (!is_val(iterable_val) && !is_val(iterable_val)) { throw std::runtime_error("Expected iterable or object type in for loop: got " + iterable_val->type()); } @@ -418,7 +427,7 @@ value for_statement::execute_impl(context & ctx) { if (is_stmt(loopvar)) { auto id = cast_stmt(loopvar)->val; scope_update_fn = [id, &items, i](context & ctx) { - ctx.var[id] = items[i]; + ctx.set_val(id, items[i]); }; } else if (is_stmt(loopvar)) { auto tuple = cast_stmt(loopvar); @@ -436,7 +445,7 @@ value for_statement::execute_impl(context & ctx) { throw std::runtime_error("Cannot unpack non-identifier type: " + tuple->val[j]->type()); } auto id = cast_stmt(tuple->val[j])->val; - ctx.var[id] = c_arr[j]; + ctx.set_val(id, c_arr[j]); } }; } else { @@ -470,11 +479,11 @@ value for_statement::execute_impl(context & ctx) { loop_obj->insert("length", mk_val(filtered_items.size())); loop_obj->insert("previtem", i > 0 ? filtered_items[i - 1] : mk_val("previtem")); loop_obj->insert("nextitem", i < filtered_items.size() - 1 ? filtered_items[i + 1] : mk_val("nextitem")); - ctx.var["loop"] = loop_obj; - scope_update_fns[i](ctx); + scope.set_val("loop", loop_obj); + scope_update_fns[i](scope); try { for (auto & stmt : body) { - value val = stmt->execute(ctx); + value val = stmt->execute(scope); result->push_back(val); } } catch (const continue_statement::signal &) { @@ -505,7 +514,7 @@ value set_statement::execute_impl(context & ctx) { if (is_stmt(assignee)) { auto var_name = cast_stmt(assignee)->val; JJ_DEBUG("Setting variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str()); - ctx.var[var_name] = rhs; + ctx.set_val(var_name, rhs); } else if (is_stmt(assignee)) { auto tuple = cast_stmt(assignee); @@ -522,7 +531,7 @@ value set_statement::execute_impl(context & ctx) { throw std::runtime_error("Cannot unpack to non-identifier in set: " + elem->type()); } auto var_name = cast_stmt(elem)->val; - ctx.var[var_name] = arr[i]; + ctx.set_val(var_name, arr[i]); } } else if (is_stmt(assignee)) { @@ -564,14 +573,14 @@ value macro_statement::execute_impl(context & ctx) { if (i < input_count) { std::string param_name = cast_stmt(this->args[i])->val; JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), args.args[i]->type().c_str()); - macro_ctx.var[param_name] = args.args[i]; + macro_ctx.set_val(param_name, args.args[i]); } else { auto & default_arg = this->args[i]; if (is_stmt(default_arg)) { auto kwarg = cast_stmt(default_arg); std::string param_name = cast_stmt(kwarg->key)->val; JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str()); - macro_ctx.var[param_name] = kwarg->val->execute(ctx); + macro_ctx.set_val(param_name, kwarg->val->execute(ctx)); } else { throw std::runtime_error("Not enough arguments provided to macro '" + name + "'"); } @@ -589,7 +598,7 @@ value macro_statement::execute_impl(context & ctx) { }; JJ_DEBUG("Defining macro '%s' with %zu parameters", name.c_str(), args.size()); - ctx.var[name] = mk_val(func); + ctx.set_val(name, mk_val(func)); return mk_val(); } diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 1095d71870..bb24abad96 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -47,23 +47,74 @@ const T * cast_stmt(const statement_ptr & ptr) { void enable_debug(bool enable); struct context { - std::map var; std::string source; // for debugging - std::time_t current_time; // for functions that need current time context() { - var["true"] = mk_val(true); - var["false"] = mk_val(false); - var["none"] = mk_val(); + global = mk_val(); + global->insert("true", mk_val(true)); + global->insert("false", mk_val(false)); + global->insert("none", mk_val()); current_time = std::time(nullptr); } ~context() = default; - context(const context & parent) { + context(const context & parent) : context() { // inherit variables (for example, when entering a new scope) - for (const auto & pair : parent.var) { - var[pair.first] = pair.second; + auto & pvar = parent.global->as_object(); + for (const auto & pair : pvar) { + set_val(pair.first, pair.second); + } + } + + value get_val(const std::string & name) { + auto it = global->val_obj.find(name); + if (it != global->val_obj.end()) { + return it->second; + } else { + return mk_val(name); + } + } + + void set_val(const std::string & name, const value & val) { + global->insert(name, val); + set_flattened_global_recursively(name, val); + } + + void mark_known_type(value & val, inferred_type type) { + val->inf_types.insert(type); + } + + void mark_known_type(value & val, value & known_val) { + mark_known_type(val, value_to_inferred_type(known_val)); + val->inf_vals.push_back(known_val); + } + + // FOR TESTING ONLY + const value_object & get_global_object() const { + return global; + } + +private: + value_object global; + +public: + std::map flatten_globals; // for debugging + void set_flattened_global_recursively(std::string path, const value & val) { + flatten_globals[path] = val; + if (is_val(val)) { + auto & obj = val->as_object(); + for (const auto & pair : obj) { + flatten_globals[pair.first] = pair.second; + set_flattened_global_recursively(pair.first, pair.second); + } + } else if (is_val(val)) { + auto & arr = val->as_array(); + for (size_t i = 0; i < arr.size(); ++i) { + std::string idx_path = path + "[" + std::to_string(i) + "]"; + flatten_globals[idx_path] = arr[i]; + set_flattened_global_recursively(idx_path, arr[i]); + } } } }; diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 0ab18c0f4f..39ce9fed00 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -13,6 +13,7 @@ #include "jinja/jinja-parser.h" #include "jinja/jinja-lexer.h" +#include "jinja/jinja-type-infer.h" void run_multiple(); void run_single(std::string contents); @@ -147,4 +148,22 @@ void run_single(std::string contents) { for (const auto & part : parts.get()->val_str.parts) { std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n"; } + + std::cout << "\n=== TYPES ===\n"; + auto & global_obj = ctx.flatten_globals; + for (const auto & pair : global_obj) { + std::string name = pair.first; + std::string inf_types; + for (const auto & t : pair.second->inf_types) { + inf_types += inferred_type_to_string(t) + " "; + } + if (inf_types.empty()) { + continue; + } + std::string inf_vals; + for (const auto & v : pair.second->inf_vals) { + inf_vals += v->as_string().str() + " ; "; + } + printf("Var: %-20s | Types: %-10s | Vals: %s\n", name.c_str(), inf_types.c_str(), inf_vals.c_str()); + } } From 1b213ae5e78ff3e7c80ee061accaa099cde79465 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 30 Dec 2025 21:52:47 +0100 Subject: [PATCH 36/47] add placeholder for tojson --- common/jinja/jinja-value.cpp | 12 +++++++++++- common/jinja/jinja-vm.cpp | 11 ++++++++++- common/jinja/jinja-vm.h | 5 +++-- tests/test-chat-jinja.cpp | 15 +++++++++++++++ 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index 5a515fc8e4..6c3d9249b3 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -165,6 +165,11 @@ const func_builtins & global_builtins() { } return out; }}, + {"tojson", [](const func_args & args) -> value { + args.ensure_count(1); + // placeholder implementation + return mk_val("TODO: to_json output"); + }}, // tests {"test_is_boolean", test_type_fn}, @@ -646,7 +651,12 @@ const func_builtins & value_object_t::get_builtins() const { } return result; }}, - {{"dictsort"}, [](const func_args & args) -> value { + {"tojson", [](const func_args & args) -> value { + args.ensure_vals(); + // use global to_json + return global_builtins().at("tojson")(args); + }}, + {"dictsort", [](const func_args & args) -> value { // no-op args.ensure_vals(); return args.args[0]; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index ed98f1d050..d6958a54c9 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -312,10 +312,19 @@ value test_expression::execute_impl(context & ctx) { throw std::runtime_error("Unknown test '" + test_id + "'"); } + value input = operand->execute(ctx); + func_args args(ctx); - args.args.push_back(operand->execute(ctx)); + args.args.push_back(input); auto res = it->second(args); + // hack: allow type inference + if (test_id == "defined" || test_id == "undefined" || test_id == "none") { + ctx.mark_known_type(input, inferred_type::optional); + } else if (test_id == "string") { + ctx.mark_known_type(input, inferred_type::string); + } + if (negate) { return mk_val(!res->as_bool()); } else { diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index bb24abad96..0ac2e5f16a 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -105,8 +105,9 @@ public: if (is_val(val)) { auto & obj = val->as_object(); for (const auto & pair : obj) { - flatten_globals[pair.first] = pair.second; - set_flattened_global_recursively(pair.first, pair.second); + std::string child_path = path + "." + pair.first; + flatten_globals[child_path] = pair.second; + set_flattened_global_recursively(child_path, pair.second); } } else if (is_val(val)) { auto & arr = val->as_array(); diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 39ce9fed00..c205b150cf 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -123,6 +123,21 @@ void run_single(std::string contents) { { "role": "assistant", "content": {"__input__": "I am fine, thank you!"} + }, + { + "role": "assistant", + "content": "Calling weather tool.", + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": { + "location": "New York", + "unit": "celsius" + } + } + } + ] } ], "bos_token": "", From cbb37dd4cda2891cdf61367546cc98d1875f29fe Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 31 Dec 2025 11:29:40 +0100 Subject: [PATCH 37/47] improve function args handling --- common/jinja/jinja-value.cpp | 7 +- common/jinja/jinja-value.h | 137 ++++++++++++++++++----------------- common/jinja/jinja-vm.cpp | 6 +- tests/test-chat-jinja.cpp | 4 +- 4 files changed, 78 insertions(+), 76 deletions(-) diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index 6c3d9249b3..270caafede 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -115,7 +115,6 @@ const func_builtins & global_builtins() { return out; }}, {"strftime_now", [](const func_args & args) -> value { - args.ensure_count(1); args.ensure_vals(); std::string format = args.args[0]->as_string().str(); // get current time @@ -128,9 +127,9 @@ const func_builtins & global_builtins() { } }}, {"range", [](const func_args & args) -> value { - if (args.args.size() < 1 || args.args.size() > 3) { - throw raised_exception("slice() takes between 1 and 3 arguments"); - } + args.ensure_count(1, 3); + args.ensure_vals(true, false, false); + auto & arg0 = args.args[0]; auto & arg1 = args.args[1]; auto & arg2 = args.args[2]; diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 77d30c82f7..6be5160a89 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -51,12 +51,6 @@ typename extract_pointee::type * cast_val(value & ptr) { using PointeeType = typename extract_pointee::type; return dynamic_cast(ptr.get()); } -template -void ensure_val(const value & ptr) { - if (!is_val(ptr)) { - throw std::runtime_error("Expected value of type " + std::string(typeid(T).name())); - } -} // End Helper @@ -92,36 +86,11 @@ struct context; // forward declaration template void global_from_json(context & ctx, const T_JSON & json_obj); +// +// base value type +// - -struct func_args { - std::vector args; - context & ctx; - func_args(context & ctx) : ctx(ctx) {} - void ensure_count(size_t min, size_t max = 999) const { - if (args.size() < min || args.size() > max) { - throw std::runtime_error("Expected between " + std::to_string(min) + " and " + std::to_string(max) + " arguments, got " + std::to_string(args.size())); - } - } - value get_kwarg(const std::string & key) const; - // utility functions - // TODO: allow optional arguments - template void ensure_vals() const { - ensure_count(1); - ensure_val(args[0]); - } - template void ensure_vals() const { - ensure_count(2); - ensure_val(args[0]); - ensure_val(args[1]); - } - template void ensure_vals() const { - ensure_count(3); - ensure_val(args[0]); - ensure_val(args[1]); - ensure_val(args[2]); - } -}; +struct func_args; // function argument values using func_handler = std::function; using func_builtins = std::map; @@ -165,6 +134,9 @@ struct value_t { virtual std::string as_repr() const { return as_string().str(); } }; +// +// primitive value types +// struct value_int_t : public value_t { value_int_t(int64_t v) { val_int = v; } @@ -275,36 +247,9 @@ struct value_object_t : public value_t { }; using value_object = std::shared_ptr; - -struct value_func_t : public value_t { - std::string name; // for debugging - value arg0; // bound "this" argument, if any - value_func_t(const func_handler & func, std::string func_name = "") { - val_func = func; - name = func_name; - } - value_func_t(const func_handler & func, const value & arg_this, std::string func_name = "") { - val_func = func; - name = func_name; - arg0 = arg_this; - } - virtual value invoke(const func_args & args) const override { - if (arg0) { - func_args new_args(args.ctx); - new_args.args.push_back(arg0); - for (const auto & a : args.args) { - new_args.args.push_back(a); - } - return val_func(new_args); - } else { - return val_func(args); - } - } - virtual std::string type() const override { return "Function"; } - virtual std::string as_repr() const override { return type(); } -}; -using value_func = std::shared_ptr; - +// +// null and undefined types +// struct value_null_t : public value_t { virtual std::string type() const override { return "Null"; } @@ -326,6 +271,63 @@ struct value_undefined_t : public value_t { }; using value_undefined = std::shared_ptr; +// +// function type +// + +struct func_args { + std::string func_name; // for error messages + std::vector args; + context & ctx; + func_args(context & ctx) : ctx(ctx) {} + value get_kwarg(const std::string & key) const; + void ensure_count(size_t min, size_t max = 999) const { + size_t n = args.size(); + if (n < min || n > max) { + throw std::runtime_error("Function '" + func_name + "' expected between " + std::to_string(min) + " and " + std::to_string(max) + " arguments, got " + std::to_string(n)); + } + } + template void ensure_val(const value & ptr) const { + if (!is_val(ptr)) { + throw std::runtime_error("Function '" + func_name + "' expected value of type " + std::string(typeid(T).name()) + ", got " + ptr->type()); + } + } + template void ensure_vals(bool required0 = true) const { + if (required0 && args.size() > 0) ensure_val(args[0]); + } + template void ensure_vals(bool required0 = true, bool required1 = true) const { + if (required0 && args.size() > 0) ensure_val(args[0]); + if (required1 && args.size() > 1) ensure_val(args[1]); + } + template void ensure_vals(bool required0 = true, bool required1 = true, bool required2 = true) const { + if (required0 && args.size() > 0) ensure_val(args[0]); + if (required1 && args.size() > 1) ensure_val(args[1]); + if (required2 && args.size() > 2) ensure_val(args[2]); + } +}; + +struct value_func_t : public value_t { + std::string name; + value arg0; // bound "this" argument, if any + value_func_t(const std::string & name, const func_handler & func) : name(name) { + val_func = func; + } + value_func_t(const std::string & name, const func_handler & func, const value & arg_this) : name(name), arg0(arg_this) { + val_func = func; + } + virtual value invoke(const func_args & args) const override { + func_args new_args(args); // copy + new_args.func_name = name; + if (arg0) { + new_args.args.insert(new_args.args.begin(), arg0); + } + return val_func(new_args); + } + virtual std::string type() const override { return "Function"; } + virtual std::string as_repr() const override { return type(); } +}; +using value_func = std::shared_ptr; + // special value for kwarg struct value_kwarg_t : public value_t { std::string key; @@ -337,11 +339,10 @@ struct value_kwarg_t : public value_t { using value_kwarg = std::shared_ptr; -const func_builtins & global_builtins(); - - // utils +const func_builtins & global_builtins(); + static inferred_type value_to_inferred_type(const value & val) { if (is_val(val) || is_val(val)) { return inferred_type::numeric; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index d6958a54c9..89dd49ed0a 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -70,7 +70,7 @@ value identifier::execute_impl(context & ctx) { return it; } else if (builtins.find(val) != builtins.end()) { JJ_DEBUG("Identifier '%s' found in builtins", val.c_str()); - return mk_val(builtins.at(val), val); + return mk_val(val, builtins.at(val)); } else { JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str()); return mk_val(val); @@ -243,7 +243,7 @@ static value try_builtin_func(const std::string & name, const value & input, boo auto it = builtins.find(name); if (it != builtins.end()) { JJ_DEBUG("Binding built-in '%s'", name.c_str()); - return mk_val(it->second, input, name); + return mk_val(name, it->second, input); } if (undef_on_missing) { return mk_val(name); @@ -607,7 +607,7 @@ value macro_statement::execute_impl(context & ctx) { }; JJ_DEBUG("Defining macro '%s' with %zu parameters", name.c_str(), args.size()); - ctx.set_val(name, mk_val(func)); + ctx.set_val(name, mk_val(name, func)); return mk_val(); } diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index c205b150cf..b6a9a4a766 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -26,7 +26,9 @@ int main(void) { //std::string contents = " {{ messages[a]['content'] }} "; //std::string contents = "{% if a is not defined %}hello{% endif %}"; - std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja"); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); + std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja"); + //std::ifstream infile("models/templates/Kimi-K2-Thinking.jinja"); + std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); run_single(contents); From d34efd9626230900af68df29e9e764b6a1e84feb Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 31 Dec 2025 11:43:53 +0100 Subject: [PATCH 38/47] rm type inference --- common/jinja/jinja-type-infer.h | 38 --------------------------------- common/jinja/jinja-value.h | 23 -------------------- common/jinja/jinja-vm.cpp | 15 ------------- common/jinja/jinja-vm.h | 36 ------------------------------- tests/test-chat-jinja.cpp | 19 ----------------- 5 files changed, 131 deletions(-) delete mode 100644 common/jinja/jinja-type-infer.h diff --git a/common/jinja/jinja-type-infer.h b/common/jinja/jinja-type-infer.h deleted file mode 100644 index 3f7508787f..0000000000 --- a/common/jinja/jinja-type-infer.h +++ /dev/null @@ -1,38 +0,0 @@ -#pragma once - -#include -#include - -#include "jinja-value.h" - -namespace jinja { - -struct value_t; -using value = std::shared_ptr; - -// this is used as a hint for chat parsing -// it is not a 1-to-1 mapping to value_t derived types -enum class inferred_type { - numeric, // int, float - string, - boolean, - array, - object, - optional, // null, undefined - unknown, -}; - -static std::string inferred_type_to_string(inferred_type type) { - switch (type) { - case inferred_type::numeric: return "numeric"; - case inferred_type::string: return "string"; - case inferred_type::boolean: return "boolean"; - case inferred_type::array: return "array"; - case inferred_type::object: return "object"; - case inferred_type::optional: return "optional"; - case inferred_type::unknown: return "unknown"; - default: return "invalid"; - } -} - -} // namespace jinja diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 6be5160a89..6483d460a3 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -9,7 +9,6 @@ #include #include "jinja-string.h" -#include "jinja-type-infer.h" namespace jinja { @@ -108,10 +107,6 @@ struct value_t { func_handler val_func; - // for type inference - std::set inf_types; - std::vector inf_vals; - value_t() = default; value_t(const value_t &) = default; virtual ~value_t() = default; @@ -343,23 +338,5 @@ using value_kwarg = std::shared_ptr; const func_builtins & global_builtins(); -static inferred_type value_to_inferred_type(const value & val) { - if (is_val(val) || is_val(val)) { - return inferred_type::numeric; - } else if (is_val(val)) { - return inferred_type::string; - } else if (is_val(val)) { - return inferred_type::boolean; - } else if (is_val(val)) { - return inferred_type::array; - } else if (is_val(val)) { - return inferred_type::object; - } else if (is_val(val) || is_val(val)) { - return inferred_type::optional; - } else { - return inferred_type::unknown; - } -} - } // namespace jinja diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 89dd49ed0a..2a679517e8 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -102,8 +102,6 @@ value binary_expression::execute_impl(context & ctx) { value right_val = right->execute(ctx); JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right_val->type().c_str()); if (op.value == "==") { - ctx.mark_known_type(left_val, right_val); - ctx.mark_known_type(right_val, left_val); return mk_val(value_compare(left_val, right_val)); } else if (op.value == "!=") { return mk_val(!value_compare(left_val, right_val)); @@ -318,13 +316,6 @@ value test_expression::execute_impl(context & ctx) { args.args.push_back(input); auto res = it->second(args); - // hack: allow type inference - if (test_id == "defined" || test_id == "undefined" || test_id == "none") { - ctx.mark_known_type(input, inferred_type::optional); - } else if (test_id == "string") { - ctx.mark_known_type(input, inferred_type::string); - } - if (negate) { return mk_val(!res->as_bool()); } else { @@ -354,9 +345,6 @@ value unary_expression::execute_impl(context & ctx) { value if_statement::execute_impl(context & ctx) { value test_val = test->execute(ctx); - ctx.mark_known_type(test_val, inferred_type::boolean); - ctx.mark_known_type(test_val, inferred_type::optional); - auto out = mk_val(); if (test_val->as_bool()) { for (auto & stmt : body) { @@ -399,9 +387,6 @@ value for_statement::execute_impl(context & ctx) { iterable_val = mk_val(); } - ctx.mark_known_type(iterable_val, inferred_type::array); - ctx.mark_known_type(iterable_val, inferred_type::object); - if (!is_val(iterable_val) && !is_val(iterable_val)) { throw std::runtime_error("Expected iterable or object type in for loop: got " + iterable_val->type()); } diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 0ac2e5f16a..3817e7f535 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -78,46 +78,10 @@ struct context { void set_val(const std::string & name, const value & val) { global->insert(name, val); - set_flattened_global_recursively(name, val); - } - - void mark_known_type(value & val, inferred_type type) { - val->inf_types.insert(type); - } - - void mark_known_type(value & val, value & known_val) { - mark_known_type(val, value_to_inferred_type(known_val)); - val->inf_vals.push_back(known_val); - } - - // FOR TESTING ONLY - const value_object & get_global_object() const { - return global; } private: value_object global; - -public: - std::map flatten_globals; // for debugging - void set_flattened_global_recursively(std::string path, const value & val) { - flatten_globals[path] = val; - if (is_val(val)) { - auto & obj = val->as_object(); - for (const auto & pair : obj) { - std::string child_path = path + "." + pair.first; - flatten_globals[child_path] = pair.second; - set_flattened_global_recursively(child_path, pair.second); - } - } else if (is_val(val)) { - auto & arr = val->as_array(); - for (size_t i = 0; i < arr.size(); ++i) { - std::string idx_path = path + "[" + std::to_string(i) + "]"; - flatten_globals[idx_path] = arr[i]; - set_flattened_global_recursively(idx_path, arr[i]); - } - } - } }; /** diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index b6a9a4a766..7f588a8878 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -13,7 +13,6 @@ #include "jinja/jinja-parser.h" #include "jinja/jinja-lexer.h" -#include "jinja/jinja-type-infer.h" void run_multiple(); void run_single(std::string contents); @@ -165,22 +164,4 @@ void run_single(std::string contents) { for (const auto & part : parts.get()->val_str.parts) { std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n"; } - - std::cout << "\n=== TYPES ===\n"; - auto & global_obj = ctx.flatten_globals; - for (const auto & pair : global_obj) { - std::string name = pair.first; - std::string inf_types; - for (const auto & t : pair.second->inf_types) { - inf_types += inferred_type_to_string(t) + " "; - } - if (inf_types.empty()) { - continue; - } - std::string inf_vals; - for (const auto & v : pair.second->inf_vals) { - inf_vals += v->as_string().str() + " ; "; - } - printf("Var: %-20s | Types: %-10s | Vals: %s\n", name.c_str(), inf_types.c_str(), inf_vals.c_str()); - } } From a10fbc77a391da139b8f729eae13c45f7fb772aa Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 1 Jan 2026 22:48:17 +0100 Subject: [PATCH 39/47] no more std::regex --- common/jinja/jinja-lexer.cpp | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/common/jinja/jinja-lexer.cpp b/common/jinja/jinja-lexer.cpp index 189f8f5b10..32f6ac909a 100644 --- a/common/jinja/jinja-lexer.cpp +++ b/common/jinja/jinja-lexer.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include @@ -23,7 +22,7 @@ static void trim_template_markers_inplace(std::string & s) { // i = head ; j = tail (i <= j) size_t j = 0; // Write pointer const size_t len = s.length(); - + for (size_t i = 0; i < len; ) { bool handled = false; @@ -75,6 +74,32 @@ static void trim_template_markers_inplace(std::string & s) { s.resize(j); } +static void trim_newline_after_tag_inplace(std::string & s) { + // i = head ; j = tail (i <= j) + size_t j = 0; // Write pointer + const size_t len = s.length(); + + for (size_t i = 0; i < len; ) { + s[j++] = s[i++]; + + if (i < len && (s[j-1] == '}' || s[j-1] == '%' || s[j-1] == '#' || s[j-1] == '-')) { + if (s[i] == '}') { + // We have a potential tag closer like %} or -} or #} or }} + // Now check if the next character is a newline + if (i + 1 < len && s[i + 1] == '\n') { + // Skip the } and the following \n + ++i; // skip the } + ++i; // skip the \n + // Do not advance j, we effectively removed the \n + continue; + } + } + } + } + + s.resize(j); +} + std::string lexer::preprocess(const std::string & template_str, const preprocess_options & options) const { std::string result = template_str; // According to https://jinja.palletsprojects.com/en/3.0.x/templates/#whitespace-control @@ -97,7 +122,8 @@ std::string lexer::preprocess(const std::string & template_str, const preprocess if (options.trim_blocks) { // If an application configures Jinja to trim_blocks, the first newline after // a template tag is removed automatically (like in PHP). - result = std::regex_replace(result, std::regex(R"(([#%-]\})\n)"), "$1"); + // Equivalent JS code: template.replace(/^[ \t]*({[#%-])/gm, "$1") + trim_newline_after_tag_inplace(result); } // Handle whitespace control with - in tags @@ -105,8 +131,8 @@ std::string lexer::preprocess(const std::string & template_str, const preprocess // Handle custom transformers-specific `generation` tag // See https://github.com/huggingface/transformers/pull/30650 for more information. - result = std::regex_replace(result, std::regex(R"(\{%\s*generation\s*%\})"), ""); - result = std::regex_replace(result, std::regex(R"(\{%\s*endgeneration\s*%\})"), ""); + // result = std::regex_replace(result, std::regex(R"(\{%\s*generation\s*%\})"), ""); + // result = std::regex_replace(result, std::regex(R"(\{%\s*endgeneration\s*%\})"), ""); return result; } From 61c25c3fbf8c73052e782744937d49ca00edc907 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 1 Jan 2026 22:48:42 +0100 Subject: [PATCH 40/47] trailing spaces --- common/jinja/jinja-parser.cpp | 20 ++++++++++---------- common/jinja/jinja-vm.cpp | 4 ++-- common/jinja/jinja-vm.h | 6 +++--- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/common/jinja/jinja-parser.cpp b/common/jinja/jinja-parser.cpp index 8cbb41eca6..ed3604ea95 100644 --- a/common/jinja/jinja-parser.cpp +++ b/common/jinja/jinja-parser.cpp @@ -132,7 +132,7 @@ private: // Consume {% token prev_cur = current; expect(token::open_statement, "Expected {%"); - + if (peek().t != token::identifier) { throw std::runtime_error("Unknown statement"); } @@ -183,15 +183,15 @@ private: } auto callee = parse_primary_expression(); if (!is_type(callee)) throw std::runtime_error("Expected identifier"); - + auto call_args = parse_args(); expect(token::close_statement, "Expected %}"); - + statements body; while (!is_statement({"endcall"})) { body.push_back(parse_any()); } - + expect(token::open_statement, "Expected {%"); expect_identifier("endcall"); expect(token::close_statement, "Expected %}"); @@ -205,12 +205,12 @@ private: filter_node = parse_call_expression(std::move(filter_node)); } expect(token::close_statement, "Expected %}"); - + statements body; while (!is_statement({"endfilter"})) { body.push_back(parse_any()); } - + expect(token::open_statement, "Expected {%"); expect_identifier("endfilter"); expect(token::close_statement, "Expected %}"); @@ -227,7 +227,7 @@ private: auto left = parse_expression_sequence(); statement_ptr value = nullptr; statements body; - + prev_cur = current; if (is(token::equals)) { @@ -311,7 +311,7 @@ private: // `messages` in `for message in messages` auto iterable = parse_expression(); expect(token::close_statement, "Expected %}"); - + statements body; statements alternate; @@ -486,7 +486,7 @@ private: arg = parse_expression(); if (is(token::equals)) { // keyword argument - // e.g., func(x = 5, y = a or b) + // e.g., func(x = 5, y = a or b) ++current; // consume equals arg = mk_stmt(std::move(arg), parse_expression()); } @@ -525,7 +525,7 @@ private: prev_cur = current; if (is(token::colon)) { // A case where a default is used - // e.g., [:2] will be parsed as [undefined, 2] + // e.g., [:2] will be parsed as [undefined, 2] slices.push_back(nullptr); ++current; // consume colon is_slice = true; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 2a679517e8..4df50c5132 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -457,7 +457,7 @@ value for_statement::execute_impl(context & ctx) { scope_update_fns.push_back(scope_update_fn); } JJ_DEBUG("For loop: %zu items after filtering", filtered_items.size()); - + auto result = mk_val(); bool noIteration = true; @@ -558,7 +558,7 @@ value macro_statement::execute_impl(context & ctx) { const func_handler func = [this, name, &ctx](const func_args & args) -> value { size_t expected_count = this->args.size(); size_t input_count = args.args.size(); - + JJ_DEBUG("Invoking macro '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count); context macro_ctx(ctx); // new scope for macro execution diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 3817e7f535..5b697eb949 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -160,7 +160,7 @@ struct for_statement : public statement { statements default_block; // if no iteration took place for_statement(statement_ptr && loopvar, statement_ptr && iterable, statements && body, statements && default_block) - : loopvar(std::move(loopvar)), iterable(std::move(iterable)), + : loopvar(std::move(loopvar)), iterable(std::move(iterable)), body(std::move(body)), default_block(std::move(default_block)) { chk_type(this->loopvar); chk_type(this->iterable); @@ -278,7 +278,7 @@ struct identifier : public expression { // Literals -struct integer_literal : public expression { +struct integer_literal : public expression { int64_t val; explicit integer_literal(int64_t val) : val(val) {} std::string type() const override { return "IntegerLiteral"; } @@ -327,7 +327,7 @@ struct tuple_literal : public array_literal { struct object_literal : public expression { std::vector> val; - explicit object_literal(std::vector> && val) + explicit object_literal(std::vector> && val) : val(std::move(val)) { for (const auto & pair : this->val) { chk_type(pair.first); From b23b5e3c0196993fca4bd4fc061e346adb5b023a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 1 Jan 2026 23:02:30 +0100 Subject: [PATCH 41/47] make testing more flexible --- tests/test-chat-jinja.cpp | 192 +++++++++++++++++++++----------------- 1 file changed, 106 insertions(+), 86 deletions(-) diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 7f588a8878..50401b56bb 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -14,74 +14,134 @@ #include "jinja/jinja-parser.h" #include "jinja/jinja-lexer.h" -void run_multiple(); -void run_single(std::string contents); +using json = nlohmann::json; -int main(void) { - //std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; +void run_multiple(std::string dir_path, bool stop_on_first_failure, json input); +void run_single(std::string contents, json input); - //std::string contents = "{% if messages[0]['role'] != 'system' %}nice {{ messages[0]['content'] }}{% endif %}"; +std::string HELP = R"( +Usage: test-chat-jinja [OPTIONS] PATH_TO_TEMPLATE +Options: + --json Path to the JSON input file. + --stop-on-first-fail Stop testing on the first failure (default: false). +If PATH_TO_TEMPLATE is a file, runs that single template. +If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory. +)"; - //std::string contents = " {{ messages[a]['content'] }} "; - //std::string contents = "{% if a is not defined %}hello{% endif %}"; +std::string DEFAULT_JSON = R"({ + "messages": [ + { + "role": "user", + "content": {"__input__": "Hello, how are you?"} + }, + { + "role": "assistant", + "content": {"__input__": "I am fine, thank you!"} + }, + { + "role": "assistant", + "content": "Calling weather tool.", + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": { + "location": "New York", + "unit": "celsius" + } + } + } + ] + } + ], + "bos_token": "", + "eos_token": "", + "tools": [] +})"; - std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja"); - //std::ifstream infile("models/templates/Kimi-K2-Thinking.jinja"); - std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); +int main(int argc, char ** argv) { + std::vector args(argv, argv + argc); - run_single(contents); + std::string tmpl_path; + std::string json_path; + bool stop_on_first_fail = false; - //run_multiple(); + for (size_t i = 1; i < args.size(); i++) { + if (args[i] == "--help" || args[i] == "-h") { + std::cout << HELP << "\n"; + return 0; + } else if (args[i] == "--json" && i + 1 < args.size()) { + json_path = args[i + 1]; + i++; + } else if (args[i] == "--stop-on-first-fail") { + stop_on_first_fail = true; + } else if (tmpl_path.empty()) { + tmpl_path = args[i]; + } else { + std::cerr << "Unknown argument: " << args[i] << "\n"; + std::cout << HELP << "\n"; + return 1; + } + } + + if (tmpl_path.empty()) { + std::cerr << "Error: PATH_TO_TEMPLATE is required.\n"; + std::cout << HELP << "\n"; + return 1; + } + + json input_json; + if (!json_path.empty()) { + std::ifstream json_file(json_path); + if (!json_file) { + std::cerr << "Error: Could not open JSON file: " << json_path << "\n"; + return 1; + } + std::string content = std::string( + std::istreambuf_iterator(json_file), + std::istreambuf_iterator()); + input_json = json::parse(content); + } else { + input_json = json::parse(DEFAULT_JSON); + } + + std::filesystem::path p(tmpl_path); + if (std::filesystem::is_directory(p)) { + run_multiple(tmpl_path, stop_on_first_fail, input_json); + } else if (std::filesystem::is_regular_file(p)) { + std::ifstream infile(tmpl_path); + std::string contents = std::string( + std::istreambuf_iterator(infile), + std::istreambuf_iterator()); + run_single(contents, input_json); + } else { + std::cerr << "Error: PATH_TO_TEMPLATE is not a valid file or directory: " << tmpl_path << "\n"; + return 1; + } return 0; } -void run_multiple(void) { +void run_multiple(std::string dir_path, bool stop_on_first_fail, json input) { std::vector failed_tests; - bool stop_on_first_failure = false; - - auto is_ignored_file = [](const std::string & filename) -> bool { - std::vector ignored_files = { - "Apriel-", - "Olmo-3-7B-Instruct-Heretic-GGUF", - "sheldonrobinson-Llama-Guard", - "deepseek-community-Janus-Pro-1B", - "bitshrine-gemma-2-2B-function-calling", - "PaddlePaddle-PaddleOCR-VL", - }; - for (const auto & ignored : ignored_files) { - if (filename.find(ignored) != std::string::npos) { - return true; - } - } - return false; - }; - // list all files in models/templates/ and run each size_t test_count = 0; - size_t skip_count = 0; - //std::string dir_path = "models/templates/"; - std::string dir_path = "../test-jinja/templates/"; - for (const auto & entry : std::filesystem::directory_iterator(dir_path)) { - if (entry.is_regular_file()) { - if (is_ignored_file(entry.path().filename().string())) { - std::cout << "=== SKIPPING TEMPLATE FILE: " << entry.path().string() << " ===\n"; - skip_count++; - continue; - } + for (const auto & entry : std::filesystem::directory_iterator(dir_path)) { + // only process .jinja files + if (entry.path().extension() == ".jinja" && entry.is_regular_file()) { test_count++; std::cout << "\n\n=== RUNNING TEMPLATE FILE: " << entry.path().string() << " ===\n"; std::ifstream infile(entry.path()); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); try { - run_single(contents); + run_single(contents, input); } catch (const std::exception & e) { std::cout << "Exception: " << e.what() << "\n"; std::cout << "=== ERROR WITH TEMPLATE FILE: " << entry.path().string() << " ===\n"; failed_tests.push_back(entry.path().string()); - if (stop_on_first_failure) { + if (stop_on_first_fail) { break; } } @@ -91,14 +151,13 @@ void run_multiple(void) { std::cout << "\n\n=== TEST SUMMARY ===\n"; std::cout << "Total tests run: " << test_count << "\n"; std::cout << "Total failed tests: " << failed_tests.size() << "\n"; - std::cout << "Total skipped tests: " << skip_count << "\n"; for (const auto & test : failed_tests) { std::cout << "FAILED TEST: " << test << "\n"; } } -void run_single(std::string contents) { +void run_single(std::string contents, json input) { jinja::enable_debug(true); // lexing @@ -115,46 +174,7 @@ void run_single(std::string contents) { jinja::context ctx; ctx.source = lexer_res.preprocessed_source; - std::string json_inp = R"({ - "messages": [ - { - "role": "user", - "content": {"__input__": "Hello, how are you?"} - }, - { - "role": "assistant", - "content": {"__input__": "I am fine, thank you!"} - }, - { - "role": "assistant", - "content": "Calling weather tool.", - "tool_calls": [ - { - "function": { - "name": "get_weather", - "arguments": { - "location": "New York", - "unit": "celsius" - } - } - } - ] - } - ], - "bos_token": "", - "eos_token": "", - "tools": [] - })"; - auto input_json = nlohmann::json::parse(json_inp); - - // workaround for functionary models - input_json["functions"] = ""; - input_json["datetime"] = ""; - - // workaround for Llama Guard models - input_json["excluded_category_keys"] = nlohmann::json::array(); - - jinja::global_from_json(ctx, input_json); + jinja::global_from_json(ctx, input); jinja::vm vm(ctx); const jinja::value results = vm.execute(ast); From a66e4a4f5de300fb72c557abc48ec2c46abcdaad Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 1 Jan 2026 23:07:45 +0100 Subject: [PATCH 42/47] make output a bit cleaner --- common/jinja/jinja-vm.h | 10 ++++++++++ tests/test-chat-jinja.cpp | 5 +++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 5b697eb949..faee1559cf 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -574,6 +574,16 @@ struct vm { value_string gather_string_parts(const value & val) { value_string parts = mk_val(); gather_string_parts_recursive(val, parts); + // join consecutive parts with the same type + auto & p = parts->val_str.parts; + for (size_t i = 1; i < p.size(); ) { + if (p[i].is_input == p[i - 1].is_input) { + p[i - 1].val += p[i].val; + p.erase(p.begin() + i); + } else { + i++; + } + } return parts; } }; diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 50401b56bb..86fe8f1f15 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -56,7 +56,8 @@ std::string DEFAULT_JSON = R"({ ], "bos_token": "", "eos_token": "", - "tools": [] + "tools": [], + "add_generation_prompt": true })"; int main(int argc, char ** argv) { @@ -181,7 +182,7 @@ void run_single(std::string contents, json input) { auto parts = vm.gather_string_parts(results); std::cout << "\n=== RESULTS ===\n"; - for (const auto & part : parts.get()->val_str.parts) { + for (const auto & part : parts->as_string().parts) { std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n"; } } From 4b71c285dbfefb22f1e2a0b86609351f3bfa2333 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 1 Jan 2026 23:33:23 +0100 Subject: [PATCH 43/47] (wip) redirect minja calls --- common/chat.cpp | 111 +++++++++++++++++++++++++++++----------- common/jinja/jinja-vm.h | 3 +- 2 files changed, 82 insertions(+), 32 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 0a426f4478..82c742ee18 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -7,8 +7,12 @@ #include "log.h" #include "regex-partial.h" -#include -#include +// #include +// #include + +#include "jinja/jinja-parser.h" +#include "jinja/jinja-value.h" +#include "jinja/jinja-vm.h" #include #include @@ -135,7 +139,46 @@ std::vector common_chat_msg_diff::compute_diffs(const comm return diffs; } -typedef minja::chat_template common_chat_template; +struct common_chat_template { + jinja::program prog; + std::string bos_tok; + std::string eos_tok; + std::string src; + common_chat_template(const std::string & src, const std::string & bos_token, const std::string & eos_token) { + jinja::lexer lexer; + jinja::preprocess_options options; + options.trim_blocks = false; + options.lstrip_blocks = false; + auto lexer_res = lexer.tokenize(src, options); + prog = jinja::parse_from_tokens(lexer_res); + + this->src = lexer_res.preprocessed_source; + this->bos_tok = bos_token; + this->eos_tok = eos_token; + } + + const std::string & source() const { return src; } + const std::string & bos_token() const { return bos_tok; } + const std::string & eos_token() const { return eos_tok; } + static json add_system(const json &, const std::string &) { + throw std::runtime_error("common_chat_template::add_system not implemented"); + } + + + // this is just for testing. it will be removed later + struct chat_template_caps { + bool supports_tools = true; + bool supports_tool_calls = true; + bool supports_tool_responses = true; + bool supports_system_role = true; + bool supports_parallel_tool_calls = true; + bool requires_typed_content = true; + }; + chat_template_caps original_caps() const { + return chat_template_caps(); + } + +}; struct common_chat_templates { bool add_bos; @@ -627,14 +670,14 @@ common_chat_templates_ptr common_chat_templates_init( tmpls->add_bos = add_bos; tmpls->add_eos = add_eos; try { - tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); + tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); } catch (const std::exception & e) { LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what()); - tmpls->template_default = std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos); + tmpls->template_default = std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos); } if (!template_tool_use_src.empty()) { try { - tmpls->template_tool_use = std::make_unique(template_tool_use_src, token_bos, token_eos); + tmpls->template_tool_use = std::make_unique(template_tool_use_src, token_bos, token_eos); } catch (const std::exception & e) { LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what()); } @@ -737,34 +780,40 @@ static std::string apply( const std::optional & tools_override = std::nullopt, const std::optional & additional_context = std::nullopt) { - minja::chat_template_inputs tmpl_inputs; - tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages; - if (tools_override) { - tmpl_inputs.tools = *tools_override; - } else { - tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools; - } - tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt; - tmpl_inputs.extra_context = inputs.extra_context; - tmpl_inputs.extra_context["enable_thinking"] = inputs.enable_thinking; - if (additional_context) { - tmpl_inputs.extra_context.merge_patch(*additional_context); - } - // TODO: add flag to control date/time, if only for testing purposes. - // tmpl_inputs.now = std::chrono::system_clock::now(); + // TODO IMPORTANT: IMPORVE THIS - minja::chat_template_options tmpl_opts; - // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens - // instead of using `chat_template_options.use_bos_token = false`, since these tokens - // may be needed inside the template / between messages too. - auto result = tmpl.apply(tmpl_inputs, tmpl_opts); - if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) { - result = result.substr(tmpl.bos_token().size()); + jinja::context ctx; + ctx.source = tmpl.source(); // for debugging + + nlohmann::json inp = nlohmann::json{ + {"messages", messages_override.has_value() ? *messages_override : inputs.messages}, + {"tools", tools_override.has_value() ? *tools_override : inputs.tools}, + }; + if (additional_context.has_value()) { + // TODO: merge properly instead of overwriting + for (const auto & [k, v] : additional_context->items()) { + inp[k] = v; + } } - if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) { - result = result.substr(0, result.size() - tmpl.eos_token().size()); + if (inputs.add_generation_prompt) { + inp["add_generation_prompt"] = true; } - return result; + if (inputs.add_bos) { + inp["bos_token"] = tmpl.bos_token(); + } + if (inputs.add_eos) { + inp["eos_token"] = tmpl.eos_token(); + } + // TODO: more inputs? + + jinja::global_from_json(ctx, inp); + + // render + jinja::vm vm(ctx); + const jinja::value results = vm.execute(tmpl.prog); + auto parts = vm.gather_string_parts(results); + + return parts->as_string().str(); } static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) { diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index faee1559cf..c1f91dd81f 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -125,6 +125,7 @@ struct expression : public statement { struct program : public statement { statements body; + program() = default; explicit program(statements && body) : body(std::move(body)) {} std::string type() const override { return "Program"; } value execute_impl(context &) override { @@ -562,7 +563,7 @@ struct vm { context & ctx; explicit vm(context & ctx) : ctx(ctx) {} - value_array execute(program & prog) { + value_array execute(const program & prog) { value_array results = mk_val(); for (auto & stmt : prog.body) { value res = stmt->execute(ctx); From 0f9f986acec5b7e2a2fd1275fbe07b302bb7620f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 2 Jan 2026 11:33:42 +0100 Subject: [PATCH 44/47] test: add --output --- tests/test-chat-jinja.cpp | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 86fe8f1f15..b22c8a56d5 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -17,13 +17,15 @@ using json = nlohmann::json; void run_multiple(std::string dir_path, bool stop_on_first_failure, json input); -void run_single(std::string contents, json input); +void run_single(std::string contents, json input, const std::string & output_path = ""); std::string HELP = R"( Usage: test-chat-jinja [OPTIONS] PATH_TO_TEMPLATE Options: + -h, --help Show this help message and exit. --json Path to the JSON input file. --stop-on-first-fail Stop testing on the first failure (default: false). + --output Path to output results (only for single template runs). If PATH_TO_TEMPLATE is a file, runs that single template. If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory. )"; @@ -65,6 +67,7 @@ int main(int argc, char ** argv) { std::string tmpl_path; std::string json_path; + std::string output_path; bool stop_on_first_fail = false; for (size_t i = 1; i < args.size(); i++) { @@ -76,6 +79,9 @@ int main(int argc, char ** argv) { i++; } else if (args[i] == "--stop-on-first-fail") { stop_on_first_fail = true; + } else if (args[i] == "--output" && i + 1 < args.size()) { + output_path = args[i + 1]; + i++; } else if (tmpl_path.empty()) { tmpl_path = args[i]; } else { @@ -114,7 +120,7 @@ int main(int argc, char ** argv) { std::string contents = std::string( std::istreambuf_iterator(infile), std::istreambuf_iterator()); - run_single(contents, input_json); + run_single(contents, input_json, output_path); } else { std::cerr << "Error: PATH_TO_TEMPLATE is not a valid file or directory: " << tmpl_path << "\n"; return 1; @@ -158,7 +164,7 @@ void run_multiple(std::string dir_path, bool stop_on_first_fail, json input) { } -void run_single(std::string contents, json input) { +void run_single(std::string contents, json input, const std::string & output_path) { jinja::enable_debug(true); // lexing @@ -185,4 +191,15 @@ void run_single(std::string contents, json input) { for (const auto & part : parts->as_string().parts) { std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n"; } + + if (!output_path.empty()) { + std::ofstream outfile(output_path); + if (!outfile) { + throw std::runtime_error("Could not open output file: " + output_path); + } + for (const auto & part : parts->as_string().parts) { + outfile << part.val; + } + std::cout << "\n=== OUTPUT WRITTEN TO " << output_path << " ===\n"; + } } From dce256cf4051b595c9cc25363738d46b948e10ef Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 2 Jan 2026 11:50:48 +0100 Subject: [PATCH 45/47] fix crash on macro kwargs --- common/jinja/jinja-vm.cpp | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 4df50c5132..076e041ef4 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -565,9 +565,20 @@ value macro_statement::execute_impl(context & ctx) { // bind parameters for (size_t i = 0; i < expected_count; ++i) { if (i < input_count) { - std::string param_name = cast_stmt(this->args[i])->val; - JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), args.args[i]->type().c_str()); - macro_ctx.set_val(param_name, args.args[i]); + if (is_stmt(this->args[i])) { + // normal parameter + std::string param_name = cast_stmt(this->args[i])->val; + JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), args.args[i]->type().c_str()); + macro_ctx.set_val(param_name, args.args[i]); + } else if (is_stmt(this->args[i])) { + // default argument used as normal parameter + auto kwarg = cast_stmt(this->args[i]); + std::string param_name = cast_stmt(kwarg->key)->val; + JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), args.args[i]->type().c_str()); + macro_ctx.set_val(param_name, args.args[i]); + } else { + throw std::runtime_error("Invalid parameter type in macro '" + name + "'"); + } } else { auto & default_arg = this->args[i]; if (is_stmt(default_arg)) { From e858b7a0a30fc4f2cb2b5e6ee5adc7210c87d8b1 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 2 Jan 2026 16:28:04 +0100 Subject: [PATCH 46/47] add minimal caps system --- common/jinja/jinja-caps.h | 159 +++++++++++++++++++++++++++++++++++ common/jinja/jinja-value.cpp | 7 +- common/jinja/jinja-value.h | 10 +++ common/jinja/jinja-vm.cpp | 41 +++++++-- common/jinja/jinja-vm.h | 4 + tests/test-chat-jinja.cpp | 11 +-- 6 files changed, 219 insertions(+), 13 deletions(-) create mode 100644 common/jinja/jinja-caps.h diff --git a/common/jinja/jinja-caps.h b/common/jinja/jinja-caps.h new file mode 100644 index 0000000000..eca5782903 --- /dev/null +++ b/common/jinja/jinja-caps.h @@ -0,0 +1,159 @@ +#pragma once + +#include + +#include "jinja-value.h" +#include "jinja-vm.h" + +#define FILENAME "jinja-caps" + +namespace jinja { + +struct caps { + bool content_string = true; + bool content_array = true; +}; + +using caps_messages_fn = std::function; +using caps_analyze_fn = std::function; +static void caps_try_execute(jinja::program & prog, + caps_messages_fn messages_fn, + caps_messages_fn tools_fn, + caps_analyze_fn analyze_fn) { + context ctx; + ctx.is_get_stats = true; + + value messages = messages_fn(); + value tools = tools_fn(); + + ctx.set_val("messages", messages); + ctx.set_val("tools", tools); + ctx.set_val("add_generation_prompt", mk_val(true)); + + bool success = false; + try { + jinja::vm vm(ctx); + vm.execute(prog); + success = true; + } catch (const std::exception & e) { + JJ_DEBUG("Exception during execution: %s", e.what()); + // ignore exceptions during capability analysis + } + return analyze_fn(success, messages, tools); +} + +// for debugging only +static void caps_print_stats(value & v, std::string path) { + std::string ops; + for (const auto & name : v->stats.ops) { + ops += name + " "; + } + JJ_DEBUG("Value %s, type: %s %s, ops: %s", + path.c_str(), + v->type().c_str(), + v->stats.used ? "(used)" : "", + ops.c_str()); +} + +static caps caps_get(jinja::program & prog) { + caps result; + + static const auto has_op = [](value & v, const std::string & op_name) { + return v->stats.ops.find(op_name) != v->stats.ops.end(); + }; + + // case: given content as string, check if it's accessed as array + caps_try_execute( + prog, + [&]() { + auto messages = mk_val(); + { + value_object msg = mk_val(); + msg->insert("role", mk_val("user")); + msg->insert("content", mk_val("User message")); + messages->push_back(msg); + } + return messages; + }, + [&]() { + return mk_val(); + }, + [&](bool, value & messages, value &) { + auto & content = messages->at(0)->at("content"); + caps_print_stats(content, "messages[0].content"); + if (has_op(content, "selectattr") || has_op(content, "array_access")) { + // accessed as an array + JJ_DEBUG("%s", "Force content as array"); + result.content_string = false; + result.content_array = true; + } + } + ); + + // case: given content as array, check if it's supported or not + caps_try_execute( + prog, + [&]() { + auto messages = mk_val(); + { + value_object msg = mk_val(); + msg->insert("role", mk_val("user")); + value_array content_arr = mk_val(); + { + value_object content_part = mk_val(); + content_part->insert("type", mk_val("text")); + content_part->insert("text", mk_val("User message")); + content_arr->push_back(content_part); + } + msg->insert("content", content_arr); + messages->push_back(msg); + } + return messages; + }, + [&]() { + return mk_val(); + }, + [&](bool success, value & messages, value &) { + auto & content = messages->at(0)->at("content"); + caps_print_stats(content, "messages[0].content"); + if (!success) { + JJ_DEBUG("%s", "Cannot handle content as array"); + result.content_array = false; + } + } + ); + + return result; +} + +static void caps_apply_workarounds(context & ctx, const caps & c) { + auto messages = ctx.get_val("messages"); + + if (!is_val(messages)) { + throw std::runtime_error("Expected messages to be an array"); + } + + if (!c.content_string) { + for (auto & msg : messages->val_arr) { + if (!is_val(msg)) { + throw std::runtime_error("Expected messages[i] to be an object"); + } + auto obj_ptr = cast_val(msg); + auto & content = obj_ptr->at("content"); + if (!is_val(content)) { + JJ_DEBUG("%s", "Converting message content to array"); + auto str_content = content->as_string(); + value_array arr_content = mk_val(); + value_object content_part = mk_val(); + content_part->insert("type", mk_val("text")); + content_part->insert("text", mk_val(str_content)); + arr_content->push_back(content_part); + obj_ptr->insert("content", arr_content); + } + } + } + + ctx.set_val("messages", messages); +} + +} // namespace jinja diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index 270caafede..4da4584e23 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -12,7 +12,7 @@ #include #include -#define FILENAME "jinja-vm-builtins" +#define FILENAME "jinja-value" namespace jinja { @@ -408,6 +408,11 @@ const func_builtins & value_string_t::get_builtins() const { res->val_str.mark_input_based_on(input->as_string()); return res; }}, + {"safe", [](const func_args & args) -> value { + // no-op for now + args.ensure_vals(); + return args.args[0]; + }}, {"selectattr", [](const func_args &) -> value { throw std::runtime_error("String selectattr builtin not supported"); }}, diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 6483d460a3..9cb57f90f3 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -107,6 +107,13 @@ struct value_t { func_handler val_func; + // only used if ctx.is_get_stats = true + struct stats_t { + bool used = false; + // ops can be builtin calls or operators: "array_access", "object_access" + std::set ops; + } stats; + value_t() = default; value_t(const value_t &) = default; virtual ~value_t() = default; @@ -126,6 +133,9 @@ struct value_t { throw std::runtime_error("No builtins available for type " + type()); } + virtual value & at(const std::string & key) { return val_obj[key]; } + virtual value & at(size_t index) { return val_arr.at(index); } + virtual std::string as_repr() const { return as_string().str(); } }; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 076e041ef4..0728054c13 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -66,6 +66,9 @@ value identifier::execute_impl(context & ctx) { auto it = ctx.get_val(val); auto builtins = global_builtins(); if (!it->is_undefined()) { + if (ctx.is_get_stats) { + it->stats.used = true; + } JJ_DEBUG("Identifier '%s' found", val.c_str()); return it; } else if (builtins.find(val) != builtins.end()) { @@ -236,7 +239,12 @@ value binary_expression::execute_impl(context & ctx) { throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type()); } -static value try_builtin_func(const std::string & name, const value & input, bool undef_on_missing = false) { +static value try_builtin_func(context & ctx, const std::string & name, value & input, bool undef_on_missing = false) { + JJ_DEBUG("Trying built-in function '%s' for type %s", name.c_str(), input->type().c_str()); + if (ctx.is_get_stats) { + input->stats.used = true; + input->stats.ops.insert(name); + } auto builtins = input->get_builtins(); auto it = builtins.find(name); if (it != builtins.end()) { @@ -266,7 +274,7 @@ value filter_expression::execute_impl(context & ctx) { filter_id = "strip"; // alias } JJ_DEBUG("Applying filter '%s' to %s", filter_id.c_str(), input->type().c_str()); - return try_builtin_func(filter_id, input)->invoke(func_args(ctx)); + return try_builtin_func(ctx, filter_id, input)->invoke(func_args(ctx)); } else if (is_stmt(filter)) { auto call = cast_stmt(filter); @@ -278,7 +286,7 @@ value filter_expression::execute_impl(context & ctx) { args.args.push_back(arg_expr->execute(ctx)); } - return try_builtin_func(filter_id, input)->invoke(args); + return try_builtin_func(ctx, filter_id, input)->invoke(args); } else { throw std::runtime_error("Invalid filter expression"); @@ -401,12 +409,20 @@ value for_statement::execute_impl(context & ctx) { tuple->push_back(p.second); items.push_back(tuple); } + if (ctx.is_get_stats) { + iterable_val->stats.used = true; + iterable_val->stats.ops.insert("object_access"); + } } else { JJ_DEBUG("%s", "For loop over array items"); auto & arr = iterable_val->as_array(); for (const auto & item : arr) { items.push_back(item); } + if (ctx.is_get_stats) { + iterable_val->stats.used = true; + iterable_val->stats.ops.insert("array_access"); + } } std::vector> scope_update_fns; @@ -624,7 +640,7 @@ value member_expression::execute_impl(context & ctx) { start_val->as_repr().c_str(), stop_val->as_repr().c_str(), step_val->as_repr().c_str()); - auto slice_func = try_builtin_func("slice", object); + auto slice_func = try_builtin_func(ctx, "slice", object); func_args args(ctx); args.args.push_back(start_val); args.args.push_back(stop_val); @@ -654,7 +670,7 @@ value member_expression::execute_impl(context & ctx) { if (it != obj.end()) { val = it->second; } else { - val = try_builtin_func(key, object, true); + val = try_builtin_func(ctx, key, object, true); } JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str()); @@ -676,10 +692,11 @@ value member_expression::execute_impl(context & ctx) { val = mk_val(std::string(1, str[index])); } } + } else if (is_val(property)) { auto key = property->as_string().str(); JJ_DEBUG("Accessing %s built-in '%s'", is_val(object) ? "array" : "string", key.c_str()); - val = try_builtin_func(key, object); + val = try_builtin_func(ctx, key, object); } else { throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type()); } @@ -689,7 +706,17 @@ value member_expression::execute_impl(context & ctx) { throw std::runtime_error("Cannot access property with non-string: got " + property->type()); } auto key = property->as_string().str(); - val = try_builtin_func(key, object); + val = try_builtin_func(ctx, key, object); + } + + if (ctx.is_get_stats && val && object && property) { + val->stats.used = true; + object->stats.used = true; + if (is_val(property)) { + object->stats.ops.insert("array_access"); + } else if (is_val(property)) { + object->stats.ops.insert("object_access"); + } } return val; diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index c1f91dd81f..099111db46 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -50,6 +50,8 @@ struct context { std::string source; // for debugging std::time_t current_time; // for functions that need current time + bool is_get_stats = false; // whether to collect stats + context() { global = mk_val(); global->insert("true", mk_val(true)); @@ -65,6 +67,8 @@ struct context { for (const auto & pair : pvar) { set_val(pair.first, pair.second); } + current_time = parent.current_time; + is_get_stats = parent.is_get_stats; } value get_val(const std::string & name) { diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index b22c8a56d5..91a7b3ff87 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -13,6 +13,7 @@ #include "jinja/jinja-parser.h" #include "jinja/jinja-lexer.h" +#include "jinja/jinja-caps.h" using json = nlohmann::json; @@ -38,11 +39,7 @@ std::string DEFAULT_JSON = R"({ }, { "role": "assistant", - "content": {"__input__": "I am fine, thank you!"} - }, - { - "role": "assistant", - "content": "Calling weather tool.", + "content": {"__input__": "I am fine, thank you!"}, "tool_calls": [ { "function": { @@ -177,11 +174,15 @@ void run_single(std::string contents, json input, const std::string & output_pat // compile to AST jinja::program ast = jinja::parse_from_tokens(lexer_res); + // check caps for workarounds + auto caps = jinja::caps_get(ast); + std::cout << "\n=== RUN ===\n"; jinja::context ctx; ctx.source = lexer_res.preprocessed_source; jinja::global_from_json(ctx, input); + jinja::caps_apply_workarounds(ctx, caps); jinja::vm vm(ctx); const jinja::value results = vm.execute(ast); From 9b79863da3e69e30e41fbd74578b8268e4e4e5b8 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 2 Jan 2026 16:49:42 +0100 Subject: [PATCH 47/47] add some workarounds --- common/jinja/jinja-caps.h | 22 ++++++++++++++++++++++ common/jinja/jinja-parser.cpp | 6 ++++++ common/jinja/jinja-value.cpp | 5 +++++ common/jinja/jinja-vm.h | 8 ++++++++ 4 files changed, 41 insertions(+) diff --git a/common/jinja/jinja-caps.h b/common/jinja/jinja-caps.h index eca5782903..a8e9c4a559 100644 --- a/common/jinja/jinja-caps.h +++ b/common/jinja/jinja-caps.h @@ -154,6 +154,28 @@ static void caps_apply_workarounds(context & ctx, const caps & c) { } ctx.set_val("messages", messages); + + // + // per-model workarounds + // + + // workaround for shieldgemma-2b-Q2_K + if (ctx.get_val("guideline")->is_undefined()) { + ctx.set_val("guideline", mk_val("")); + } + + // workaround for functionary models + if (ctx.get_val("functions")->is_undefined()) { + ctx.set_val("functions", mk_val("")); + } + if (ctx.get_val("datetime")->is_undefined()) { + ctx.set_val("datetime", mk_val("")); + } + + // workaround for Llama-3-5B-Sheard + if (ctx.get_val("system_message")->is_undefined()) { + ctx.set_val("system_message", mk_val("")); + } } } // namespace jinja diff --git a/common/jinja/jinja-parser.cpp b/common/jinja/jinja-parser.cpp index ed3604ea95..25dacfefa0 100644 --- a/common/jinja/jinja-parser.cpp +++ b/common/jinja/jinja-parser.cpp @@ -216,6 +216,12 @@ private: expect(token::close_statement, "Expected %}"); result = mk_stmt(std::move(filter_node), std::move(body)); + } else if (name == "generation" || name == "endgeneration") { + // Ignore generation blocks (transformers-specific) + // See https://github.com/huggingface/transformers/pull/30650 for more information. + result = mk_stmt(); + current++; + } else { throw std::runtime_error("Unknown statement: " + name); } diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index 4da4584e23..1e7ef96e04 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -197,6 +197,7 @@ const func_builtins & global_builtins() { {"test_is_integer", test_type_fn}, {"test_is_number", test_type_fn}, {"test_is_iterable", test_type_fn}, + {"test_is_sequence", test_type_fn}, {"test_is_mapping", test_type_fn}, {"test_is_lower", [](const func_args & args) -> value { args.ensure_vals(); @@ -655,6 +656,10 @@ const func_builtins & value_object_t::get_builtins() const { } return result; }}, + {"string", [](const func_args & args) -> value { + args.ensure_vals(); + return mk_val("TO BE IMPLEMENTED"); + }}, {"tojson", [](const func_args & args) -> value { args.ensure_vals(); // use global to_json diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 099111db46..93c3ca91a5 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -203,6 +203,14 @@ struct continue_statement : public statement { } }; +// do nothing +struct noop_statement : public statement { + std::string type() const override { return "Noop"; } + value execute_impl(context &) override { + return mk_val(); + } +}; + struct set_statement : public statement { statement_ptr assignee; statement_ptr val;