This commit is contained in:
Xuan-Son Nguyen 2026-01-03 07:00:30 +09:00 committed by GitHub
commit b26b18261f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 4302 additions and 31 deletions

View File

@ -83,6 +83,15 @@ add_library(${TARGET} STATIC
speculative.h
unicode.cpp
unicode.h
jinja/jinja-lexer.cpp
jinja/jinja-lexer.h
jinja/jinja-parser.cpp
jinja/jinja-parser.h
jinja/jinja-vm.cpp
jinja/jinja-vm.h
jinja/jinja-value.cpp
jinja/jinja-value.h
jinja/jinja-string.h
)
target_include_directories(${TARGET} PUBLIC . ../vendor)

View File

@ -7,8 +7,12 @@
#include "log.h"
#include "regex-partial.h"
#include <minja/chat-template.hpp>
#include <minja/minja.hpp>
// #include <minja/chat-template.hpp>
// #include <minja/minja.hpp>
#include "jinja/jinja-parser.h"
#include "jinja/jinja-value.h"
#include "jinja/jinja-vm.h"
#include <algorithm>
#include <cstdio>
@ -135,7 +139,46 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
return diffs;
}
typedef minja::chat_template common_chat_template;
struct common_chat_template {
jinja::program prog;
std::string bos_tok;
std::string eos_tok;
std::string src;
common_chat_template(const std::string & src, const std::string & bos_token, const std::string & eos_token) {
jinja::lexer lexer;
jinja::preprocess_options options;
options.trim_blocks = false;
options.lstrip_blocks = false;
auto lexer_res = lexer.tokenize(src, options);
prog = jinja::parse_from_tokens(lexer_res);
this->src = lexer_res.preprocessed_source;
this->bos_tok = bos_token;
this->eos_tok = eos_token;
}
const std::string & source() const { return src; }
const std::string & bos_token() const { return bos_tok; }
const std::string & eos_token() const { return eos_tok; }
static json add_system(const json &, const std::string &) {
throw std::runtime_error("common_chat_template::add_system not implemented");
}
// this is just for testing. it will be removed later
struct chat_template_caps {
bool supports_tools = true;
bool supports_tool_calls = true;
bool supports_tool_responses = true;
bool supports_system_role = true;
bool supports_parallel_tool_calls = true;
bool requires_typed_content = true;
};
chat_template_caps original_caps() const {
return chat_template_caps();
}
};
struct common_chat_templates {
bool add_bos;
@ -627,14 +670,14 @@ common_chat_templates_ptr common_chat_templates_init(
tmpls->add_bos = add_bos;
tmpls->add_eos = add_eos;
try {
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
tmpls->template_default = std::make_unique<common_chat_template>(default_template_src, token_bos, token_eos);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
tmpls->template_default = std::make_unique<common_chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
}
if (!template_tool_use_src.empty()) {
try {
tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
tmpls->template_tool_use = std::make_unique<common_chat_template>(template_tool_use_src, token_bos, token_eos);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
}
@ -738,34 +781,40 @@ static std::string apply(
const std::optional<json> & tools_override = std::nullopt,
const std::optional<json> & additional_context = std::nullopt)
{
minja::chat_template_inputs tmpl_inputs;
tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages;
if (tools_override) {
tmpl_inputs.tools = *tools_override;
} else {
tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools;
}
tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt;
tmpl_inputs.extra_context = inputs.extra_context;
tmpl_inputs.extra_context["enable_thinking"] = inputs.enable_thinking;
if (additional_context) {
tmpl_inputs.extra_context.merge_patch(*additional_context);
}
// TODO: add flag to control date/time, if only for testing purposes.
// tmpl_inputs.now = std::chrono::system_clock::now();
// TODO IMPORTANT: IMPORVE THIS
minja::chat_template_options tmpl_opts;
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
// may be needed inside the template / between messages too.
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) {
result = result.substr(tmpl.bos_token().size());
jinja::context ctx;
ctx.source = tmpl.source(); // for debugging
nlohmann::json inp = nlohmann::json{
{"messages", messages_override.has_value() ? *messages_override : inputs.messages},
{"tools", tools_override.has_value() ? *tools_override : inputs.tools},
};
if (additional_context.has_value()) {
// TODO: merge properly instead of overwriting
for (const auto & [k, v] : additional_context->items()) {
inp[k] = v;
}
}
if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) {
result = result.substr(0, result.size() - tmpl.eos_token().size());
if (inputs.add_generation_prompt) {
inp["add_generation_prompt"] = true;
}
return result;
if (inputs.add_bos) {
inp["bos_token"] = tmpl.bos_token();
}
if (inputs.add_eos) {
inp["eos_token"] = tmpl.eos_token();
}
// TODO: more inputs?
jinja::global_from_json(ctx, inp);
// render
jinja::vm vm(ctx);
const jinja::value results = vm.execute(tmpl.prog);
auto parts = vm.gather_string_parts(results);
return parts->as_string().str();
}
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {

181
common/jinja/jinja-caps.h Normal file
View File

@ -0,0 +1,181 @@
#pragma once
#include <vector>
#include "jinja-value.h"
#include "jinja-vm.h"
#define FILENAME "jinja-caps"
namespace jinja {
struct caps {
bool content_string = true;
bool content_array = true;
};
using caps_messages_fn = std::function<value()>;
using caps_analyze_fn = std::function<void(bool, value &, value &)>;
static void caps_try_execute(jinja::program & prog,
caps_messages_fn messages_fn,
caps_messages_fn tools_fn,
caps_analyze_fn analyze_fn) {
context ctx;
ctx.is_get_stats = true;
value messages = messages_fn();
value tools = tools_fn();
ctx.set_val("messages", messages);
ctx.set_val("tools", tools);
ctx.set_val("add_generation_prompt", mk_val<value_bool>(true));
bool success = false;
try {
jinja::vm vm(ctx);
vm.execute(prog);
success = true;
} catch (const std::exception & e) {
JJ_DEBUG("Exception during execution: %s", e.what());
// ignore exceptions during capability analysis
}
return analyze_fn(success, messages, tools);
}
// for debugging only
static void caps_print_stats(value & v, std::string path) {
std::string ops;
for (const auto & name : v->stats.ops) {
ops += name + " ";
}
JJ_DEBUG("Value %s, type: %s %s, ops: %s",
path.c_str(),
v->type().c_str(),
v->stats.used ? "(used)" : "",
ops.c_str());
}
static caps caps_get(jinja::program & prog) {
caps result;
static const auto has_op = [](value & v, const std::string & op_name) {
return v->stats.ops.find(op_name) != v->stats.ops.end();
};
// case: given content as string, check if it's accessed as array
caps_try_execute(
prog,
[&]() {
auto messages = mk_val<value_array>();
{
value_object msg = mk_val<value_object>();
msg->insert("role", mk_val<value_string>("user"));
msg->insert("content", mk_val<value_string>("User message"));
messages->push_back(msg);
}
return messages;
},
[&]() {
return mk_val<value_array>();
},
[&](bool, value & messages, value &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
// accessed as an array
JJ_DEBUG("%s", "Force content as array");
result.content_string = false;
result.content_array = true;
}
}
);
// case: given content as array, check if it's supported or not
caps_try_execute(
prog,
[&]() {
auto messages = mk_val<value_array>();
{
value_object msg = mk_val<value_object>();
msg->insert("role", mk_val<value_string>("user"));
value_array content_arr = mk_val<value_array>();
{
value_object content_part = mk_val<value_object>();
content_part->insert("type", mk_val<value_string>("text"));
content_part->insert("text", mk_val<value_string>("User message"));
content_arr->push_back(content_part);
}
msg->insert("content", content_arr);
messages->push_back(msg);
}
return messages;
},
[&]() {
return mk_val<value_array>();
},
[&](bool success, value & messages, value &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (!success) {
JJ_DEBUG("%s", "Cannot handle content as array");
result.content_array = false;
}
}
);
return result;
}
static void caps_apply_workarounds(context & ctx, const caps & c) {
auto messages = ctx.get_val("messages");
if (!is_val<value_array>(messages)) {
throw std::runtime_error("Expected messages to be an array");
}
if (!c.content_string) {
for (auto & msg : messages->val_arr) {
if (!is_val<value_object>(msg)) {
throw std::runtime_error("Expected messages[i] to be an object");
}
auto obj_ptr = cast_val<value_object>(msg);
auto & content = obj_ptr->at("content");
if (!is_val<value_array>(content)) {
JJ_DEBUG("%s", "Converting message content to array");
auto str_content = content->as_string();
value_array arr_content = mk_val<value_array>();
value_object content_part = mk_val<value_object>();
content_part->insert("type", mk_val<value_string>("text"));
content_part->insert("text", mk_val<value_string>(str_content));
arr_content->push_back(content_part);
obj_ptr->insert("content", arr_content);
}
}
}
ctx.set_val("messages", messages);
//
// per-model workarounds
//
// workaround for shieldgemma-2b-Q2_K
if (ctx.get_val("guideline")->is_undefined()) {
ctx.set_val("guideline", mk_val<value_string>(""));
}
// workaround for functionary models
if (ctx.get_val("functions")->is_undefined()) {
ctx.set_val("functions", mk_val<value_string>(""));
}
if (ctx.get_val("datetime")->is_undefined()) {
ctx.set_val("datetime", mk_val<value_string>(""));
}
// workaround for Llama-3-5B-Sheard
if (ctx.get_val("system_message")->is_undefined()) {
ctx.set_val("system_message", mk_val<value_string>(""));
}
}
} // namespace jinja

View File

@ -0,0 +1,333 @@
#include "jinja-lexer.h"
#include "jinja-vm.h"
#include <vector>
#include <string>
#include <map>
#include <stdexcept>
#include <cctype>
#include <functional>
#include <string_view>
#define FILENAME "jinja-lexer"
namespace jinja {
// Trim template markers with '-' for whitespace control
// Example: [spaces]{%- ... -%} --> {% ... %}
#include <string>
#include <cctype>
static void trim_template_markers_inplace(std::string & s) {
// i = head ; j = tail (i <= j)
size_t j = 0; // Write pointer
const size_t len = s.length();
for (size_t i = 0; i < len; ) {
bool handled = false;
// We need at least 3 characters for any marker: {X- or -X}
if (i + 2 < len) {
const char c1 = s[i];
const char c2 = s[i + 1];
const char c3 = s[i + 2];
// 1. Closing trim: -X} where X = %, }, #
// Example: [content]-%} [spaces] -> [content]%}
if (c1 == '-' && c3 == '}' && (c2 == '%' || c2 == '}' || c2 == '#')) {
s[j++] = c2;
s[j++] = '}';
i += 3;
// Strip leading whitespace AFTER the tag
while (i < len && std::isspace(static_cast<unsigned char>(s[i]))) {
i++;
}
handled = true;
}
// 2. Opening trim: {X- where X = %, {, #
// Example: [spaces]{%- [content] -> {% [content]
else if (c1 == '{' && c3 == '-' && (c2 == '%' || c2 == '{' || c2 == '#')) {
// Trim trailing whitespace BEFORE the tag by moving write pointer back
while (j > 0 && std::isspace(static_cast<unsigned char>(s[j - 1]))) {
j--;
}
// Safety: Prevent merging '{' with tag start (avoid creating '{{%' or '{{{')
// if the character immediately before our new tag is a literal '{'.
if (j > 0 && s[j - 1] == '{') {
s[j++] = ' ';
}
s[j++] = '{';
s[j++] = c2;
i += 3;
handled = true;
}
}
if (!handled) {
// Note: j is always <= i here, so this is safe.
s[j++] = s[i++];
}
}
s.resize(j);
}
static void trim_newline_after_tag_inplace(std::string & s) {
// i = head ; j = tail (i <= j)
size_t j = 0; // Write pointer
const size_t len = s.length();
for (size_t i = 0; i < len; ) {
s[j++] = s[i++];
if (i < len && (s[j-1] == '}' || s[j-1] == '%' || s[j-1] == '#' || s[j-1] == '-')) {
if (s[i] == '}') {
// We have a potential tag closer like %} or -} or #} or }}
// Now check if the next character is a newline
if (i + 1 < len && s[i + 1] == '\n') {
// Skip the } and the following \n
++i; // skip the }
++i; // skip the \n
// Do not advance j, we effectively removed the \n
continue;
}
}
}
}
s.resize(j);
}
std::string lexer::preprocess(const std::string & template_str, const preprocess_options & options) const {
std::string result = template_str;
// According to https://jinja.palletsprojects.com/en/3.0.x/templates/#whitespace-control
// In the default configuration:
// - a single trailing newline is stripped if present
// - other whitespace (spaces, tabs, newlines etc.) is returned unchanged
if (!result.empty() && result.back() == '\n') {
result.pop_back();
}
if (options.lstrip_blocks) {
// The lstrip_blocks option can also be set to strip tabs and spaces from the
// beginning of a line to the start of a block. (Nothing will be stripped if
// there are other characters before the start of the block.)
// result = std::regex_replace(result, std::regex(R"((?m)^[ \t]*(\{[#%-]))"), "$1");
throw std::runtime_error("lstrip_blocks option is not implemented yet");
}
if (options.trim_blocks) {
// If an application configures Jinja to trim_blocks, the first newline after
// a template tag is removed automatically (like in PHP).
// Equivalent JS code: template.replace(/^[ \t]*({[#%-])/gm, "$1")
trim_newline_after_tag_inplace(result);
}
// Handle whitespace control with - in tags
trim_template_markers_inplace(result);
// Handle custom transformers-specific `generation` tag
// See https://github.com/huggingface/transformers/pull/30650 for more information.
// result = std::regex_replace(result, std::regex(R"(\{%\s*generation\s*%\})"), "");
// result = std::regex_replace(result, std::regex(R"(\{%\s*endgeneration\s*%\})"), "");
return result;
}
lexer_result lexer::tokenize(const std::string & input, const preprocess_options & options) {
std::vector<token> tokens;
std::string src = preprocess(input, options);
JJ_DEBUG("preprocessed input: '%s'", src.c_str());
size_t pos = 0;
size_t start_pos = 0;
size_t curly_bracket_depth = 0;
using pred = std::function<bool(char)>;
auto consume_while = [&](pred predicate) -> std::string {
std::string str;
while (predicate(src[pos])) {
// check for escape char
if (src[pos] == '\\') {
// consume backslash
++pos;
// check for end of input
if (pos >= src.size()) {
throw std::runtime_error("lexer: unexpected end of input after escape character");
}
// add escaped char
char escaped_char = src[pos++];
if (escape_chars.find(escaped_char) == escape_chars.end()) {
throw std::runtime_error(std::string("lexer: unknown escape character \\") + escaped_char);
}
char unescaped_char = escape_chars.at(escaped_char);
str += unescaped_char;
continue;
}
str += src[pos++];
if (pos > src.size()) {
throw std::runtime_error("lexer: unexpected end of input during consume_while");
}
}
return str;
};
auto next_pos_is = [&](std::initializer_list<char> chars) -> bool {
if (pos + 1 >= src.size()) return false;
for (char c : chars) {
if (src[pos + 1] == c) return true;
}
return false;
};
while (pos < src.size()) {
start_pos = pos;
JJ_DEBUG("lexer main loop at pos %zu: '%s...'", pos, src.substr(pos, 10).c_str());
// First, consume all text that is outside of a Jinja statement or expression
token::type last_token_type = tokens.empty()
? token::undefined
: tokens.back().t;
if (last_token_type == token::undefined ||
last_token_type == token::close_statement ||
last_token_type == token::close_expression ||
last_token_type == token::comment) {
std::string text;
while (pos < src.size() &&
// Keep going until we hit the next Jinja statement or expression
!(
src[pos] == '{' &&
next_pos_is( {'%', '{', '#'} )
)) {
text += src[pos++];
}
JJ_DEBUG("consumed text: '%s'", text.c_str());
if (!text.empty()) {
tokens.push_back({token::text, text, start_pos});
continue;
}
}
// Possibly consume a comment
if (src[pos] == '{' && next_pos_is( {'#'} )) {
start_pos = pos;
pos += 2; // Skip the opening {#
std::string comment;
while (!(src[pos] == '#' && next_pos_is( {'}'} ))) {
if (pos + 2 >= src.size()) {
throw std::runtime_error("lexer: missing end of comment tag");
}
comment += src[pos++];
}
JJ_DEBUG("consumed comment: '%s'", comment.c_str());
tokens.push_back({token::comment, comment, start_pos});
pos += 2; // Skip the closing #}
continue;
}
// Consume (and ignore) all whitespace inside Jinja statements or expressions
consume_while([](char c) { return std::isspace(static_cast<unsigned char>(c)); });
if (pos >= src.size()) break;
char ch = src[pos];
// Check for unary operators
if (ch == '-' || ch == '+') {
start_pos = pos;
token::type last_token_type = tokens.empty() ? token::undefined : tokens.back().t;
if (last_token_type == token::text || last_token_type == token::undefined) {
throw std::runtime_error(std::string("lexer: unexpected character: ") + ch);
}
switch (last_token_type) {
case token::identifier:
case token::numeric_literal:
case token::string_literal:
case token::close_paren:
case token::close_square_bracket:
// Part of a binary operator
// a - 1, 1 - 1, true - 1, "apple" - 1, (1) - 1, a[1] - 1
// Continue parsing normally
break;
default: {
// Is part of a unary operator
// (-1), [-1], (1 + -1), not -1, -apple
++pos; // Consume the operator
// Check for numbers following the unary operator
std::string num = consume_while(is_integer);
std::string value = std::string(1, ch) + num;
token::type t = num.empty() ? token::unary_operator : token::numeric_literal;
JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str());
tokens.push_back({t, value, start_pos});
continue;
}
}
}
// Try to match one of the tokens in the mapping table
bool matched = false;
for (const auto & [seq, typ] : ordered_mapping_table) {
start_pos = pos;
// Inside an object literal, don't treat "}}" as expression-end
if (seq == "}}" && curly_bracket_depth > 0) {
continue;
}
if (pos + seq.size() <= src.size() && src.substr(pos, seq.size()) == seq) {
tokens.push_back({typ, seq, start_pos});
if (typ == token::open_expression) {
curly_bracket_depth = 0;
} else if (typ == token::open_curly_bracket) {
++curly_bracket_depth;
} else if (typ == token::close_curly_bracket) {
--curly_bracket_depth;
}
pos += seq.size();
matched = true;
break; // continue main loop
}
}
if (matched) continue; // continue main loop
// Strings
if (ch == '\'' || ch == '"') {
start_pos = pos;
++pos; // Skip opening quote
std::string str = consume_while([ch](char c) { return c != ch; });
tokens.push_back({token::string_literal, str, start_pos});
++pos; // Skip closing quote
continue;
}
// Numbers
if (is_integer(ch)) {
start_pos = pos;
std::string num = consume_while(is_integer);
if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) {
++pos; // Consume '.'
std::string frac = consume_while(is_integer);
num += "." + frac;
}
tokens.push_back({token::numeric_literal, num, start_pos});
continue;
}
// Identifiers
if (is_word(ch)) {
start_pos = pos;
std::string word = consume_while(is_word);
tokens.push_back({token::identifier, word, start_pos});
continue;
}
throw std::runtime_error(std::string("lexer: unexpected character: ") + ch);
}
return {std::move(tokens), std::move(src)};
}
} // namespace jinja

152
common/jinja/jinja-lexer.h Normal file
View File

@ -0,0 +1,152 @@
#pragma once
#include <vector>
#include <string>
#include <map>
#include <regex>
#include <stdexcept>
#include <cctype>
#include <functional>
namespace jinja {
struct preprocess_options {
bool trim_blocks = false;
bool lstrip_blocks = false;
};
struct token {
enum type {
undefined,
text, // The text between Jinja statements or expressions
numeric_literal, // e.g., 123, 1.0
string_literal, // 'string'
identifier, // Variables, functions, statements, booleans, etc.
equals, // =
open_paren, // (
close_paren, // )
open_statement, // {%
close_statement, // %}
open_expression, // {{
close_expression, // }}
open_square_bracket, // [
close_square_bracket, // ]
open_curly_bracket, // {
close_curly_bracket, // }
comma, // ,
dot, // .
colon, // :
pipe, // |
call_operator, // ()
additive_binary_operator, // + - ~
multiplicative_binary_operator, // * / %
comparison_binary_operator, // < > <= >= == !=
unary_operator, // ! - +
comment, // {# ... #}
};
type t;
std::string value;
size_t pos;
};
static std::string type_to_string(token::type t) {
switch (t) {
case token::undefined: return "undefined";
case token::text: return "text";
case token::numeric_literal: return "numeric_literal";
case token::string_literal: return "string_literal";
case token::identifier: return "identifier";
case token::equals: return "equals";
case token::open_paren: return "open_paren";
case token::close_paren: return "close_paren";
case token::open_statement: return "open_statement";
case token::close_statement: return "close_statement";
case token::open_expression: return "open_expression";
case token::close_expression: return "close_expression";
case token::open_square_bracket: return "open_square_bracket";
case token::close_square_bracket: return "close_square_bracket";
case token::open_curly_bracket: return "open_curly_bracket";
case token::close_curly_bracket: return "close_curly_bracket";
case token::comma: return "comma";
case token::dot: return "dot";
case token::colon: return "colon";
case token::pipe: return "pipe";
case token::call_operator: return "call_operator";
case token::additive_binary_operator: return "additive_binary_operator";
case token::multiplicative_binary_operator: return "multiplicative_binary_operator";
case token::comparison_binary_operator: return "comparison_binary_operator";
case token::unary_operator: return "unary_operator";
case token::comment: return "comment";
default: return "unknown";
}
}
struct lexer_result {
std::vector<token> tokens;
std::string preprocessed_source;
};
struct lexer {
const std::map<char, char> escape_chars = {
{'n', '\n'},
{'t', '\t'},
{'r', '\r'},
{'b', '\b'},
{'f', '\f'},
{'v', '\v'},
{'\\', '\\'},
{'\'', '\''},
{'\"', '\"'},
};
static bool is_word(char c) {
return std::isalnum(static_cast<unsigned char>(c)) || c == '_';
}
static bool is_integer(char c) {
return std::isdigit(static_cast<unsigned char>(c));
}
const std::vector<std::pair<std::string, token::type>> ordered_mapping_table = {
// Control sequences
{"{%", token::open_statement},
{"%}", token::close_statement},
{"{{", token::open_expression},
{"}}", token::close_expression},
// Single character tokens
{"(", token::open_paren},
{")", token::close_paren},
{"{", token::open_curly_bracket},
{"}", token::close_curly_bracket},
{"[", token::open_square_bracket},
{"]", token::close_square_bracket},
{",", token::comma},
{".", token::dot},
{":", token::colon},
{"|", token::pipe},
// Comparison operators
{"<=", token::comparison_binary_operator},
{">=", token::comparison_binary_operator},
{"==", token::comparison_binary_operator},
{"!=", token::comparison_binary_operator},
{"<", token::comparison_binary_operator},
{">", token::comparison_binary_operator},
// Arithmetic operators
{"+", token::additive_binary_operator},
{"-", token::additive_binary_operator},
{"~", token::additive_binary_operator},
{"*", token::multiplicative_binary_operator},
{"/", token::multiplicative_binary_operator},
{"%", token::multiplicative_binary_operator},
// Assignment operator
{"=", token::equals},
};
std::string preprocess(const std::string& template_str, const preprocess_options& options) const;
lexer_result tokenize(const std::string & input, const preprocess_options & options);
};
} // namespace jinja

View File

@ -0,0 +1,610 @@
#include "jinja-lexer.h"
#include "jinja-vm.h"
#include "jinja-parser.h"
#include <string>
#include <vector>
#include <memory>
#include <stdexcept>
#include <algorithm>
#define FILENAME "jinja-parser"
namespace jinja {
// Helper to check type without asserting (useful for logic)
template<typename T>
static bool is_type(const statement_ptr & ptr) {
return dynamic_cast<const T*>(ptr.get()) != nullptr;
}
class parser {
const std::vector<token> & tokens;
size_t current = 0;
size_t prev_cur = 0;
// for debugging; a token can be multiple chars in source
std::vector<size_t> tok_pos_to_src_pos;
std::string source; // for error reporting
public:
parser(const std::vector<token> & t, const std::string & src) : tokens(t), source(src) {
tok_pos_to_src_pos.resize(tokens.size());
for (size_t i = 0; i < tokens.size(); i++) {
tok_pos_to_src_pos[i] = tokens[i].pos;
}
}
program parse() {
statements body;
while (current < tokens.size()) {
body.push_back(parse_any());
}
return program(std::move(body));
}
template<typename T, typename... Args>
std::unique_ptr<T> mk_stmt(Args&&... args) {
auto ptr = std::make_unique<T>(std::forward<Args>(args)...);
ptr->pos = tok_pos_to_src_pos[prev_cur];
std::string snippet = "no source";
if (!source.empty()) {
size_t start_pos = ptr->pos;
size_t end_pos = start_pos + 20;
if (end_pos > source.size()) end_pos = source.size();
snippet = source.substr(start_pos, end_pos - start_pos);
}
JJ_DEBUG("Created %-20s statement at src pos %-4zu (%s)", ptr->type().c_str(), ptr->pos, snippet.c_str());
return ptr;
}
private:
const token & peek(size_t offset = 0) const {
if (current + offset >= tokens.size()) {
static const token end_token{token::undefined, "", 0};
return end_token;
}
return tokens[current + offset];
}
token expect(token::type type, const std::string& error) {
const auto & t = peek();
if (t.t != type) {
throw std::runtime_error("Parser Error: " + error + " (Got " + t.value + ")");
}
current++;
return t;
}
void expect_identifier(const std::string& name) {
const auto & t = peek();
if (t.t != token::identifier || t.value != name) {
throw std::runtime_error("Expected identifier: " + name);
}
current++;
}
bool is(token::type type) const {
return peek().t == type;
}
bool is_identifier(const std::string& name) const {
return peek().t == token::identifier && peek().value == name;
}
bool is_statement(const std::vector<std::string>& names) const {
if (peek(0).t != token::open_statement || peek(1).t != token::identifier) {
return false;
}
std::string val = peek(1).value;
return std::find(names.begin(), names.end(), val) != names.end();
}
statement_ptr parse_any() {
prev_cur = current;
switch (peek().t) {
case token::comment:
return mk_stmt<comment_statement>(tokens[current++].value);
case token::text:
return mk_stmt<string_literal>(tokens[current++].value);
case token::open_statement:
return parse_jinja_statement();
case token::open_expression:
return parse_jinja_expression();
default:
throw std::runtime_error("Unexpected token type");
}
}
statement_ptr parse_jinja_expression() {
// Consume {{ }} tokens
prev_cur = current;
expect(token::open_expression, "Expected {{");
auto result = parse_expression();
expect(token::close_expression, "Expected }}");
return result;
}
statement_ptr parse_jinja_statement() {
// Consume {% token
prev_cur = current;
expect(token::open_statement, "Expected {%");
if (peek().t != token::identifier) {
throw std::runtime_error("Unknown statement");
}
std::string name = peek().value;
current++; // consume identifier
statement_ptr result;
if (name == "set") {
result = parse_set_statement();
} else if (name == "if") {
result = parse_if_statement();
// expect {% endif %}
expect(token::open_statement, "Expected {%");
expect_identifier("endif");
expect(token::close_statement, "Expected %}");
} else if (name == "macro") {
result = parse_macro_statement();
// expect {% endmacro %}
expect(token::open_statement, "Expected {%");
expect_identifier("endmacro");
expect(token::close_statement, "Expected %}");
} else if (name == "for") {
result = parse_for_statement();
// expect {% endfor %}
expect(token::open_statement, "Expected {%");
expect_identifier("endfor");
expect(token::close_statement, "Expected %}");
} else if (name == "break") {
expect(token::close_statement, "Expected %}");
result = mk_stmt<break_statement>();
} else if (name == "continue") {
expect(token::close_statement, "Expected %}");
result = mk_stmt<continue_statement>();
} else if (name == "call") {
statements caller_args;
// bool has_caller_args = false;
if (is(token::open_paren)) {
// Optional caller arguments, e.g. {% call(user) dump_users(...) %}
caller_args = parse_args();
// has_caller_args = true;
}
auto callee = parse_primary_expression();
if (!is_type<identifier>(callee)) throw std::runtime_error("Expected identifier");
auto call_args = parse_args();
expect(token::close_statement, "Expected %}");
statements body;
while (!is_statement({"endcall"})) {
body.push_back(parse_any());
}
expect(token::open_statement, "Expected {%");
expect_identifier("endcall");
expect(token::close_statement, "Expected %}");
auto call_expr = mk_stmt<call_expression>(std::move(callee), std::move(call_args));
result = mk_stmt<call_statement>(std::move(call_expr), std::move(caller_args), std::move(body));
} else if (name == "filter") {
auto filter_node = parse_primary_expression();
if (is_type<identifier>(filter_node) && is(token::open_paren)) {
filter_node = parse_call_expression(std::move(filter_node));
}
expect(token::close_statement, "Expected %}");
statements body;
while (!is_statement({"endfilter"})) {
body.push_back(parse_any());
}
expect(token::open_statement, "Expected {%");
expect_identifier("endfilter");
expect(token::close_statement, "Expected %}");
result = mk_stmt<filter_statement>(std::move(filter_node), std::move(body));
} else if (name == "generation" || name == "endgeneration") {
// Ignore generation blocks (transformers-specific)
// See https://github.com/huggingface/transformers/pull/30650 for more information.
result = mk_stmt<noop_statement>();
current++;
} else {
throw std::runtime_error("Unknown statement: " + name);
}
return result;
}
statement_ptr parse_set_statement() {
// NOTE: `set` acts as both declaration statement and assignment expression
auto left = parse_expression_sequence();
statement_ptr value = nullptr;
statements body;
prev_cur = current;
if (is(token::equals)) {
current++;
value = parse_expression_sequence();
} else {
// parsing multiline set here
expect(token::close_statement, "Expected %}");
while (!is_statement({"endset"})) {
body.push_back(parse_any());
}
expect(token::open_statement, "Expected {%");
expect_identifier("endset");
}
expect(token::close_statement, "Expected %}");
return mk_stmt<set_statement>(std::move(left), std::move(value), std::move(body));
}
statement_ptr parse_if_statement() {
auto test = parse_expression();
expect(token::close_statement, "Expected %}");
statements body;
statements alternate;
prev_cur = current;
// Keep parsing 'if' body until we reach the first {% elif %} or {% else %} or {% endif %}
while (!is_statement({"elif", "else", "endif"})) {
body.push_back(parse_any());
}
if (is_statement({"elif"})) {
++current; // consume {%
++current; // consume 'elif'
alternate.push_back(parse_if_statement()); // nested If
} else if (is_statement({"else"})) {
++current; // consume {%
++current; // consume 'else'
expect(token::close_statement, "Expected %}");
// keep going until we hit {% endif %}
while (!is_statement({"endif"})) {
alternate.push_back(parse_any());
}
}
return mk_stmt<if_statement>(std::move(test), std::move(body), std::move(alternate));
}
statement_ptr parse_macro_statement() {
auto name = parse_primary_expression();
auto args = parse_args();
expect(token::close_statement, "Expected %}");
statements body;
// Keep going until we hit {% endmacro
while (!is_statement({"endmacro"})) {
body.push_back(parse_any());
}
return mk_stmt<macro_statement>(std::move(name), std::move(args), std::move(body));
}
statement_ptr parse_expression_sequence(bool primary = false) {
statements exprs;
exprs.push_back(primary ? parse_primary_expression() : parse_expression());
bool is_tuple = is(token::comma);
while (is(token::comma)) {
prev_cur = current;
current++; // consume comma
exprs.push_back(primary ? parse_primary_expression() : parse_expression());
if (!is(token::comma)) break;
}
return is_tuple ? mk_stmt<tuple_literal>(std::move(exprs)) : std::move(exprs[0]);
}
statement_ptr parse_for_statement() {
// e.g., `message` in `for message in messages`
auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple
if (!is_identifier("in")) throw std::runtime_error("Expected 'in'");
current++;
// `messages` in `for message in messages`
auto iterable = parse_expression();
expect(token::close_statement, "Expected %}");
statements body;
statements alternate;
// Keep going until we hit {% endfor or {% else
while (!is_statement({"endfor", "else"})) {
body.push_back(parse_any());
}
if (is_statement({"else"})) {
prev_cur = current;
current += 2;
expect(token::close_statement, "Expected %}");
while (!is_statement({"endfor"})) {
alternate.push_back(parse_any());
}
}
return mk_stmt<for_statement>(
std::move(loop_var), std::move(iterable),
std::move(body), std::move(alternate));
}
statement_ptr parse_expression() {
// Choose parse function with lowest precedence
return parse_if_expression();
}
statement_ptr parse_if_expression() {
auto a = parse_logical_or_expression();
if (is_identifier("if")) {
// Ternary expression
prev_cur = current;
++current; // consume 'if'
auto test = parse_logical_or_expression();
if (is_identifier("else")) {
// Ternary expression with else
prev_cur = current;
++current; // consume 'else'
auto false_expr = parse_if_expression(); // recurse to support chained ternaries
return mk_stmt<ternary_expression>(std::move(test), std::move(a), std::move(false_expr));
} else {
// Select expression on iterable
return mk_stmt<select_expression>(std::move(a), std::move(test));
}
}
return a;
}
statement_ptr parse_logical_or_expression() {
auto left = parse_logical_and_expression();
while (is_identifier("or")) {
prev_cur = current;
token op = tokens[current++];
left = mk_stmt<binary_expression>(op, std::move(left), parse_logical_and_expression());
}
return left;
}
statement_ptr parse_logical_and_expression() {
auto left = parse_logical_negation_expression();
while (is_identifier("and")) {
prev_cur = current;
auto op = tokens[current++];
left = mk_stmt<binary_expression>(op, std::move(left), parse_logical_negation_expression());
}
return left;
}
statement_ptr parse_logical_negation_expression() {
// Try parse unary operators
if (is_identifier("not")) {
prev_cur = current;
auto op = tokens[current];
++current; // consume 'not'
return mk_stmt<unary_expression>(op, parse_logical_negation_expression());
}
return parse_comparison_expression();
}
statement_ptr parse_comparison_expression() {
// NOTE: membership has same precedence as comparison
// e.g., ('a' in 'apple' == 'b' in 'banana') evaluates as ('a' in ('apple' == ('b' in 'banana')))
auto left = parse_additive_expression();
while (true) {
token op;
prev_cur = current;
if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") {
op = {token::identifier, "not in", tokens[current].pos};
current += 2;
} else if (is_identifier("in")) {
op = tokens[current++];
} else if (is(token::comparison_binary_operator)) {
op = tokens[current++];
} else break;
left = mk_stmt<binary_expression>(op, std::move(left), parse_additive_expression());
}
return left;
}
statement_ptr parse_additive_expression() {
auto left = parse_multiplicative_expression();
while (is(token::additive_binary_operator)) {
prev_cur = current;
auto op = tokens[current++];
left = mk_stmt<binary_expression>(op, std::move(left), parse_multiplicative_expression());
}
return left;
}
statement_ptr parse_multiplicative_expression() {
auto left = parse_test_expression();
while (is(token::multiplicative_binary_operator)) {
prev_cur = current;
auto op = tokens[current++];
left = mk_stmt<binary_expression>(op, std::move(left), parse_test_expression());
}
return left;
}
statement_ptr parse_test_expression() {
auto operand = parse_filter_expression();
while (is_identifier("is")) {
prev_cur = current;
current++;
bool negate = false;
if (is_identifier("not")) { current++; negate = true; }
auto test_id = parse_primary_expression();
operand = mk_stmt<test_expression>(std::move(operand), negate, std::move(test_id));
}
return operand;
}
statement_ptr parse_filter_expression() {
auto operand = parse_call_member_expression();
while (is(token::pipe)) {
prev_cur = current;
current++;
auto filter = parse_primary_expression();
if (is(token::open_paren)) filter = parse_call_expression(std::move(filter));
operand = mk_stmt<filter_expression>(std::move(operand), std::move(filter));
}
return operand;
}
statement_ptr parse_call_member_expression() {
// Handle member expressions recursively
auto member = parse_member_expression(parse_primary_expression());
return is(token::open_paren)
? parse_call_expression(std::move(member)) // foo.x()
: std::move(member);
}
statement_ptr parse_call_expression(statement_ptr callee) {
auto expr = mk_stmt<call_expression>(std::move(callee), parse_args());
auto member = parse_member_expression(std::move(expr)); // foo.x().y
return is(token::open_paren)
? parse_call_expression(std::move(member)) // foo.x()()
: std::move(member);
}
statements parse_args() {
// comma-separated arguments list
expect(token::open_paren, "Expected (");
statements args;
while (!is(token::close_paren)) {
statement_ptr arg;
prev_cur = current;
// unpacking: *expr
if (peek().t == token::multiplicative_binary_operator && peek().value == "*") {
++current; // consume *
arg = mk_stmt<spread_expression>(parse_expression());
} else {
arg = parse_expression();
if (is(token::equals)) {
// keyword argument
// e.g., func(x = 5, y = a or b)
++current; // consume equals
arg = mk_stmt<keyword_argument_expression>(std::move(arg), parse_expression());
}
}
args.push_back(std::move(arg));
if (is(token::comma)) {
++current; // consume comma
}
}
expect(token::close_paren, "Expected )");
return args;
}
statement_ptr parse_member_expression(statement_ptr object) {
while (is(token::dot) || is(token::open_square_bracket)) {
auto op = tokens[current++];
bool computed = op.t == token::open_square_bracket;
statement_ptr prop;
if (computed) {
prop = parse_member_expression_arguments();
expect(token::close_square_bracket, "Expected ]");
} else {
prop = parse_primary_expression();
}
object = mk_stmt<member_expression>(std::move(object), std::move(prop), computed);
}
return object;
}
statement_ptr parse_member_expression_arguments() {
// NOTE: This also handles slice expressions colon-separated arguments list
// e.g., ['test'], [0], [:2], [1:], [1:2], [1:2:3]
statements slices;
bool is_slice = false;
while (!is(token::close_square_bracket)) {
prev_cur = current;
if (is(token::colon)) {
// A case where a default is used
// e.g., [:2] will be parsed as [undefined, 2]
slices.push_back(nullptr);
++current; // consume colon
is_slice = true;
} else {
slices.push_back(parse_expression());
if (is(token::colon)) {
++current; // consume colon after expression, if it exists
is_slice = true;
}
}
}
if (is_slice) {
statement_ptr start = slices.size() > 0 ? std::move(slices[0]) : nullptr;
statement_ptr stop = slices.size() > 1 ? std::move(slices[1]) : nullptr;
statement_ptr step = slices.size() > 2 ? std::move(slices[2]) : nullptr;
return mk_stmt<slice_expression>(std::move(start), std::move(stop), std::move(step));
}
return std::move(slices[0]);
}
statement_ptr parse_primary_expression() {
prev_cur = current;
auto t = tokens[current++];
switch (t.t) {
case token::numeric_literal:
if (t.value.find('.') != std::string::npos) return mk_stmt<float_literal>(std::stod(t.value));
return mk_stmt<integer_literal>(std::stoll(t.value));
case token::string_literal: {
std::string val = t.value;
while (is(token::string_literal)) {
val += tokens[current++].value;
}
return mk_stmt<string_literal>(val);
}
case token::identifier:
return mk_stmt<identifier>(t.value);
case token::open_paren: {
auto expr = parse_expression_sequence();
expect(token::close_paren, "Expected )");
return expr;
}
case token::open_square_bracket: {
statements vals;
while (!is(token::close_square_bracket)) {
vals.push_back(parse_expression());
if (is(token::comma)) current++;
}
current++;
return mk_stmt<array_literal>(std::move(vals));
}
case token::open_curly_bracket: {
std::vector<std::pair<statement_ptr, statement_ptr>> pairs;
while (!is(token::close_curly_bracket)) {
auto key = parse_expression();
expect(token::colon, "Expected :");
pairs.push_back({std::move(key), parse_expression()});
if (is(token::comma)) current++;
}
current++;
return mk_stmt<object_literal>(std::move(pairs));
}
default:
throw std::runtime_error("Unexpected token: " + t.value + " of type " + std::to_string(t.t));
}
}
};
program parse_from_tokens(const std::vector<token> & tokens) {
return parser(tokens, "").parse();
}
program parse_from_tokens(const lexer_result & lexer_res) {
return parser(lexer_res.tokens, lexer_res.preprocessed_source).parse();
}
} // namespace jinja

View File

@ -0,0 +1,18 @@
#pragma once
#include "jinja-lexer.h"
#include "jinja-vm.h"
#include <string>
#include <vector>
#include <memory>
#include <stdexcept>
#include <algorithm>
namespace jinja {
program parse_from_tokens(const std::vector<token> & tokens);
program parse_from_tokens(const lexer_result & lexer_res);
} // namespace jinja

202
common/jinja/jinja-string.h Normal file
View File

@ -0,0 +1,202 @@
#pragma once
#include <vector>
#include <string>
#include <functional>
#include <sstream>
namespace jinja {
// allow differentiate between user input strings and template strings
// transformations should handle this information as follows:
// - one-to-one (e.g., uppercase, lowercase): preserve is_input flag
// - one-to-many (e.g., strip): if input string is marked as is_input, all resulting parts should be marked as is_input
// - many-to-one (e.g., concat): if ALL input parts are marked as is_input, resulting part should be marked as is_input
struct string_part {
bool is_input = false; // may skip parsing special tokens if true
std::string val;
bool is_uppercase() const {
for (char c : val) {
if (std::islower(static_cast<unsigned char>(c))) {
return false;
}
}
return true;
}
bool is_lowercase() const {
for (char c : val) {
if (std::isupper(static_cast<unsigned char>(c))) {
return false;
}
}
return true;
}
};
struct string {
using transform_fn = std::function<std::string(const std::string&)>;
std::vector<string_part> parts;
string() = default;
string(const std::string & v, bool user_input = false) {
parts.push_back({user_input, v});
}
string(int v) {
parts.push_back({false, std::to_string(v)});
}
string(double v) {
parts.push_back({false, std::to_string(v)});
}
void mark_input() {
for (auto & part : parts) {
part.is_input = true;
}
}
std::string str() const {
if (parts.size() == 1) {
return parts[0].val;
}
std::ostringstream oss;
for (const auto & part : parts) {
oss << part.val;
}
return oss.str();
}
size_t length() const {
size_t len = 0;
for (const auto & part : parts) {
len += part.val.length();
}
return len;
}
bool all_parts_are_input() const {
for (const auto & part : parts) {
if (!part.is_input) {
return false;
}
}
return true;
}
bool is_uppercase() const {
for (const auto & part : parts) {
if (!part.is_uppercase()) {
return false;
}
}
return true;
}
bool is_lowercase() const {
for (const auto & part : parts) {
if (!part.is_lowercase()) {
return false;
}
}
return true;
}
// mark this string as input if other has ALL parts as input
void mark_input_based_on(const string & other) {
if (other.all_parts_are_input()) {
for (auto & part : parts) {
part.is_input = true;
}
}
}
string append(const string & other) {
for (const auto & part : other.parts) {
parts.push_back(part);
}
return *this;
}
// in-place transformation
string apply_transform(const transform_fn & fn) {
for (auto & part : parts) {
part.val = fn(part.val);
}
return *this;
}
string uppercase() {
return apply_transform([](const std::string & s) {
std::string res = s;
std::transform(res.begin(), res.end(), res.begin(), ::toupper);
return res;
});
}
string lowercase() {
return apply_transform([](const std::string & s) {
std::string res = s;
std::transform(res.begin(), res.end(), res.begin(), ::tolower);
return res;
});
}
string capitalize() {
return apply_transform([](const std::string & s) {
if (s.empty()) return s;
std::string res = s;
res[0] = ::toupper(static_cast<unsigned char>(res[0]));
std::transform(res.begin() + 1, res.end(), res.begin() + 1, ::tolower);
return res;
});
}
string titlecase() {
return apply_transform([](const std::string & s) {
std::string res = s;
bool capitalize_next = true;
for (char &c : res) {
if (isspace(static_cast<unsigned char>(c))) {
capitalize_next = true;
} else if (capitalize_next) {
c = ::toupper(static_cast<unsigned char>(c));
capitalize_next = false;
} else {
c = ::tolower(static_cast<unsigned char>(c));
}
}
return res;
});
}
string strip(bool left, bool right) {
// TODO: what if leading/trailing continue in multiple parts?
static auto strip_part = [](const std::string & s, bool left, bool right) -> std::string {
size_t start = 0;
size_t end = s.length();
if (left) {
while (start < end && isspace(static_cast<unsigned char>(s[start]))) {
++start;
}
}
if (right) {
while (end > start && isspace(static_cast<unsigned char>(s[end - 1]))) {
--end;
}
}
return s.substr(start, end - start);
};
if (parts.empty()) {
return *this;
}
if (left) {
parts[0].val = strip_part(parts[0].val, true, false);
}
if (right) {
auto & last = parts[parts.size() - 1];
last.val = strip_part(last.val, false, true);
}
return *this;
}
};
} // namespace jinja

View File

@ -0,0 +1,26 @@
#pragma once
#include <string>
#include <algorithm>
#include <vector>
namespace jinja {
static void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
if (search.empty()) {
return;
}
std::string builder;
builder.reserve(s.length());
size_t pos = 0;
size_t last_pos = 0;
while ((pos = s.find(search, last_pos)) != std::string::npos) {
builder.append(s, last_pos, pos - last_pos);
builder.append(replace);
last_pos = pos + search.length();
}
builder.append(s, last_pos, std::string::npos);
s = std::move(builder);
}
} // namespace jinja

View File

@ -0,0 +1,734 @@
#include "jinja-lexer.h"
#include "jinja-vm.h"
#include "jinja-parser.h"
#include "jinja-value.h"
// for converting from JSON to jinja values
#include <nlohmann/json.hpp>
#include <string>
#include <cctype>
#include <vector>
#include <optional>
#include <algorithm>
#define FILENAME "jinja-value"
namespace jinja {
// func_args method implementations
value func_args::get_kwarg(const std::string & key) const {
for (const auto & arg : args) {
if (is_val<value_kwarg>(arg)) {
auto * kwarg = cast_val<value_kwarg>(arg);
if (kwarg->key == key) {
return kwarg->val;
}
}
}
return mk_val<value_undefined>();
}
/**
* Function that mimics Python's array slicing.
*/
template<typename T>
static T slice(const T & array, std::optional<int64_t> start = std::nullopt, std::optional<int64_t> stop = std::nullopt, int64_t step = 1) {
int64_t len = static_cast<int64_t>(array.size());
int64_t direction = (step > 0) ? 1 : ((step < 0) ? -1 : 0);
int64_t start_val;
int64_t stop_val;
if (direction >= 0) {
start_val = start.value_or(0);
if (start_val < 0) {
start_val = std::max(len + start_val, (int64_t)0);
} else {
start_val = std::min(start_val, len);
}
stop_val = stop.value_or(len);
if (stop_val < 0) {
stop_val = std::max(len + stop_val, (int64_t)0);
} else {
stop_val = std::min(stop_val, len);
}
} else {
start_val = start.value_or(len - 1);
if (start_val < 0) {
start_val = std::max(len + start_val, (int64_t)-1);
} else {
start_val = std::min(start_val, len - 1);
}
stop_val = stop.value_or(-1);
if (stop_val < -1) {
stop_val = std::max(len + stop_val, (int64_t)-1);
} else {
stop_val = std::min(stop_val, len - 1);
}
}
T result;
if (direction == 0) {
return result;
}
for (int64_t i = start_val; direction * i < direction * stop_val; i += step) {
if (i >= 0 && i < len) {
result.push_back(array[static_cast<size_t>(i)]);
}
}
return result;
}
template<typename T>
static value test_type_fn(const func_args & args) {
args.ensure_count(1);
bool is_type = is_val<T>(args.args[0]);
JJ_DEBUG("test_type_fn: type=%s result=%d", typeid(T).name(), is_type ? 1 : 0);
return mk_val<value_bool>(is_type);
}
template<typename T, typename U>
static value test_type_fn(const func_args & args) {
args.ensure_count(1);
bool is_type = is_val<T>(args.args[0]) || is_val<U>(args.args[0]);
JJ_DEBUG("test_type_fn: type=%s or %s result=%d", typeid(T).name(), typeid(U).name(), is_type ? 1 : 0);
return mk_val<value_bool>(is_type);
}
const func_builtins & global_builtins() {
static const func_builtins builtins = {
{"raise_exception", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
std::string msg = args.args[0]->as_string().str();
throw raised_exception("Jinja Exception: " + msg);
}},
{"namespace", [](const func_args & args) -> value {
auto out = mk_val<value_object>();
for (const auto & arg : args.args) {
if (!is_val<value_kwarg>(arg)) {
throw raised_exception("namespace() arguments must be kwargs");
}
auto kwarg = cast_val<value_kwarg>(arg);
JJ_DEBUG("namespace: adding key '%s'", kwarg->key.c_str());
out->insert(kwarg->key, kwarg->val);
}
return out;
}},
{"strftime_now", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
std::string format = args.args[0]->as_string().str();
// get current time
// TODO: make sure this is the same behavior as Python's strftime
char buf[100];
if (std::strftime(buf, sizeof(buf), format.c_str(), std::localtime(&args.ctx.current_time))) {
return mk_val<value_string>(std::string(buf));
} else {
throw raised_exception("strftime_now: failed to format time");
}
}},
{"range", [](const func_args & args) -> value {
args.ensure_count(1, 3);
args.ensure_vals<value_int, value_int, value_int>(true, false, false);
auto & arg0 = args.args[0];
auto & arg1 = args.args[1];
auto & arg2 = args.args[2];
int64_t start, stop, step;
if (args.args.size() == 1) {
start = 0;
stop = arg0->as_int();
step = 1;
} else if (args.args.size() == 2) {
start = arg0->as_int();
stop = arg1->as_int();
step = 1;
} else {
start = arg0->as_int();
stop = arg1->as_int();
step = arg2->as_int();
}
auto out = mk_val<value_array>();
if (step == 0) {
throw raised_exception("range() step argument must not be zero");
}
if (step > 0) {
for (int64_t i = start; i < stop; i += step) {
out->push_back(mk_val<value_int>(i));
}
} else {
for (int64_t i = start; i > stop; i += step) {
out->push_back(mk_val<value_int>(i));
}
}
return out;
}},
{"tojson", [](const func_args & args) -> value {
args.ensure_count(1);
// placeholder implementation
return mk_val<value_string>("TODO: to_json output");
}},
// tests
{"test_is_boolean", test_type_fn<value_bool>},
{"test_is_callable", test_type_fn<value_func>},
{"test_is_odd", [](const func_args & args) -> value {
args.ensure_vals<value_int>();
int64_t val = args.args[0]->as_int();
return mk_val<value_bool>(val % 2 != 0);
}},
{"test_is_even", [](const func_args & args) -> value {
args.ensure_vals<value_int>();
int64_t val = args.args[0]->as_int();
return mk_val<value_bool>(val % 2 == 0);
}},
{"test_is_false", [](const func_args & args) -> value {
args.ensure_count(1);
bool val = is_val<value_bool>(args.args[0]) && !args.args[0]->as_bool();
return mk_val<value_bool>(val);
}},
{"test_is_true", [](const func_args & args) -> value {
args.ensure_count(1);
bool val = is_val<value_bool>(args.args[0]) && args.args[0]->as_bool();
return mk_val<value_bool>(val);
}},
{"test_is_string", test_type_fn<value_string>},
{"test_is_integer", test_type_fn<value_int>},
{"test_is_number", test_type_fn<value_int, value_float>},
{"test_is_iterable", test_type_fn<value_array, value_string>},
{"test_is_sequence", test_type_fn<value_array, value_string>},
{"test_is_mapping", test_type_fn<value_object>},
{"test_is_lower", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
return mk_val<value_bool>(args.args[0]->val_str.is_lowercase());
}},
{"test_is_upper", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
return mk_val<value_bool>(args.args[0]->val_str.is_uppercase());
}},
{"test_is_none", test_type_fn<value_null>},
{"test_is_defined", [](const func_args & args) -> value {
args.ensure_count(1);
bool res = !args.args[0]->is_undefined();
JJ_DEBUG("test_is_defined: result=%d", res ? 1 : 0);
return mk_val<value_bool>(res);
}},
{"test_is_undefined", test_type_fn<value_undefined>},
};
return builtins;
}
const func_builtins & value_int_t::get_builtins() const {
static const func_builtins builtins = {
{"abs", [](const func_args & args) -> value {
args.ensure_vals<value_int>();
int64_t val = args.args[0]->as_int();
return mk_val<value_int>(val < 0 ? -val : val);
}},
{"float", [](const func_args & args) -> value {
args.ensure_vals<value_int>();
double val = static_cast<double>(args.args[0]->as_int());
return mk_val<value_float>(val);
}},
};
return builtins;
}
const func_builtins & value_float_t::get_builtins() const {
static const func_builtins builtins = {
{"abs", [](const func_args & args) -> value {
args.ensure_vals<value_float>();
double val = args.args[0]->as_float();
return mk_val<value_float>(val < 0.0 ? -val : val);
}},
{"int", [](const func_args & args) -> value {
args.ensure_vals<value_float>();
int64_t val = static_cast<int64_t>(args.args[0]->as_float());
return mk_val<value_int>(val);
}},
};
return builtins;
}
static bool string_startswith(const std::string & str, const std::string & prefix) {
if (str.length() < prefix.length()) return false;
return str.compare(0, prefix.length(), prefix) == 0;
}
static bool string_endswith(const std::string & str, const std::string & suffix) {
if (str.length() < suffix.length()) return false;
return str.compare(str.length() - suffix.length(), suffix.length(), suffix) == 0;
}
const func_builtins & value_string_t::get_builtins() const {
static const func_builtins builtins = {
{"upper", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
jinja::string str = args.args[0]->as_string().uppercase();
return mk_val<value_string>(str);
}},
{"lower", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
jinja::string str = args.args[0]->as_string().lowercase();
return mk_val<value_string>(str);
}},
{"strip", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
jinja::string str = args.args[0]->as_string().strip(true, true);
return mk_val<value_string>(str);
}},
{"rstrip", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
jinja::string str = args.args[0]->as_string().strip(false, true);
return mk_val<value_string>(str);
}},
{"lstrip", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
jinja::string str = args.args[0]->as_string().strip(true, false);
return mk_val<value_string>(str);
}},
{"title", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
jinja::string str = args.args[0]->as_string().titlecase();
return mk_val<value_string>(str);
}},
{"capitalize", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
jinja::string str = args.args[0]->as_string().capitalize();
return mk_val<value_string>(str);
}},
{"length", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
jinja::string str = args.args[0]->as_string();
return mk_val<value_int>(str.length());
}},
{"startswith", [](const func_args & args) -> value {
args.ensure_vals<value_string, value_string>();
std::string str = args.args[0]->as_string().str();
std::string prefix = args.args[1]->as_string().str();
return mk_val<value_bool>(string_startswith(str, prefix));
}},
{"endswith", [](const func_args & args) -> value {
args.ensure_vals<value_string, value_string>();
std::string str = args.args[0]->as_string().str();
std::string suffix = args.args[1]->as_string().str();
return mk_val<value_bool>(string_endswith(str, suffix));
}},
{"split", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
std::string str = args.args[0]->as_string().str();
std::string delim = (args.args.size() > 1) ? args.args[1]->as_string().str() : " ";
auto result = mk_val<value_array>();
size_t pos = 0;
std::string token;
while ((pos = str.find(delim)) != std::string::npos) {
token = str.substr(0, pos);
result->push_back(mk_val<value_string>(token));
str.erase(0, pos + delim.length());
}
auto res = mk_val<value_string>(str);
res->val_str.mark_input_based_on(args.args[0]->val_str);
result->push_back(std::move(res));
return std::move(result);
}},
{"replace", [](const func_args & args) -> value {
args.ensure_vals<value_string, value_string, value_string>();
std::string str = args.args[0]->as_string().str();
std::string old_str = args.args[1]->as_string().str();
std::string new_str = args.args[2]->as_string().str();
size_t pos = 0;
while ((pos = str.find(old_str, pos)) != std::string::npos) {
str.replace(pos, old_str.length(), new_str);
pos += new_str.length();
}
auto res = mk_val<value_string>(str);
res->val_str.mark_input_based_on(args.args[0]->val_str);
return res;
}},
{"int", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
std::string str = args.args[0]->as_string().str();
try {
return mk_val<value_int>(std::stoi(str));
} catch (...) {
throw std::runtime_error("Cannot convert string '" + str + "' to int");
}
}},
{"float", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
std::string str = args.args[0]->as_string().str();
try {
return mk_val<value_float>(std::stod(str));
} catch (...) {
throw std::runtime_error("Cannot convert string '" + str + "' to float");
}
}},
{"string", [](const func_args & args) -> value {
// no-op
args.ensure_vals<value_string>();
return mk_val<value_string>(args.args[0]->as_string());
}},
{"default", [](const func_args & args) -> value {
value input = args.args[0];
if (!is_val<value_string>(input)) {
throw raised_exception("default() first argument must be a string");
}
value default_val = mk_val<value_string>("");
if (args.args.size() > 1 && !args.args[1]->is_undefined()) {
default_val = args.args[1];
}
value boolean_val = mk_val<value_bool>(false);
if (args.args.size() > 1) {
boolean_val = args.args[1];
}
if (input->is_undefined() || (boolean_val->as_bool() && !input->as_bool())) {
return default_val;
} else {
return input;
}
}},
{"slice", [](const func_args & args) -> value {
auto & input = args.args[0];
if (!is_val<value_string>(input)) {
throw raised_exception("slice() first argument must be a string");
}
if (args.args.size() < 1 || args.args.size() > 4) {
throw raised_exception("slice() takes between 1 and 4 arguments");
}
int64_t start = is_val<value_int>(args.args[1]) ? args.args[1]->as_int() : 0;
int64_t stop = is_val<value_int>(args.args[2]) ? args.args[2]->as_int() : -1;
int64_t step = is_val<value_int>(args.args[3]) ? args.args[3]->as_int() : 1;
if (step == 0) {
throw raised_exception("slice step cannot be zero");
}
auto sliced = slice(input->as_string().str(), start, stop, step);
auto res = mk_val<value_string>(sliced);
res->val_str.mark_input_based_on(input->as_string());
return res;
}},
{"safe", [](const func_args & args) -> value {
// no-op for now
args.ensure_vals<value_string>();
return args.args[0];
}},
{"selectattr", [](const func_args &) -> value {
throw std::runtime_error("String selectattr builtin not supported");
}},
{"rejectattr", [](const func_args &) -> value {
throw std::runtime_error("String rejectattr builtin not supported");
}},
{"indent", [](const func_args &) -> value {
throw std::runtime_error("String indent builtin not implemented");
}},
{"join", [](const func_args &) -> value {
throw std::runtime_error("String join builtin not implemented");
}},
};
return builtins;
}
const func_builtins & value_bool_t::get_builtins() const {
static const func_builtins builtins = {
{"int", [](const func_args & args) -> value {
args.ensure_vals<value_bool>();
bool val = args.args[0]->as_bool();
return mk_val<value_int>(val ? 1 : 0);
}},
{"float", [](const func_args & args) -> value {
args.ensure_vals<value_bool>();
bool val = args.args[0]->as_bool();
return mk_val<value_float>(val ? 1.0 : 0.0);
}},
{"string", [](const func_args & args) -> value {
args.ensure_vals<value_bool>();
bool val = args.args[0]->as_bool();
return mk_val<value_string>(val ? "True" : "False");
}},
};
return builtins;
}
const func_builtins & value_array_t::get_builtins() const {
static const func_builtins builtins = {
{"list", [](const func_args & args) -> value {
args.ensure_vals<value_array>();
const auto & arr = args.args[0]->as_array();
auto result = mk_val<value_array>();
for (const auto& v : arr) {
result->push_back(v);
}
return result;
}},
{"first", [](const func_args & args) -> value {
args.ensure_vals<value_array>();
const auto & arr = args.args[0]->as_array();
if (arr.empty()) {
return mk_val<value_undefined>();
}
return arr[0];
}},
{"last", [](const func_args & args) -> value {
args.ensure_vals<value_array>();
const auto & arr = args.args[0]->as_array();
if (arr.empty()) {
return mk_val<value_undefined>();
}
return arr[arr.size() - 1];
}},
{"length", [](const func_args & args) -> value {
args.ensure_vals<value_array>();
const auto & arr = args.args[0]->as_array();
return mk_val<value_int>(static_cast<int64_t>(arr.size()));
}},
{"slice", [](const func_args & args) -> value {
if (args.args.size() < 1 || args.args.size() > 4) {
throw raised_exception("slice() takes between 1 and 4 arguments");
}
int64_t start = is_val<value_int>(args.args[1]) ? args.args[1]->as_int() : 0;
int64_t stop = is_val<value_int>(args.args[2]) ? args.args[2]->as_int() : -1;
int64_t step = is_val<value_int>(args.args[3]) ? args.args[3]->as_int() : 1;
if (!is_val<value_array>(args.args[0])) {
throw raised_exception("slice() first argument must be an array");
}
if (step == 0) {
throw raised_exception("slice step cannot be zero");
}
auto arr = slice(args.args[0]->as_array(), start, stop, step);
auto res = mk_val<value_array>();
res->val_arr = std::move(arr);
return res;
}},
{"selectattr", [](const func_args & args) -> value {
value input = args.args[0];
if (!is_val<value_array>(input)) {
throw raised_exception("selectattr() first argument must be an array, got " + input->type());
}
std::vector<std::string> selected;
for (size_t i = 1; i < args.args.size(); ++i) {
const auto & v = args.args[i];
if (!is_val<value_string>(v)) {
throw raised_exception("selectattr() attributes must be strings, got " + v->type());
}
JJ_DEBUG("selectattr: selecting attribute '%s'", v->as_string().str().c_str());
selected.push_back(v->as_string().str());
}
auto result = mk_val<value_array>();
for (const auto & item : input->as_array()) {
if (!is_val<value_object>(item)) {
continue;
}
const auto & obj = item->as_object();
bool match = true;
for (const auto & attr : selected) {
auto it = obj.find(attr);
if (it == obj.end() || it->second->is_undefined() || (is_val<value_bool>(it->second) && !it->second->as_bool())) {
match = false;
break;
}
}
if (match) {
result->push_back(item);
}
}
return result;
}},
{"rejectattr", [](const func_args & args) -> value {
value input = args.args[0];
if (!is_val<value_array>(input)) {
throw raised_exception("rejectattr() first argument must be an array, got " + input->type());
}
std::vector<std::string> rejected;
for (size_t i = 1; i < args.args.size(); ++i) {
const auto & v = args.args[i];
if (!is_val<value_string>(v)) {
throw raised_exception("rejectattr() attributes must be strings, got " + v->type());
}
JJ_DEBUG("rejectattr: rejecting attribute '%s'", v->as_string().str().c_str());
rejected.push_back(v->as_string().str());
}
auto result = mk_val<value_array>();
for (const auto & item : input->as_array()) {
if (!is_val<value_object>(item)) {
result->push_back(item);
continue;
}
const auto & obj = item->as_object();
bool match = false;
for (const auto & attr : rejected) {
auto it = obj.find(attr);
if (it != obj.end() && !it->second->is_undefined() && (!is_val<value_bool>(it->second) || it->second->as_bool())) {
match = true;
break;
}
}
if (!match) {
result->push_back(item);
}
}
return result;
}},
{"join", [](const func_args & args) -> value {
if (args.args.size() < 1 || args.args.size() > 2) {
throw raised_exception("join() takes one or two arguments");
}
if (!is_val<value_array>(args.args[0])) {
throw raised_exception("join() first argument must be an array");
}
const auto & arr = args.args[0]->as_array();
std::string delim = (args.args.size() > 1 && is_val<value_string>(args.args[1])) ? args.args[1]->as_string().str() : "";
std::string result;
for (size_t i = 0; i < arr.size(); ++i) {
if (!is_val<value_string>(arr[i])) {
throw raised_exception("join() can only join arrays of strings");
}
result += arr[i]->as_string().str();
if (i < arr.size() - 1) {
result += delim;
}
}
return mk_val<value_string>(result);
}},
{"string", [](const func_args & args) -> value {
args.ensure_vals<value_array>();
auto str = mk_val<value_string>();
gather_string_parts_recursive(args.args[0], str);
return str;
}},
{"sort", [](const func_args &) -> value {
throw std::runtime_error("Array sort builtin not implemented");
}},
{"reverse", [](const func_args &) -> value {
throw std::runtime_error("Array reverse builtin not implemented");
}},
{"unique", [](const func_args &) -> value {
throw std::runtime_error("Array unique builtin not implemented");
}},
};
return builtins;
}
const func_builtins & value_object_t::get_builtins() const {
static const func_builtins builtins = {
{"get", [](const func_args & args) -> value {
args.ensure_vals<value_object, value_string>(); // TODO: add default value
const auto & obj = args.args[0]->as_object();
std::string key = args.args[1]->as_string().str();
auto it = obj.find(key);
if (it != obj.end()) {
return it->second;
} else {
return mk_val<value_undefined>();
}
}},
{"keys", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
const auto & obj = args.args[0]->as_object();
auto result = mk_val<value_array>();
for (const auto & pair : obj) {
result->push_back(mk_val<value_string>(pair.first));
}
return result;
}},
{"values", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
const auto & obj = args.args[0]->as_object();
auto result = mk_val<value_array>();
for (const auto & pair : obj) {
result->push_back(pair.second);
}
return result;
}},
{"items", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
const auto & obj = args.args[0]->as_object();
auto result = mk_val<value_array>();
for (const auto & pair : obj) {
auto item = mk_val<value_array>();
item->push_back(mk_val<value_string>(pair.first));
item->push_back(pair.second);
result->push_back(std::move(item));
}
return result;
}},
{"string", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
return mk_val<value_string>("TO BE IMPLEMENTED");
}},
{"tojson", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
// use global to_json
return global_builtins().at("tojson")(args);
}},
{"dictsort", [](const func_args & args) -> value {
// no-op
args.ensure_vals<value_object>();
return args.args[0];
}},
};
return builtins;
}
const func_builtins & value_null_t::get_builtins() const {
static const func_builtins builtins = {
// TODO: may need to implement this, idk
};
return builtins;
}
//////////////////////////////////
static value from_json(const nlohmann::json & j) {
if (j.is_null()) {
return mk_val<value_null>();
} else if (j.is_boolean()) {
return mk_val<value_bool>(j.get<bool>());
} else if (j.is_number_integer()) {
return mk_val<value_int>(j.get<int64_t>());
} else if (j.is_number_float()) {
return mk_val<value_float>(j.get<double>());
} else if (j.is_string()) {
return mk_val<value_string>(j.get<std::string>());
} else if (j.is_array()) {
auto arr = mk_val<value_array>();
for (const auto & item : j) {
arr->push_back(from_json(item));
}
return arr;
} else if (j.is_object()) {
if (j.contains("__input__")) {
// handle input marking
auto str = mk_val<value_string>(j.at("__input__").get<std::string>());
str->mark_input();
return str;
} else {
// normal object
auto obj = mk_val<value_object>();
for (auto it = j.begin(); it != j.end(); ++it) {
obj->insert(it.key(), from_json(it.value()));
}
return obj;
}
} else {
throw std::runtime_error("Unsupported JSON value type");
}
}
template<>
void global_from_json(context & ctx, const nlohmann::json & json_obj) {
if (json_obj.is_null() || !json_obj.is_object()) {
throw std::runtime_error("global_from_json: input JSON value must be an object");
}
for (auto it = json_obj.begin(); it != json_obj.end(); ++it) {
ctx.set_val(it.key(), from_json(it.value()));
}
}
} // namespace jinja

352
common/jinja/jinja-value.h Normal file
View File

@ -0,0 +1,352 @@
#pragma once
#include <vector>
#include <string>
#include <map>
#include <functional>
#include <memory>
#include <sstream>
#include <set>
#include "jinja-string.h"
namespace jinja {
struct value_t;
using value = std::shared_ptr<value_t>;
// Helper to check the type of a value
template<typename T>
struct extract_pointee {
using type = T;
};
template<typename U>
struct extract_pointee<std::shared_ptr<U>> {
using type = U;
};
template<typename T>
bool is_val(const value & ptr) {
using PointeeType = typename extract_pointee<T>::type;
return dynamic_cast<const PointeeType*>(ptr.get()) != nullptr;
}
template<typename T>
bool is_val(const value_t * ptr) {
using PointeeType = typename extract_pointee<T>::type;
return dynamic_cast<const PointeeType*>(ptr) != nullptr;
}
template<typename T, typename... Args>
std::shared_ptr<typename extract_pointee<T>::type> mk_val(Args&&... args) {
using PointeeType = typename extract_pointee<T>::type;
return std::make_shared<PointeeType>(std::forward<Args>(args)...);
}
template<typename T>
const typename extract_pointee<T>::type * cast_val(const value & ptr) {
using PointeeType = typename extract_pointee<T>::type;
return dynamic_cast<const PointeeType*>(ptr.get());
}
template<typename T>
typename extract_pointee<T>::type * cast_val(value & ptr) {
using PointeeType = typename extract_pointee<T>::type;
return dynamic_cast<PointeeType*>(ptr.get());
}
// End Helper
struct context; // forward declaration
// for converting from JSON to jinja values
// example input JSON:
// {
// "messages": [
// {"role": "user", "content": "Hello!"},
// {"role": "assistant", "content": "Hi there!"}
// ],
// "bos_token": "<s>",
// "eos_token": "</s>",
// }
//
// to mark strings as user input, wrap them in a special object:
// {
// "messages": [
// {
// "role": "user",
// "content": {"__input__": "Hello!"} // this string is user input
// },
// ...
// ],
// }
//
// marking input can be useful for tracking data provenance
// and preventing template injection attacks
//
// Note: T_JSON can be nlohmann::json or similar types
template<typename T_JSON>
void global_from_json(context & ctx, const T_JSON & json_obj);
//
// base value type
//
struct func_args; // function argument values
using func_handler = std::function<value(const func_args &)>;
using func_builtins = std::map<std::string, func_handler>;
bool value_compare(const value & a, const value & b);
struct value_t {
int64_t val_int;
double val_flt;
string val_str;
bool val_bool;
std::vector<value> val_arr;
std::map<std::string, value> val_obj;
func_handler val_func;
// only used if ctx.is_get_stats = true
struct stats_t {
bool used = false;
// ops can be builtin calls or operators: "array_access", "object_access"
std::set<std::string> ops;
} stats;
value_t() = default;
value_t(const value_t &) = default;
virtual ~value_t() = default;
virtual std::string type() const { return ""; }
virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); }
virtual double as_float() const { throw std::runtime_error(type() + " is not a float value"); }
virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
virtual const std::map<std::string, value> & as_object() const { throw std::runtime_error(type() + " is not an object value"); }
virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
virtual bool is_null() const { return false; }
virtual bool is_undefined() const { return false; }
virtual const func_builtins & get_builtins() const {
throw std::runtime_error("No builtins available for type " + type());
}
virtual value & at(const std::string & key) { return val_obj[key]; }
virtual value & at(size_t index) { return val_arr.at(index); }
virtual std::string as_repr() const { return as_string().str(); }
};
//
// primitive value types
//
struct value_int_t : public value_t {
value_int_t(int64_t v) { val_int = v; }
virtual std::string type() const override { return "Integer"; }
virtual int64_t as_int() const override { return val_int; }
virtual double as_float() const override { return static_cast<double>(val_int); }
virtual string as_string() const override { return std::to_string(val_int); }
virtual const func_builtins & get_builtins() const override;
};
using value_int = std::shared_ptr<value_int_t>;
struct value_float_t : public value_t {
value_float_t(double v) { val_flt = v; }
virtual std::string type() const override { return "Float"; }
virtual double as_float() const override { return val_flt; }
virtual int64_t as_int() const override { return static_cast<int64_t>(val_flt); }
virtual string as_string() const override { return std::to_string(val_flt); }
virtual const func_builtins & get_builtins() const override;
};
using value_float = std::shared_ptr<value_float_t>;
struct value_string_t : public value_t {
value_string_t() { val_str = string(); }
value_string_t(const std::string & v) { val_str = string(v); }
value_string_t(const string & v) { val_str = v; }
virtual std::string type() const override { return "String"; }
virtual string as_string() const override { return val_str; }
virtual std::string as_repr() const override {
std::ostringstream ss;
for (const auto & part : val_str.parts) {
ss << (part.is_input ? "INPUT: " : "TMPL: ") << part.val << "\n";
}
return ss.str();
}
virtual bool as_bool() const override {
return val_str.length() > 0;
}
virtual const func_builtins & get_builtins() const override;
void mark_input() {
val_str.mark_input();
}
};
using value_string = std::shared_ptr<value_string_t>;
struct value_bool_t : public value_t {
value_bool_t(bool v) { val_bool = v; }
virtual std::string type() const override { return "Boolean"; }
virtual bool as_bool() const override { return val_bool; }
virtual string as_string() const override { return std::string(val_bool ? "True" : "False"); }
virtual const func_builtins & get_builtins() const override;
};
using value_bool = std::shared_ptr<value_bool_t>;
struct value_array_t : public value_t {
value_array_t() = default;
value_array_t(value & v) {
// point to the same underlying data
val_arr = v->val_arr;
}
void push_back(const value & val) {
val_arr.push_back(val);
}
virtual std::string type() const override { return "Array"; }
virtual const std::vector<value> & as_array() const override { return val_arr; }
virtual string as_string() const override {
std::ostringstream ss;
ss << "[";
for (size_t i = 0; i < val_arr.size(); i++) {
if (i > 0) ss << ", ";
ss << val_arr.at(i)->as_repr();
}
ss << "]";
return ss.str();
}
virtual bool as_bool() const override {
return !val_arr.empty();
}
virtual const func_builtins & get_builtins() const override;
};
using value_array = std::shared_ptr<value_array_t>;
struct value_object_t : public value_t {
value_object_t() = default;
value_object_t(value & v) {
// point to the same underlying data
val_obj = v->val_obj;
}
value_object_t(const std::map<std::string, value> & obj) {
val_obj = std::map<std::string, value>();
for (const auto & pair : obj) {
val_obj[pair.first] = pair.second;
}
}
void insert(const std::string & key, const value & val) {
val_obj[key] = val;
}
virtual std::string type() const override { return "Object"; }
virtual const std::map<std::string, value> & as_object() const override { return val_obj; }
virtual bool as_bool() const override {
return !val_obj.empty();
}
virtual const func_builtins & get_builtins() const override;
};
using value_object = std::shared_ptr<value_object_t>;
//
// null and undefined types
//
struct value_null_t : public value_t {
virtual std::string type() const override { return "Null"; }
virtual bool is_null() const override { return true; }
virtual bool as_bool() const override { return false; }
virtual std::string as_repr() const override { return type(); }
virtual const func_builtins & get_builtins() const override;
};
using value_null = std::shared_ptr<value_null_t>;
struct value_undefined_t : public value_t {
std::string hint; // for debugging, to indicate where undefined came from
value_undefined_t(const std::string & h = "") : hint(h) {}
virtual std::string type() const override { return hint.empty() ? "Undefined" : "Undefined (hint: '" + hint + "')"; }
virtual bool is_undefined() const override { return true; }
virtual bool as_bool() const override { return false; }
virtual std::string as_repr() const override { return type(); }
};
using value_undefined = std::shared_ptr<value_undefined_t>;
//
// function type
//
struct func_args {
std::string func_name; // for error messages
std::vector<value> args;
context & ctx;
func_args(context & ctx) : ctx(ctx) {}
value get_kwarg(const std::string & key) const;
void ensure_count(size_t min, size_t max = 999) const {
size_t n = args.size();
if (n < min || n > max) {
throw std::runtime_error("Function '" + func_name + "' expected between " + std::to_string(min) + " and " + std::to_string(max) + " arguments, got " + std::to_string(n));
}
}
template<typename T> void ensure_val(const value & ptr) const {
if (!is_val<T>(ptr)) {
throw std::runtime_error("Function '" + func_name + "' expected value of type " + std::string(typeid(T).name()) + ", got " + ptr->type());
}
}
template<typename T0> void ensure_vals(bool required0 = true) const {
if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
}
template<typename T0, typename T1> void ensure_vals(bool required0 = true, bool required1 = true) const {
if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
if (required1 && args.size() > 1) ensure_val<T1>(args[1]);
}
template<typename T0, typename T1, typename T2> void ensure_vals(bool required0 = true, bool required1 = true, bool required2 = true) const {
if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
if (required1 && args.size() > 1) ensure_val<T1>(args[1]);
if (required2 && args.size() > 2) ensure_val<T2>(args[2]);
}
};
struct value_func_t : public value_t {
std::string name;
value arg0; // bound "this" argument, if any
value_func_t(const std::string & name, const func_handler & func) : name(name) {
val_func = func;
}
value_func_t(const std::string & name, const func_handler & func, const value & arg_this) : name(name), arg0(arg_this) {
val_func = func;
}
virtual value invoke(const func_args & args) const override {
func_args new_args(args); // copy
new_args.func_name = name;
if (arg0) {
new_args.args.insert(new_args.args.begin(), arg0);
}
return val_func(new_args);
}
virtual std::string type() const override { return "Function"; }
virtual std::string as_repr() const override { return type(); }
};
using value_func = std::shared_ptr<value_func_t>;
// special value for kwarg
struct value_kwarg_t : public value_t {
std::string key;
value val;
value_kwarg_t(const std::string & k, const value & v) : key(k), val(v) {}
virtual std::string type() const override { return "KwArg"; }
virtual std::string as_repr() const override { return type(); }
};
using value_kwarg = std::shared_ptr<value_kwarg_t>;
// utils
const func_builtins & global_builtins();
} // namespace jinja

794
common/jinja/jinja-vm.cpp Normal file
View File

@ -0,0 +1,794 @@
#include "jinja-lexer.h"
#include "jinja-vm.h"
#include "jinja-parser.h"
#include "jinja-value.h"
#include "jinja-utils.h"
#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#define FILENAME "jinja-vm"
bool g_jinja_debug = false;
namespace jinja {
void enable_debug(bool enable) {
g_jinja_debug = enable;
}
static value_string exec_statements(const statements & stmts, context & ctx) {
auto result = mk_val<value_array>();
for (const auto & stmt : stmts) {
JJ_DEBUG("Executing statement of type %s", stmt->type().c_str());
result->push_back(stmt->execute(ctx));
}
// convert to string parts
value_string str = mk_val<value_string>();
gather_string_parts_recursive(result, str);
return str;
}
// execute with error handling
value statement::execute(context & ctx) {
try {
return execute_impl(ctx);
} catch (const continue_statement::signal & ex) {
throw ex;
} catch (const break_statement::signal & ex) {
throw ex;
} catch (const std::exception & e) {
if (ctx.source.empty()) {
std::ostringstream oss;
oss << "\nError executing " << type() << " at position " << pos << ": " << e.what();
throw raised_exception(oss.str());
} else {
std::ostringstream oss;
constexpr int max_peak_chars = 40;
oss << "\n------------\n";
oss << "While executing " << type() << " at position " << pos << " in source:\n";
size_t start = (pos >= max_peak_chars) ? (pos - max_peak_chars) : 0;
size_t end = std::min(pos + max_peak_chars, ctx.source.length());
std::string substr = ctx.source.substr(start, end - start);
string_replace_all(substr, "\n", "\\n");
oss << "..." << substr << "...\n";
std::string spaces(pos - start + 3, ' ');
oss << spaces << "^\n";
oss << "Error: " << e.what();
throw raised_exception(oss.str());
}
}
}
value identifier::execute_impl(context & ctx) {
auto it = ctx.get_val(val);
auto builtins = global_builtins();
if (!it->is_undefined()) {
if (ctx.is_get_stats) {
it->stats.used = true;
}
JJ_DEBUG("Identifier '%s' found", val.c_str());
return it;
} else if (builtins.find(val) != builtins.end()) {
JJ_DEBUG("Identifier '%s' found in builtins", val.c_str());
return mk_val<value_func>(val, builtins.at(val));
} else {
JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str());
return mk_val<value_undefined>(val);
}
}
value object_literal::execute_impl(context & ctx) {
auto obj = mk_val<value_object>();
for (const auto & pair : val) {
std::string key = pair.first->execute(ctx)->as_string().str();
value val = pair.second->execute(ctx);
JJ_DEBUG("Object literal: setting key '%s' of type %s", key.c_str(), val->type().c_str());
obj->val_obj[key] = val;
}
return obj;
}
value binary_expression::execute_impl(context & ctx) {
value left_val = left->execute(ctx);
// Logical operators
if (op.value == "and") {
return left_val->as_bool() ? right->execute(ctx) : std::move(left_val);
} else if (op.value == "or") {
return left_val->as_bool() ? std::move(left_val) : right->execute(ctx);
}
// Equality operators
value right_val = right->execute(ctx);
JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right_val->type().c_str());
if (op.value == "==") {
return mk_val<value_bool>(value_compare(left_val, right_val));
} else if (op.value == "!=") {
return mk_val<value_bool>(!value_compare(left_val, right_val));
}
// Handle undefined and null values
if (is_val<value_undefined>(left_val) || is_val<value_undefined>(right_val)) {
if (is_val<value_undefined>(right_val) && (op.value == "in" || op.value == "not in")) {
// Special case: `anything in undefined` is `false` and `anything not in undefined` is `true`
return mk_val<value_bool>(op.value == "not in");
}
// if (ctx.wrk_around.string_plus_undefined_is_string && (op.value == "+" || op.value == "~")) {
// JJ_DEBUG("%s", "Workaround: treating undefined as empty string for string concatenation");
// auto left_str = left_val->is_undefined() ? string() : left_val->as_string();
// auto right_str = right_val->is_undefined() ? string() : right_val->as_string();
// auto output = left_str.append(right_str);
// auto res = mk_val<value_string>();
// res->val_str = std::move(output);
// return res;
// }
throw std::runtime_error("Cannot perform operation " + op.value + " on undefined values");
} else if (is_val<value_null>(left_val) || is_val<value_null>(right_val)) {
throw std::runtime_error("Cannot perform operation on null values");
}
// Float operations
if ((is_val<value_int>(left_val) || is_val<value_float>(left_val)) &&
(is_val<value_int>(right_val) || is_val<value_float>(right_val))) {
double a = left_val->as_float();
double b = right_val->as_float();
if (op.value == "+" || op.value == "-" || op.value == "*") {
double res = (op.value == "+") ? a + b : (op.value == "-") ? a - b : a * b;
JJ_DEBUG("Arithmetic operation: %f %s %f = %f", a, op.value.c_str(), b, res);
bool is_float = is_val<value_float>(left_val) || is_val<value_float>(right_val);
if (is_float) {
return mk_val<value_float>(res);
} else {
return mk_val<value_int>(static_cast<int64_t>(res));
}
} else if (op.value == "/") {
JJ_DEBUG("Division operation: %f / %f", a, b);
return mk_val<value_float>(a / b);
} else if (op.value == "%") {
double rem = std::fmod(a, b);
JJ_DEBUG("Modulo operation: %f %% %f = %f", a, b, rem);
bool is_float = is_val<value_float>(left_val) || is_val<value_float>(right_val);
if (is_float) {
return mk_val<value_float>(rem);
} else {
return mk_val<value_int>(static_cast<int64_t>(rem));
}
} else if (op.value == "<") {
JJ_DEBUG("Comparison operation: %f < %f is %d", a, b, a < b);
return mk_val<value_bool>(a < b);
} else if (op.value == ">") {
JJ_DEBUG("Comparison operation: %f > %f is %d", a, b, a > b);
return mk_val<value_bool>(a > b);
} else if (op.value == ">=") {
JJ_DEBUG("Comparison operation: %f >= %f is %d", a, b, a >= b);
return mk_val<value_bool>(a >= b);
} else if (op.value == "<=") {
JJ_DEBUG("Comparison operation: %f <= %f is %d", a, b, a <= b);
return mk_val<value_bool>(a <= b);
}
}
// Array operations
if (is_val<value_array>(left_val) && is_val<value_array>(right_val)) {
if (op.value == "+") {
auto & left_arr = left_val->as_array();
auto & right_arr = right_val->as_array();
auto result = mk_val<value_array>();
for (const auto & item : left_arr) {
result->push_back(item);
}
for (const auto & item : right_arr) {
result->push_back(item);
}
return result;
}
} else if (is_val<value_array>(right_val)) {
auto & arr = right_val->as_array();
bool member = false;
for (const auto & item : arr) {
if (value_compare(left_val, item)) {
member = true;
break;
}
}
if (op.value == "in") {
JJ_DEBUG("Checking membership: %s in Array is %d", left_val->type().c_str(), member);
return mk_val<value_bool>(member);
} else if (op.value == "not in") {
JJ_DEBUG("Checking non-membership: %s not in Array is %d", left_val->type().c_str(), !member);
return mk_val<value_bool>(!member);
}
}
// String concatenation with ~ and +
if ((is_val<value_string>(left_val) || is_val<value_string>(right_val)) &&
(op.value == "~" || op.value == "+")) {
JJ_DEBUG("String concatenation with %s operator", op.value.c_str());
auto output = left_val->as_string().append(right_val->as_string());
auto res = mk_val<value_string>();
res->val_str = std::move(output);
return res;
}
// String membership
if (is_val<value_string>(left_val) && is_val<value_string>(right_val)) {
auto left_str = left_val->as_string().str();
auto right_str = right_val->as_string().str();
if (op.value == "in") {
return mk_val<value_bool>(right_str.find(left_str) != std::string::npos);
} else if (op.value == "not in") {
return mk_val<value_bool>(right_str.find(left_str) == std::string::npos);
}
}
// String in object
if (is_val<value_string>(left_val) && is_val<value_object>(right_val)) {
auto key = left_val->as_string().str();
auto & obj = right_val->as_object();
bool has_key = obj.find(key) != obj.end();
if (op.value == "in") {
return mk_val<value_bool>(has_key);
} else if (op.value == "not in") {
return mk_val<value_bool>(!has_key);
}
}
throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type());
}
static value try_builtin_func(context & ctx, const std::string & name, value & input, bool undef_on_missing = false) {
JJ_DEBUG("Trying built-in function '%s' for type %s", name.c_str(), input->type().c_str());
if (ctx.is_get_stats) {
input->stats.used = true;
input->stats.ops.insert(name);
}
auto builtins = input->get_builtins();
auto it = builtins.find(name);
if (it != builtins.end()) {
JJ_DEBUG("Binding built-in '%s'", name.c_str());
return mk_val<value_func>(name, it->second, input);
}
if (undef_on_missing) {
return mk_val<value_undefined>(name);
}
throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type());
}
value filter_expression::execute_impl(context & ctx) {
value input = operand ? operand->execute(ctx) : val;
JJ_DEBUG("Applying filter to %s", input->type().c_str());
if (is_stmt<identifier>(filter)) {
auto filter_id = cast_stmt<identifier>(filter)->val;
if (filter_id == "to_json") {
// TODO: Implement to_json filter
throw std::runtime_error("to_json filter not implemented");
}
if (filter_id == "trim") {
filter_id = "strip"; // alias
}
JJ_DEBUG("Applying filter '%s' to %s", filter_id.c_str(), input->type().c_str());
return try_builtin_func(ctx, filter_id, input)->invoke(func_args(ctx));
} else if (is_stmt<call_expression>(filter)) {
auto call = cast_stmt<call_expression>(filter);
auto filter_id = cast_stmt<identifier>(call->callee)->val;
JJ_DEBUG("Applying filter '%s' with arguments to %s", filter_id.c_str(), input->type().c_str());
func_args args(ctx);
for (const auto & arg_expr : call->args) {
args.args.push_back(arg_expr->execute(ctx));
}
return try_builtin_func(ctx, filter_id, input)->invoke(args);
} else {
throw std::runtime_error("Invalid filter expression");
}
}
value filter_statement::execute_impl(context & ctx) {
// eval body as string, then apply filter
auto body_val = exec_statements(body, ctx);
value_string parts = mk_val<value_string>();
gather_string_parts_recursive(body_val, parts);
JJ_DEBUG("FilterStatement: applying filter to body string of length %zu", parts->val_str.length());
filter_expression filter_expr(std::move(parts), std::move(filter));
return filter_expr.execute(ctx);
}
value test_expression::execute_impl(context & ctx) {
// NOTE: "value is something" translates to function call "test_is_something(value)"
const auto & builtins = global_builtins();
if (!is_stmt<identifier>(test)) {
throw std::runtime_error("Invalid test expression");
}
auto test_id = cast_stmt<identifier>(test)->val;
auto it = builtins.find("test_is_" + test_id);
JJ_DEBUG("Test expression %s '%s' %s (using function 'test_is_%s')", operand->type().c_str(), test_id.c_str(), negate ? "(negate)" : "", test_id.c_str());
if (it == builtins.end()) {
throw std::runtime_error("Unknown test '" + test_id + "'");
}
value input = operand->execute(ctx);
func_args args(ctx);
args.args.push_back(input);
auto res = it->second(args);
if (negate) {
return mk_val<value_bool>(!res->as_bool());
} else {
return res;
}
}
value unary_expression::execute_impl(context & ctx) {
value operand_val = argument->execute(ctx);
JJ_DEBUG("Executing unary expression with operator '%s'", op.value.c_str());
if (op.value == "not") {
return mk_val<value_bool>(!operand_val->as_bool());
} else if (op.value == "-") {
if (is_val<value_int>(operand_val)) {
return mk_val<value_int>(-operand_val->as_int());
} else if (is_val<value_float>(operand_val)) {
return mk_val<value_float>(-operand_val->as_float());
} else {
throw std::runtime_error("Unary - operator requires numeric operand");
}
}
throw std::runtime_error("Unknown unary operator '" + op.value + "'");
}
value if_statement::execute_impl(context & ctx) {
value test_val = test->execute(ctx);
auto out = mk_val<value_array>();
if (test_val->as_bool()) {
for (auto & stmt : body) {
JJ_DEBUG("IF --> Executing THEN body, current block: %s", stmt->type().c_str());
out->push_back(stmt->execute(ctx));
}
} else {
for (auto & stmt : alternate) {
JJ_DEBUG("IF --> Executing ELSE body, current block: %s", stmt->type().c_str());
out->push_back(stmt->execute(ctx));
}
}
// convert to string parts
value_string str = mk_val<value_string>();
gather_string_parts_recursive(out, str);
return str;
}
value for_statement::execute_impl(context & ctx) {
context scope(ctx); // new scope for loop variables
jinja::select_expression * select_expr = cast_stmt<select_expression>(iterable);
statement_ptr test_expr_nullptr;
statement_ptr & iter_expr = [&]() -> statement_ptr & {
auto tmp = cast_stmt<select_expression>(iterable);
return tmp ? tmp->lhs : iterable;
}();
statement_ptr & test_expr = [&]() -> statement_ptr & {
auto tmp = cast_stmt<select_expression>(iterable);
return tmp ? tmp->test : test_expr_nullptr;
}();
JJ_DEBUG("Executing for statement, iterable type: %s", iter_expr->type().c_str());
value iterable_val = iter_expr->execute(scope);
if (iterable_val->is_undefined()) {
JJ_DEBUG("%s", "For loop iterable is undefined, skipping loop");
iterable_val = mk_val<value_array>();
}
if (!is_val<value_array>(iterable_val) && !is_val<value_object>(iterable_val)) {
throw std::runtime_error("Expected iterable or object type in for loop: got " + iterable_val->type());
}
std::vector<value> items;
if (is_val<value_object>(iterable_val)) {
JJ_DEBUG("%s", "For loop over object keys");
auto & obj = iterable_val->as_object();
for (auto & p : obj) {
auto tuple = mk_val<value_array>();
tuple->push_back(mk_val<value_string>(p.first));
tuple->push_back(p.second);
items.push_back(tuple);
}
if (ctx.is_get_stats) {
iterable_val->stats.used = true;
iterable_val->stats.ops.insert("object_access");
}
} else {
JJ_DEBUG("%s", "For loop over array items");
auto & arr = iterable_val->as_array();
for (const auto & item : arr) {
items.push_back(item);
}
if (ctx.is_get_stats) {
iterable_val->stats.used = true;
iterable_val->stats.ops.insert("array_access");
}
}
std::vector<std::function<void(context &)>> scope_update_fns;
std::vector<value> filtered_items;
for (size_t i = 0; i < items.size(); ++i) {
context loop_scope(scope);
const value & current = items[i];
std::function<void(context&)> scope_update_fn = [](context &) { /* no-op */};
if (is_stmt<identifier>(loopvar)) {
auto id = cast_stmt<identifier>(loopvar)->val;
scope_update_fn = [id, &items, i](context & ctx) {
ctx.set_val(id, items[i]);
};
} else if (is_stmt<tuple_literal>(loopvar)) {
auto tuple = cast_stmt<tuple_literal>(loopvar);
if (!is_val<value_array>(current)) {
throw std::runtime_error("Cannot unpack non-iterable type: " + current->type());
}
auto & c_arr = current->as_array();
if (tuple->val.size() != c_arr.size()) {
throw std::runtime_error(std::string("Too ") + (tuple->val.size() > c_arr.size() ? "few" : "many") + " items to unpack");
}
scope_update_fn = [tuple, &items, i](context & ctx) {
auto & c_arr = items[i]->as_array();
for (size_t j = 0; j < tuple->val.size(); ++j) {
if (!is_stmt<identifier>(tuple->val[j])) {
throw std::runtime_error("Cannot unpack non-identifier type: " + tuple->val[j]->type());
}
auto id = cast_stmt<identifier>(tuple->val[j])->val;
ctx.set_val(id, c_arr[j]);
}
};
} else {
throw std::runtime_error("Invalid loop variable(s): " + loopvar->type());
}
if (select_expr && test_expr) {
scope_update_fn(loop_scope);
value test_val = test_expr->execute(loop_scope);
if (!test_val->as_bool()) {
continue;
}
}
JJ_DEBUG("For loop: adding item type %s at index %zu", current->type().c_str(), i);
filtered_items.push_back(current);
scope_update_fns.push_back(scope_update_fn);
}
JJ_DEBUG("For loop: %zu items after filtering", filtered_items.size());
auto result = mk_val<value_array>();
bool noIteration = true;
for (size_t i = 0; i < filtered_items.size(); i++) {
JJ_DEBUG("For loop iteration %zu/%zu", i + 1, filtered_items.size());
value_object loop_obj = mk_val<value_object>();
loop_obj->insert("index", mk_val<value_int>(i + 1));
loop_obj->insert("index0", mk_val<value_int>(i));
loop_obj->insert("revindex", mk_val<value_int>(filtered_items.size() - i));
loop_obj->insert("revindex0", mk_val<value_int>(filtered_items.size() - i - 1));
loop_obj->insert("first", mk_val<value_bool>(i == 0));
loop_obj->insert("last", mk_val<value_bool>(i == filtered_items.size() - 1));
loop_obj->insert("length", mk_val<value_int>(filtered_items.size()));
loop_obj->insert("previtem", i > 0 ? filtered_items[i - 1] : mk_val<value_undefined>("previtem"));
loop_obj->insert("nextitem", i < filtered_items.size() - 1 ? filtered_items[i + 1] : mk_val<value_undefined>("nextitem"));
scope.set_val("loop", loop_obj);
scope_update_fns[i](scope);
try {
for (auto & stmt : body) {
value val = stmt->execute(scope);
result->push_back(val);
}
} catch (const continue_statement::signal &) {
continue;
} catch (const break_statement::signal &) {
break;
}
noIteration = false;
}
JJ_DEBUG("For loop complete, total iterations: %zu", filtered_items.size());
if (noIteration) {
for (auto & stmt : default_block) {
value val = stmt->execute(ctx);
result->push_back(val);
}
}
// convert to string parts
value_string str = mk_val<value_string>();
gather_string_parts_recursive(result, str);
return str;
}
value set_statement::execute_impl(context & ctx) {
auto rhs = val ? val->execute(ctx) : exec_statements(body, ctx);
if (is_stmt<identifier>(assignee)) {
auto var_name = cast_stmt<identifier>(assignee)->val;
JJ_DEBUG("Setting variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str());
ctx.set_val(var_name, rhs);
} else if (is_stmt<tuple_literal>(assignee)) {
auto tuple = cast_stmt<tuple_literal>(assignee);
if (!is_val<value_array>(rhs)) {
throw std::runtime_error("Cannot unpack non-iterable type in set: " + rhs->type());
}
auto & arr = rhs->as_array();
if (arr.size() != tuple->val.size()) {
throw std::runtime_error(std::string("Too ") + (tuple->val.size() > arr.size() ? "few" : "many") + " items to unpack in set");
}
for (size_t i = 0; i < tuple->val.size(); ++i) {
auto & elem = tuple->val[i];
if (!is_stmt<identifier>(elem)) {
throw std::runtime_error("Cannot unpack to non-identifier in set: " + elem->type());
}
auto var_name = cast_stmt<identifier>(elem)->val;
ctx.set_val(var_name, arr[i]);
}
} else if (is_stmt<member_expression>(assignee)) {
auto member = cast_stmt<member_expression>(assignee);
if (member->computed) {
throw std::runtime_error("Cannot assign to computed member");
}
if (!is_stmt<identifier>(member->property)) {
throw std::runtime_error("Cannot assign to member with non-identifier property");
}
auto prop_name = cast_stmt<identifier>(member->property)->val;
value object = member->object->execute(ctx);
if (!is_val<value_object>(object)) {
throw std::runtime_error("Cannot assign to member of non-object");
}
auto obj_ptr = cast_val<value_object>(object);
JJ_DEBUG("Setting object property '%s'", prop_name.c_str());
obj_ptr->insert(prop_name, rhs);
} else {
throw std::runtime_error("Invalid LHS inside assignment expression: " + assignee->type());
}
return mk_val<value_null>();
}
value macro_statement::execute_impl(context & ctx) {
std::string name = cast_stmt<identifier>(this->name)->val;
const func_handler func = [this, name, &ctx](const func_args & args) -> value {
size_t expected_count = this->args.size();
size_t input_count = args.args.size();
JJ_DEBUG("Invoking macro '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
context macro_ctx(ctx); // new scope for macro execution
// bind parameters
for (size_t i = 0; i < expected_count; ++i) {
if (i < input_count) {
if (is_stmt<identifier>(this->args[i])) {
// normal parameter
std::string param_name = cast_stmt<identifier>(this->args[i])->val;
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), args.args[i]->type().c_str());
macro_ctx.set_val(param_name, args.args[i]);
} else if (is_stmt<keyword_argument_expression>(this->args[i])) {
// default argument used as normal parameter
auto kwarg = cast_stmt<keyword_argument_expression>(this->args[i]);
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), args.args[i]->type().c_str());
macro_ctx.set_val(param_name, args.args[i]);
} else {
throw std::runtime_error("Invalid parameter type in macro '" + name + "'");
}
} else {
auto & default_arg = this->args[i];
if (is_stmt<keyword_argument_expression>(default_arg)) {
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
macro_ctx.set_val(param_name, kwarg->val->execute(ctx));
} else {
throw std::runtime_error("Not enough arguments provided to macro '" + name + "'");
}
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
//macro_ctx.var[param_name] = default_args[i]->execute(ctx);
}
}
// execute macro body
JJ_DEBUG("Executing macro '%s' body with %zu statements", name.c_str(), this->body.size());
auto res = exec_statements(this->body, macro_ctx);
JJ_DEBUG("Macro '%s' execution complete, result: %s", name.c_str(), res->val_str.str().c_str());
return res;
};
JJ_DEBUG("Defining macro '%s' with %zu parameters", name.c_str(), args.size());
ctx.set_val(name, mk_val<value_func>(name, func));
return mk_val<value_null>();
}
value member_expression::execute_impl(context & ctx) {
value object = this->object->execute(ctx);
value property;
if (this->computed) {
JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str());
if (is_stmt<slice_expression>(this->property)) {
auto s = cast_stmt<slice_expression>(this->property);
value start_val = s->start_expr ? s->start_expr->execute(ctx) : mk_val<value_undefined>("start");
value stop_val = s->stop_expr ? s->stop_expr->execute(ctx) : mk_val<value_undefined>("stop");
value step_val = s->step_expr ? s->step_expr->execute(ctx) : mk_val<value_undefined>("step");
// translate to function call: obj.slice(start, stop, step)
JJ_DEBUG("Member expression is a slice: start %s, stop %s, step %s",
start_val->as_repr().c_str(),
stop_val->as_repr().c_str(),
step_val->as_repr().c_str());
auto slice_func = try_builtin_func(ctx, "slice", object);
func_args args(ctx);
args.args.push_back(start_val);
args.args.push_back(stop_val);
args.args.push_back(step_val);
return slice_func->invoke(args);
} else {
property = this->property->execute(ctx);
}
} else {
property = mk_val<value_string>(cast_stmt<identifier>(this->property)->val);
}
JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str());
value val = mk_val<value_undefined>("object_property");
if (is_val<value_undefined>(object)) {
JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined");
return val;
} else if (is_val<value_object>(object)) {
if (!is_val<value_string>(property)) {
throw std::runtime_error("Cannot access object with non-string: got " + property->type());
}
auto key = property->as_string().str();
auto & obj = object->as_object();
auto it = obj.find(key);
if (it != obj.end()) {
val = it->second;
} else {
val = try_builtin_func(ctx, key, object, true);
}
JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str());
} else if (is_val<value_array>(object) || is_val<value_string>(object)) {
if (is_val<value_int>(property)) {
int64_t index = property->as_int();
JJ_DEBUG("Accessing %s index %lld", object->type().c_str(), index);
if (is_val<value_array>(object)) {
auto & arr = object->as_array();
if (index < 0) {
index += static_cast<int64_t>(arr.size());
}
if (index >= 0 && index < static_cast<int64_t>(arr.size())) {
val = arr[index];
}
} else { // value_string
auto str = object->as_string().str();
if (index >= 0 && index < static_cast<int64_t>(str.size())) {
val = mk_val<value_string>(std::string(1, str[index]));
}
}
} else if (is_val<value_string>(property)) {
auto key = property->as_string().str();
JJ_DEBUG("Accessing %s built-in '%s'", is_val<value_array>(object) ? "array" : "string", key.c_str());
val = try_builtin_func(ctx, key, object);
} else {
throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type());
}
} else {
if (!is_val<value_string>(property)) {
throw std::runtime_error("Cannot access property with non-string: got " + property->type());
}
auto key = property->as_string().str();
val = try_builtin_func(ctx, key, object);
}
if (ctx.is_get_stats && val && object && property) {
val->stats.used = true;
object->stats.used = true;
if (is_val<value_int>(property)) {
object->stats.ops.insert("array_access");
} else if (is_val<value_string>(property)) {
object->stats.ops.insert("object_access");
}
}
return val;
}
value call_expression::execute_impl(context & ctx) {
// gather arguments
func_args args(ctx);
for (auto & arg_stmt : this->args) {
auto arg_val = arg_stmt->execute(ctx);
JJ_DEBUG(" Argument type: %s", arg_val->type().c_str());
args.args.push_back(std::move(arg_val));
}
// execute callee
value callee_val = callee->execute(ctx);
if (!is_val<value_func>(callee_val)) {
throw std::runtime_error("Callee is not a function: got " + callee_val->type());
}
auto * callee_func = cast_val<value_func>(callee_val);
JJ_DEBUG("Calling function '%s' with %zu arguments", callee_func->name.c_str(), args.args.size());
return callee_func->invoke(args);
}
// compare operator for value_t
bool value_compare(const value & a, const value & b) {
auto cmp = [&]() {
// compare numeric types
if ((is_val<value_int>(a) || is_val<value_float>(a)) &&
(is_val<value_int>(b) || is_val<value_float>(b))){
try {
return a->as_float() == b->as_float();
} catch (...) {}
}
// compare string and number
// TODO: not sure if this is the right behavior
if ((is_val<value_string>(b) && (is_val<value_int>(a) || is_val<value_float>(a))) ||
(is_val<value_string>(a) && (is_val<value_int>(b) || is_val<value_float>(b)))) {
try {
return a->as_string().str() == b->as_string().str();
} catch (...) {}
}
// compare boolean simple
if (is_val<value_bool>(a) && is_val<value_bool>(b)) {
return a->as_bool() == b->as_bool();
}
// compare string simple
if (is_val<value_string>(a) && is_val<value_string>(b)) {
return a->as_string().str() == b->as_string().str();
}
// compare by type
if (a->type() != b->type()) {
return false;
}
return false;
};
auto result = cmp();
JJ_DEBUG("Comparing types: %s and %s result=%d", a->type().c_str(), b->type().c_str(), result);
return result;
}
value keyword_argument_expression::execute_impl(context & ctx) {
if (!is_stmt<identifier>(key)) {
throw std::runtime_error("Keyword argument key must be identifiers");
}
std::string k = cast_stmt<identifier>(key)->val;
JJ_DEBUG("Keyword argument expression key: %s, value: %s", k.c_str(), val->type().c_str());
value v = val->execute(ctx);
JJ_DEBUG("Keyword argument value executed, type: %s", v->type().c_str());
return mk_val<value_kwarg>(k, v);
}
} // namespace jinja

604
common/jinja/jinja-vm.h Normal file
View File

@ -0,0 +1,604 @@
#pragma once
#include "jinja-lexer.h"
#include "jinja-value.h"
#include <string>
#include <vector>
#include <cassert>
#include <memory>
#include <sstream>
#define JJ_DEBUG(msg, ...) if (g_jinja_debug) printf("%s:%-3d : " msg "\n", FILENAME, __LINE__, __VA_ARGS__)
extern bool g_jinja_debug;
namespace jinja {
struct statement;
using statement_ptr = std::unique_ptr<statement>;
using statements = std::vector<statement_ptr>;
// Helpers for dynamic casting and type checking
template<typename T>
struct extract_pointee_unique {
using type = T;
};
template<typename U>
struct extract_pointee_unique<std::unique_ptr<U>> {
using type = U;
};
template<typename T>
bool is_stmt(const statement_ptr & ptr) {
return dynamic_cast<const T*>(ptr.get()) != nullptr;
}
template<typename T>
T * cast_stmt(statement_ptr & ptr) {
return dynamic_cast<T*>(ptr.get());
}
template<typename T>
const T * cast_stmt(const statement_ptr & ptr) {
return dynamic_cast<const T*>(ptr.get());
}
// End Helpers
// not thread-safe
void enable_debug(bool enable);
struct context {
std::string source; // for debugging
std::time_t current_time; // for functions that need current time
bool is_get_stats = false; // whether to collect stats
context() {
global = mk_val<value_object>();
global->insert("true", mk_val<value_bool>(true));
global->insert("false", mk_val<value_bool>(false));
global->insert("none", mk_val<value_null>());
current_time = std::time(nullptr);
}
~context() = default;
context(const context & parent) : context() {
// inherit variables (for example, when entering a new scope)
auto & pvar = parent.global->as_object();
for (const auto & pair : pvar) {
set_val(pair.first, pair.second);
}
current_time = parent.current_time;
is_get_stats = parent.is_get_stats;
}
value get_val(const std::string & name) {
auto it = global->val_obj.find(name);
if (it != global->val_obj.end()) {
return it->second;
} else {
return mk_val<value_undefined>(name);
}
}
void set_val(const std::string & name, const value & val) {
global->insert(name, val);
}
private:
value_object global;
};
/**
* Base class for all nodes in the AST.
*/
struct statement {
size_t pos; // position in source, for debugging
virtual ~statement() = default;
virtual std::string type() const { return "Statement"; }
// execute_impl must be overridden by derived classes
virtual value execute_impl(context &) { throw std::runtime_error("cannot exec " + type()); }
// execute is the public method to execute a statement with error handling
value execute(context &);
};
// Type Checking Utilities
template<typename T>
static void chk_type(const statement_ptr & ptr) {
if (!ptr) return; // Allow null for optional fields
assert(dynamic_cast<T *>(ptr.get()) != nullptr);
}
template<typename T, typename U>
static void chk_type(const statement_ptr & ptr) {
if (!ptr) return;
assert(dynamic_cast<T *>(ptr.get()) != nullptr || dynamic_cast<U *>(ptr.get()) != nullptr);
}
// Base Types
/**
* Expressions will result in a value at runtime (unlike statements).
*/
struct expression : public statement {
std::string type() const override { return "Expression"; }
};
// Statements
struct program : public statement {
statements body;
program() = default;
explicit program(statements && body) : body(std::move(body)) {}
std::string type() const override { return "Program"; }
value execute_impl(context &) override {
throw std::runtime_error("Cannot execute program directly, use jinja::vm instead");
}
};
struct if_statement : public statement {
statement_ptr test;
statements body;
statements alternate;
if_statement(statement_ptr && test, statements && body, statements && alternate)
: test(std::move(test)), body(std::move(body)), alternate(std::move(alternate)) {
chk_type<expression>(this->test);
}
std::string type() const override { return "If"; }
value execute_impl(context & ctx) override;
};
struct identifier;
struct tuple_literal;
/**
* Loop over each item in a sequence
* https://jinja.palletsprojects.com/en/3.0.x/templates/#for
*/
struct for_statement : public statement {
statement_ptr loopvar; // Identifier | TupleLiteral
statement_ptr iterable;
statements body;
statements default_block; // if no iteration took place
for_statement(statement_ptr && loopvar, statement_ptr && iterable, statements && body, statements && default_block)
: loopvar(std::move(loopvar)), iterable(std::move(iterable)),
body(std::move(body)), default_block(std::move(default_block)) {
chk_type<identifier, tuple_literal>(this->loopvar);
chk_type<expression>(this->iterable);
}
std::string type() const override { return "For"; }
value execute_impl(context & ctx) override;
};
struct break_statement : public statement {
std::string type() const override { return "Break"; }
struct signal : public std::exception {
const char* what() const noexcept override {
return "Break statement executed";
}
};
value execute_impl(context &) override {
throw break_statement::signal();
}
};
struct continue_statement : public statement {
std::string type() const override { return "Continue"; }
struct signal : public std::exception {
const char* what() const noexcept override {
return "Continue statement executed";
}
};
value execute_impl(context &) override {
throw continue_statement::signal();
}
};
// do nothing
struct noop_statement : public statement {
std::string type() const override { return "Noop"; }
value execute_impl(context &) override {
return mk_val<value_null>();
}
};
struct set_statement : public statement {
statement_ptr assignee;
statement_ptr val;
statements body;
set_statement(statement_ptr && assignee, statement_ptr && value, statements && body)
: assignee(std::move(assignee)), val(std::move(value)), body(std::move(body)) {
chk_type<expression>(this->assignee);
chk_type<expression>(this->val);
}
std::string type() const override { return "Set"; }
value execute_impl(context & ctx) override;
};
struct macro_statement : public statement {
statement_ptr name;
statements args;
statements body;
macro_statement(statement_ptr && name, statements && args, statements && body)
: name(std::move(name)), args(std::move(args)), body(std::move(body)) {
chk_type<identifier>(this->name);
for (const auto& arg : this->args) chk_type<expression>(arg);
}
std::string type() const override { return "Macro"; }
value execute_impl(context & ctx) override;
};
struct comment_statement : public statement {
std::string val;
explicit comment_statement(const std::string & v) : val(v) {}
std::string type() const override { return "Comment"; }
value execute_impl(context &) override {
return mk_val<value_null>();
}
};
// Expressions
struct member_expression : public expression {
statement_ptr object;
statement_ptr property;
bool computed;
member_expression(statement_ptr && object, statement_ptr && property, bool computed)
: object(std::move(object)), property(std::move(property)), computed(computed) {
chk_type<expression>(this->object);
chk_type<expression>(this->property);
}
std::string type() const override { return "MemberExpression"; }
value execute_impl(context & ctx) override;
};
struct call_expression : public expression {
statement_ptr callee;
statements args;
call_expression(statement_ptr && callee, statements && args)
: callee(std::move(callee)), args(std::move(args)) {
chk_type<expression>(this->callee);
for (const auto& arg : this->args) chk_type<expression>(arg);
}
std::string type() const override { return "CallExpression"; }
value execute_impl(context & ctx) override;
};
/**
* Represents a user-defined variable or symbol in the template.
*/
struct identifier : public expression {
std::string val;
explicit identifier(const std::string & val) : val(val) {}
std::string type() const override { return "Identifier"; }
value execute_impl(context & ctx) override;
};
// Literals
struct integer_literal : public expression {
int64_t val;
explicit integer_literal(int64_t val) : val(val) {}
std::string type() const override { return "IntegerLiteral"; }
value execute_impl(context &) override {
return mk_val<value_int>(val);
}
};
struct float_literal : public expression {
double val;
explicit float_literal(double val) : val(val) {}
std::string type() const override { return "FloatLiteral"; }
value execute_impl(context &) override {
return mk_val<value_float>(val);
}
};
struct string_literal : public expression {
std::string val;
explicit string_literal(const std::string & val) : val(val) {}
std::string type() const override { return "StringLiteral"; }
value execute_impl(context &) override {
return mk_val<value_string>(val);
}
};
struct array_literal : public expression {
statements val;
explicit array_literal(statements && val) : val(std::move(val)) {
for (const auto& item : this->val) chk_type<expression>(item);
}
std::string type() const override { return "ArrayLiteral"; }
value execute_impl(context & ctx) override {
auto arr = mk_val<value_array>();
for (const auto & item_stmt : val) {
arr->push_back(item_stmt->execute(ctx));
}
return arr;
}
};
struct tuple_literal : public array_literal {
explicit tuple_literal(statements && val) : array_literal(std::move(val)) {}
std::string type() const override { return "TupleLiteral"; }
};
struct object_literal : public expression {
std::vector<std::pair<statement_ptr, statement_ptr>> val;
explicit object_literal(std::vector<std::pair<statement_ptr, statement_ptr>> && val)
: val(std::move(val)) {
for (const auto & pair : this->val) {
chk_type<expression>(pair.first);
chk_type<expression>(pair.second);
}
}
std::string type() const override { return "ObjectLiteral"; }
value execute_impl(context & ctx) override;
};
// Complex Expressions
/**
* An operation with two sides, separated by an operator.
* Note: Either side can be a Complex Expression, with order
* of operations being determined by the operator.
*/
struct binary_expression : public expression {
token op;
statement_ptr left;
statement_ptr right;
binary_expression(token op, statement_ptr && left, statement_ptr && right)
: op(op), left(std::move(left)), right(std::move(right)) {
chk_type<expression>(this->left);
chk_type<expression>(this->right);
}
std::string type() const override { return "BinaryExpression"; }
value execute_impl(context & ctx) override;
};
/**
* An operation with two sides, separated by the | operator.
* Operator precedence: https://github.com/pallets/jinja/issues/379#issuecomment-168076202
*/
struct filter_expression : public expression {
// either an expression or a value is allowed
statement_ptr operand;
value_string val; // will be set by filter_statement
statement_ptr filter;
filter_expression(statement_ptr && operand, statement_ptr && filter)
: operand(std::move(operand)), filter(std::move(filter)) {
chk_type<expression>(this->operand);
chk_type<identifier, call_expression>(this->filter);
}
filter_expression(value_string && val, statement_ptr && filter)
: val(std::move(val)), filter(std::move(filter)) {
chk_type<identifier, call_expression>(this->filter);
}
std::string type() const override { return "FilterExpression"; }
value execute_impl(context & ctx) override;
};
struct filter_statement : public statement {
statement_ptr filter;
statements body;
filter_statement(statement_ptr && filter, statements && body)
: filter(std::move(filter)), body(std::move(body)) {
chk_type<identifier, call_expression>(this->filter);
}
std::string type() const override { return "FilterStatement"; }
value execute_impl(context & ctx) override;
};
/**
* An operation which filters a sequence of objects by applying a test to each object,
* and only selecting the objects with the test succeeding.
*
* It may also be used as a shortcut for a ternary operator.
*/
struct select_expression : public expression {
statement_ptr lhs;
statement_ptr test;
select_expression(statement_ptr && lhs, statement_ptr && test)
: lhs(std::move(lhs)), test(std::move(test)) {
chk_type<expression>(this->lhs);
chk_type<expression>(this->test);
}
std::string type() const override { return "SelectExpression"; }
value execute_impl(context & ctx) override {
auto predicate = test->execute_impl(ctx);
if (!predicate->as_bool()) {
return mk_val<value_undefined>();
}
return lhs->execute_impl(ctx);
}
};
/**
* An operation with two sides, separated by the "is" operator.
* NOTE: "value is something" translates to function call "test_is_something(value)"
*/
struct test_expression : public expression {
statement_ptr operand;
bool negate;
statement_ptr test;
test_expression(statement_ptr && operand, bool negate, statement_ptr && test)
: operand(std::move(operand)), negate(negate), test(std::move(test)) {
chk_type<expression>(this->operand);
chk_type<identifier>(this->test);
}
std::string type() const override { return "TestExpression"; }
value execute_impl(context & ctx) override;
};
/**
* An operation with one side (operator on the left).
*/
struct unary_expression : public expression {
token op;
statement_ptr argument;
unary_expression(token op, statement_ptr && argument)
: op(std::move(op)), argument(std::move(argument)) {
chk_type<expression>(this->argument);
}
std::string type() const override { return "UnaryExpression"; }
value execute_impl(context & ctx) override;
};
struct slice_expression : public expression {
statement_ptr start_expr;
statement_ptr stop_expr;
statement_ptr step_expr;
slice_expression(statement_ptr && start_expr, statement_ptr && stop_expr, statement_ptr && step_expr)
: start_expr(std::move(start_expr)), stop_expr(std::move(stop_expr)), step_expr(std::move(step_expr)) {
chk_type<expression>(this->start_expr);
chk_type<expression>(this->stop_expr);
chk_type<expression>(this->step_expr);
}
std::string type() const override { return "SliceExpression"; }
value execute_impl(context &) override {
throw std::runtime_error("must be handled by MemberExpression");
}
};
struct keyword_argument_expression : public expression {
statement_ptr key;
statement_ptr val;
keyword_argument_expression(statement_ptr && key, statement_ptr && val)
: key(std::move(key)), val(std::move(val)) {
chk_type<identifier>(this->key);
chk_type<expression>(this->val);
}
std::string type() const override { return "KeywordArgumentExpression"; }
value execute_impl(context & ctx) override;
};
struct spread_expression : public expression {
statement_ptr argument;
explicit spread_expression(statement_ptr && argument) : argument(std::move(argument)) {
chk_type<expression>(this->argument);
}
std::string type() const override { return "SpreadExpression"; }
};
struct call_statement : public statement {
statement_ptr call;
statements caller_args;
statements body;
call_statement(statement_ptr && call, statements && caller_args, statements && body)
: call(std::move(call)), caller_args(std::move(caller_args)), body(std::move(body)) {
chk_type<call_expression>(this->call);
for (const auto& arg : this->caller_args) chk_type<expression>(arg);
}
std::string type() const override { return "CallStatement"; }
};
struct ternary_expression : public expression {
statement_ptr condition;
statement_ptr true_expr;
statement_ptr false_expr;
ternary_expression(statement_ptr && condition, statement_ptr && true_expr, statement_ptr && false_expr)
: condition(std::move(condition)), true_expr(std::move(true_expr)), false_expr(std::move(false_expr)) {
chk_type<expression>(this->condition);
chk_type<expression>(this->true_expr);
chk_type<expression>(this->false_expr);
}
std::string type() const override { return "Ternary"; }
value execute_impl(context & ctx) override {
value cond_val = condition->execute(ctx);
if (cond_val->as_bool()) {
return true_expr->execute(ctx);
} else {
return false_expr->execute(ctx);
}
}
};
struct raised_exception : public std::exception {
std::string message;
raised_exception(const std::string & msg) : message(msg) {}
const char* what() const noexcept override {
return message.c_str();
}
};
//////////////////////
static void gather_string_parts_recursive(const value & val, value_string & parts) {
if (is_val<value_string>(val)) {
const auto & str_val = cast_val<value_string>(val)->val_str;
parts->val_str.append(str_val);
} else if (is_val<value_array>(val)) {
auto items = cast_val<value_array>(val)->as_array();
for (const auto & item : items) {
gather_string_parts_recursive(item, parts);
}
}
}
static std::string render_string_parts(const value_string & parts) {
std::ostringstream oss;
for (const auto & part : parts->val_str.parts) {
oss << part.val;
}
return oss.str();
}
struct vm {
context & ctx;
explicit vm(context & ctx) : ctx(ctx) {}
value_array execute(const program & prog) {
value_array results = mk_val<value_array>();
for (auto & stmt : prog.body) {
value res = stmt->execute(ctx);
results->push_back(std::move(res));
}
return results;
}
value_string gather_string_parts(const value & val) {
value_string parts = mk_val<value_string>();
gather_string_parts_recursive(val, parts);
// join consecutive parts with the same type
auto & p = parts->val_str.parts;
for (size_t i = 1; i < p.size(); ) {
if (p[i].is_input == p[i - 1].is_input) {
p[i - 1].val += p[i].val;
p.erase(p.begin() + i);
} else {
i++;
}
}
return parts;
}
};
} // namespace jinja

View File

@ -186,6 +186,7 @@ endif()
llama_build_and_test(test-chat-parser.cpp)
llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp)
llama_build_and_test(test-chat-template.cpp)
llama_build_and_test(test-chat-jinja.cpp)
llama_build_and_test(test-json-partial.cpp)
llama_build_and_test(test-log.cpp)
llama_build_and_test(

206
tests/test-chat-jinja.cpp Normal file
View File

@ -0,0 +1,206 @@
#include <string>
#include <vector>
#include <sstream>
#include <regex>
#include <iostream>
#include <fstream>
#include <filesystem>
#include <nlohmann/json.hpp>
#undef NDEBUG
#include <cassert>
#include "jinja/jinja-parser.h"
#include "jinja/jinja-lexer.h"
#include "jinja/jinja-caps.h"
using json = nlohmann::json;
void run_multiple(std::string dir_path, bool stop_on_first_failure, json input);
void run_single(std::string contents, json input, const std::string & output_path = "");
std::string HELP = R"(
Usage: test-chat-jinja [OPTIONS] PATH_TO_TEMPLATE
Options:
-h, --help Show this help message and exit.
--json <path> Path to the JSON input file.
--stop-on-first-fail Stop testing on the first failure (default: false).
--output <path> Path to output results (only for single template runs).
If PATH_TO_TEMPLATE is a file, runs that single template.
If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory.
)";
std::string DEFAULT_JSON = R"({
"messages": [
{
"role": "user",
"content": {"__input__": "Hello, how are you?"}
},
{
"role": "assistant",
"content": {"__input__": "I am fine, thank you!"},
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": {
"location": "New York",
"unit": "celsius"
}
}
}
]
}
],
"bos_token": "<s>",
"eos_token": "</s>",
"tools": [],
"add_generation_prompt": true
})";
int main(int argc, char ** argv) {
std::vector<std::string> args(argv, argv + argc);
std::string tmpl_path;
std::string json_path;
std::string output_path;
bool stop_on_first_fail = false;
for (size_t i = 1; i < args.size(); i++) {
if (args[i] == "--help" || args[i] == "-h") {
std::cout << HELP << "\n";
return 0;
} else if (args[i] == "--json" && i + 1 < args.size()) {
json_path = args[i + 1];
i++;
} else if (args[i] == "--stop-on-first-fail") {
stop_on_first_fail = true;
} else if (args[i] == "--output" && i + 1 < args.size()) {
output_path = args[i + 1];
i++;
} else if (tmpl_path.empty()) {
tmpl_path = args[i];
} else {
std::cerr << "Unknown argument: " << args[i] << "\n";
std::cout << HELP << "\n";
return 1;
}
}
if (tmpl_path.empty()) {
std::cerr << "Error: PATH_TO_TEMPLATE is required.\n";
std::cout << HELP << "\n";
return 1;
}
json input_json;
if (!json_path.empty()) {
std::ifstream json_file(json_path);
if (!json_file) {
std::cerr << "Error: Could not open JSON file: " << json_path << "\n";
return 1;
}
std::string content = std::string(
std::istreambuf_iterator<char>(json_file),
std::istreambuf_iterator<char>());
input_json = json::parse(content);
} else {
input_json = json::parse(DEFAULT_JSON);
}
std::filesystem::path p(tmpl_path);
if (std::filesystem::is_directory(p)) {
run_multiple(tmpl_path, stop_on_first_fail, input_json);
} else if (std::filesystem::is_regular_file(p)) {
std::ifstream infile(tmpl_path);
std::string contents = std::string(
std::istreambuf_iterator<char>(infile),
std::istreambuf_iterator<char>());
run_single(contents, input_json, output_path);
} else {
std::cerr << "Error: PATH_TO_TEMPLATE is not a valid file or directory: " << tmpl_path << "\n";
return 1;
}
return 0;
}
void run_multiple(std::string dir_path, bool stop_on_first_fail, json input) {
std::vector<std::string> failed_tests;
// list all files in models/templates/ and run each
size_t test_count = 0;
for (const auto & entry : std::filesystem::directory_iterator(dir_path)) {
// only process .jinja files
if (entry.path().extension() == ".jinja" && entry.is_regular_file()) {
test_count++;
std::cout << "\n\n=== RUNNING TEMPLATE FILE: " << entry.path().string() << " ===\n";
std::ifstream infile(entry.path());
std::string contents((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
try {
run_single(contents, input);
} catch (const std::exception & e) {
std::cout << "Exception: " << e.what() << "\n";
std::cout << "=== ERROR WITH TEMPLATE FILE: " << entry.path().string() << " ===\n";
failed_tests.push_back(entry.path().string());
if (stop_on_first_fail) {
break;
}
}
}
}
std::cout << "\n\n=== TEST SUMMARY ===\n";
std::cout << "Total tests run: " << test_count << "\n";
std::cout << "Total failed tests: " << failed_tests.size() << "\n";
for (const auto & test : failed_tests) {
std::cout << "FAILED TEST: " << test << "\n";
}
}
void run_single(std::string contents, json input, const std::string & output_path) {
jinja::enable_debug(true);
// lexing
jinja::lexer lexer;
jinja::preprocess_options options;
options.trim_blocks = false;
options.lstrip_blocks = false;
auto lexer_res = lexer.tokenize(contents, options);
// compile to AST
jinja::program ast = jinja::parse_from_tokens(lexer_res);
// check caps for workarounds
auto caps = jinja::caps_get(ast);
std::cout << "\n=== RUN ===\n";
jinja::context ctx;
ctx.source = lexer_res.preprocessed_source;
jinja::global_from_json(ctx, input);
jinja::caps_apply_workarounds(ctx, caps);
jinja::vm vm(ctx);
const jinja::value results = vm.execute(ast);
auto parts = vm.gather_string_parts(results);
std::cout << "\n=== RESULTS ===\n";
for (const auto & part : parts->as_string().parts) {
std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n";
}
if (!output_path.empty()) {
std::ofstream outfile(output_path);
if (!outfile) {
throw std::runtime_error("Could not open output file: " + output_path);
}
for (const auto & part : parts->as_string().parts) {
outfile << part.val;
}
std::cout << "\n=== OUTPUT WRITTEN TO " << output_path << " ===\n";
}
}