From 45c194622efbd32660cc4fdf83ac8c32dcd20c3c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 15:33:14 +0100 Subject: [PATCH] support binded functions --- common/jinja/jinja-string.h | 36 +++++++ common/jinja/jinja-value.h | 36 ++++++- common/jinja/jinja-vm-builtins.cpp | 58 ++++++++++- common/jinja/jinja-vm.cpp | 151 +++++++++++++++++------------ common/jinja/jinja-vm.h | 33 ++++++- tests/test-chat-jinja.cpp | 16 +-- 6 files changed, 254 insertions(+), 76 deletions(-) diff --git a/common/jinja/jinja-string.h b/common/jinja/jinja-string.h index fb3371271f..d26bb1e20c 100644 --- a/common/jinja/jinja-string.h +++ b/common/jinja/jinja-string.h @@ -16,6 +16,24 @@ namespace jinja { struct string_part { bool is_input = false; // may skip parsing special tokens if true std::string val; + + bool is_uppercase() const { + for (char c : val) { + if (std::islower(static_cast(c))) { + return false; + } + } + return true; + } + + bool is_lowercase() const { + for (char c : val) { + if (std::isupper(static_cast(c))) { + return false; + } + } + return true; + } }; struct string { @@ -67,6 +85,24 @@ struct string { return true; } + bool is_uppercase() const { + for (const auto & part : parts) { + if (!part.is_uppercase()) { + return false; + } + } + return true; + } + + bool is_lowercase() const { + for (const auto & part : parts) { + if (!part.is_lowercase()) { + return false; + } + } + return true; + } + // mark this string as input if other has ALL parts as input void mark_input_based_on(const string & other) { if (other.all_parts_are_input()) { diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 74366de9ba..787cec46b3 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -107,7 +107,7 @@ struct value_t { virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); } virtual const std::vector & as_array() const { throw std::runtime_error(type() + " is not an array value"); } virtual const std::map & as_object() const { throw std::runtime_error(type() + " is not an object value"); } - virtual value invoke(const func_args &) const { throw std::runtime_error("Not a function value"); } + virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); } virtual bool is_null() const { return false; } virtual bool is_undefined() const { return false; } virtual const func_builtins & get_builtins() const { @@ -221,6 +221,9 @@ struct value_array_t : public value_t { ss << "]"; return ss.str(); } + virtual bool as_bool() const override { + return !val_arr->empty(); + } virtual const func_builtins & get_builtins() const override; }; using value_array = std::unique_ptr; @@ -251,17 +254,44 @@ struct value_object_t : public value_t { tmp->val_obj = this->val_obj; return tmp; } + virtual bool as_bool() const override { + return !val_obj->empty(); + } virtual const func_builtins & get_builtins() const override; }; using value_object = std::unique_ptr; struct value_func_t : public value_t { - value_func_t(func_handler & func) { + std::string name; // for debugging + value arg0; // bound "this" argument, if any + value_func_t(const value_func_t & other) { + val_func = other.val_func; + name = other.name; + if (other.arg0) { + arg0 = other.arg0->clone(); + } + } + value_func_t(const func_handler & func, std::string func_name = "") { val_func = func; + name = func_name; + } + value_func_t(const func_handler & func, const value & arg_this, std::string func_name = "") { + val_func = func; + name = func_name; + arg0 = arg_this->clone(); } virtual value invoke(const func_args & args) const override { - return val_func(args); + if (arg0) { + func_args new_args; + new_args.args.push_back(arg0->clone()); + for (const auto & a : args.args) { + new_args.args.push_back(a->clone()); + } + return val_func(new_args); + } else { + return val_func(args); + } } virtual std::string type() const override { return "Function"; } virtual std::string as_repr() const override { return type(); } diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index e8c8eee993..160001e522 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -8,13 +8,69 @@ namespace jinja { +template +static value test_type_fn(const func_args & args) { + args.ensure_count(1); + bool is_type = is_val(args.args[0]); + return mk_val(is_type); +} +template +static value test_type_fn(const func_args & args) { + args.ensure_count(1); + bool is_type = is_val(args.args[0]) || is_val(args.args[0]); + return mk_val(is_type); +} + const func_builtins & global_builtins() { static const func_builtins builtins = { {"raise_exception", [](const func_args & args) -> value { - args.ensure_count(1); + args.ensure_vals(); std::string msg = args.args[0]->as_string().str(); throw raised_exception("Jinja Exception: " + msg); }}, + + // tests + {"test_is_boolean", test_type_fn}, + {"test_is_callable", test_type_fn}, + {"test_is_odd", [](const func_args & args) -> value { + args.ensure_vals(); + int64_t val = args.args[0]->as_int(); + return mk_val(val % 2 != 0); + }}, + {"test_is_even", [](const func_args & args) -> value { + args.ensure_vals(); + int64_t val = args.args[0]->as_int(); + return mk_val(val % 2 == 0); + }}, + {"test_is_false", [](const func_args & args) -> value { + args.ensure_count(1); + bool val = is_val(args.args[0]) && !args.args[0]->as_bool(); + return mk_val(val); + }}, + {"test_is_true", [](const func_args & args) -> value { + args.ensure_count(1); + bool val = is_val(args.args[0]) && args.args[0]->as_bool(); + return mk_val(val); + }}, + {"test_is_string", test_type_fn}, + {"test_is_integer", test_type_fn}, + {"test_is_number", test_type_fn}, + {"test_is_iterable", test_type_fn}, + {"test_is_mapping", test_type_fn}, + {"test_is_lower", [](const func_args & args) -> value { + args.ensure_vals(); + return mk_val(args.args[0]->val_str.is_lowercase()); + }}, + {"test_is_upper", [](const func_args & args) -> value { + args.ensure_vals(); + return mk_val(args.args[0]->val_str.is_uppercase()); + }}, + {"test_is_none", test_type_fn}, + {"test_is_defined", [](const func_args & args) -> value { + args.ensure_count(1); + return mk_val(!is_val(args.args[0])); + }}, + {"test_is_undefined", test_type_fn}, }; return builtins; } diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index c6861eeb39..bd4d53bded 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -8,7 +8,7 @@ #include #include -#define JJ_DEBUG(msg, ...) printf("jinja-vm: " msg "\n", __VA_ARGS__) +#define JJ_DEBUG(msg, ...) printf("jinja-vm:%3d : " msg "\n", __LINE__, __VA_ARGS__) //#define JJ_DEBUG(msg, ...) // no-op namespace jinja { @@ -44,7 +44,7 @@ value identifier::execute(context & ctx) { value binary_expression::execute(context & ctx) { value left_val = left->execute(ctx); - JJ_DEBUG("Executing binary expression with operator '%s'", op.value.c_str()); + JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right->type().c_str()); // Logical operators if (op.value == "and") { @@ -168,20 +168,19 @@ value binary_expression::execute(context & ctx) { throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type()); } +static value try_builtin_func(const std::string & name, const value & input) { + auto builtins = input->get_builtins(); + auto it = builtins.find(name); + if (it != builtins.end()) { + JJ_DEBUG("Binding built-in '%s'", name.c_str()); + return mk_val(it->second, input, name); + } + throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type()); +} + value filter_expression::execute(context & ctx) { value input = operand->execute(ctx); - auto try_builtin = [&](const std::string & name) -> value { - auto builtins = input->get_builtins(); - auto it = builtins.find(name); - if (it != builtins.end()) { - func_args args; - args.args.push_back(input->clone()); - return it->second(args); - } - throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type()); - }; - if (is_stmt(filter)) { auto filter_val = dynamic_cast(filter.get())->val; @@ -190,35 +189,12 @@ value filter_expression::execute(context & ctx) { throw std::runtime_error("to_json filter not implemented"); } - if (is_val(input)) { - auto res = try_builtin(filter_val); - if (res) { - return res; - } - throw std::runtime_error("Unknown filter '" + filter_val + "' for array"); - - } else if (is_val(input)) { - auto str = input->as_string(); - auto builtins = input->get_builtins(); - if (filter_val == "trim") { - filter_val = "strip"; // alias - } - auto res = try_builtin(filter_val); - if (res) { - return res; - } - throw std::runtime_error("Unknown filter '" + filter_val + "' for string"); - - } else if (is_val(input) || is_val(input)) { - auto res = try_builtin(filter_val); - if (res) { - return res; - } - throw std::runtime_error("Unknown filter '" + filter_val + "' for number"); - - } else { - throw std::runtime_error("Filters not supported for type " + input->type()); + auto str = input->as_string(); + if (filter_val == "trim") { + filter_val = "strip"; // alias } + JJ_DEBUG("Applying filter '%s' to %s", filter_val.c_str(), input->type().c_str()); + return try_builtin_func(filter_val, input); } else if (is_stmt(filter)) { // TODO @@ -230,6 +206,44 @@ value filter_expression::execute(context & ctx) { } } +value test_expression::execute(context & ctx) { + // NOTE: "value is something" translates to function call "test_is_something(value)" + const auto & builtins = global_builtins(); + if (!is_stmt(test)) { + throw std::runtime_error("Invalid test expression"); + } + + auto test_id = dynamic_cast(test.get())->val; + auto it = builtins.find("test_is_" + test_id); + JJ_DEBUG("Test expression %s '%s'", operand->type().c_str(), test_id.c_str()); + if (it == builtins.end()) { + throw std::runtime_error("Unknown test '" + test_id + "'"); + } + + func_args args; + args.args.push_back(operand->execute(ctx)); + return it->second(args); +} + +value unary_expression::execute(context & ctx) { + value operand_val = argument->execute(ctx); + JJ_DEBUG("Executing unary expression with operator '%s'", op.value.c_str()); + + if (op.value == "not") { + return mk_val(!operand_val->as_bool()); + } else if (op.value == "-") { + if (is_val(operand_val)) { + return mk_val(-operand_val->as_int()); + } else if (is_val(operand_val)) { + return mk_val(-operand_val->as_float()); + } else { + throw std::runtime_error("Unary - operator requires numeric operand"); + } + } + + throw std::runtime_error("Unknown unary operator '" + op.value + "'"); +} + value if_statement::execute(context & ctx) { value test_val = test->execute(ctx); auto out = mk_val(); @@ -415,16 +429,46 @@ value set_statement::execute(context & ctx) { return mk_val(); } +value macro_statement::execute(context & ctx) { + std::string name = dynamic_cast(this->name.get())->val; + const func_handler func = [this, &ctx, name](const func_args & args) -> value { + JJ_DEBUG("Invoking macro '%s' with %zu arguments", name.c_str(), args.args.size()); + context macro_ctx(ctx); // new scope for macro execution + + // bind parameters + size_t param_count = this->args.size(); + size_t arg_count = args.args.size(); + for (size_t i = 0; i < param_count; ++i) { + std::string param_name = dynamic_cast(this->args[i].get())->val; + if (i < arg_count) { + macro_ctx.var[param_name] = args.args[i]->clone(); + } else { + macro_ctx.var[param_name] = mk_val(); + } + } + + // execute macro body + return exec_statements(this->body, macro_ctx); + }; + + JJ_DEBUG("Defining macro '%s' with %zu parameters", name.c_str(), args.size()); + ctx.var[name] = mk_val(func); + return mk_val(); +} + value member_expression::execute(context & ctx) { value object = this->object->execute(ctx); value property; if (this->computed) { + JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str()); property = this->property->execute(ctx); } else { property = mk_val(dynamic_cast(this->property.get())->val); } + JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str()); + value val = mk_val(); if (is_val(object)) { @@ -432,18 +476,13 @@ value member_expression::execute(context & ctx) { throw std::runtime_error("Cannot access object with non-string: got " + property->type()); } auto key = property->as_string().str(); + JJ_DEBUG("Accessing object property '%s'", key.c_str()); auto & obj = object->as_object(); auto it = obj.find(key); if (it != obj.end()) { val = it->second->clone(); } else { - auto builtins = object->get_builtins(); - auto bit = builtins.find(key); - if (bit != builtins.end()) { - func_args args; - args.args.push_back(object->clone()); - val = bit->second(args); - } + val = try_builtin_func(key, object); } } else if (is_val(object) || is_val(object)) { @@ -464,13 +503,7 @@ value member_expression::execute(context & ctx) { } else if (is_val(property)) { auto key = property->as_string().str(); JJ_DEBUG("Accessing %s built-in '%s'", is_val(object) ? "array" : "string", key.c_str()); - auto builtins = object->get_builtins(); - auto bit = builtins.find(key); - if (bit != builtins.end()) { - func_args args; - args.args.push_back(object->clone()); - val = bit->second(args); - } + val = try_builtin_func(key, object); } else { throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type()); } @@ -480,13 +513,7 @@ value member_expression::execute(context & ctx) { throw std::runtime_error("Cannot access property with non-string: got " + property->type()); } auto key = property->as_string().str(); - auto builtins = object->get_builtins(); - auto bit = builtins.find(key); - if (bit != builtins.end()) { - func_args args; - args.args.push_back(object->clone()); - val = bit->second(args); - } + val = try_builtin_func(key, object); } return val; diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 7c431cd47e..786d49bad1 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -166,12 +166,16 @@ struct macro_statement : public statement { } std::string type() const override { return "Macro"; } + value execute(context & ctx) override; }; struct comment_statement : public statement { std::string val; explicit comment_statement(const std::string & v) : val(v) {} std::string type() const override { return "Comment"; } + value execute(context &) override { + return mk_val(); + } }; // Expressions @@ -339,6 +343,7 @@ struct select_expression : public expression { /** * An operation with two sides, separated by the "is" operator. + * NOTE: "value is something" translates to function call "test_is_something(value)" */ struct test_expression : public expression { statement_ptr operand; @@ -351,6 +356,7 @@ struct test_expression : public expression { chk_type(this->test); } std::string type() const override { return "TestExpression"; } + value execute(context & ctx) override; }; /** @@ -365,6 +371,7 @@ struct unary_expression : public expression { chk_type(this->argument); } std::string type() const override { return "UnaryExpression"; } + value execute(context & ctx) override; }; struct slice_expression : public expression { @@ -442,14 +449,34 @@ struct vm { context & ctx; explicit vm(context & ctx) : ctx(ctx) {} - std::vector execute(program & prog) { - std::vector results; + value_array execute(program & prog) { + value_array results = mk_val(); for (auto & stmt : prog.body) { value res = stmt->execute(ctx); - results.push_back(std::move(res)); + results->val_arr->push_back(std::move(res)); } return results; } + + std::vector gather_string_parts(const value & val) { + std::vector parts; + gather_string_parts_recursive(val, parts); + return parts; + } + + void gather_string_parts_recursive(const value & val, std::vector & parts) { + if (is_val(val)) { + const auto & str_val = dynamic_cast(val.get())->val_str; + for (const auto & part : str_val.parts) { + parts.push_back(part); + } + } else if (is_val(val)) { + auto items = dynamic_cast(val.get())->val_arr.get(); + for (const auto & item : *items) { + gather_string_parts_recursive(item, parts); + } + } + } }; } // namespace jinja diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index acbf7daf2a..87ac00fca1 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #undef NDEBUG #include @@ -11,12 +12,15 @@ #include "jinja/jinja-lexer.h" int main(void) { - std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; + //std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; //std::string contents = "{% if messages[0]['role'] != 'system' %}nice {{ messages[0]['content'] }}{% endif %}"; //std::string contents = " {{ messages[0]['content'] }} "; + std::ifstream infile("models/templates/moonshotai-Kimi-K2.jinja"); + std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); + std::cout << "=== INPUT ===\n" << contents << "\n\n"; jinja::lexer lexer; @@ -56,14 +60,12 @@ int main(void) { ctx.var["messages"] = std::move(messages); jinja::vm vm(ctx); - auto results = vm.execute(ast); + const jinja::value results = vm.execute(ast); + auto parts = vm.gather_string_parts(results); std::cout << "\n=== RESULTS ===\n"; - for (const auto & res : results) { - if (res->is_null()) { - continue; - } - std::cout << "result type: " << res->type() << " | value: " << res->as_repr(); + for (const auto & part : parts) { + std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n"; } return 0;