From 9a8a45ff3bb51eeed117b7305264833758039849 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 21:32:55 +0100 Subject: [PATCH] mostly works --- common/jinja/jinja-utils.h | 26 ++++++ common/jinja/jinja-value.h | 2 +- common/jinja/jinja-vm-builtins.cpp | 69 ++++++++++++++ common/jinja/jinja-vm.cpp | 139 +++++++++++++++++------------ common/jinja/jinja-vm.h | 31 +++++-- tests/test-chat-jinja.cpp | 2 + 6 files changed, 206 insertions(+), 63 deletions(-) create mode 100644 common/jinja/jinja-utils.h diff --git a/common/jinja/jinja-utils.h b/common/jinja/jinja-utils.h new file mode 100644 index 0000000000..a7d3bea5a8 --- /dev/null +++ b/common/jinja/jinja-utils.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include +#include + +namespace jinja { + +static void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { + if (search.empty()) { + return; + } + std::string builder; + builder.reserve(s.length()); + size_t pos = 0; + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); +} + +} // namespace jinja diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 94c638eab2..a5eafda2dd 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -65,7 +65,7 @@ struct func_args { throw std::runtime_error("Expected " + std::to_string(count) + " arguments, got " + std::to_string(args.size())); } } - // TODO: add support for get kwargs + value get_kwarg(const std::string & key) const; // utility functions template void ensure_vals() const { ensure_count(1); diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index ecc2cfea52..39ae955e79 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -67,12 +67,14 @@ template static value test_type_fn(const func_args & args) { args.ensure_count(1); bool is_type = is_val(args.args[0]); + JJ_DEBUG("test_type_fn: type=%s result=%d", typeid(T).name(), is_type ? 1 : 0); return mk_val(is_type); } template static value test_type_fn(const func_args & args) { args.ensure_count(1); bool is_type = is_val(args.args[0]) || is_val(args.args[0]); + JJ_DEBUG("test_type_fn: type=%s or %s result=%d", typeid(T).name(), typeid(U).name(), is_type ? 1 : 0); return mk_val(is_type); } @@ -95,6 +97,20 @@ const func_builtins & global_builtins() { } return out; }}, + {"strftime_now", [](const func_args & args) -> value { + args.ensure_count(1); + args.ensure_vals(); + std::string format = args.args[0]->as_string().str(); + // get current time + // TODO: make sure this is the same behavior as Python's strftime + std::time_t t = std::time(nullptr); + char buf[100]; + if (std::strftime(buf, sizeof(buf), format.c_str(), std::localtime(&t))) { + return mk_val(std::string(buf)); + } else { + throw raised_exception("strftime_now: failed to format time"); + } + }}, // tests {"test_is_boolean", test_type_fn}, @@ -296,6 +312,25 @@ const func_builtins & value_string_t::get_builtins() const { args.ensure_vals(); return mk_val(args.args[0]->as_string()); }}, + {"default", [](const func_args & args) -> value { + value input = args.args[0]; + if (!is_val(input)) { + throw raised_exception("default() first argument must be a string"); + } + value default_val = mk_val(""); + if (args.args.size() > 1 && !args.args[1]->is_undefined()) { + default_val = args.args[1]; + } + value boolean_val = mk_val(false); + if (args.args.size() > 1) { + boolean_val = args.args[1]; + } + if (input->is_undefined() || (boolean_val->as_bool() && !input->as_bool())) { + return default_val; + } else { + return input; + } + }}, {"indent", [](const func_args &) -> value { throw std::runtime_error("indent builtin not implemented"); }}, @@ -380,6 +415,40 @@ const func_builtins & value_array_t::get_builtins() const { res->val_arr = std::move(arr); return res; }}, + {"selectattr", [](const func_args & args) -> value { + value input = args.args[0]; + if (!is_val(input)) { + throw raised_exception("selectattr() first argument must be an array, got " + input->type()); + } + std::vector selected; + for (size_t i = 1; i < args.args.size(); ++i) { + const auto & v = args.args[i]; + if (!is_val(v)) { + throw raised_exception("selectattr() attributes must be strings, got " + v->type()); + } + JJ_DEBUG("selectattr: selecting attribute '%s'", v->as_string().str().c_str()); + selected.push_back(v->as_string().str()); + } + auto result = mk_val(); + for (const auto & item : input->as_array()) { + if (!is_val(item)) { + continue; + } + const auto & obj = item->as_object(); + bool match = true; + for (const auto & attr : selected) { + auto it = obj.find(attr); + if (it == obj.end() || it->second->is_undefined() || (is_val(it->second) && !it->second->as_bool())) { + match = false; + break; + } + } + if (match) { + result->push_back(item); + } + } + return result; + }}, // TODO: reverse, sort, join, string, unique }; return builtins; diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index 276c79156c..844dcdef7d 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -2,6 +2,7 @@ #include "jinja-vm.h" #include "jinja-parser.h" #include "jinja-value.h" +#include "jinja-utils.h" #include #include @@ -14,6 +15,22 @@ bool g_jinja_debug = true; namespace jinja { +// func_args method implementations + +value func_args::get_kwarg(const std::string & key) const { + for (const auto & arg : args) { + if (is_val(arg)) { + auto * kwarg = cast_val(arg); + if (kwarg->key == key) { + return kwarg->val; + } + } + } + return mk_val(); +} + +// utils + static value_array exec_statements(const statements & stmts, context & ctx) { auto result = mk_val(); for (const auto & stmt : stmts) { @@ -23,23 +40,6 @@ static value_array exec_statements(const statements & stmts, context & ctx) { return result; } -static void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { - if (search.empty()) { - return; - } - std::string builder; - builder.reserve(s.length()); - size_t pos = 0; - size_t last_pos = 0; - while ((pos = s.find(search, last_pos)) != std::string::npos) { - builder.append(s, last_pos, pos - last_pos); - builder.append(replace); - last_pos = pos + search.length(); - } - builder.append(s, last_pos, std::string::npos); - s = std::move(builder); -} - // execute with error handling value statement::execute(context & ctx) { try { @@ -138,6 +138,7 @@ value binary_expression::execute_impl(context & ctx) { return mk_val(static_cast(res)); } } else if (op.value == "/") { + JJ_DEBUG("Division operation: %f / %f", a, b); return mk_val(a / b); } else if (op.value == "%") { double rem = std::fmod(a, b); @@ -149,12 +150,16 @@ value binary_expression::execute_impl(context & ctx) { return mk_val(static_cast(rem)); } } else if (op.value == "<") { + JJ_DEBUG("Comparison operation: %f < %f is %d", a, b, a < b); return mk_val(a < b); } else if (op.value == ">") { + JJ_DEBUG("Comparison operation: %f > %f is %d", a, b, a > b); return mk_val(a > b); } else if (op.value == ">=") { + JJ_DEBUG("Comparison operation: %f >= %f is %d", a, b, a >= b); return mk_val(a >= b); } else if (op.value == "<=") { + JJ_DEBUG("Comparison operation: %f <= %f is %d", a, b, a <= b); return mk_val(a <= b); } } @@ -235,24 +240,33 @@ static value try_builtin_func(const std::string & name, const value & input, boo value filter_expression::execute_impl(context & ctx) { value input = operand->execute(ctx); - if (is_stmt(filter)) { - auto filter_val = cast_stmt(filter)->val; + JJ_DEBUG("Applying filter to %s", input->type().c_str()); - if (filter_val == "to_json") { + if (is_stmt(filter)) { + auto filter_id = cast_stmt(filter)->val; + + if (filter_id == "to_json") { // TODO: Implement to_json filter throw std::runtime_error("to_json filter not implemented"); } - if (filter_val == "trim") { - filter_val = "strip"; // alias + if (filter_id == "trim") { + filter_id = "strip"; // alias } - JJ_DEBUG("Applying filter '%s' to %s", filter_val.c_str(), input->type().c_str()); - return try_builtin_func(filter_val, input)->invoke({}); + JJ_DEBUG("Applying filter '%s' to %s", filter_id.c_str(), input->type().c_str()); + return try_builtin_func(filter_id, input)->invoke({}); } else if (is_stmt(filter)) { - // TODO - // value filter_func = filter->execute(ctx); - throw std::runtime_error("Filter with arguments not implemented"); + auto call = cast_stmt(filter); + auto filter_id = cast_stmt(call->callee)->val; + + JJ_DEBUG("Applying filter '%s' with arguments to %s", filter_id.c_str(), input->type().c_str()); + func_args args; + for (const auto & arg_expr : call->args) { + args.args.push_back(arg_expr->execute(ctx)); + } + + return try_builtin_func(filter_id, input)->invoke(args); } else { throw std::runtime_error("Invalid filter expression"); @@ -268,7 +282,7 @@ value test_expression::execute_impl(context & ctx) { auto test_id = cast_stmt(test)->val; auto it = builtins.find("test_is_" + test_id); - JJ_DEBUG("Test expression %s '%s' %s", operand->type().c_str(), test_id.c_str(), negate ? "(negate)" : ""); + JJ_DEBUG("Test expression %s '%s' %s (using function 'test_is_%s')", operand->type().c_str(), test_id.c_str(), negate ? "(negate)" : "", test_id.c_str()); if (it == builtins.end()) { throw std::runtime_error("Unknown test '" + test_id + "'"); } @@ -336,6 +350,12 @@ value for_statement::execute_impl(context & ctx) { JJ_DEBUG("Executing for statement, iterable type: %s", iter_expr->type().c_str()); value iterable_val = iter_expr->execute(scope); + + if (iterable_val->is_undefined()) { + JJ_DEBUG("%s", "For loop iterable is undefined, skipping loop"); + iterable_val = mk_val(); + } + if (!is_val(iterable_val) && !is_val(iterable_val)) { throw std::runtime_error("Expected iterable or object type in for loop: got " + iterable_val->type()); } @@ -555,7 +575,10 @@ value member_expression::execute_impl(context & ctx) { value val = mk_val(); - if (is_val(object)) { + if (is_val(object)) { + JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined"); + return val; + } else if (is_val(object)) { if (!is_val(property)) { throw std::runtime_error("Cannot access object with non-string: got " + property->type()); } @@ -623,35 +646,39 @@ value call_expression::execute_impl(context & ctx) { // compare operator for value_t bool value_compare(const value & a, const value & b) { - JJ_DEBUG("Comparing types: %s and %s", a->type().c_str(), b->type().c_str()); - // compare numeric types - if ((is_val(a) || is_val(a)) && - (is_val(b) || is_val(b))){ - try { - return a->as_float() == b->as_float(); - } catch (...) {} - } - // compare string and number - // TODO: not sure if this is the right behavior - if ((is_val(b) && (is_val(a) || is_val(a))) || - (is_val(a) && (is_val(b) || is_val(b)))) { - try { + auto cmp = [&]() { + // compare numeric types + if ((is_val(a) || is_val(a)) && + (is_val(b) || is_val(b))){ + try { + return a->as_float() == b->as_float(); + } catch (...) {} + } + // compare string and number + // TODO: not sure if this is the right behavior + if ((is_val(b) && (is_val(a) || is_val(a))) || + (is_val(a) && (is_val(b) || is_val(b)))) { + try { + return a->as_string().str() == b->as_string().str(); + } catch (...) {} + } + // compare boolean simple + if (is_val(a) && is_val(b)) { + return a->as_bool() == b->as_bool(); + } + // compare string simple + if (is_val(a) && is_val(b)) { return a->as_string().str() == b->as_string().str(); - } catch (...) {} - } - // compare boolean simple - if (is_val(a) && is_val(b)) { - return a->as_bool() == b->as_bool(); - } - // compare string simple - if (is_val(a) && is_val(b)) { - return a->as_string().str() == b->as_string().str(); - } - // compare by type - if (a->type() != b->type()) { + } + // compare by type + if (a->type() != b->type()) { + return false; + } return false; - } - return false; + }; + auto result = cmp(); + JJ_DEBUG("Comparing types: %s and %s result=%d", a->type().c_str(), b->type().c_str(), result); + return result; } value keyword_argument_expression::execute_impl(context & ctx) { diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 647da3a72b..5172969a9d 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -71,7 +71,7 @@ struct statement { // execute_impl must be overridden by derived classes virtual value execute_impl(context &) { throw std::runtime_error("cannot exec " + type()); } // execute is the public method to execute a statement with error handling - virtual value execute(context &); + value execute(context &); }; // Type Checking Utilities @@ -288,13 +288,17 @@ struct array_literal : public expression { for (const auto& item : this->val) chk_type(item); } std::string type() const override { return "ArrayLiteral"; } + value execute_impl(context & ctx) override { + auto arr = mk_val(); + for (const auto & item_stmt : val) { + arr->push_back(item_stmt->execute(ctx)); + } + return arr; + } }; -struct tuple_literal : public expression { - statements val; - explicit tuple_literal(statements && val) : val(std::move(val)) { - for (const auto & item : this->val) chk_type(item); - } +struct tuple_literal : public array_literal { + explicit tuple_literal(statements && val) : array_literal(std::move(val)) {} std::string type() const override { return "TupleLiteral"; } }; @@ -376,6 +380,13 @@ struct select_expression : public expression { chk_type(this->test); } std::string type() const override { return "SelectExpression"; } + value execute_impl(context & ctx) override { + auto predicate = test->execute_impl(ctx); + if (!predicate->as_bool()) { + return mk_val(); + } + return lhs->execute_impl(ctx); + } }; /** @@ -474,6 +485,14 @@ struct ternary_expression : public expression { chk_type(this->false_expr); } std::string type() const override { return "Ternary"; } + value execute_impl(context & ctx) override { + value cond_val = condition->execute(ctx); + if (cond_val->as_bool()) { + return true_expr->execute(ctx); + } else { + return false_expr->execute(ctx); + } + } }; struct raised_exception : public std::exception { diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 0bf15bed91..64777a3495 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -83,6 +83,8 @@ void run(std::string contents) { messages->push_back(std::move(msg2)); ctx.var["messages"] = std::move(messages); + ctx.var["eos_token"] = jinja::mk_val(""); + // ctx.var["tools"] = jinja::mk_val(); jinja::vm vm(ctx); const jinja::value results = vm.execute(ast);