add minimal caps system
This commit is contained in:
parent
dce256cf40
commit
e858b7a0a3
|
|
@ -0,0 +1,159 @@
|
|||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "jinja-value.h"
|
||||
#include "jinja-vm.h"
|
||||
|
||||
#define FILENAME "jinja-caps"
|
||||
|
||||
namespace jinja {
|
||||
|
||||
struct caps {
|
||||
bool content_string = true;
|
||||
bool content_array = true;
|
||||
};
|
||||
|
||||
using caps_messages_fn = std::function<value()>;
|
||||
using caps_analyze_fn = std::function<void(bool, value &, value &)>;
|
||||
static void caps_try_execute(jinja::program & prog,
|
||||
caps_messages_fn messages_fn,
|
||||
caps_messages_fn tools_fn,
|
||||
caps_analyze_fn analyze_fn) {
|
||||
context ctx;
|
||||
ctx.is_get_stats = true;
|
||||
|
||||
value messages = messages_fn();
|
||||
value tools = tools_fn();
|
||||
|
||||
ctx.set_val("messages", messages);
|
||||
ctx.set_val("tools", tools);
|
||||
ctx.set_val("add_generation_prompt", mk_val<value_bool>(true));
|
||||
|
||||
bool success = false;
|
||||
try {
|
||||
jinja::vm vm(ctx);
|
||||
vm.execute(prog);
|
||||
success = true;
|
||||
} catch (const std::exception & e) {
|
||||
JJ_DEBUG("Exception during execution: %s", e.what());
|
||||
// ignore exceptions during capability analysis
|
||||
}
|
||||
return analyze_fn(success, messages, tools);
|
||||
}
|
||||
|
||||
// for debugging only
|
||||
static void caps_print_stats(value & v, std::string path) {
|
||||
std::string ops;
|
||||
for (const auto & name : v->stats.ops) {
|
||||
ops += name + " ";
|
||||
}
|
||||
JJ_DEBUG("Value %s, type: %s %s, ops: %s",
|
||||
path.c_str(),
|
||||
v->type().c_str(),
|
||||
v->stats.used ? "(used)" : "",
|
||||
ops.c_str());
|
||||
}
|
||||
|
||||
static caps caps_get(jinja::program & prog) {
|
||||
caps result;
|
||||
|
||||
static const auto has_op = [](value & v, const std::string & op_name) {
|
||||
return v->stats.ops.find(op_name) != v->stats.ops.end();
|
||||
};
|
||||
|
||||
// case: given content as string, check if it's accessed as array
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
auto messages = mk_val<value_array>();
|
||||
{
|
||||
value_object msg = mk_val<value_object>();
|
||||
msg->insert("role", mk_val<value_string>("user"));
|
||||
msg->insert("content", mk_val<value_string>("User message"));
|
||||
messages->push_back(msg);
|
||||
}
|
||||
return messages;
|
||||
},
|
||||
[&]() {
|
||||
return mk_val<value_array>();
|
||||
},
|
||||
[&](bool, value & messages, value &) {
|
||||
auto & content = messages->at(0)->at("content");
|
||||
caps_print_stats(content, "messages[0].content");
|
||||
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
|
||||
// accessed as an array
|
||||
JJ_DEBUG("%s", "Force content as array");
|
||||
result.content_string = false;
|
||||
result.content_array = true;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// case: given content as array, check if it's supported or not
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
auto messages = mk_val<value_array>();
|
||||
{
|
||||
value_object msg = mk_val<value_object>();
|
||||
msg->insert("role", mk_val<value_string>("user"));
|
||||
value_array content_arr = mk_val<value_array>();
|
||||
{
|
||||
value_object content_part = mk_val<value_object>();
|
||||
content_part->insert("type", mk_val<value_string>("text"));
|
||||
content_part->insert("text", mk_val<value_string>("User message"));
|
||||
content_arr->push_back(content_part);
|
||||
}
|
||||
msg->insert("content", content_arr);
|
||||
messages->push_back(msg);
|
||||
}
|
||||
return messages;
|
||||
},
|
||||
[&]() {
|
||||
return mk_val<value_array>();
|
||||
},
|
||||
[&](bool success, value & messages, value &) {
|
||||
auto & content = messages->at(0)->at("content");
|
||||
caps_print_stats(content, "messages[0].content");
|
||||
if (!success) {
|
||||
JJ_DEBUG("%s", "Cannot handle content as array");
|
||||
result.content_array = false;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static void caps_apply_workarounds(context & ctx, const caps & c) {
|
||||
auto messages = ctx.get_val("messages");
|
||||
|
||||
if (!is_val<value_array>(messages)) {
|
||||
throw std::runtime_error("Expected messages to be an array");
|
||||
}
|
||||
|
||||
if (!c.content_string) {
|
||||
for (auto & msg : messages->val_arr) {
|
||||
if (!is_val<value_object>(msg)) {
|
||||
throw std::runtime_error("Expected messages[i] to be an object");
|
||||
}
|
||||
auto obj_ptr = cast_val<value_object>(msg);
|
||||
auto & content = obj_ptr->at("content");
|
||||
if (!is_val<value_array>(content)) {
|
||||
JJ_DEBUG("%s", "Converting message content to array");
|
||||
auto str_content = content->as_string();
|
||||
value_array arr_content = mk_val<value_array>();
|
||||
value_object content_part = mk_val<value_object>();
|
||||
content_part->insert("type", mk_val<value_string>("text"));
|
||||
content_part->insert("text", mk_val<value_string>(str_content));
|
||||
arr_content->push_back(content_part);
|
||||
obj_ptr->insert("content", arr_content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.set_val("messages", messages);
|
||||
}
|
||||
|
||||
} // namespace jinja
|
||||
|
|
@ -12,7 +12,7 @@
|
|||
#include <optional>
|
||||
#include <algorithm>
|
||||
|
||||
#define FILENAME "jinja-vm-builtins"
|
||||
#define FILENAME "jinja-value"
|
||||
|
||||
namespace jinja {
|
||||
|
||||
|
|
@ -408,6 +408,11 @@ const func_builtins & value_string_t::get_builtins() const {
|
|||
res->val_str.mark_input_based_on(input->as_string());
|
||||
return res;
|
||||
}},
|
||||
{"safe", [](const func_args & args) -> value {
|
||||
// no-op for now
|
||||
args.ensure_vals<value_string>();
|
||||
return args.args[0];
|
||||
}},
|
||||
{"selectattr", [](const func_args &) -> value {
|
||||
throw std::runtime_error("String selectattr builtin not supported");
|
||||
}},
|
||||
|
|
|
|||
|
|
@ -107,6 +107,13 @@ struct value_t {
|
|||
|
||||
func_handler val_func;
|
||||
|
||||
// only used if ctx.is_get_stats = true
|
||||
struct stats_t {
|
||||
bool used = false;
|
||||
// ops can be builtin calls or operators: "array_access", "object_access"
|
||||
std::set<std::string> ops;
|
||||
} stats;
|
||||
|
||||
value_t() = default;
|
||||
value_t(const value_t &) = default;
|
||||
virtual ~value_t() = default;
|
||||
|
|
@ -126,6 +133,9 @@ struct value_t {
|
|||
throw std::runtime_error("No builtins available for type " + type());
|
||||
}
|
||||
|
||||
virtual value & at(const std::string & key) { return val_obj[key]; }
|
||||
virtual value & at(size_t index) { return val_arr.at(index); }
|
||||
|
||||
virtual std::string as_repr() const { return as_string().str(); }
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -66,6 +66,9 @@ value identifier::execute_impl(context & ctx) {
|
|||
auto it = ctx.get_val(val);
|
||||
auto builtins = global_builtins();
|
||||
if (!it->is_undefined()) {
|
||||
if (ctx.is_get_stats) {
|
||||
it->stats.used = true;
|
||||
}
|
||||
JJ_DEBUG("Identifier '%s' found", val.c_str());
|
||||
return it;
|
||||
} else if (builtins.find(val) != builtins.end()) {
|
||||
|
|
@ -236,7 +239,12 @@ value binary_expression::execute_impl(context & ctx) {
|
|||
throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type());
|
||||
}
|
||||
|
||||
static value try_builtin_func(const std::string & name, const value & input, bool undef_on_missing = false) {
|
||||
static value try_builtin_func(context & ctx, const std::string & name, value & input, bool undef_on_missing = false) {
|
||||
JJ_DEBUG("Trying built-in function '%s' for type %s", name.c_str(), input->type().c_str());
|
||||
if (ctx.is_get_stats) {
|
||||
input->stats.used = true;
|
||||
input->stats.ops.insert(name);
|
||||
}
|
||||
auto builtins = input->get_builtins();
|
||||
auto it = builtins.find(name);
|
||||
if (it != builtins.end()) {
|
||||
|
|
@ -266,7 +274,7 @@ value filter_expression::execute_impl(context & ctx) {
|
|||
filter_id = "strip"; // alias
|
||||
}
|
||||
JJ_DEBUG("Applying filter '%s' to %s", filter_id.c_str(), input->type().c_str());
|
||||
return try_builtin_func(filter_id, input)->invoke(func_args(ctx));
|
||||
return try_builtin_func(ctx, filter_id, input)->invoke(func_args(ctx));
|
||||
|
||||
} else if (is_stmt<call_expression>(filter)) {
|
||||
auto call = cast_stmt<call_expression>(filter);
|
||||
|
|
@ -278,7 +286,7 @@ value filter_expression::execute_impl(context & ctx) {
|
|||
args.args.push_back(arg_expr->execute(ctx));
|
||||
}
|
||||
|
||||
return try_builtin_func(filter_id, input)->invoke(args);
|
||||
return try_builtin_func(ctx, filter_id, input)->invoke(args);
|
||||
|
||||
} else {
|
||||
throw std::runtime_error("Invalid filter expression");
|
||||
|
|
@ -401,12 +409,20 @@ value for_statement::execute_impl(context & ctx) {
|
|||
tuple->push_back(p.second);
|
||||
items.push_back(tuple);
|
||||
}
|
||||
if (ctx.is_get_stats) {
|
||||
iterable_val->stats.used = true;
|
||||
iterable_val->stats.ops.insert("object_access");
|
||||
}
|
||||
} else {
|
||||
JJ_DEBUG("%s", "For loop over array items");
|
||||
auto & arr = iterable_val->as_array();
|
||||
for (const auto & item : arr) {
|
||||
items.push_back(item);
|
||||
}
|
||||
if (ctx.is_get_stats) {
|
||||
iterable_val->stats.used = true;
|
||||
iterable_val->stats.ops.insert("array_access");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::function<void(context &)>> scope_update_fns;
|
||||
|
|
@ -624,7 +640,7 @@ value member_expression::execute_impl(context & ctx) {
|
|||
start_val->as_repr().c_str(),
|
||||
stop_val->as_repr().c_str(),
|
||||
step_val->as_repr().c_str());
|
||||
auto slice_func = try_builtin_func("slice", object);
|
||||
auto slice_func = try_builtin_func(ctx, "slice", object);
|
||||
func_args args(ctx);
|
||||
args.args.push_back(start_val);
|
||||
args.args.push_back(stop_val);
|
||||
|
|
@ -654,7 +670,7 @@ value member_expression::execute_impl(context & ctx) {
|
|||
if (it != obj.end()) {
|
||||
val = it->second;
|
||||
} else {
|
||||
val = try_builtin_func(key, object, true);
|
||||
val = try_builtin_func(ctx, key, object, true);
|
||||
}
|
||||
JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str());
|
||||
|
||||
|
|
@ -676,10 +692,11 @@ value member_expression::execute_impl(context & ctx) {
|
|||
val = mk_val<value_string>(std::string(1, str[index]));
|
||||
}
|
||||
}
|
||||
|
||||
} else if (is_val<value_string>(property)) {
|
||||
auto key = property->as_string().str();
|
||||
JJ_DEBUG("Accessing %s built-in '%s'", is_val<value_array>(object) ? "array" : "string", key.c_str());
|
||||
val = try_builtin_func(key, object);
|
||||
val = try_builtin_func(ctx, key, object);
|
||||
} else {
|
||||
throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type());
|
||||
}
|
||||
|
|
@ -689,7 +706,17 @@ value member_expression::execute_impl(context & ctx) {
|
|||
throw std::runtime_error("Cannot access property with non-string: got " + property->type());
|
||||
}
|
||||
auto key = property->as_string().str();
|
||||
val = try_builtin_func(key, object);
|
||||
val = try_builtin_func(ctx, key, object);
|
||||
}
|
||||
|
||||
if (ctx.is_get_stats && val && object && property) {
|
||||
val->stats.used = true;
|
||||
object->stats.used = true;
|
||||
if (is_val<value_int>(property)) {
|
||||
object->stats.ops.insert("array_access");
|
||||
} else if (is_val<value_string>(property)) {
|
||||
object->stats.ops.insert("object_access");
|
||||
}
|
||||
}
|
||||
|
||||
return val;
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ struct context {
|
|||
std::string source; // for debugging
|
||||
std::time_t current_time; // for functions that need current time
|
||||
|
||||
bool is_get_stats = false; // whether to collect stats
|
||||
|
||||
context() {
|
||||
global = mk_val<value_object>();
|
||||
global->insert("true", mk_val<value_bool>(true));
|
||||
|
|
@ -65,6 +67,8 @@ struct context {
|
|||
for (const auto & pair : pvar) {
|
||||
set_val(pair.first, pair.second);
|
||||
}
|
||||
current_time = parent.current_time;
|
||||
is_get_stats = parent.is_get_stats;
|
||||
}
|
||||
|
||||
value get_val(const std::string & name) {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
#include "jinja/jinja-parser.h"
|
||||
#include "jinja/jinja-lexer.h"
|
||||
#include "jinja/jinja-caps.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
|
|
@ -38,11 +39,7 @@ std::string DEFAULT_JSON = R"({
|
|||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": {"__input__": "I am fine, thank you!"}
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Calling weather tool.",
|
||||
"content": {"__input__": "I am fine, thank you!"},
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
|
|
@ -177,11 +174,15 @@ void run_single(std::string contents, json input, const std::string & output_pat
|
|||
// compile to AST
|
||||
jinja::program ast = jinja::parse_from_tokens(lexer_res);
|
||||
|
||||
// check caps for workarounds
|
||||
auto caps = jinja::caps_get(ast);
|
||||
|
||||
std::cout << "\n=== RUN ===\n";
|
||||
jinja::context ctx;
|
||||
ctx.source = lexer_res.preprocessed_source;
|
||||
|
||||
jinja::global_from_json(ctx, input);
|
||||
jinja::caps_apply_workarounds(ctx, caps);
|
||||
|
||||
jinja::vm vm(ctx);
|
||||
const jinja::value results = vm.execute(ast);
|
||||
|
|
|
|||
Loading…
Reference in New Issue