diff --git a/common/jinja/jinja-string.h b/common/jinja/jinja-string.h new file mode 100644 index 0000000000..fb3371271f --- /dev/null +++ b/common/jinja/jinja-string.h @@ -0,0 +1,166 @@ +#pragma once + +#include +#include +#include +#include + + +namespace jinja { + +// allow differentiate between user input strings and template strings +// transformations should handle this information as follows: +// - one-to-one (e.g., uppercase, lowercase): preserve is_input flag +// - one-to-many (e.g., strip): if input string is marked as is_input, all resulting parts should be marked as is_input +// - many-to-one (e.g., concat): if ALL input parts are marked as is_input, resulting part should be marked as is_input +struct string_part { + bool is_input = false; // may skip parsing special tokens if true + std::string val; +}; + +struct string { + using transform_fn = std::function; + + std::vector parts; + string() = default; + string(const std::string & v, bool user_input = false) { + parts.push_back({user_input, v}); + } + string(int v) { + parts.push_back({false, std::to_string(v)}); + } + string(double v) { + parts.push_back({false, std::to_string(v)}); + } + + void mark_input() { + for (auto & part : parts) { + part.is_input = true; + } + } + + std::string str() const { + if (parts.size() == 1) { + return parts[0].val; + } + std::ostringstream oss; + for (const auto & part : parts) { + oss << part.val; + } + return oss.str(); + } + + size_t length() const { + size_t len = 0; + for (const auto & part : parts) { + len += part.val.length(); + } + return len; + } + + bool all_parts_are_input() const { + for (const auto & part : parts) { + if (!part.is_input) { + return false; + } + } + return true; + } + + // mark this string as input if other has ALL parts as input + void mark_input_based_on(const string & other) { + if (other.all_parts_are_input()) { + for (auto & part : parts) { + part.is_input = true; + } + } + } + + string append(const string & other) { + for (const auto & part : other.parts) { + parts.push_back(part); + } + return *this; + } + + // in-place transformation + + string apply_transform(const transform_fn & fn) { + for (auto & part : parts) { + part.val = fn(part.val); + } + return *this; + } + string uppercase() { + return apply_transform([](const std::string & s) { + std::string res = s; + std::transform(res.begin(), res.end(), res.begin(), ::toupper); + return res; + }); + } + string lowercase() { + return apply_transform([](const std::string & s) { + std::string res = s; + std::transform(res.begin(), res.end(), res.begin(), ::tolower); + return res; + }); + } + string capitalize() { + return apply_transform([](const std::string & s) { + if (s.empty()) return s; + std::string res = s; + res[0] = ::toupper(static_cast(res[0])); + std::transform(res.begin() + 1, res.end(), res.begin() + 1, ::tolower); + return res; + }); + } + string titlecase() { + return apply_transform([](const std::string & s) { + std::string res = s; + bool capitalize_next = true; + for (char &c : res) { + if (isspace(static_cast(c))) { + capitalize_next = true; + } else if (capitalize_next) { + c = ::toupper(static_cast(c)); + capitalize_next = false; + } else { + c = ::tolower(static_cast(c)); + } + } + return res; + }); + } + string strip(bool left, bool right) { + // TODO: what if leading/trailing continue in multiple parts? + + static auto strip_part = [](const std::string & s, bool left, bool right) -> std::string { + size_t start = 0; + size_t end = s.length(); + if (left) { + while (start < end && isspace(static_cast(s[start]))) { + ++start; + } + } + if (right) { + while (end > start && isspace(static_cast(s[end - 1]))) { + --end; + } + } + return s.substr(start, end - start); + }; + if (parts.empty()) { + return *this; + } + if (left) { + parts[0].val = strip_part(parts[0].val, true, false); + } + if (right) { + auto & last = parts[parts.size() - 1]; + last.val = strip_part(last.val, false, true); + } + return *this; + } +}; + +} // namespace jinja diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 8b2d74ae35..74366de9ba 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -7,6 +7,7 @@ #include #include +#include "jinja-string.h" namespace jinja { @@ -80,7 +81,7 @@ bool value_compare(const value & a, const value & b); struct value_t { int64_t val_int; double val_flt; - std::string val_str; + string val_str; bool val_bool; // array and object are stored as shared_ptr to allow reference access @@ -102,7 +103,7 @@ struct value_t { virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); } virtual double as_float() const { throw std::runtime_error(type() + " is not a float value"); } - virtual std::string as_string() const { throw std::runtime_error(type() + " is not a string value"); } + virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); } virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); } virtual const std::vector & as_array() const { throw std::runtime_error(type() + " is not an array value"); } virtual const std::map & as_object() const { throw std::runtime_error(type() + " is not an object value"); } @@ -113,7 +114,7 @@ struct value_t { throw std::runtime_error("No builtins available for type " + type()); } - virtual std::string as_repr() const { return as_string(); } + virtual std::string as_repr() const { return as_string().str(); } virtual value clone() const { return std::make_unique(*this); @@ -126,7 +127,7 @@ struct value_int_t : public value_t { virtual std::string type() const override { return "Integer"; } virtual int64_t as_int() const override { return val_int; } virtual double as_float() const override { return static_cast(val_int); } - virtual std::string as_string() const override { return std::to_string(val_int); } + virtual string as_string() const override { return std::to_string(val_int); } virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; }; @@ -138,7 +139,7 @@ struct value_float_t : public value_t { virtual std::string type() const override { return "Float"; } virtual double as_float() const override { return val_flt; } virtual int64_t as_int() const override { return static_cast(val_flt); } - virtual std::string as_string() const override { return std::to_string(val_flt); } + virtual string as_string() const override { return std::to_string(val_flt); } virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; }; @@ -146,13 +147,23 @@ using value_float = std::unique_ptr; struct value_string_t : public value_t { - bool is_user_input = false; // may skip parsing special tokens if true - - value_string_t(const std::string & v) { val_str = v; } + value_string_t() { val_str = string(); } + value_string_t(const std::string & v) { val_str = string(v); } + value_string_t(const string & v) { val_str = v; } virtual std::string type() const override { return "String"; } - virtual std::string as_string() const override { return val_str; } + virtual string as_string() const override { return val_str; } + virtual std::string as_repr() const override { + std::ostringstream ss; + for (const auto & part : val_str.parts) { + ss << (part.is_input ? "INPUT: " : "TMPL: ") << part.val << "\n"; + } + return ss.str(); + } virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; + void mark_input() { + val_str.mark_input(); + } }; using value_string = std::unique_ptr; @@ -161,7 +172,7 @@ struct value_bool_t : public value_t { value_bool_t(bool v) { val_bool = v; } virtual std::string type() const override { return "Boolean"; } virtual bool as_bool() const override { return val_bool; } - virtual std::string as_string() const override { return val_bool ? "True" : "False"; } + virtual string as_string() const override { return std::string(val_bool ? "True" : "False"); } virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; }; @@ -200,7 +211,7 @@ struct value_array_t : public value_t { tmp->val_arr = this->val_arr; return tmp; } - virtual std::string as_string() const override { + virtual string as_string() const override { std::ostringstream ss; ss << "["; for (size_t i = 0; i < val_arr->size(); i++) { diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index 493c71e25e..e8c8eee993 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -12,7 +12,7 @@ const func_builtins & global_builtins() { static const func_builtins builtins = { {"raise_exception", [](const func_args & args) -> value { args.ensure_count(1); - std::string msg = args.args[0]->as_string(); + std::string msg = args.args[0]->as_string().str(); throw raised_exception("Jinja Exception: " + msg); }}, }; @@ -54,21 +54,21 @@ const func_builtins & value_float_t::get_builtins() const { } -static std::string string_strip(const std::string & str, bool left, bool right) { - size_t start = 0; - size_t end = str.length(); - if (left) { - while (start < end && isspace(static_cast(str[start]))) { - ++start; - } - } - if (right) { - while (end > start && isspace(static_cast(str[end - 1]))) { - --end; - } - } - return str.substr(start, end - start); -} +// static std::string string_strip(const std::string & str, bool left, bool right) { +// size_t start = 0; +// size_t end = str.length(); +// if (left) { +// while (start < end && isspace(static_cast(str[start]))) { +// ++start; +// } +// } +// if (right) { +// while (end > start && isspace(static_cast(str[end - 1]))) { +// --end; +// } +// } +// return str.substr(start, end - start); +// } static bool string_startswith(const std::string & str, const std::string & prefix) { if (str.length() < prefix.length()) return false; @@ -84,77 +84,60 @@ const func_builtins & value_string_t::get_builtins() const { static const func_builtins builtins = { {"upper", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - std::transform(str.begin(), str.end(), str.begin(), ::toupper); + jinja::string str = args.args[0]->as_string().uppercase(); return mk_val(str); }}, {"lower", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - std::transform(str.begin(), str.end(), str.begin(), ::tolower); + jinja::string str = args.args[0]->as_string().lowercase(); return mk_val(str); }}, {"strip", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - return mk_val(string_strip(str, true, true)); + jinja::string str = args.args[0]->as_string().strip(true, true); + return mk_val(str); }}, {"rstrip", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - return mk_val(string_strip(str, false, true)); + jinja::string str = args.args[0]->as_string().strip(false, true); + return mk_val(str); }}, {"lstrip", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - return mk_val(string_strip(str, true, false)); + jinja::string str = args.args[0]->as_string().strip(true, false); + return mk_val(str); }}, {"title", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - bool capitalize_next = true; - for (char &c : str) { - if (isspace(static_cast(c))) { - capitalize_next = true; - } else if (capitalize_next) { - c = ::toupper(static_cast(c)); - capitalize_next = false; - } else { - c = ::tolower(static_cast(c)); - } - } + jinja::string str = args.args[0]->as_string().titlecase(); return mk_val(str); }}, {"capitalize", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - if (!str.empty()) { - str[0] = ::toupper(static_cast(str[0])); - std::transform(str.begin() + 1, str.end(), str.begin() + 1, ::tolower); - } + jinja::string str = args.args[0]->as_string().capitalize(); return mk_val(str); }}, {"length", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); + jinja::string str = args.args[0]->as_string(); return mk_val(str.length()); }}, {"startswith", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - std::string prefix = args.args[1]->as_string(); + std::string str = args.args[0]->as_string().str(); + std::string prefix = args.args[1]->as_string().str(); return mk_val(string_startswith(str, prefix)); }}, {"endswith", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - std::string suffix = args.args[1]->as_string(); + std::string str = args.args[0]->as_string().str(); + std::string suffix = args.args[1]->as_string().str(); return mk_val(string_endswith(str, suffix)); }}, {"split", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - std::string delim = (args.args.size() > 1) ? args.args[1]->as_string() : " "; + std::string str = args.args[0]->as_string().str(); + std::string delim = (args.args.size() > 1) ? args.args[1]->as_string().str() : " "; auto result = mk_val(); size_t pos = 0; std::string token; @@ -163,24 +146,28 @@ const func_builtins & value_string_t::get_builtins() const { result->val_arr->push_back(mk_val(token)); str.erase(0, pos + delim.length()); } - result->val_arr->push_back(mk_val(str)); + auto res = mk_val(str); + res->val_str.mark_input_based_on(args.args[0]->val_str); + result->val_arr->push_back(std::move(res)); return std::move(result); }}, {"replace", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); - std::string old_str = args.args[1]->as_string(); - std::string new_str = args.args[2]->as_string(); + std::string str = args.args[0]->as_string().str(); + std::string old_str = args.args[1]->as_string().str(); + std::string new_str = args.args[2]->as_string().str(); size_t pos = 0; while ((pos = str.find(old_str, pos)) != std::string::npos) { str.replace(pos, old_str.length(), new_str); pos += new_str.length(); } - return mk_val(str); + auto res = mk_val(str); + res->val_str.mark_input_based_on(args.args[0]->val_str); + return res; }}, {"int", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); + std::string str = args.args[0]->as_string().str(); try { return mk_val(std::stoi(str)); } catch (...) { @@ -189,7 +176,7 @@ const func_builtins & value_string_t::get_builtins() const { }}, {"float", [](const func_args & args) -> value { args.ensure_vals(); - std::string str = args.args[0]->as_string(); + std::string str = args.args[0]->as_string().str(); try { return mk_val(std::stod(str)); } catch (...) { @@ -277,7 +264,7 @@ const func_builtins & value_object_t::get_builtins() const { {"get", [](const func_args & args) -> value { args.ensure_vals(); // TODO: add default value const auto & obj = args.args[0]->as_object(); - std::string key = args.args[1]->as_string(); + std::string key = args.args[1]->as_string().str(); auto it = obj.find(key); if (it != obj.end()) { return it->second->clone(); diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 7fb323c58b..c6861eeb39 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -72,11 +72,6 @@ value binary_expression::execute(context & ctx) { throw std::runtime_error("Cannot perform operation on null values"); } - // String concatenation with ~ - if (op.value == "~") { - return mk_val(left_val->as_string() + right_val->as_string()); - } - // Float operations if ((is_val(left_val) || is_val(left_val)) && (is_val(right_val) || is_val(right_val))) { @@ -137,18 +132,20 @@ value binary_expression::execute(context & ctx) { } } - // String concatenation - if (is_val(left_val) || is_val(right_val)) { - JJ_DEBUG("%s", "String concatenation with + operator"); - if (op.value == "+") { - return mk_val(left_val->as_string() + right_val->as_string()); - } + // String concatenation with ~ and + + if ((is_val(left_val) || is_val(right_val)) && + (op.value == "~" || op.value == "+")) { + JJ_DEBUG("String concatenation with %s operator", op.value.c_str()); + auto output = left_val->as_string().append(right_val->as_string()); + auto res = mk_val(); + res->val_str = std::move(output); + return res; } // String membership if (is_val(left_val) && is_val(right_val)) { - auto left_str = left_val->as_string(); - auto right_str = right_val->as_string(); + auto left_str = left_val->as_string().str(); + auto right_str = right_val->as_string().str(); if (op.value == "in") { return mk_val(right_str.find(left_str) != std::string::npos); } else if (op.value == "not in") { @@ -158,7 +155,7 @@ value binary_expression::execute(context & ctx) { // String in object if (is_val(left_val) && is_val(right_val)) { - auto key = left_val->as_string(); + auto key = left_val->as_string().str(); auto & obj = right_val->as_object(); bool has_key = obj.find(key) != obj.end(); if (op.value == "in") { @@ -434,7 +431,7 @@ value member_expression::execute(context & ctx) { if (!is_val(property)) { throw std::runtime_error("Cannot access object with non-string: got " + property->type()); } - auto key = property->as_string(); + auto key = property->as_string().str(); auto & obj = object->as_object(); auto it = obj.find(key); if (it != obj.end()) { @@ -459,13 +456,13 @@ value member_expression::execute(context & ctx) { val = arr[index]->clone(); } } else { // value_string - auto str = object->as_string(); + auto str = object->as_string().str(); if (index >= 0 && index < static_cast(str.size())) { val = mk_val(std::string(1, str[index])); } } } else if (is_val(property)) { - auto key = property->as_string(); + auto key = property->as_string().str(); JJ_DEBUG("Accessing %s built-in '%s'", is_val(object) ? "array" : "string", key.c_str()); auto builtins = object->get_builtins(); auto bit = builtins.find(key); @@ -482,7 +479,7 @@ value member_expression::execute(context & ctx) { if (!is_val(property)) { throw std::runtime_error("Cannot access property with non-string: got " + property->type()); } - auto key = property->as_string(); + auto key = property->as_string().str(); auto builtins = object->get_builtins(); auto bit = builtins.find(key); if (bit != builtins.end()) { @@ -528,7 +525,7 @@ bool value_compare(const value & a, const value & b) { if ((is_val(b) && (is_val(a) || is_val(a))) || (is_val(a) && (is_val(b) || is_val(b)))) { try { - return a->as_string() == b->as_string(); + return a->as_string().str() == b->as_string().str(); } catch (...) {} } // compare boolean simple @@ -537,7 +534,7 @@ bool value_compare(const value & a, const value & b) { } // compare string simple if (is_val(a) && is_val(b)) { - return a->as_string() == b->as_string(); + return a->as_string().str() == b->as_string().str(); } // compare by type if (a->type() != b->type()) { diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 085531a673..acbf7daf2a 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -39,7 +39,7 @@ int main(void) { auto make_non_special_string = [](const std::string & s) { jinja::value_string str_val = jinja::mk_val(s); - str_val->is_user_input = true; + str_val->mark_input(); return str_val; }; @@ -63,12 +63,7 @@ int main(void) { if (res->is_null()) { continue; } - auto str_ptr = dynamic_cast(res.get()); - std::string is_user_input = "false"; - if (str_ptr) { - is_user_input = str_ptr->is_user_input ? "true" : "false"; - } - std::cout << "result type: " << res->type() << " | value: " << res->as_string() << " | is_user_input: " << is_user_input << "\n"; + std::cout << "result type: " << res->type() << " | value: " << res->as_repr(); } return 0;