add minimal caps system

This commit is contained in:
Xuan Son Nguyen 2026-01-02 16:28:04 +01:00
parent dce256cf40
commit e858b7a0a3
6 changed files with 219 additions and 13 deletions

159
common/jinja/jinja-caps.h Normal file
View File

@ -0,0 +1,159 @@
#pragma once
#include <vector>
#include "jinja-value.h"
#include "jinja-vm.h"
#define FILENAME "jinja-caps"
namespace jinja {
struct caps {
bool content_string = true;
bool content_array = true;
};
using caps_messages_fn = std::function<value()>;
using caps_analyze_fn = std::function<void(bool, value &, value &)>;
static void caps_try_execute(jinja::program & prog,
caps_messages_fn messages_fn,
caps_messages_fn tools_fn,
caps_analyze_fn analyze_fn) {
context ctx;
ctx.is_get_stats = true;
value messages = messages_fn();
value tools = tools_fn();
ctx.set_val("messages", messages);
ctx.set_val("tools", tools);
ctx.set_val("add_generation_prompt", mk_val<value_bool>(true));
bool success = false;
try {
jinja::vm vm(ctx);
vm.execute(prog);
success = true;
} catch (const std::exception & e) {
JJ_DEBUG("Exception during execution: %s", e.what());
// ignore exceptions during capability analysis
}
return analyze_fn(success, messages, tools);
}
// for debugging only
static void caps_print_stats(value & v, std::string path) {
std::string ops;
for (const auto & name : v->stats.ops) {
ops += name + " ";
}
JJ_DEBUG("Value %s, type: %s %s, ops: %s",
path.c_str(),
v->type().c_str(),
v->stats.used ? "(used)" : "",
ops.c_str());
}
static caps caps_get(jinja::program & prog) {
caps result;
static const auto has_op = [](value & v, const std::string & op_name) {
return v->stats.ops.find(op_name) != v->stats.ops.end();
};
// case: given content as string, check if it's accessed as array
caps_try_execute(
prog,
[&]() {
auto messages = mk_val<value_array>();
{
value_object msg = mk_val<value_object>();
msg->insert("role", mk_val<value_string>("user"));
msg->insert("content", mk_val<value_string>("User message"));
messages->push_back(msg);
}
return messages;
},
[&]() {
return mk_val<value_array>();
},
[&](bool, value & messages, value &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
// accessed as an array
JJ_DEBUG("%s", "Force content as array");
result.content_string = false;
result.content_array = true;
}
}
);
// case: given content as array, check if it's supported or not
caps_try_execute(
prog,
[&]() {
auto messages = mk_val<value_array>();
{
value_object msg = mk_val<value_object>();
msg->insert("role", mk_val<value_string>("user"));
value_array content_arr = mk_val<value_array>();
{
value_object content_part = mk_val<value_object>();
content_part->insert("type", mk_val<value_string>("text"));
content_part->insert("text", mk_val<value_string>("User message"));
content_arr->push_back(content_part);
}
msg->insert("content", content_arr);
messages->push_back(msg);
}
return messages;
},
[&]() {
return mk_val<value_array>();
},
[&](bool success, value & messages, value &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (!success) {
JJ_DEBUG("%s", "Cannot handle content as array");
result.content_array = false;
}
}
);
return result;
}
static void caps_apply_workarounds(context & ctx, const caps & c) {
auto messages = ctx.get_val("messages");
if (!is_val<value_array>(messages)) {
throw std::runtime_error("Expected messages to be an array");
}
if (!c.content_string) {
for (auto & msg : messages->val_arr) {
if (!is_val<value_object>(msg)) {
throw std::runtime_error("Expected messages[i] to be an object");
}
auto obj_ptr = cast_val<value_object>(msg);
auto & content = obj_ptr->at("content");
if (!is_val<value_array>(content)) {
JJ_DEBUG("%s", "Converting message content to array");
auto str_content = content->as_string();
value_array arr_content = mk_val<value_array>();
value_object content_part = mk_val<value_object>();
content_part->insert("type", mk_val<value_string>("text"));
content_part->insert("text", mk_val<value_string>(str_content));
arr_content->push_back(content_part);
obj_ptr->insert("content", arr_content);
}
}
}
ctx.set_val("messages", messages);
}
} // namespace jinja

View File

@ -12,7 +12,7 @@
#include <optional>
#include <algorithm>
#define FILENAME "jinja-vm-builtins"
#define FILENAME "jinja-value"
namespace jinja {
@ -408,6 +408,11 @@ const func_builtins & value_string_t::get_builtins() const {
res->val_str.mark_input_based_on(input->as_string());
return res;
}},
{"safe", [](const func_args & args) -> value {
// no-op for now
args.ensure_vals<value_string>();
return args.args[0];
}},
{"selectattr", [](const func_args &) -> value {
throw std::runtime_error("String selectattr builtin not supported");
}},

View File

@ -107,6 +107,13 @@ struct value_t {
func_handler val_func;
// only used if ctx.is_get_stats = true
struct stats_t {
bool used = false;
// ops can be builtin calls or operators: "array_access", "object_access"
std::set<std::string> ops;
} stats;
value_t() = default;
value_t(const value_t &) = default;
virtual ~value_t() = default;
@ -126,6 +133,9 @@ struct value_t {
throw std::runtime_error("No builtins available for type " + type());
}
virtual value & at(const std::string & key) { return val_obj[key]; }
virtual value & at(size_t index) { return val_arr.at(index); }
virtual std::string as_repr() const { return as_string().str(); }
};

View File

@ -66,6 +66,9 @@ value identifier::execute_impl(context & ctx) {
auto it = ctx.get_val(val);
auto builtins = global_builtins();
if (!it->is_undefined()) {
if (ctx.is_get_stats) {
it->stats.used = true;
}
JJ_DEBUG("Identifier '%s' found", val.c_str());
return it;
} else if (builtins.find(val) != builtins.end()) {
@ -236,7 +239,12 @@ value binary_expression::execute_impl(context & ctx) {
throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type());
}
static value try_builtin_func(const std::string & name, const value & input, bool undef_on_missing = false) {
static value try_builtin_func(context & ctx, const std::string & name, value & input, bool undef_on_missing = false) {
JJ_DEBUG("Trying built-in function '%s' for type %s", name.c_str(), input->type().c_str());
if (ctx.is_get_stats) {
input->stats.used = true;
input->stats.ops.insert(name);
}
auto builtins = input->get_builtins();
auto it = builtins.find(name);
if (it != builtins.end()) {
@ -266,7 +274,7 @@ value filter_expression::execute_impl(context & ctx) {
filter_id = "strip"; // alias
}
JJ_DEBUG("Applying filter '%s' to %s", filter_id.c_str(), input->type().c_str());
return try_builtin_func(filter_id, input)->invoke(func_args(ctx));
return try_builtin_func(ctx, filter_id, input)->invoke(func_args(ctx));
} else if (is_stmt<call_expression>(filter)) {
auto call = cast_stmt<call_expression>(filter);
@ -278,7 +286,7 @@ value filter_expression::execute_impl(context & ctx) {
args.args.push_back(arg_expr->execute(ctx));
}
return try_builtin_func(filter_id, input)->invoke(args);
return try_builtin_func(ctx, filter_id, input)->invoke(args);
} else {
throw std::runtime_error("Invalid filter expression");
@ -401,12 +409,20 @@ value for_statement::execute_impl(context & ctx) {
tuple->push_back(p.second);
items.push_back(tuple);
}
if (ctx.is_get_stats) {
iterable_val->stats.used = true;
iterable_val->stats.ops.insert("object_access");
}
} else {
JJ_DEBUG("%s", "For loop over array items");
auto & arr = iterable_val->as_array();
for (const auto & item : arr) {
items.push_back(item);
}
if (ctx.is_get_stats) {
iterable_val->stats.used = true;
iterable_val->stats.ops.insert("array_access");
}
}
std::vector<std::function<void(context &)>> scope_update_fns;
@ -624,7 +640,7 @@ value member_expression::execute_impl(context & ctx) {
start_val->as_repr().c_str(),
stop_val->as_repr().c_str(),
step_val->as_repr().c_str());
auto slice_func = try_builtin_func("slice", object);
auto slice_func = try_builtin_func(ctx, "slice", object);
func_args args(ctx);
args.args.push_back(start_val);
args.args.push_back(stop_val);
@ -654,7 +670,7 @@ value member_expression::execute_impl(context & ctx) {
if (it != obj.end()) {
val = it->second;
} else {
val = try_builtin_func(key, object, true);
val = try_builtin_func(ctx, key, object, true);
}
JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str());
@ -676,10 +692,11 @@ value member_expression::execute_impl(context & ctx) {
val = mk_val<value_string>(std::string(1, str[index]));
}
}
} else if (is_val<value_string>(property)) {
auto key = property->as_string().str();
JJ_DEBUG("Accessing %s built-in '%s'", is_val<value_array>(object) ? "array" : "string", key.c_str());
val = try_builtin_func(key, object);
val = try_builtin_func(ctx, key, object);
} else {
throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type());
}
@ -689,7 +706,17 @@ value member_expression::execute_impl(context & ctx) {
throw std::runtime_error("Cannot access property with non-string: got " + property->type());
}
auto key = property->as_string().str();
val = try_builtin_func(key, object);
val = try_builtin_func(ctx, key, object);
}
if (ctx.is_get_stats && val && object && property) {
val->stats.used = true;
object->stats.used = true;
if (is_val<value_int>(property)) {
object->stats.ops.insert("array_access");
} else if (is_val<value_string>(property)) {
object->stats.ops.insert("object_access");
}
}
return val;

View File

@ -50,6 +50,8 @@ struct context {
std::string source; // for debugging
std::time_t current_time; // for functions that need current time
bool is_get_stats = false; // whether to collect stats
context() {
global = mk_val<value_object>();
global->insert("true", mk_val<value_bool>(true));
@ -65,6 +67,8 @@ struct context {
for (const auto & pair : pvar) {
set_val(pair.first, pair.second);
}
current_time = parent.current_time;
is_get_stats = parent.is_get_stats;
}
value get_val(const std::string & name) {

View File

@ -13,6 +13,7 @@
#include "jinja/jinja-parser.h"
#include "jinja/jinja-lexer.h"
#include "jinja/jinja-caps.h"
using json = nlohmann::json;
@ -38,11 +39,7 @@ std::string DEFAULT_JSON = R"({
},
{
"role": "assistant",
"content": {"__input__": "I am fine, thank you!"}
},
{
"role": "assistant",
"content": "Calling weather tool.",
"content": {"__input__": "I am fine, thank you!"},
"tool_calls": [
{
"function": {
@ -177,11 +174,15 @@ void run_single(std::string contents, json input, const std::string & output_pat
// compile to AST
jinja::program ast = jinja::parse_from_tokens(lexer_res);
// check caps for workarounds
auto caps = jinja::caps_get(ast);
std::cout << "\n=== RUN ===\n";
jinja::context ctx;
ctx.source = lexer_res.preprocessed_source;
jinja::global_from_json(ctx, input);
jinja::caps_apply_workarounds(ctx, caps);
jinja::vm vm(ctx);
const jinja::value results = vm.execute(ast);