improve function args handling
This commit is contained in:
parent
1b213ae5e7
commit
cbb37dd4cd
|
|
@ -115,7 +115,6 @@ const func_builtins & global_builtins() {
|
|||
return out;
|
||||
}},
|
||||
{"strftime_now", [](const func_args & args) -> value {
|
||||
args.ensure_count(1);
|
||||
args.ensure_vals<value_string>();
|
||||
std::string format = args.args[0]->as_string().str();
|
||||
// get current time
|
||||
|
|
@ -128,9 +127,9 @@ const func_builtins & global_builtins() {
|
|||
}
|
||||
}},
|
||||
{"range", [](const func_args & args) -> value {
|
||||
if (args.args.size() < 1 || args.args.size() > 3) {
|
||||
throw raised_exception("slice() takes between 1 and 3 arguments");
|
||||
}
|
||||
args.ensure_count(1, 3);
|
||||
args.ensure_vals<value_int, value_int, value_int>(true, false, false);
|
||||
|
||||
auto & arg0 = args.args[0];
|
||||
auto & arg1 = args.args[1];
|
||||
auto & arg2 = args.args[2];
|
||||
|
|
|
|||
|
|
@ -51,12 +51,6 @@ typename extract_pointee<T>::type * cast_val(value & ptr) {
|
|||
using PointeeType = typename extract_pointee<T>::type;
|
||||
return dynamic_cast<PointeeType*>(ptr.get());
|
||||
}
|
||||
template<typename T>
|
||||
void ensure_val(const value & ptr) {
|
||||
if (!is_val<T>(ptr)) {
|
||||
throw std::runtime_error("Expected value of type " + std::string(typeid(T).name()));
|
||||
}
|
||||
}
|
||||
// End Helper
|
||||
|
||||
|
||||
|
|
@ -92,36 +86,11 @@ struct context; // forward declaration
|
|||
template<typename T_JSON>
|
||||
void global_from_json(context & ctx, const T_JSON & json_obj);
|
||||
|
||||
//
|
||||
// base value type
|
||||
//
|
||||
|
||||
|
||||
struct func_args {
|
||||
std::vector<value> args;
|
||||
context & ctx;
|
||||
func_args(context & ctx) : ctx(ctx) {}
|
||||
void ensure_count(size_t min, size_t max = 999) const {
|
||||
if (args.size() < min || args.size() > max) {
|
||||
throw std::runtime_error("Expected between " + std::to_string(min) + " and " + std::to_string(max) + " arguments, got " + std::to_string(args.size()));
|
||||
}
|
||||
}
|
||||
value get_kwarg(const std::string & key) const;
|
||||
// utility functions
|
||||
// TODO: allow optional arguments
|
||||
template<typename T> void ensure_vals() const {
|
||||
ensure_count(1);
|
||||
ensure_val<T>(args[0]);
|
||||
}
|
||||
template<typename T, typename U> void ensure_vals() const {
|
||||
ensure_count(2);
|
||||
ensure_val<T>(args[0]);
|
||||
ensure_val<U>(args[1]);
|
||||
}
|
||||
template<typename T, typename U, typename V> void ensure_vals() const {
|
||||
ensure_count(3);
|
||||
ensure_val<T>(args[0]);
|
||||
ensure_val<U>(args[1]);
|
||||
ensure_val<V>(args[2]);
|
||||
}
|
||||
};
|
||||
struct func_args; // function argument values
|
||||
|
||||
using func_handler = std::function<value(const func_args &)>;
|
||||
using func_builtins = std::map<std::string, func_handler>;
|
||||
|
|
@ -165,6 +134,9 @@ struct value_t {
|
|||
virtual std::string as_repr() const { return as_string().str(); }
|
||||
};
|
||||
|
||||
//
|
||||
// primitive value types
|
||||
//
|
||||
|
||||
struct value_int_t : public value_t {
|
||||
value_int_t(int64_t v) { val_int = v; }
|
||||
|
|
@ -275,36 +247,9 @@ struct value_object_t : public value_t {
|
|||
};
|
||||
using value_object = std::shared_ptr<value_object_t>;
|
||||
|
||||
|
||||
struct value_func_t : public value_t {
|
||||
std::string name; // for debugging
|
||||
value arg0; // bound "this" argument, if any
|
||||
value_func_t(const func_handler & func, std::string func_name = "") {
|
||||
val_func = func;
|
||||
name = func_name;
|
||||
}
|
||||
value_func_t(const func_handler & func, const value & arg_this, std::string func_name = "") {
|
||||
val_func = func;
|
||||
name = func_name;
|
||||
arg0 = arg_this;
|
||||
}
|
||||
virtual value invoke(const func_args & args) const override {
|
||||
if (arg0) {
|
||||
func_args new_args(args.ctx);
|
||||
new_args.args.push_back(arg0);
|
||||
for (const auto & a : args.args) {
|
||||
new_args.args.push_back(a);
|
||||
}
|
||||
return val_func(new_args);
|
||||
} else {
|
||||
return val_func(args);
|
||||
}
|
||||
}
|
||||
virtual std::string type() const override { return "Function"; }
|
||||
virtual std::string as_repr() const override { return type(); }
|
||||
};
|
||||
using value_func = std::shared_ptr<value_func_t>;
|
||||
|
||||
//
|
||||
// null and undefined types
|
||||
//
|
||||
|
||||
struct value_null_t : public value_t {
|
||||
virtual std::string type() const override { return "Null"; }
|
||||
|
|
@ -326,6 +271,63 @@ struct value_undefined_t : public value_t {
|
|||
};
|
||||
using value_undefined = std::shared_ptr<value_undefined_t>;
|
||||
|
||||
//
|
||||
// function type
|
||||
//
|
||||
|
||||
struct func_args {
|
||||
std::string func_name; // for error messages
|
||||
std::vector<value> args;
|
||||
context & ctx;
|
||||
func_args(context & ctx) : ctx(ctx) {}
|
||||
value get_kwarg(const std::string & key) const;
|
||||
void ensure_count(size_t min, size_t max = 999) const {
|
||||
size_t n = args.size();
|
||||
if (n < min || n > max) {
|
||||
throw std::runtime_error("Function '" + func_name + "' expected between " + std::to_string(min) + " and " + std::to_string(max) + " arguments, got " + std::to_string(n));
|
||||
}
|
||||
}
|
||||
template<typename T> void ensure_val(const value & ptr) const {
|
||||
if (!is_val<T>(ptr)) {
|
||||
throw std::runtime_error("Function '" + func_name + "' expected value of type " + std::string(typeid(T).name()) + ", got " + ptr->type());
|
||||
}
|
||||
}
|
||||
template<typename T0> void ensure_vals(bool required0 = true) const {
|
||||
if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
|
||||
}
|
||||
template<typename T0, typename T1> void ensure_vals(bool required0 = true, bool required1 = true) const {
|
||||
if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
|
||||
if (required1 && args.size() > 1) ensure_val<T1>(args[1]);
|
||||
}
|
||||
template<typename T0, typename T1, typename T2> void ensure_vals(bool required0 = true, bool required1 = true, bool required2 = true) const {
|
||||
if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
|
||||
if (required1 && args.size() > 1) ensure_val<T1>(args[1]);
|
||||
if (required2 && args.size() > 2) ensure_val<T2>(args[2]);
|
||||
}
|
||||
};
|
||||
|
||||
struct value_func_t : public value_t {
|
||||
std::string name;
|
||||
value arg0; // bound "this" argument, if any
|
||||
value_func_t(const std::string & name, const func_handler & func) : name(name) {
|
||||
val_func = func;
|
||||
}
|
||||
value_func_t(const std::string & name, const func_handler & func, const value & arg_this) : name(name), arg0(arg_this) {
|
||||
val_func = func;
|
||||
}
|
||||
virtual value invoke(const func_args & args) const override {
|
||||
func_args new_args(args); // copy
|
||||
new_args.func_name = name;
|
||||
if (arg0) {
|
||||
new_args.args.insert(new_args.args.begin(), arg0);
|
||||
}
|
||||
return val_func(new_args);
|
||||
}
|
||||
virtual std::string type() const override { return "Function"; }
|
||||
virtual std::string as_repr() const override { return type(); }
|
||||
};
|
||||
using value_func = std::shared_ptr<value_func_t>;
|
||||
|
||||
// special value for kwarg
|
||||
struct value_kwarg_t : public value_t {
|
||||
std::string key;
|
||||
|
|
@ -337,11 +339,10 @@ struct value_kwarg_t : public value_t {
|
|||
using value_kwarg = std::shared_ptr<value_kwarg_t>;
|
||||
|
||||
|
||||
const func_builtins & global_builtins();
|
||||
|
||||
|
||||
// utils
|
||||
|
||||
const func_builtins & global_builtins();
|
||||
|
||||
static inferred_type value_to_inferred_type(const value & val) {
|
||||
if (is_val<value_int>(val) || is_val<value_float>(val)) {
|
||||
return inferred_type::numeric;
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ value identifier::execute_impl(context & ctx) {
|
|||
return it;
|
||||
} else if (builtins.find(val) != builtins.end()) {
|
||||
JJ_DEBUG("Identifier '%s' found in builtins", val.c_str());
|
||||
return mk_val<value_func>(builtins.at(val), val);
|
||||
return mk_val<value_func>(val, builtins.at(val));
|
||||
} else {
|
||||
JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str());
|
||||
return mk_val<value_undefined>(val);
|
||||
|
|
@ -243,7 +243,7 @@ static value try_builtin_func(const std::string & name, const value & input, boo
|
|||
auto it = builtins.find(name);
|
||||
if (it != builtins.end()) {
|
||||
JJ_DEBUG("Binding built-in '%s'", name.c_str());
|
||||
return mk_val<value_func>(it->second, input, name);
|
||||
return mk_val<value_func>(name, it->second, input);
|
||||
}
|
||||
if (undef_on_missing) {
|
||||
return mk_val<value_undefined>(name);
|
||||
|
|
@ -607,7 +607,7 @@ value macro_statement::execute_impl(context & ctx) {
|
|||
};
|
||||
|
||||
JJ_DEBUG("Defining macro '%s' with %zu parameters", name.c_str(), args.size());
|
||||
ctx.set_val(name, mk_val<value_func>(func));
|
||||
ctx.set_val(name, mk_val<value_func>(name, func));
|
||||
return mk_val<value_null>();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,9 @@ int main(void) {
|
|||
//std::string contents = "<some_tokens> {{ messages[a]['content'] }} <another_token>";
|
||||
//std::string contents = "{% if a is not defined %}hello{% endif %}";
|
||||
|
||||
std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja"); std::string contents((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
|
||||
std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja");
|
||||
//std::ifstream infile("models/templates/Kimi-K2-Thinking.jinja");
|
||||
std::string contents((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
|
||||
|
||||
run_single(contents);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue