improve function args handling

This commit is contained in:
Xuan Son Nguyen 2025-12-31 11:29:40 +01:00
parent 1b213ae5e7
commit cbb37dd4cd
4 changed files with 78 additions and 76 deletions

View File

@ -115,7 +115,6 @@ const func_builtins & global_builtins() {
return out;
}},
{"strftime_now", [](const func_args & args) -> value {
args.ensure_count(1);
args.ensure_vals<value_string>();
std::string format = args.args[0]->as_string().str();
// get current time
@ -128,9 +127,9 @@ const func_builtins & global_builtins() {
}
}},
{"range", [](const func_args & args) -> value {
if (args.args.size() < 1 || args.args.size() > 3) {
throw raised_exception("slice() takes between 1 and 3 arguments");
}
args.ensure_count(1, 3);
args.ensure_vals<value_int, value_int, value_int>(true, false, false);
auto & arg0 = args.args[0];
auto & arg1 = args.args[1];
auto & arg2 = args.args[2];

View File

@ -51,12 +51,6 @@ typename extract_pointee<T>::type * cast_val(value & ptr) {
using PointeeType = typename extract_pointee<T>::type;
return dynamic_cast<PointeeType*>(ptr.get());
}
template<typename T>
void ensure_val(const value & ptr) {
if (!is_val<T>(ptr)) {
throw std::runtime_error("Expected value of type " + std::string(typeid(T).name()));
}
}
// End Helper
@ -92,36 +86,11 @@ struct context; // forward declaration
template<typename T_JSON>
void global_from_json(context & ctx, const T_JSON & json_obj);
//
// base value type
//
struct func_args {
std::vector<value> args;
context & ctx;
func_args(context & ctx) : ctx(ctx) {}
void ensure_count(size_t min, size_t max = 999) const {
if (args.size() < min || args.size() > max) {
throw std::runtime_error("Expected between " + std::to_string(min) + " and " + std::to_string(max) + " arguments, got " + std::to_string(args.size()));
}
}
value get_kwarg(const std::string & key) const;
// utility functions
// TODO: allow optional arguments
template<typename T> void ensure_vals() const {
ensure_count(1);
ensure_val<T>(args[0]);
}
template<typename T, typename U> void ensure_vals() const {
ensure_count(2);
ensure_val<T>(args[0]);
ensure_val<U>(args[1]);
}
template<typename T, typename U, typename V> void ensure_vals() const {
ensure_count(3);
ensure_val<T>(args[0]);
ensure_val<U>(args[1]);
ensure_val<V>(args[2]);
}
};
struct func_args; // function argument values
using func_handler = std::function<value(const func_args &)>;
using func_builtins = std::map<std::string, func_handler>;
@ -165,6 +134,9 @@ struct value_t {
virtual std::string as_repr() const { return as_string().str(); }
};
//
// primitive value types
//
struct value_int_t : public value_t {
value_int_t(int64_t v) { val_int = v; }
@ -275,36 +247,9 @@ struct value_object_t : public value_t {
};
using value_object = std::shared_ptr<value_object_t>;
struct value_func_t : public value_t {
std::string name; // for debugging
value arg0; // bound "this" argument, if any
value_func_t(const func_handler & func, std::string func_name = "") {
val_func = func;
name = func_name;
}
value_func_t(const func_handler & func, const value & arg_this, std::string func_name = "") {
val_func = func;
name = func_name;
arg0 = arg_this;
}
virtual value invoke(const func_args & args) const override {
if (arg0) {
func_args new_args(args.ctx);
new_args.args.push_back(arg0);
for (const auto & a : args.args) {
new_args.args.push_back(a);
}
return val_func(new_args);
} else {
return val_func(args);
}
}
virtual std::string type() const override { return "Function"; }
virtual std::string as_repr() const override { return type(); }
};
using value_func = std::shared_ptr<value_func_t>;
//
// null and undefined types
//
struct value_null_t : public value_t {
virtual std::string type() const override { return "Null"; }
@ -326,6 +271,63 @@ struct value_undefined_t : public value_t {
};
using value_undefined = std::shared_ptr<value_undefined_t>;
//
// function type
//
struct func_args {
std::string func_name; // for error messages
std::vector<value> args;
context & ctx;
func_args(context & ctx) : ctx(ctx) {}
value get_kwarg(const std::string & key) const;
void ensure_count(size_t min, size_t max = 999) const {
size_t n = args.size();
if (n < min || n > max) {
throw std::runtime_error("Function '" + func_name + "' expected between " + std::to_string(min) + " and " + std::to_string(max) + " arguments, got " + std::to_string(n));
}
}
template<typename T> void ensure_val(const value & ptr) const {
if (!is_val<T>(ptr)) {
throw std::runtime_error("Function '" + func_name + "' expected value of type " + std::string(typeid(T).name()) + ", got " + ptr->type());
}
}
template<typename T0> void ensure_vals(bool required0 = true) const {
if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
}
template<typename T0, typename T1> void ensure_vals(bool required0 = true, bool required1 = true) const {
if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
if (required1 && args.size() > 1) ensure_val<T1>(args[1]);
}
template<typename T0, typename T1, typename T2> void ensure_vals(bool required0 = true, bool required1 = true, bool required2 = true) const {
if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
if (required1 && args.size() > 1) ensure_val<T1>(args[1]);
if (required2 && args.size() > 2) ensure_val<T2>(args[2]);
}
};
struct value_func_t : public value_t {
std::string name;
value arg0; // bound "this" argument, if any
value_func_t(const std::string & name, const func_handler & func) : name(name) {
val_func = func;
}
value_func_t(const std::string & name, const func_handler & func, const value & arg_this) : name(name), arg0(arg_this) {
val_func = func;
}
virtual value invoke(const func_args & args) const override {
func_args new_args(args); // copy
new_args.func_name = name;
if (arg0) {
new_args.args.insert(new_args.args.begin(), arg0);
}
return val_func(new_args);
}
virtual std::string type() const override { return "Function"; }
virtual std::string as_repr() const override { return type(); }
};
using value_func = std::shared_ptr<value_func_t>;
// special value for kwarg
struct value_kwarg_t : public value_t {
std::string key;
@ -337,11 +339,10 @@ struct value_kwarg_t : public value_t {
using value_kwarg = std::shared_ptr<value_kwarg_t>;
const func_builtins & global_builtins();
// utils
const func_builtins & global_builtins();
static inferred_type value_to_inferred_type(const value & val) {
if (is_val<value_int>(val) || is_val<value_float>(val)) {
return inferred_type::numeric;

View File

@ -70,7 +70,7 @@ value identifier::execute_impl(context & ctx) {
return it;
} else if (builtins.find(val) != builtins.end()) {
JJ_DEBUG("Identifier '%s' found in builtins", val.c_str());
return mk_val<value_func>(builtins.at(val), val);
return mk_val<value_func>(val, builtins.at(val));
} else {
JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str());
return mk_val<value_undefined>(val);
@ -243,7 +243,7 @@ static value try_builtin_func(const std::string & name, const value & input, boo
auto it = builtins.find(name);
if (it != builtins.end()) {
JJ_DEBUG("Binding built-in '%s'", name.c_str());
return mk_val<value_func>(it->second, input, name);
return mk_val<value_func>(name, it->second, input);
}
if (undef_on_missing) {
return mk_val<value_undefined>(name);
@ -607,7 +607,7 @@ value macro_statement::execute_impl(context & ctx) {
};
JJ_DEBUG("Defining macro '%s' with %zu parameters", name.c_str(), args.size());
ctx.set_val(name, mk_val<value_func>(func));
ctx.set_val(name, mk_val<value_func>(name, func));
return mk_val<value_null>();
}

View File

@ -26,7 +26,9 @@ int main(void) {
//std::string contents = "<some_tokens> {{ messages[a]['content'] }} <another_token>";
//std::string contents = "{% if a is not defined %}hello{% endif %}";
std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja"); std::string contents((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja");
//std::ifstream infile("models/templates/Kimi-K2-Thinking.jinja");
std::string contents((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
run_single(contents);