From cbb37dd4cda2891cdf61367546cc98d1875f29fe Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 31 Dec 2025 11:29:40 +0100 Subject: [PATCH] improve function args handling --- common/jinja/jinja-value.cpp | 7 +- common/jinja/jinja-value.h | 137 ++++++++++++++++++----------------- common/jinja/jinja-vm.cpp | 6 +- tests/test-chat-jinja.cpp | 4 +- 4 files changed, 78 insertions(+), 76 deletions(-) diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index 6c3d9249b3..270caafede 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -115,7 +115,6 @@ const func_builtins & global_builtins() { return out; }}, {"strftime_now", [](const func_args & args) -> value { - args.ensure_count(1); args.ensure_vals(); std::string format = args.args[0]->as_string().str(); // get current time @@ -128,9 +127,9 @@ const func_builtins & global_builtins() { } }}, {"range", [](const func_args & args) -> value { - if (args.args.size() < 1 || args.args.size() > 3) { - throw raised_exception("slice() takes between 1 and 3 arguments"); - } + args.ensure_count(1, 3); + args.ensure_vals(true, false, false); + auto & arg0 = args.args[0]; auto & arg1 = args.args[1]; auto & arg2 = args.args[2]; diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 77d30c82f7..6be5160a89 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -51,12 +51,6 @@ typename extract_pointee::type * cast_val(value & ptr) { using PointeeType = typename extract_pointee::type; return dynamic_cast(ptr.get()); } -template -void ensure_val(const value & ptr) { - if (!is_val(ptr)) { - throw std::runtime_error("Expected value of type " + std::string(typeid(T).name())); - } -} // End Helper @@ -92,36 +86,11 @@ struct context; // forward declaration template void global_from_json(context & ctx, const T_JSON & json_obj); +// +// base value type +// - -struct func_args { - std::vector args; - context & ctx; - func_args(context & ctx) : ctx(ctx) {} - void ensure_count(size_t min, size_t max = 999) const { - if (args.size() < min || args.size() > max) { - throw std::runtime_error("Expected between " + std::to_string(min) + " and " + std::to_string(max) + " arguments, got " + std::to_string(args.size())); - } - } - value get_kwarg(const std::string & key) const; - // utility functions - // TODO: allow optional arguments - template void ensure_vals() const { - ensure_count(1); - ensure_val(args[0]); - } - template void ensure_vals() const { - ensure_count(2); - ensure_val(args[0]); - ensure_val(args[1]); - } - template void ensure_vals() const { - ensure_count(3); - ensure_val(args[0]); - ensure_val(args[1]); - ensure_val(args[2]); - } -}; +struct func_args; // function argument values using func_handler = std::function; using func_builtins = std::map; @@ -165,6 +134,9 @@ struct value_t { virtual std::string as_repr() const { return as_string().str(); } }; +// +// primitive value types +// struct value_int_t : public value_t { value_int_t(int64_t v) { val_int = v; } @@ -275,36 +247,9 @@ struct value_object_t : public value_t { }; using value_object = std::shared_ptr; - -struct value_func_t : public value_t { - std::string name; // for debugging - value arg0; // bound "this" argument, if any - value_func_t(const func_handler & func, std::string func_name = "") { - val_func = func; - name = func_name; - } - value_func_t(const func_handler & func, const value & arg_this, std::string func_name = "") { - val_func = func; - name = func_name; - arg0 = arg_this; - } - virtual value invoke(const func_args & args) const override { - if (arg0) { - func_args new_args(args.ctx); - new_args.args.push_back(arg0); - for (const auto & a : args.args) { - new_args.args.push_back(a); - } - return val_func(new_args); - } else { - return val_func(args); - } - } - virtual std::string type() const override { return "Function"; } - virtual std::string as_repr() const override { return type(); } -}; -using value_func = std::shared_ptr; - +// +// null and undefined types +// struct value_null_t : public value_t { virtual std::string type() const override { return "Null"; } @@ -326,6 +271,63 @@ struct value_undefined_t : public value_t { }; using value_undefined = std::shared_ptr; +// +// function type +// + +struct func_args { + std::string func_name; // for error messages + std::vector args; + context & ctx; + func_args(context & ctx) : ctx(ctx) {} + value get_kwarg(const std::string & key) const; + void ensure_count(size_t min, size_t max = 999) const { + size_t n = args.size(); + if (n < min || n > max) { + throw std::runtime_error("Function '" + func_name + "' expected between " + std::to_string(min) + " and " + std::to_string(max) + " arguments, got " + std::to_string(n)); + } + } + template void ensure_val(const value & ptr) const { + if (!is_val(ptr)) { + throw std::runtime_error("Function '" + func_name + "' expected value of type " + std::string(typeid(T).name()) + ", got " + ptr->type()); + } + } + template void ensure_vals(bool required0 = true) const { + if (required0 && args.size() > 0) ensure_val(args[0]); + } + template void ensure_vals(bool required0 = true, bool required1 = true) const { + if (required0 && args.size() > 0) ensure_val(args[0]); + if (required1 && args.size() > 1) ensure_val(args[1]); + } + template void ensure_vals(bool required0 = true, bool required1 = true, bool required2 = true) const { + if (required0 && args.size() > 0) ensure_val(args[0]); + if (required1 && args.size() > 1) ensure_val(args[1]); + if (required2 && args.size() > 2) ensure_val(args[2]); + } +}; + +struct value_func_t : public value_t { + std::string name; + value arg0; // bound "this" argument, if any + value_func_t(const std::string & name, const func_handler & func) : name(name) { + val_func = func; + } + value_func_t(const std::string & name, const func_handler & func, const value & arg_this) : name(name), arg0(arg_this) { + val_func = func; + } + virtual value invoke(const func_args & args) const override { + func_args new_args(args); // copy + new_args.func_name = name; + if (arg0) { + new_args.args.insert(new_args.args.begin(), arg0); + } + return val_func(new_args); + } + virtual std::string type() const override { return "Function"; } + virtual std::string as_repr() const override { return type(); } +}; +using value_func = std::shared_ptr; + // special value for kwarg struct value_kwarg_t : public value_t { std::string key; @@ -337,11 +339,10 @@ struct value_kwarg_t : public value_t { using value_kwarg = std::shared_ptr; -const func_builtins & global_builtins(); - - // utils +const func_builtins & global_builtins(); + static inferred_type value_to_inferred_type(const value & val) { if (is_val(val) || is_val(val)) { return inferred_type::numeric; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index d6958a54c9..89dd49ed0a 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -70,7 +70,7 @@ value identifier::execute_impl(context & ctx) { return it; } else if (builtins.find(val) != builtins.end()) { JJ_DEBUG("Identifier '%s' found in builtins", val.c_str()); - return mk_val(builtins.at(val), val); + return mk_val(val, builtins.at(val)); } else { JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str()); return mk_val(val); @@ -243,7 +243,7 @@ static value try_builtin_func(const std::string & name, const value & input, boo auto it = builtins.find(name); if (it != builtins.end()) { JJ_DEBUG("Binding built-in '%s'", name.c_str()); - return mk_val(it->second, input, name); + return mk_val(name, it->second, input); } if (undef_on_missing) { return mk_val(name); @@ -607,7 +607,7 @@ value macro_statement::execute_impl(context & ctx) { }; JJ_DEBUG("Defining macro '%s' with %zu parameters", name.c_str(), args.size()); - ctx.set_val(name, mk_val(func)); + ctx.set_val(name, mk_val(name, func)); return mk_val(); } diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index c205b150cf..b6a9a4a766 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -26,7 +26,9 @@ int main(void) { //std::string contents = " {{ messages[a]['content'] }} "; //std::string contents = "{% if a is not defined %}hello{% endif %}"; - std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja"); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); + std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja"); + //std::ifstream infile("models/templates/Kimi-K2-Thinking.jinja"); + std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); run_single(contents);