keyword arguments and slicing array
This commit is contained in:
parent
45c194622e
commit
4331e9c8e9
|
|
@ -55,6 +55,7 @@ struct func_args {
|
|||
throw std::runtime_error("Expected " + std::to_string(count) + " arguments, got " + std::to_string(args.size()));
|
||||
}
|
||||
}
|
||||
// TODO: add support for get kwargs
|
||||
// utility functions
|
||||
template<typename T> void ensure_vals() const {
|
||||
ensure_count(1);
|
||||
|
|
@ -187,19 +188,6 @@ struct value_array_t : public value_t {
|
|||
// point to the same underlying data
|
||||
val_arr = v->val_arr;
|
||||
}
|
||||
value_array_t(value_array_t & other, size_t start = 0, size_t end = -1) {
|
||||
val_arr = std::make_shared<std::vector<value>>();
|
||||
size_t sz = other.val_arr->size();
|
||||
if (end == static_cast<size_t>(-1) || end > sz) {
|
||||
end = sz;
|
||||
}
|
||||
if (start > end || start >= sz) {
|
||||
return;
|
||||
}
|
||||
for (size_t i = start; i < end; i++) {
|
||||
val_arr->push_back(other.val_arr->at(i)->clone());
|
||||
}
|
||||
}
|
||||
void push_back(const value & val) {
|
||||
val_arr->push_back(val->clone());
|
||||
}
|
||||
|
|
@ -319,6 +307,21 @@ struct value_undefined_t : public value_t {
|
|||
};
|
||||
using value_undefined = std::unique_ptr<value_undefined_t>;
|
||||
|
||||
// special value for kwarg
|
||||
struct value_kwarg_t : public value_t {
|
||||
std::string key;
|
||||
value val;
|
||||
value_kwarg_t(const value_kwarg_t & other) {
|
||||
key = other.key;
|
||||
val = other.val->clone();
|
||||
}
|
||||
value_kwarg_t(const std::string & k, const value & v) : key(k), val(v->clone()) {}
|
||||
virtual std::string type() const override { return "KwArg"; }
|
||||
virtual std::string as_repr() const override { return type(); }
|
||||
virtual value clone() const override { return std::make_unique<value_kwarg_t>(*this); }
|
||||
};
|
||||
using value_kwarg = std::unique_ptr<value_kwarg_t>;
|
||||
|
||||
|
||||
const func_builtins & global_builtins();
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,62 @@
|
|||
|
||||
#include <string>
|
||||
#include <cctype>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
#include <algorithm>
|
||||
|
||||
namespace jinja {
|
||||
|
||||
/**
|
||||
* Function that mimics Python's array slicing.
|
||||
*/
|
||||
template<typename T>
|
||||
static T slice(const T & array, std::optional<int64_t> start = std::nullopt, std::optional<int64_t> stop = std::nullopt, int64_t step = 1) {
|
||||
int64_t len = static_cast<int64_t>(array.size());
|
||||
int64_t direction = (step > 0) ? 1 : ((step < 0) ? -1 : 0);
|
||||
int64_t start_val;
|
||||
int64_t stop_val;
|
||||
if (direction >= 0) {
|
||||
start_val = start.value_or(0);
|
||||
if (start_val < 0) {
|
||||
start_val = std::max(len + start_val, (int64_t)0);
|
||||
} else {
|
||||
start_val = std::min(start_val, len);
|
||||
}
|
||||
|
||||
stop_val = stop.value_or(len);
|
||||
if (stop_val < 0) {
|
||||
stop_val = std::max(len + stop_val, (int64_t)0);
|
||||
} else {
|
||||
stop_val = std::min(stop_val, len);
|
||||
}
|
||||
} else {
|
||||
start_val = start.value_or(len - 1);
|
||||
if (start_val < 0) {
|
||||
start_val = std::max(len + start_val, (int64_t)-1);
|
||||
} else {
|
||||
start_val = std::min(start_val, len - 1);
|
||||
}
|
||||
|
||||
stop_val = stop.value_or(-1);
|
||||
if (stop_val < -1) {
|
||||
stop_val = std::max(len + stop_val, (int64_t)-1);
|
||||
} else {
|
||||
stop_val = std::min(stop_val, len - 1);
|
||||
}
|
||||
}
|
||||
T result;
|
||||
if (direction == 0) {
|
||||
return result;
|
||||
}
|
||||
for (int64_t i = start_val; direction * i < direction * stop_val; i += step) {
|
||||
if (i >= 0 && i < len) {
|
||||
result.push_back(std::move(array[static_cast<size_t>(i)]->clone()));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static value test_type_fn(const func_args & args) {
|
||||
args.ensure_count(1);
|
||||
|
|
@ -28,6 +81,17 @@ const func_builtins & global_builtins() {
|
|||
std::string msg = args.args[0]->as_string().str();
|
||||
throw raised_exception("Jinja Exception: " + msg);
|
||||
}},
|
||||
{"namespace", [](const func_args & args) -> value {
|
||||
auto out = mk_val<value_object>();
|
||||
for (const auto & arg : args.args) {
|
||||
if (!is_val<value_kwarg>(arg)) {
|
||||
throw raised_exception("namespace() arguments must be kwargs");
|
||||
}
|
||||
auto kwarg = dynamic_cast<value_kwarg_t*>(arg.get());
|
||||
out->insert(kwarg->key, kwarg->val);
|
||||
}
|
||||
return out;
|
||||
}},
|
||||
|
||||
// tests
|
||||
{"test_is_boolean", test_type_fn<value_bool>},
|
||||
|
|
@ -126,6 +190,8 @@ const func_builtins & value_float_t::get_builtins() const {
|
|||
// return str.substr(start, end - start);
|
||||
// }
|
||||
|
||||
|
||||
|
||||
static bool string_startswith(const std::string & str, const std::string & prefix) {
|
||||
if (str.length() < prefix.length()) return false;
|
||||
return str.compare(0, prefix.length(), prefix) == 0;
|
||||
|
|
@ -250,6 +316,9 @@ const func_builtins & value_string_t::get_builtins() const {
|
|||
{"join", [](const func_args &) -> value {
|
||||
throw std::runtime_error("join builtin not implemented");
|
||||
}},
|
||||
{"slice", [](const func_args &) -> value {
|
||||
throw std::runtime_error("slice builtin not implemented");
|
||||
}},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
|
|
@ -309,6 +378,22 @@ const func_builtins & value_array_t::get_builtins() const {
|
|||
const auto & arr = args.args[0]->as_array();
|
||||
return mk_val<value_int>(static_cast<int64_t>(arr.size()));
|
||||
}},
|
||||
{"slice", [](const func_args & args) -> value {
|
||||
args.ensure_count(4);
|
||||
int64_t start = is_val<value_int>(args.args[1]) ? args.args[1]->as_int() : 0;
|
||||
int64_t stop = is_val<value_int>(args.args[2]) ? args.args[2]->as_int() : -1;
|
||||
int64_t step = is_val<value_int>(args.args[3]) ? args.args[3]->as_int() : 1;
|
||||
if (!is_val<value_array>(args.args[0])) {
|
||||
throw raised_exception("slice() first argument must be an array");
|
||||
}
|
||||
if (step == 0) {
|
||||
throw raised_exception("slice step cannot be zero");
|
||||
}
|
||||
auto arr = slice(args.args[0]->as_array(), start, stop, step);
|
||||
auto res = mk_val<value_array>();
|
||||
res->val_arr = std::make_shared<std::vector<value>>(std::move(arr));
|
||||
return res;
|
||||
}},
|
||||
// TODO: reverse, sort, join, string, unique
|
||||
};
|
||||
return builtins;
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ value identifier::execute(context & ctx) {
|
|||
return it->second->clone();
|
||||
} else if (builtins.find(val) != builtins.end()) {
|
||||
JJ_DEBUG("Identifier '%s' found in builtins", val.c_str());
|
||||
return mk_val<value_func>(builtins.at(val));
|
||||
return mk_val<value_func>(builtins.at(val), val);
|
||||
} else {
|
||||
JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str());
|
||||
return mk_val<value_undefined>();
|
||||
|
|
@ -168,13 +168,16 @@ value binary_expression::execute(context & ctx) {
|
|||
throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type());
|
||||
}
|
||||
|
||||
static value try_builtin_func(const std::string & name, const value & input) {
|
||||
static value try_builtin_func(const std::string & name, const value & input, bool undef_on_missing = true) {
|
||||
auto builtins = input->get_builtins();
|
||||
auto it = builtins.find(name);
|
||||
if (it != builtins.end()) {
|
||||
JJ_DEBUG("Binding built-in '%s'", name.c_str());
|
||||
return mk_val<value_func>(it->second, input, name);
|
||||
}
|
||||
if (undef_on_missing) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type());
|
||||
}
|
||||
|
||||
|
|
@ -189,12 +192,11 @@ value filter_expression::execute(context & ctx) {
|
|||
throw std::runtime_error("to_json filter not implemented");
|
||||
}
|
||||
|
||||
auto str = input->as_string();
|
||||
if (filter_val == "trim") {
|
||||
filter_val = "strip"; // alias
|
||||
}
|
||||
JJ_DEBUG("Applying filter '%s' to %s", filter_val.c_str(), input->type().c_str());
|
||||
return try_builtin_func(filter_val, input);
|
||||
return try_builtin_func(filter_val, input)->invoke({});
|
||||
|
||||
} else if (is_stmt<call_expression>(filter)) {
|
||||
// TODO
|
||||
|
|
@ -385,7 +387,7 @@ value set_statement::execute(context & ctx) {
|
|||
|
||||
if (is_stmt<identifier>(assignee)) {
|
||||
auto var_name = dynamic_cast<identifier*>(assignee.get())->val;
|
||||
JJ_DEBUG("Setting variable '%s'", var_name.c_str());
|
||||
JJ_DEBUG("Setting variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str());
|
||||
ctx.var[var_name] = rhs->clone();
|
||||
|
||||
} else if (is_stmt<tuple_literal>(assignee)) {
|
||||
|
|
@ -408,10 +410,6 @@ value set_statement::execute(context & ctx) {
|
|||
|
||||
} else if (is_stmt<member_expression>(assignee)) {
|
||||
auto member = dynamic_cast<member_expression*>(assignee.get());
|
||||
value object = member->object->execute(ctx);
|
||||
if (!is_val<value_object>(object)) {
|
||||
throw std::runtime_error("Cannot assign to member of non-object");
|
||||
}
|
||||
if (member->computed) {
|
||||
throw std::runtime_error("Cannot assign to computed member");
|
||||
}
|
||||
|
|
@ -419,9 +417,14 @@ value set_statement::execute(context & ctx) {
|
|||
throw std::runtime_error("Cannot assign to member with non-identifier property");
|
||||
}
|
||||
auto prop_name = dynamic_cast<identifier*>(member->property.get())->val;
|
||||
auto obj_ptr = dynamic_cast<value_object*>(object.get());
|
||||
|
||||
value object = member->object->execute(ctx);
|
||||
if (!is_val<value_object>(object)) {
|
||||
throw std::runtime_error("Cannot assign to member of non-object");
|
||||
}
|
||||
auto obj_ptr = dynamic_cast<value_object_t*>(object.get());
|
||||
JJ_DEBUG("Setting object property '%s'", prop_name.c_str());
|
||||
obj_ptr->get()->insert(prop_name, rhs->clone());
|
||||
obj_ptr->insert(prop_name, rhs->clone());
|
||||
|
||||
} else {
|
||||
throw std::runtime_error("Invalid LHS inside assignment expression: " + assignee->type());
|
||||
|
|
@ -462,7 +465,26 @@ value member_expression::execute(context & ctx) {
|
|||
value property;
|
||||
if (this->computed) {
|
||||
JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str());
|
||||
property = this->property->execute(ctx);
|
||||
if (is_stmt<slice_expression>(this->property)) {
|
||||
auto s = dynamic_cast<slice_expression*>(this->property.get());
|
||||
value start_val = s->start_expr ? s->start_expr->execute(ctx) : mk_val<value_undefined>();
|
||||
value stop_val = s->stop_expr ? s->stop_expr->execute(ctx) : mk_val<value_undefined>();
|
||||
value step_val = s->step_expr ? s->step_expr->execute(ctx) : mk_val<value_undefined>();
|
||||
|
||||
// translate to function call: obj.slice(start, stop, step)
|
||||
JJ_DEBUG("Member expression is a slice: start %s, stop %s, step %s",
|
||||
start_val->as_repr().c_str(),
|
||||
stop_val->as_repr().c_str(),
|
||||
step_val->as_repr().c_str());
|
||||
auto slice_func = try_builtin_func("slice", object);
|
||||
func_args args;
|
||||
args.args.push_back(start_val->clone());
|
||||
args.args.push_back(stop_val->clone());
|
||||
args.args.push_back(step_val->clone());
|
||||
return slice_func->invoke(args);
|
||||
} else {
|
||||
property = this->property->execute(ctx);
|
||||
}
|
||||
} else {
|
||||
property = mk_val<value_string>(dynamic_cast<identifier*>(this->property.get())->val);
|
||||
}
|
||||
|
|
@ -482,7 +504,7 @@ value member_expression::execute(context & ctx) {
|
|||
if (it != obj.end()) {
|
||||
val = it->second->clone();
|
||||
} else {
|
||||
val = try_builtin_func(key, object);
|
||||
val = try_builtin_func(key, object, true);
|
||||
}
|
||||
|
||||
} else if (is_val<value_array>(object) || is_val<value_string>(object)) {
|
||||
|
|
@ -519,22 +541,22 @@ value member_expression::execute(context & ctx) {
|
|||
return val;
|
||||
}
|
||||
|
||||
static func_args gather_call_args(const statements & arg_stmts, context & ctx) {
|
||||
func_args args;
|
||||
for (auto & arg_stmt : arg_stmts) {
|
||||
args.args.push_back(arg_stmt->execute(ctx));
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
value call_expression::execute(context & ctx) {
|
||||
auto args = gather_call_args(this->args, ctx);
|
||||
// gather arguments
|
||||
func_args args;
|
||||
for (auto & arg_stmt : this->args) {
|
||||
auto arg_val = arg_stmt->execute(ctx);
|
||||
JJ_DEBUG(" Argument type: %s", arg_val->type().c_str());
|
||||
args.args.push_back(std::move(arg_val));
|
||||
}
|
||||
// execute callee
|
||||
value callee_val = callee->execute(ctx);
|
||||
JJ_DEBUG("Calling function of type %s with %zu arguments", callee_val->type().c_str(), args.args.size());
|
||||
if (!is_val<value_t>(callee_val)) {
|
||||
if (!is_val<value_func>(callee_val)) {
|
||||
throw std::runtime_error("Callee is not a function: got " + callee_val->type());
|
||||
}
|
||||
return callee_val->invoke(args);
|
||||
auto * callee_func = dynamic_cast<value_func_t*>(callee_val.get());
|
||||
JJ_DEBUG("Calling function '%s' with %zu arguments", callee_func->name.c_str(), args.args.size());
|
||||
return callee_func->invoke(args);
|
||||
}
|
||||
|
||||
// compare operator for value_t
|
||||
|
|
@ -570,4 +592,18 @@ bool value_compare(const value & a, const value & b) {
|
|||
return false;
|
||||
}
|
||||
|
||||
value keyword_argument_expression::execute(context & ctx) {
|
||||
if (!is_stmt<identifier>(key)) {
|
||||
throw std::runtime_error("Keyword argument key must be identifiers");
|
||||
}
|
||||
|
||||
std::string k = dynamic_cast<identifier*>(key.get())->val;
|
||||
JJ_DEBUG("Keyword argument expression key: %s, value: %s", k.c_str(), val->type().c_str());
|
||||
|
||||
value v = val->execute(ctx);
|
||||
JJ_DEBUG("Keyword argument value executed, type: %s", v->type().c_str());
|
||||
|
||||
return mk_val<value_kwarg>(k, v);
|
||||
}
|
||||
|
||||
} // namespace jinja
|
||||
|
|
|
|||
|
|
@ -15,7 +15,11 @@ namespace jinja {
|
|||
struct context {
|
||||
std::map<std::string, value> var;
|
||||
|
||||
context() = default;
|
||||
context() {
|
||||
var["true"] = mk_val<value_bool>(true);
|
||||
var["false"] = mk_val<value_bool>(false);
|
||||
var["none"] = mk_val<value_null>();
|
||||
}
|
||||
~context() = default;
|
||||
|
||||
context(const context & parent) {
|
||||
|
|
@ -375,29 +379,33 @@ struct unary_expression : public expression {
|
|||
};
|
||||
|
||||
struct slice_expression : public expression {
|
||||
statement_ptr start;
|
||||
statement_ptr stop;
|
||||
statement_ptr step;
|
||||
statement_ptr start_expr;
|
||||
statement_ptr stop_expr;
|
||||
statement_ptr step_expr;
|
||||
|
||||
slice_expression(statement_ptr && start, statement_ptr && stop, statement_ptr && step)
|
||||
: start(std::move(start)), stop(std::move(stop)), step(std::move(step)) {
|
||||
chk_type<expression>(this->start);
|
||||
chk_type<expression>(this->stop);
|
||||
chk_type<expression>(this->step);
|
||||
slice_expression(statement_ptr && start_expr, statement_ptr && stop_expr, statement_ptr && step_expr)
|
||||
: start_expr(std::move(start_expr)), stop_expr(std::move(stop_expr)), step_expr(std::move(step_expr)) {
|
||||
chk_type<expression>(this->start_expr);
|
||||
chk_type<expression>(this->stop_expr);
|
||||
chk_type<expression>(this->step_expr);
|
||||
}
|
||||
std::string type() const override { return "SliceExpression"; }
|
||||
value execute(context &) override {
|
||||
throw std::runtime_error("must be handled by MemberExpression");
|
||||
}
|
||||
};
|
||||
|
||||
struct keyword_argument_expression : public expression {
|
||||
statement_ptr key;
|
||||
statement_ptr value;
|
||||
statement_ptr val;
|
||||
|
||||
keyword_argument_expression(statement_ptr && key, statement_ptr && value)
|
||||
: key(std::move(key)), value(std::move(value)) {
|
||||
keyword_argument_expression(statement_ptr && key, statement_ptr && val)
|
||||
: key(std::move(key)), val(std::move(val)) {
|
||||
chk_type<identifier>(this->key);
|
||||
chk_type<expression>(this->value);
|
||||
chk_type<expression>(this->val);
|
||||
}
|
||||
std::string type() const override { return "KeywordArgumentExpression"; }
|
||||
value execute(context & ctx) override;
|
||||
};
|
||||
|
||||
struct spread_expression : public expression {
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ int main(void) {
|
|||
|
||||
//std::string contents = "<some_tokens> {{ messages[0]['content'] }} <another_token>";
|
||||
|
||||
std::ifstream infile("models/templates/moonshotai-Kimi-K2.jinja");
|
||||
std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja");
|
||||
std::string contents((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
|
||||
|
||||
std::cout << "=== INPUT ===\n" << contents << "\n\n";
|
||||
|
|
|
|||
Loading…
Reference in New Issue