impl global_from_json
This commit is contained in:
parent
55fe96a9df
commit
1784a57e7b
|
|
@ -3,6 +3,9 @@
|
|||
#include "jinja-parser.h"
|
||||
#include "jinja-value.h"
|
||||
|
||||
// for converting from JSON to jinja values
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <cctype>
|
||||
#include <vector>
|
||||
|
|
@ -520,4 +523,50 @@ const func_builtins & value_object_t::get_builtins() const {
|
|||
return builtins;
|
||||
}
|
||||
|
||||
static value from_json(const nlohmann::json & j) {
|
||||
if (j.is_null()) {
|
||||
return mk_val<value_null>();
|
||||
} else if (j.is_boolean()) {
|
||||
return mk_val<value_bool>(j.get<bool>());
|
||||
} else if (j.is_number_integer()) {
|
||||
return mk_val<value_int>(j.get<int64_t>());
|
||||
} else if (j.is_number_float()) {
|
||||
return mk_val<value_float>(j.get<double>());
|
||||
} else if (j.is_string()) {
|
||||
return mk_val<value_string>(j.get<std::string>());
|
||||
} else if (j.is_array()) {
|
||||
auto arr = mk_val<value_array>();
|
||||
for (const auto & item : j) {
|
||||
arr->push_back(from_json(item));
|
||||
}
|
||||
return arr;
|
||||
} else if (j.is_object()) {
|
||||
if (j.contains("__input__")) {
|
||||
// handle input marking
|
||||
auto str = mk_val<value_string>(j.at("__input__").get<std::string>());
|
||||
str->mark_input();
|
||||
return str;
|
||||
} else {
|
||||
// normal object
|
||||
auto obj = mk_val<value_object>();
|
||||
for (auto it = j.begin(); it != j.end(); ++it) {
|
||||
obj->insert(it.key(), from_json(it.value()));
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported JSON value type");
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
void global_from_json(context & ctx, const nlohmann::json & json_obj) {
|
||||
if (json_obj.is_null() || !json_obj.is_object()) {
|
||||
throw std::runtime_error("global_from_json: input JSON value must be an object");
|
||||
}
|
||||
for (auto it = json_obj.begin(); it != json_obj.end(); ++it) {
|
||||
ctx.var[it.key()] = from_json(it.value());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jinja
|
||||
|
|
|
|||
|
|
@ -57,8 +57,41 @@ void ensure_val(const value & ptr) {
|
|||
}
|
||||
// End Helper
|
||||
|
||||
|
||||
struct context; // forward declaration
|
||||
|
||||
|
||||
// for converting from JSON to jinja values
|
||||
// example input JSON:
|
||||
// {
|
||||
// "messages": [
|
||||
// {"role": "user", "content": "Hello!"},
|
||||
// {"role": "assistant", "content": "Hi there!"}
|
||||
// ],
|
||||
// "bos_token": "<s>",
|
||||
// "eos_token": "</s>",
|
||||
// }
|
||||
//
|
||||
// to mark strings as user input, wrap them in a special object:
|
||||
// {
|
||||
// "messages": [
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": {"__input__": "Hello!"} // this string is user input
|
||||
// },
|
||||
// ...
|
||||
// ],
|
||||
// }
|
||||
//
|
||||
// marking input can be useful for tracking data provenance
|
||||
// and preventing template injection attacks
|
||||
//
|
||||
// Note: T_JSON can be nlohmann::json or similar types
|
||||
template<typename T_JSON>
|
||||
void global_from_json(context & ctx, const T_JSON & json_obj);
|
||||
|
||||
|
||||
|
||||
struct func_args {
|
||||
std::vector<value> args;
|
||||
context & ctx;
|
||||
|
|
|
|||
|
|
@ -226,7 +226,7 @@ static value try_builtin_func(const std::string & name, const value & input, boo
|
|||
}
|
||||
|
||||
value filter_expression::execute_impl(context & ctx) {
|
||||
value input = operand->execute(ctx);
|
||||
value input = operand ? operand->execute(ctx) : val;
|
||||
|
||||
JJ_DEBUG("Applying filter to %s", input->type().c_str());
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@
|
|||
#include <fstream>
|
||||
#include <filesystem>
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#undef NDEBUG
|
||||
#include <cassert>
|
||||
|
||||
|
|
@ -24,10 +26,14 @@ int main(void) {
|
|||
|
||||
//std::ifstream infile("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja"); std::string contents((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
|
||||
|
||||
std::vector<std::string> failed_tests;
|
||||
|
||||
// list all files in models/templates/ and run each
|
||||
size_t test_count = 0;
|
||||
std::string dir_path = "models/templates/";
|
||||
for (const auto & entry : std::filesystem::directory_iterator(dir_path)) {
|
||||
if (entry.is_regular_file()) {
|
||||
test_count++;
|
||||
std::cout << "\n\n=== RUNNING TEMPLATE FILE: " << entry.path().string() << " ===\n";
|
||||
std::ifstream infile(entry.path());
|
||||
std::string contents((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
|
||||
|
|
@ -35,11 +41,18 @@ int main(void) {
|
|||
run(contents);
|
||||
} catch (const std::exception & e) {
|
||||
std::cout << "Exception: " << e.what() << "\n";
|
||||
std::cout << "=== CURRENT TEMPLATE FILE: " << entry.path().string() << " ===\n";
|
||||
exit(1);
|
||||
std::cout << "=== ERROR WITH TEMPLATE FILE: " << entry.path().string() << " ===\n";
|
||||
failed_tests.push_back(entry.path().string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\n\n=== TEST SUMMARY ===\n";
|
||||
std::cout << "Total tests run: " << test_count << "\n";
|
||||
std::cout << "Total failed tests: " << failed_tests.size() << "\n";
|
||||
for (const auto & test : failed_tests) {
|
||||
std::cout << "FAILED TEST: " << test << "\n";
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -66,25 +79,20 @@ void run(std::string contents) {
|
|||
jinja::context ctx;
|
||||
ctx.source = lexer_res.preprocessed_source;
|
||||
|
||||
auto make_non_special_string = [](const std::string & s) {
|
||||
jinja::value_string str_val = jinja::mk_val<jinja::value_string>(s);
|
||||
str_val->mark_input();
|
||||
return str_val;
|
||||
};
|
||||
|
||||
jinja::value_array messages = jinja::mk_val<jinja::value_array>();
|
||||
jinja::value_object msg1 = jinja::mk_val<jinja::value_object>();
|
||||
msg1->insert("role", make_non_special_string("user"));
|
||||
msg1->insert("content", make_non_special_string("Hello, how are you?"));
|
||||
messages->push_back(std::move(msg1));
|
||||
jinja::value_object msg2 = jinja::mk_val<jinja::value_object>();
|
||||
msg2->insert("role", make_non_special_string("assistant"));
|
||||
msg2->insert("content", make_non_special_string("I am fine, thank you!"));
|
||||
messages->push_back(std::move(msg2));
|
||||
|
||||
ctx.var["messages"] = std::move(messages);
|
||||
ctx.var["eos_token"] = jinja::mk_val<jinja::value_string>("</s>");
|
||||
// ctx.var["tools"] = jinja::mk_val<jinja::value_null>();
|
||||
std::string json_inp = R"({
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": {"__input__": "Hello, how are you?"}
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": {"__input__": "I am fine, thank you!"}
|
||||
}
|
||||
],
|
||||
"eos_token": "</s>"
|
||||
})";
|
||||
jinja::global_from_json(ctx, nlohmann::json::parse(json_inp));
|
||||
|
||||
jinja::vm vm(ctx);
|
||||
const jinja::value results = vm.execute(ast);
|
||||
|
|
|
|||
Loading…
Reference in New Issue