diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index f7b99159e3..b270bebbcc 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -83,6 +83,15 @@ add_library(${TARGET} STATIC speculative.h unicode.cpp unicode.h + jinja/jinja-lexer.cpp + jinja/jinja-lexer.h + jinja/jinja-parser.cpp + jinja/jinja-parser.h + jinja/jinja-vm.cpp + jinja/jinja-vm.h + jinja/jinja-value.cpp + jinja/jinja-value.h + jinja/jinja-string.h ) target_include_directories(${TARGET} PUBLIC . ../vendor) diff --git a/common/chat.cpp b/common/chat.cpp index b98ab21ce1..f792fde0fd 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -7,8 +7,12 @@ #include "log.h" #include "regex-partial.h" -#include -#include +// #include +// #include + +#include "jinja/jinja-parser.h" +#include "jinja/jinja-value.h" +#include "jinja/jinja-vm.h" #include #include @@ -135,7 +139,46 @@ std::vector common_chat_msg_diff::compute_diffs(const comm return diffs; } -typedef minja::chat_template common_chat_template; +struct common_chat_template { + jinja::program prog; + std::string bos_tok; + std::string eos_tok; + std::string src; + common_chat_template(const std::string & src, const std::string & bos_token, const std::string & eos_token) { + jinja::lexer lexer; + jinja::preprocess_options options; + options.trim_blocks = false; + options.lstrip_blocks = false; + auto lexer_res = lexer.tokenize(src, options); + prog = jinja::parse_from_tokens(lexer_res); + + this->src = lexer_res.preprocessed_source; + this->bos_tok = bos_token; + this->eos_tok = eos_token; + } + + const std::string & source() const { return src; } + const std::string & bos_token() const { return bos_tok; } + const std::string & eos_token() const { return eos_tok; } + static json add_system(const json &, const std::string &) { + throw std::runtime_error("common_chat_template::add_system not implemented"); + } + + + // this is just for testing. it will be removed later + struct chat_template_caps { + bool supports_tools = true; + bool supports_tool_calls = true; + bool supports_tool_responses = true; + bool supports_system_role = true; + bool supports_parallel_tool_calls = true; + bool requires_typed_content = true; + }; + chat_template_caps original_caps() const { + return chat_template_caps(); + } + +}; struct common_chat_templates { bool add_bos; @@ -627,14 +670,14 @@ common_chat_templates_ptr common_chat_templates_init( tmpls->add_bos = add_bos; tmpls->add_eos = add_eos; try { - tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); + tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); } catch (const std::exception & e) { LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what()); - tmpls->template_default = std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos); + tmpls->template_default = std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos); } if (!template_tool_use_src.empty()) { try { - tmpls->template_tool_use = std::make_unique(template_tool_use_src, token_bos, token_eos); + tmpls->template_tool_use = std::make_unique(template_tool_use_src, token_bos, token_eos); } catch (const std::exception & e) { LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what()); } @@ -738,34 +781,40 @@ static std::string apply( const std::optional & tools_override = std::nullopt, const std::optional & additional_context = std::nullopt) { - minja::chat_template_inputs tmpl_inputs; - tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages; - if (tools_override) { - tmpl_inputs.tools = *tools_override; - } else { - tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools; - } - tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt; - tmpl_inputs.extra_context = inputs.extra_context; - tmpl_inputs.extra_context["enable_thinking"] = inputs.enable_thinking; - if (additional_context) { - tmpl_inputs.extra_context.merge_patch(*additional_context); - } - // TODO: add flag to control date/time, if only for testing purposes. - // tmpl_inputs.now = std::chrono::system_clock::now(); + // TODO IMPORTANT: IMPORVE THIS - minja::chat_template_options tmpl_opts; - // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens - // instead of using `chat_template_options.use_bos_token = false`, since these tokens - // may be needed inside the template / between messages too. - auto result = tmpl.apply(tmpl_inputs, tmpl_opts); - if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) { - result = result.substr(tmpl.bos_token().size()); + jinja::context ctx; + ctx.source = tmpl.source(); // for debugging + + nlohmann::json inp = nlohmann::json{ + {"messages", messages_override.has_value() ? *messages_override : inputs.messages}, + {"tools", tools_override.has_value() ? *tools_override : inputs.tools}, + }; + if (additional_context.has_value()) { + // TODO: merge properly instead of overwriting + for (const auto & [k, v] : additional_context->items()) { + inp[k] = v; + } } - if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) { - result = result.substr(0, result.size() - tmpl.eos_token().size()); + if (inputs.add_generation_prompt) { + inp["add_generation_prompt"] = true; } - return result; + if (inputs.add_bos) { + inp["bos_token"] = tmpl.bos_token(); + } + if (inputs.add_eos) { + inp["eos_token"] = tmpl.eos_token(); + } + // TODO: more inputs? + + jinja::global_from_json(ctx, inp); + + // render + jinja::vm vm(ctx); + const jinja::value results = vm.execute(tmpl.prog); + auto parts = vm.gather_string_parts(results); + + return parts->as_string().str(); } static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) { diff --git a/common/jinja/jinja-caps.h b/common/jinja/jinja-caps.h new file mode 100644 index 0000000000..a8e9c4a559 --- /dev/null +++ b/common/jinja/jinja-caps.h @@ -0,0 +1,181 @@ +#pragma once + +#include + +#include "jinja-value.h" +#include "jinja-vm.h" + +#define FILENAME "jinja-caps" + +namespace jinja { + +struct caps { + bool content_string = true; + bool content_array = true; +}; + +using caps_messages_fn = std::function; +using caps_analyze_fn = std::function; +static void caps_try_execute(jinja::program & prog, + caps_messages_fn messages_fn, + caps_messages_fn tools_fn, + caps_analyze_fn analyze_fn) { + context ctx; + ctx.is_get_stats = true; + + value messages = messages_fn(); + value tools = tools_fn(); + + ctx.set_val("messages", messages); + ctx.set_val("tools", tools); + ctx.set_val("add_generation_prompt", mk_val(true)); + + bool success = false; + try { + jinja::vm vm(ctx); + vm.execute(prog); + success = true; + } catch (const std::exception & e) { + JJ_DEBUG("Exception during execution: %s", e.what()); + // ignore exceptions during capability analysis + } + return analyze_fn(success, messages, tools); +} + +// for debugging only +static void caps_print_stats(value & v, std::string path) { + std::string ops; + for (const auto & name : v->stats.ops) { + ops += name + " "; + } + JJ_DEBUG("Value %s, type: %s %s, ops: %s", + path.c_str(), + v->type().c_str(), + v->stats.used ? "(used)" : "", + ops.c_str()); +} + +static caps caps_get(jinja::program & prog) { + caps result; + + static const auto has_op = [](value & v, const std::string & op_name) { + return v->stats.ops.find(op_name) != v->stats.ops.end(); + }; + + // case: given content as string, check if it's accessed as array + caps_try_execute( + prog, + [&]() { + auto messages = mk_val(); + { + value_object msg = mk_val(); + msg->insert("role", mk_val("user")); + msg->insert("content", mk_val("User message")); + messages->push_back(msg); + } + return messages; + }, + [&]() { + return mk_val(); + }, + [&](bool, value & messages, value &) { + auto & content = messages->at(0)->at("content"); + caps_print_stats(content, "messages[0].content"); + if (has_op(content, "selectattr") || has_op(content, "array_access")) { + // accessed as an array + JJ_DEBUG("%s", "Force content as array"); + result.content_string = false; + result.content_array = true; + } + } + ); + + // case: given content as array, check if it's supported or not + caps_try_execute( + prog, + [&]() { + auto messages = mk_val(); + { + value_object msg = mk_val(); + msg->insert("role", mk_val("user")); + value_array content_arr = mk_val(); + { + value_object content_part = mk_val(); + content_part->insert("type", mk_val("text")); + content_part->insert("text", mk_val("User message")); + content_arr->push_back(content_part); + } + msg->insert("content", content_arr); + messages->push_back(msg); + } + return messages; + }, + [&]() { + return mk_val(); + }, + [&](bool success, value & messages, value &) { + auto & content = messages->at(0)->at("content"); + caps_print_stats(content, "messages[0].content"); + if (!success) { + JJ_DEBUG("%s", "Cannot handle content as array"); + result.content_array = false; + } + } + ); + + return result; +} + +static void caps_apply_workarounds(context & ctx, const caps & c) { + auto messages = ctx.get_val("messages"); + + if (!is_val(messages)) { + throw std::runtime_error("Expected messages to be an array"); + } + + if (!c.content_string) { + for (auto & msg : messages->val_arr) { + if (!is_val(msg)) { + throw std::runtime_error("Expected messages[i] to be an object"); + } + auto obj_ptr = cast_val(msg); + auto & content = obj_ptr->at("content"); + if (!is_val(content)) { + JJ_DEBUG("%s", "Converting message content to array"); + auto str_content = content->as_string(); + value_array arr_content = mk_val(); + value_object content_part = mk_val(); + content_part->insert("type", mk_val("text")); + content_part->insert("text", mk_val(str_content)); + arr_content->push_back(content_part); + obj_ptr->insert("content", arr_content); + } + } + } + + ctx.set_val("messages", messages); + + // + // per-model workarounds + // + + // workaround for shieldgemma-2b-Q2_K + if (ctx.get_val("guideline")->is_undefined()) { + ctx.set_val("guideline", mk_val("")); + } + + // workaround for functionary models + if (ctx.get_val("functions")->is_undefined()) { + ctx.set_val("functions", mk_val("")); + } + if (ctx.get_val("datetime")->is_undefined()) { + ctx.set_val("datetime", mk_val("")); + } + + // workaround for Llama-3-5B-Sheard + if (ctx.get_val("system_message")->is_undefined()) { + ctx.set_val("system_message", mk_val("")); + } +} + +} // namespace jinja diff --git a/common/jinja/jinja-lexer.cpp b/common/jinja/jinja-lexer.cpp new file mode 100644 index 0000000000..32f6ac909a --- /dev/null +++ b/common/jinja/jinja-lexer.cpp @@ -0,0 +1,333 @@ +#include "jinja-lexer.h" +#include "jinja-vm.h" + +#include +#include +#include +#include +#include +#include +#include + +#define FILENAME "jinja-lexer" + +namespace jinja { + +// Trim template markers with '-' for whitespace control +// Example: [spaces]{%- ... -%} --> {% ... %} +#include +#include + +static void trim_template_markers_inplace(std::string & s) { + // i = head ; j = tail (i <= j) + size_t j = 0; // Write pointer + const size_t len = s.length(); + + for (size_t i = 0; i < len; ) { + bool handled = false; + + // We need at least 3 characters for any marker: {X- or -X} + if (i + 2 < len) { + const char c1 = s[i]; + const char c2 = s[i + 1]; + const char c3 = s[i + 2]; + + // 1. Closing trim: -X} where X = %, }, # + // Example: [content]-%} [spaces] -> [content]%} + if (c1 == '-' && c3 == '}' && (c2 == '%' || c2 == '}' || c2 == '#')) { + s[j++] = c2; + s[j++] = '}'; + i += 3; + // Strip leading whitespace AFTER the tag + while (i < len && std::isspace(static_cast(s[i]))) { + i++; + } + handled = true; + } + // 2. Opening trim: {X- where X = %, {, # + // Example: [spaces]{%- [content] -> {% [content] + else if (c1 == '{' && c3 == '-' && (c2 == '%' || c2 == '{' || c2 == '#')) { + // Trim trailing whitespace BEFORE the tag by moving write pointer back + while (j > 0 && std::isspace(static_cast(s[j - 1]))) { + j--; + } + + // Safety: Prevent merging '{' with tag start (avoid creating '{{%' or '{{{') + // if the character immediately before our new tag is a literal '{'. + if (j > 0 && s[j - 1] == '{') { + s[j++] = ' '; + } + + s[j++] = '{'; + s[j++] = c2; + i += 3; + handled = true; + } + } + + if (!handled) { + // Note: j is always <= i here, so this is safe. + s[j++] = s[i++]; + } + } + + s.resize(j); +} + +static void trim_newline_after_tag_inplace(std::string & s) { + // i = head ; j = tail (i <= j) + size_t j = 0; // Write pointer + const size_t len = s.length(); + + for (size_t i = 0; i < len; ) { + s[j++] = s[i++]; + + if (i < len && (s[j-1] == '}' || s[j-1] == '%' || s[j-1] == '#' || s[j-1] == '-')) { + if (s[i] == '}') { + // We have a potential tag closer like %} or -} or #} or }} + // Now check if the next character is a newline + if (i + 1 < len && s[i + 1] == '\n') { + // Skip the } and the following \n + ++i; // skip the } + ++i; // skip the \n + // Do not advance j, we effectively removed the \n + continue; + } + } + } + } + + s.resize(j); +} + +std::string lexer::preprocess(const std::string & template_str, const preprocess_options & options) const { + std::string result = template_str; + // According to https://jinja.palletsprojects.com/en/3.0.x/templates/#whitespace-control + + // In the default configuration: + // - a single trailing newline is stripped if present + // - other whitespace (spaces, tabs, newlines etc.) is returned unchanged + if (!result.empty() && result.back() == '\n') { + result.pop_back(); + } + + if (options.lstrip_blocks) { + // The lstrip_blocks option can also be set to strip tabs and spaces from the + // beginning of a line to the start of a block. (Nothing will be stripped if + // there are other characters before the start of the block.) + // result = std::regex_replace(result, std::regex(R"((?m)^[ \t]*(\{[#%-]))"), "$1"); + throw std::runtime_error("lstrip_blocks option is not implemented yet"); + } + + if (options.trim_blocks) { + // If an application configures Jinja to trim_blocks, the first newline after + // a template tag is removed automatically (like in PHP). + // Equivalent JS code: template.replace(/^[ \t]*({[#%-])/gm, "$1") + trim_newline_after_tag_inplace(result); + } + + // Handle whitespace control with - in tags + trim_template_markers_inplace(result); + + // Handle custom transformers-specific `generation` tag + // See https://github.com/huggingface/transformers/pull/30650 for more information. + // result = std::regex_replace(result, std::regex(R"(\{%\s*generation\s*%\})"), ""); + // result = std::regex_replace(result, std::regex(R"(\{%\s*endgeneration\s*%\})"), ""); + + return result; +} + +lexer_result lexer::tokenize(const std::string & input, const preprocess_options & options) { + std::vector tokens; + std::string src = preprocess(input, options); + JJ_DEBUG("preprocessed input: '%s'", src.c_str()); + + size_t pos = 0; + size_t start_pos = 0; + size_t curly_bracket_depth = 0; + + using pred = std::function; + auto consume_while = [&](pred predicate) -> std::string { + std::string str; + while (predicate(src[pos])) { + // check for escape char + if (src[pos] == '\\') { + // consume backslash + ++pos; + // check for end of input + if (pos >= src.size()) { + throw std::runtime_error("lexer: unexpected end of input after escape character"); + } + // add escaped char + char escaped_char = src[pos++]; + if (escape_chars.find(escaped_char) == escape_chars.end()) { + throw std::runtime_error(std::string("lexer: unknown escape character \\") + escaped_char); + } + char unescaped_char = escape_chars.at(escaped_char); + str += unescaped_char; + continue; + } + + str += src[pos++]; + if (pos > src.size()) { + throw std::runtime_error("lexer: unexpected end of input during consume_while"); + } + } + return str; + }; + + auto next_pos_is = [&](std::initializer_list chars) -> bool { + if (pos + 1 >= src.size()) return false; + for (char c : chars) { + if (src[pos + 1] == c) return true; + } + return false; + }; + + while (pos < src.size()) { + start_pos = pos; + JJ_DEBUG("lexer main loop at pos %zu: '%s...'", pos, src.substr(pos, 10).c_str()); + + // First, consume all text that is outside of a Jinja statement or expression + token::type last_token_type = tokens.empty() + ? token::undefined + : tokens.back().t; + if (last_token_type == token::undefined || + last_token_type == token::close_statement || + last_token_type == token::close_expression || + last_token_type == token::comment) { + std::string text; + while (pos < src.size() && + // Keep going until we hit the next Jinja statement or expression + !( + src[pos] == '{' && + next_pos_is( {'%', '{', '#'} ) + )) { + text += src[pos++]; + } + JJ_DEBUG("consumed text: '%s'", text.c_str()); + if (!text.empty()) { + tokens.push_back({token::text, text, start_pos}); + continue; + } + } + + // Possibly consume a comment + if (src[pos] == '{' && next_pos_is( {'#'} )) { + start_pos = pos; + pos += 2; // Skip the opening {# + std::string comment; + while (!(src[pos] == '#' && next_pos_is( {'}'} ))) { + if (pos + 2 >= src.size()) { + throw std::runtime_error("lexer: missing end of comment tag"); + } + comment += src[pos++]; + } + JJ_DEBUG("consumed comment: '%s'", comment.c_str()); + tokens.push_back({token::comment, comment, start_pos}); + pos += 2; // Skip the closing #} + continue; + } + + // Consume (and ignore) all whitespace inside Jinja statements or expressions + consume_while([](char c) { return std::isspace(static_cast(c)); }); + + if (pos >= src.size()) break; + + char ch = src[pos]; + + // Check for unary operators + if (ch == '-' || ch == '+') { + start_pos = pos; + token::type last_token_type = tokens.empty() ? token::undefined : tokens.back().t; + if (last_token_type == token::text || last_token_type == token::undefined) { + throw std::runtime_error(std::string("lexer: unexpected character: ") + ch); + } + switch (last_token_type) { + case token::identifier: + case token::numeric_literal: + case token::string_literal: + case token::close_paren: + case token::close_square_bracket: + // Part of a binary operator + // a - 1, 1 - 1, true - 1, "apple" - 1, (1) - 1, a[1] - 1 + // Continue parsing normally + break; + default: { + // Is part of a unary operator + // (-1), [-1], (1 + -1), not -1, -apple + ++pos; // Consume the operator + + // Check for numbers following the unary operator + std::string num = consume_while(is_integer); + std::string value = std::string(1, ch) + num; + token::type t = num.empty() ? token::unary_operator : token::numeric_literal; + JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str()); + tokens.push_back({t, value, start_pos}); + continue; + } + } + } + + // Try to match one of the tokens in the mapping table + bool matched = false; + for (const auto & [seq, typ] : ordered_mapping_table) { + start_pos = pos; + // Inside an object literal, don't treat "}}" as expression-end + if (seq == "}}" && curly_bracket_depth > 0) { + continue; + } + if (pos + seq.size() <= src.size() && src.substr(pos, seq.size()) == seq) { + tokens.push_back({typ, seq, start_pos}); + if (typ == token::open_expression) { + curly_bracket_depth = 0; + } else if (typ == token::open_curly_bracket) { + ++curly_bracket_depth; + } else if (typ == token::close_curly_bracket) { + --curly_bracket_depth; + } + pos += seq.size(); + matched = true; + break; // continue main loop + } + } + if (matched) continue; // continue main loop + + // Strings + if (ch == '\'' || ch == '"') { + start_pos = pos; + ++pos; // Skip opening quote + std::string str = consume_while([ch](char c) { return c != ch; }); + tokens.push_back({token::string_literal, str, start_pos}); + ++pos; // Skip closing quote + continue; + } + + // Numbers + if (is_integer(ch)) { + start_pos = pos; + std::string num = consume_while(is_integer); + if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) { + ++pos; // Consume '.' + std::string frac = consume_while(is_integer); + num += "." + frac; + } + tokens.push_back({token::numeric_literal, num, start_pos}); + continue; + } + + // Identifiers + if (is_word(ch)) { + start_pos = pos; + std::string word = consume_while(is_word); + tokens.push_back({token::identifier, word, start_pos}); + continue; + } + + throw std::runtime_error(std::string("lexer: unexpected character: ") + ch); + } + + return {std::move(tokens), std::move(src)}; +} + +} // namespace jinja diff --git a/common/jinja/jinja-lexer.h b/common/jinja/jinja-lexer.h new file mode 100644 index 0000000000..f9bbe0a991 --- /dev/null +++ b/common/jinja/jinja-lexer.h @@ -0,0 +1,152 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace jinja { + +struct preprocess_options { + bool trim_blocks = false; + bool lstrip_blocks = false; +}; + +struct token { + enum type { + undefined, + text, // The text between Jinja statements or expressions + + numeric_literal, // e.g., 123, 1.0 + string_literal, // 'string' + identifier, // Variables, functions, statements, booleans, etc. + equals, // = + open_paren, // ( + close_paren, // ) + open_statement, // {% + close_statement, // %} + open_expression, // {{ + close_expression, // }} + open_square_bracket, // [ + close_square_bracket, // ] + open_curly_bracket, // { + close_curly_bracket, // } + comma, // , + dot, // . + colon, // : + pipe, // | + + call_operator, // () + additive_binary_operator, // + - ~ + multiplicative_binary_operator, // * / % + comparison_binary_operator, // < > <= >= == != + unary_operator, // ! - + + comment, // {# ... #} + }; + type t; + std::string value; + size_t pos; +}; + +static std::string type_to_string(token::type t) { + switch (t) { + case token::undefined: return "undefined"; + case token::text: return "text"; + case token::numeric_literal: return "numeric_literal"; + case token::string_literal: return "string_literal"; + case token::identifier: return "identifier"; + case token::equals: return "equals"; + case token::open_paren: return "open_paren"; + case token::close_paren: return "close_paren"; + case token::open_statement: return "open_statement"; + case token::close_statement: return "close_statement"; + case token::open_expression: return "open_expression"; + case token::close_expression: return "close_expression"; + case token::open_square_bracket: return "open_square_bracket"; + case token::close_square_bracket: return "close_square_bracket"; + case token::open_curly_bracket: return "open_curly_bracket"; + case token::close_curly_bracket: return "close_curly_bracket"; + case token::comma: return "comma"; + case token::dot: return "dot"; + case token::colon: return "colon"; + case token::pipe: return "pipe"; + case token::call_operator: return "call_operator"; + case token::additive_binary_operator: return "additive_binary_operator"; + case token::multiplicative_binary_operator: return "multiplicative_binary_operator"; + case token::comparison_binary_operator: return "comparison_binary_operator"; + case token::unary_operator: return "unary_operator"; + case token::comment: return "comment"; + default: return "unknown"; + } +} + +struct lexer_result { + std::vector tokens; + std::string preprocessed_source; +}; + +struct lexer { + const std::map escape_chars = { + {'n', '\n'}, + {'t', '\t'}, + {'r', '\r'}, + {'b', '\b'}, + {'f', '\f'}, + {'v', '\v'}, + {'\\', '\\'}, + {'\'', '\''}, + {'\"', '\"'}, + }; + + static bool is_word(char c) { + return std::isalnum(static_cast(c)) || c == '_'; + } + + static bool is_integer(char c) { + return std::isdigit(static_cast(c)); + } + + const std::vector> ordered_mapping_table = { + // Control sequences + {"{%", token::open_statement}, + {"%}", token::close_statement}, + {"{{", token::open_expression}, + {"}}", token::close_expression}, + // Single character tokens + {"(", token::open_paren}, + {")", token::close_paren}, + {"{", token::open_curly_bracket}, + {"}", token::close_curly_bracket}, + {"[", token::open_square_bracket}, + {"]", token::close_square_bracket}, + {",", token::comma}, + {".", token::dot}, + {":", token::colon}, + {"|", token::pipe}, + // Comparison operators + {"<=", token::comparison_binary_operator}, + {">=", token::comparison_binary_operator}, + {"==", token::comparison_binary_operator}, + {"!=", token::comparison_binary_operator}, + {"<", token::comparison_binary_operator}, + {">", token::comparison_binary_operator}, + // Arithmetic operators + {"+", token::additive_binary_operator}, + {"-", token::additive_binary_operator}, + {"~", token::additive_binary_operator}, + {"*", token::multiplicative_binary_operator}, + {"/", token::multiplicative_binary_operator}, + {"%", token::multiplicative_binary_operator}, + // Assignment operator + {"=", token::equals}, + }; + + std::string preprocess(const std::string& template_str, const preprocess_options& options) const; + + lexer_result tokenize(const std::string & input, const preprocess_options & options); +}; + +} // namespace jinja diff --git a/common/jinja/jinja-parser.cpp b/common/jinja/jinja-parser.cpp new file mode 100644 index 0000000000..25dacfefa0 --- /dev/null +++ b/common/jinja/jinja-parser.cpp @@ -0,0 +1,610 @@ +#include "jinja-lexer.h" +#include "jinja-vm.h" +#include "jinja-parser.h" + +#include +#include +#include +#include +#include + +#define FILENAME "jinja-parser" + +namespace jinja { + +// Helper to check type without asserting (useful for logic) +template +static bool is_type(const statement_ptr & ptr) { + return dynamic_cast(ptr.get()) != nullptr; +} + +class parser { + const std::vector & tokens; + size_t current = 0; + size_t prev_cur = 0; + + // for debugging; a token can be multiple chars in source + std::vector tok_pos_to_src_pos; + + std::string source; // for error reporting + +public: + parser(const std::vector & t, const std::string & src) : tokens(t), source(src) { + tok_pos_to_src_pos.resize(tokens.size()); + for (size_t i = 0; i < tokens.size(); i++) { + tok_pos_to_src_pos[i] = tokens[i].pos; + } + } + + program parse() { + statements body; + while (current < tokens.size()) { + body.push_back(parse_any()); + } + return program(std::move(body)); + } + + template + std::unique_ptr mk_stmt(Args&&... args) { + auto ptr = std::make_unique(std::forward(args)...); + ptr->pos = tok_pos_to_src_pos[prev_cur]; + + std::string snippet = "no source"; + if (!source.empty()) { + size_t start_pos = ptr->pos; + size_t end_pos = start_pos + 20; + if (end_pos > source.size()) end_pos = source.size(); + snippet = source.substr(start_pos, end_pos - start_pos); + } + JJ_DEBUG("Created %-20s statement at src pos %-4zu (%s)", ptr->type().c_str(), ptr->pos, snippet.c_str()); + + return ptr; + } + +private: + const token & peek(size_t offset = 0) const { + if (current + offset >= tokens.size()) { + static const token end_token{token::undefined, "", 0}; + return end_token; + } + return tokens[current + offset]; + } + + token expect(token::type type, const std::string& error) { + const auto & t = peek(); + if (t.t != type) { + throw std::runtime_error("Parser Error: " + error + " (Got " + t.value + ")"); + } + current++; + return t; + } + + void expect_identifier(const std::string& name) { + const auto & t = peek(); + if (t.t != token::identifier || t.value != name) { + throw std::runtime_error("Expected identifier: " + name); + } + current++; + } + + bool is(token::type type) const { + return peek().t == type; + } + + bool is_identifier(const std::string& name) const { + return peek().t == token::identifier && peek().value == name; + } + + bool is_statement(const std::vector& names) const { + if (peek(0).t != token::open_statement || peek(1).t != token::identifier) { + return false; + } + std::string val = peek(1).value; + return std::find(names.begin(), names.end(), val) != names.end(); + } + + statement_ptr parse_any() { + prev_cur = current; + switch (peek().t) { + case token::comment: + return mk_stmt(tokens[current++].value); + case token::text: + return mk_stmt(tokens[current++].value); + case token::open_statement: + return parse_jinja_statement(); + case token::open_expression: + return parse_jinja_expression(); + default: + throw std::runtime_error("Unexpected token type"); + } + } + + statement_ptr parse_jinja_expression() { + // Consume {{ }} tokens + prev_cur = current; + expect(token::open_expression, "Expected {{"); + auto result = parse_expression(); + expect(token::close_expression, "Expected }}"); + return result; + } + + statement_ptr parse_jinja_statement() { + // Consume {% token + prev_cur = current; + expect(token::open_statement, "Expected {%"); + + if (peek().t != token::identifier) { + throw std::runtime_error("Unknown statement"); + } + + std::string name = peek().value; + current++; // consume identifier + + statement_ptr result; + if (name == "set") { + result = parse_set_statement(); + + } else if (name == "if") { + result = parse_if_statement(); + // expect {% endif %} + expect(token::open_statement, "Expected {%"); + expect_identifier("endif"); + expect(token::close_statement, "Expected %}"); + + } else if (name == "macro") { + result = parse_macro_statement(); + // expect {% endmacro %} + expect(token::open_statement, "Expected {%"); + expect_identifier("endmacro"); + expect(token::close_statement, "Expected %}"); + + } else if (name == "for") { + result = parse_for_statement(); + // expect {% endfor %} + expect(token::open_statement, "Expected {%"); + expect_identifier("endfor"); + expect(token::close_statement, "Expected %}"); + + } else if (name == "break") { + expect(token::close_statement, "Expected %}"); + result = mk_stmt(); + + } else if (name == "continue") { + expect(token::close_statement, "Expected %}"); + result = mk_stmt(); + + } else if (name == "call") { + statements caller_args; + // bool has_caller_args = false; + if (is(token::open_paren)) { + // Optional caller arguments, e.g. {% call(user) dump_users(...) %} + caller_args = parse_args(); + // has_caller_args = true; + } + auto callee = parse_primary_expression(); + if (!is_type(callee)) throw std::runtime_error("Expected identifier"); + + auto call_args = parse_args(); + expect(token::close_statement, "Expected %}"); + + statements body; + while (!is_statement({"endcall"})) { + body.push_back(parse_any()); + } + + expect(token::open_statement, "Expected {%"); + expect_identifier("endcall"); + expect(token::close_statement, "Expected %}"); + + auto call_expr = mk_stmt(std::move(callee), std::move(call_args)); + result = mk_stmt(std::move(call_expr), std::move(caller_args), std::move(body)); + + } else if (name == "filter") { + auto filter_node = parse_primary_expression(); + if (is_type(filter_node) && is(token::open_paren)) { + filter_node = parse_call_expression(std::move(filter_node)); + } + expect(token::close_statement, "Expected %}"); + + statements body; + while (!is_statement({"endfilter"})) { + body.push_back(parse_any()); + } + + expect(token::open_statement, "Expected {%"); + expect_identifier("endfilter"); + expect(token::close_statement, "Expected %}"); + result = mk_stmt(std::move(filter_node), std::move(body)); + + } else if (name == "generation" || name == "endgeneration") { + // Ignore generation blocks (transformers-specific) + // See https://github.com/huggingface/transformers/pull/30650 for more information. + result = mk_stmt(); + current++; + + } else { + throw std::runtime_error("Unknown statement: " + name); + } + return result; + } + + statement_ptr parse_set_statement() { + // NOTE: `set` acts as both declaration statement and assignment expression + auto left = parse_expression_sequence(); + statement_ptr value = nullptr; + statements body; + + prev_cur = current; + + if (is(token::equals)) { + current++; + value = parse_expression_sequence(); + } else { + // parsing multiline set here + expect(token::close_statement, "Expected %}"); + while (!is_statement({"endset"})) { + body.push_back(parse_any()); + } + expect(token::open_statement, "Expected {%"); + expect_identifier("endset"); + } + expect(token::close_statement, "Expected %}"); + return mk_stmt(std::move(left), std::move(value), std::move(body)); + } + + statement_ptr parse_if_statement() { + auto test = parse_expression(); + expect(token::close_statement, "Expected %}"); + + statements body; + statements alternate; + + prev_cur = current; + + // Keep parsing 'if' body until we reach the first {% elif %} or {% else %} or {% endif %} + while (!is_statement({"elif", "else", "endif"})) { + body.push_back(parse_any()); + } + + if (is_statement({"elif"})) { + ++current; // consume {% + ++current; // consume 'elif' + alternate.push_back(parse_if_statement()); // nested If + } else if (is_statement({"else"})) { + ++current; // consume {% + ++current; // consume 'else' + expect(token::close_statement, "Expected %}"); + + // keep going until we hit {% endif %} + while (!is_statement({"endif"})) { + alternate.push_back(parse_any()); + } + } + return mk_stmt(std::move(test), std::move(body), std::move(alternate)); + } + + statement_ptr parse_macro_statement() { + auto name = parse_primary_expression(); + auto args = parse_args(); + expect(token::close_statement, "Expected %}"); + statements body; + // Keep going until we hit {% endmacro + while (!is_statement({"endmacro"})) { + body.push_back(parse_any()); + } + return mk_stmt(std::move(name), std::move(args), std::move(body)); + } + + statement_ptr parse_expression_sequence(bool primary = false) { + statements exprs; + exprs.push_back(primary ? parse_primary_expression() : parse_expression()); + bool is_tuple = is(token::comma); + while (is(token::comma)) { + prev_cur = current; + current++; // consume comma + exprs.push_back(primary ? parse_primary_expression() : parse_expression()); + if (!is(token::comma)) break; + } + return is_tuple ? mk_stmt(std::move(exprs)) : std::move(exprs[0]); + } + + statement_ptr parse_for_statement() { + // e.g., `message` in `for message in messages` + auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple + if (!is_identifier("in")) throw std::runtime_error("Expected 'in'"); + current++; + + // `messages` in `for message in messages` + auto iterable = parse_expression(); + expect(token::close_statement, "Expected %}"); + + statements body; + statements alternate; + + // Keep going until we hit {% endfor or {% else + while (!is_statement({"endfor", "else"})) { + body.push_back(parse_any()); + } + + if (is_statement({"else"})) { + prev_cur = current; + current += 2; + expect(token::close_statement, "Expected %}"); + while (!is_statement({"endfor"})) { + alternate.push_back(parse_any()); + } + } + return mk_stmt( + std::move(loop_var), std::move(iterable), + std::move(body), std::move(alternate)); + } + + statement_ptr parse_expression() { + // Choose parse function with lowest precedence + return parse_if_expression(); + } + + statement_ptr parse_if_expression() { + auto a = parse_logical_or_expression(); + if (is_identifier("if")) { + // Ternary expression + prev_cur = current; + ++current; // consume 'if' + auto test = parse_logical_or_expression(); + if (is_identifier("else")) { + // Ternary expression with else + prev_cur = current; + ++current; // consume 'else' + auto false_expr = parse_if_expression(); // recurse to support chained ternaries + return mk_stmt(std::move(test), std::move(a), std::move(false_expr)); + } else { + // Select expression on iterable + return mk_stmt(std::move(a), std::move(test)); + } + } + return a; + } + + statement_ptr parse_logical_or_expression() { + auto left = parse_logical_and_expression(); + while (is_identifier("or")) { + prev_cur = current; + token op = tokens[current++]; + left = mk_stmt(op, std::move(left), parse_logical_and_expression()); + } + return left; + } + + statement_ptr parse_logical_and_expression() { + auto left = parse_logical_negation_expression(); + while (is_identifier("and")) { + prev_cur = current; + auto op = tokens[current++]; + left = mk_stmt(op, std::move(left), parse_logical_negation_expression()); + } + return left; + } + + statement_ptr parse_logical_negation_expression() { + // Try parse unary operators + if (is_identifier("not")) { + prev_cur = current; + auto op = tokens[current]; + ++current; // consume 'not' + return mk_stmt(op, parse_logical_negation_expression()); + } + return parse_comparison_expression(); + } + + statement_ptr parse_comparison_expression() { + // NOTE: membership has same precedence as comparison + // e.g., ('a' in 'apple' == 'b' in 'banana') evaluates as ('a' in ('apple' == ('b' in 'banana'))) + auto left = parse_additive_expression(); + while (true) { + token op; + prev_cur = current; + if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") { + op = {token::identifier, "not in", tokens[current].pos}; + current += 2; + } else if (is_identifier("in")) { + op = tokens[current++]; + } else if (is(token::comparison_binary_operator)) { + op = tokens[current++]; + } else break; + left = mk_stmt(op, std::move(left), parse_additive_expression()); + } + return left; + } + + statement_ptr parse_additive_expression() { + auto left = parse_multiplicative_expression(); + while (is(token::additive_binary_operator)) { + prev_cur = current; + auto op = tokens[current++]; + left = mk_stmt(op, std::move(left), parse_multiplicative_expression()); + } + return left; + } + + statement_ptr parse_multiplicative_expression() { + auto left = parse_test_expression(); + while (is(token::multiplicative_binary_operator)) { + prev_cur = current; + auto op = tokens[current++]; + left = mk_stmt(op, std::move(left), parse_test_expression()); + } + return left; + } + + statement_ptr parse_test_expression() { + auto operand = parse_filter_expression(); + while (is_identifier("is")) { + prev_cur = current; + current++; + bool negate = false; + if (is_identifier("not")) { current++; negate = true; } + auto test_id = parse_primary_expression(); + operand = mk_stmt(std::move(operand), negate, std::move(test_id)); + } + return operand; + } + + statement_ptr parse_filter_expression() { + auto operand = parse_call_member_expression(); + while (is(token::pipe)) { + prev_cur = current; + current++; + auto filter = parse_primary_expression(); + if (is(token::open_paren)) filter = parse_call_expression(std::move(filter)); + operand = mk_stmt(std::move(operand), std::move(filter)); + } + return operand; + } + + statement_ptr parse_call_member_expression() { + // Handle member expressions recursively + auto member = parse_member_expression(parse_primary_expression()); + return is(token::open_paren) + ? parse_call_expression(std::move(member)) // foo.x() + : std::move(member); + } + + statement_ptr parse_call_expression(statement_ptr callee) { + auto expr = mk_stmt(std::move(callee), parse_args()); + auto member = parse_member_expression(std::move(expr)); // foo.x().y + return is(token::open_paren) + ? parse_call_expression(std::move(member)) // foo.x()() + : std::move(member); + } + + statements parse_args() { + // comma-separated arguments list + expect(token::open_paren, "Expected ("); + statements args; + while (!is(token::close_paren)) { + statement_ptr arg; + prev_cur = current; + // unpacking: *expr + if (peek().t == token::multiplicative_binary_operator && peek().value == "*") { + ++current; // consume * + arg = mk_stmt(parse_expression()); + } else { + arg = parse_expression(); + if (is(token::equals)) { + // keyword argument + // e.g., func(x = 5, y = a or b) + ++current; // consume equals + arg = mk_stmt(std::move(arg), parse_expression()); + } + } + args.push_back(std::move(arg)); + if (is(token::comma)) { + ++current; // consume comma + } + } + expect(token::close_paren, "Expected )"); + return args; + } + + statement_ptr parse_member_expression(statement_ptr object) { + while (is(token::dot) || is(token::open_square_bracket)) { + auto op = tokens[current++]; + bool computed = op.t == token::open_square_bracket; + statement_ptr prop; + if (computed) { + prop = parse_member_expression_arguments(); + expect(token::close_square_bracket, "Expected ]"); + } else { + prop = parse_primary_expression(); + } + object = mk_stmt(std::move(object), std::move(prop), computed); + } + return object; + } + + statement_ptr parse_member_expression_arguments() { + // NOTE: This also handles slice expressions colon-separated arguments list + // e.g., ['test'], [0], [:2], [1:], [1:2], [1:2:3] + statements slices; + bool is_slice = false; + while (!is(token::close_square_bracket)) { + prev_cur = current; + if (is(token::colon)) { + // A case where a default is used + // e.g., [:2] will be parsed as [undefined, 2] + slices.push_back(nullptr); + ++current; // consume colon + is_slice = true; + } else { + slices.push_back(parse_expression()); + if (is(token::colon)) { + ++current; // consume colon after expression, if it exists + is_slice = true; + } + } + } + if (is_slice) { + statement_ptr start = slices.size() > 0 ? std::move(slices[0]) : nullptr; + statement_ptr stop = slices.size() > 1 ? std::move(slices[1]) : nullptr; + statement_ptr step = slices.size() > 2 ? std::move(slices[2]) : nullptr; + return mk_stmt(std::move(start), std::move(stop), std::move(step)); + } + return std::move(slices[0]); + } + + statement_ptr parse_primary_expression() { + prev_cur = current; + auto t = tokens[current++]; + switch (t.t) { + case token::numeric_literal: + if (t.value.find('.') != std::string::npos) return mk_stmt(std::stod(t.value)); + return mk_stmt(std::stoll(t.value)); + case token::string_literal: { + std::string val = t.value; + while (is(token::string_literal)) { + val += tokens[current++].value; + } + return mk_stmt(val); + } + case token::identifier: + return mk_stmt(t.value); + case token::open_paren: { + auto expr = parse_expression_sequence(); + expect(token::close_paren, "Expected )"); + return expr; + } + case token::open_square_bracket: { + statements vals; + while (!is(token::close_square_bracket)) { + vals.push_back(parse_expression()); + if (is(token::comma)) current++; + } + current++; + return mk_stmt(std::move(vals)); + } + case token::open_curly_bracket: { + std::vector> pairs; + while (!is(token::close_curly_bracket)) { + auto key = parse_expression(); + expect(token::colon, "Expected :"); + pairs.push_back({std::move(key), parse_expression()}); + if (is(token::comma)) current++; + } + current++; + return mk_stmt(std::move(pairs)); + } + default: + throw std::runtime_error("Unexpected token: " + t.value + " of type " + std::to_string(t.t)); + } + } +}; + +program parse_from_tokens(const std::vector & tokens) { + return parser(tokens, "").parse(); +} + +program parse_from_tokens(const lexer_result & lexer_res) { + return parser(lexer_res.tokens, lexer_res.preprocessed_source).parse(); +} + +} // namespace jinja diff --git a/common/jinja/jinja-parser.h b/common/jinja/jinja-parser.h new file mode 100644 index 0000000000..14ce135432 --- /dev/null +++ b/common/jinja/jinja-parser.h @@ -0,0 +1,18 @@ +#pragma once + +#include "jinja-lexer.h" +#include "jinja-vm.h" + +#include +#include +#include +#include +#include + +namespace jinja { + +program parse_from_tokens(const std::vector & tokens); + +program parse_from_tokens(const lexer_result & lexer_res); + +} // namespace jinja diff --git a/common/jinja/jinja-string.h b/common/jinja/jinja-string.h new file mode 100644 index 0000000000..d26bb1e20c --- /dev/null +++ b/common/jinja/jinja-string.h @@ -0,0 +1,202 @@ +#pragma once + +#include +#include +#include +#include + + +namespace jinja { + +// allow differentiate between user input strings and template strings +// transformations should handle this information as follows: +// - one-to-one (e.g., uppercase, lowercase): preserve is_input flag +// - one-to-many (e.g., strip): if input string is marked as is_input, all resulting parts should be marked as is_input +// - many-to-one (e.g., concat): if ALL input parts are marked as is_input, resulting part should be marked as is_input +struct string_part { + bool is_input = false; // may skip parsing special tokens if true + std::string val; + + bool is_uppercase() const { + for (char c : val) { + if (std::islower(static_cast(c))) { + return false; + } + } + return true; + } + + bool is_lowercase() const { + for (char c : val) { + if (std::isupper(static_cast(c))) { + return false; + } + } + return true; + } +}; + +struct string { + using transform_fn = std::function; + + std::vector parts; + string() = default; + string(const std::string & v, bool user_input = false) { + parts.push_back({user_input, v}); + } + string(int v) { + parts.push_back({false, std::to_string(v)}); + } + string(double v) { + parts.push_back({false, std::to_string(v)}); + } + + void mark_input() { + for (auto & part : parts) { + part.is_input = true; + } + } + + std::string str() const { + if (parts.size() == 1) { + return parts[0].val; + } + std::ostringstream oss; + for (const auto & part : parts) { + oss << part.val; + } + return oss.str(); + } + + size_t length() const { + size_t len = 0; + for (const auto & part : parts) { + len += part.val.length(); + } + return len; + } + + bool all_parts_are_input() const { + for (const auto & part : parts) { + if (!part.is_input) { + return false; + } + } + return true; + } + + bool is_uppercase() const { + for (const auto & part : parts) { + if (!part.is_uppercase()) { + return false; + } + } + return true; + } + + bool is_lowercase() const { + for (const auto & part : parts) { + if (!part.is_lowercase()) { + return false; + } + } + return true; + } + + // mark this string as input if other has ALL parts as input + void mark_input_based_on(const string & other) { + if (other.all_parts_are_input()) { + for (auto & part : parts) { + part.is_input = true; + } + } + } + + string append(const string & other) { + for (const auto & part : other.parts) { + parts.push_back(part); + } + return *this; + } + + // in-place transformation + + string apply_transform(const transform_fn & fn) { + for (auto & part : parts) { + part.val = fn(part.val); + } + return *this; + } + string uppercase() { + return apply_transform([](const std::string & s) { + std::string res = s; + std::transform(res.begin(), res.end(), res.begin(), ::toupper); + return res; + }); + } + string lowercase() { + return apply_transform([](const std::string & s) { + std::string res = s; + std::transform(res.begin(), res.end(), res.begin(), ::tolower); + return res; + }); + } + string capitalize() { + return apply_transform([](const std::string & s) { + if (s.empty()) return s; + std::string res = s; + res[0] = ::toupper(static_cast(res[0])); + std::transform(res.begin() + 1, res.end(), res.begin() + 1, ::tolower); + return res; + }); + } + string titlecase() { + return apply_transform([](const std::string & s) { + std::string res = s; + bool capitalize_next = true; + for (char &c : res) { + if (isspace(static_cast(c))) { + capitalize_next = true; + } else if (capitalize_next) { + c = ::toupper(static_cast(c)); + capitalize_next = false; + } else { + c = ::tolower(static_cast(c)); + } + } + return res; + }); + } + string strip(bool left, bool right) { + // TODO: what if leading/trailing continue in multiple parts? + + static auto strip_part = [](const std::string & s, bool left, bool right) -> std::string { + size_t start = 0; + size_t end = s.length(); + if (left) { + while (start < end && isspace(static_cast(s[start]))) { + ++start; + } + } + if (right) { + while (end > start && isspace(static_cast(s[end - 1]))) { + --end; + } + } + return s.substr(start, end - start); + }; + if (parts.empty()) { + return *this; + } + if (left) { + parts[0].val = strip_part(parts[0].val, true, false); + } + if (right) { + auto & last = parts[parts.size() - 1]; + last.val = strip_part(last.val, false, true); + } + return *this; + } +}; + +} // namespace jinja diff --git a/common/jinja/jinja-utils.h b/common/jinja/jinja-utils.h new file mode 100644 index 0000000000..a7d3bea5a8 --- /dev/null +++ b/common/jinja/jinja-utils.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include +#include + +namespace jinja { + +static void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { + if (search.empty()) { + return; + } + std::string builder; + builder.reserve(s.length()); + size_t pos = 0; + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); +} + +} // namespace jinja diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp new file mode 100644 index 0000000000..1e7ef96e04 --- /dev/null +++ b/common/jinja/jinja-value.cpp @@ -0,0 +1,734 @@ +#include "jinja-lexer.h" +#include "jinja-vm.h" +#include "jinja-parser.h" +#include "jinja-value.h" + +// for converting from JSON to jinja values +#include + +#include +#include +#include +#include +#include + +#define FILENAME "jinja-value" + +namespace jinja { + +// func_args method implementations + +value func_args::get_kwarg(const std::string & key) const { + for (const auto & arg : args) { + if (is_val(arg)) { + auto * kwarg = cast_val(arg); + if (kwarg->key == key) { + return kwarg->val; + } + } + } + return mk_val(); +} + +/** + * Function that mimics Python's array slicing. + */ +template +static T slice(const T & array, std::optional start = std::nullopt, std::optional stop = std::nullopt, int64_t step = 1) { + int64_t len = static_cast(array.size()); + int64_t direction = (step > 0) ? 1 : ((step < 0) ? -1 : 0); + int64_t start_val; + int64_t stop_val; + if (direction >= 0) { + start_val = start.value_or(0); + if (start_val < 0) { + start_val = std::max(len + start_val, (int64_t)0); + } else { + start_val = std::min(start_val, len); + } + + stop_val = stop.value_or(len); + if (stop_val < 0) { + stop_val = std::max(len + stop_val, (int64_t)0); + } else { + stop_val = std::min(stop_val, len); + } + } else { + start_val = start.value_or(len - 1); + if (start_val < 0) { + start_val = std::max(len + start_val, (int64_t)-1); + } else { + start_val = std::min(start_val, len - 1); + } + + stop_val = stop.value_or(-1); + if (stop_val < -1) { + stop_val = std::max(len + stop_val, (int64_t)-1); + } else { + stop_val = std::min(stop_val, len - 1); + } + } + T result; + if (direction == 0) { + return result; + } + for (int64_t i = start_val; direction * i < direction * stop_val; i += step) { + if (i >= 0 && i < len) { + result.push_back(array[static_cast(i)]); + } + } + return result; +} + +template +static value test_type_fn(const func_args & args) { + args.ensure_count(1); + bool is_type = is_val(args.args[0]); + JJ_DEBUG("test_type_fn: type=%s result=%d", typeid(T).name(), is_type ? 1 : 0); + return mk_val(is_type); +} +template +static value test_type_fn(const func_args & args) { + args.ensure_count(1); + bool is_type = is_val(args.args[0]) || is_val(args.args[0]); + JJ_DEBUG("test_type_fn: type=%s or %s result=%d", typeid(T).name(), typeid(U).name(), is_type ? 1 : 0); + return mk_val(is_type); +} + +const func_builtins & global_builtins() { + static const func_builtins builtins = { + {"raise_exception", [](const func_args & args) -> value { + args.ensure_vals(); + std::string msg = args.args[0]->as_string().str(); + throw raised_exception("Jinja Exception: " + msg); + }}, + {"namespace", [](const func_args & args) -> value { + auto out = mk_val(); + for (const auto & arg : args.args) { + if (!is_val(arg)) { + throw raised_exception("namespace() arguments must be kwargs"); + } + auto kwarg = cast_val(arg); + JJ_DEBUG("namespace: adding key '%s'", kwarg->key.c_str()); + out->insert(kwarg->key, kwarg->val); + } + return out; + }}, + {"strftime_now", [](const func_args & args) -> value { + args.ensure_vals(); + std::string format = args.args[0]->as_string().str(); + // get current time + // TODO: make sure this is the same behavior as Python's strftime + char buf[100]; + if (std::strftime(buf, sizeof(buf), format.c_str(), std::localtime(&args.ctx.current_time))) { + return mk_val(std::string(buf)); + } else { + throw raised_exception("strftime_now: failed to format time"); + } + }}, + {"range", [](const func_args & args) -> value { + args.ensure_count(1, 3); + args.ensure_vals(true, false, false); + + auto & arg0 = args.args[0]; + auto & arg1 = args.args[1]; + auto & arg2 = args.args[2]; + + int64_t start, stop, step; + if (args.args.size() == 1) { + start = 0; + stop = arg0->as_int(); + step = 1; + } else if (args.args.size() == 2) { + start = arg0->as_int(); + stop = arg1->as_int(); + step = 1; + } else { + start = arg0->as_int(); + stop = arg1->as_int(); + step = arg2->as_int(); + } + + auto out = mk_val(); + if (step == 0) { + throw raised_exception("range() step argument must not be zero"); + } + if (step > 0) { + for (int64_t i = start; i < stop; i += step) { + out->push_back(mk_val(i)); + } + } else { + for (int64_t i = start; i > stop; i += step) { + out->push_back(mk_val(i)); + } + } + return out; + }}, + {"tojson", [](const func_args & args) -> value { + args.ensure_count(1); + // placeholder implementation + return mk_val("TODO: to_json output"); + }}, + + // tests + {"test_is_boolean", test_type_fn}, + {"test_is_callable", test_type_fn}, + {"test_is_odd", [](const func_args & args) -> value { + args.ensure_vals(); + int64_t val = args.args[0]->as_int(); + return mk_val(val % 2 != 0); + }}, + {"test_is_even", [](const func_args & args) -> value { + args.ensure_vals(); + int64_t val = args.args[0]->as_int(); + return mk_val(val % 2 == 0); + }}, + {"test_is_false", [](const func_args & args) -> value { + args.ensure_count(1); + bool val = is_val(args.args[0]) && !args.args[0]->as_bool(); + return mk_val(val); + }}, + {"test_is_true", [](const func_args & args) -> value { + args.ensure_count(1); + bool val = is_val(args.args[0]) && args.args[0]->as_bool(); + return mk_val(val); + }}, + {"test_is_string", test_type_fn}, + {"test_is_integer", test_type_fn}, + {"test_is_number", test_type_fn}, + {"test_is_iterable", test_type_fn}, + {"test_is_sequence", test_type_fn}, + {"test_is_mapping", test_type_fn}, + {"test_is_lower", [](const func_args & args) -> value { + args.ensure_vals(); + return mk_val(args.args[0]->val_str.is_lowercase()); + }}, + {"test_is_upper", [](const func_args & args) -> value { + args.ensure_vals(); + return mk_val(args.args[0]->val_str.is_uppercase()); + }}, + {"test_is_none", test_type_fn}, + {"test_is_defined", [](const func_args & args) -> value { + args.ensure_count(1); + bool res = !args.args[0]->is_undefined(); + JJ_DEBUG("test_is_defined: result=%d", res ? 1 : 0); + return mk_val(res); + }}, + {"test_is_undefined", test_type_fn}, + }; + return builtins; +} + + +const func_builtins & value_int_t::get_builtins() const { + static const func_builtins builtins = { + {"abs", [](const func_args & args) -> value { + args.ensure_vals(); + int64_t val = args.args[0]->as_int(); + return mk_val(val < 0 ? -val : val); + }}, + {"float", [](const func_args & args) -> value { + args.ensure_vals(); + double val = static_cast(args.args[0]->as_int()); + return mk_val(val); + }}, + }; + return builtins; +} + + +const func_builtins & value_float_t::get_builtins() const { + static const func_builtins builtins = { + {"abs", [](const func_args & args) -> value { + args.ensure_vals(); + double val = args.args[0]->as_float(); + return mk_val(val < 0.0 ? -val : val); + }}, + {"int", [](const func_args & args) -> value { + args.ensure_vals(); + int64_t val = static_cast(args.args[0]->as_float()); + return mk_val(val); + }}, + }; + return builtins; +} + +static bool string_startswith(const std::string & str, const std::string & prefix) { + if (str.length() < prefix.length()) return false; + return str.compare(0, prefix.length(), prefix) == 0; +} + +static bool string_endswith(const std::string & str, const std::string & suffix) { + if (str.length() < suffix.length()) return false; + return str.compare(str.length() - suffix.length(), suffix.length(), suffix) == 0; +} + +const func_builtins & value_string_t::get_builtins() const { + static const func_builtins builtins = { + {"upper", [](const func_args & args) -> value { + args.ensure_vals(); + jinja::string str = args.args[0]->as_string().uppercase(); + return mk_val(str); + }}, + {"lower", [](const func_args & args) -> value { + args.ensure_vals(); + jinja::string str = args.args[0]->as_string().lowercase(); + return mk_val(str); + }}, + {"strip", [](const func_args & args) -> value { + args.ensure_vals(); + jinja::string str = args.args[0]->as_string().strip(true, true); + return mk_val(str); + }}, + {"rstrip", [](const func_args & args) -> value { + args.ensure_vals(); + jinja::string str = args.args[0]->as_string().strip(false, true); + return mk_val(str); + }}, + {"lstrip", [](const func_args & args) -> value { + args.ensure_vals(); + jinja::string str = args.args[0]->as_string().strip(true, false); + return mk_val(str); + }}, + {"title", [](const func_args & args) -> value { + args.ensure_vals(); + jinja::string str = args.args[0]->as_string().titlecase(); + return mk_val(str); + }}, + {"capitalize", [](const func_args & args) -> value { + args.ensure_vals(); + jinja::string str = args.args[0]->as_string().capitalize(); + return mk_val(str); + }}, + {"length", [](const func_args & args) -> value { + args.ensure_vals(); + jinja::string str = args.args[0]->as_string(); + return mk_val(str.length()); + }}, + {"startswith", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string().str(); + std::string prefix = args.args[1]->as_string().str(); + return mk_val(string_startswith(str, prefix)); + }}, + {"endswith", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string().str(); + std::string suffix = args.args[1]->as_string().str(); + return mk_val(string_endswith(str, suffix)); + }}, + {"split", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string().str(); + std::string delim = (args.args.size() > 1) ? args.args[1]->as_string().str() : " "; + auto result = mk_val(); + size_t pos = 0; + std::string token; + while ((pos = str.find(delim)) != std::string::npos) { + token = str.substr(0, pos); + result->push_back(mk_val(token)); + str.erase(0, pos + delim.length()); + } + auto res = mk_val(str); + res->val_str.mark_input_based_on(args.args[0]->val_str); + result->push_back(std::move(res)); + return std::move(result); + }}, + {"replace", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string().str(); + std::string old_str = args.args[1]->as_string().str(); + std::string new_str = args.args[2]->as_string().str(); + size_t pos = 0; + while ((pos = str.find(old_str, pos)) != std::string::npos) { + str.replace(pos, old_str.length(), new_str); + pos += new_str.length(); + } + auto res = mk_val(str); + res->val_str.mark_input_based_on(args.args[0]->val_str); + return res; + }}, + {"int", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string().str(); + try { + return mk_val(std::stoi(str)); + } catch (...) { + throw std::runtime_error("Cannot convert string '" + str + "' to int"); + } + }}, + {"float", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string().str(); + try { + return mk_val(std::stod(str)); + } catch (...) { + throw std::runtime_error("Cannot convert string '" + str + "' to float"); + } + }}, + {"string", [](const func_args & args) -> value { + // no-op + args.ensure_vals(); + return mk_val(args.args[0]->as_string()); + }}, + {"default", [](const func_args & args) -> value { + value input = args.args[0]; + if (!is_val(input)) { + throw raised_exception("default() first argument must be a string"); + } + value default_val = mk_val(""); + if (args.args.size() > 1 && !args.args[1]->is_undefined()) { + default_val = args.args[1]; + } + value boolean_val = mk_val(false); + if (args.args.size() > 1) { + boolean_val = args.args[1]; + } + if (input->is_undefined() || (boolean_val->as_bool() && !input->as_bool())) { + return default_val; + } else { + return input; + } + }}, + {"slice", [](const func_args & args) -> value { + auto & input = args.args[0]; + if (!is_val(input)) { + throw raised_exception("slice() first argument must be a string"); + } + if (args.args.size() < 1 || args.args.size() > 4) { + throw raised_exception("slice() takes between 1 and 4 arguments"); + } + int64_t start = is_val(args.args[1]) ? args.args[1]->as_int() : 0; + int64_t stop = is_val(args.args[2]) ? args.args[2]->as_int() : -1; + int64_t step = is_val(args.args[3]) ? args.args[3]->as_int() : 1; + if (step == 0) { + throw raised_exception("slice step cannot be zero"); + } + auto sliced = slice(input->as_string().str(), start, stop, step); + auto res = mk_val(sliced); + res->val_str.mark_input_based_on(input->as_string()); + return res; + }}, + {"safe", [](const func_args & args) -> value { + // no-op for now + args.ensure_vals(); + return args.args[0]; + }}, + {"selectattr", [](const func_args &) -> value { + throw std::runtime_error("String selectattr builtin not supported"); + }}, + {"rejectattr", [](const func_args &) -> value { + throw std::runtime_error("String rejectattr builtin not supported"); + }}, + {"indent", [](const func_args &) -> value { + throw std::runtime_error("String indent builtin not implemented"); + }}, + {"join", [](const func_args &) -> value { + throw std::runtime_error("String join builtin not implemented"); + }}, + }; + return builtins; +} + + +const func_builtins & value_bool_t::get_builtins() const { + static const func_builtins builtins = { + {"int", [](const func_args & args) -> value { + args.ensure_vals(); + bool val = args.args[0]->as_bool(); + return mk_val(val ? 1 : 0); + }}, + {"float", [](const func_args & args) -> value { + args.ensure_vals(); + bool val = args.args[0]->as_bool(); + return mk_val(val ? 1.0 : 0.0); + }}, + {"string", [](const func_args & args) -> value { + args.ensure_vals(); + bool val = args.args[0]->as_bool(); + return mk_val(val ? "True" : "False"); + }}, + }; + return builtins; +} + + +const func_builtins & value_array_t::get_builtins() const { + static const func_builtins builtins = { + {"list", [](const func_args & args) -> value { + args.ensure_vals(); + const auto & arr = args.args[0]->as_array(); + auto result = mk_val(); + for (const auto& v : arr) { + result->push_back(v); + } + return result; + }}, + {"first", [](const func_args & args) -> value { + args.ensure_vals(); + const auto & arr = args.args[0]->as_array(); + if (arr.empty()) { + return mk_val(); + } + return arr[0]; + }}, + {"last", [](const func_args & args) -> value { + args.ensure_vals(); + const auto & arr = args.args[0]->as_array(); + if (arr.empty()) { + return mk_val(); + } + return arr[arr.size() - 1]; + }}, + {"length", [](const func_args & args) -> value { + args.ensure_vals(); + const auto & arr = args.args[0]->as_array(); + return mk_val(static_cast(arr.size())); + }}, + {"slice", [](const func_args & args) -> value { + if (args.args.size() < 1 || args.args.size() > 4) { + throw raised_exception("slice() takes between 1 and 4 arguments"); + } + int64_t start = is_val(args.args[1]) ? args.args[1]->as_int() : 0; + int64_t stop = is_val(args.args[2]) ? args.args[2]->as_int() : -1; + int64_t step = is_val(args.args[3]) ? args.args[3]->as_int() : 1; + if (!is_val(args.args[0])) { + throw raised_exception("slice() first argument must be an array"); + } + if (step == 0) { + throw raised_exception("slice step cannot be zero"); + } + auto arr = slice(args.args[0]->as_array(), start, stop, step); + auto res = mk_val(); + res->val_arr = std::move(arr); + return res; + }}, + {"selectattr", [](const func_args & args) -> value { + value input = args.args[0]; + if (!is_val(input)) { + throw raised_exception("selectattr() first argument must be an array, got " + input->type()); + } + std::vector selected; + for (size_t i = 1; i < args.args.size(); ++i) { + const auto & v = args.args[i]; + if (!is_val(v)) { + throw raised_exception("selectattr() attributes must be strings, got " + v->type()); + } + JJ_DEBUG("selectattr: selecting attribute '%s'", v->as_string().str().c_str()); + selected.push_back(v->as_string().str()); + } + auto result = mk_val(); + for (const auto & item : input->as_array()) { + if (!is_val(item)) { + continue; + } + const auto & obj = item->as_object(); + bool match = true; + for (const auto & attr : selected) { + auto it = obj.find(attr); + if (it == obj.end() || it->second->is_undefined() || (is_val(it->second) && !it->second->as_bool())) { + match = false; + break; + } + } + if (match) { + result->push_back(item); + } + } + return result; + }}, + {"rejectattr", [](const func_args & args) -> value { + value input = args.args[0]; + if (!is_val(input)) { + throw raised_exception("rejectattr() first argument must be an array, got " + input->type()); + } + std::vector rejected; + for (size_t i = 1; i < args.args.size(); ++i) { + const auto & v = args.args[i]; + if (!is_val(v)) { + throw raised_exception("rejectattr() attributes must be strings, got " + v->type()); + } + JJ_DEBUG("rejectattr: rejecting attribute '%s'", v->as_string().str().c_str()); + rejected.push_back(v->as_string().str()); + } + auto result = mk_val(); + for (const auto & item : input->as_array()) { + if (!is_val(item)) { + result->push_back(item); + continue; + } + const auto & obj = item->as_object(); + bool match = false; + for (const auto & attr : rejected) { + auto it = obj.find(attr); + if (it != obj.end() && !it->second->is_undefined() && (!is_val(it->second) || it->second->as_bool())) { + match = true; + break; + } + } + if (!match) { + result->push_back(item); + } + } + return result; + }}, + {"join", [](const func_args & args) -> value { + if (args.args.size() < 1 || args.args.size() > 2) { + throw raised_exception("join() takes one or two arguments"); + } + if (!is_val(args.args[0])) { + throw raised_exception("join() first argument must be an array"); + } + const auto & arr = args.args[0]->as_array(); + std::string delim = (args.args.size() > 1 && is_val(args.args[1])) ? args.args[1]->as_string().str() : ""; + std::string result; + for (size_t i = 0; i < arr.size(); ++i) { + if (!is_val(arr[i])) { + throw raised_exception("join() can only join arrays of strings"); + } + result += arr[i]->as_string().str(); + if (i < arr.size() - 1) { + result += delim; + } + } + return mk_val(result); + }}, + {"string", [](const func_args & args) -> value { + args.ensure_vals(); + auto str = mk_val(); + gather_string_parts_recursive(args.args[0], str); + return str; + }}, + {"sort", [](const func_args &) -> value { + throw std::runtime_error("Array sort builtin not implemented"); + }}, + {"reverse", [](const func_args &) -> value { + throw std::runtime_error("Array reverse builtin not implemented"); + }}, + {"unique", [](const func_args &) -> value { + throw std::runtime_error("Array unique builtin not implemented"); + }}, + }; + return builtins; +} + + +const func_builtins & value_object_t::get_builtins() const { + static const func_builtins builtins = { + {"get", [](const func_args & args) -> value { + args.ensure_vals(); // TODO: add default value + const auto & obj = args.args[0]->as_object(); + std::string key = args.args[1]->as_string().str(); + auto it = obj.find(key); + if (it != obj.end()) { + return it->second; + } else { + return mk_val(); + } + }}, + {"keys", [](const func_args & args) -> value { + args.ensure_vals(); + const auto & obj = args.args[0]->as_object(); + auto result = mk_val(); + for (const auto & pair : obj) { + result->push_back(mk_val(pair.first)); + } + return result; + }}, + {"values", [](const func_args & args) -> value { + args.ensure_vals(); + const auto & obj = args.args[0]->as_object(); + auto result = mk_val(); + for (const auto & pair : obj) { + result->push_back(pair.second); + } + return result; + }}, + {"items", [](const func_args & args) -> value { + args.ensure_vals(); + const auto & obj = args.args[0]->as_object(); + auto result = mk_val(); + for (const auto & pair : obj) { + auto item = mk_val(); + item->push_back(mk_val(pair.first)); + item->push_back(pair.second); + result->push_back(std::move(item)); + } + return result; + }}, + {"string", [](const func_args & args) -> value { + args.ensure_vals(); + return mk_val("TO BE IMPLEMENTED"); + }}, + {"tojson", [](const func_args & args) -> value { + args.ensure_vals(); + // use global to_json + return global_builtins().at("tojson")(args); + }}, + {"dictsort", [](const func_args & args) -> value { + // no-op + args.ensure_vals(); + return args.args[0]; + }}, + }; + return builtins; +} + +const func_builtins & value_null_t::get_builtins() const { + static const func_builtins builtins = { + // TODO: may need to implement this, idk + }; + return builtins; +} + + +////////////////////////////////// + + +static value from_json(const nlohmann::json & j) { + if (j.is_null()) { + return mk_val(); + } else if (j.is_boolean()) { + return mk_val(j.get()); + } else if (j.is_number_integer()) { + return mk_val(j.get()); + } else if (j.is_number_float()) { + return mk_val(j.get()); + } else if (j.is_string()) { + return mk_val(j.get()); + } else if (j.is_array()) { + auto arr = mk_val(); + for (const auto & item : j) { + arr->push_back(from_json(item)); + } + return arr; + } else if (j.is_object()) { + if (j.contains("__input__")) { + // handle input marking + auto str = mk_val(j.at("__input__").get()); + str->mark_input(); + return str; + } else { + // normal object + auto obj = mk_val(); + for (auto it = j.begin(); it != j.end(); ++it) { + obj->insert(it.key(), from_json(it.value())); + } + return obj; + } + } else { + throw std::runtime_error("Unsupported JSON value type"); + } +} + +template<> +void global_from_json(context & ctx, const nlohmann::json & json_obj) { + if (json_obj.is_null() || !json_obj.is_object()) { + throw std::runtime_error("global_from_json: input JSON value must be an object"); + } + for (auto it = json_obj.begin(); it != json_obj.end(); ++it) { + ctx.set_val(it.key(), from_json(it.value())); + } +} + +} // namespace jinja diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h new file mode 100644 index 0000000000..9cb57f90f3 --- /dev/null +++ b/common/jinja/jinja-value.h @@ -0,0 +1,352 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "jinja-string.h" + +namespace jinja { + +struct value_t; +using value = std::shared_ptr; + + +// Helper to check the type of a value +template +struct extract_pointee { + using type = T; +}; +template +struct extract_pointee> { + using type = U; +}; +template +bool is_val(const value & ptr) { + using PointeeType = typename extract_pointee::type; + return dynamic_cast(ptr.get()) != nullptr; +} +template +bool is_val(const value_t * ptr) { + using PointeeType = typename extract_pointee::type; + return dynamic_cast(ptr) != nullptr; +} +template +std::shared_ptr::type> mk_val(Args&&... args) { + using PointeeType = typename extract_pointee::type; + return std::make_shared(std::forward(args)...); +} +template +const typename extract_pointee::type * cast_val(const value & ptr) { + using PointeeType = typename extract_pointee::type; + return dynamic_cast(ptr.get()); +} +template +typename extract_pointee::type * cast_val(value & ptr) { + using PointeeType = typename extract_pointee::type; + return dynamic_cast(ptr.get()); +} +// End Helper + + +struct context; // forward declaration + + +// for converting from JSON to jinja values +// example input JSON: +// { +// "messages": [ +// {"role": "user", "content": "Hello!"}, +// {"role": "assistant", "content": "Hi there!"} +// ], +// "bos_token": "", +// "eos_token": "", +// } +// +// to mark strings as user input, wrap them in a special object: +// { +// "messages": [ +// { +// "role": "user", +// "content": {"__input__": "Hello!"} // this string is user input +// }, +// ... +// ], +// } +// +// marking input can be useful for tracking data provenance +// and preventing template injection attacks +// +// Note: T_JSON can be nlohmann::json or similar types +template +void global_from_json(context & ctx, const T_JSON & json_obj); + +// +// base value type +// + +struct func_args; // function argument values + +using func_handler = std::function; +using func_builtins = std::map; + +bool value_compare(const value & a, const value & b); + +struct value_t { + int64_t val_int; + double val_flt; + string val_str; + bool val_bool; + + std::vector val_arr; + std::map val_obj; + + func_handler val_func; + + // only used if ctx.is_get_stats = true + struct stats_t { + bool used = false; + // ops can be builtin calls or operators: "array_access", "object_access" + std::set ops; + } stats; + + value_t() = default; + value_t(const value_t &) = default; + virtual ~value_t() = default; + + virtual std::string type() const { return ""; } + + virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); } + virtual double as_float() const { throw std::runtime_error(type() + " is not a float value"); } + virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); } + virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); } + virtual const std::vector & as_array() const { throw std::runtime_error(type() + " is not an array value"); } + virtual const std::map & as_object() const { throw std::runtime_error(type() + " is not an object value"); } + virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); } + virtual bool is_null() const { return false; } + virtual bool is_undefined() const { return false; } + virtual const func_builtins & get_builtins() const { + throw std::runtime_error("No builtins available for type " + type()); + } + + virtual value & at(const std::string & key) { return val_obj[key]; } + virtual value & at(size_t index) { return val_arr.at(index); } + + virtual std::string as_repr() const { return as_string().str(); } +}; + +// +// primitive value types +// + +struct value_int_t : public value_t { + value_int_t(int64_t v) { val_int = v; } + virtual std::string type() const override { return "Integer"; } + virtual int64_t as_int() const override { return val_int; } + virtual double as_float() const override { return static_cast(val_int); } + virtual string as_string() const override { return std::to_string(val_int); } + virtual const func_builtins & get_builtins() const override; +}; +using value_int = std::shared_ptr; + + +struct value_float_t : public value_t { + value_float_t(double v) { val_flt = v; } + virtual std::string type() const override { return "Float"; } + virtual double as_float() const override { return val_flt; } + virtual int64_t as_int() const override { return static_cast(val_flt); } + virtual string as_string() const override { return std::to_string(val_flt); } + virtual const func_builtins & get_builtins() const override; +}; +using value_float = std::shared_ptr; + + +struct value_string_t : public value_t { + value_string_t() { val_str = string(); } + value_string_t(const std::string & v) { val_str = string(v); } + value_string_t(const string & v) { val_str = v; } + virtual std::string type() const override { return "String"; } + virtual string as_string() const override { return val_str; } + virtual std::string as_repr() const override { + std::ostringstream ss; + for (const auto & part : val_str.parts) { + ss << (part.is_input ? "INPUT: " : "TMPL: ") << part.val << "\n"; + } + return ss.str(); + } + virtual bool as_bool() const override { + return val_str.length() > 0; + } + virtual const func_builtins & get_builtins() const override; + void mark_input() { + val_str.mark_input(); + } +}; +using value_string = std::shared_ptr; + + +struct value_bool_t : public value_t { + value_bool_t(bool v) { val_bool = v; } + virtual std::string type() const override { return "Boolean"; } + virtual bool as_bool() const override { return val_bool; } + virtual string as_string() const override { return std::string(val_bool ? "True" : "False"); } + virtual const func_builtins & get_builtins() const override; +}; +using value_bool = std::shared_ptr; + + +struct value_array_t : public value_t { + value_array_t() = default; + value_array_t(value & v) { + // point to the same underlying data + val_arr = v->val_arr; + } + void push_back(const value & val) { + val_arr.push_back(val); + } + virtual std::string type() const override { return "Array"; } + virtual const std::vector & as_array() const override { return val_arr; } + virtual string as_string() const override { + std::ostringstream ss; + ss << "["; + for (size_t i = 0; i < val_arr.size(); i++) { + if (i > 0) ss << ", "; + ss << val_arr.at(i)->as_repr(); + } + ss << "]"; + return ss.str(); + } + virtual bool as_bool() const override { + return !val_arr.empty(); + } + virtual const func_builtins & get_builtins() const override; +}; +using value_array = std::shared_ptr; + + +struct value_object_t : public value_t { + value_object_t() = default; + value_object_t(value & v) { + // point to the same underlying data + val_obj = v->val_obj; + } + value_object_t(const std::map & obj) { + val_obj = std::map(); + for (const auto & pair : obj) { + val_obj[pair.first] = pair.second; + } + } + void insert(const std::string & key, const value & val) { + val_obj[key] = val; + } + virtual std::string type() const override { return "Object"; } + virtual const std::map & as_object() const override { return val_obj; } + virtual bool as_bool() const override { + return !val_obj.empty(); + } + virtual const func_builtins & get_builtins() const override; +}; +using value_object = std::shared_ptr; + +// +// null and undefined types +// + +struct value_null_t : public value_t { + virtual std::string type() const override { return "Null"; } + virtual bool is_null() const override { return true; } + virtual bool as_bool() const override { return false; } + virtual std::string as_repr() const override { return type(); } + virtual const func_builtins & get_builtins() const override; +}; +using value_null = std::shared_ptr; + + +struct value_undefined_t : public value_t { + std::string hint; // for debugging, to indicate where undefined came from + value_undefined_t(const std::string & h = "") : hint(h) {} + virtual std::string type() const override { return hint.empty() ? "Undefined" : "Undefined (hint: '" + hint + "')"; } + virtual bool is_undefined() const override { return true; } + virtual bool as_bool() const override { return false; } + virtual std::string as_repr() const override { return type(); } +}; +using value_undefined = std::shared_ptr; + +// +// function type +// + +struct func_args { + std::string func_name; // for error messages + std::vector args; + context & ctx; + func_args(context & ctx) : ctx(ctx) {} + value get_kwarg(const std::string & key) const; + void ensure_count(size_t min, size_t max = 999) const { + size_t n = args.size(); + if (n < min || n > max) { + throw std::runtime_error("Function '" + func_name + "' expected between " + std::to_string(min) + " and " + std::to_string(max) + " arguments, got " + std::to_string(n)); + } + } + template void ensure_val(const value & ptr) const { + if (!is_val(ptr)) { + throw std::runtime_error("Function '" + func_name + "' expected value of type " + std::string(typeid(T).name()) + ", got " + ptr->type()); + } + } + template void ensure_vals(bool required0 = true) const { + if (required0 && args.size() > 0) ensure_val(args[0]); + } + template void ensure_vals(bool required0 = true, bool required1 = true) const { + if (required0 && args.size() > 0) ensure_val(args[0]); + if (required1 && args.size() > 1) ensure_val(args[1]); + } + template void ensure_vals(bool required0 = true, bool required1 = true, bool required2 = true) const { + if (required0 && args.size() > 0) ensure_val(args[0]); + if (required1 && args.size() > 1) ensure_val(args[1]); + if (required2 && args.size() > 2) ensure_val(args[2]); + } +}; + +struct value_func_t : public value_t { + std::string name; + value arg0; // bound "this" argument, if any + value_func_t(const std::string & name, const func_handler & func) : name(name) { + val_func = func; + } + value_func_t(const std::string & name, const func_handler & func, const value & arg_this) : name(name), arg0(arg_this) { + val_func = func; + } + virtual value invoke(const func_args & args) const override { + func_args new_args(args); // copy + new_args.func_name = name; + if (arg0) { + new_args.args.insert(new_args.args.begin(), arg0); + } + return val_func(new_args); + } + virtual std::string type() const override { return "Function"; } + virtual std::string as_repr() const override { return type(); } +}; +using value_func = std::shared_ptr; + +// special value for kwarg +struct value_kwarg_t : public value_t { + std::string key; + value val; + value_kwarg_t(const std::string & k, const value & v) : key(k), val(v) {} + virtual std::string type() const override { return "KwArg"; } + virtual std::string as_repr() const override { return type(); } +}; +using value_kwarg = std::shared_ptr; + + +// utils + +const func_builtins & global_builtins(); + + +} // namespace jinja diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp new file mode 100644 index 0000000000..0728054c13 --- /dev/null +++ b/common/jinja/jinja-vm.cpp @@ -0,0 +1,794 @@ +#include "jinja-lexer.h" +#include "jinja-vm.h" +#include "jinja-parser.h" +#include "jinja-value.h" +#include "jinja-utils.h" + +#include +#include +#include +#include + +#define FILENAME "jinja-vm" + +bool g_jinja_debug = false; + +namespace jinja { + +void enable_debug(bool enable) { + g_jinja_debug = enable; +} + +static value_string exec_statements(const statements & stmts, context & ctx) { + auto result = mk_val(); + for (const auto & stmt : stmts) { + JJ_DEBUG("Executing statement of type %s", stmt->type().c_str()); + result->push_back(stmt->execute(ctx)); + } + // convert to string parts + value_string str = mk_val(); + gather_string_parts_recursive(result, str); + return str; +} + +// execute with error handling +value statement::execute(context & ctx) { + try { + return execute_impl(ctx); + } catch (const continue_statement::signal & ex) { + throw ex; + } catch (const break_statement::signal & ex) { + throw ex; + } catch (const std::exception & e) { + if (ctx.source.empty()) { + std::ostringstream oss; + oss << "\nError executing " << type() << " at position " << pos << ": " << e.what(); + throw raised_exception(oss.str()); + } else { + std::ostringstream oss; + constexpr int max_peak_chars = 40; + oss << "\n------------\n"; + oss << "While executing " << type() << " at position " << pos << " in source:\n"; + size_t start = (pos >= max_peak_chars) ? (pos - max_peak_chars) : 0; + size_t end = std::min(pos + max_peak_chars, ctx.source.length()); + std::string substr = ctx.source.substr(start, end - start); + string_replace_all(substr, "\n", "\\n"); + oss << "..." << substr << "...\n"; + std::string spaces(pos - start + 3, ' '); + oss << spaces << "^\n"; + oss << "Error: " << e.what(); + throw raised_exception(oss.str()); + } + } +} + +value identifier::execute_impl(context & ctx) { + auto it = ctx.get_val(val); + auto builtins = global_builtins(); + if (!it->is_undefined()) { + if (ctx.is_get_stats) { + it->stats.used = true; + } + JJ_DEBUG("Identifier '%s' found", val.c_str()); + return it; + } else if (builtins.find(val) != builtins.end()) { + JJ_DEBUG("Identifier '%s' found in builtins", val.c_str()); + return mk_val(val, builtins.at(val)); + } else { + JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str()); + return mk_val(val); + } +} + +value object_literal::execute_impl(context & ctx) { + auto obj = mk_val(); + for (const auto & pair : val) { + std::string key = pair.first->execute(ctx)->as_string().str(); + value val = pair.second->execute(ctx); + JJ_DEBUG("Object literal: setting key '%s' of type %s", key.c_str(), val->type().c_str()); + obj->val_obj[key] = val; + } + return obj; +} + +value binary_expression::execute_impl(context & ctx) { + value left_val = left->execute(ctx); + + // Logical operators + if (op.value == "and") { + return left_val->as_bool() ? right->execute(ctx) : std::move(left_val); + } else if (op.value == "or") { + return left_val->as_bool() ? std::move(left_val) : right->execute(ctx); + } + + // Equality operators + value right_val = right->execute(ctx); + JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right_val->type().c_str()); + if (op.value == "==") { + return mk_val(value_compare(left_val, right_val)); + } else if (op.value == "!=") { + return mk_val(!value_compare(left_val, right_val)); + } + + // Handle undefined and null values + if (is_val(left_val) || is_val(right_val)) { + if (is_val(right_val) && (op.value == "in" || op.value == "not in")) { + // Special case: `anything in undefined` is `false` and `anything not in undefined` is `true` + return mk_val(op.value == "not in"); + } + // if (ctx.wrk_around.string_plus_undefined_is_string && (op.value == "+" || op.value == "~")) { + // JJ_DEBUG("%s", "Workaround: treating undefined as empty string for string concatenation"); + // auto left_str = left_val->is_undefined() ? string() : left_val->as_string(); + // auto right_str = right_val->is_undefined() ? string() : right_val->as_string(); + // auto output = left_str.append(right_str); + // auto res = mk_val(); + // res->val_str = std::move(output); + // return res; + // } + throw std::runtime_error("Cannot perform operation " + op.value + " on undefined values"); + } else if (is_val(left_val) || is_val(right_val)) { + throw std::runtime_error("Cannot perform operation on null values"); + } + + // Float operations + if ((is_val(left_val) || is_val(left_val)) && + (is_val(right_val) || is_val(right_val))) { + double a = left_val->as_float(); + double b = right_val->as_float(); + if (op.value == "+" || op.value == "-" || op.value == "*") { + double res = (op.value == "+") ? a + b : (op.value == "-") ? a - b : a * b; + JJ_DEBUG("Arithmetic operation: %f %s %f = %f", a, op.value.c_str(), b, res); + bool is_float = is_val(left_val) || is_val(right_val); + if (is_float) { + return mk_val(res); + } else { + return mk_val(static_cast(res)); + } + } else if (op.value == "/") { + JJ_DEBUG("Division operation: %f / %f", a, b); + return mk_val(a / b); + } else if (op.value == "%") { + double rem = std::fmod(a, b); + JJ_DEBUG("Modulo operation: %f %% %f = %f", a, b, rem); + bool is_float = is_val(left_val) || is_val(right_val); + if (is_float) { + return mk_val(rem); + } else { + return mk_val(static_cast(rem)); + } + } else if (op.value == "<") { + JJ_DEBUG("Comparison operation: %f < %f is %d", a, b, a < b); + return mk_val(a < b); + } else if (op.value == ">") { + JJ_DEBUG("Comparison operation: %f > %f is %d", a, b, a > b); + return mk_val(a > b); + } else if (op.value == ">=") { + JJ_DEBUG("Comparison operation: %f >= %f is %d", a, b, a >= b); + return mk_val(a >= b); + } else if (op.value == "<=") { + JJ_DEBUG("Comparison operation: %f <= %f is %d", a, b, a <= b); + return mk_val(a <= b); + } + } + + // Array operations + if (is_val(left_val) && is_val(right_val)) { + if (op.value == "+") { + auto & left_arr = left_val->as_array(); + auto & right_arr = right_val->as_array(); + auto result = mk_val(); + for (const auto & item : left_arr) { + result->push_back(item); + } + for (const auto & item : right_arr) { + result->push_back(item); + } + return result; + } + } else if (is_val(right_val)) { + auto & arr = right_val->as_array(); + bool member = false; + for (const auto & item : arr) { + if (value_compare(left_val, item)) { + member = true; + break; + } + } + if (op.value == "in") { + JJ_DEBUG("Checking membership: %s in Array is %d", left_val->type().c_str(), member); + return mk_val(member); + } else if (op.value == "not in") { + JJ_DEBUG("Checking non-membership: %s not in Array is %d", left_val->type().c_str(), !member); + return mk_val(!member); + } + } + + // String concatenation with ~ and + + if ((is_val(left_val) || is_val(right_val)) && + (op.value == "~" || op.value == "+")) { + JJ_DEBUG("String concatenation with %s operator", op.value.c_str()); + auto output = left_val->as_string().append(right_val->as_string()); + auto res = mk_val(); + res->val_str = std::move(output); + return res; + } + + // String membership + if (is_val(left_val) && is_val(right_val)) { + auto left_str = left_val->as_string().str(); + auto right_str = right_val->as_string().str(); + if (op.value == "in") { + return mk_val(right_str.find(left_str) != std::string::npos); + } else if (op.value == "not in") { + return mk_val(right_str.find(left_str) == std::string::npos); + } + } + + // String in object + if (is_val(left_val) && is_val(right_val)) { + auto key = left_val->as_string().str(); + auto & obj = right_val->as_object(); + bool has_key = obj.find(key) != obj.end(); + if (op.value == "in") { + return mk_val(has_key); + } else if (op.value == "not in") { + return mk_val(!has_key); + } + } + + throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type()); +} + +static value try_builtin_func(context & ctx, const std::string & name, value & input, bool undef_on_missing = false) { + JJ_DEBUG("Trying built-in function '%s' for type %s", name.c_str(), input->type().c_str()); + if (ctx.is_get_stats) { + input->stats.used = true; + input->stats.ops.insert(name); + } + auto builtins = input->get_builtins(); + auto it = builtins.find(name); + if (it != builtins.end()) { + JJ_DEBUG("Binding built-in '%s'", name.c_str()); + return mk_val(name, it->second, input); + } + if (undef_on_missing) { + return mk_val(name); + } + throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type()); +} + +value filter_expression::execute_impl(context & ctx) { + value input = operand ? operand->execute(ctx) : val; + + JJ_DEBUG("Applying filter to %s", input->type().c_str()); + + if (is_stmt(filter)) { + auto filter_id = cast_stmt(filter)->val; + + if (filter_id == "to_json") { + // TODO: Implement to_json filter + throw std::runtime_error("to_json filter not implemented"); + } + + if (filter_id == "trim") { + filter_id = "strip"; // alias + } + JJ_DEBUG("Applying filter '%s' to %s", filter_id.c_str(), input->type().c_str()); + return try_builtin_func(ctx, filter_id, input)->invoke(func_args(ctx)); + + } else if (is_stmt(filter)) { + auto call = cast_stmt(filter); + auto filter_id = cast_stmt(call->callee)->val; + + JJ_DEBUG("Applying filter '%s' with arguments to %s", filter_id.c_str(), input->type().c_str()); + func_args args(ctx); + for (const auto & arg_expr : call->args) { + args.args.push_back(arg_expr->execute(ctx)); + } + + return try_builtin_func(ctx, filter_id, input)->invoke(args); + + } else { + throw std::runtime_error("Invalid filter expression"); + } +} + +value filter_statement::execute_impl(context & ctx) { + // eval body as string, then apply filter + auto body_val = exec_statements(body, ctx); + value_string parts = mk_val(); + gather_string_parts_recursive(body_val, parts); + + JJ_DEBUG("FilterStatement: applying filter to body string of length %zu", parts->val_str.length()); + filter_expression filter_expr(std::move(parts), std::move(filter)); + return filter_expr.execute(ctx); +} + +value test_expression::execute_impl(context & ctx) { + // NOTE: "value is something" translates to function call "test_is_something(value)" + const auto & builtins = global_builtins(); + if (!is_stmt(test)) { + throw std::runtime_error("Invalid test expression"); + } + + auto test_id = cast_stmt(test)->val; + auto it = builtins.find("test_is_" + test_id); + JJ_DEBUG("Test expression %s '%s' %s (using function 'test_is_%s')", operand->type().c_str(), test_id.c_str(), negate ? "(negate)" : "", test_id.c_str()); + if (it == builtins.end()) { + throw std::runtime_error("Unknown test '" + test_id + "'"); + } + + value input = operand->execute(ctx); + + func_args args(ctx); + args.args.push_back(input); + auto res = it->second(args); + + if (negate) { + return mk_val(!res->as_bool()); + } else { + return res; + } +} + +value unary_expression::execute_impl(context & ctx) { + value operand_val = argument->execute(ctx); + JJ_DEBUG("Executing unary expression with operator '%s'", op.value.c_str()); + + if (op.value == "not") { + return mk_val(!operand_val->as_bool()); + } else if (op.value == "-") { + if (is_val(operand_val)) { + return mk_val(-operand_val->as_int()); + } else if (is_val(operand_val)) { + return mk_val(-operand_val->as_float()); + } else { + throw std::runtime_error("Unary - operator requires numeric operand"); + } + } + + throw std::runtime_error("Unknown unary operator '" + op.value + "'"); +} + +value if_statement::execute_impl(context & ctx) { + value test_val = test->execute(ctx); + + auto out = mk_val(); + if (test_val->as_bool()) { + for (auto & stmt : body) { + JJ_DEBUG("IF --> Executing THEN body, current block: %s", stmt->type().c_str()); + out->push_back(stmt->execute(ctx)); + } + } else { + for (auto & stmt : alternate) { + JJ_DEBUG("IF --> Executing ELSE body, current block: %s", stmt->type().c_str()); + out->push_back(stmt->execute(ctx)); + } + } + // convert to string parts + value_string str = mk_val(); + gather_string_parts_recursive(out, str); + return str; +} + +value for_statement::execute_impl(context & ctx) { + context scope(ctx); // new scope for loop variables + + jinja::select_expression * select_expr = cast_stmt(iterable); + statement_ptr test_expr_nullptr; + + statement_ptr & iter_expr = [&]() -> statement_ptr & { + auto tmp = cast_stmt(iterable); + return tmp ? tmp->lhs : iterable; + }(); + statement_ptr & test_expr = [&]() -> statement_ptr & { + auto tmp = cast_stmt(iterable); + return tmp ? tmp->test : test_expr_nullptr; + }(); + + JJ_DEBUG("Executing for statement, iterable type: %s", iter_expr->type().c_str()); + + value iterable_val = iter_expr->execute(scope); + + if (iterable_val->is_undefined()) { + JJ_DEBUG("%s", "For loop iterable is undefined, skipping loop"); + iterable_val = mk_val(); + } + + if (!is_val(iterable_val) && !is_val(iterable_val)) { + throw std::runtime_error("Expected iterable or object type in for loop: got " + iterable_val->type()); + } + + std::vector items; + if (is_val(iterable_val)) { + JJ_DEBUG("%s", "For loop over object keys"); + auto & obj = iterable_val->as_object(); + for (auto & p : obj) { + auto tuple = mk_val(); + tuple->push_back(mk_val(p.first)); + tuple->push_back(p.second); + items.push_back(tuple); + } + if (ctx.is_get_stats) { + iterable_val->stats.used = true; + iterable_val->stats.ops.insert("object_access"); + } + } else { + JJ_DEBUG("%s", "For loop over array items"); + auto & arr = iterable_val->as_array(); + for (const auto & item : arr) { + items.push_back(item); + } + if (ctx.is_get_stats) { + iterable_val->stats.used = true; + iterable_val->stats.ops.insert("array_access"); + } + } + + std::vector> scope_update_fns; + + std::vector filtered_items; + for (size_t i = 0; i < items.size(); ++i) { + context loop_scope(scope); + + const value & current = items[i]; + + std::function scope_update_fn = [](context &) { /* no-op */}; + if (is_stmt(loopvar)) { + auto id = cast_stmt(loopvar)->val; + scope_update_fn = [id, &items, i](context & ctx) { + ctx.set_val(id, items[i]); + }; + } else if (is_stmt(loopvar)) { + auto tuple = cast_stmt(loopvar); + if (!is_val(current)) { + throw std::runtime_error("Cannot unpack non-iterable type: " + current->type()); + } + auto & c_arr = current->as_array(); + if (tuple->val.size() != c_arr.size()) { + throw std::runtime_error(std::string("Too ") + (tuple->val.size() > c_arr.size() ? "few" : "many") + " items to unpack"); + } + scope_update_fn = [tuple, &items, i](context & ctx) { + auto & c_arr = items[i]->as_array(); + for (size_t j = 0; j < tuple->val.size(); ++j) { + if (!is_stmt(tuple->val[j])) { + throw std::runtime_error("Cannot unpack non-identifier type: " + tuple->val[j]->type()); + } + auto id = cast_stmt(tuple->val[j])->val; + ctx.set_val(id, c_arr[j]); + } + }; + } else { + throw std::runtime_error("Invalid loop variable(s): " + loopvar->type()); + } + if (select_expr && test_expr) { + scope_update_fn(loop_scope); + value test_val = test_expr->execute(loop_scope); + if (!test_val->as_bool()) { + continue; + } + } + JJ_DEBUG("For loop: adding item type %s at index %zu", current->type().c_str(), i); + filtered_items.push_back(current); + scope_update_fns.push_back(scope_update_fn); + } + JJ_DEBUG("For loop: %zu items after filtering", filtered_items.size()); + + auto result = mk_val(); + + bool noIteration = true; + for (size_t i = 0; i < filtered_items.size(); i++) { + JJ_DEBUG("For loop iteration %zu/%zu", i + 1, filtered_items.size()); + value_object loop_obj = mk_val(); + loop_obj->insert("index", mk_val(i + 1)); + loop_obj->insert("index0", mk_val(i)); + loop_obj->insert("revindex", mk_val(filtered_items.size() - i)); + loop_obj->insert("revindex0", mk_val(filtered_items.size() - i - 1)); + loop_obj->insert("first", mk_val(i == 0)); + loop_obj->insert("last", mk_val(i == filtered_items.size() - 1)); + loop_obj->insert("length", mk_val(filtered_items.size())); + loop_obj->insert("previtem", i > 0 ? filtered_items[i - 1] : mk_val("previtem")); + loop_obj->insert("nextitem", i < filtered_items.size() - 1 ? filtered_items[i + 1] : mk_val("nextitem")); + scope.set_val("loop", loop_obj); + scope_update_fns[i](scope); + try { + for (auto & stmt : body) { + value val = stmt->execute(scope); + result->push_back(val); + } + } catch (const continue_statement::signal &) { + continue; + } catch (const break_statement::signal &) { + break; + } + noIteration = false; + } + + JJ_DEBUG("For loop complete, total iterations: %zu", filtered_items.size()); + if (noIteration) { + for (auto & stmt : default_block) { + value val = stmt->execute(ctx); + result->push_back(val); + } + } + + // convert to string parts + value_string str = mk_val(); + gather_string_parts_recursive(result, str); + return str; +} + +value set_statement::execute_impl(context & ctx) { + auto rhs = val ? val->execute(ctx) : exec_statements(body, ctx); + + if (is_stmt(assignee)) { + auto var_name = cast_stmt(assignee)->val; + JJ_DEBUG("Setting variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str()); + ctx.set_val(var_name, rhs); + + } else if (is_stmt(assignee)) { + auto tuple = cast_stmt(assignee); + if (!is_val(rhs)) { + throw std::runtime_error("Cannot unpack non-iterable type in set: " + rhs->type()); + } + auto & arr = rhs->as_array(); + if (arr.size() != tuple->val.size()) { + throw std::runtime_error(std::string("Too ") + (tuple->val.size() > arr.size() ? "few" : "many") + " items to unpack in set"); + } + for (size_t i = 0; i < tuple->val.size(); ++i) { + auto & elem = tuple->val[i]; + if (!is_stmt(elem)) { + throw std::runtime_error("Cannot unpack to non-identifier in set: " + elem->type()); + } + auto var_name = cast_stmt(elem)->val; + ctx.set_val(var_name, arr[i]); + } + + } else if (is_stmt(assignee)) { + auto member = cast_stmt(assignee); + if (member->computed) { + throw std::runtime_error("Cannot assign to computed member"); + } + if (!is_stmt(member->property)) { + throw std::runtime_error("Cannot assign to member with non-identifier property"); + } + auto prop_name = cast_stmt(member->property)->val; + + value object = member->object->execute(ctx); + if (!is_val(object)) { + throw std::runtime_error("Cannot assign to member of non-object"); + } + auto obj_ptr = cast_val(object); + JJ_DEBUG("Setting object property '%s'", prop_name.c_str()); + obj_ptr->insert(prop_name, rhs); + + } else { + throw std::runtime_error("Invalid LHS inside assignment expression: " + assignee->type()); + } + return mk_val(); +} + +value macro_statement::execute_impl(context & ctx) { + std::string name = cast_stmt(this->name)->val; + + const func_handler func = [this, name, &ctx](const func_args & args) -> value { + size_t expected_count = this->args.size(); + size_t input_count = args.args.size(); + + JJ_DEBUG("Invoking macro '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count); + context macro_ctx(ctx); // new scope for macro execution + + // bind parameters + for (size_t i = 0; i < expected_count; ++i) { + if (i < input_count) { + if (is_stmt(this->args[i])) { + // normal parameter + std::string param_name = cast_stmt(this->args[i])->val; + JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), args.args[i]->type().c_str()); + macro_ctx.set_val(param_name, args.args[i]); + } else if (is_stmt(this->args[i])) { + // default argument used as normal parameter + auto kwarg = cast_stmt(this->args[i]); + std::string param_name = cast_stmt(kwarg->key)->val; + JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), args.args[i]->type().c_str()); + macro_ctx.set_val(param_name, args.args[i]); + } else { + throw std::runtime_error("Invalid parameter type in macro '" + name + "'"); + } + } else { + auto & default_arg = this->args[i]; + if (is_stmt(default_arg)) { + auto kwarg = cast_stmt(default_arg); + std::string param_name = cast_stmt(kwarg->key)->val; + JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str()); + macro_ctx.set_val(param_name, kwarg->val->execute(ctx)); + } else { + throw std::runtime_error("Not enough arguments provided to macro '" + name + "'"); + } + //std::string param_name = cast_stmt(default_args[i])->val; + //JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str()); + //macro_ctx.var[param_name] = default_args[i]->execute(ctx); + } + } + + // execute macro body + JJ_DEBUG("Executing macro '%s' body with %zu statements", name.c_str(), this->body.size()); + auto res = exec_statements(this->body, macro_ctx); + JJ_DEBUG("Macro '%s' execution complete, result: %s", name.c_str(), res->val_str.str().c_str()); + return res; + }; + + JJ_DEBUG("Defining macro '%s' with %zu parameters", name.c_str(), args.size()); + ctx.set_val(name, mk_val(name, func)); + return mk_val(); +} + +value member_expression::execute_impl(context & ctx) { + value object = this->object->execute(ctx); + + value property; + if (this->computed) { + JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str()); + if (is_stmt(this->property)) { + auto s = cast_stmt(this->property); + value start_val = s->start_expr ? s->start_expr->execute(ctx) : mk_val("start"); + value stop_val = s->stop_expr ? s->stop_expr->execute(ctx) : mk_val("stop"); + value step_val = s->step_expr ? s->step_expr->execute(ctx) : mk_val("step"); + + // translate to function call: obj.slice(start, stop, step) + JJ_DEBUG("Member expression is a slice: start %s, stop %s, step %s", + start_val->as_repr().c_str(), + stop_val->as_repr().c_str(), + step_val->as_repr().c_str()); + auto slice_func = try_builtin_func(ctx, "slice", object); + func_args args(ctx); + args.args.push_back(start_val); + args.args.push_back(stop_val); + args.args.push_back(step_val); + return slice_func->invoke(args); + } else { + property = this->property->execute(ctx); + } + } else { + property = mk_val(cast_stmt(this->property)->val); + } + + JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str()); + + value val = mk_val("object_property"); + + if (is_val(object)) { + JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined"); + return val; + } else if (is_val(object)) { + if (!is_val(property)) { + throw std::runtime_error("Cannot access object with non-string: got " + property->type()); + } + auto key = property->as_string().str(); + auto & obj = object->as_object(); + auto it = obj.find(key); + if (it != obj.end()) { + val = it->second; + } else { + val = try_builtin_func(ctx, key, object, true); + } + JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str()); + + } else if (is_val(object) || is_val(object)) { + if (is_val(property)) { + int64_t index = property->as_int(); + JJ_DEBUG("Accessing %s index %lld", object->type().c_str(), index); + if (is_val(object)) { + auto & arr = object->as_array(); + if (index < 0) { + index += static_cast(arr.size()); + } + if (index >= 0 && index < static_cast(arr.size())) { + val = arr[index]; + } + } else { // value_string + auto str = object->as_string().str(); + if (index >= 0 && index < static_cast(str.size())) { + val = mk_val(std::string(1, str[index])); + } + } + + } else if (is_val(property)) { + auto key = property->as_string().str(); + JJ_DEBUG("Accessing %s built-in '%s'", is_val(object) ? "array" : "string", key.c_str()); + val = try_builtin_func(ctx, key, object); + } else { + throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type()); + } + + } else { + if (!is_val(property)) { + throw std::runtime_error("Cannot access property with non-string: got " + property->type()); + } + auto key = property->as_string().str(); + val = try_builtin_func(ctx, key, object); + } + + if (ctx.is_get_stats && val && object && property) { + val->stats.used = true; + object->stats.used = true; + if (is_val(property)) { + object->stats.ops.insert("array_access"); + } else if (is_val(property)) { + object->stats.ops.insert("object_access"); + } + } + + return val; +} + +value call_expression::execute_impl(context & ctx) { + // gather arguments + func_args args(ctx); + for (auto & arg_stmt : this->args) { + auto arg_val = arg_stmt->execute(ctx); + JJ_DEBUG(" Argument type: %s", arg_val->type().c_str()); + args.args.push_back(std::move(arg_val)); + } + // execute callee + value callee_val = callee->execute(ctx); + if (!is_val(callee_val)) { + throw std::runtime_error("Callee is not a function: got " + callee_val->type()); + } + auto * callee_func = cast_val(callee_val); + JJ_DEBUG("Calling function '%s' with %zu arguments", callee_func->name.c_str(), args.args.size()); + return callee_func->invoke(args); +} + +// compare operator for value_t +bool value_compare(const value & a, const value & b) { + auto cmp = [&]() { + // compare numeric types + if ((is_val(a) || is_val(a)) && + (is_val(b) || is_val(b))){ + try { + return a->as_float() == b->as_float(); + } catch (...) {} + } + // compare string and number + // TODO: not sure if this is the right behavior + if ((is_val(b) && (is_val(a) || is_val(a))) || + (is_val(a) && (is_val(b) || is_val(b)))) { + try { + return a->as_string().str() == b->as_string().str(); + } catch (...) {} + } + // compare boolean simple + if (is_val(a) && is_val(b)) { + return a->as_bool() == b->as_bool(); + } + // compare string simple + if (is_val(a) && is_val(b)) { + return a->as_string().str() == b->as_string().str(); + } + // compare by type + if (a->type() != b->type()) { + return false; + } + return false; + }; + auto result = cmp(); + JJ_DEBUG("Comparing types: %s and %s result=%d", a->type().c_str(), b->type().c_str(), result); + return result; +} + +value keyword_argument_expression::execute_impl(context & ctx) { + if (!is_stmt(key)) { + throw std::runtime_error("Keyword argument key must be identifiers"); + } + + std::string k = cast_stmt(key)->val; + JJ_DEBUG("Keyword argument expression key: %s, value: %s", k.c_str(), val->type().c_str()); + + value v = val->execute(ctx); + JJ_DEBUG("Keyword argument value executed, type: %s", v->type().c_str()); + + return mk_val(k, v); +} + +} // namespace jinja diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h new file mode 100644 index 0000000000..93c3ca91a5 --- /dev/null +++ b/common/jinja/jinja-vm.h @@ -0,0 +1,604 @@ +#pragma once + +#include "jinja-lexer.h" +#include "jinja-value.h" + +#include +#include +#include +#include +#include + +#define JJ_DEBUG(msg, ...) if (g_jinja_debug) printf("%s:%-3d : " msg "\n", FILENAME, __LINE__, __VA_ARGS__) + +extern bool g_jinja_debug; + +namespace jinja { + +struct statement; +using statement_ptr = std::unique_ptr; +using statements = std::vector; + +// Helpers for dynamic casting and type checking +template +struct extract_pointee_unique { + using type = T; +}; +template +struct extract_pointee_unique> { + using type = U; +}; +template +bool is_stmt(const statement_ptr & ptr) { + return dynamic_cast(ptr.get()) != nullptr; +} +template +T * cast_stmt(statement_ptr & ptr) { + return dynamic_cast(ptr.get()); +} +template +const T * cast_stmt(const statement_ptr & ptr) { + return dynamic_cast(ptr.get()); +} +// End Helpers + + +// not thread-safe +void enable_debug(bool enable); + +struct context { + std::string source; // for debugging + std::time_t current_time; // for functions that need current time + + bool is_get_stats = false; // whether to collect stats + + context() { + global = mk_val(); + global->insert("true", mk_val(true)); + global->insert("false", mk_val(false)); + global->insert("none", mk_val()); + current_time = std::time(nullptr); + } + ~context() = default; + + context(const context & parent) : context() { + // inherit variables (for example, when entering a new scope) + auto & pvar = parent.global->as_object(); + for (const auto & pair : pvar) { + set_val(pair.first, pair.second); + } + current_time = parent.current_time; + is_get_stats = parent.is_get_stats; + } + + value get_val(const std::string & name) { + auto it = global->val_obj.find(name); + if (it != global->val_obj.end()) { + return it->second; + } else { + return mk_val(name); + } + } + + void set_val(const std::string & name, const value & val) { + global->insert(name, val); + } + +private: + value_object global; +}; + +/** + * Base class for all nodes in the AST. + */ +struct statement { + size_t pos; // position in source, for debugging + virtual ~statement() = default; + virtual std::string type() const { return "Statement"; } + // execute_impl must be overridden by derived classes + virtual value execute_impl(context &) { throw std::runtime_error("cannot exec " + type()); } + // execute is the public method to execute a statement with error handling + value execute(context &); +}; + +// Type Checking Utilities + +template +static void chk_type(const statement_ptr & ptr) { + if (!ptr) return; // Allow null for optional fields + assert(dynamic_cast(ptr.get()) != nullptr); +} + +template +static void chk_type(const statement_ptr & ptr) { + if (!ptr) return; + assert(dynamic_cast(ptr.get()) != nullptr || dynamic_cast(ptr.get()) != nullptr); +} + +// Base Types + +/** + * Expressions will result in a value at runtime (unlike statements). + */ +struct expression : public statement { + std::string type() const override { return "Expression"; } +}; + +// Statements + +struct program : public statement { + statements body; + + program() = default; + explicit program(statements && body) : body(std::move(body)) {} + std::string type() const override { return "Program"; } + value execute_impl(context &) override { + throw std::runtime_error("Cannot execute program directly, use jinja::vm instead"); + } +}; + +struct if_statement : public statement { + statement_ptr test; + statements body; + statements alternate; + + if_statement(statement_ptr && test, statements && body, statements && alternate) + : test(std::move(test)), body(std::move(body)), alternate(std::move(alternate)) { + chk_type(this->test); + } + + std::string type() const override { return "If"; } + value execute_impl(context & ctx) override; +}; + +struct identifier; +struct tuple_literal; + +/** + * Loop over each item in a sequence + * https://jinja.palletsprojects.com/en/3.0.x/templates/#for + */ +struct for_statement : public statement { + statement_ptr loopvar; // Identifier | TupleLiteral + statement_ptr iterable; + statements body; + statements default_block; // if no iteration took place + + for_statement(statement_ptr && loopvar, statement_ptr && iterable, statements && body, statements && default_block) + : loopvar(std::move(loopvar)), iterable(std::move(iterable)), + body(std::move(body)), default_block(std::move(default_block)) { + chk_type(this->loopvar); + chk_type(this->iterable); + } + + std::string type() const override { return "For"; } + value execute_impl(context & ctx) override; +}; + +struct break_statement : public statement { + std::string type() const override { return "Break"; } + + struct signal : public std::exception { + const char* what() const noexcept override { + return "Break statement executed"; + } + }; + + value execute_impl(context &) override { + throw break_statement::signal(); + } +}; + +struct continue_statement : public statement { + std::string type() const override { return "Continue"; } + + struct signal : public std::exception { + const char* what() const noexcept override { + return "Continue statement executed"; + } + }; + + value execute_impl(context &) override { + throw continue_statement::signal(); + } +}; + +// do nothing +struct noop_statement : public statement { + std::string type() const override { return "Noop"; } + value execute_impl(context &) override { + return mk_val(); + } +}; + +struct set_statement : public statement { + statement_ptr assignee; + statement_ptr val; + statements body; + + set_statement(statement_ptr && assignee, statement_ptr && value, statements && body) + : assignee(std::move(assignee)), val(std::move(value)), body(std::move(body)) { + chk_type(this->assignee); + chk_type(this->val); + } + + std::string type() const override { return "Set"; } + value execute_impl(context & ctx) override; +}; + +struct macro_statement : public statement { + statement_ptr name; + statements args; + statements body; + + macro_statement(statement_ptr && name, statements && args, statements && body) + : name(std::move(name)), args(std::move(args)), body(std::move(body)) { + chk_type(this->name); + for (const auto& arg : this->args) chk_type(arg); + } + + std::string type() const override { return "Macro"; } + value execute_impl(context & ctx) override; +}; + +struct comment_statement : public statement { + std::string val; + explicit comment_statement(const std::string & v) : val(v) {} + std::string type() const override { return "Comment"; } + value execute_impl(context &) override { + return mk_val(); + } +}; + +// Expressions + +struct member_expression : public expression { + statement_ptr object; + statement_ptr property; + bool computed; + + member_expression(statement_ptr && object, statement_ptr && property, bool computed) + : object(std::move(object)), property(std::move(property)), computed(computed) { + chk_type(this->object); + chk_type(this->property); + } + std::string type() const override { return "MemberExpression"; } + value execute_impl(context & ctx) override; +}; + +struct call_expression : public expression { + statement_ptr callee; + statements args; + + call_expression(statement_ptr && callee, statements && args) + : callee(std::move(callee)), args(std::move(args)) { + chk_type(this->callee); + for (const auto& arg : this->args) chk_type(arg); + } + std::string type() const override { return "CallExpression"; } + value execute_impl(context & ctx) override; +}; + +/** + * Represents a user-defined variable or symbol in the template. + */ +struct identifier : public expression { + std::string val; + explicit identifier(const std::string & val) : val(val) {} + std::string type() const override { return "Identifier"; } + value execute_impl(context & ctx) override; +}; + +// Literals + +struct integer_literal : public expression { + int64_t val; + explicit integer_literal(int64_t val) : val(val) {} + std::string type() const override { return "IntegerLiteral"; } + value execute_impl(context &) override { + return mk_val(val); + } +}; + +struct float_literal : public expression { + double val; + explicit float_literal(double val) : val(val) {} + std::string type() const override { return "FloatLiteral"; } + value execute_impl(context &) override { + return mk_val(val); + } +}; + +struct string_literal : public expression { + std::string val; + explicit string_literal(const std::string & val) : val(val) {} + std::string type() const override { return "StringLiteral"; } + value execute_impl(context &) override { + return mk_val(val); + } +}; + +struct array_literal : public expression { + statements val; + explicit array_literal(statements && val) : val(std::move(val)) { + for (const auto& item : this->val) chk_type(item); + } + std::string type() const override { return "ArrayLiteral"; } + value execute_impl(context & ctx) override { + auto arr = mk_val(); + for (const auto & item_stmt : val) { + arr->push_back(item_stmt->execute(ctx)); + } + return arr; + } +}; + +struct tuple_literal : public array_literal { + explicit tuple_literal(statements && val) : array_literal(std::move(val)) {} + std::string type() const override { return "TupleLiteral"; } +}; + +struct object_literal : public expression { + std::vector> val; + explicit object_literal(std::vector> && val) + : val(std::move(val)) { + for (const auto & pair : this->val) { + chk_type(pair.first); + chk_type(pair.second); + } + } + std::string type() const override { return "ObjectLiteral"; } + value execute_impl(context & ctx) override; +}; + +// Complex Expressions + +/** + * An operation with two sides, separated by an operator. + * Note: Either side can be a Complex Expression, with order + * of operations being determined by the operator. + */ +struct binary_expression : public expression { + token op; + statement_ptr left; + statement_ptr right; + + binary_expression(token op, statement_ptr && left, statement_ptr && right) + : op(op), left(std::move(left)), right(std::move(right)) { + chk_type(this->left); + chk_type(this->right); + } + std::string type() const override { return "BinaryExpression"; } + value execute_impl(context & ctx) override; +}; + +/** + * An operation with two sides, separated by the | operator. + * Operator precedence: https://github.com/pallets/jinja/issues/379#issuecomment-168076202 + */ +struct filter_expression : public expression { + // either an expression or a value is allowed + statement_ptr operand; + value_string val; // will be set by filter_statement + + statement_ptr filter; + + filter_expression(statement_ptr && operand, statement_ptr && filter) + : operand(std::move(operand)), filter(std::move(filter)) { + chk_type(this->operand); + chk_type(this->filter); + } + + filter_expression(value_string && val, statement_ptr && filter) + : val(std::move(val)), filter(std::move(filter)) { + chk_type(this->filter); + } + + std::string type() const override { return "FilterExpression"; } + value execute_impl(context & ctx) override; +}; + +struct filter_statement : public statement { + statement_ptr filter; + statements body; + + filter_statement(statement_ptr && filter, statements && body) + : filter(std::move(filter)), body(std::move(body)) { + chk_type(this->filter); + } + std::string type() const override { return "FilterStatement"; } + value execute_impl(context & ctx) override; +}; + +/** + * An operation which filters a sequence of objects by applying a test to each object, + * and only selecting the objects with the test succeeding. + * + * It may also be used as a shortcut for a ternary operator. + */ +struct select_expression : public expression { + statement_ptr lhs; + statement_ptr test; + + select_expression(statement_ptr && lhs, statement_ptr && test) + : lhs(std::move(lhs)), test(std::move(test)) { + chk_type(this->lhs); + chk_type(this->test); + } + std::string type() const override { return "SelectExpression"; } + value execute_impl(context & ctx) override { + auto predicate = test->execute_impl(ctx); + if (!predicate->as_bool()) { + return mk_val(); + } + return lhs->execute_impl(ctx); + } +}; + +/** + * An operation with two sides, separated by the "is" operator. + * NOTE: "value is something" translates to function call "test_is_something(value)" + */ +struct test_expression : public expression { + statement_ptr operand; + bool negate; + statement_ptr test; + + test_expression(statement_ptr && operand, bool negate, statement_ptr && test) + : operand(std::move(operand)), negate(negate), test(std::move(test)) { + chk_type(this->operand); + chk_type(this->test); + } + std::string type() const override { return "TestExpression"; } + value execute_impl(context & ctx) override; +}; + +/** + * An operation with one side (operator on the left). + */ +struct unary_expression : public expression { + token op; + statement_ptr argument; + + unary_expression(token op, statement_ptr && argument) + : op(std::move(op)), argument(std::move(argument)) { + chk_type(this->argument); + } + std::string type() const override { return "UnaryExpression"; } + value execute_impl(context & ctx) override; +}; + +struct slice_expression : public expression { + statement_ptr start_expr; + statement_ptr stop_expr; + statement_ptr step_expr; + + slice_expression(statement_ptr && start_expr, statement_ptr && stop_expr, statement_ptr && step_expr) + : start_expr(std::move(start_expr)), stop_expr(std::move(stop_expr)), step_expr(std::move(step_expr)) { + chk_type(this->start_expr); + chk_type(this->stop_expr); + chk_type(this->step_expr); + } + std::string type() const override { return "SliceExpression"; } + value execute_impl(context &) override { + throw std::runtime_error("must be handled by MemberExpression"); + } +}; + +struct keyword_argument_expression : public expression { + statement_ptr key; + statement_ptr val; + + keyword_argument_expression(statement_ptr && key, statement_ptr && val) + : key(std::move(key)), val(std::move(val)) { + chk_type(this->key); + chk_type(this->val); + } + std::string type() const override { return "KeywordArgumentExpression"; } + value execute_impl(context & ctx) override; +}; + +struct spread_expression : public expression { + statement_ptr argument; + explicit spread_expression(statement_ptr && argument) : argument(std::move(argument)) { + chk_type(this->argument); + } + std::string type() const override { return "SpreadExpression"; } +}; + +struct call_statement : public statement { + statement_ptr call; + statements caller_args; + statements body; + + call_statement(statement_ptr && call, statements && caller_args, statements && body) + : call(std::move(call)), caller_args(std::move(caller_args)), body(std::move(body)) { + chk_type(this->call); + for (const auto& arg : this->caller_args) chk_type(arg); + } + std::string type() const override { return "CallStatement"; } +}; + +struct ternary_expression : public expression { + statement_ptr condition; + statement_ptr true_expr; + statement_ptr false_expr; + + ternary_expression(statement_ptr && condition, statement_ptr && true_expr, statement_ptr && false_expr) + : condition(std::move(condition)), true_expr(std::move(true_expr)), false_expr(std::move(false_expr)) { + chk_type(this->condition); + chk_type(this->true_expr); + chk_type(this->false_expr); + } + std::string type() const override { return "Ternary"; } + value execute_impl(context & ctx) override { + value cond_val = condition->execute(ctx); + if (cond_val->as_bool()) { + return true_expr->execute(ctx); + } else { + return false_expr->execute(ctx); + } + } +}; + +struct raised_exception : public std::exception { + std::string message; + raised_exception(const std::string & msg) : message(msg) {} + const char* what() const noexcept override { + return message.c_str(); + } +}; + +////////////////////// + +static void gather_string_parts_recursive(const value & val, value_string & parts) { + if (is_val(val)) { + const auto & str_val = cast_val(val)->val_str; + parts->val_str.append(str_val); + } else if (is_val(val)) { + auto items = cast_val(val)->as_array(); + for (const auto & item : items) { + gather_string_parts_recursive(item, parts); + } + } +} + +static std::string render_string_parts(const value_string & parts) { + std::ostringstream oss; + for (const auto & part : parts->val_str.parts) { + oss << part.val; + } + return oss.str(); +} + +struct vm { + context & ctx; + explicit vm(context & ctx) : ctx(ctx) {} + + value_array execute(const program & prog) { + value_array results = mk_val(); + for (auto & stmt : prog.body) { + value res = stmt->execute(ctx); + results->push_back(std::move(res)); + } + return results; + } + + value_string gather_string_parts(const value & val) { + value_string parts = mk_val(); + gather_string_parts_recursive(val, parts); + // join consecutive parts with the same type + auto & p = parts->val_str.parts; + for (size_t i = 1; i < p.size(); ) { + if (p[i].is_input == p[i - 1].is_input) { + p[i - 1].val += p[i].val; + p.erase(p.begin() + i); + } else { + i++; + } + } + return parts; + } +}; + +} // namespace jinja diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c3d9f9c324..f86a5b6657 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -186,6 +186,7 @@ endif() llama_build_and_test(test-chat-parser.cpp) llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp) llama_build_and_test(test-chat-template.cpp) +llama_build_and_test(test-chat-jinja.cpp) llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) llama_build_and_test( diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp new file mode 100644 index 0000000000..91a7b3ff87 --- /dev/null +++ b/tests/test-chat-jinja.cpp @@ -0,0 +1,206 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +#undef NDEBUG +#include + +#include "jinja/jinja-parser.h" +#include "jinja/jinja-lexer.h" +#include "jinja/jinja-caps.h" + +using json = nlohmann::json; + +void run_multiple(std::string dir_path, bool stop_on_first_failure, json input); +void run_single(std::string contents, json input, const std::string & output_path = ""); + +std::string HELP = R"( +Usage: test-chat-jinja [OPTIONS] PATH_TO_TEMPLATE +Options: + -h, --help Show this help message and exit. + --json Path to the JSON input file. + --stop-on-first-fail Stop testing on the first failure (default: false). + --output Path to output results (only for single template runs). +If PATH_TO_TEMPLATE is a file, runs that single template. +If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory. +)"; + +std::string DEFAULT_JSON = R"({ + "messages": [ + { + "role": "user", + "content": {"__input__": "Hello, how are you?"} + }, + { + "role": "assistant", + "content": {"__input__": "I am fine, thank you!"}, + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": { + "location": "New York", + "unit": "celsius" + } + } + } + ] + } + ], + "bos_token": "", + "eos_token": "", + "tools": [], + "add_generation_prompt": true +})"; + +int main(int argc, char ** argv) { + std::vector args(argv, argv + argc); + + std::string tmpl_path; + std::string json_path; + std::string output_path; + bool stop_on_first_fail = false; + + for (size_t i = 1; i < args.size(); i++) { + if (args[i] == "--help" || args[i] == "-h") { + std::cout << HELP << "\n"; + return 0; + } else if (args[i] == "--json" && i + 1 < args.size()) { + json_path = args[i + 1]; + i++; + } else if (args[i] == "--stop-on-first-fail") { + stop_on_first_fail = true; + } else if (args[i] == "--output" && i + 1 < args.size()) { + output_path = args[i + 1]; + i++; + } else if (tmpl_path.empty()) { + tmpl_path = args[i]; + } else { + std::cerr << "Unknown argument: " << args[i] << "\n"; + std::cout << HELP << "\n"; + return 1; + } + } + + if (tmpl_path.empty()) { + std::cerr << "Error: PATH_TO_TEMPLATE is required.\n"; + std::cout << HELP << "\n"; + return 1; + } + + json input_json; + if (!json_path.empty()) { + std::ifstream json_file(json_path); + if (!json_file) { + std::cerr << "Error: Could not open JSON file: " << json_path << "\n"; + return 1; + } + std::string content = std::string( + std::istreambuf_iterator(json_file), + std::istreambuf_iterator()); + input_json = json::parse(content); + } else { + input_json = json::parse(DEFAULT_JSON); + } + + std::filesystem::path p(tmpl_path); + if (std::filesystem::is_directory(p)) { + run_multiple(tmpl_path, stop_on_first_fail, input_json); + } else if (std::filesystem::is_regular_file(p)) { + std::ifstream infile(tmpl_path); + std::string contents = std::string( + std::istreambuf_iterator(infile), + std::istreambuf_iterator()); + run_single(contents, input_json, output_path); + } else { + std::cerr << "Error: PATH_TO_TEMPLATE is not a valid file or directory: " << tmpl_path << "\n"; + return 1; + } + + return 0; +} + +void run_multiple(std::string dir_path, bool stop_on_first_fail, json input) { + std::vector failed_tests; + + // list all files in models/templates/ and run each + size_t test_count = 0; + + for (const auto & entry : std::filesystem::directory_iterator(dir_path)) { + // only process .jinja files + if (entry.path().extension() == ".jinja" && entry.is_regular_file()) { + test_count++; + std::cout << "\n\n=== RUNNING TEMPLATE FILE: " << entry.path().string() << " ===\n"; + std::ifstream infile(entry.path()); + std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); + try { + run_single(contents, input); + } catch (const std::exception & e) { + std::cout << "Exception: " << e.what() << "\n"; + std::cout << "=== ERROR WITH TEMPLATE FILE: " << entry.path().string() << " ===\n"; + failed_tests.push_back(entry.path().string()); + if (stop_on_first_fail) { + break; + } + } + } + } + + std::cout << "\n\n=== TEST SUMMARY ===\n"; + std::cout << "Total tests run: " << test_count << "\n"; + std::cout << "Total failed tests: " << failed_tests.size() << "\n"; + for (const auto & test : failed_tests) { + std::cout << "FAILED TEST: " << test << "\n"; + } +} + + +void run_single(std::string contents, json input, const std::string & output_path) { + jinja::enable_debug(true); + + // lexing + jinja::lexer lexer; + jinja::preprocess_options options; + options.trim_blocks = false; + options.lstrip_blocks = false; + auto lexer_res = lexer.tokenize(contents, options); + + // compile to AST + jinja::program ast = jinja::parse_from_tokens(lexer_res); + + // check caps for workarounds + auto caps = jinja::caps_get(ast); + + std::cout << "\n=== RUN ===\n"; + jinja::context ctx; + ctx.source = lexer_res.preprocessed_source; + + jinja::global_from_json(ctx, input); + jinja::caps_apply_workarounds(ctx, caps); + + jinja::vm vm(ctx); + const jinja::value results = vm.execute(ast); + auto parts = vm.gather_string_parts(results); + + std::cout << "\n=== RESULTS ===\n"; + for (const auto & part : parts->as_string().parts) { + std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n"; + } + + if (!output_path.empty()) { + std::ofstream outfile(output_path); + if (!outfile) { + throw std::runtime_error("Could not open output file: " + output_path); + } + for (const auto & part : parts->as_string().parts) { + outfile << part.val; + } + std::cout << "\n=== OUTPUT WRITTEN TO " << output_path << " ===\n"; + } +}