jinja : implement mixed type object keys (#18955)

* implement mixed type object keys

* add tests

* refactor

* minor fixes

* massive refactor

* add more tests

* forgotten tuples

* fix array/object is_hashable

* correct (albeit broken) jinja responses

verified with transformers

* improved hashing and equality

* refactor hash function

* more exhausive test case

* clean up

* cont

* cont (2)

* missing cstring

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
Sigbjørn Skjæret 2026-01-27 19:50:42 +01:00 committed by GitHub
parent 68ac3acb43
commit 2b4cbd2834
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 802 additions and 168 deletions

View File

@ -44,6 +44,12 @@ static std::string get_line_col(const std::string & source, size_t pos) {
return "line " + std::to_string(line) + ", column " + std::to_string(col);
}
static void ensure_key_type_allowed(const value & val) {
if (!val->is_hashable()) {
throw std::runtime_error("Type: " + val->type() + " is not allowed as object key");
}
}
// execute with error handling
value statement::execute(context & ctx) {
try {
@ -95,20 +101,10 @@ value identifier::execute_impl(context & ctx) {
value object_literal::execute_impl(context & ctx) {
auto obj = mk_val<value_object>();
for (const auto & pair : val) {
value key_val = pair.first->execute(ctx);
if (!is_val<value_string>(key_val) && !is_val<value_int>(key_val)) {
throw std::runtime_error("Object literal: keys must be string or int values, got " + key_val->type());
}
std::string key = key_val->as_string().str();
value key = pair.first->execute(ctx);
value val = pair.second->execute(ctx);
JJ_DEBUG("Object literal: setting key '%s' with value type %s", key.c_str(), val->type().c_str());
JJ_DEBUG("Object literal: setting key '%s' with value type %s", key->as_string().str().c_str(), val->type().c_str());
obj->insert(key, val);
if (is_val<value_int>(key_val)) {
obj->val_obj.is_key_numeric = true;
} else if (obj->val_obj.is_key_numeric) {
throw std::runtime_error("Object literal: cannot mix numeric and non-numeric keys");
}
}
return obj;
}
@ -127,9 +123,9 @@ value binary_expression::execute_impl(context & ctx) {
value right_val = right->execute(ctx);
JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right_val->type().c_str());
if (op.value == "==") {
return mk_val<value_bool>(value_compare(left_val, right_val, value_compare_op::eq));
return mk_val<value_bool>(*left_val == *right_val);
} else if (op.value == "!=") {
return mk_val<value_bool>(!value_compare(left_val, right_val, value_compare_op::eq));
return mk_val<value_bool>(!(*left_val == *right_val));
}
auto workaround_concat_null_with_str = [&](value & res) -> bool {
@ -230,7 +226,7 @@ value binary_expression::execute_impl(context & ctx) {
auto & arr = right_val->as_array();
bool member = false;
for (const auto & item : arr) {
if (value_compare(left_val, item, value_compare_op::eq)) {
if (*left_val == *item) {
member = true;
break;
}
@ -265,10 +261,9 @@ value binary_expression::execute_impl(context & ctx) {
}
}
// String in object
if (is_val<value_string>(left_val) && is_val<value_object>(right_val)) {
auto key = left_val->as_string().str();
bool has_key = right_val->has_key(key);
// Value key in object
if (is_val<value_object>(right_val)) {
bool has_key = right_val->has_key(left_val);
if (op.value == "in") {
return mk_val<value_bool>(has_key);
} else if (op.value == "not in") {
@ -465,14 +460,8 @@ value for_statement::execute_impl(context & ctx) {
JJ_DEBUG("%s", "For loop over object keys");
auto & obj = iterable_val->as_ordered_object();
for (auto & p : obj) {
auto tuple = mk_val<value_array>();
if (iterable_val->val_obj.is_key_numeric) {
tuple->push_back(mk_val<value_int>(std::stoll(p.first)));
} else {
tuple->push_back(mk_val<value_string>(p.first));
}
tuple->push_back(p.second);
items.push_back(tuple);
auto tuple = mk_val<value_tuple>(p);
items.push_back(std::move(tuple));
}
if (ctx.is_get_stats) {
iterable_val->stats.used = true;
@ -602,11 +591,13 @@ value set_statement::execute_impl(context & ctx) {
auto rhs = val ? val->execute(ctx) : exec_statements(body, ctx);
if (is_stmt<identifier>(assignee)) {
// case: {% set my_var = value %}
auto var_name = cast_stmt<identifier>(assignee)->val;
JJ_DEBUG("Setting global variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str());
ctx.set_val(var_name, rhs);
} else if (is_stmt<tuple_literal>(assignee)) {
// case: {% set a, b = value %}
auto tuple = cast_stmt<tuple_literal>(assignee);
if (!is_val<value_array>(rhs)) {
throw std::runtime_error("Cannot unpack non-iterable type in set: " + rhs->type());
@ -625,6 +616,7 @@ value set_statement::execute_impl(context & ctx) {
}
} else if (is_stmt<member_expression>(assignee)) {
// case: {% set ns.my_var = value %}
auto member = cast_stmt<member_expression>(assignee);
if (member->computed) {
throw std::runtime_error("Cannot assign to computed member");
@ -767,22 +759,22 @@ value member_expression::execute_impl(context & ctx) {
}
JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str());
ensure_key_type_allowed(property);
value val = mk_val<value_undefined>("object_property");
if (is_val<value_undefined>(object)) {
JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined");
return val;
} else if (is_val<value_object>(object)) {
if (!is_val<value_string>(property)) {
throw std::runtime_error("Cannot access object with non-string: got " + property->type());
}
auto key = property->as_string().str();
val = object->at(key, val);
val = object->at(property, val);
if (is_val<value_undefined>(val)) {
val = try_builtin_func(ctx, key, object, true);
}
JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str());
} else if (is_val<value_array>(object) || is_val<value_string>(object)) {
if (is_val<value_int>(property)) {
int64_t index = property->as_int();
@ -806,6 +798,7 @@ value member_expression::execute_impl(context & ctx) {
auto key = property->as_string().str();
JJ_DEBUG("Accessing %s built-in '%s'", is_val<value_array>(object) ? "array" : "string", key.c_str());
val = try_builtin_func(ctx, key, object, true);
} else {
throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type());
}

View File

@ -79,18 +79,18 @@ struct context {
}
value get_val(const std::string & name) {
auto it = env->val_obj.unordered.find(name);
if (it != env->val_obj.unordered.end()) {
return it->second;
} else {
return mk_val<value_undefined>(name);
}
value default_val = mk_val<value_undefined>(name);
return env->at(name, default_val);
}
void set_val(const std::string & name, const value & val) {
env->insert(name, val);
}
void set_val(const value & name, const value & val) {
env->insert(name, val);
}
void print_vars() const {
printf("Context Variables:\n%s\n", value_to_json(env, 2).c_str());
}
@ -344,9 +344,19 @@ struct array_literal : public expression {
}
};
struct tuple_literal : public array_literal {
explicit tuple_literal(statements && val) : array_literal(std::move(val)) {}
struct tuple_literal : public expression {
statements val;
explicit tuple_literal(statements && val) : val(std::move(val)) {
for (const auto& item : this->val) chk_type<expression>(item);
}
std::string type() const override { return "TupleLiteral"; }
value execute_impl(context & ctx) override {
auto arr = mk_val<value_array>();
for (const auto & item_stmt : val) {
arr->push_back(item_stmt->execute(ctx));
}
return mk_val<value_tuple>(std::move(arr->as_array()));
}
};
struct object_literal : public expression {

View File

@ -61,6 +61,12 @@ size_t string::length() const {
return len;
}
void string::hash_update(hasher & hash) const noexcept {
for (const auto & part : parts) {
hash.update(part.val.data(), part.val.length());
}
}
bool string::all_parts_are_input() const {
for (const auto & part : parts) {
if (!part.is_input) {

View File

@ -4,6 +4,8 @@
#include <string>
#include <vector>
#include "utils.h"
namespace jinja {
// allow differentiate between user input strings and template strings
@ -37,6 +39,7 @@ struct string {
std::string str() const;
size_t length() const;
void hash_update(hasher & hash) const noexcept;
bool all_parts_are_input() const;
bool is_uppercase() const;
bool is_lowercase() const;

View File

@ -3,6 +3,8 @@
#include <string>
#include <sstream>
#include <algorithm>
#include <cstdint>
#include <cstring>
namespace jinja {
@ -46,4 +48,102 @@ static std::string fmt_error_with_source(const std::string & tag, const std::str
return oss.str();
}
// Note: this is a simple hasher, not cryptographically secure, just for hash table usage
struct hasher {
static constexpr auto size_t_digits = sizeof(size_t) * 8;
static constexpr size_t prime = size_t_digits == 64 ? 0x100000001b3 : 0x01000193;
static constexpr size_t seed = size_t_digits == 64 ? 0xcbf29ce484222325 : 0x811c9dc5;
static constexpr auto block_size = sizeof(size_t); // in bytes; allowing the compiler to vectorize the computation
static_assert(size_t_digits == 64 || size_t_digits == 32);
static_assert(block_size == 8 || block_size == 4);
uint8_t buffer[block_size];
size_t idx = 0; // current index in buffer
size_t state = seed;
hasher() = default;
hasher(const std::type_info & type_inf) noexcept {
const auto type_hash = type_inf.hash_code();
update(&type_hash, sizeof(type_hash));
}
// Properties:
// - update is not associative: update(a).update(b) != update(b).update(a)
// - update(a ~ b) == update(a).update(b) with ~ as concatenation operator --> useful for streaming
// - update("", 0) --> state unchanged with empty input
hasher& update(void const * bytes, size_t len) noexcept {
const uint8_t * c = static_cast<uint8_t const *>(bytes);
if (len == 0) {
return *this;
}
size_t processed = 0;
// first, fill the existing buffer if it's partial
if (idx > 0) {
size_t to_fill = block_size - idx;
if (to_fill > len) {
to_fill = len;
}
std::memcpy(buffer + idx, c, to_fill);
idx += to_fill;
processed += to_fill;
if (idx == block_size) {
update_block(buffer);
idx = 0;
}
}
// process full blocks from the remaining input
for (; processed + block_size <= len; processed += block_size) {
update_block(c + processed);
}
// buffer any remaining bytes
size_t remaining = len - processed;
if (remaining > 0) {
std::memcpy(buffer, c + processed, remaining);
idx = remaining;
}
return *this;
}
// convenience function for testing only
hasher& update(const std::string & s) noexcept {
return update(s.data(), s.size());
}
// finalize and get the hash value
// note: after calling digest, the hasher state is modified, do not call update() again
size_t digest() noexcept {
// if there are remaining bytes in buffer, fill the rest with zeros and process
if (idx > 0) {
for (size_t i = idx; i < block_size; ++i) {
buffer[i] = 0;
}
update_block(buffer);
idx = 0;
}
return state;
}
private:
// IMPORTANT: block must have at least block_size bytes
void update_block(const uint8_t * block) noexcept {
size_t blk = static_cast<uint32_t>(block[0])
| (static_cast<uint32_t>(block[1]) << 8)
| (static_cast<uint32_t>(block[2]) << 16)
| (static_cast<uint32_t>(block[3]) << 24);
if constexpr (block_size == 8) {
blk = blk | (static_cast<uint64_t>(block[4]) << 32)
| (static_cast<uint64_t>(block[5]) << 40)
| (static_cast<uint64_t>(block[6]) << 48)
| (static_cast<uint64_t>(block[7]) << 56);
}
state ^= blk;
state *= prime;
}
};
} // namespace jinja

View File

@ -163,7 +163,7 @@ static value selectattr(const func_args & args) {
args.ensure_vals<value_array, value_string, value_string, value_string>(true, true, false, false);
auto arr = args.get_pos(0)->as_array();
auto attr_name = args.get_pos(1)->as_string().str();
auto attribute = args.get_pos(1);
auto out = mk_val<value_array>();
value val_default = mk_val<value_undefined>();
@ -173,7 +173,7 @@ static value selectattr(const func_args & args) {
if (!is_val<value_object>(item)) {
throw raised_exception("selectattr: item is not an object");
}
value attr_val = item->at(attr_name, val_default);
value attr_val = item->at(attribute, val_default);
bool is_selected = attr_val->as_bool();
if constexpr (is_reject) is_selected = !is_selected;
if (is_selected) out->push_back(item);
@ -217,7 +217,7 @@ static value selectattr(const func_args & args) {
if (!is_val<value_object>(item)) {
throw raised_exception("selectattr: item is not an object");
}
value attr_val = item->at(attr_name, val_default);
value attr_val = item->at(attribute, val_default);
func_args test_args(args.ctx);
test_args.push_back(attr_val); // attribute value
test_args.push_back(extra_arg); // extra argument
@ -741,6 +741,7 @@ const func_builtins & value_array_t::get_builtins() const {
args.ensure_count(1, 4);
args.ensure_vals<value_array, value_int, value_int, value_int>(true, true, false, false);
auto val = args.get_pos(0);
auto arg0 = args.get_pos(1);
auto arg1 = args.get_pos(2, mk_val<value_undefined>());
auto arg2 = args.get_pos(3, mk_val<value_undefined>());
@ -762,10 +763,8 @@ const func_builtins & value_array_t::get_builtins() const {
if (step == 0) {
throw raised_exception("slice step cannot be zero");
}
auto arr = slice(args.get_pos(0)->as_array(), start, stop, step);
auto res = mk_val<value_array>();
res->val_arr = std::move(arr);
return res;
auto arr = slice(val->as_array(), start, stop, step);
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
}},
{"selectattr", selectattr<false>},
{"select", selectattr<false>},
@ -785,15 +784,14 @@ const func_builtins & value_array_t::get_builtins() const {
}
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str();
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
std::string result;
for (size_t i = 0; i < arr.size(); ++i) {
value val_arr = arr[i];
if (!attribute->is_undefined()) {
if (attr_is_int && is_val<value_array>(val_arr)) {
val_arr = val_arr->at(attr_int);
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(val_arr)) {
val_arr = val_arr->at(attr_name);
} else if (!attr_is_int && is_val<value_object>(val_arr)) {
val_arr = val_arr->at(attribute);
}
}
if (!is_val<value_string>(val_arr) && !is_val<value_int>(val_arr) && !is_val<value_float>(val_arr)) {
@ -808,9 +806,7 @@ const func_builtins & value_array_t::get_builtins() const {
}},
{"string", [](const func_args & args) -> value {
args.ensure_vals<value_array>();
auto str = mk_val<value_string>();
gather_string_parts_recursive(args.get_pos(0), str);
return str;
return mk_val<value_string>(args.get_pos(0)->as_string());
}},
{"tojson", tojson},
{"map", [](const func_args & args) -> value {
@ -821,26 +817,26 @@ const func_builtins & value_array_t::get_builtins() const {
if (!is_val<value_kwarg>(args.get_args().at(1))) {
throw not_implemented_exception("map: filter-mapping not implemented");
}
value val = args.get_pos(0);
value attribute = args.get_kwarg_or_pos("attribute", 1);
const bool attr_is_int = is_val<value_int>(attribute);
if (!is_val<value_string>(attribute) && !attr_is_int) {
throw raised_exception("map: attribute must be string or integer");
}
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string attr_name = attribute->as_string().str();
value default_val = args.get_kwarg("default", mk_val<value_undefined>());
auto out = mk_val<value_array>();
auto arr = args.get_pos(0)->as_array();
auto arr = val->as_array();
for (const auto & item : arr) {
value attr_val;
if (attr_is_int) {
attr_val = is_val<value_array>(item) ? item->at(attr_int, default_val) : default_val;
} else {
attr_val = is_val<value_object>(item) ? item->at(attr_name, default_val) : default_val;
attr_val = is_val<value_object>(item) ? item->at(attribute, default_val) : default_val;
}
out->push_back(attr_val);
}
return out;
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(out->as_array())) : out;
}},
{"append", [](const func_args & args) -> value {
args.ensure_count(2);
@ -867,6 +863,7 @@ const func_builtins & value_array_t::get_builtins() const {
if (!is_val<value_array>(args.get_pos(0))) {
throw raised_exception("sort: first argument must be an array");
}
value val = args.get_pos(0);
value val_reverse = args.get_kwarg_or_pos("reverse", 1);
value val_case = args.get_kwarg_or_pos("case_sensitive", 2);
value attribute = args.get_kwarg_or_pos("attribute", 3);
@ -875,8 +872,7 @@ const func_builtins & value_array_t::get_builtins() const {
const bool reverse = val_reverse->as_bool(); // undefined == false
const bool attr_is_int = is_val<value_int>(attribute);
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
std::vector<value> arr = val->as_array(); // copy
std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) {
value val_a = a;
value val_b = b;
@ -884,22 +880,23 @@ const func_builtins & value_array_t::get_builtins() const {
if (attr_is_int && is_val<value_array>(a) && is_val<value_array>(b)) {
val_a = a->at(attr_int);
val_b = b->at(attr_int);
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(a) && is_val<value_object>(b)) {
val_a = a->at(attr_name);
val_b = b->at(attr_name);
} else if (!attr_is_int && is_val<value_object>(a) && is_val<value_object>(b)) {
val_a = a->at(attribute);
val_b = b->at(attribute);
} else {
throw raised_exception("sort: unsupported object attribute comparison");
throw raised_exception("sort: unsupported object attribute comparison between " + a->type() + " and " + b->type());
}
}
return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt);
});
return mk_val<value_array>(arr);
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
}},
{"reverse", [](const func_args & args) -> value {
args.ensure_vals<value_array>();
std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
value val = args.get_pos(0);
std::vector<value> arr = val->as_array(); // copy
std::reverse(arr.begin(), arr.end());
return mk_val<value_array>(arr);
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
}},
{"unique", [](const func_args &) -> value {
throw not_implemented_exception("Array unique builtin not implemented");
@ -930,7 +927,7 @@ const func_builtins & value_object_t::get_builtins() const {
default_val = args.get_pos(2);
}
const value obj = args.get_pos(0);
std::string key = args.get_pos(1)->as_string().str();
const value key = args.get_pos(1);
return obj->at(key, default_val);
}},
{"keys", [](const func_args & args) -> value {
@ -938,7 +935,7 @@ const func_builtins & value_object_t::get_builtins() const {
const auto & obj = args.get_pos(0)->as_ordered_object();
auto result = mk_val<value_array>();
for (const auto & pair : obj) {
result->push_back(mk_val<value_string>(pair.first));
result->push_back(pair.first);
}
return result;
}},
@ -956,15 +953,16 @@ const func_builtins & value_object_t::get_builtins() const {
const auto & obj = args.get_pos(0)->as_ordered_object();
auto result = mk_val<value_array>();
for (const auto & pair : obj) {
auto item = mk_val<value_array>();
item->push_back(mk_val<value_string>(pair.first));
item->push_back(pair.second);
auto item = mk_val<value_tuple>(pair);
result->push_back(std::move(item));
}
return result;
}},
{"tojson", tojson},
{"string", tojson},
{"string", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
return mk_val<value_string>(args.get_pos(0)->as_string());
}},
{"length", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
const auto & obj = args.get_pos(0)->as_ordered_object();
@ -985,11 +983,11 @@ const func_builtins & value_object_t::get_builtins() const {
const bool reverse = val_reverse->as_bool(); // undefined == false
const bool by_value = is_val<value_string>(val_by) && val_by->as_string().str() == "value" ? true : false;
auto result = mk_val<value_object>(val_input); // copy
std::sort(result->val_obj.ordered.begin(), result->val_obj.ordered.end(), [&](const auto & a, const auto & b) {
std::sort(result->val_obj.begin(), result->val_obj.end(), [&](const auto & a, const auto & b) {
if (by_value) {
return value_compare(a.second, b.second, reverse ? value_compare_op::gt : value_compare_op::lt);
} else {
return reverse ? a.first > b.first : a.first < b.first;
return value_compare(a.first, b.first, reverse ? value_compare_op::gt : value_compare_op::lt);
}
});
return result;
@ -1134,6 +1132,8 @@ void global_from_json(context & ctx, const nlohmann::ordered_json & json_obj, bo
}
}
// recursively convert value to JSON string
// TODO: avoid circular references
static void value_to_json_internal(std::ostringstream & oss, const value & val, int curr_lvl, int indent, const std::string_view item_sep, const std::string_view key_sep) {
auto indent_str = [indent, curr_lvl]() -> std::string {
return (indent > 0) ? std::string(curr_lvl * indent, ' ') : "";
@ -1196,7 +1196,8 @@ static void value_to_json_internal(std::ostringstream & oss, const value & val,
size_t i = 0;
for (const auto & pair : obj) {
oss << indent_str() << (indent > 0 ? std::string(indent, ' ') : "");
oss << "\"" << pair.first << "\"" << key_sep;
value_to_json_internal(oss, mk_val<value_string>(pair.first->as_string().str()), curr_lvl + 1, indent, item_sep, key_sep);
oss << key_sep;
value_to_json_internal(oss, pair.second, curr_lvl + 1, indent, item_sep, key_sep);
if (i < obj.size() - 1) {
oss << item_sep;
@ -1219,4 +1220,19 @@ std::string value_to_json(const value & val, int indent, const std::string_view
return oss.str();
}
// TODO: avoid circular references
std::string value_to_string_repr(const value & val) {
if (is_val<value_string>(val)) {
const std::string val_str = val->as_string().str();
if (val_str.find('\'') != std::string::npos) {
return value_to_json(val);
} else {
return "'" + val_str + "'";
}
} else {
return val->as_repr();
}
}
} // namespace jinja

View File

@ -1,8 +1,10 @@
#pragma once
#include "string.h"
#include "utils.h"
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <functional>
#include <map>
@ -93,7 +95,8 @@ void global_from_json(context & ctx, const T_JSON & json_obj, bool mark_input);
struct func_args; // function argument values
using func_handler = std::function<value(const func_args &)>;
using func_hptr = value(const func_args &);
using func_handler = std::function<func_hptr>;
using func_builtins = std::map<std::string, func_handler>;
enum value_compare_op { eq, ge, gt, lt, ne };
@ -103,28 +106,9 @@ struct value_t {
int64_t val_int;
double val_flt;
string val_str;
bool val_bool;
std::vector<value> val_arr;
struct map {
// once set to true, all keys must be numeric
// caveat: we only allow either all numeric keys or all non-numeric keys
// for now, this only applied to for_statement in case of iterating over object keys/items
bool is_key_numeric = false;
std::map<std::string, value> unordered;
std::vector<std::pair<std::string, value>> ordered;
void insert(const std::string & key, const value & val) {
if (unordered.find(key) != unordered.end()) {
// if key exists, remove from ordered list
ordered.erase(std::remove_if(ordered.begin(), ordered.end(),
[&](const std::pair<std::string, value> & p) { return p.first == key; }),
ordered.end());
}
unordered[key] = val;
ordered.push_back({key, val});
}
} val_obj;
std::vector<std::pair<value, value>> val_obj;
func_handler val_func;
@ -139,6 +123,7 @@ struct value_t {
value_t(const value_t &) = default;
virtual ~value_t() = default;
// Note: only for debugging and error reporting purposes
virtual std::string type() const { return ""; }
virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); }
@ -146,7 +131,7 @@ struct value_t {
virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
virtual const std::vector<std::pair<value, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
virtual bool is_none() const { return false; }
virtual bool is_undefined() const { return false; }
@ -154,43 +139,66 @@ struct value_t {
throw std::runtime_error("No builtins available for type " + type());
}
virtual bool has_key(const std::string & key) {
return val_obj.unordered.find(key) != val_obj.unordered.end();
}
virtual value & at(const std::string & key, value & default_val) {
auto it = val_obj.unordered.find(key);
if (it == val_obj.unordered.end()) {
return default_val;
}
return val_obj.unordered.at(key);
}
virtual value & at(const std::string & key) {
auto it = val_obj.unordered.find(key);
if (it == val_obj.unordered.end()) {
throw std::runtime_error("Key '" + key + "' not found in value of type " + type());
}
return val_obj.unordered.at(key);
}
virtual value & at(int64_t index, value & default_val) {
if (index < 0) {
index += val_arr.size();
}
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
return default_val;
}
return val_arr[index];
}
virtual value & at(int64_t index) {
if (index < 0) {
index += val_arr.size();
}
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
}
return val_arr[index];
}
virtual bool has_key(const value &) { throw std::runtime_error(type() + " is not an object value"); }
virtual void insert(const value & /* key */, const value & /* val */) { throw std::runtime_error(type() + " is not an object value"); }
virtual value & at(const value & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
virtual value & at(const value & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
virtual value & at(const std::string & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
virtual value & at(const std::string & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
virtual value & at(int64_t /* idx */, value & /* default_val */) { throw std::runtime_error(type() + " is not an array value"); }
virtual value & at(int64_t /* idx */) { throw std::runtime_error(type() + " is not an array value"); }
virtual bool is_numeric() const { return false; }
virtual bool is_hashable() const { return false; }
virtual bool is_immutable() const { return true; }
virtual hasher unique_hash() const noexcept = 0;
// TODO: C++20 <=> operator
// NOTE: We are treating == as equivalent (for normal comparisons) and != as strict nonequal (for strict (is) comparisons)
virtual bool operator==(const value_t & other) const { return equivalent(other); }
virtual bool operator!=(const value_t & other) const { return nonequal(other); }
// Note: only for debugging purposes
virtual std::string as_repr() const { return as_string().str(); }
protected:
virtual bool equivalent(const value_t &) const = 0;
virtual bool nonequal(const value_t & other) const { return !equivalent(other); }
};
//
// utils
//
const func_builtins & global_builtins();
std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": ");
// Note: only used for debugging purposes
std::string value_to_string_repr(const value & val);
struct not_implemented_exception : public std::runtime_error {
not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {}
};
struct value_hasher {
size_t operator()(const value & val) const noexcept {
return val->unique_hash().digest();
}
};
struct value_equivalence {
bool operator()(const value & lhs, const value & rhs) const {
return *lhs == *rhs;
}
bool operator()(const std::pair<value, value> & lhs, const std::pair<value, value> & rhs) const {
return *(lhs.first) == *(rhs.first) && *(lhs.second) == *(rhs.second);
}
};
struct value_equality {
bool operator()(const value & lhs, const value & rhs) const {
return !(*lhs != *rhs);
}
};
//
@ -198,24 +206,49 @@ struct value_t {
//
struct value_int_t : public value_t {
value_int_t(int64_t v) { val_int = v; }
value_int_t(int64_t v) {
val_int = v;
val_flt = static_cast<double>(v);
if (static_cast<int64_t>(val_flt) != v) {
val_flt = v < 0 ? -INFINITY : INFINITY;
}
}
virtual std::string type() const override { return "Integer"; }
virtual int64_t as_int() const override { return val_int; }
virtual double as_float() const override { return static_cast<double>(val_int); }
virtual double as_float() const override { return val_flt; }
virtual string as_string() const override { return std::to_string(val_int); }
virtual bool as_bool() const override {
return val_int != 0;
}
virtual const func_builtins & get_builtins() const override;
virtual bool is_numeric() const override { return true; }
virtual bool is_hashable() const override { return true; }
virtual hasher unique_hash() const noexcept override {
return hasher(typeid(*this))
.update(&val_int, sizeof(val_int))
.update(&val_flt, sizeof(val_flt));
}
protected:
virtual bool equivalent(const value_t & other) const override {
return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
}
virtual bool nonequal(const value_t & other) const override {
return !(typeid(*this) == typeid(other) && val_int == other.val_int);
}
};
using value_int = std::shared_ptr<value_int_t>;
struct value_float_t : public value_t {
value_float_t(double v) { val_flt = v; }
value val;
value_float_t(double v) {
val_flt = v;
val_int = std::isfinite(v) ? static_cast<int64_t>(v) : 0;
val = mk_val<value_int>(val_int);
}
virtual std::string type() const override { return "Float"; }
virtual double as_float() const override { return val_flt; }
virtual int64_t as_int() const override { return static_cast<int64_t>(val_flt); }
virtual int64_t as_int() const override { return val_int; }
virtual string as_string() const override {
std::string out = std::to_string(val_flt);
out.erase(out.find_last_not_of('0') + 1, std::string::npos); // remove trailing zeros
@ -226,6 +259,24 @@ struct value_float_t : public value_t {
return val_flt != 0.0;
}
virtual const func_builtins & get_builtins() const override;
virtual bool is_numeric() const override { return true; }
virtual bool is_hashable() const override { return true; }
virtual hasher unique_hash() const noexcept override {
if (static_cast<double>(val_int) == val_flt) {
return val->unique_hash();
} else {
return hasher(typeid(*this))
.update(&val_int, sizeof(val_int))
.update(&val_flt, sizeof(val_flt));
}
}
protected:
virtual bool equivalent(const value_t & other) const override {
return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
}
virtual bool nonequal(const value_t & other) const override {
return !(typeid(*this) == typeid(other) && val_flt == other.val_flt);
}
};
using value_float = std::shared_ptr<value_float_t>;
@ -247,19 +298,49 @@ struct value_string_t : public value_t {
return val_str.length() > 0;
}
virtual const func_builtins & get_builtins() const override;
virtual bool is_hashable() const override { return true; }
virtual hasher unique_hash() const noexcept override {
const auto type_hash = typeid(*this).hash_code();
auto hash = hasher();
hash.update(&type_hash, sizeof(type_hash));
val_str.hash_update(hash);
return hash;
}
void mark_input() {
val_str.mark_input();
}
protected:
virtual bool equivalent(const value_t & other) const override {
return typeid(*this) == typeid(other) && val_str.str() == other.val_str.str();
}
};
using value_string = std::shared_ptr<value_string_t>;
struct value_bool_t : public value_t {
value_bool_t(bool v) { val_bool = v; }
value val;
value_bool_t(bool v) {
val_int = static_cast<int64_t>(v);
val_flt = static_cast<double>(v);
val = mk_val<value_int>(val_int);
}
virtual std::string type() const override { return "Boolean"; }
virtual bool as_bool() const override { return val_bool; }
virtual string as_string() const override { return std::string(val_bool ? "True" : "False"); }
virtual int64_t as_int() const override { return val_int; }
virtual bool as_bool() const override { return val_int; }
virtual string as_string() const override { return std::string(val_int ? "True" : "False"); }
virtual const func_builtins & get_builtins() const override;
virtual bool is_numeric() const override { return true; }
virtual bool is_hashable() const override { return true; }
virtual hasher unique_hash() const noexcept override {
return val->unique_hash();
}
protected:
virtual bool equivalent(const value_t & other) const override {
return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
}
virtual bool nonequal(const value_t & other) const override {
return !(typeid(*this) == typeid(other) && val_int == other.val_int);
}
};
using value_bool = std::shared_ptr<value_bool_t>;
@ -269,13 +350,34 @@ struct value_array_t : public value_t {
value_array_t(value & v) {
val_arr = v->val_arr;
}
value_array_t(std::vector<value> && arr) {
val_arr = arr;
}
value_array_t(const std::vector<value> & arr) {
val_arr = arr;
}
void reverse() { std::reverse(val_arr.begin(), val_arr.end()); }
void push_back(const value & val) { val_arr.push_back(val); }
void push_back(value && val) { val_arr.push_back(std::move(val)); }
void reverse() {
if (is_immutable()) {
throw std::runtime_error("Attempting to modify immutable type");
}
std::reverse(val_arr.begin(), val_arr.end());
}
void push_back(const value & val) {
if (is_immutable()) {
throw std::runtime_error("Attempting to modify immutable type");
}
val_arr.push_back(val);
}
void push_back(value && val) {
if (is_immutable()) {
throw std::runtime_error("Attempting to modify immutable type");
}
val_arr.push_back(std::move(val));
}
value pop_at(int64_t index) {
if (is_immutable()) {
throw std::runtime_error("Attempting to modify immutable type");
}
if (index < 0) {
index = static_cast<int64_t>(val_arr.size()) + index;
}
@ -287,64 +389,225 @@ struct value_array_t : public value_t {
return val;
}
virtual std::string type() const override { return "Array"; }
virtual bool is_immutable() const override { return false; }
virtual const std::vector<value> & as_array() const override { return val_arr; }
virtual string as_string() const override {
const bool immutable = is_immutable();
std::ostringstream ss;
ss << "[";
ss << (immutable ? "(" : "[");
for (size_t i = 0; i < val_arr.size(); i++) {
if (i > 0) ss << ", ";
ss << val_arr.at(i)->as_repr();
value val = val_arr.at(i);
ss << value_to_string_repr(val);
}
ss << "]";
if (immutable && val_arr.size() == 1) {
ss << ",";
}
ss << (immutable ? ")" : "]");
return ss.str();
}
virtual bool as_bool() const override {
return !val_arr.empty();
}
virtual value & at(int64_t index, value & default_val) override {
if (index < 0) {
index += val_arr.size();
}
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
return default_val;
}
return val_arr[index];
}
virtual value & at(int64_t index) override {
if (index < 0) {
index += val_arr.size();
}
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
}
return val_arr[index];
}
virtual const func_builtins & get_builtins() const override;
virtual bool is_hashable() const override {
if (std::all_of(val_arr.begin(), val_arr.end(), [&](auto & val) -> bool {
return val->is_immutable() && val->is_hashable();
})) {
return true;
}
return false;
}
virtual hasher unique_hash() const noexcept override {
auto hash = hasher(typeid(*this));
for (const auto & val : val_arr) {
// must use digest to prevent problems from "concatenation" property of hasher
// for ex. hash of [ "ab", "c" ] should be different from [ "a", "bc" ]
const size_t val_hash = val->unique_hash().digest();
hash.update(&val_hash, sizeof(size_t));
}
return hash;
}
protected:
virtual bool equivalent(const value_t & other) const override {
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence());
}
};
using value_array = std::shared_ptr<value_array_t>;
struct value_tuple_t : public value_array_t {
value_tuple_t(value & v) {
val_arr = v->val_arr;
}
value_tuple_t(std::vector<value> && arr) {
val_arr = arr;
}
value_tuple_t(const std::vector<value> & arr) {
val_arr = arr;
}
value_tuple_t(const std::pair<value, value> & pair) {
val_arr.push_back(pair.first);
val_arr.push_back(pair.second);
}
virtual std::string type() const override { return "Tuple"; }
virtual bool is_immutable() const override { return true; }
};
using value_tuple = std::shared_ptr<value_tuple_t>;
struct value_object_t : public value_t {
std::unordered_map<value, value, value_hasher, value_equivalence> unordered;
bool has_builtins = true; // context and loop objects do not have builtins
value_object_t() = default;
value_object_t(value & v) {
val_obj = v->val_obj;
}
value_object_t(const std::map<std::string, value> & obj) {
for (const auto & pair : obj) {
val_obj.insert(pair.first, pair.second);
for (const auto & pair : val_obj) {
unordered[pair.first] = pair.second;
}
}
value_object_t(const std::vector<std::pair<std::string, value>> & obj) {
value_object_t(const std::map<value, value> & obj) {
for (const auto & pair : obj) {
val_obj.insert(pair.first, pair.second);
insert(pair.first, pair.second);
}
}
value_object_t(const std::vector<std::pair<value, value>> & obj) {
for (const auto & pair : obj) {
insert(pair.first, pair.second);
}
}
void insert(const std::string & key, const value & val) {
val_obj.insert(key, val);
insert(mk_val<value_string>(key), val);
}
virtual std::string type() const override { return "Object"; }
virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const override { return val_obj.ordered; }
virtual bool is_immutable() const override { return false; }
virtual const std::vector<std::pair<value, value>> & as_ordered_object() const override { return val_obj; }
virtual string as_string() const override {
std::ostringstream ss;
ss << "{";
for (size_t i = 0; i < val_obj.size(); i++) {
if (i > 0) ss << ", ";
auto & [key, val] = val_obj.at(i);
ss << value_to_string_repr(key) << ": " << value_to_string_repr(val);
}
ss << "}";
return ss.str();
}
virtual bool as_bool() const override {
return !val_obj.unordered.empty();
return !unordered.empty();
}
virtual bool has_key(const value & key) override {
if (!key->is_immutable() || !key->is_hashable()) {
throw std::runtime_error("Object key of unhashable type: " + key->type());
}
return unordered.find(key) != unordered.end();
}
virtual void insert(const value & key, const value & val) override {
bool replaced = false;
if (is_immutable()) {
throw std::runtime_error("Attempting to modify immutable type");
}
if (has_key(key)) {
// if key exists, replace value in ordered list instead of appending
for (auto & pair : val_obj) {
if (*(pair.first) == *key) {
pair.second = val;
replaced = true;
break;
}
}
}
unordered[key] = val;
if (!replaced) {
val_obj.push_back({key, val});
}
}
virtual value & at(const value & key, value & default_val) override {
if (!has_key(key)) {
return default_val;
}
return unordered.at(key);
}
virtual value & at(const value & key) override {
if (!has_key(key)) {
throw std::runtime_error("Key '" + key->as_string().str() + "' not found in value of type " + type());
}
return unordered.at(key);
}
virtual value & at(const std::string & key, value & default_val) override {
value key_val = mk_val<value_string>(key);
return at(key_val, default_val);
}
virtual value & at(const std::string & key) override {
value key_val = mk_val<value_string>(key);
return at(key_val);
}
virtual const func_builtins & get_builtins() const override;
virtual bool is_hashable() const override {
if (std::all_of(val_obj.begin(), val_obj.end(), [&](auto & pair) -> bool {
const auto & val = pair.second;
return val->is_immutable() && val->is_hashable();
})) {
return true;
}
return false;
}
virtual hasher unique_hash() const noexcept override {
auto hash = hasher(typeid(*this));
for (const auto & [key, val] : val_obj) {
// must use digest to prevent problems from "concatenation" property of hasher
// for ex. hash of key="ab", value="c" should be different from key="a", value="bc"
const size_t key_hash = key->unique_hash().digest();
const size_t val_hash = val->unique_hash().digest();
hash.update(&key_hash, sizeof(key_hash));
hash.update(&val_hash, sizeof(val_hash));
}
return hash;
}
protected:
virtual bool equivalent(const value_t & other) const override {
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence());
}
};
using value_object = std::shared_ptr<value_object_t>;
//
// null and undefined types
// none and undefined types
//
struct value_none_t : public value_t {
virtual std::string type() const override { return "None"; }
virtual bool is_none() const override { return true; }
virtual bool as_bool() const override { return false; }
virtual string as_string() const override { return string("None"); }
virtual string as_string() const override { return string(type()); }
virtual std::string as_repr() const override { return type(); }
virtual const func_builtins & get_builtins() const override;
virtual bool is_hashable() const override { return true; }
virtual hasher unique_hash() const noexcept override {
return hasher(typeid(*this));
}
protected:
virtual bool equivalent(const value_t & other) const override {
return typeid(*this) == typeid(other);
}
};
using value_none = std::shared_ptr<value_none_t>;
@ -356,6 +619,13 @@ struct value_undefined_t : public value_t {
virtual bool as_bool() const override { return false; }
virtual std::string as_repr() const override { return type(); }
virtual const func_builtins & get_builtins() const override;
virtual hasher unique_hash() const noexcept override {
return hasher(typeid(*this));
}
protected:
virtual bool equivalent(const value_t & other) const override {
return is_undefined() == other.is_undefined();
}
};
using value_undefined = std::shared_ptr<value_undefined_t>;
@ -436,7 +706,23 @@ struct value_func_t : public value_t {
return val_func(new_args);
}
virtual std::string type() const override { return "Function"; }
virtual std::string as_repr() const override { return type(); }
virtual std::string as_repr() const override { return type() + "<" + name + ">(" + (arg0 ? arg0->as_repr() : "") + ")"; }
virtual bool is_hashable() const override { return false; }
virtual hasher unique_hash() const noexcept override {
// Note: this is unused for now, we don't support function as object keys
// use function pointer as unique identifier
const auto target = val_func.target<func_hptr>();
return hasher(typeid(*this)).update(&target, sizeof(target));
}
protected:
virtual bool equivalent(const value_t & other) const override {
// Note: this is unused for now, we don't support function as object keys
// compare function pointers
// (val_func == other.val_func does not work as std::function::operator== is only used for nullptr check)
const auto target_this = this->val_func.target<func_hptr>();
const auto target_other = other.val_func.target<func_hptr>();
return typeid(*this) == typeid(other) && target_this == target_other;
}
};
using value_func = std::shared_ptr<value_func_t>;
@ -447,18 +733,21 @@ struct value_kwarg_t : public value_t {
value_kwarg_t(const std::string & k, const value & v) : key(k), val(v) {}
virtual std::string type() const override { return "KwArg"; }
virtual std::string as_repr() const override { return type(); }
virtual bool is_hashable() const override { return true; }
virtual hasher unique_hash() const noexcept override {
const auto type_hash = typeid(*this).hash_code();
auto hash = val->unique_hash();
hash.update(&type_hash, sizeof(type_hash))
.update(key.data(), key.size());
return hash;
}
protected:
virtual bool equivalent(const value_t & other) const override {
const value_kwarg_t & other_val = static_cast<const value_kwarg_t &>(other);
return typeid(*this) == typeid(other) && key == other_val.key && val == other_val.val;
}
};
using value_kwarg = std::shared_ptr<value_kwarg_t>;
// utils
const func_builtins & global_builtins();
std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": ");
struct not_implemented_exception : public std::runtime_error {
not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {}
};
} // namespace jinja

View File

@ -481,7 +481,7 @@ int main_automated_tests(void) {
/* .name= */ "Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)",
/* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
/* .expected_output= */ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] Another question[/INST]",
/* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] You are a helpful assistant\n\nAnother question[/INST]",
/* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[AVAILABLE_TOOLS] [[/AVAILABLE_TOOLS][INST] You are a helpful assistant\n\nAnother question[/INST]",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},
@ -489,7 +489,7 @@ int main_automated_tests(void) {
/* .name= */ "Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)",
/* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
/* .expected_output= */ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]Another question[/INST]",
/* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]You are a helpful assistant\n\nAnother question[/INST]",
/* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[AVAILABLE_TOOLS][[/AVAILABLE_TOOLS][INST]You are a helpful assistant\n\nAnother question[/INST]",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},

View File

@ -9,6 +9,7 @@
#include "jinja/runtime.h"
#include "jinja/parser.h"
#include "jinja/lexer.h"
#include "jinja/utils.h"
#include "testing.h"
@ -30,6 +31,7 @@ static void test_tests(testing & t);
static void test_string_methods(testing & t);
static void test_array_methods(testing & t);
static void test_object_methods(testing & t);
static void test_hasher(testing & t);
static void test_fuzzing(testing & t);
static bool g_python_mode = false;
@ -67,6 +69,7 @@ int main(int argc, char *argv[]) {
t.test("array methods", test_array_methods);
t.test("object methods", test_object_methods);
if (!g_python_mode) {
t.test("hasher", test_hasher);
t.test("fuzzing", test_fuzzing);
}
@ -156,6 +159,18 @@ static void test_conditionals(testing & t) {
"big"
);
test_template(t, "object comparison",
"{% if {0: 1, none: 2, 1.0: 3, '0': 4, true: 5} == {false: 1, none: 2, 1: 5, '0': 4} %}equal{% endif %}",
json::object(),
"equal"
);
test_template(t, "array comparison",
"{% if [0, 1.0, false] == [false, 1, 0.0] %}equal{% endif %}",
json::object(),
"equal"
);
test_template(t, "logical and",
"{% if a and b %}both{% endif %}",
{{"a", true}, {"b", true}},
@ -358,6 +373,30 @@ static void test_expressions(testing & t) {
"b"
);
test_template(t, "array negative access",
"{{ items[-1] }}",
{{"items", json::array({"a", "b", "c"})}},
"c"
);
test_template(t, "array slice",
"{{ items[1:-1]|string }}",
{{"items", json::array({"a", "b", "c"})}},
"['b']"
);
test_template(t, "array slice step",
"{{ items[::2]|string }}",
{{"items", json::array({"a", "b", "c"})}},
"['a', 'c']"
);
test_template(t, "tuple slice",
"{{ ('a', 'b', 'c')[::-1]|string }}",
json::object(),
"('c', 'b', 'a')"
);
test_template(t, "arithmetic",
"{{ (a + b) * c }}",
{{"a", 2}, {"b", 3}, {"c", 4}},
@ -401,6 +440,36 @@ static void test_set_statement(testing & t) {
json::object(),
"1"
);
test_template(t, "set dict with mixed type keys",
"{% set d = {0: 1, none: 2, 1.0: 3, '0': 4, (0, 0): 5, false: 6, 1: 7} %}{{ d[(0, 0)] + d[0] + d[none] + d['0'] + d[false] + d[1.0] + d[1] }}",
json::object(),
"37"
);
test_template(t, "print dict with mixed type keys",
"{% set d = {0: 1, none: 2, 1.0: 3, '0': 4, (0, 0): 5, true: 6} %}{{ d|string }}",
json::object(),
"{0: 1, None: 2, 1.0: 6, '0': 4, (0, 0): 5}"
);
test_template(t, "print array with mixed types",
"{% set d = [0, none, 1.0, '0', true, (0, 0)] %}{{ d|string }}",
json::object(),
"[0, None, 1.0, '0', True, (0, 0)]"
);
test_template(t, "object member assignment with mixed key types",
"{% set d = namespace() %}{% set d.a = 123 %}{{ d['a'] == 123 }}",
json::object(),
"True"
);
test_template(t, "tuple unpacking",
"{% set t = (1, 2, 3) %}{% set a, b, c = t %}{{ a + b + c }}",
json::object(),
"6"
);
}
static void test_filters(testing & t) {
@ -1312,6 +1381,154 @@ static void test_object_methods(testing & t) {
{{"obj", {{"a", "b"}}}},
"True True"
);
test_template(t, "expression as object key",
"{% set d = {'ab': 123} %}{{ d['a' + 'b'] == 123 }}",
json::object(),
"True"
);
test_template(t, "numeric as object key (template: Seed-OSS)",
"{% set d = {1: 'a', 2: 'b'} %}{{ d[1] == 'a' and d[2] == 'b' }}",
json::object(),
"True"
);
}
static void test_hasher(testing & t) {
static const std::vector<std::pair<size_t, size_t>> chunk_sizes = {
{1, 2},
{1, 16},
{8, 1},
{1, 1024},
{5, 512},
{16, 256},
{45, 122},
{70, 634},
};
static auto random_bytes = [](size_t length) -> std::string {
std::string data;
data.resize(length);
for (size_t i = 0; i < length; ++i) {
data[i] = static_cast<char>(rand() % 256);
}
return data;
};
t.test("state unchanged with empty input", [](testing & t) {
jinja::hasher hasher;
hasher.update("some data");
size_t initial_state = hasher.digest();
hasher.update("", 0);
size_t final_state = hasher.digest();
t.assert_true("Hasher state should remain unchanged", initial_state == final_state);
});
t.test("different inputs produce different hashes", [](testing & t) {
jinja::hasher hasher1;
hasher1.update("data one");
size_t hash1 = hasher1.digest();
jinja::hasher hasher2;
hasher2.update("data two");
size_t hash2 = hasher2.digest();
t.assert_true("Different inputs should produce different hashes", hash1 != hash2);
});
t.test("same inputs produce same hashes", [](testing & t) {
jinja::hasher hasher1;
hasher1.update("consistent data");
size_t hash1 = hasher1.digest();
jinja::hasher hasher2;
hasher2.update("consistent data");
size_t hash2 = hasher2.digest();
t.assert_true("Same inputs should produce same hashes", hash1 == hash2);
});
t.test("property: update(a ~ b) == update(a).update(b)", [](testing & t) {
for (const auto & [size1, size2] : chunk_sizes) {
std::string data1 = random_bytes(size1);
std::string data2 = random_bytes(size2);
jinja::hasher hasher1;
hasher1.update(data1);
hasher1.update(data2);
size_t hash1 = hasher1.digest();
jinja::hasher hasher2;
hasher2.update(data1 + data2);
size_t hash2 = hasher2.digest();
t.assert_true(
"Hashing in multiple updates should match single update (" + std::to_string(size1) + ", " + std::to_string(size2) + ")",
hash1 == hash2);
}
});
t.test("property: update(a ~ b) == update(a).update(b) with more update passes", [](testing & t) {
static const std::vector<size_t> sizes = {3, 732, 131, 13, 17, 256, 436, 99, 4};
jinja::hasher hasher1;
jinja::hasher hasher2;
std::string combined_data;
for (size_t size : sizes) {
std::string data = random_bytes(size);
hasher1.update(data);
combined_data += data;
}
hasher2.update(combined_data);
size_t hash1 = hasher1.digest();
size_t hash2 = hasher2.digest();
t.assert_true(
"Hashing in multiple updates should match single update with many chunks",
hash1 == hash2);
});
t.test("property: non associativity of update", [](testing & t) {
for (const auto & [size1, size2] : chunk_sizes) {
std::string data1 = random_bytes(size1);
std::string data2 = random_bytes(size2);
jinja::hasher hasher1;
hasher1.update(data1);
hasher1.update(data2);
size_t hash1 = hasher1.digest();
jinja::hasher hasher2;
hasher2.update(data2);
hasher2.update(data1);
size_t hash2 = hasher2.digest();
t.assert_true(
"Hashing order should matter (" + std::to_string(size1) + ", " + std::to_string(size2) + ")",
hash1 != hash2);
}
});
t.test("property: different lengths produce different hashes (padding block size)", [](testing & t) {
std::string random_data = random_bytes(64);
jinja::hasher hasher1;
hasher1.update(random_data);
size_t hash1 = hasher1.digest();
for (int i = 0; i < 16; ++i) {
random_data.push_back('A'); // change length
jinja::hasher hasher2;
hasher2.update(random_data);
size_t hash2 = hasher2.digest();
t.assert_true("Different lengths should produce different hashes (length " + std::to_string(random_data.size()) + ")", hash1 != hash2);
hash1 = hash2;
}
});
}
static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {