From 4331e9c8e979bedff396f4a4e5764fa50df8df92 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 17:23:29 +0100 Subject: [PATCH] keyword arguments and slicing array --- common/jinja/jinja-value.h | 29 +++++----- common/jinja/jinja-vm-builtins.cpp | 85 +++++++++++++++++++++++++++++ common/jinja/jinja-vm.cpp | 86 +++++++++++++++++++++--------- common/jinja/jinja-vm.h | 34 +++++++----- tests/test-chat-jinja.cpp | 2 +- 5 files changed, 184 insertions(+), 52 deletions(-) diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 787cec46b3..2bb600c1b9 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -55,6 +55,7 @@ struct func_args { throw std::runtime_error("Expected " + std::to_string(count) + " arguments, got " + std::to_string(args.size())); } } + // TODO: add support for get kwargs // utility functions template void ensure_vals() const { ensure_count(1); @@ -187,19 +188,6 @@ struct value_array_t : public value_t { // point to the same underlying data val_arr = v->val_arr; } - value_array_t(value_array_t & other, size_t start = 0, size_t end = -1) { - val_arr = std::make_shared>(); - size_t sz = other.val_arr->size(); - if (end == static_cast(-1) || end > sz) { - end = sz; - } - if (start > end || start >= sz) { - return; - } - for (size_t i = start; i < end; i++) { - val_arr->push_back(other.val_arr->at(i)->clone()); - } - } void push_back(const value & val) { val_arr->push_back(val->clone()); } @@ -319,6 +307,21 @@ struct value_undefined_t : public value_t { }; using value_undefined = std::unique_ptr; +// special value for kwarg +struct value_kwarg_t : public value_t { + std::string key; + value val; + value_kwarg_t(const value_kwarg_t & other) { + key = other.key; + val = other.val->clone(); + } + value_kwarg_t(const std::string & k, const value & v) : key(k), val(v->clone()) {} + virtual std::string type() const override { return "KwArg"; } + virtual std::string as_repr() const override { return type(); } + virtual value clone() const override { return std::make_unique(*this); } +}; +using value_kwarg = std::unique_ptr; + const func_builtins & global_builtins(); diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index 160001e522..feb7ffb5d2 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -5,9 +5,62 @@ #include #include +#include +#include +#include namespace jinja { +/** + * Function that mimics Python's array slicing. + */ +template +static T slice(const T & array, std::optional start = std::nullopt, std::optional stop = std::nullopt, int64_t step = 1) { + int64_t len = static_cast(array.size()); + int64_t direction = (step > 0) ? 1 : ((step < 0) ? -1 : 0); + int64_t start_val; + int64_t stop_val; + if (direction >= 0) { + start_val = start.value_or(0); + if (start_val < 0) { + start_val = std::max(len + start_val, (int64_t)0); + } else { + start_val = std::min(start_val, len); + } + + stop_val = stop.value_or(len); + if (stop_val < 0) { + stop_val = std::max(len + stop_val, (int64_t)0); + } else { + stop_val = std::min(stop_val, len); + } + } else { + start_val = start.value_or(len - 1); + if (start_val < 0) { + start_val = std::max(len + start_val, (int64_t)-1); + } else { + start_val = std::min(start_val, len - 1); + } + + stop_val = stop.value_or(-1); + if (stop_val < -1) { + stop_val = std::max(len + stop_val, (int64_t)-1); + } else { + stop_val = std::min(stop_val, len - 1); + } + } + T result; + if (direction == 0) { + return result; + } + for (int64_t i = start_val; direction * i < direction * stop_val; i += step) { + if (i >= 0 && i < len) { + result.push_back(std::move(array[static_cast(i)]->clone())); + } + } + return result; +} + template static value test_type_fn(const func_args & args) { args.ensure_count(1); @@ -28,6 +81,17 @@ const func_builtins & global_builtins() { std::string msg = args.args[0]->as_string().str(); throw raised_exception("Jinja Exception: " + msg); }}, + {"namespace", [](const func_args & args) -> value { + auto out = mk_val(); + for (const auto & arg : args.args) { + if (!is_val(arg)) { + throw raised_exception("namespace() arguments must be kwargs"); + } + auto kwarg = dynamic_cast(arg.get()); + out->insert(kwarg->key, kwarg->val); + } + return out; + }}, // tests {"test_is_boolean", test_type_fn}, @@ -126,6 +190,8 @@ const func_builtins & value_float_t::get_builtins() const { // return str.substr(start, end - start); // } + + static bool string_startswith(const std::string & str, const std::string & prefix) { if (str.length() < prefix.length()) return false; return str.compare(0, prefix.length(), prefix) == 0; @@ -250,6 +316,9 @@ const func_builtins & value_string_t::get_builtins() const { {"join", [](const func_args &) -> value { throw std::runtime_error("join builtin not implemented"); }}, + {"slice", [](const func_args &) -> value { + throw std::runtime_error("slice builtin not implemented"); + }}, }; return builtins; } @@ -309,6 +378,22 @@ const func_builtins & value_array_t::get_builtins() const { const auto & arr = args.args[0]->as_array(); return mk_val(static_cast(arr.size())); }}, + {"slice", [](const func_args & args) -> value { + args.ensure_count(4); + int64_t start = is_val(args.args[1]) ? args.args[1]->as_int() : 0; + int64_t stop = is_val(args.args[2]) ? args.args[2]->as_int() : -1; + int64_t step = is_val(args.args[3]) ? args.args[3]->as_int() : 1; + if (!is_val(args.args[0])) { + throw raised_exception("slice() first argument must be an array"); + } + if (step == 0) { + throw raised_exception("slice step cannot be zero"); + } + auto arr = slice(args.args[0]->as_array(), start, stop, step); + auto res = mk_val(); + res->val_arr = std::make_shared>(std::move(arr)); + return res; + }}, // TODO: reverse, sort, join, string, unique }; return builtins; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index bd4d53bded..f39321fa00 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -35,7 +35,7 @@ value identifier::execute(context & ctx) { return it->second->clone(); } else if (builtins.find(val) != builtins.end()) { JJ_DEBUG("Identifier '%s' found in builtins", val.c_str()); - return mk_val(builtins.at(val)); + return mk_val(builtins.at(val), val); } else { JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str()); return mk_val(); @@ -168,13 +168,16 @@ value binary_expression::execute(context & ctx) { throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type()); } -static value try_builtin_func(const std::string & name, const value & input) { +static value try_builtin_func(const std::string & name, const value & input, bool undef_on_missing = true) { auto builtins = input->get_builtins(); auto it = builtins.find(name); if (it != builtins.end()) { JJ_DEBUG("Binding built-in '%s'", name.c_str()); return mk_val(it->second, input, name); } + if (undef_on_missing) { + return mk_val(); + } throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type()); } @@ -189,12 +192,11 @@ value filter_expression::execute(context & ctx) { throw std::runtime_error("to_json filter not implemented"); } - auto str = input->as_string(); if (filter_val == "trim") { filter_val = "strip"; // alias } JJ_DEBUG("Applying filter '%s' to %s", filter_val.c_str(), input->type().c_str()); - return try_builtin_func(filter_val, input); + return try_builtin_func(filter_val, input)->invoke({}); } else if (is_stmt(filter)) { // TODO @@ -385,7 +387,7 @@ value set_statement::execute(context & ctx) { if (is_stmt(assignee)) { auto var_name = dynamic_cast(assignee.get())->val; - JJ_DEBUG("Setting variable '%s'", var_name.c_str()); + JJ_DEBUG("Setting variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str()); ctx.var[var_name] = rhs->clone(); } else if (is_stmt(assignee)) { @@ -408,10 +410,6 @@ value set_statement::execute(context & ctx) { } else if (is_stmt(assignee)) { auto member = dynamic_cast(assignee.get()); - value object = member->object->execute(ctx); - if (!is_val(object)) { - throw std::runtime_error("Cannot assign to member of non-object"); - } if (member->computed) { throw std::runtime_error("Cannot assign to computed member"); } @@ -419,9 +417,14 @@ value set_statement::execute(context & ctx) { throw std::runtime_error("Cannot assign to member with non-identifier property"); } auto prop_name = dynamic_cast(member->property.get())->val; - auto obj_ptr = dynamic_cast(object.get()); + + value object = member->object->execute(ctx); + if (!is_val(object)) { + throw std::runtime_error("Cannot assign to member of non-object"); + } + auto obj_ptr = dynamic_cast(object.get()); JJ_DEBUG("Setting object property '%s'", prop_name.c_str()); - obj_ptr->get()->insert(prop_name, rhs->clone()); + obj_ptr->insert(prop_name, rhs->clone()); } else { throw std::runtime_error("Invalid LHS inside assignment expression: " + assignee->type()); @@ -462,7 +465,26 @@ value member_expression::execute(context & ctx) { value property; if (this->computed) { JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str()); - property = this->property->execute(ctx); + if (is_stmt(this->property)) { + auto s = dynamic_cast(this->property.get()); + value start_val = s->start_expr ? s->start_expr->execute(ctx) : mk_val(); + value stop_val = s->stop_expr ? s->stop_expr->execute(ctx) : mk_val(); + value step_val = s->step_expr ? s->step_expr->execute(ctx) : mk_val(); + + // translate to function call: obj.slice(start, stop, step) + JJ_DEBUG("Member expression is a slice: start %s, stop %s, step %s", + start_val->as_repr().c_str(), + stop_val->as_repr().c_str(), + step_val->as_repr().c_str()); + auto slice_func = try_builtin_func("slice", object); + func_args args; + args.args.push_back(start_val->clone()); + args.args.push_back(stop_val->clone()); + args.args.push_back(step_val->clone()); + return slice_func->invoke(args); + } else { + property = this->property->execute(ctx); + } } else { property = mk_val(dynamic_cast(this->property.get())->val); } @@ -482,7 +504,7 @@ value member_expression::execute(context & ctx) { if (it != obj.end()) { val = it->second->clone(); } else { - val = try_builtin_func(key, object); + val = try_builtin_func(key, object, true); } } else if (is_val(object) || is_val(object)) { @@ -519,22 +541,22 @@ value member_expression::execute(context & ctx) { return val; } -static func_args gather_call_args(const statements & arg_stmts, context & ctx) { - func_args args; - for (auto & arg_stmt : arg_stmts) { - args.args.push_back(arg_stmt->execute(ctx)); - } - return args; -} - value call_expression::execute(context & ctx) { - auto args = gather_call_args(this->args, ctx); + // gather arguments + func_args args; + for (auto & arg_stmt : this->args) { + auto arg_val = arg_stmt->execute(ctx); + JJ_DEBUG(" Argument type: %s", arg_val->type().c_str()); + args.args.push_back(std::move(arg_val)); + } + // execute callee value callee_val = callee->execute(ctx); - JJ_DEBUG("Calling function of type %s with %zu arguments", callee_val->type().c_str(), args.args.size()); - if (!is_val(callee_val)) { + if (!is_val(callee_val)) { throw std::runtime_error("Callee is not a function: got " + callee_val->type()); } - return callee_val->invoke(args); + auto * callee_func = dynamic_cast(callee_val.get()); + JJ_DEBUG("Calling function '%s' with %zu arguments", callee_func->name.c_str(), args.args.size()); + return callee_func->invoke(args); } // compare operator for value_t @@ -570,4 +592,18 @@ bool value_compare(const value & a, const value & b) { return false; } +value keyword_argument_expression::execute(context & ctx) { + if (!is_stmt(key)) { + throw std::runtime_error("Keyword argument key must be identifiers"); + } + + std::string k = dynamic_cast(key.get())->val; + JJ_DEBUG("Keyword argument expression key: %s, value: %s", k.c_str(), val->type().c_str()); + + value v = val->execute(ctx); + JJ_DEBUG("Keyword argument value executed, type: %s", v->type().c_str()); + + return mk_val(k, v); +} + } // namespace jinja diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 786d49bad1..a931bc1ea8 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -15,7 +15,11 @@ namespace jinja { struct context { std::map var; - context() = default; + context() { + var["true"] = mk_val(true); + var["false"] = mk_val(false); + var["none"] = mk_val(); + } ~context() = default; context(const context & parent) { @@ -375,29 +379,33 @@ struct unary_expression : public expression { }; struct slice_expression : public expression { - statement_ptr start; - statement_ptr stop; - statement_ptr step; + statement_ptr start_expr; + statement_ptr stop_expr; + statement_ptr step_expr; - slice_expression(statement_ptr && start, statement_ptr && stop, statement_ptr && step) - : start(std::move(start)), stop(std::move(stop)), step(std::move(step)) { - chk_type(this->start); - chk_type(this->stop); - chk_type(this->step); + slice_expression(statement_ptr && start_expr, statement_ptr && stop_expr, statement_ptr && step_expr) + : start_expr(std::move(start_expr)), stop_expr(std::move(stop_expr)), step_expr(std::move(step_expr)) { + chk_type(this->start_expr); + chk_type(this->stop_expr); + chk_type(this->step_expr); } std::string type() const override { return "SliceExpression"; } + value execute(context &) override { + throw std::runtime_error("must be handled by MemberExpression"); + } }; struct keyword_argument_expression : public expression { statement_ptr key; - statement_ptr value; + statement_ptr val; - keyword_argument_expression(statement_ptr && key, statement_ptr && value) - : key(std::move(key)), value(std::move(value)) { + keyword_argument_expression(statement_ptr && key, statement_ptr && val) + : key(std::move(key)), val(std::move(val)) { chk_type(this->key); - chk_type(this->value); + chk_type(this->val); } std::string type() const override { return "KeywordArgumentExpression"; } + value execute(context & ctx) override; }; struct spread_expression : public expression { diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 87ac00fca1..ce17df5b1d 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -18,7 +18,7 @@ int main(void) { //std::string contents = " {{ messages[0]['content'] }} "; - std::ifstream infile("models/templates/moonshotai-Kimi-K2.jinja"); + std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja"); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); std::cout << "=== INPUT ===\n" << contents << "\n\n";