From e858b7a0a30fc4f2cb2b5e6ee5adc7210c87d8b1 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 2 Jan 2026 16:28:04 +0100 Subject: [PATCH] add minimal caps system --- common/jinja/jinja-caps.h | 159 +++++++++++++++++++++++++++++++++++ common/jinja/jinja-value.cpp | 7 +- common/jinja/jinja-value.h | 10 +++ common/jinja/jinja-vm.cpp | 41 +++++++-- common/jinja/jinja-vm.h | 4 + tests/test-chat-jinja.cpp | 11 +-- 6 files changed, 219 insertions(+), 13 deletions(-) create mode 100644 common/jinja/jinja-caps.h diff --git a/common/jinja/jinja-caps.h b/common/jinja/jinja-caps.h new file mode 100644 index 0000000000..eca5782903 --- /dev/null +++ b/common/jinja/jinja-caps.h @@ -0,0 +1,159 @@ +#pragma once + +#include + +#include "jinja-value.h" +#include "jinja-vm.h" + +#define FILENAME "jinja-caps" + +namespace jinja { + +struct caps { + bool content_string = true; + bool content_array = true; +}; + +using caps_messages_fn = std::function; +using caps_analyze_fn = std::function; +static void caps_try_execute(jinja::program & prog, + caps_messages_fn messages_fn, + caps_messages_fn tools_fn, + caps_analyze_fn analyze_fn) { + context ctx; + ctx.is_get_stats = true; + + value messages = messages_fn(); + value tools = tools_fn(); + + ctx.set_val("messages", messages); + ctx.set_val("tools", tools); + ctx.set_val("add_generation_prompt", mk_val(true)); + + bool success = false; + try { + jinja::vm vm(ctx); + vm.execute(prog); + success = true; + } catch (const std::exception & e) { + JJ_DEBUG("Exception during execution: %s", e.what()); + // ignore exceptions during capability analysis + } + return analyze_fn(success, messages, tools); +} + +// for debugging only +static void caps_print_stats(value & v, std::string path) { + std::string ops; + for (const auto & name : v->stats.ops) { + ops += name + " "; + } + JJ_DEBUG("Value %s, type: %s %s, ops: %s", + path.c_str(), + v->type().c_str(), + v->stats.used ? "(used)" : "", + ops.c_str()); +} + +static caps caps_get(jinja::program & prog) { + caps result; + + static const auto has_op = [](value & v, const std::string & op_name) { + return v->stats.ops.find(op_name) != v->stats.ops.end(); + }; + + // case: given content as string, check if it's accessed as array + caps_try_execute( + prog, + [&]() { + auto messages = mk_val(); + { + value_object msg = mk_val(); + msg->insert("role", mk_val("user")); + msg->insert("content", mk_val("User message")); + messages->push_back(msg); + } + return messages; + }, + [&]() { + return mk_val(); + }, + [&](bool, value & messages, value &) { + auto & content = messages->at(0)->at("content"); + caps_print_stats(content, "messages[0].content"); + if (has_op(content, "selectattr") || has_op(content, "array_access")) { + // accessed as an array + JJ_DEBUG("%s", "Force content as array"); + result.content_string = false; + result.content_array = true; + } + } + ); + + // case: given content as array, check if it's supported or not + caps_try_execute( + prog, + [&]() { + auto messages = mk_val(); + { + value_object msg = mk_val(); + msg->insert("role", mk_val("user")); + value_array content_arr = mk_val(); + { + value_object content_part = mk_val(); + content_part->insert("type", mk_val("text")); + content_part->insert("text", mk_val("User message")); + content_arr->push_back(content_part); + } + msg->insert("content", content_arr); + messages->push_back(msg); + } + return messages; + }, + [&]() { + return mk_val(); + }, + [&](bool success, value & messages, value &) { + auto & content = messages->at(0)->at("content"); + caps_print_stats(content, "messages[0].content"); + if (!success) { + JJ_DEBUG("%s", "Cannot handle content as array"); + result.content_array = false; + } + } + ); + + return result; +} + +static void caps_apply_workarounds(context & ctx, const caps & c) { + auto messages = ctx.get_val("messages"); + + if (!is_val(messages)) { + throw std::runtime_error("Expected messages to be an array"); + } + + if (!c.content_string) { + for (auto & msg : messages->val_arr) { + if (!is_val(msg)) { + throw std::runtime_error("Expected messages[i] to be an object"); + } + auto obj_ptr = cast_val(msg); + auto & content = obj_ptr->at("content"); + if (!is_val(content)) { + JJ_DEBUG("%s", "Converting message content to array"); + auto str_content = content->as_string(); + value_array arr_content = mk_val(); + value_object content_part = mk_val(); + content_part->insert("type", mk_val("text")); + content_part->insert("text", mk_val(str_content)); + arr_content->push_back(content_part); + obj_ptr->insert("content", arr_content); + } + } + } + + ctx.set_val("messages", messages); +} + +} // namespace jinja diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index 270caafede..4da4584e23 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -12,7 +12,7 @@ #include #include -#define FILENAME "jinja-vm-builtins" +#define FILENAME "jinja-value" namespace jinja { @@ -408,6 +408,11 @@ const func_builtins & value_string_t::get_builtins() const { res->val_str.mark_input_based_on(input->as_string()); return res; }}, + {"safe", [](const func_args & args) -> value { + // no-op for now + args.ensure_vals(); + return args.args[0]; + }}, {"selectattr", [](const func_args &) -> value { throw std::runtime_error("String selectattr builtin not supported"); }}, diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 6483d460a3..9cb57f90f3 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -107,6 +107,13 @@ struct value_t { func_handler val_func; + // only used if ctx.is_get_stats = true + struct stats_t { + bool used = false; + // ops can be builtin calls or operators: "array_access", "object_access" + std::set ops; + } stats; + value_t() = default; value_t(const value_t &) = default; virtual ~value_t() = default; @@ -126,6 +133,9 @@ struct value_t { throw std::runtime_error("No builtins available for type " + type()); } + virtual value & at(const std::string & key) { return val_obj[key]; } + virtual value & at(size_t index) { return val_arr.at(index); } + virtual std::string as_repr() const { return as_string().str(); } }; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 076e041ef4..0728054c13 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -66,6 +66,9 @@ value identifier::execute_impl(context & ctx) { auto it = ctx.get_val(val); auto builtins = global_builtins(); if (!it->is_undefined()) { + if (ctx.is_get_stats) { + it->stats.used = true; + } JJ_DEBUG("Identifier '%s' found", val.c_str()); return it; } else if (builtins.find(val) != builtins.end()) { @@ -236,7 +239,12 @@ value binary_expression::execute_impl(context & ctx) { throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type()); } -static value try_builtin_func(const std::string & name, const value & input, bool undef_on_missing = false) { +static value try_builtin_func(context & ctx, const std::string & name, value & input, bool undef_on_missing = false) { + JJ_DEBUG("Trying built-in function '%s' for type %s", name.c_str(), input->type().c_str()); + if (ctx.is_get_stats) { + input->stats.used = true; + input->stats.ops.insert(name); + } auto builtins = input->get_builtins(); auto it = builtins.find(name); if (it != builtins.end()) { @@ -266,7 +274,7 @@ value filter_expression::execute_impl(context & ctx) { filter_id = "strip"; // alias } JJ_DEBUG("Applying filter '%s' to %s", filter_id.c_str(), input->type().c_str()); - return try_builtin_func(filter_id, input)->invoke(func_args(ctx)); + return try_builtin_func(ctx, filter_id, input)->invoke(func_args(ctx)); } else if (is_stmt(filter)) { auto call = cast_stmt(filter); @@ -278,7 +286,7 @@ value filter_expression::execute_impl(context & ctx) { args.args.push_back(arg_expr->execute(ctx)); } - return try_builtin_func(filter_id, input)->invoke(args); + return try_builtin_func(ctx, filter_id, input)->invoke(args); } else { throw std::runtime_error("Invalid filter expression"); @@ -401,12 +409,20 @@ value for_statement::execute_impl(context & ctx) { tuple->push_back(p.second); items.push_back(tuple); } + if (ctx.is_get_stats) { + iterable_val->stats.used = true; + iterable_val->stats.ops.insert("object_access"); + } } else { JJ_DEBUG("%s", "For loop over array items"); auto & arr = iterable_val->as_array(); for (const auto & item : arr) { items.push_back(item); } + if (ctx.is_get_stats) { + iterable_val->stats.used = true; + iterable_val->stats.ops.insert("array_access"); + } } std::vector> scope_update_fns; @@ -624,7 +640,7 @@ value member_expression::execute_impl(context & ctx) { start_val->as_repr().c_str(), stop_val->as_repr().c_str(), step_val->as_repr().c_str()); - auto slice_func = try_builtin_func("slice", object); + auto slice_func = try_builtin_func(ctx, "slice", object); func_args args(ctx); args.args.push_back(start_val); args.args.push_back(stop_val); @@ -654,7 +670,7 @@ value member_expression::execute_impl(context & ctx) { if (it != obj.end()) { val = it->second; } else { - val = try_builtin_func(key, object, true); + val = try_builtin_func(ctx, key, object, true); } JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str()); @@ -676,10 +692,11 @@ value member_expression::execute_impl(context & ctx) { val = mk_val(std::string(1, str[index])); } } + } else if (is_val(property)) { auto key = property->as_string().str(); JJ_DEBUG("Accessing %s built-in '%s'", is_val(object) ? "array" : "string", key.c_str()); - val = try_builtin_func(key, object); + val = try_builtin_func(ctx, key, object); } else { throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type()); } @@ -689,7 +706,17 @@ value member_expression::execute_impl(context & ctx) { throw std::runtime_error("Cannot access property with non-string: got " + property->type()); } auto key = property->as_string().str(); - val = try_builtin_func(key, object); + val = try_builtin_func(ctx, key, object); + } + + if (ctx.is_get_stats && val && object && property) { + val->stats.used = true; + object->stats.used = true; + if (is_val(property)) { + object->stats.ops.insert("array_access"); + } else if (is_val(property)) { + object->stats.ops.insert("object_access"); + } } return val; diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index c1f91dd81f..099111db46 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -50,6 +50,8 @@ struct context { std::string source; // for debugging std::time_t current_time; // for functions that need current time + bool is_get_stats = false; // whether to collect stats + context() { global = mk_val(); global->insert("true", mk_val(true)); @@ -65,6 +67,8 @@ struct context { for (const auto & pair : pvar) { set_val(pair.first, pair.second); } + current_time = parent.current_time; + is_get_stats = parent.is_get_stats; } value get_val(const std::string & name) { diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index b22c8a56d5..91a7b3ff87 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -13,6 +13,7 @@ #include "jinja/jinja-parser.h" #include "jinja/jinja-lexer.h" +#include "jinja/jinja-caps.h" using json = nlohmann::json; @@ -38,11 +39,7 @@ std::string DEFAULT_JSON = R"({ }, { "role": "assistant", - "content": {"__input__": "I am fine, thank you!"} - }, - { - "role": "assistant", - "content": "Calling weather tool.", + "content": {"__input__": "I am fine, thank you!"}, "tool_calls": [ { "function": { @@ -177,11 +174,15 @@ void run_single(std::string contents, json input, const std::string & output_pat // compile to AST jinja::program ast = jinja::parse_from_tokens(lexer_res); + // check caps for workarounds + auto caps = jinja::caps_get(ast); + std::cout << "\n=== RUN ===\n"; jinja::context ctx; ctx.source = lexer_res.preprocessed_source; jinja::global_from_json(ctx, input); + jinja::caps_apply_workarounds(ctx, caps); jinja::vm vm(ctx); const jinja::value results = vm.execute(ast);