impl global_from_json

This commit is contained in:
Xuan Son Nguyen 2025-12-28 23:15:48 +01:00
parent 55fe96a9df
commit 1784a57e7b
4 changed files with 112 additions and 22 deletions

View File

@ -3,6 +3,9 @@
#include "jinja-parser.h"
#include "jinja-value.h"
// for converting from JSON to jinja values
#include <nlohmann/json.hpp>
#include <string>
#include <cctype>
#include <vector>
@ -520,4 +523,50 @@ const func_builtins & value_object_t::get_builtins() const {
return builtins;
}
static value from_json(const nlohmann::json & j) {
if (j.is_null()) {
return mk_val<value_null>();
} else if (j.is_boolean()) {
return mk_val<value_bool>(j.get<bool>());
} else if (j.is_number_integer()) {
return mk_val<value_int>(j.get<int64_t>());
} else if (j.is_number_float()) {
return mk_val<value_float>(j.get<double>());
} else if (j.is_string()) {
return mk_val<value_string>(j.get<std::string>());
} else if (j.is_array()) {
auto arr = mk_val<value_array>();
for (const auto & item : j) {
arr->push_back(from_json(item));
}
return arr;
} else if (j.is_object()) {
if (j.contains("__input__")) {
// handle input marking
auto str = mk_val<value_string>(j.at("__input__").get<std::string>());
str->mark_input();
return str;
} else {
// normal object
auto obj = mk_val<value_object>();
for (auto it = j.begin(); it != j.end(); ++it) {
obj->insert(it.key(), from_json(it.value()));
}
return obj;
}
} else {
throw std::runtime_error("Unsupported JSON value type");
}
}
template<>
void global_from_json(context & ctx, const nlohmann::json & json_obj) {
if (json_obj.is_null() || !json_obj.is_object()) {
throw std::runtime_error("global_from_json: input JSON value must be an object");
}
for (auto it = json_obj.begin(); it != json_obj.end(); ++it) {
ctx.var[it.key()] = from_json(it.value());
}
}
} // namespace jinja

View File

@ -57,8 +57,41 @@ void ensure_val(const value & ptr) {
}
// End Helper
struct context; // forward declaration
// for converting from JSON to jinja values
// example input JSON:
// {
// "messages": [
// {"role": "user", "content": "Hello!"},
// {"role": "assistant", "content": "Hi there!"}
// ],
// "bos_token": "<s>",
// "eos_token": "</s>",
// }
//
// to mark strings as user input, wrap them in a special object:
// {
// "messages": [
// {
// "role": "user",
// "content": {"__input__": "Hello!"} // this string is user input
// },
// ...
// ],
// }
//
// marking input can be useful for tracking data provenance
// and preventing template injection attacks
//
// Note: T_JSON can be nlohmann::json or similar types
template<typename T_JSON>
void global_from_json(context & ctx, const T_JSON & json_obj);
struct func_args {
std::vector<value> args;
context & ctx;

View File

@ -226,7 +226,7 @@ static value try_builtin_func(const std::string & name, const value & input, boo
}
value filter_expression::execute_impl(context & ctx) {
value input = operand->execute(ctx);
value input = operand ? operand->execute(ctx) : val;
JJ_DEBUG("Applying filter to %s", input->type().c_str());

View File

@ -6,6 +6,8 @@
#include <fstream>
#include <filesystem>
#include <nlohmann/json.hpp>
#undef NDEBUG
#include <cassert>
@ -24,10 +26,14 @@ int main(void) {
//std::ifstream infile("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja"); std::string contents((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
std::vector<std::string> failed_tests;
// list all files in models/templates/ and run each
size_t test_count = 0;
std::string dir_path = "models/templates/";
for (const auto & entry : std::filesystem::directory_iterator(dir_path)) {
if (entry.is_regular_file()) {
test_count++;
std::cout << "\n\n=== RUNNING TEMPLATE FILE: " << entry.path().string() << " ===\n";
std::ifstream infile(entry.path());
std::string contents((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
@ -35,11 +41,18 @@ int main(void) {
run(contents);
} catch (const std::exception & e) {
std::cout << "Exception: " << e.what() << "\n";
std::cout << "=== CURRENT TEMPLATE FILE: " << entry.path().string() << " ===\n";
exit(1);
std::cout << "=== ERROR WITH TEMPLATE FILE: " << entry.path().string() << " ===\n";
failed_tests.push_back(entry.path().string());
}
}
}
std::cout << "\n\n=== TEST SUMMARY ===\n";
std::cout << "Total tests run: " << test_count << "\n";
std::cout << "Total failed tests: " << failed_tests.size() << "\n";
for (const auto & test : failed_tests) {
std::cout << "FAILED TEST: " << test << "\n";
}
return 0;
}
@ -66,25 +79,20 @@ void run(std::string contents) {
jinja::context ctx;
ctx.source = lexer_res.preprocessed_source;
auto make_non_special_string = [](const std::string & s) {
jinja::value_string str_val = jinja::mk_val<jinja::value_string>(s);
str_val->mark_input();
return str_val;
};
jinja::value_array messages = jinja::mk_val<jinja::value_array>();
jinja::value_object msg1 = jinja::mk_val<jinja::value_object>();
msg1->insert("role", make_non_special_string("user"));
msg1->insert("content", make_non_special_string("Hello, how are you?"));
messages->push_back(std::move(msg1));
jinja::value_object msg2 = jinja::mk_val<jinja::value_object>();
msg2->insert("role", make_non_special_string("assistant"));
msg2->insert("content", make_non_special_string("I am fine, thank you!"));
messages->push_back(std::move(msg2));
ctx.var["messages"] = std::move(messages);
ctx.var["eos_token"] = jinja::mk_val<jinja::value_string>("</s>");
// ctx.var["tools"] = jinja::mk_val<jinja::value_null>();
std::string json_inp = R"({
"messages": [
{
"role": "user",
"content": {"__input__": "Hello, how are you?"}
},
{
"role": "assistant",
"content": {"__input__": "I am fine, thank you!"}
}
],
"eos_token": "</s>"
})";
jinja::global_from_json(ctx, nlohmann::json::parse(json_inp));
jinja::vm vm(ctx);
const jinja::value results = vm.execute(ast);