eval with is_user_input
This commit is contained in:
parent
c08f4ddf01
commit
10835f2720
|
|
@ -5,6 +5,7 @@
|
|||
#include <map>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
|
||||
|
||||
namespace jinja {
|
||||
|
|
@ -144,6 +145,8 @@ using value_float = std::unique_ptr<value_float_t>;
|
|||
|
||||
|
||||
struct value_string_t : public value_t {
|
||||
bool is_user_input = false; // may skip parsing special tokens if true
|
||||
|
||||
value_string_t(const std::string & v) { val_str = v; }
|
||||
virtual std::string type() const override { return "String"; }
|
||||
virtual std::string as_string() const override { return val_str; }
|
||||
|
|
@ -192,6 +195,16 @@ struct value_array_t : public value_t {
|
|||
tmp->val_arr = this->val_arr;
|
||||
return tmp;
|
||||
}
|
||||
virtual std::string as_string() const override {
|
||||
std::ostringstream ss;
|
||||
ss << "[";
|
||||
for (size_t i = 0; i < val_arr->size(); i++) {
|
||||
if (i > 0) ss << ", ";
|
||||
ss << val_arr->at(i)->as_string();
|
||||
}
|
||||
ss << "]";
|
||||
return ss.str();
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
};
|
||||
using value_array = std::unique_ptr<value_array_t>;
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ const func_builtins & value_string_t::get_builtins() const {
|
|||
}},
|
||||
};
|
||||
return builtins;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
const func_builtins & value_bool_t::get_builtins() const {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,9 @@
|
|||
#include <memory>
|
||||
#include <algorithm>
|
||||
|
||||
#define JJ_DEBUG(msg, ...) printf("jinja-vm: " msg "\n", __VA_ARGS__)
|
||||
//#define JJ_DEBUG(msg, ...) // no-op
|
||||
|
||||
namespace jinja {
|
||||
|
||||
template<typename T>
|
||||
|
|
@ -15,6 +18,17 @@ static bool is_stmt(const statement_ptr & ptr) {
|
|||
return dynamic_cast<const T*>(ptr.get()) != nullptr;
|
||||
}
|
||||
|
||||
value identifier::execute(context & ctx) {
|
||||
auto it = ctx.var.find(val);
|
||||
if (it != ctx.var.end()) {
|
||||
JJ_DEBUG("Identifier '%s' found", val.c_str());
|
||||
return it->second->clone();
|
||||
} else {
|
||||
JJ_DEBUG("Identifier '%s' not found, returning undefined", val.c_str());
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
}
|
||||
|
||||
value binary_expression::execute(context & ctx) {
|
||||
value left_val = left->execute(ctx);
|
||||
|
||||
|
|
@ -151,11 +165,11 @@ value filter_expression::execute(context & ctx) {
|
|||
args.args.push_back(input->clone());
|
||||
return it->second(args);
|
||||
}
|
||||
return nullptr;
|
||||
throw std::runtime_error("Unknown (built-in) filter '" + name + "' for type " + input->type());
|
||||
};
|
||||
|
||||
if (is_stmt<identifier>(filter)) {
|
||||
auto filter_val = dynamic_cast<identifier*>(filter.get())->value;
|
||||
auto filter_val = dynamic_cast<identifier*>(filter.get())->val;
|
||||
|
||||
if (filter_val == "to_json") {
|
||||
// TODO: Implement to_json filter
|
||||
|
|
@ -204,7 +218,15 @@ value filter_expression::execute(context & ctx) {
|
|||
}
|
||||
|
||||
value if_statement::execute(context & ctx) {
|
||||
throw std::runtime_error("if_statement::execute not implemented");
|
||||
value test_val = test->execute(ctx);
|
||||
auto out = mk_val<value_array>();
|
||||
if (test_val->as_bool()) {
|
||||
for (auto & stmt : body) {
|
||||
JJ_DEBUG("Executing if body statement of type %s", stmt->type().c_str());
|
||||
out->val_arr->push_back(stmt->execute(ctx));
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
value for_statement::execute(context & ctx) {
|
||||
|
|
@ -223,4 +245,79 @@ value set_statement::execute(context & ctx) {
|
|||
throw std::runtime_error("set_statement::execute not implemented");
|
||||
}
|
||||
|
||||
value member_expression::execute(context & ctx) {
|
||||
value object = this->object->execute(ctx);
|
||||
|
||||
value property;
|
||||
if (this->computed) {
|
||||
property = this->property->execute(ctx);
|
||||
} else {
|
||||
property = mk_val<value_string>(dynamic_cast<identifier*>(this->property.get())->val);
|
||||
}
|
||||
|
||||
value val = mk_val<value_undefined>();
|
||||
|
||||
if (is_val<value_object>(object)) {
|
||||
if (!is_val<value_string>(property)) {
|
||||
throw std::runtime_error("Cannot access object with non-string: got " + property->type());
|
||||
}
|
||||
auto key = property->as_string();
|
||||
auto & obj = object->as_object();
|
||||
auto it = obj.find(key);
|
||||
if (it != obj.end()) {
|
||||
val = it->second->clone();
|
||||
} else {
|
||||
auto builtins = object->get_builtins();
|
||||
auto bit = builtins.find(key);
|
||||
if (bit != builtins.end()) {
|
||||
func_args args;
|
||||
args.args.push_back(object->clone());
|
||||
val = bit->second(args);
|
||||
}
|
||||
}
|
||||
|
||||
} else if (is_val<value_array>(object) || is_val<value_string>(object)) {
|
||||
if (is_val<value_int>(property)) {
|
||||
int64_t index = property->as_int();
|
||||
if (is_val<value_array>(object)) {
|
||||
auto & arr = object->as_array();
|
||||
if (index >= 0 && index < static_cast<int64_t>(arr.size())) {
|
||||
val = arr[index]->clone();
|
||||
}
|
||||
} else { // value_string
|
||||
auto str = object->as_string();
|
||||
if (index >= 0 && index < static_cast<int64_t>(str.size())) {
|
||||
val = mk_val<value_string>(std::string(1, str[index]));
|
||||
}
|
||||
}
|
||||
} else if (is_val<value_string>(property)) {
|
||||
auto key = property->as_string();
|
||||
auto builtins = object->get_builtins();
|
||||
auto bit = builtins.find(key);
|
||||
if (bit != builtins.end()) {
|
||||
func_args args;
|
||||
args.args.push_back(object->clone());
|
||||
val = bit->second(args);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type());
|
||||
}
|
||||
|
||||
} else {
|
||||
if (!is_val<value_string>(property)) {
|
||||
throw std::runtime_error("Cannot access property with non-string: got " + property->type());
|
||||
}
|
||||
auto key = property->as_string();
|
||||
auto builtins = object->get_builtins();
|
||||
auto bit = builtins.find(key);
|
||||
if (bit != builtins.end()) {
|
||||
func_args args;
|
||||
args.args.push_back(object->clone());
|
||||
val = bit->second(args);
|
||||
}
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
} // namespace jinja
|
||||
|
|
|
|||
|
|
@ -171,6 +171,7 @@ struct member_expression : public expression {
|
|||
chk_type<expression>(this->property);
|
||||
}
|
||||
std::string type() const override { return "MemberExpression"; }
|
||||
value execute(context & ctx) override;
|
||||
};
|
||||
|
||||
struct call_expression : public expression {
|
||||
|
|
@ -189,9 +190,10 @@ struct call_expression : public expression {
|
|||
* Represents a user-defined variable or symbol in the template.
|
||||
*/
|
||||
struct identifier : public expression {
|
||||
std::string value;
|
||||
explicit identifier(const std::string & value) : value(value) {}
|
||||
std::string val;
|
||||
explicit identifier(const std::string & val) : val(val) {}
|
||||
std::string type() const override { return "Identifier"; }
|
||||
value execute(context & ctx) override;
|
||||
};
|
||||
|
||||
// Literals
|
||||
|
|
|
|||
|
|
@ -11,9 +11,11 @@
|
|||
#include "jinja/jinja-lexer.h"
|
||||
|
||||
int main(void) {
|
||||
std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}";
|
||||
//std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}";
|
||||
|
||||
//std::string contents = "{{ ('hi' + 'fi') | upper }}";
|
||||
//std::string contents = "{% if messages[0]['role'] != 'system' %}nice {{ messages[0]['content'] }}{% endif %}";
|
||||
|
||||
std::string contents = "<some_tokens> {{ messages[0]['content'] }} <another_token>";
|
||||
|
||||
std::cout << "=== INPUT ===\n" << contents << "\n\n";
|
||||
|
||||
|
|
@ -34,11 +36,34 @@ int main(void) {
|
|||
|
||||
std::cout << "\n=== OUTPUT ===\n";
|
||||
jinja::context ctx;
|
||||
|
||||
auto make_non_special_string = [](const std::string & s) {
|
||||
jinja::value_string str_val = std::make_unique<jinja::value_string_t>(s);
|
||||
str_val->is_user_input = true;
|
||||
return str_val;
|
||||
};
|
||||
|
||||
jinja::value messages = jinja::mk_val<jinja::value_array>();
|
||||
jinja::value msg1 = jinja::mk_val<jinja::value_object>();
|
||||
(*msg1->val_obj)["role"] = make_non_special_string("user");
|
||||
(*msg1->val_obj)["content"] = make_non_special_string("Hello, how are you?");
|
||||
messages->val_arr->push_back(std::move(msg1));
|
||||
jinja::value msg2 = jinja::mk_val<jinja::value_object>();
|
||||
(*msg2->val_obj)["role"] = make_non_special_string("assistant");
|
||||
(*msg2->val_obj)["content"] = make_non_special_string("I am fine, thank you!");
|
||||
messages->val_arr->push_back(std::move(msg2));
|
||||
|
||||
ctx.var["messages"] = std::move(messages);
|
||||
|
||||
jinja::vm vm(ctx);
|
||||
auto results = vm.execute(ast);
|
||||
for (const auto & res : results) {
|
||||
std::cout << "result type: " << res->type() << "\n";
|
||||
std::cout << "result value: " << res->as_string() << "\n";
|
||||
auto str_ptr = dynamic_cast<jinja::value_string_t*>(res.get());
|
||||
std::string is_user_input = "false";
|
||||
if (str_ptr) {
|
||||
is_user_input = str_ptr->is_user_input ? "true" : "false";
|
||||
}
|
||||
std::cout << "result type: " << res->type() << " | value: " << res->as_string() << " | is_user_input: " << is_user_input << "\n";
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
|
|
|||
Loading…
Reference in New Issue