support binded functions

This commit is contained in:
Xuan Son Nguyen 2025-12-28 15:33:14 +01:00
parent 4ca114b095
commit 45c194622e
6 changed files with 254 additions and 76 deletions

View File

@ -16,6 +16,24 @@ namespace jinja {
struct string_part {
bool is_input = false; // may skip parsing special tokens if true
std::string val;
bool is_uppercase() const {
for (char c : val) {
if (std::islower(static_cast<unsigned char>(c))) {
return false;
}
}
return true;
}
bool is_lowercase() const {
for (char c : val) {
if (std::isupper(static_cast<unsigned char>(c))) {
return false;
}
}
return true;
}
};
struct string {
@ -67,6 +85,24 @@ struct string {
return true;
}
bool is_uppercase() const {
for (const auto & part : parts) {
if (!part.is_uppercase()) {
return false;
}
}
return true;
}
bool is_lowercase() const {
for (const auto & part : parts) {
if (!part.is_lowercase()) {
return false;
}
}
return true;
}
// mark this string as input if other has ALL parts as input
void mark_input_based_on(const string & other) {
if (other.all_parts_are_input()) {

View File

@ -107,7 +107,7 @@ struct value_t {
virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
virtual const std::map<std::string, value> & as_object() const { throw std::runtime_error(type() + " is not an object value"); }
virtual value invoke(const func_args &) const { throw std::runtime_error("Not a function value"); }
virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
virtual bool is_null() const { return false; }
virtual bool is_undefined() const { return false; }
virtual const func_builtins & get_builtins() const {
@ -221,6 +221,9 @@ struct value_array_t : public value_t {
ss << "]";
return ss.str();
}
virtual bool as_bool() const override {
return !val_arr->empty();
}
virtual const func_builtins & get_builtins() const override;
};
using value_array = std::unique_ptr<value_array_t>;
@ -251,17 +254,44 @@ struct value_object_t : public value_t {
tmp->val_obj = this->val_obj;
return tmp;
}
virtual bool as_bool() const override {
return !val_obj->empty();
}
virtual const func_builtins & get_builtins() const override;
};
using value_object = std::unique_ptr<value_object_t>;
struct value_func_t : public value_t {
value_func_t(func_handler & func) {
std::string name; // for debugging
value arg0; // bound "this" argument, if any
value_func_t(const value_func_t & other) {
val_func = other.val_func;
name = other.name;
if (other.arg0) {
arg0 = other.arg0->clone();
}
}
value_func_t(const func_handler & func, std::string func_name = "") {
val_func = func;
name = func_name;
}
value_func_t(const func_handler & func, const value & arg_this, std::string func_name = "") {
val_func = func;
name = func_name;
arg0 = arg_this->clone();
}
virtual value invoke(const func_args & args) const override {
return val_func(args);
if (arg0) {
func_args new_args;
new_args.args.push_back(arg0->clone());
for (const auto & a : args.args) {
new_args.args.push_back(a->clone());
}
return val_func(new_args);
} else {
return val_func(args);
}
}
virtual std::string type() const override { return "Function"; }
virtual std::string as_repr() const override { return type(); }

View File

@ -8,13 +8,69 @@
namespace jinja {
template<typename T>
static value test_type_fn(const func_args & args) {
args.ensure_count(1);
bool is_type = is_val<T>(args.args[0]);
return mk_val<value_bool>(is_type);
}
template<typename T, typename U>
static value test_type_fn(const func_args & args) {
args.ensure_count(1);
bool is_type = is_val<T>(args.args[0]) || is_val<U>(args.args[0]);
return mk_val<value_bool>(is_type);
}
const func_builtins & global_builtins() {
static const func_builtins builtins = {
{"raise_exception", [](const func_args & args) -> value {
args.ensure_count(1);
args.ensure_vals<value_string>();
std::string msg = args.args[0]->as_string().str();
throw raised_exception("Jinja Exception: " + msg);
}},
// tests
{"test_is_boolean", test_type_fn<value_bool>},
{"test_is_callable", test_type_fn<value_func>},
{"test_is_odd", [](const func_args & args) -> value {
args.ensure_vals<value_int>();
int64_t val = args.args[0]->as_int();
return mk_val<value_bool>(val % 2 != 0);
}},
{"test_is_even", [](const func_args & args) -> value {
args.ensure_vals<value_int>();
int64_t val = args.args[0]->as_int();
return mk_val<value_bool>(val % 2 == 0);
}},
{"test_is_false", [](const func_args & args) -> value {
args.ensure_count(1);
bool val = is_val<value_bool>(args.args[0]) && !args.args[0]->as_bool();
return mk_val<value_bool>(val);
}},
{"test_is_true", [](const func_args & args) -> value {
args.ensure_count(1);
bool val = is_val<value_bool>(args.args[0]) && args.args[0]->as_bool();
return mk_val<value_bool>(val);
}},
{"test_is_string", test_type_fn<value_string>},
{"test_is_integer", test_type_fn<value_int>},
{"test_is_number", test_type_fn<value_int, value_float>},
{"test_is_iterable", test_type_fn<value_array, value_string>},
{"test_is_mapping", test_type_fn<value_object>},
{"test_is_lower", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
return mk_val<value_bool>(args.args[0]->val_str.is_lowercase());
}},
{"test_is_upper", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
return mk_val<value_bool>(args.args[0]->val_str.is_uppercase());
}},
{"test_is_none", test_type_fn<value_null>},
{"test_is_defined", [](const func_args & args) -> value {
args.ensure_count(1);
return mk_val<value_bool>(!is_val<value_undefined>(args.args[0]));
}},
{"test_is_undefined", test_type_fn<value_undefined>},
};
return builtins;
}

View File

@ -8,7 +8,7 @@
#include <memory>
#include <algorithm>
#define JJ_DEBUG(msg, ...) printf("jinja-vm: " msg "\n", __VA_ARGS__)
#define JJ_DEBUG(msg, ...) printf("jinja-vm:%3d : " msg "\n", __LINE__, __VA_ARGS__)
//#define JJ_DEBUG(msg, ...) // no-op
namespace jinja {
@ -44,7 +44,7 @@ value identifier::execute(context & ctx) {
value binary_expression::execute(context & ctx) {
value left_val = left->execute(ctx);
JJ_DEBUG("Executing binary expression with operator '%s'", op.value.c_str());
JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right->type().c_str());
// Logical operators
if (op.value == "and") {
@ -168,20 +168,19 @@ value binary_expression::execute(context & ctx) {
throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type());
}
static value try_builtin_func(const std::string & name, const value & input) {
auto builtins = input->get_builtins();
auto it = builtins.find(name);
if (it != builtins.end()) {
JJ_DEBUG("Binding built-in '%s'", name.c_str());
return mk_val<value_func>(it->second, input, name);
}
throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type());
}
value filter_expression::execute(context & ctx) {
value input = operand->execute(ctx);
auto try_builtin = [&](const std::string & name) -> value {
auto builtins = input->get_builtins();
auto it = builtins.find(name);
if (it != builtins.end()) {
func_args args;
args.args.push_back(input->clone());
return it->second(args);
}
throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type());
};
if (is_stmt<identifier>(filter)) {
auto filter_val = dynamic_cast<identifier*>(filter.get())->val;
@ -190,35 +189,12 @@ value filter_expression::execute(context & ctx) {
throw std::runtime_error("to_json filter not implemented");
}
if (is_val<value_array>(input)) {
auto res = try_builtin(filter_val);
if (res) {
return res;
}
throw std::runtime_error("Unknown filter '" + filter_val + "' for array");
} else if (is_val<value_string>(input)) {
auto str = input->as_string();
auto builtins = input->get_builtins();
if (filter_val == "trim") {
filter_val = "strip"; // alias
}
auto res = try_builtin(filter_val);
if (res) {
return res;
}
throw std::runtime_error("Unknown filter '" + filter_val + "' for string");
} else if (is_val<value_int>(input) || is_val<value_float>(input)) {
auto res = try_builtin(filter_val);
if (res) {
return res;
}
throw std::runtime_error("Unknown filter '" + filter_val + "' for number");
} else {
throw std::runtime_error("Filters not supported for type " + input->type());
auto str = input->as_string();
if (filter_val == "trim") {
filter_val = "strip"; // alias
}
JJ_DEBUG("Applying filter '%s' to %s", filter_val.c_str(), input->type().c_str());
return try_builtin_func(filter_val, input);
} else if (is_stmt<call_expression>(filter)) {
// TODO
@ -230,6 +206,44 @@ value filter_expression::execute(context & ctx) {
}
}
value test_expression::execute(context & ctx) {
// NOTE: "value is something" translates to function call "test_is_something(value)"
const auto & builtins = global_builtins();
if (!is_stmt<identifier>(test)) {
throw std::runtime_error("Invalid test expression");
}
auto test_id = dynamic_cast<identifier*>(test.get())->val;
auto it = builtins.find("test_is_" + test_id);
JJ_DEBUG("Test expression %s '%s'", operand->type().c_str(), test_id.c_str());
if (it == builtins.end()) {
throw std::runtime_error("Unknown test '" + test_id + "'");
}
func_args args;
args.args.push_back(operand->execute(ctx));
return it->second(args);
}
value unary_expression::execute(context & ctx) {
value operand_val = argument->execute(ctx);
JJ_DEBUG("Executing unary expression with operator '%s'", op.value.c_str());
if (op.value == "not") {
return mk_val<value_bool>(!operand_val->as_bool());
} else if (op.value == "-") {
if (is_val<value_int>(operand_val)) {
return mk_val<value_int>(-operand_val->as_int());
} else if (is_val<value_float>(operand_val)) {
return mk_val<value_float>(-operand_val->as_float());
} else {
throw std::runtime_error("Unary - operator requires numeric operand");
}
}
throw std::runtime_error("Unknown unary operator '" + op.value + "'");
}
value if_statement::execute(context & ctx) {
value test_val = test->execute(ctx);
auto out = mk_val<value_array>();
@ -415,16 +429,46 @@ value set_statement::execute(context & ctx) {
return mk_val<value_null>();
}
value macro_statement::execute(context & ctx) {
std::string name = dynamic_cast<identifier*>(this->name.get())->val;
const func_handler func = [this, &ctx, name](const func_args & args) -> value {
JJ_DEBUG("Invoking macro '%s' with %zu arguments", name.c_str(), args.args.size());
context macro_ctx(ctx); // new scope for macro execution
// bind parameters
size_t param_count = this->args.size();
size_t arg_count = args.args.size();
for (size_t i = 0; i < param_count; ++i) {
std::string param_name = dynamic_cast<identifier*>(this->args[i].get())->val;
if (i < arg_count) {
macro_ctx.var[param_name] = args.args[i]->clone();
} else {
macro_ctx.var[param_name] = mk_val<value_undefined>();
}
}
// execute macro body
return exec_statements(this->body, macro_ctx);
};
JJ_DEBUG("Defining macro '%s' with %zu parameters", name.c_str(), args.size());
ctx.var[name] = mk_val<value_func>(func);
return mk_val<value_null>();
}
value member_expression::execute(context & ctx) {
value object = this->object->execute(ctx);
value property;
if (this->computed) {
JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str());
property = this->property->execute(ctx);
} else {
property = mk_val<value_string>(dynamic_cast<identifier*>(this->property.get())->val);
}
JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str());
value val = mk_val<value_undefined>();
if (is_val<value_object>(object)) {
@ -432,18 +476,13 @@ value member_expression::execute(context & ctx) {
throw std::runtime_error("Cannot access object with non-string: got " + property->type());
}
auto key = property->as_string().str();
JJ_DEBUG("Accessing object property '%s'", key.c_str());
auto & obj = object->as_object();
auto it = obj.find(key);
if (it != obj.end()) {
val = it->second->clone();
} else {
auto builtins = object->get_builtins();
auto bit = builtins.find(key);
if (bit != builtins.end()) {
func_args args;
args.args.push_back(object->clone());
val = bit->second(args);
}
val = try_builtin_func(key, object);
}
} else if (is_val<value_array>(object) || is_val<value_string>(object)) {
@ -464,13 +503,7 @@ value member_expression::execute(context & ctx) {
} else if (is_val<value_string>(property)) {
auto key = property->as_string().str();
JJ_DEBUG("Accessing %s built-in '%s'", is_val<value_array>(object) ? "array" : "string", key.c_str());
auto builtins = object->get_builtins();
auto bit = builtins.find(key);
if (bit != builtins.end()) {
func_args args;
args.args.push_back(object->clone());
val = bit->second(args);
}
val = try_builtin_func(key, object);
} else {
throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type());
}
@ -480,13 +513,7 @@ value member_expression::execute(context & ctx) {
throw std::runtime_error("Cannot access property with non-string: got " + property->type());
}
auto key = property->as_string().str();
auto builtins = object->get_builtins();
auto bit = builtins.find(key);
if (bit != builtins.end()) {
func_args args;
args.args.push_back(object->clone());
val = bit->second(args);
}
val = try_builtin_func(key, object);
}
return val;

View File

@ -166,12 +166,16 @@ struct macro_statement : public statement {
}
std::string type() const override { return "Macro"; }
value execute(context & ctx) override;
};
struct comment_statement : public statement {
std::string val;
explicit comment_statement(const std::string & v) : val(v) {}
std::string type() const override { return "Comment"; }
value execute(context &) override {
return mk_val<value_null>();
}
};
// Expressions
@ -339,6 +343,7 @@ struct select_expression : public expression {
/**
* An operation with two sides, separated by the "is" operator.
* NOTE: "value is something" translates to function call "test_is_something(value)"
*/
struct test_expression : public expression {
statement_ptr operand;
@ -351,6 +356,7 @@ struct test_expression : public expression {
chk_type<identifier>(this->test);
}
std::string type() const override { return "TestExpression"; }
value execute(context & ctx) override;
};
/**
@ -365,6 +371,7 @@ struct unary_expression : public expression {
chk_type<expression>(this->argument);
}
std::string type() const override { return "UnaryExpression"; }
value execute(context & ctx) override;
};
struct slice_expression : public expression {
@ -442,14 +449,34 @@ struct vm {
context & ctx;
explicit vm(context & ctx) : ctx(ctx) {}
std::vector<value> execute(program & prog) {
std::vector<value> results;
value_array execute(program & prog) {
value_array results = mk_val<value_array>();
for (auto & stmt : prog.body) {
value res = stmt->execute(ctx);
results.push_back(std::move(res));
results->val_arr->push_back(std::move(res));
}
return results;
}
std::vector<jinja::string_part> gather_string_parts(const value & val) {
std::vector<jinja::string_part> parts;
gather_string_parts_recursive(val, parts);
return parts;
}
void gather_string_parts_recursive(const value & val, std::vector<jinja::string_part> & parts) {
if (is_val<value_string>(val)) {
const auto & str_val = dynamic_cast<value_string_t*>(val.get())->val_str;
for (const auto & part : str_val.parts) {
parts.push_back(part);
}
} else if (is_val<value_array>(val)) {
auto items = dynamic_cast<value_array_t*>(val.get())->val_arr.get();
for (const auto & item : *items) {
gather_string_parts_recursive(item, parts);
}
}
}
};
} // namespace jinja

View File

@ -3,6 +3,7 @@
#include <sstream>
#include <regex>
#include <iostream>
#include <fstream>
#undef NDEBUG
#include <cassert>
@ -11,12 +12,15 @@
#include "jinja/jinja-lexer.h"
int main(void) {
std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}";
//std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}";
//std::string contents = "{% if messages[0]['role'] != 'system' %}nice {{ messages[0]['content'] }}{% endif %}";
//std::string contents = "<some_tokens> {{ messages[0]['content'] }} <another_token>";
std::ifstream infile("models/templates/moonshotai-Kimi-K2.jinja");
std::string contents((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
std::cout << "=== INPUT ===\n" << contents << "\n\n";
jinja::lexer lexer;
@ -56,14 +60,12 @@ int main(void) {
ctx.var["messages"] = std::move(messages);
jinja::vm vm(ctx);
auto results = vm.execute(ast);
const jinja::value results = vm.execute(ast);
auto parts = vm.gather_string_parts(results);
std::cout << "\n=== RESULTS ===\n";
for (const auto & res : results) {
if (res->is_null()) {
continue;
}
std::cout << "result type: " << res->type() << " | value: " << res->as_repr();
for (const auto & part : parts) {
std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n";
}
return 0;