keyword arguments and slicing array

This commit is contained in:
Xuan Son Nguyen 2025-12-28 17:23:29 +01:00
parent 45c194622e
commit 4331e9c8e9
5 changed files with 184 additions and 52 deletions

View File

@ -55,6 +55,7 @@ struct func_args {
throw std::runtime_error("Expected " + std::to_string(count) + " arguments, got " + std::to_string(args.size()));
}
}
// TODO: add support for get kwargs
// utility functions
template<typename T> void ensure_vals() const {
ensure_count(1);
@ -187,19 +188,6 @@ struct value_array_t : public value_t {
// point to the same underlying data
val_arr = v->val_arr;
}
value_array_t(value_array_t & other, size_t start = 0, size_t end = -1) {
val_arr = std::make_shared<std::vector<value>>();
size_t sz = other.val_arr->size();
if (end == static_cast<size_t>(-1) || end > sz) {
end = sz;
}
if (start > end || start >= sz) {
return;
}
for (size_t i = start; i < end; i++) {
val_arr->push_back(other.val_arr->at(i)->clone());
}
}
void push_back(const value & val) {
val_arr->push_back(val->clone());
}
@ -319,6 +307,21 @@ struct value_undefined_t : public value_t {
};
using value_undefined = std::unique_ptr<value_undefined_t>;
// special value for kwarg
struct value_kwarg_t : public value_t {
std::string key;
value val;
value_kwarg_t(const value_kwarg_t & other) {
key = other.key;
val = other.val->clone();
}
value_kwarg_t(const std::string & k, const value & v) : key(k), val(v->clone()) {}
virtual std::string type() const override { return "KwArg"; }
virtual std::string as_repr() const override { return type(); }
virtual value clone() const override { return std::make_unique<value_kwarg_t>(*this); }
};
using value_kwarg = std::unique_ptr<value_kwarg_t>;
const func_builtins & global_builtins();

View File

@ -5,9 +5,62 @@
#include <string>
#include <cctype>
#include <vector>
#include <optional>
#include <algorithm>
namespace jinja {
/**
* Function that mimics Python's array slicing.
*/
template<typename T>
static T slice(const T & array, std::optional<int64_t> start = std::nullopt, std::optional<int64_t> stop = std::nullopt, int64_t step = 1) {
int64_t len = static_cast<int64_t>(array.size());
int64_t direction = (step > 0) ? 1 : ((step < 0) ? -1 : 0);
int64_t start_val;
int64_t stop_val;
if (direction >= 0) {
start_val = start.value_or(0);
if (start_val < 0) {
start_val = std::max(len + start_val, (int64_t)0);
} else {
start_val = std::min(start_val, len);
}
stop_val = stop.value_or(len);
if (stop_val < 0) {
stop_val = std::max(len + stop_val, (int64_t)0);
} else {
stop_val = std::min(stop_val, len);
}
} else {
start_val = start.value_or(len - 1);
if (start_val < 0) {
start_val = std::max(len + start_val, (int64_t)-1);
} else {
start_val = std::min(start_val, len - 1);
}
stop_val = stop.value_or(-1);
if (stop_val < -1) {
stop_val = std::max(len + stop_val, (int64_t)-1);
} else {
stop_val = std::min(stop_val, len - 1);
}
}
T result;
if (direction == 0) {
return result;
}
for (int64_t i = start_val; direction * i < direction * stop_val; i += step) {
if (i >= 0 && i < len) {
result.push_back(std::move(array[static_cast<size_t>(i)]->clone()));
}
}
return result;
}
template<typename T>
static value test_type_fn(const func_args & args) {
args.ensure_count(1);
@ -28,6 +81,17 @@ const func_builtins & global_builtins() {
std::string msg = args.args[0]->as_string().str();
throw raised_exception("Jinja Exception: " + msg);
}},
{"namespace", [](const func_args & args) -> value {
auto out = mk_val<value_object>();
for (const auto & arg : args.args) {
if (!is_val<value_kwarg>(arg)) {
throw raised_exception("namespace() arguments must be kwargs");
}
auto kwarg = dynamic_cast<value_kwarg_t*>(arg.get());
out->insert(kwarg->key, kwarg->val);
}
return out;
}},
// tests
{"test_is_boolean", test_type_fn<value_bool>},
@ -126,6 +190,8 @@ const func_builtins & value_float_t::get_builtins() const {
// return str.substr(start, end - start);
// }
static bool string_startswith(const std::string & str, const std::string & prefix) {
if (str.length() < prefix.length()) return false;
return str.compare(0, prefix.length(), prefix) == 0;
@ -250,6 +316,9 @@ const func_builtins & value_string_t::get_builtins() const {
{"join", [](const func_args &) -> value {
throw std::runtime_error("join builtin not implemented");
}},
{"slice", [](const func_args &) -> value {
throw std::runtime_error("slice builtin not implemented");
}},
};
return builtins;
}
@ -309,6 +378,22 @@ const func_builtins & value_array_t::get_builtins() const {
const auto & arr = args.args[0]->as_array();
return mk_val<value_int>(static_cast<int64_t>(arr.size()));
}},
{"slice", [](const func_args & args) -> value {
args.ensure_count(4);
int64_t start = is_val<value_int>(args.args[1]) ? args.args[1]->as_int() : 0;
int64_t stop = is_val<value_int>(args.args[2]) ? args.args[2]->as_int() : -1;
int64_t step = is_val<value_int>(args.args[3]) ? args.args[3]->as_int() : 1;
if (!is_val<value_array>(args.args[0])) {
throw raised_exception("slice() first argument must be an array");
}
if (step == 0) {
throw raised_exception("slice step cannot be zero");
}
auto arr = slice(args.args[0]->as_array(), start, stop, step);
auto res = mk_val<value_array>();
res->val_arr = std::make_shared<std::vector<value>>(std::move(arr));
return res;
}},
// TODO: reverse, sort, join, string, unique
};
return builtins;

View File

@ -35,7 +35,7 @@ value identifier::execute(context & ctx) {
return it->second->clone();
} else if (builtins.find(val) != builtins.end()) {
JJ_DEBUG("Identifier '%s' found in builtins", val.c_str());
return mk_val<value_func>(builtins.at(val));
return mk_val<value_func>(builtins.at(val), val);
} else {
JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str());
return mk_val<value_undefined>();
@ -168,13 +168,16 @@ value binary_expression::execute(context & ctx) {
throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type());
}
static value try_builtin_func(const std::string & name, const value & input) {
static value try_builtin_func(const std::string & name, const value & input, bool undef_on_missing = true) {
auto builtins = input->get_builtins();
auto it = builtins.find(name);
if (it != builtins.end()) {
JJ_DEBUG("Binding built-in '%s'", name.c_str());
return mk_val<value_func>(it->second, input, name);
}
if (undef_on_missing) {
return mk_val<value_undefined>();
}
throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type());
}
@ -189,12 +192,11 @@ value filter_expression::execute(context & ctx) {
throw std::runtime_error("to_json filter not implemented");
}
auto str = input->as_string();
if (filter_val == "trim") {
filter_val = "strip"; // alias
}
JJ_DEBUG("Applying filter '%s' to %s", filter_val.c_str(), input->type().c_str());
return try_builtin_func(filter_val, input);
return try_builtin_func(filter_val, input)->invoke({});
} else if (is_stmt<call_expression>(filter)) {
// TODO
@ -385,7 +387,7 @@ value set_statement::execute(context & ctx) {
if (is_stmt<identifier>(assignee)) {
auto var_name = dynamic_cast<identifier*>(assignee.get())->val;
JJ_DEBUG("Setting variable '%s'", var_name.c_str());
JJ_DEBUG("Setting variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str());
ctx.var[var_name] = rhs->clone();
} else if (is_stmt<tuple_literal>(assignee)) {
@ -408,10 +410,6 @@ value set_statement::execute(context & ctx) {
} else if (is_stmt<member_expression>(assignee)) {
auto member = dynamic_cast<member_expression*>(assignee.get());
value object = member->object->execute(ctx);
if (!is_val<value_object>(object)) {
throw std::runtime_error("Cannot assign to member of non-object");
}
if (member->computed) {
throw std::runtime_error("Cannot assign to computed member");
}
@ -419,9 +417,14 @@ value set_statement::execute(context & ctx) {
throw std::runtime_error("Cannot assign to member with non-identifier property");
}
auto prop_name = dynamic_cast<identifier*>(member->property.get())->val;
auto obj_ptr = dynamic_cast<value_object*>(object.get());
value object = member->object->execute(ctx);
if (!is_val<value_object>(object)) {
throw std::runtime_error("Cannot assign to member of non-object");
}
auto obj_ptr = dynamic_cast<value_object_t*>(object.get());
JJ_DEBUG("Setting object property '%s'", prop_name.c_str());
obj_ptr->get()->insert(prop_name, rhs->clone());
obj_ptr->insert(prop_name, rhs->clone());
} else {
throw std::runtime_error("Invalid LHS inside assignment expression: " + assignee->type());
@ -462,7 +465,26 @@ value member_expression::execute(context & ctx) {
value property;
if (this->computed) {
JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str());
property = this->property->execute(ctx);
if (is_stmt<slice_expression>(this->property)) {
auto s = dynamic_cast<slice_expression*>(this->property.get());
value start_val = s->start_expr ? s->start_expr->execute(ctx) : mk_val<value_undefined>();
value stop_val = s->stop_expr ? s->stop_expr->execute(ctx) : mk_val<value_undefined>();
value step_val = s->step_expr ? s->step_expr->execute(ctx) : mk_val<value_undefined>();
// translate to function call: obj.slice(start, stop, step)
JJ_DEBUG("Member expression is a slice: start %s, stop %s, step %s",
start_val->as_repr().c_str(),
stop_val->as_repr().c_str(),
step_val->as_repr().c_str());
auto slice_func = try_builtin_func("slice", object);
func_args args;
args.args.push_back(start_val->clone());
args.args.push_back(stop_val->clone());
args.args.push_back(step_val->clone());
return slice_func->invoke(args);
} else {
property = this->property->execute(ctx);
}
} else {
property = mk_val<value_string>(dynamic_cast<identifier*>(this->property.get())->val);
}
@ -482,7 +504,7 @@ value member_expression::execute(context & ctx) {
if (it != obj.end()) {
val = it->second->clone();
} else {
val = try_builtin_func(key, object);
val = try_builtin_func(key, object, true);
}
} else if (is_val<value_array>(object) || is_val<value_string>(object)) {
@ -519,22 +541,22 @@ value member_expression::execute(context & ctx) {
return val;
}
static func_args gather_call_args(const statements & arg_stmts, context & ctx) {
func_args args;
for (auto & arg_stmt : arg_stmts) {
args.args.push_back(arg_stmt->execute(ctx));
}
return args;
}
value call_expression::execute(context & ctx) {
auto args = gather_call_args(this->args, ctx);
// gather arguments
func_args args;
for (auto & arg_stmt : this->args) {
auto arg_val = arg_stmt->execute(ctx);
JJ_DEBUG(" Argument type: %s", arg_val->type().c_str());
args.args.push_back(std::move(arg_val));
}
// execute callee
value callee_val = callee->execute(ctx);
JJ_DEBUG("Calling function of type %s with %zu arguments", callee_val->type().c_str(), args.args.size());
if (!is_val<value_t>(callee_val)) {
if (!is_val<value_func>(callee_val)) {
throw std::runtime_error("Callee is not a function: got " + callee_val->type());
}
return callee_val->invoke(args);
auto * callee_func = dynamic_cast<value_func_t*>(callee_val.get());
JJ_DEBUG("Calling function '%s' with %zu arguments", callee_func->name.c_str(), args.args.size());
return callee_func->invoke(args);
}
// compare operator for value_t
@ -570,4 +592,18 @@ bool value_compare(const value & a, const value & b) {
return false;
}
value keyword_argument_expression::execute(context & ctx) {
if (!is_stmt<identifier>(key)) {
throw std::runtime_error("Keyword argument key must be identifiers");
}
std::string k = dynamic_cast<identifier*>(key.get())->val;
JJ_DEBUG("Keyword argument expression key: %s, value: %s", k.c_str(), val->type().c_str());
value v = val->execute(ctx);
JJ_DEBUG("Keyword argument value executed, type: %s", v->type().c_str());
return mk_val<value_kwarg>(k, v);
}
} // namespace jinja

View File

@ -15,7 +15,11 @@ namespace jinja {
struct context {
std::map<std::string, value> var;
context() = default;
context() {
var["true"] = mk_val<value_bool>(true);
var["false"] = mk_val<value_bool>(false);
var["none"] = mk_val<value_null>();
}
~context() = default;
context(const context & parent) {
@ -375,29 +379,33 @@ struct unary_expression : public expression {
};
struct slice_expression : public expression {
statement_ptr start;
statement_ptr stop;
statement_ptr step;
statement_ptr start_expr;
statement_ptr stop_expr;
statement_ptr step_expr;
slice_expression(statement_ptr && start, statement_ptr && stop, statement_ptr && step)
: start(std::move(start)), stop(std::move(stop)), step(std::move(step)) {
chk_type<expression>(this->start);
chk_type<expression>(this->stop);
chk_type<expression>(this->step);
slice_expression(statement_ptr && start_expr, statement_ptr && stop_expr, statement_ptr && step_expr)
: start_expr(std::move(start_expr)), stop_expr(std::move(stop_expr)), step_expr(std::move(step_expr)) {
chk_type<expression>(this->start_expr);
chk_type<expression>(this->stop_expr);
chk_type<expression>(this->step_expr);
}
std::string type() const override { return "SliceExpression"; }
value execute(context &) override {
throw std::runtime_error("must be handled by MemberExpression");
}
};
struct keyword_argument_expression : public expression {
statement_ptr key;
statement_ptr value;
statement_ptr val;
keyword_argument_expression(statement_ptr && key, statement_ptr && value)
: key(std::move(key)), value(std::move(value)) {
keyword_argument_expression(statement_ptr && key, statement_ptr && val)
: key(std::move(key)), val(std::move(val)) {
chk_type<identifier>(this->key);
chk_type<expression>(this->value);
chk_type<expression>(this->val);
}
std::string type() const override { return "KeywordArgumentExpression"; }
value execute(context & ctx) override;
};
struct spread_expression : public expression {

View File

@ -18,7 +18,7 @@ int main(void) {
//std::string contents = "<some_tokens> {{ messages[0]['content'] }} <another_token>";
std::ifstream infile("models/templates/moonshotai-Kimi-K2.jinja");
std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja");
std::string contents((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
std::cout << "=== INPUT ===\n" << contents << "\n\n";