diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 844dcdef7d..8ec8e742f0 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -11,10 +11,14 @@ #define FILENAME "jinja-vm" -bool g_jinja_debug = true; +bool g_jinja_debug = false; namespace jinja { +void enable_debug(bool enable) { + g_jinja_debug = enable; +} + // func_args method implementations value func_args::get_kwarg(const std::string & key) const { @@ -273,6 +277,17 @@ value filter_expression::execute_impl(context & ctx) { } } +value filter_statement::execute_impl(context & ctx) { + // eval body as string, then apply filter + auto body_val = exec_statements(body, ctx); + value_string parts = mk_val(); + gather_string_parts_recursive(body_val, parts); + + JJ_DEBUG("FilterStatement: applying filter to body string of length %zu", parts->val_str.length()); + filter_expression filter_expr(std::move(parts), std::move(filter)); + return filter_expr.execute(ctx); +} + value test_expression::execute_impl(context & ctx) { // NOTE: "value is something" translates to function call "test_is_something(value)" const auto & builtins = global_builtins(); diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 5172969a9d..d67bc2d5c1 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -42,6 +42,10 @@ const T * cast_stmt(const statement_ptr & ptr) { } // End Helpers + +// not thread-safe +void enable_debug(bool enable); + struct context { std::map var; std::string source; // for debugging @@ -260,7 +264,7 @@ struct integer_literal : public expression { explicit integer_literal(int64_t val) : val(val) {} std::string type() const override { return "IntegerLiteral"; } value execute_impl(context &) override { - return std::make_unique(val); + return mk_val(val); } }; @@ -269,7 +273,7 @@ struct float_literal : public expression { explicit float_literal(double val) : val(val) {} std::string type() const override { return "FloatLiteral"; } value execute_impl(context &) override { - return std::make_unique(val); + return mk_val(val); } }; @@ -278,7 +282,7 @@ struct string_literal : public expression { explicit string_literal(const std::string & val) : val(val) {} std::string type() const override { return "StringLiteral"; } value execute_impl(context &) override { - return std::make_unique(val); + return mk_val(val); } }; @@ -341,7 +345,10 @@ struct binary_expression : public expression { * Operator precedence: https://github.com/pallets/jinja/issues/379#issuecomment-168076202 */ struct filter_expression : public expression { + // either an expression or a value is allowed statement_ptr operand; + value_string val; // will be set by filter_statement + statement_ptr filter; filter_expression(statement_ptr && operand, statement_ptr && filter) @@ -349,6 +356,12 @@ struct filter_expression : public expression { chk_type(this->operand); chk_type(this->filter); } + + filter_expression(value_string && val, statement_ptr && filter) + : val(std::move(val)), filter(std::move(filter)) { + chk_type(this->filter); + } + std::string type() const override { return "FilterExpression"; } value execute_impl(context & ctx) override; }; @@ -362,6 +375,7 @@ struct filter_statement : public statement { chk_type(this->filter); } std::string type() const override { return "FilterStatement"; } + value execute_impl(context & ctx) override; }; /** @@ -505,6 +519,26 @@ struct raised_exception : public std::exception { ////////////////////// +static void gather_string_parts_recursive(const value & val, value_string & parts) { + if (is_val(val)) { + const auto & str_val = cast_val(val)->val_str; + parts->val_str.append(str_val); + } else if (is_val(val)) { + auto items = cast_val(val)->as_array(); + for (const auto & item : items) { + gather_string_parts_recursive(item, parts); + } + } +} + +static std::string render_string_parts(const value_string & parts) { + std::ostringstream oss; + for (const auto & part : parts->val_str.parts) { + oss << part.val; + } + return oss.str(); +} + struct vm { context & ctx; explicit vm(context & ctx) : ctx(ctx) {} @@ -518,25 +552,11 @@ struct vm { return results; } - std::vector gather_string_parts(const value & val) { - std::vector parts; + value_string gather_string_parts(const value & val) { + value_string parts = mk_val(); gather_string_parts_recursive(val, parts); return parts; } - - void gather_string_parts_recursive(const value & val, std::vector & parts) { - if (is_val(val)) { - const auto & str_val = cast_val(val)->val_str; - for (const auto & part : str_val.parts) { - parts.push_back(part); - } - } else if (is_val(val)) { - auto items = cast_val(val)->as_array(); - for (const auto & item : items) { - gather_string_parts_recursive(item, parts); - } - } - } }; } // namespace jinja diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 64777a3495..1f9dedb1e4 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -45,7 +45,7 @@ int main(void) { void run(std::string contents) { - std::cout << "=== INPUT ===\n" << contents << "\n\n"; + // jinja::enable_debug(true); jinja::lexer lexer; jinja::preprocess_options options; @@ -53,13 +53,13 @@ void run(std::string contents) { options.lstrip_blocks = false; auto lexer_res = lexer.tokenize(contents, options); for (const auto & tok : lexer_res.tokens) { - std::cout << "token: type=" << static_cast(tok.t) << " text='" << tok.value << "' pos=" << tok.pos << "\n"; + //std::cout << "token: type=" << static_cast(tok.t) << " text='" << tok.value << "' pos=" << tok.pos << "\n"; } std::cout << "\n=== AST ===\n"; jinja::program ast = jinja::parse_from_tokens(lexer_res.tokens); for (const auto & stmt : ast.body) { - std::cout << "stmt type: " << stmt->type() << "\n"; + //std::cout << "stmt type: " << stmt->type() << "\n"; } std::cout << "\n=== RUN ===\n"; @@ -91,7 +91,7 @@ void run(std::string contents) { auto parts = vm.gather_string_parts(results); std::cout << "\n=== RESULTS ===\n"; - for (const auto & part : parts) { + for (const auto & part : parts.get()->val_str.parts) { std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n"; } }