add string builtins
This commit is contained in:
parent
5a041e65b8
commit
15b3dbab05
|
|
@ -84,8 +84,12 @@ add_library(${TARGET} STATIC
|
|||
unicode.cpp
|
||||
unicode.h
|
||||
jinja/jinja-lexer.cpp
|
||||
jinja/jinja-lexer.h
|
||||
jinja/jinja-parser.cpp
|
||||
jinja/jinja-parser.h
|
||||
jinja/jinja-vm.cpp
|
||||
jinja/jinja-vm.h
|
||||
jinja/jinja-vm-builtins.cpp
|
||||
)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC . ../vendor)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
|
||||
namespace jinja {
|
||||
|
|
@ -10,6 +12,63 @@ namespace jinja {
|
|||
struct value_t;
|
||||
using value = std::unique_ptr<value_t>;
|
||||
|
||||
|
||||
// Helper to check the type of a value
|
||||
template<typename T>
|
||||
struct extract_pointee {
|
||||
using type = T;
|
||||
};
|
||||
template<typename U>
|
||||
struct extract_pointee<std::unique_ptr<U>> {
|
||||
using type = U;
|
||||
};
|
||||
template<typename T>
|
||||
bool is_val(const value & ptr) {
|
||||
using PointeeType = typename extract_pointee<T>::type;
|
||||
return dynamic_cast<const PointeeType*>(ptr.get()) != nullptr;
|
||||
}
|
||||
template<typename T, typename... Args>
|
||||
bool mk_val(Args&&... args) {
|
||||
using PointeeType = typename extract_pointee<T>::type;
|
||||
return std::make_unique<PointeeType>(std::forward<Args>(args)...);
|
||||
}
|
||||
template<typename T>
|
||||
void ensure_val(const value & ptr) {
|
||||
if (!is_val<T>(ptr)) {
|
||||
throw std::runtime_error("Expected value of type " + std::string(typeid(T).name()));
|
||||
}
|
||||
}
|
||||
// End Helper
|
||||
|
||||
|
||||
struct func_args {
|
||||
std::vector<value> args;
|
||||
void ensure_count(size_t count) const {
|
||||
if (args.size() != count) {
|
||||
throw std::runtime_error("Expected " + std::to_string(count) + " arguments, got " + std::to_string(args.size()));
|
||||
}
|
||||
}
|
||||
// utility functions
|
||||
template<typename T> void ensure_vals() const {
|
||||
ensure_count(1);
|
||||
ensure_val<T>(args[0]);
|
||||
}
|
||||
template<typename T, typename U> void ensure_vals() const {
|
||||
ensure_count(2);
|
||||
ensure_val<T>(args[0]);
|
||||
ensure_val<U>(args[1]);
|
||||
}
|
||||
template<typename T, typename U, typename V> void ensure_vals() const {
|
||||
ensure_count(3);
|
||||
ensure_val<T>(args[0]);
|
||||
ensure_val<U>(args[1]);
|
||||
ensure_val<V>(args[2]);
|
||||
}
|
||||
};
|
||||
|
||||
using func_handler = std::function<value(const func_args &)>;
|
||||
using func_builtins = std::map<std::string, func_handler>;
|
||||
|
||||
struct value_t {
|
||||
int64_t val_int;
|
||||
double val_flt;
|
||||
|
|
@ -25,6 +84,8 @@ struct value_t {
|
|||
std::shared_ptr<std::vector<value>> val_arr;
|
||||
std::shared_ptr<std::map<std::string, value>> val_obj;
|
||||
|
||||
func_handler val_func;
|
||||
|
||||
value_t() = default;
|
||||
value_t(const value_t &) = default;
|
||||
virtual ~value_t() = default;
|
||||
|
|
@ -37,8 +98,12 @@ struct value_t {
|
|||
virtual bool as_bool() const { throw std::runtime_error("Not a bool value"); }
|
||||
virtual const std::vector<value> & as_array() const { throw std::runtime_error("Not an array value"); }
|
||||
virtual const std::map<std::string, value> & as_object() const { throw std::runtime_error("Not an object value"); }
|
||||
virtual value invoke(const func_args &) const { throw std::runtime_error("Not a function value"); }
|
||||
virtual bool is_null() const { return false; }
|
||||
virtual bool is_undefined() const { return false; }
|
||||
virtual const func_builtins & get_builtins() const {
|
||||
throw std::runtime_error("No builtins available for type " + type());
|
||||
}
|
||||
|
||||
virtual value clone() const {
|
||||
return std::make_unique<value_t>(*this);
|
||||
|
|
@ -78,6 +143,7 @@ struct value_string_t : public value_t {
|
|||
virtual std::string type() const override { return "String"; }
|
||||
virtual std::string as_string() const override { return val_str; }
|
||||
virtual value clone() const override { return std::make_unique<value_string_t>(*this); }
|
||||
const func_builtins & get_builtins() const override;
|
||||
};
|
||||
using value_string = std::unique_ptr<value_string_t>;
|
||||
|
||||
|
|
@ -145,6 +211,18 @@ struct value_object_t : public value_t {
|
|||
};
|
||||
using value_object = std::unique_ptr<value_object_t>;
|
||||
|
||||
struct value_func_t : public value_t {
|
||||
value_func_t(func_handler & func) {
|
||||
val_func = func;
|
||||
}
|
||||
virtual value invoke(const func_args & args) const override {
|
||||
return val_func(args);
|
||||
}
|
||||
virtual std::string type() const override { return "Function"; }
|
||||
virtual value clone() const override { return std::make_unique<value_func_t>(*this); }
|
||||
};
|
||||
using value_func = std::unique_ptr<value_func_t>;
|
||||
|
||||
struct value_null_t : public value_t {
|
||||
virtual std::string type() const override { return "Null"; }
|
||||
virtual bool is_null() const override { return true; }
|
||||
|
|
|
|||
|
|
@ -0,0 +1,139 @@
|
|||
#include "jinja-lexer.h"
|
||||
#include "jinja-vm.h"
|
||||
#include "jinja-parser.h"
|
||||
#include "jinja-value.h"
|
||||
|
||||
#include <string>
|
||||
#include <cctype>
|
||||
|
||||
namespace jinja {
|
||||
|
||||
static std::string string_strip(const std::string & str, bool left, bool right) {
|
||||
size_t start = 0;
|
||||
size_t end = str.length();
|
||||
if (left) {
|
||||
while (start < end && isspace(static_cast<unsigned char>(str[start]))) {
|
||||
++start;
|
||||
}
|
||||
}
|
||||
if (right) {
|
||||
while (end > start && isspace(static_cast<unsigned char>(str[end - 1]))) {
|
||||
--end;
|
||||
}
|
||||
}
|
||||
return str.substr(start, end - start);
|
||||
}
|
||||
|
||||
static bool string_startswith(const std::string & str, const std::string & prefix) {
|
||||
if (str.length() < prefix.length()) return false;
|
||||
return str.compare(0, prefix.length(), prefix) == 0;
|
||||
}
|
||||
|
||||
static bool string_endswith(const std::string & str, const std::string & suffix) {
|
||||
if (str.length() < suffix.length()) return false;
|
||||
return str.compare(str.length() - suffix.length(), suffix.length(), suffix) == 0;
|
||||
}
|
||||
|
||||
const func_builtins & value_string_t::get_builtins() const {
|
||||
static const func_builtins builtins = {
|
||||
{"upper", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_string>();
|
||||
std::string str = args.args[0]->as_string();
|
||||
std::transform(str.begin(), str.end(), str.begin(), ::toupper);
|
||||
return std::make_unique<value_string_t>(str);
|
||||
}},
|
||||
{"lower", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_string>();
|
||||
std::string str = args.args[0]->as_string();
|
||||
std::transform(str.begin(), str.end(), str.begin(), ::tolower);
|
||||
return std::make_unique<value_string_t>(str);
|
||||
}},
|
||||
{"strip", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_string>();
|
||||
std::string str = args.args[0]->as_string();
|
||||
return std::make_unique<value_string_t>(string_strip(str, true, true));
|
||||
}},
|
||||
{"rstrip", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_string>();
|
||||
std::string str = args.args[0]->as_string();
|
||||
return std::make_unique<value_string_t>(string_strip(str, false, true));
|
||||
}},
|
||||
{"lstrip", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_string>();
|
||||
std::string str = args.args[0]->as_string();
|
||||
return std::make_unique<value_string_t>(string_strip(str, true, false));
|
||||
}},
|
||||
{"title", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_string>();
|
||||
std::string str = args.args[0]->as_string();
|
||||
bool capitalize_next = true;
|
||||
for (char &c : str) {
|
||||
if (isspace(static_cast<unsigned char>(c))) {
|
||||
capitalize_next = true;
|
||||
} else if (capitalize_next) {
|
||||
c = ::toupper(static_cast<unsigned char>(c));
|
||||
capitalize_next = false;
|
||||
} else {
|
||||
c = ::tolower(static_cast<unsigned char>(c));
|
||||
}
|
||||
}
|
||||
return std::make_unique<value_string_t>(str);
|
||||
}},
|
||||
{"capitalize", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_string>();
|
||||
std::string str = args.args[0]->as_string();
|
||||
if (!str.empty()) {
|
||||
str[0] = ::toupper(static_cast<unsigned char>(str[0]));
|
||||
std::transform(str.begin() + 1, str.end(), str.begin() + 1, ::tolower);
|
||||
}
|
||||
return std::make_unique<value_string_t>(str);
|
||||
}},
|
||||
{"length", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_string>();
|
||||
std::string str = args.args[0]->as_string();
|
||||
return std::make_unique<value_int_t>(str.length());
|
||||
}},
|
||||
{"startswith", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_string, value_string>();
|
||||
std::string str = args.args[0]->as_string();
|
||||
std::string prefix = args.args[1]->as_string();
|
||||
return std::make_unique<value_bool_t>(string_startswith(str, prefix));
|
||||
}},
|
||||
{"endswith", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_string, value_string>();
|
||||
std::string str = args.args[0]->as_string();
|
||||
std::string suffix = args.args[1]->as_string();
|
||||
return std::make_unique<value_bool_t>(string_endswith(str, suffix));
|
||||
}},
|
||||
{"split", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_string>();
|
||||
std::string str = args.args[0]->as_string();
|
||||
std::string delim = (args.args.size() > 1) ? args.args[1]->as_string() : " ";
|
||||
auto result = std::make_unique<value_array_t>();
|
||||
size_t pos = 0;
|
||||
std::string token;
|
||||
while ((pos = str.find(delim)) != std::string::npos) {
|
||||
token = str.substr(0, pos);
|
||||
result->val_arr->push_back(std::make_unique<value_string_t>(token));
|
||||
str.erase(0, pos + delim.length());
|
||||
}
|
||||
result->val_arr->push_back(std::make_unique<value_string_t>(str));
|
||||
return std::move(result);
|
||||
}},
|
||||
{"replace", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_string, value_string, value_string>();
|
||||
std::string str = args.args[0]->as_string();
|
||||
std::string old_str = args.args[1]->as_string();
|
||||
std::string new_str = args.args[2]->as_string();
|
||||
size_t pos = 0;
|
||||
while ((pos = str.find(old_str, pos)) != std::string::npos) {
|
||||
str.replace(pos, old_str.length(), new_str);
|
||||
pos += new_str.length();
|
||||
}
|
||||
return std::make_unique<value_string_t>(str);
|
||||
}},
|
||||
};
|
||||
return builtins;
|
||||
};
|
||||
|
||||
} // namespace jinja
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
#include "jinja-lexer.h"
|
||||
#include "jinja-vm.h"
|
||||
#include "jinja-parser.h"
|
||||
#include "jinja-value.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
|
@ -9,23 +10,6 @@
|
|||
|
||||
namespace jinja {
|
||||
|
||||
// Helper to extract the inner type if T is unique_ptr<U>, else T itself
|
||||
template<typename T>
|
||||
struct extract_pointee {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template<typename U>
|
||||
struct extract_pointee<std::unique_ptr<U>> {
|
||||
using type = U;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
static bool is_type(const value& ptr) {
|
||||
using PointeeType = typename extract_pointee<T>::type;
|
||||
return dynamic_cast<const PointeeType*>(ptr.get()) != nullptr;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static bool is_stmt(const statement_ptr & ptr) {
|
||||
return dynamic_cast<const T*>(ptr.get()) != nullptr;
|
||||
|
|
@ -50,13 +34,13 @@ value binary_expression::execute(context & ctx) {
|
|||
}
|
||||
|
||||
// Handle undefined and null values
|
||||
if (is_type<value_undefined>(left_val) || is_type<value_undefined>(right_val)) {
|
||||
if (is_type<value_undefined>(right_val) && (op.value == "in" || op.value == "not in")) {
|
||||
if (is_val<value_undefined>(left_val) || is_val<value_undefined>(right_val)) {
|
||||
if (is_val<value_undefined>(right_val) && (op.value == "in" || op.value == "not in")) {
|
||||
// Special case: `anything in undefined` is `false` and `anything not in undefined` is `true`
|
||||
return std::make_unique<value_bool_t>(op.value == "not in");
|
||||
}
|
||||
throw std::runtime_error("Cannot perform operation " + op.value + " on undefined values");
|
||||
} else if (is_type<value_null>(left_val) || is_type<value_null>(right_val)) {
|
||||
} else if (is_val<value_null>(left_val) || is_val<value_null>(right_val)) {
|
||||
throw std::runtime_error("Cannot perform operation on null values");
|
||||
}
|
||||
|
||||
|
|
@ -66,13 +50,13 @@ value binary_expression::execute(context & ctx) {
|
|||
}
|
||||
|
||||
// Float operations
|
||||
if ((is_type<value_int>(left_val) || is_type<value_float>(left_val)) &&
|
||||
(is_type<value_int>(right_val) || is_type<value_float>(right_val))) {
|
||||
if ((is_val<value_int>(left_val) || is_val<value_float>(left_val)) &&
|
||||
(is_val<value_int>(right_val) || is_val<value_float>(right_val))) {
|
||||
double a = left_val->as_float();
|
||||
double b = right_val->as_float();
|
||||
if (op.value == "+" || op.value == "-" || op.value == "*") {
|
||||
double res = (op.value == "+") ? a + b : (op.value == "-") ? a - b : a * b;
|
||||
bool is_float = is_type<value_float>(left_val) || is_type<value_float>(right_val);
|
||||
bool is_float = is_val<value_float>(left_val) || is_val<value_float>(right_val);
|
||||
if (is_float) {
|
||||
return std::make_unique<value_float_t>(res);
|
||||
} else {
|
||||
|
|
@ -82,7 +66,7 @@ value binary_expression::execute(context & ctx) {
|
|||
return std::make_unique<value_float_t>(a / b);
|
||||
} else if (op.value == "%") {
|
||||
double rem = std::fmod(a, b);
|
||||
bool is_float = is_type<value_float>(left_val) || is_type<value_float>(right_val);
|
||||
bool is_float = is_val<value_float>(left_val) || is_val<value_float>(right_val);
|
||||
if (is_float) {
|
||||
return std::make_unique<value_float_t>(rem);
|
||||
} else {
|
||||
|
|
@ -100,7 +84,7 @@ value binary_expression::execute(context & ctx) {
|
|||
}
|
||||
|
||||
// Array operations
|
||||
if (is_type<value_array>(left_val) && is_type<value_array>(right_val)) {
|
||||
if (is_val<value_array>(left_val) && is_val<value_array>(right_val)) {
|
||||
if (op.value == "+") {
|
||||
auto & left_arr = left_val->as_array();
|
||||
auto & right_arr = right_val->as_array();
|
||||
|
|
@ -113,7 +97,7 @@ value binary_expression::execute(context & ctx) {
|
|||
}
|
||||
return result;
|
||||
}
|
||||
} else if (is_type<value_array>(right_val)) {
|
||||
} else if (is_val<value_array>(right_val)) {
|
||||
auto & arr = right_val->as_array();
|
||||
bool member = std::find_if(arr.begin(), arr.end(), [&](const value& v) { return v == left_val; }) != arr.end();
|
||||
if (op.value == "in") {
|
||||
|
|
@ -124,14 +108,14 @@ value binary_expression::execute(context & ctx) {
|
|||
}
|
||||
|
||||
// String concatenation
|
||||
if (is_type<value_string>(left_val) || is_type<value_string>(right_val)) {
|
||||
if (is_val<value_string>(left_val) || is_val<value_string>(right_val)) {
|
||||
if (op.value == "+") {
|
||||
return std::make_unique<value_string_t>(left_val->as_string() + right_val->as_string());
|
||||
}
|
||||
}
|
||||
|
||||
// String membership
|
||||
if (is_type<value_string>(left_val) && is_type<value_string>(right_val)) {
|
||||
if (is_val<value_string>(left_val) && is_val<value_string>(right_val)) {
|
||||
auto left_str = left_val->as_string();
|
||||
auto right_str = right_val->as_string();
|
||||
if (op.value == "in") {
|
||||
|
|
@ -142,7 +126,7 @@ value binary_expression::execute(context & ctx) {
|
|||
}
|
||||
|
||||
// String in object
|
||||
if (is_type<value_string>(left_val) && is_type<value_object>(right_val)) {
|
||||
if (is_val<value_string>(left_val) && is_val<value_object>(right_val)) {
|
||||
auto key = left_val->as_string();
|
||||
auto & obj = right_val->as_object();
|
||||
bool has_key = obj.find(key) != obj.end();
|
||||
|
|
@ -158,7 +142,7 @@ value binary_expression::execute(context & ctx) {
|
|||
|
||||
value filter_expression::execute(context & ctx) {
|
||||
value input = operand->execute(ctx);
|
||||
value filter_func = filter->execute(ctx);
|
||||
// value filter_func = filter->execute(ctx);
|
||||
|
||||
if (is_stmt<identifier>(filter)) {
|
||||
auto filter_val = dynamic_cast<identifier*>(filter.get())->value;
|
||||
|
|
@ -168,7 +152,7 @@ value filter_expression::execute(context & ctx) {
|
|||
throw std::runtime_error("to_json filter not implemented");
|
||||
}
|
||||
|
||||
if (is_type<value_array>(input)) {
|
||||
if (is_val<value_array>(input)) {
|
||||
auto & arr = input->as_array();
|
||||
if (filter_val == "list") {
|
||||
return std::make_unique<value_array_t>(input);
|
||||
|
|
@ -189,12 +173,18 @@ value filter_expression::execute(context & ctx) {
|
|||
throw std::runtime_error("Unknown filter '" + filter_val + "' for array");
|
||||
}
|
||||
|
||||
} else if (is_type<value_string>(input)) {
|
||||
} else if (is_val<value_string>(input)) {
|
||||
auto str = input->as_string();
|
||||
// TODO
|
||||
auto builtins = input->get_builtins();
|
||||
auto it = builtins.find(filter_val);
|
||||
if (it != builtins.end()) {
|
||||
func_args args;
|
||||
args.args.push_back(input->clone());
|
||||
return it->second(args);
|
||||
}
|
||||
throw std::runtime_error("Unknown filter '" + filter_val + "' for string");
|
||||
|
||||
} else if (is_type<value_int>(input) || is_type<value_float>(input)) {
|
||||
} else if (is_val<value_int>(input) || is_val<value_float>(input)) {
|
||||
// TODO
|
||||
throw std::runtime_error("Unknown filter '" + filter_val + "' for number");
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ struct context {
|
|||
struct statement {
|
||||
virtual ~statement() = default;
|
||||
virtual std::string type() const { return "Statement"; }
|
||||
virtual value execute(context & ctx) { throw std::runtime_error("cannot exec " + type()); };
|
||||
virtual value execute(context & ctx) { throw std::runtime_error("cannot exec " + type()); }
|
||||
};
|
||||
|
||||
using statement_ptr = std::unique_ptr<statement>;
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
int main(void) {
|
||||
//std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}";
|
||||
|
||||
std::string contents = "{{ 'hi' + 'fi' }}";
|
||||
std::string contents = "{{ ('hi' + 'fi') | upper }}";
|
||||
|
||||
std::cout << "=== INPUT ===\n" << contents << "\n\n";
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue