diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 7f588a8878..50401b56bb 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -14,74 +14,134 @@ #include "jinja/jinja-parser.h" #include "jinja/jinja-lexer.h" -void run_multiple(); -void run_single(std::string contents); +using json = nlohmann::json; -int main(void) { - //std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; +void run_multiple(std::string dir_path, bool stop_on_first_failure, json input); +void run_single(std::string contents, json input); - //std::string contents = "{% if messages[0]['role'] != 'system' %}nice {{ messages[0]['content'] }}{% endif %}"; +std::string HELP = R"( +Usage: test-chat-jinja [OPTIONS] PATH_TO_TEMPLATE +Options: + --json Path to the JSON input file. + --stop-on-first-fail Stop testing on the first failure (default: false). +If PATH_TO_TEMPLATE is a file, runs that single template. +If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory. +)"; - //std::string contents = " {{ messages[a]['content'] }} "; - //std::string contents = "{% if a is not defined %}hello{% endif %}"; +std::string DEFAULT_JSON = R"({ + "messages": [ + { + "role": "user", + "content": {"__input__": "Hello, how are you?"} + }, + { + "role": "assistant", + "content": {"__input__": "I am fine, thank you!"} + }, + { + "role": "assistant", + "content": "Calling weather tool.", + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": { + "location": "New York", + "unit": "celsius" + } + } + } + ] + } + ], + "bos_token": "", + "eos_token": "", + "tools": [] +})"; - std::ifstream infile("models/templates/Qwen-Qwen3-0.6B.jinja"); - //std::ifstream infile("models/templates/Kimi-K2-Thinking.jinja"); - std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); +int main(int argc, char ** argv) { + std::vector args(argv, argv + argc); - run_single(contents); + std::string tmpl_path; + std::string json_path; + bool stop_on_first_fail = false; - //run_multiple(); + for (size_t i = 1; i < args.size(); i++) { + if (args[i] == "--help" || args[i] == "-h") { + std::cout << HELP << "\n"; + return 0; + } else if (args[i] == "--json" && i + 1 < args.size()) { + json_path = args[i + 1]; + i++; + } else if (args[i] == "--stop-on-first-fail") { + stop_on_first_fail = true; + } else if (tmpl_path.empty()) { + tmpl_path = args[i]; + } else { + std::cerr << "Unknown argument: " << args[i] << "\n"; + std::cout << HELP << "\n"; + return 1; + } + } + + if (tmpl_path.empty()) { + std::cerr << "Error: PATH_TO_TEMPLATE is required.\n"; + std::cout << HELP << "\n"; + return 1; + } + + json input_json; + if (!json_path.empty()) { + std::ifstream json_file(json_path); + if (!json_file) { + std::cerr << "Error: Could not open JSON file: " << json_path << "\n"; + return 1; + } + std::string content = std::string( + std::istreambuf_iterator(json_file), + std::istreambuf_iterator()); + input_json = json::parse(content); + } else { + input_json = json::parse(DEFAULT_JSON); + } + + std::filesystem::path p(tmpl_path); + if (std::filesystem::is_directory(p)) { + run_multiple(tmpl_path, stop_on_first_fail, input_json); + } else if (std::filesystem::is_regular_file(p)) { + std::ifstream infile(tmpl_path); + std::string contents = std::string( + std::istreambuf_iterator(infile), + std::istreambuf_iterator()); + run_single(contents, input_json); + } else { + std::cerr << "Error: PATH_TO_TEMPLATE is not a valid file or directory: " << tmpl_path << "\n"; + return 1; + } return 0; } -void run_multiple(void) { +void run_multiple(std::string dir_path, bool stop_on_first_fail, json input) { std::vector failed_tests; - bool stop_on_first_failure = false; - - auto is_ignored_file = [](const std::string & filename) -> bool { - std::vector ignored_files = { - "Apriel-", - "Olmo-3-7B-Instruct-Heretic-GGUF", - "sheldonrobinson-Llama-Guard", - "deepseek-community-Janus-Pro-1B", - "bitshrine-gemma-2-2B-function-calling", - "PaddlePaddle-PaddleOCR-VL", - }; - for (const auto & ignored : ignored_files) { - if (filename.find(ignored) != std::string::npos) { - return true; - } - } - return false; - }; - // list all files in models/templates/ and run each size_t test_count = 0; - size_t skip_count = 0; - //std::string dir_path = "models/templates/"; - std::string dir_path = "../test-jinja/templates/"; - for (const auto & entry : std::filesystem::directory_iterator(dir_path)) { - if (entry.is_regular_file()) { - if (is_ignored_file(entry.path().filename().string())) { - std::cout << "=== SKIPPING TEMPLATE FILE: " << entry.path().string() << " ===\n"; - skip_count++; - continue; - } + for (const auto & entry : std::filesystem::directory_iterator(dir_path)) { + // only process .jinja files + if (entry.path().extension() == ".jinja" && entry.is_regular_file()) { test_count++; std::cout << "\n\n=== RUNNING TEMPLATE FILE: " << entry.path().string() << " ===\n"; std::ifstream infile(entry.path()); std::string contents((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); try { - run_single(contents); + run_single(contents, input); } catch (const std::exception & e) { std::cout << "Exception: " << e.what() << "\n"; std::cout << "=== ERROR WITH TEMPLATE FILE: " << entry.path().string() << " ===\n"; failed_tests.push_back(entry.path().string()); - if (stop_on_first_failure) { + if (stop_on_first_fail) { break; } } @@ -91,14 +151,13 @@ void run_multiple(void) { std::cout << "\n\n=== TEST SUMMARY ===\n"; std::cout << "Total tests run: " << test_count << "\n"; std::cout << "Total failed tests: " << failed_tests.size() << "\n"; - std::cout << "Total skipped tests: " << skip_count << "\n"; for (const auto & test : failed_tests) { std::cout << "FAILED TEST: " << test << "\n"; } } -void run_single(std::string contents) { +void run_single(std::string contents, json input) { jinja::enable_debug(true); // lexing @@ -115,46 +174,7 @@ void run_single(std::string contents) { jinja::context ctx; ctx.source = lexer_res.preprocessed_source; - std::string json_inp = R"({ - "messages": [ - { - "role": "user", - "content": {"__input__": "Hello, how are you?"} - }, - { - "role": "assistant", - "content": {"__input__": "I am fine, thank you!"} - }, - { - "role": "assistant", - "content": "Calling weather tool.", - "tool_calls": [ - { - "function": { - "name": "get_weather", - "arguments": { - "location": "New York", - "unit": "celsius" - } - } - } - ] - } - ], - "bos_token": "", - "eos_token": "", - "tools": [] - })"; - auto input_json = nlohmann::json::parse(json_inp); - - // workaround for functionary models - input_json["functions"] = ""; - input_json["datetime"] = ""; - - // workaround for Llama Guard models - input_json["excluded_category_keys"] = nlohmann::json::array(); - - jinja::global_from_json(ctx, input_json); + jinja::global_from_json(ctx, input); jinja::vm vm(ctx); const jinja::value results = vm.execute(ast);