diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index b06f465a1d..01cfffe529 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -15,12 +15,22 @@ struct value_t { double val_flt; std::string val_str; bool val_bool; - std::vector val_arr; - std::map val_obj; + + // array and object are stored as shared_ptr to allow reference access + // example: + // my_obj = {"a": 1, "b": 2} + // my_arr = [my_obj] + // my_obj["a"] = 3 + // print(my_arr[0]["a"]) # should print 3 + std::shared_ptr> val_arr; + std::shared_ptr> val_obj; + + value_t() = default; + value_t(const value_t &) = default; + virtual ~value_t() = default; virtual std::string type() const { return ""; } - virtual ~value_t() = default; virtual int64_t as_int() const { throw std::runtime_error("Not an int value"); } virtual double as_float() const { throw std::runtime_error("Not a float value"); } virtual std::string as_string() const { throw std::runtime_error("Not a string value"); } @@ -30,6 +40,10 @@ struct value_t { virtual bool is_null() const { return false; } virtual bool is_undefined() const { return false; } + virtual value clone() const { + return std::make_unique(*this); + } + virtual bool operator==(const value & other) const { // TODO return false; @@ -44,6 +58,8 @@ struct value_int_t : public value_t { virtual std::string type() const override { return "Integer"; } virtual int64_t as_int() const override { return val_int; } virtual double as_float() const override { return static_cast(val_int); } + virtual std::string as_string() const override { return std::to_string(val_int); } + virtual value clone() const override { return std::make_unique(*this); } }; using value_int = std::unique_ptr; @@ -52,6 +68,8 @@ struct value_float_t : public value_t { virtual std::string type() const override { return "Float"; } virtual double as_float() const override { return val_flt; } virtual int64_t as_int() const override { return static_cast(val_flt); } + virtual std::string as_string() const override { return std::to_string(val_flt); } + virtual value clone() const override { return std::make_unique(*this); } }; using value_float = std::unique_ptr; @@ -59,6 +77,7 @@ struct value_string_t : public value_t { value_string_t(const std::string & v) { val_str = v; } virtual std::string type() const override { return "String"; } virtual std::string as_string() const override { return val_str; } + virtual value clone() const override { return std::make_unique(*this); } }; using value_string = std::unique_ptr; @@ -66,32 +85,81 @@ struct value_bool_t : public value_t { value_bool_t(bool v) { val_bool = v; } virtual std::string type() const override { return "Boolean"; } virtual bool as_bool() const override { return val_bool; } + virtual std::string as_string() const override { return val_bool ? "True" : "False"; } + virtual value clone() const override { return std::make_unique(*this); } }; using value_bool = std::unique_ptr; struct value_array_t : public value_t { - value_array_t(const std::vector && v) { val_arr = std::move(v); } + value_array_t() { + val_arr = std::make_shared>(); + } + value_array_t(value & v) { + // point to the same underlying data + val_arr = v->val_arr; + } + value_array_t(value_array_t & other, size_t start = 0, size_t end = -1) { + val_arr = std::make_shared>(); + size_t sz = other.val_arr->size(); + if (end == static_cast(-1) || end > sz) { + end = sz; + } + if (start > end || start >= sz) { + return; + } + for (size_t i = start; i < end; i++) { + val_arr->push_back(other.val_arr->at(i)->clone()); + } + } virtual std::string type() const override { return "Array"; } - virtual const std::vector & as_array() const override { return val_arr; } + virtual const std::vector & as_array() const override { return *val_arr; } + virtual value clone() const override { + auto tmp = std::make_unique(); + tmp->val_arr = this->val_arr; + return tmp; + } }; using value_array = std::unique_ptr; -struct value_object_t : public value_t { - value_object_t(const std::map & v) { val_obj = v; } +/*struct value_object_t : public value_t { + value_object_t() { + val_obj = std::make_shared>(); + } + value_object_t(value & v) { + // point to the same underlying data + val_obj = v->val_obj; + } + value_object_t(const std::map & obj) { + val_obj = std::make_shared>(obj); + } virtual std::string type() const override { return "Object"; } - virtual const std::map & as_object() const override { return val_obj; } + virtual const std::map & as_object() const override { return *val_obj; } + virtual value clone() const override { + auto tmp = std::make_unique(); + tmp->val_obj = this->val_obj; + return tmp; + } +}; +using value_object = std::unique_ptr;*/ + +struct value_object_t : public value_t { + virtual std::string type() const override { return "TEST"; } + virtual bool is_null() const override { return true; } + virtual value clone() const override { return std::make_unique(*this); } }; using value_object = std::unique_ptr; struct value_null_t : public value_t { virtual std::string type() const override { return "Null"; } virtual bool is_null() const override { return true; } + virtual value clone() const override { return std::make_unique(*this); } }; using value_null = std::unique_ptr; struct value_undefined_t : public value_t { virtual std::string type() const override { return "Undefined"; } virtual bool is_undefined() const override { return true; } + virtual value clone() const override { return std::make_unique(*this); } }; using value_undefined = std::unique_ptr; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 1c3ec49013..aff6e90603 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -9,22 +9,27 @@ namespace jinja { -// Helper to check type without asserting (useful for logic) +// Helper to extract the inner type if T is unique_ptr, else T itself template -static bool is_type(const value & ptr) { - return dynamic_cast(ptr.get()) != nullptr; +struct extract_pointee { + using type = T; +}; + +template +struct extract_pointee> { + using type = U; +}; + +template +static bool is_type(const value& ptr) { + using PointeeType = typename extract_pointee::type; + return dynamic_cast(ptr.get()) != nullptr; } -struct vm { - context & ctx; - explicit vm(context & ctx) : ctx(ctx) {} - - void execute(program & prog) { - for (auto & stmt : prog.body) { - stmt->execute(ctx); - } - } -}; +template +static bool is_stmt(const statement_ptr & ptr) { + return dynamic_cast(ptr.get()) != nullptr; +} value binary_expression::execute(context & ctx) { value left_val = left->execute(ctx); @@ -97,13 +102,16 @@ value binary_expression::execute(context & ctx) { // Array operations if (is_type(left_val) && is_type(right_val)) { if (op.value == "+") { - auto& left_arr = left_val->as_array(); - auto& right_arr = right_val->as_array(); - std::vector result = left_arr; - for (auto & v : right_arr) { - result.push_back(std::move(v)); + auto & left_arr = left_val->as_array(); + auto & right_arr = right_val->as_array(); + auto result = std::make_unique(); + for (const auto & item : left_arr) { + result->val_arr->push_back(item->clone()); } - return std::make_unique(result); + for (const auto & item : right_arr) { + result->val_arr->push_back(item->clone()); + } + return result; } } else if (is_type(right_val)) { auto & arr = right_val->as_array(); @@ -148,4 +156,52 @@ value binary_expression::execute(context & ctx) { throw std::runtime_error("Unknown operator \"" + op.value + "\" between " + left_val->type() + " and " + right_val->type()); } +value filter_expression::execute(context & ctx) { + value input = operand->execute(ctx); + value filter_func = filter->execute(ctx); + + if (is_stmt(filter)) { + auto filter_val = dynamic_cast(filter.get())->value; + + if (filter_val == "to_json") { + // TODO: Implement to_json filter + throw std::runtime_error("to_json filter not implemented"); + } + + if (is_type(input)) { + auto & arr = input->as_array(); + if (filter_val == "list") { + return std::make_unique(input); + } else if (filter_val == "first") { + if (arr.empty()) { + return std::make_unique(); + } + return arr[0]->clone(); + } else if (filter_val == "last") { + if (arr.empty()) { + return std::make_unique(); + } + return arr[arr.size() - 1]->clone(); + } else if (filter_val == "length") { + return std::make_unique(static_cast(arr.size())); + } else { + // TODO: reverse, sort, join, string, unique + throw std::runtime_error("Unknown filter '" + filter_val + "' for array"); + } + + } else if (is_type(input)) { + auto str = input->as_string(); + // TODO + throw std::runtime_error("Unknown filter '" + filter_val + "' for string"); + + } else if (is_type(input) || is_type(input)) { + // TODO + throw std::runtime_error("Unknown filter '" + filter_val + "' for number"); + + } else { + throw std::runtime_error("Filters not supported for type " + input->type()); + } + } +} + } // namespace jinja diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 58a71abe24..2c547294a8 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -23,7 +23,7 @@ struct context { struct statement { virtual ~statement() = default; virtual std::string type() const { return "Statement"; } - virtual value execute(context & ctx) = 0; + virtual value execute(context & ctx) { throw std::runtime_error("cannot exec " + type()); }; }; using statement_ptr = std::unique_ptr; @@ -186,44 +186,53 @@ struct identifier : public expression { // Literals struct integer_literal : public expression { - int64_t value; - explicit integer_literal(int64_t value) : value(value) {} + int64_t val; + explicit integer_literal(int64_t val) : val(val) {} std::string type() const override { return "IntegerLiteral"; } + value execute(context & ctx) override { + return std::make_unique(val); + } }; struct float_literal : public expression { - double value; - explicit float_literal(double value) : value(value) {} + double val; + explicit float_literal(double val) : val(val) {} std::string type() const override { return "FloatLiteral"; } + value execute(context & ctx) override { + return std::make_unique(val); + } }; struct string_literal : public expression { - std::string value; - explicit string_literal(const std::string & value) : value(value) {} + std::string val; + explicit string_literal(const std::string & val) : val(val) {} std::string type() const override { return "StringLiteral"; } + value execute(context & ctx) override { + return std::make_unique(val); + } }; struct array_literal : public expression { - statements value; - explicit array_literal(statements && value) : value(std::move(value)) { - for (const auto& item : this->value) chk_type(item); + statements val; + explicit array_literal(statements && val) : val(std::move(val)) { + for (const auto& item : this->val) chk_type(item); } std::string type() const override { return "ArrayLiteral"; } }; struct tuple_literal : public expression { - statements value; - explicit tuple_literal(statements && value) : value(std::move(value)) { - for (const auto& item : this->value) chk_type(item); + statements val; + explicit tuple_literal(statements && val) : val(std::move(val)) { + for (const auto & item : this->val) chk_type(item); } std::string type() const override { return "TupleLiteral"; } }; struct object_literal : public expression { - std::vector> value; - explicit object_literal(std::vector> && value) - : value(std::move(value)) { - for (const auto & pair : this->value) { + std::vector> val; + explicit object_literal(std::vector> && val) + : val(std::move(val)) { + for (const auto & pair : this->val) { chk_type(pair.first); chk_type(pair.second); } @@ -391,4 +400,20 @@ struct ternary_expression : public expression { std::string type() const override { return "Ternary"; } }; +////////////////////// + +struct vm { + context & ctx; + explicit vm(context & ctx) : ctx(ctx) {} + + std::vector execute(program & prog) { + std::vector results; + for (auto & stmt : prog.body) { + value res = stmt->execute(ctx); + results.push_back(std::move(res)); + } + return results; + } +}; + } // namespace jinja diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index ebebba37b1..e0b5d8f8d9 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -11,7 +11,9 @@ #include "jinja/jinja-lexer.h" int main(void) { - std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; + //std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; + + std::string contents = "{{ 'hi' + 'fi' }}"; std::cout << "=== INPUT ===\n" << contents << "\n\n"; @@ -24,11 +26,20 @@ int main(void) { std::cout << "token: type=" << static_cast(tok.t) << " text='" << tok.value << "'\n"; } - jinja::program ast = jinja::parse_from_tokens(tokens); std::cout << "\n=== AST ===\n"; + jinja::program ast = jinja::parse_from_tokens(tokens); for (const auto & stmt : ast.body) { std::cout << "stmt type: " << stmt->type() << "\n"; } + std::cout << "\n=== OUTPUT ===\n"; + jinja::context ctx; + jinja::vm vm(ctx); + auto results = vm.execute(ast); + for (const auto & res : results) { + std::cout << "result type: " << res->type() << "\n"; + std::cout << "result value: " << res->as_string() << "\n"; + } + return 0; }