From 2b4cbd2834e427024bc7f935a1f232aecac6679b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Tue, 27 Jan 2026 19:50:42 +0100 Subject: [PATCH] jinja : implement mixed type object keys (#18955) * implement mixed type object keys * add tests * refactor * minor fixes * massive refactor * add more tests * forgotten tuples * fix array/object is_hashable * correct (albeit broken) jinja responses verified with transformers * improved hashing and equality * refactor hash function * more exhausive test case * clean up * cont * cont (2) * missing cstring --------- Co-authored-by: Xuan Son Nguyen --- common/jinja/runtime.cpp | 55 ++-- common/jinja/runtime.h | 26 +- common/jinja/string.cpp | 6 + common/jinja/string.h | 3 + common/jinja/utils.h | 100 ++++++++ common/jinja/value.cpp | 86 ++++--- common/jinja/value.h | 473 ++++++++++++++++++++++++++++------- tests/test-chat-template.cpp | 4 +- tests/test-jinja.cpp | 217 ++++++++++++++++ 9 files changed, 802 insertions(+), 168 deletions(-) diff --git a/common/jinja/runtime.cpp b/common/jinja/runtime.cpp index e3e4ebf1ec..f234d9284f 100644 --- a/common/jinja/runtime.cpp +++ b/common/jinja/runtime.cpp @@ -44,6 +44,12 @@ static std::string get_line_col(const std::string & source, size_t pos) { return "line " + std::to_string(line) + ", column " + std::to_string(col); } +static void ensure_key_type_allowed(const value & val) { + if (!val->is_hashable()) { + throw std::runtime_error("Type: " + val->type() + " is not allowed as object key"); + } +} + // execute with error handling value statement::execute(context & ctx) { try { @@ -95,20 +101,10 @@ value identifier::execute_impl(context & ctx) { value object_literal::execute_impl(context & ctx) { auto obj = mk_val(); for (const auto & pair : val) { - value key_val = pair.first->execute(ctx); - if (!is_val(key_val) && !is_val(key_val)) { - throw std::runtime_error("Object literal: keys must be string or int values, got " + key_val->type()); - } - std::string key = key_val->as_string().str(); + value key = pair.first->execute(ctx); value val = pair.second->execute(ctx); - JJ_DEBUG("Object literal: setting key '%s' with value type %s", key.c_str(), val->type().c_str()); + JJ_DEBUG("Object literal: setting key '%s' with value type %s", key->as_string().str().c_str(), val->type().c_str()); obj->insert(key, val); - - if (is_val(key_val)) { - obj->val_obj.is_key_numeric = true; - } else if (obj->val_obj.is_key_numeric) { - throw std::runtime_error("Object literal: cannot mix numeric and non-numeric keys"); - } } return obj; } @@ -127,9 +123,9 @@ value binary_expression::execute_impl(context & ctx) { value right_val = right->execute(ctx); JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right_val->type().c_str()); if (op.value == "==") { - return mk_val(value_compare(left_val, right_val, value_compare_op::eq)); + return mk_val(*left_val == *right_val); } else if (op.value == "!=") { - return mk_val(!value_compare(left_val, right_val, value_compare_op::eq)); + return mk_val(!(*left_val == *right_val)); } auto workaround_concat_null_with_str = [&](value & res) -> bool { @@ -230,7 +226,7 @@ value binary_expression::execute_impl(context & ctx) { auto & arr = right_val->as_array(); bool member = false; for (const auto & item : arr) { - if (value_compare(left_val, item, value_compare_op::eq)) { + if (*left_val == *item) { member = true; break; } @@ -265,10 +261,9 @@ value binary_expression::execute_impl(context & ctx) { } } - // String in object - if (is_val(left_val) && is_val(right_val)) { - auto key = left_val->as_string().str(); - bool has_key = right_val->has_key(key); + // Value key in object + if (is_val(right_val)) { + bool has_key = right_val->has_key(left_val); if (op.value == "in") { return mk_val(has_key); } else if (op.value == "not in") { @@ -465,14 +460,8 @@ value for_statement::execute_impl(context & ctx) { JJ_DEBUG("%s", "For loop over object keys"); auto & obj = iterable_val->as_ordered_object(); for (auto & p : obj) { - auto tuple = mk_val(); - if (iterable_val->val_obj.is_key_numeric) { - tuple->push_back(mk_val(std::stoll(p.first))); - } else { - tuple->push_back(mk_val(p.first)); - } - tuple->push_back(p.second); - items.push_back(tuple); + auto tuple = mk_val(p); + items.push_back(std::move(tuple)); } if (ctx.is_get_stats) { iterable_val->stats.used = true; @@ -602,11 +591,13 @@ value set_statement::execute_impl(context & ctx) { auto rhs = val ? val->execute(ctx) : exec_statements(body, ctx); if (is_stmt(assignee)) { + // case: {% set my_var = value %} auto var_name = cast_stmt(assignee)->val; JJ_DEBUG("Setting global variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str()); ctx.set_val(var_name, rhs); } else if (is_stmt(assignee)) { + // case: {% set a, b = value %} auto tuple = cast_stmt(assignee); if (!is_val(rhs)) { throw std::runtime_error("Cannot unpack non-iterable type in set: " + rhs->type()); @@ -625,6 +616,7 @@ value set_statement::execute_impl(context & ctx) { } } else if (is_stmt(assignee)) { + // case: {% set ns.my_var = value %} auto member = cast_stmt(assignee); if (member->computed) { throw std::runtime_error("Cannot assign to computed member"); @@ -767,22 +759,22 @@ value member_expression::execute_impl(context & ctx) { } JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str()); + ensure_key_type_allowed(property); value val = mk_val("object_property"); if (is_val(object)) { JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined"); return val; + } else if (is_val(object)) { - if (!is_val(property)) { - throw std::runtime_error("Cannot access object with non-string: got " + property->type()); - } auto key = property->as_string().str(); - val = object->at(key, val); + val = object->at(property, val); if (is_val(val)) { val = try_builtin_func(ctx, key, object, true); } JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str()); + } else if (is_val(object) || is_val(object)) { if (is_val(property)) { int64_t index = property->as_int(); @@ -806,6 +798,7 @@ value member_expression::execute_impl(context & ctx) { auto key = property->as_string().str(); JJ_DEBUG("Accessing %s built-in '%s'", is_val(object) ? "array" : "string", key.c_str()); val = try_builtin_func(ctx, key, object, true); + } else { throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type()); } diff --git a/common/jinja/runtime.h b/common/jinja/runtime.h index dc7f4e471c..17a6dff5aa 100644 --- a/common/jinja/runtime.h +++ b/common/jinja/runtime.h @@ -79,18 +79,18 @@ struct context { } value get_val(const std::string & name) { - auto it = env->val_obj.unordered.find(name); - if (it != env->val_obj.unordered.end()) { - return it->second; - } else { - return mk_val(name); - } + value default_val = mk_val(name); + return env->at(name, default_val); } void set_val(const std::string & name, const value & val) { env->insert(name, val); } + void set_val(const value & name, const value & val) { + env->insert(name, val); + } + void print_vars() const { printf("Context Variables:\n%s\n", value_to_json(env, 2).c_str()); } @@ -344,9 +344,19 @@ struct array_literal : public expression { } }; -struct tuple_literal : public array_literal { - explicit tuple_literal(statements && val) : array_literal(std::move(val)) {} +struct tuple_literal : public expression { + statements val; + explicit tuple_literal(statements && val) : val(std::move(val)) { + for (const auto& item : this->val) chk_type(item); + } std::string type() const override { return "TupleLiteral"; } + value execute_impl(context & ctx) override { + auto arr = mk_val(); + for (const auto & item_stmt : val) { + arr->push_back(item_stmt->execute(ctx)); + } + return mk_val(std::move(arr->as_array())); + } }; struct object_literal : public expression { diff --git a/common/jinja/string.cpp b/common/jinja/string.cpp index 21ebde39e3..8087e15b35 100644 --- a/common/jinja/string.cpp +++ b/common/jinja/string.cpp @@ -61,6 +61,12 @@ size_t string::length() const { return len; } +void string::hash_update(hasher & hash) const noexcept { + for (const auto & part : parts) { + hash.update(part.val.data(), part.val.length()); + } +} + bool string::all_parts_are_input() const { for (const auto & part : parts) { if (!part.is_input) { diff --git a/common/jinja/string.h b/common/jinja/string.h index 78457f9e41..c4963000ad 100644 --- a/common/jinja/string.h +++ b/common/jinja/string.h @@ -4,6 +4,8 @@ #include #include +#include "utils.h" + namespace jinja { // allow differentiate between user input strings and template strings @@ -37,6 +39,7 @@ struct string { std::string str() const; size_t length() const; + void hash_update(hasher & hash) const noexcept; bool all_parts_are_input() const; bool is_uppercase() const; bool is_lowercase() const; diff --git a/common/jinja/utils.h b/common/jinja/utils.h index 1e9f2a12a1..de6947fc28 100644 --- a/common/jinja/utils.h +++ b/common/jinja/utils.h @@ -3,6 +3,8 @@ #include #include #include +#include +#include namespace jinja { @@ -46,4 +48,102 @@ static std::string fmt_error_with_source(const std::string & tag, const std::str return oss.str(); } +// Note: this is a simple hasher, not cryptographically secure, just for hash table usage +struct hasher { + static constexpr auto size_t_digits = sizeof(size_t) * 8; + static constexpr size_t prime = size_t_digits == 64 ? 0x100000001b3 : 0x01000193; + static constexpr size_t seed = size_t_digits == 64 ? 0xcbf29ce484222325 : 0x811c9dc5; + static constexpr auto block_size = sizeof(size_t); // in bytes; allowing the compiler to vectorize the computation + + static_assert(size_t_digits == 64 || size_t_digits == 32); + static_assert(block_size == 8 || block_size == 4); + + uint8_t buffer[block_size]; + size_t idx = 0; // current index in buffer + size_t state = seed; + + hasher() = default; + hasher(const std::type_info & type_inf) noexcept { + const auto type_hash = type_inf.hash_code(); + update(&type_hash, sizeof(type_hash)); + } + + // Properties: + // - update is not associative: update(a).update(b) != update(b).update(a) + // - update(a ~ b) == update(a).update(b) with ~ as concatenation operator --> useful for streaming + // - update("", 0) --> state unchanged with empty input + hasher& update(void const * bytes, size_t len) noexcept { + const uint8_t * c = static_cast(bytes); + if (len == 0) { + return *this; + } + size_t processed = 0; + + // first, fill the existing buffer if it's partial + if (idx > 0) { + size_t to_fill = block_size - idx; + if (to_fill > len) { + to_fill = len; + } + std::memcpy(buffer + idx, c, to_fill); + idx += to_fill; + processed += to_fill; + if (idx == block_size) { + update_block(buffer); + idx = 0; + } + } + + // process full blocks from the remaining input + for (; processed + block_size <= len; processed += block_size) { + update_block(c + processed); + } + + // buffer any remaining bytes + size_t remaining = len - processed; + if (remaining > 0) { + std::memcpy(buffer, c + processed, remaining); + idx = remaining; + } + return *this; + } + + // convenience function for testing only + hasher& update(const std::string & s) noexcept { + return update(s.data(), s.size()); + } + + // finalize and get the hash value + // note: after calling digest, the hasher state is modified, do not call update() again + size_t digest() noexcept { + // if there are remaining bytes in buffer, fill the rest with zeros and process + if (idx > 0) { + for (size_t i = idx; i < block_size; ++i) { + buffer[i] = 0; + } + update_block(buffer); + idx = 0; + } + + return state; + } + +private: + // IMPORTANT: block must have at least block_size bytes + void update_block(const uint8_t * block) noexcept { + size_t blk = static_cast(block[0]) + | (static_cast(block[1]) << 8) + | (static_cast(block[2]) << 16) + | (static_cast(block[3]) << 24); + if constexpr (block_size == 8) { + blk = blk | (static_cast(block[4]) << 32) + | (static_cast(block[5]) << 40) + | (static_cast(block[6]) << 48) + | (static_cast(block[7]) << 56); + } + state ^= blk; + state *= prime; + } +}; + } // namespace jinja diff --git a/common/jinja/value.cpp b/common/jinja/value.cpp index d2ed824269..2d77068143 100644 --- a/common/jinja/value.cpp +++ b/common/jinja/value.cpp @@ -163,7 +163,7 @@ static value selectattr(const func_args & args) { args.ensure_vals(true, true, false, false); auto arr = args.get_pos(0)->as_array(); - auto attr_name = args.get_pos(1)->as_string().str(); + auto attribute = args.get_pos(1); auto out = mk_val(); value val_default = mk_val(); @@ -173,7 +173,7 @@ static value selectattr(const func_args & args) { if (!is_val(item)) { throw raised_exception("selectattr: item is not an object"); } - value attr_val = item->at(attr_name, val_default); + value attr_val = item->at(attribute, val_default); bool is_selected = attr_val->as_bool(); if constexpr (is_reject) is_selected = !is_selected; if (is_selected) out->push_back(item); @@ -217,7 +217,7 @@ static value selectattr(const func_args & args) { if (!is_val(item)) { throw raised_exception("selectattr: item is not an object"); } - value attr_val = item->at(attr_name, val_default); + value attr_val = item->at(attribute, val_default); func_args test_args(args.ctx); test_args.push_back(attr_val); // attribute value test_args.push_back(extra_arg); // extra argument @@ -741,6 +741,7 @@ const func_builtins & value_array_t::get_builtins() const { args.ensure_count(1, 4); args.ensure_vals(true, true, false, false); + auto val = args.get_pos(0); auto arg0 = args.get_pos(1); auto arg1 = args.get_pos(2, mk_val()); auto arg2 = args.get_pos(3, mk_val()); @@ -762,10 +763,8 @@ const func_builtins & value_array_t::get_builtins() const { if (step == 0) { throw raised_exception("slice step cannot be zero"); } - auto arr = slice(args.get_pos(0)->as_array(), start, stop, step); - auto res = mk_val(); - res->val_arr = std::move(arr); - return res; + auto arr = slice(val->as_array(), start, stop, step); + return is_val(val) ? mk_val(std::move(arr)) : mk_val(std::move(arr)); }}, {"selectattr", selectattr}, {"select", selectattr}, @@ -785,15 +784,14 @@ const func_builtins & value_array_t::get_builtins() const { } const int64_t attr_int = attr_is_int ? attribute->as_int() : 0; const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str(); - const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str(); std::string result; for (size_t i = 0; i < arr.size(); ++i) { value val_arr = arr[i]; if (!attribute->is_undefined()) { if (attr_is_int && is_val(val_arr)) { val_arr = val_arr->at(attr_int); - } else if (!attr_is_int && !attr_name.empty() && is_val(val_arr)) { - val_arr = val_arr->at(attr_name); + } else if (!attr_is_int && is_val(val_arr)) { + val_arr = val_arr->at(attribute); } } if (!is_val(val_arr) && !is_val(val_arr) && !is_val(val_arr)) { @@ -808,9 +806,7 @@ const func_builtins & value_array_t::get_builtins() const { }}, {"string", [](const func_args & args) -> value { args.ensure_vals(); - auto str = mk_val(); - gather_string_parts_recursive(args.get_pos(0), str); - return str; + return mk_val(args.get_pos(0)->as_string()); }}, {"tojson", tojson}, {"map", [](const func_args & args) -> value { @@ -821,26 +817,26 @@ const func_builtins & value_array_t::get_builtins() const { if (!is_val(args.get_args().at(1))) { throw not_implemented_exception("map: filter-mapping not implemented"); } + value val = args.get_pos(0); value attribute = args.get_kwarg_or_pos("attribute", 1); const bool attr_is_int = is_val(attribute); if (!is_val(attribute) && !attr_is_int) { throw raised_exception("map: attribute must be string or integer"); } const int64_t attr_int = attr_is_int ? attribute->as_int() : 0; - const std::string attr_name = attribute->as_string().str(); value default_val = args.get_kwarg("default", mk_val()); auto out = mk_val(); - auto arr = args.get_pos(0)->as_array(); + auto arr = val->as_array(); for (const auto & item : arr) { value attr_val; if (attr_is_int) { attr_val = is_val(item) ? item->at(attr_int, default_val) : default_val; } else { - attr_val = is_val(item) ? item->at(attr_name, default_val) : default_val; + attr_val = is_val(item) ? item->at(attribute, default_val) : default_val; } out->push_back(attr_val); } - return out; + return is_val(val) ? mk_val(std::move(out->as_array())) : out; }}, {"append", [](const func_args & args) -> value { args.ensure_count(2); @@ -867,6 +863,7 @@ const func_builtins & value_array_t::get_builtins() const { if (!is_val(args.get_pos(0))) { throw raised_exception("sort: first argument must be an array"); } + value val = args.get_pos(0); value val_reverse = args.get_kwarg_or_pos("reverse", 1); value val_case = args.get_kwarg_or_pos("case_sensitive", 2); value attribute = args.get_kwarg_or_pos("attribute", 3); @@ -875,8 +872,7 @@ const func_builtins & value_array_t::get_builtins() const { const bool reverse = val_reverse->as_bool(); // undefined == false const bool attr_is_int = is_val(attribute); const int64_t attr_int = attr_is_int ? attribute->as_int() : 0; - const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str(); - std::vector arr = cast_val(args.get_pos(0))->as_array(); // copy + std::vector arr = val->as_array(); // copy std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) { value val_a = a; value val_b = b; @@ -884,22 +880,23 @@ const func_builtins & value_array_t::get_builtins() const { if (attr_is_int && is_val(a) && is_val(b)) { val_a = a->at(attr_int); val_b = b->at(attr_int); - } else if (!attr_is_int && !attr_name.empty() && is_val(a) && is_val(b)) { - val_a = a->at(attr_name); - val_b = b->at(attr_name); + } else if (!attr_is_int && is_val(a) && is_val(b)) { + val_a = a->at(attribute); + val_b = b->at(attribute); } else { - throw raised_exception("sort: unsupported object attribute comparison"); + throw raised_exception("sort: unsupported object attribute comparison between " + a->type() + " and " + b->type()); } } return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt); }); - return mk_val(arr); + return is_val(val) ? mk_val(std::move(arr)) : mk_val(std::move(arr)); }}, {"reverse", [](const func_args & args) -> value { args.ensure_vals(); - std::vector arr = cast_val(args.get_pos(0))->as_array(); // copy + value val = args.get_pos(0); + std::vector arr = val->as_array(); // copy std::reverse(arr.begin(), arr.end()); - return mk_val(arr); + return is_val(val) ? mk_val(std::move(arr)) : mk_val(std::move(arr)); }}, {"unique", [](const func_args &) -> value { throw not_implemented_exception("Array unique builtin not implemented"); @@ -930,7 +927,7 @@ const func_builtins & value_object_t::get_builtins() const { default_val = args.get_pos(2); } const value obj = args.get_pos(0); - std::string key = args.get_pos(1)->as_string().str(); + const value key = args.get_pos(1); return obj->at(key, default_val); }}, {"keys", [](const func_args & args) -> value { @@ -938,7 +935,7 @@ const func_builtins & value_object_t::get_builtins() const { const auto & obj = args.get_pos(0)->as_ordered_object(); auto result = mk_val(); for (const auto & pair : obj) { - result->push_back(mk_val(pair.first)); + result->push_back(pair.first); } return result; }}, @@ -956,15 +953,16 @@ const func_builtins & value_object_t::get_builtins() const { const auto & obj = args.get_pos(0)->as_ordered_object(); auto result = mk_val(); for (const auto & pair : obj) { - auto item = mk_val(); - item->push_back(mk_val(pair.first)); - item->push_back(pair.second); + auto item = mk_val(pair); result->push_back(std::move(item)); } return result; }}, {"tojson", tojson}, - {"string", tojson}, + {"string", [](const func_args & args) -> value { + args.ensure_vals(); + return mk_val(args.get_pos(0)->as_string()); + }}, {"length", [](const func_args & args) -> value { args.ensure_vals(); const auto & obj = args.get_pos(0)->as_ordered_object(); @@ -985,11 +983,11 @@ const func_builtins & value_object_t::get_builtins() const { const bool reverse = val_reverse->as_bool(); // undefined == false const bool by_value = is_val(val_by) && val_by->as_string().str() == "value" ? true : false; auto result = mk_val(val_input); // copy - std::sort(result->val_obj.ordered.begin(), result->val_obj.ordered.end(), [&](const auto & a, const auto & b) { + std::sort(result->val_obj.begin(), result->val_obj.end(), [&](const auto & a, const auto & b) { if (by_value) { return value_compare(a.second, b.second, reverse ? value_compare_op::gt : value_compare_op::lt); } else { - return reverse ? a.first > b.first : a.first < b.first; + return value_compare(a.first, b.first, reverse ? value_compare_op::gt : value_compare_op::lt); } }); return result; @@ -1134,6 +1132,8 @@ void global_from_json(context & ctx, const nlohmann::ordered_json & json_obj, bo } } +// recursively convert value to JSON string +// TODO: avoid circular references static void value_to_json_internal(std::ostringstream & oss, const value & val, int curr_lvl, int indent, const std::string_view item_sep, const std::string_view key_sep) { auto indent_str = [indent, curr_lvl]() -> std::string { return (indent > 0) ? std::string(curr_lvl * indent, ' ') : ""; @@ -1196,7 +1196,8 @@ static void value_to_json_internal(std::ostringstream & oss, const value & val, size_t i = 0; for (const auto & pair : obj) { oss << indent_str() << (indent > 0 ? std::string(indent, ' ') : ""); - oss << "\"" << pair.first << "\"" << key_sep; + value_to_json_internal(oss, mk_val(pair.first->as_string().str()), curr_lvl + 1, indent, item_sep, key_sep); + oss << key_sep; value_to_json_internal(oss, pair.second, curr_lvl + 1, indent, item_sep, key_sep); if (i < obj.size() - 1) { oss << item_sep; @@ -1219,4 +1220,19 @@ std::string value_to_json(const value & val, int indent, const std::string_view return oss.str(); } +// TODO: avoid circular references +std::string value_to_string_repr(const value & val) { + if (is_val(val)) { + const std::string val_str = val->as_string().str(); + + if (val_str.find('\'') != std::string::npos) { + return value_to_json(val); + } else { + return "'" + val_str + "'"; + } + } else { + return val->as_repr(); + } +} + } // namespace jinja diff --git a/common/jinja/value.h b/common/jinja/value.h index ccb05c6fd4..a2f92d2c69 100644 --- a/common/jinja/value.h +++ b/common/jinja/value.h @@ -1,8 +1,10 @@ #pragma once #include "string.h" +#include "utils.h" #include +#include #include #include #include @@ -93,7 +95,8 @@ void global_from_json(context & ctx, const T_JSON & json_obj, bool mark_input); struct func_args; // function argument values -using func_handler = std::function; +using func_hptr = value(const func_args &); +using func_handler = std::function; using func_builtins = std::map; enum value_compare_op { eq, ge, gt, lt, ne }; @@ -103,28 +106,9 @@ struct value_t { int64_t val_int; double val_flt; string val_str; - bool val_bool; std::vector val_arr; - - struct map { - // once set to true, all keys must be numeric - // caveat: we only allow either all numeric keys or all non-numeric keys - // for now, this only applied to for_statement in case of iterating over object keys/items - bool is_key_numeric = false; - std::map unordered; - std::vector> ordered; - void insert(const std::string & key, const value & val) { - if (unordered.find(key) != unordered.end()) { - // if key exists, remove from ordered list - ordered.erase(std::remove_if(ordered.begin(), ordered.end(), - [&](const std::pair & p) { return p.first == key; }), - ordered.end()); - } - unordered[key] = val; - ordered.push_back({key, val}); - } - } val_obj; + std::vector> val_obj; func_handler val_func; @@ -139,6 +123,7 @@ struct value_t { value_t(const value_t &) = default; virtual ~value_t() = default; + // Note: only for debugging and error reporting purposes virtual std::string type() const { return ""; } virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); } @@ -146,7 +131,7 @@ struct value_t { virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); } virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); } virtual const std::vector & as_array() const { throw std::runtime_error(type() + " is not an array value"); } - virtual const std::vector> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); } + virtual const std::vector> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); } virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); } virtual bool is_none() const { return false; } virtual bool is_undefined() const { return false; } @@ -154,43 +139,66 @@ struct value_t { throw std::runtime_error("No builtins available for type " + type()); } - virtual bool has_key(const std::string & key) { - return val_obj.unordered.find(key) != val_obj.unordered.end(); - } - virtual value & at(const std::string & key, value & default_val) { - auto it = val_obj.unordered.find(key); - if (it == val_obj.unordered.end()) { - return default_val; - } - return val_obj.unordered.at(key); - } - virtual value & at(const std::string & key) { - auto it = val_obj.unordered.find(key); - if (it == val_obj.unordered.end()) { - throw std::runtime_error("Key '" + key + "' not found in value of type " + type()); - } - return val_obj.unordered.at(key); - } - virtual value & at(int64_t index, value & default_val) { - if (index < 0) { - index += val_arr.size(); - } - if (index < 0 || static_cast(index) >= val_arr.size()) { - return default_val; - } - return val_arr[index]; - } - virtual value & at(int64_t index) { - if (index < 0) { - index += val_arr.size(); - } - if (index < 0 || static_cast(index) >= val_arr.size()) { - throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size())); - } - return val_arr[index]; - } + virtual bool has_key(const value &) { throw std::runtime_error(type() + " is not an object value"); } + virtual void insert(const value & /* key */, const value & /* val */) { throw std::runtime_error(type() + " is not an object value"); } + virtual value & at(const value & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); } + virtual value & at(const value & /* key */) { throw std::runtime_error(type() + " is not an object value"); } + virtual value & at(const std::string & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); } + virtual value & at(const std::string & /* key */) { throw std::runtime_error(type() + " is not an object value"); } + virtual value & at(int64_t /* idx */, value & /* default_val */) { throw std::runtime_error(type() + " is not an array value"); } + virtual value & at(int64_t /* idx */) { throw std::runtime_error(type() + " is not an array value"); } + virtual bool is_numeric() const { return false; } + virtual bool is_hashable() const { return false; } + virtual bool is_immutable() const { return true; } + virtual hasher unique_hash() const noexcept = 0; + // TODO: C++20 <=> operator + // NOTE: We are treating == as equivalent (for normal comparisons) and != as strict nonequal (for strict (is) comparisons) + virtual bool operator==(const value_t & other) const { return equivalent(other); } + virtual bool operator!=(const value_t & other) const { return nonequal(other); } + + // Note: only for debugging purposes virtual std::string as_repr() const { return as_string().str(); } + +protected: + virtual bool equivalent(const value_t &) const = 0; + virtual bool nonequal(const value_t & other) const { return !equivalent(other); } +}; + +// +// utils +// + +const func_builtins & global_builtins(); + +std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": "); + +// Note: only used for debugging purposes +std::string value_to_string_repr(const value & val); + +struct not_implemented_exception : public std::runtime_error { + not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {} +}; + +struct value_hasher { + size_t operator()(const value & val) const noexcept { + return val->unique_hash().digest(); + } +}; + +struct value_equivalence { + bool operator()(const value & lhs, const value & rhs) const { + return *lhs == *rhs; + } + bool operator()(const std::pair & lhs, const std::pair & rhs) const { + return *(lhs.first) == *(rhs.first) && *(lhs.second) == *(rhs.second); + } +}; + +struct value_equality { + bool operator()(const value & lhs, const value & rhs) const { + return !(*lhs != *rhs); + } }; // @@ -198,24 +206,49 @@ struct value_t { // struct value_int_t : public value_t { - value_int_t(int64_t v) { val_int = v; } + value_int_t(int64_t v) { + val_int = v; + val_flt = static_cast(v); + if (static_cast(val_flt) != v) { + val_flt = v < 0 ? -INFINITY : INFINITY; + } + } virtual std::string type() const override { return "Integer"; } virtual int64_t as_int() const override { return val_int; } - virtual double as_float() const override { return static_cast(val_int); } + virtual double as_float() const override { return val_flt; } virtual string as_string() const override { return std::to_string(val_int); } virtual bool as_bool() const override { return val_int != 0; } virtual const func_builtins & get_builtins() const override; + virtual bool is_numeric() const override { return true; } + virtual bool is_hashable() const override { return true; } + virtual hasher unique_hash() const noexcept override { + return hasher(typeid(*this)) + .update(&val_int, sizeof(val_int)) + .update(&val_flt, sizeof(val_flt)); + } +protected: + virtual bool equivalent(const value_t & other) const override { + return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt; + } + virtual bool nonequal(const value_t & other) const override { + return !(typeid(*this) == typeid(other) && val_int == other.val_int); + } }; using value_int = std::shared_ptr; struct value_float_t : public value_t { - value_float_t(double v) { val_flt = v; } + value val; + value_float_t(double v) { + val_flt = v; + val_int = std::isfinite(v) ? static_cast(v) : 0; + val = mk_val(val_int); + } virtual std::string type() const override { return "Float"; } virtual double as_float() const override { return val_flt; } - virtual int64_t as_int() const override { return static_cast(val_flt); } + virtual int64_t as_int() const override { return val_int; } virtual string as_string() const override { std::string out = std::to_string(val_flt); out.erase(out.find_last_not_of('0') + 1, std::string::npos); // remove trailing zeros @@ -226,6 +259,24 @@ struct value_float_t : public value_t { return val_flt != 0.0; } virtual const func_builtins & get_builtins() const override; + virtual bool is_numeric() const override { return true; } + virtual bool is_hashable() const override { return true; } + virtual hasher unique_hash() const noexcept override { + if (static_cast(val_int) == val_flt) { + return val->unique_hash(); + } else { + return hasher(typeid(*this)) + .update(&val_int, sizeof(val_int)) + .update(&val_flt, sizeof(val_flt)); + } + } +protected: + virtual bool equivalent(const value_t & other) const override { + return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt; + } + virtual bool nonequal(const value_t & other) const override { + return !(typeid(*this) == typeid(other) && val_flt == other.val_flt); + } }; using value_float = std::shared_ptr; @@ -247,19 +298,49 @@ struct value_string_t : public value_t { return val_str.length() > 0; } virtual const func_builtins & get_builtins() const override; + virtual bool is_hashable() const override { return true; } + virtual hasher unique_hash() const noexcept override { + const auto type_hash = typeid(*this).hash_code(); + auto hash = hasher(); + hash.update(&type_hash, sizeof(type_hash)); + val_str.hash_update(hash); + return hash; + } void mark_input() { val_str.mark_input(); } +protected: + virtual bool equivalent(const value_t & other) const override { + return typeid(*this) == typeid(other) && val_str.str() == other.val_str.str(); + } }; using value_string = std::shared_ptr; struct value_bool_t : public value_t { - value_bool_t(bool v) { val_bool = v; } + value val; + value_bool_t(bool v) { + val_int = static_cast(v); + val_flt = static_cast(v); + val = mk_val(val_int); + } virtual std::string type() const override { return "Boolean"; } - virtual bool as_bool() const override { return val_bool; } - virtual string as_string() const override { return std::string(val_bool ? "True" : "False"); } + virtual int64_t as_int() const override { return val_int; } + virtual bool as_bool() const override { return val_int; } + virtual string as_string() const override { return std::string(val_int ? "True" : "False"); } virtual const func_builtins & get_builtins() const override; + virtual bool is_numeric() const override { return true; } + virtual bool is_hashable() const override { return true; } + virtual hasher unique_hash() const noexcept override { + return val->unique_hash(); + } +protected: + virtual bool equivalent(const value_t & other) const override { + return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt; + } + virtual bool nonequal(const value_t & other) const override { + return !(typeid(*this) == typeid(other) && val_int == other.val_int); + } }; using value_bool = std::shared_ptr; @@ -269,13 +350,34 @@ struct value_array_t : public value_t { value_array_t(value & v) { val_arr = v->val_arr; } + value_array_t(std::vector && arr) { + val_arr = arr; + } value_array_t(const std::vector & arr) { val_arr = arr; } - void reverse() { std::reverse(val_arr.begin(), val_arr.end()); } - void push_back(const value & val) { val_arr.push_back(val); } - void push_back(value && val) { val_arr.push_back(std::move(val)); } + void reverse() { + if (is_immutable()) { + throw std::runtime_error("Attempting to modify immutable type"); + } + std::reverse(val_arr.begin(), val_arr.end()); + } + void push_back(const value & val) { + if (is_immutable()) { + throw std::runtime_error("Attempting to modify immutable type"); + } + val_arr.push_back(val); + } + void push_back(value && val) { + if (is_immutable()) { + throw std::runtime_error("Attempting to modify immutable type"); + } + val_arr.push_back(std::move(val)); + } value pop_at(int64_t index) { + if (is_immutable()) { + throw std::runtime_error("Attempting to modify immutable type"); + } if (index < 0) { index = static_cast(val_arr.size()) + index; } @@ -287,64 +389,225 @@ struct value_array_t : public value_t { return val; } virtual std::string type() const override { return "Array"; } + virtual bool is_immutable() const override { return false; } virtual const std::vector & as_array() const override { return val_arr; } virtual string as_string() const override { + const bool immutable = is_immutable(); std::ostringstream ss; - ss << "["; + ss << (immutable ? "(" : "["); for (size_t i = 0; i < val_arr.size(); i++) { if (i > 0) ss << ", "; - ss << val_arr.at(i)->as_repr(); + value val = val_arr.at(i); + ss << value_to_string_repr(val); } - ss << "]"; + if (immutable && val_arr.size() == 1) { + ss << ","; + } + ss << (immutable ? ")" : "]"); return ss.str(); } virtual bool as_bool() const override { return !val_arr.empty(); } + virtual value & at(int64_t index, value & default_val) override { + if (index < 0) { + index += val_arr.size(); + } + if (index < 0 || static_cast(index) >= val_arr.size()) { + return default_val; + } + return val_arr[index]; + } + virtual value & at(int64_t index) override { + if (index < 0) { + index += val_arr.size(); + } + if (index < 0 || static_cast(index) >= val_arr.size()) { + throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size())); + } + return val_arr[index]; + } virtual const func_builtins & get_builtins() const override; + virtual bool is_hashable() const override { + if (std::all_of(val_arr.begin(), val_arr.end(), [&](auto & val) -> bool { + return val->is_immutable() && val->is_hashable(); + })) { + return true; + } + return false; + } + virtual hasher unique_hash() const noexcept override { + auto hash = hasher(typeid(*this)); + for (const auto & val : val_arr) { + // must use digest to prevent problems from "concatenation" property of hasher + // for ex. hash of [ "ab", "c" ] should be different from [ "a", "bc" ] + const size_t val_hash = val->unique_hash().digest(); + hash.update(&val_hash, sizeof(size_t)); + } + return hash; + } +protected: + virtual bool equivalent(const value_t & other) const override { + return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence()); + } }; using value_array = std::shared_ptr; +struct value_tuple_t : public value_array_t { + value_tuple_t(value & v) { + val_arr = v->val_arr; + } + value_tuple_t(std::vector && arr) { + val_arr = arr; + } + value_tuple_t(const std::vector & arr) { + val_arr = arr; + } + value_tuple_t(const std::pair & pair) { + val_arr.push_back(pair.first); + val_arr.push_back(pair.second); + } + virtual std::string type() const override { return "Tuple"; } + virtual bool is_immutable() const override { return true; } +}; +using value_tuple = std::shared_ptr; + + struct value_object_t : public value_t { + std::unordered_map unordered; bool has_builtins = true; // context and loop objects do not have builtins value_object_t() = default; value_object_t(value & v) { val_obj = v->val_obj; - } - value_object_t(const std::map & obj) { - for (const auto & pair : obj) { - val_obj.insert(pair.first, pair.second); + for (const auto & pair : val_obj) { + unordered[pair.first] = pair.second; } } - value_object_t(const std::vector> & obj) { + value_object_t(const std::map & obj) { for (const auto & pair : obj) { - val_obj.insert(pair.first, pair.second); + insert(pair.first, pair.second); + } + } + value_object_t(const std::vector> & obj) { + for (const auto & pair : obj) { + insert(pair.first, pair.second); } } void insert(const std::string & key, const value & val) { - val_obj.insert(key, val); + insert(mk_val(key), val); } virtual std::string type() const override { return "Object"; } - virtual const std::vector> & as_ordered_object() const override { return val_obj.ordered; } + virtual bool is_immutable() const override { return false; } + virtual const std::vector> & as_ordered_object() const override { return val_obj; } + virtual string as_string() const override { + std::ostringstream ss; + ss << "{"; + for (size_t i = 0; i < val_obj.size(); i++) { + if (i > 0) ss << ", "; + auto & [key, val] = val_obj.at(i); + ss << value_to_string_repr(key) << ": " << value_to_string_repr(val); + } + ss << "}"; + return ss.str(); + } virtual bool as_bool() const override { - return !val_obj.unordered.empty(); + return !unordered.empty(); + } + virtual bool has_key(const value & key) override { + if (!key->is_immutable() || !key->is_hashable()) { + throw std::runtime_error("Object key of unhashable type: " + key->type()); + } + return unordered.find(key) != unordered.end(); + } + virtual void insert(const value & key, const value & val) override { + bool replaced = false; + if (is_immutable()) { + throw std::runtime_error("Attempting to modify immutable type"); + } + if (has_key(key)) { + // if key exists, replace value in ordered list instead of appending + for (auto & pair : val_obj) { + if (*(pair.first) == *key) { + pair.second = val; + replaced = true; + break; + } + } + } + unordered[key] = val; + if (!replaced) { + val_obj.push_back({key, val}); + } + } + virtual value & at(const value & key, value & default_val) override { + if (!has_key(key)) { + return default_val; + } + return unordered.at(key); + } + virtual value & at(const value & key) override { + if (!has_key(key)) { + throw std::runtime_error("Key '" + key->as_string().str() + "' not found in value of type " + type()); + } + return unordered.at(key); + } + virtual value & at(const std::string & key, value & default_val) override { + value key_val = mk_val(key); + return at(key_val, default_val); + } + virtual value & at(const std::string & key) override { + value key_val = mk_val(key); + return at(key_val); } virtual const func_builtins & get_builtins() const override; + virtual bool is_hashable() const override { + if (std::all_of(val_obj.begin(), val_obj.end(), [&](auto & pair) -> bool { + const auto & val = pair.second; + return val->is_immutable() && val->is_hashable(); + })) { + return true; + } + return false; + } + virtual hasher unique_hash() const noexcept override { + auto hash = hasher(typeid(*this)); + for (const auto & [key, val] : val_obj) { + // must use digest to prevent problems from "concatenation" property of hasher + // for ex. hash of key="ab", value="c" should be different from key="a", value="bc" + const size_t key_hash = key->unique_hash().digest(); + const size_t val_hash = val->unique_hash().digest(); + hash.update(&key_hash, sizeof(key_hash)); + hash.update(&val_hash, sizeof(val_hash)); + } + return hash; + } +protected: + virtual bool equivalent(const value_t & other) const override { + return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence()); + } }; using value_object = std::shared_ptr; // -// null and undefined types +// none and undefined types // struct value_none_t : public value_t { virtual std::string type() const override { return "None"; } virtual bool is_none() const override { return true; } virtual bool as_bool() const override { return false; } - virtual string as_string() const override { return string("None"); } + virtual string as_string() const override { return string(type()); } virtual std::string as_repr() const override { return type(); } virtual const func_builtins & get_builtins() const override; + virtual bool is_hashable() const override { return true; } + virtual hasher unique_hash() const noexcept override { + return hasher(typeid(*this)); + } +protected: + virtual bool equivalent(const value_t & other) const override { + return typeid(*this) == typeid(other); + } }; using value_none = std::shared_ptr; @@ -356,6 +619,13 @@ struct value_undefined_t : public value_t { virtual bool as_bool() const override { return false; } virtual std::string as_repr() const override { return type(); } virtual const func_builtins & get_builtins() const override; + virtual hasher unique_hash() const noexcept override { + return hasher(typeid(*this)); + } +protected: + virtual bool equivalent(const value_t & other) const override { + return is_undefined() == other.is_undefined(); + } }; using value_undefined = std::shared_ptr; @@ -436,7 +706,23 @@ struct value_func_t : public value_t { return val_func(new_args); } virtual std::string type() const override { return "Function"; } - virtual std::string as_repr() const override { return type(); } + virtual std::string as_repr() const override { return type() + "<" + name + ">(" + (arg0 ? arg0->as_repr() : "") + ")"; } + virtual bool is_hashable() const override { return false; } + virtual hasher unique_hash() const noexcept override { + // Note: this is unused for now, we don't support function as object keys + // use function pointer as unique identifier + const auto target = val_func.target(); + return hasher(typeid(*this)).update(&target, sizeof(target)); + } +protected: + virtual bool equivalent(const value_t & other) const override { + // Note: this is unused for now, we don't support function as object keys + // compare function pointers + // (val_func == other.val_func does not work as std::function::operator== is only used for nullptr check) + const auto target_this = this->val_func.target(); + const auto target_other = other.val_func.target(); + return typeid(*this) == typeid(other) && target_this == target_other; + } }; using value_func = std::shared_ptr; @@ -447,18 +733,21 @@ struct value_kwarg_t : public value_t { value_kwarg_t(const std::string & k, const value & v) : key(k), val(v) {} virtual std::string type() const override { return "KwArg"; } virtual std::string as_repr() const override { return type(); } + virtual bool is_hashable() const override { return true; } + virtual hasher unique_hash() const noexcept override { + const auto type_hash = typeid(*this).hash_code(); + auto hash = val->unique_hash(); + hash.update(&type_hash, sizeof(type_hash)) + .update(key.data(), key.size()); + return hash; + } +protected: + virtual bool equivalent(const value_t & other) const override { + const value_kwarg_t & other_val = static_cast(other); + return typeid(*this) == typeid(other) && key == other_val.key && val == other_val.val; + } }; using value_kwarg = std::shared_ptr; -// utils - -const func_builtins & global_builtins(); -std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": "); - -struct not_implemented_exception : public std::runtime_error { - not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {} -}; - - } // namespace jinja diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index e142900723..d2a1437ca4 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -481,7 +481,7 @@ int main_automated_tests(void) { /* .name= */ "Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)", /* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n", /* .expected_output= */ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[INST] Another question[/INST]", - /* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[INST] You are a helpful assistant\n\nAnother question[/INST]", + /* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[AVAILABLE_TOOLS] [[/AVAILABLE_TOOLS][INST] You are a helpful assistant\n\nAnother question[/INST]", /* .bos_token= */ "", /* .eos_token= */ "", }, @@ -489,7 +489,7 @@ int main_automated_tests(void) { /* .name= */ "Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)", /* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n", /* .expected_output= */ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there[INST]Who are you[/INST] I am an assistant [INST]Another question[/INST]", - /* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there[INST]Who are you[/INST] I am an assistant [INST]You are a helpful assistant\n\nAnother question[/INST]", + /* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there[INST]Who are you[/INST] I am an assistant [AVAILABLE_TOOLS][[/AVAILABLE_TOOLS][INST]You are a helpful assistant\n\nAnother question[/INST]", /* .bos_token= */ "", /* .eos_token= */ "", }, diff --git a/tests/test-jinja.cpp b/tests/test-jinja.cpp index 54d3a0923b..7c6eeb311c 100644 --- a/tests/test-jinja.cpp +++ b/tests/test-jinja.cpp @@ -9,6 +9,7 @@ #include "jinja/runtime.h" #include "jinja/parser.h" #include "jinja/lexer.h" +#include "jinja/utils.h" #include "testing.h" @@ -30,6 +31,7 @@ static void test_tests(testing & t); static void test_string_methods(testing & t); static void test_array_methods(testing & t); static void test_object_methods(testing & t); +static void test_hasher(testing & t); static void test_fuzzing(testing & t); static bool g_python_mode = false; @@ -67,6 +69,7 @@ int main(int argc, char *argv[]) { t.test("array methods", test_array_methods); t.test("object methods", test_object_methods); if (!g_python_mode) { + t.test("hasher", test_hasher); t.test("fuzzing", test_fuzzing); } @@ -156,6 +159,18 @@ static void test_conditionals(testing & t) { "big" ); + test_template(t, "object comparison", + "{% if {0: 1, none: 2, 1.0: 3, '0': 4, true: 5} == {false: 1, none: 2, 1: 5, '0': 4} %}equal{% endif %}", + json::object(), + "equal" + ); + + test_template(t, "array comparison", + "{% if [0, 1.0, false] == [false, 1, 0.0] %}equal{% endif %}", + json::object(), + "equal" + ); + test_template(t, "logical and", "{% if a and b %}both{% endif %}", {{"a", true}, {"b", true}}, @@ -358,6 +373,30 @@ static void test_expressions(testing & t) { "b" ); + test_template(t, "array negative access", + "{{ items[-1] }}", + {{"items", json::array({"a", "b", "c"})}}, + "c" + ); + + test_template(t, "array slice", + "{{ items[1:-1]|string }}", + {{"items", json::array({"a", "b", "c"})}}, + "['b']" + ); + + test_template(t, "array slice step", + "{{ items[::2]|string }}", + {{"items", json::array({"a", "b", "c"})}}, + "['a', 'c']" + ); + + test_template(t, "tuple slice", + "{{ ('a', 'b', 'c')[::-1]|string }}", + json::object(), + "('c', 'b', 'a')" + ); + test_template(t, "arithmetic", "{{ (a + b) * c }}", {{"a", 2}, {"b", 3}, {"c", 4}}, @@ -401,6 +440,36 @@ static void test_set_statement(testing & t) { json::object(), "1" ); + + test_template(t, "set dict with mixed type keys", + "{% set d = {0: 1, none: 2, 1.0: 3, '0': 4, (0, 0): 5, false: 6, 1: 7} %}{{ d[(0, 0)] + d[0] + d[none] + d['0'] + d[false] + d[1.0] + d[1] }}", + json::object(), + "37" + ); + + test_template(t, "print dict with mixed type keys", + "{% set d = {0: 1, none: 2, 1.0: 3, '0': 4, (0, 0): 5, true: 6} %}{{ d|string }}", + json::object(), + "{0: 1, None: 2, 1.0: 6, '0': 4, (0, 0): 5}" + ); + + test_template(t, "print array with mixed types", + "{% set d = [0, none, 1.0, '0', true, (0, 0)] %}{{ d|string }}", + json::object(), + "[0, None, 1.0, '0', True, (0, 0)]" + ); + + test_template(t, "object member assignment with mixed key types", + "{% set d = namespace() %}{% set d.a = 123 %}{{ d['a'] == 123 }}", + json::object(), + "True" + ); + + test_template(t, "tuple unpacking", + "{% set t = (1, 2, 3) %}{% set a, b, c = t %}{{ a + b + c }}", + json::object(), + "6" + ); } static void test_filters(testing & t) { @@ -1312,6 +1381,154 @@ static void test_object_methods(testing & t) { {{"obj", {{"a", "b"}}}}, "True True" ); + + test_template(t, "expression as object key", + "{% set d = {'ab': 123} %}{{ d['a' + 'b'] == 123 }}", + json::object(), + "True" + ); + + test_template(t, "numeric as object key (template: Seed-OSS)", + "{% set d = {1: 'a', 2: 'b'} %}{{ d[1] == 'a' and d[2] == 'b' }}", + json::object(), + "True" + ); +} + +static void test_hasher(testing & t) { + static const std::vector> chunk_sizes = { + {1, 2}, + {1, 16}, + {8, 1}, + {1, 1024}, + {5, 512}, + {16, 256}, + {45, 122}, + {70, 634}, + }; + + static auto random_bytes = [](size_t length) -> std::string { + std::string data; + data.resize(length); + for (size_t i = 0; i < length; ++i) { + data[i] = static_cast(rand() % 256); + } + return data; + }; + + t.test("state unchanged with empty input", [](testing & t) { + jinja::hasher hasher; + hasher.update("some data"); + size_t initial_state = hasher.digest(); + hasher.update("", 0); + size_t final_state = hasher.digest(); + t.assert_true("Hasher state should remain unchanged", initial_state == final_state); + }); + + t.test("different inputs produce different hashes", [](testing & t) { + jinja::hasher hasher1; + hasher1.update("data one"); + size_t hash1 = hasher1.digest(); + + jinja::hasher hasher2; + hasher2.update("data two"); + size_t hash2 = hasher2.digest(); + + t.assert_true("Different inputs should produce different hashes", hash1 != hash2); + }); + + t.test("same inputs produce same hashes", [](testing & t) { + jinja::hasher hasher1; + hasher1.update("consistent data"); + size_t hash1 = hasher1.digest(); + + jinja::hasher hasher2; + hasher2.update("consistent data"); + size_t hash2 = hasher2.digest(); + + t.assert_true("Same inputs should produce same hashes", hash1 == hash2); + }); + + t.test("property: update(a ~ b) == update(a).update(b)", [](testing & t) { + for (const auto & [size1, size2] : chunk_sizes) { + std::string data1 = random_bytes(size1); + std::string data2 = random_bytes(size2); + + jinja::hasher hasher1; + hasher1.update(data1); + hasher1.update(data2); + size_t hash1 = hasher1.digest(); + + jinja::hasher hasher2; + hasher2.update(data1 + data2); + size_t hash2 = hasher2.digest(); + + t.assert_true( + "Hashing in multiple updates should match single update (" + std::to_string(size1) + ", " + std::to_string(size2) + ")", + hash1 == hash2); + } + }); + + t.test("property: update(a ~ b) == update(a).update(b) with more update passes", [](testing & t) { + static const std::vector sizes = {3, 732, 131, 13, 17, 256, 436, 99, 4}; + + jinja::hasher hasher1; + jinja::hasher hasher2; + + std::string combined_data; + for (size_t size : sizes) { + std::string data = random_bytes(size); + hasher1.update(data); + combined_data += data; + } + + hasher2.update(combined_data); + size_t hash1 = hasher1.digest(); + size_t hash2 = hasher2.digest(); + t.assert_true( + "Hashing in multiple updates should match single update with many chunks", + hash1 == hash2); + }); + + t.test("property: non associativity of update", [](testing & t) { + for (const auto & [size1, size2] : chunk_sizes) { + std::string data1 = random_bytes(size1); + std::string data2 = random_bytes(size2); + + jinja::hasher hasher1; + hasher1.update(data1); + hasher1.update(data2); + size_t hash1 = hasher1.digest(); + + jinja::hasher hasher2; + hasher2.update(data2); + hasher2.update(data1); + size_t hash2 = hasher2.digest(); + + t.assert_true( + "Hashing order should matter (" + std::to_string(size1) + ", " + std::to_string(size2) + ")", + hash1 != hash2); + } + }); + + t.test("property: different lengths produce different hashes (padding block size)", [](testing & t) { + std::string random_data = random_bytes(64); + + jinja::hasher hasher1; + hasher1.update(random_data); + size_t hash1 = hasher1.digest(); + + for (int i = 0; i < 16; ++i) { + random_data.push_back('A'); // change length + jinja::hasher hasher2; + hasher2.update(random_data); + size_t hash2 = hasher2.digest(); + + t.assert_true("Different lengths should produce different hashes (length " + std::to_string(random_data.size()) + ")", hash1 != hash2); + + hash1 = hash2; + } + }); } static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {