From 15b3dbab05a85a892c2d0ebaf6f3b6913d3ea24e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 27 Dec 2025 21:52:50 +0100 Subject: [PATCH] add string builtins --- common/CMakeLists.txt | 4 + common/jinja/jinja-value.h | 78 ++++++++++++++++ common/jinja/jinja-vm-builtins.cpp | 139 +++++++++++++++++++++++++++++ common/jinja/jinja-vm.cpp | 58 +++++------- common/jinja/jinja-vm.h | 2 +- tests/test-chat-jinja.cpp | 2 +- 6 files changed, 247 insertions(+), 36 deletions(-) create mode 100644 common/jinja/jinja-vm-builtins.cpp diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 49ce25a842..4ed0df100f 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -84,8 +84,12 @@ add_library(${TARGET} STATIC unicode.cpp unicode.h jinja/jinja-lexer.cpp + jinja/jinja-lexer.h jinja/jinja-parser.cpp + jinja/jinja-parser.h jinja/jinja-vm.cpp + jinja/jinja-vm.h + jinja/jinja-vm-builtins.cpp ) target_include_directories(${TARGET} PUBLIC . ../vendor) diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 87ee91f693..289acb1c7d 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -3,6 +3,8 @@ #include #include #include +#include +#include namespace jinja { @@ -10,6 +12,63 @@ namespace jinja { struct value_t; using value = std::unique_ptr; + +// Helper to check the type of a value +template +struct extract_pointee { + using type = T; +}; +template +struct extract_pointee> { + using type = U; +}; +template +bool is_val(const value & ptr) { + using PointeeType = typename extract_pointee::type; + return dynamic_cast(ptr.get()) != nullptr; +} +template +bool mk_val(Args&&... args) { + using PointeeType = typename extract_pointee::type; + return std::make_unique(std::forward(args)...); +} +template +void ensure_val(const value & ptr) { + if (!is_val(ptr)) { + throw std::runtime_error("Expected value of type " + std::string(typeid(T).name())); + } +} +// End Helper + + +struct func_args { + std::vector args; + void ensure_count(size_t count) const { + if (args.size() != count) { + throw std::runtime_error("Expected " + std::to_string(count) + " arguments, got " + std::to_string(args.size())); + } + } + // utility functions + template void ensure_vals() const { + ensure_count(1); + ensure_val(args[0]); + } + template void ensure_vals() const { + ensure_count(2); + ensure_val(args[0]); + ensure_val(args[1]); + } + template void ensure_vals() const { + ensure_count(3); + ensure_val(args[0]); + ensure_val(args[1]); + ensure_val(args[2]); + } +}; + +using func_handler = std::function; +using func_builtins = std::map; + struct value_t { int64_t val_int; double val_flt; @@ -25,6 +84,8 @@ struct value_t { std::shared_ptr> val_arr; std::shared_ptr> val_obj; + func_handler val_func; + value_t() = default; value_t(const value_t &) = default; virtual ~value_t() = default; @@ -37,8 +98,12 @@ struct value_t { virtual bool as_bool() const { throw std::runtime_error("Not a bool value"); } virtual const std::vector & as_array() const { throw std::runtime_error("Not an array value"); } virtual const std::map & as_object() const { throw std::runtime_error("Not an object value"); } + virtual value invoke(const func_args &) const { throw std::runtime_error("Not a function value"); } virtual bool is_null() const { return false; } virtual bool is_undefined() const { return false; } + virtual const func_builtins & get_builtins() const { + throw std::runtime_error("No builtins available for type " + type()); + } virtual value clone() const { return std::make_unique(*this); @@ -78,6 +143,7 @@ struct value_string_t : public value_t { virtual std::string type() const override { return "String"; } virtual std::string as_string() const override { return val_str; } virtual value clone() const override { return std::make_unique(*this); } + const func_builtins & get_builtins() const override; }; using value_string = std::unique_ptr; @@ -145,6 +211,18 @@ struct value_object_t : public value_t { }; using value_object = std::unique_ptr; +struct value_func_t : public value_t { + value_func_t(func_handler & func) { + val_func = func; + } + virtual value invoke(const func_args & args) const override { + return val_func(args); + } + virtual std::string type() const override { return "Function"; } + virtual value clone() const override { return std::make_unique(*this); } +}; +using value_func = std::unique_ptr; + struct value_null_t : public value_t { virtual std::string type() const override { return "Null"; } virtual bool is_null() const override { return true; } diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp new file mode 100644 index 0000000000..85d0681867 --- /dev/null +++ b/common/jinja/jinja-vm-builtins.cpp @@ -0,0 +1,139 @@ +#include "jinja-lexer.h" +#include "jinja-vm.h" +#include "jinja-parser.h" +#include "jinja-value.h" + +#include +#include + +namespace jinja { + +static std::string string_strip(const std::string & str, bool left, bool right) { + size_t start = 0; + size_t end = str.length(); + if (left) { + while (start < end && isspace(static_cast(str[start]))) { + ++start; + } + } + if (right) { + while (end > start && isspace(static_cast(str[end - 1]))) { + --end; + } + } + return str.substr(start, end - start); +} + +static bool string_startswith(const std::string & str, const std::string & prefix) { + if (str.length() < prefix.length()) return false; + return str.compare(0, prefix.length(), prefix) == 0; +} + +static bool string_endswith(const std::string & str, const std::string & suffix) { + if (str.length() < suffix.length()) return false; + return str.compare(str.length() - suffix.length(), suffix.length(), suffix) == 0; +} + +const func_builtins & value_string_t::get_builtins() const { + static const func_builtins builtins = { + {"upper", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + std::transform(str.begin(), str.end(), str.begin(), ::toupper); + return std::make_unique(str); + }}, + {"lower", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + std::transform(str.begin(), str.end(), str.begin(), ::tolower); + return std::make_unique(str); + }}, + {"strip", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + return std::make_unique(string_strip(str, true, true)); + }}, + {"rstrip", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + return std::make_unique(string_strip(str, false, true)); + }}, + {"lstrip", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + return std::make_unique(string_strip(str, true, false)); + }}, + {"title", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + bool capitalize_next = true; + for (char &c : str) { + if (isspace(static_cast(c))) { + capitalize_next = true; + } else if (capitalize_next) { + c = ::toupper(static_cast(c)); + capitalize_next = false; + } else { + c = ::tolower(static_cast(c)); + } + } + return std::make_unique(str); + }}, + {"capitalize", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + if (!str.empty()) { + str[0] = ::toupper(static_cast(str[0])); + std::transform(str.begin() + 1, str.end(), str.begin() + 1, ::tolower); + } + return std::make_unique(str); + }}, + {"length", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + return std::make_unique(str.length()); + }}, + {"startswith", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + std::string prefix = args.args[1]->as_string(); + return std::make_unique(string_startswith(str, prefix)); + }}, + {"endswith", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + std::string suffix = args.args[1]->as_string(); + return std::make_unique(string_endswith(str, suffix)); + }}, + {"split", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + std::string delim = (args.args.size() > 1) ? args.args[1]->as_string() : " "; + auto result = std::make_unique(); + size_t pos = 0; + std::string token; + while ((pos = str.find(delim)) != std::string::npos) { + token = str.substr(0, pos); + result->val_arr->push_back(std::make_unique(token)); + str.erase(0, pos + delim.length()); + } + result->val_arr->push_back(std::make_unique(str)); + return std::move(result); + }}, + {"replace", [](const func_args & args) -> value { + args.ensure_vals(); + std::string str = args.args[0]->as_string(); + std::string old_str = args.args[1]->as_string(); + std::string new_str = args.args[2]->as_string(); + size_t pos = 0; + while ((pos = str.find(old_str, pos)) != std::string::npos) { + str.replace(pos, old_str.length(), new_str); + pos += new_str.length(); + } + return std::make_unique(str); + }}, + }; + return builtins; +}; + +} // namespace jinja diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index aff6e90603..25106f1e4a 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -1,6 +1,7 @@ #include "jinja-lexer.h" #include "jinja-vm.h" #include "jinja-parser.h" +#include "jinja-value.h" #include #include @@ -9,23 +10,6 @@ namespace jinja { -// Helper to extract the inner type if T is unique_ptr, else T itself -template -struct extract_pointee { - using type = T; -}; - -template -struct extract_pointee> { - using type = U; -}; - -template -static bool is_type(const value& ptr) { - using PointeeType = typename extract_pointee::type; - return dynamic_cast(ptr.get()) != nullptr; -} - template static bool is_stmt(const statement_ptr & ptr) { return dynamic_cast(ptr.get()) != nullptr; @@ -50,13 +34,13 @@ value binary_expression::execute(context & ctx) { } // Handle undefined and null values - if (is_type(left_val) || is_type(right_val)) { - if (is_type(right_val) && (op.value == "in" || op.value == "not in")) { + if (is_val(left_val) || is_val(right_val)) { + if (is_val(right_val) && (op.value == "in" || op.value == "not in")) { // Special case: `anything in undefined` is `false` and `anything not in undefined` is `true` return std::make_unique(op.value == "not in"); } throw std::runtime_error("Cannot perform operation " + op.value + " on undefined values"); - } else if (is_type(left_val) || is_type(right_val)) { + } else if (is_val(left_val) || is_val(right_val)) { throw std::runtime_error("Cannot perform operation on null values"); } @@ -66,13 +50,13 @@ value binary_expression::execute(context & ctx) { } // Float operations - if ((is_type(left_val) || is_type(left_val)) && - (is_type(right_val) || is_type(right_val))) { + if ((is_val(left_val) || is_val(left_val)) && + (is_val(right_val) || is_val(right_val))) { double a = left_val->as_float(); double b = right_val->as_float(); if (op.value == "+" || op.value == "-" || op.value == "*") { double res = (op.value == "+") ? a + b : (op.value == "-") ? a - b : a * b; - bool is_float = is_type(left_val) || is_type(right_val); + bool is_float = is_val(left_val) || is_val(right_val); if (is_float) { return std::make_unique(res); } else { @@ -82,7 +66,7 @@ value binary_expression::execute(context & ctx) { return std::make_unique(a / b); } else if (op.value == "%") { double rem = std::fmod(a, b); - bool is_float = is_type(left_val) || is_type(right_val); + bool is_float = is_val(left_val) || is_val(right_val); if (is_float) { return std::make_unique(rem); } else { @@ -100,7 +84,7 @@ value binary_expression::execute(context & ctx) { } // Array operations - if (is_type(left_val) && is_type(right_val)) { + if (is_val(left_val) && is_val(right_val)) { if (op.value == "+") { auto & left_arr = left_val->as_array(); auto & right_arr = right_val->as_array(); @@ -113,7 +97,7 @@ value binary_expression::execute(context & ctx) { } return result; } - } else if (is_type(right_val)) { + } else if (is_val(right_val)) { auto & arr = right_val->as_array(); bool member = std::find_if(arr.begin(), arr.end(), [&](const value& v) { return v == left_val; }) != arr.end(); if (op.value == "in") { @@ -124,14 +108,14 @@ value binary_expression::execute(context & ctx) { } // String concatenation - if (is_type(left_val) || is_type(right_val)) { + if (is_val(left_val) || is_val(right_val)) { if (op.value == "+") { return std::make_unique(left_val->as_string() + right_val->as_string()); } } // String membership - if (is_type(left_val) && is_type(right_val)) { + if (is_val(left_val) && is_val(right_val)) { auto left_str = left_val->as_string(); auto right_str = right_val->as_string(); if (op.value == "in") { @@ -142,7 +126,7 @@ value binary_expression::execute(context & ctx) { } // String in object - if (is_type(left_val) && is_type(right_val)) { + if (is_val(left_val) && is_val(right_val)) { auto key = left_val->as_string(); auto & obj = right_val->as_object(); bool has_key = obj.find(key) != obj.end(); @@ -158,7 +142,7 @@ value binary_expression::execute(context & ctx) { value filter_expression::execute(context & ctx) { value input = operand->execute(ctx); - value filter_func = filter->execute(ctx); + // value filter_func = filter->execute(ctx); if (is_stmt(filter)) { auto filter_val = dynamic_cast(filter.get())->value; @@ -168,7 +152,7 @@ value filter_expression::execute(context & ctx) { throw std::runtime_error("to_json filter not implemented"); } - if (is_type(input)) { + if (is_val(input)) { auto & arr = input->as_array(); if (filter_val == "list") { return std::make_unique(input); @@ -189,12 +173,18 @@ value filter_expression::execute(context & ctx) { throw std::runtime_error("Unknown filter '" + filter_val + "' for array"); } - } else if (is_type(input)) { + } else if (is_val(input)) { auto str = input->as_string(); - // TODO + auto builtins = input->get_builtins(); + auto it = builtins.find(filter_val); + if (it != builtins.end()) { + func_args args; + args.args.push_back(input->clone()); + return it->second(args); + } throw std::runtime_error("Unknown filter '" + filter_val + "' for string"); - } else if (is_type(input) || is_type(input)) { + } else if (is_val(input) || is_val(input)) { // TODO throw std::runtime_error("Unknown filter '" + filter_val + "' for number"); diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 2c547294a8..ac5d679e88 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -23,7 +23,7 @@ struct context { struct statement { virtual ~statement() = default; virtual std::string type() const { return "Statement"; } - virtual value execute(context & ctx) { throw std::runtime_error("cannot exec " + type()); }; + virtual value execute(context & ctx) { throw std::runtime_error("cannot exec " + type()); } }; using statement_ptr = std::unique_ptr; diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index e0b5d8f8d9..3a8fc0cd87 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -13,7 +13,7 @@ int main(void) { //std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; - std::string contents = "{{ 'hi' + 'fi' }}"; + std::string contents = "{{ ('hi' + 'fi') | upper }}"; std::cout << "=== INPUT ===\n" << contents << "\n\n";