diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 9a6acae7e2..a5362169c4 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -5,6 +5,7 @@ #include #include #include +#include namespace jinja { @@ -144,6 +145,8 @@ using value_float = std::unique_ptr; struct value_string_t : public value_t { + bool is_user_input = false; // may skip parsing special tokens if true + value_string_t(const std::string & v) { val_str = v; } virtual std::string type() const override { return "String"; } virtual std::string as_string() const override { return val_str; } @@ -192,6 +195,16 @@ struct value_array_t : public value_t { tmp->val_arr = this->val_arr; return tmp; } + virtual std::string as_string() const override { + std::ostringstream ss; + ss << "["; + for (size_t i = 0; i < val_arr->size(); i++) { + if (i > 0) ss << ", "; + ss << val_arr->at(i)->as_string(); + } + ss << "]"; + return ss.str(); + } virtual const func_builtins & get_builtins() const override; }; using value_array = std::unique_ptr; diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index cc2b2b39a0..860f67b629 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -197,7 +197,7 @@ const func_builtins & value_string_t::get_builtins() const { }}, }; return builtins; -}; +} const func_builtins & value_bool_t::get_builtins() const { diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 3a28977e6b..73ad5bae0d 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -8,6 +8,9 @@ #include #include +#define JJ_DEBUG(msg, ...) printf("jinja-vm: " msg "\n", __VA_ARGS__) +//#define JJ_DEBUG(msg, ...) // no-op + namespace jinja { template @@ -15,6 +18,17 @@ static bool is_stmt(const statement_ptr & ptr) { return dynamic_cast(ptr.get()) != nullptr; } +value identifier::execute(context & ctx) { + auto it = ctx.var.find(val); + if (it != ctx.var.end()) { + JJ_DEBUG("Identifier '%s' found", val.c_str()); + return it->second->clone(); + } else { + JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str()); + return mk_val(); + } +} + value binary_expression::execute(context & ctx) { value left_val = left->execute(ctx); @@ -151,11 +165,11 @@ value filter_expression::execute(context & ctx) { args.args.push_back(input->clone()); return it->second(args); } - return nullptr; + throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type()); }; if (is_stmt(filter)) { - auto filter_val = dynamic_cast(filter.get())->value; + auto filter_val = dynamic_cast(filter.get())->val; if (filter_val == "to_json") { // TODO: Implement to_json filter @@ -204,7 +218,15 @@ value filter_expression::execute(context & ctx) { } value if_statement::execute(context & ctx) { - throw std::runtime_error("if_statement::execute not implemented"); + value test_val = test->execute(ctx); + auto out = mk_val(); + if (test_val->as_bool()) { + for (auto & stmt : body) { + JJ_DEBUG("Executing if body statement of type %s", stmt->type().c_str()); + out->val_arr->push_back(stmt->execute(ctx)); + } + } + return out; } value for_statement::execute(context & ctx) { @@ -223,4 +245,79 @@ value set_statement::execute(context & ctx) { throw std::runtime_error("set_statement::execute not implemented"); } +value member_expression::execute(context & ctx) { + value object = this->object->execute(ctx); + + value property; + if (this->computed) { + property = this->property->execute(ctx); + } else { + property = mk_val(dynamic_cast(this->property.get())->val); + } + + value val = mk_val(); + + if (is_val(object)) { + if (!is_val(property)) { + throw std::runtime_error("Cannot access object with non-string: got " + property->type()); + } + auto key = property->as_string(); + auto & obj = object->as_object(); + auto it = obj.find(key); + if (it != obj.end()) { + val = it->second->clone(); + } else { + auto builtins = object->get_builtins(); + auto bit = builtins.find(key); + if (bit != builtins.end()) { + func_args args; + args.args.push_back(object->clone()); + val = bit->second(args); + } + } + + } else if (is_val(object) || is_val(object)) { + if (is_val(property)) { + int64_t index = property->as_int(); + if (is_val(object)) { + auto & arr = object->as_array(); + if (index >= 0 && index < static_cast(arr.size())) { + val = arr[index]->clone(); + } + } else { // value_string + auto str = object->as_string(); + if (index >= 0 && index < static_cast(str.size())) { + val = mk_val(std::string(1, str[index])); + } + } + } else if (is_val(property)) { + auto key = property->as_string(); + auto builtins = object->get_builtins(); + auto bit = builtins.find(key); + if (bit != builtins.end()) { + func_args args; + args.args.push_back(object->clone()); + val = bit->second(args); + } + } else { + throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type()); + } + + } else { + if (!is_val(property)) { + throw std::runtime_error("Cannot access property with non-string: got " + property->type()); + } + auto key = property->as_string(); + auto builtins = object->get_builtins(); + auto bit = builtins.find(key); + if (bit != builtins.end()) { + func_args args; + args.args.push_back(object->clone()); + val = bit->second(args); + } + } + + return val; +} + } // namespace jinja diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 5b620026a2..d2e763b13b 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -171,6 +171,7 @@ struct member_expression : public expression { chk_type(this->property); } std::string type() const override { return "MemberExpression"; } + value execute(context & ctx) override; }; struct call_expression : public expression { @@ -189,9 +190,10 @@ struct call_expression : public expression { * Represents a user-defined variable or symbol in the template. */ struct identifier : public expression { - std::string value; - explicit identifier(const std::string & value) : value(value) {} + std::string val; + explicit identifier(const std::string & val) : val(val) {} std::string type() const override { return "Identifier"; } + value execute(context & ctx) override; }; // Literals diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index e923da4481..63048841c3 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -11,9 +11,11 @@ #include "jinja/jinja-lexer.h" int main(void) { - std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; + //std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}"; - //std::string contents = "{{ ('hi' + 'fi') | upper }}"; + //std::string contents = "{% if messages[0]['role'] != 'system' %}nice {{ messages[0]['content'] }}{% endif %}"; + + std::string contents = " {{ messages[0]['content'] }} "; std::cout << "=== INPUT ===\n" << contents << "\n\n"; @@ -34,11 +36,34 @@ int main(void) { std::cout << "\n=== OUTPUT ===\n"; jinja::context ctx; + + auto make_non_special_string = [](const std::string & s) { + jinja::value_string str_val = std::make_unique(s); + str_val->is_user_input = true; + return str_val; + }; + + jinja::value messages = jinja::mk_val(); + jinja::value msg1 = jinja::mk_val(); + (*msg1->val_obj)["role"] = make_non_special_string("user"); + (*msg1->val_obj)["content"] = make_non_special_string("Hello, how are you?"); + messages->val_arr->push_back(std::move(msg1)); + jinja::value msg2 = jinja::mk_val(); + (*msg2->val_obj)["role"] = make_non_special_string("assistant"); + (*msg2->val_obj)["content"] = make_non_special_string("I am fine, thank you!"); + messages->val_arr->push_back(std::move(msg2)); + + ctx.var["messages"] = std::move(messages); + jinja::vm vm(ctx); auto results = vm.execute(ast); for (const auto & res : results) { - std::cout << "result type: " << res->type() << "\n"; - std::cout << "result value: " << res->as_string() << "\n"; + auto str_ptr = dynamic_cast(res.get()); + std::string is_user_input = "false"; + if (str_ptr) { + is_user_input = str_ptr->is_user_input ? "true" : "false"; + } + std::cout << "result type: " << res->type() << " | value: " << res->as_string() << " | is_user_input: " << is_user_input << "\n"; } return 0;