From 7f17608ea433729e47751d452eb7545768ed45d9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Dec 2025 17:46:25 +0100 Subject: [PATCH] use shared_ptr for values --- common/jinja/jinja-value.h | 113 +++++++++++------------------ common/jinja/jinja-vm-builtins.cpp | 28 +++---- common/jinja/jinja-vm.cpp | 85 ++++++++++------------ common/jinja/jinja-vm.h | 40 ++++++++-- tests/test-chat-jinja.cpp | 18 ++--- 5 files changed, 137 insertions(+), 147 deletions(-) diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 2bb600c1b9..6c6f4a30d6 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -12,7 +12,7 @@ namespace jinja { struct value_t; -using value = std::unique_ptr; +using value = std::shared_ptr; // Helper to check the type of a value @@ -21,7 +21,7 @@ struct extract_pointee { using type = T; }; template -struct extract_pointee> { +struct extract_pointee> { using type = U; }; template @@ -35,9 +35,19 @@ bool is_val(const value_t * ptr) { return dynamic_cast(ptr) != nullptr; } template -std::unique_ptr::type> mk_val(Args&&... args) { +std::shared_ptr::type> mk_val(Args&&... args) { using PointeeType = typename extract_pointee::type; - return std::make_unique(std::forward(args)...); + return std::make_shared(std::forward(args)...); +} +template +const typename extract_pointee::type * cast_val(const value & ptr) { + using PointeeType = typename extract_pointee::type; + return dynamic_cast(ptr.get()); +} +template +typename extract_pointee::type * cast_val(value & ptr) { + using PointeeType = typename extract_pointee::type; + return dynamic_cast(ptr.get()); } template void ensure_val(const value & ptr) { @@ -91,8 +101,8 @@ struct value_t { // my_arr = [my_obj] // my_obj["a"] = 3 // print(my_arr[0]["a"]) # should print 3 - std::shared_ptr> val_arr; - std::shared_ptr> val_obj; + std::vector val_arr; + std::map val_obj; func_handler val_func; @@ -116,10 +126,6 @@ struct value_t { } virtual std::string as_repr() const { return as_string().str(); } - - virtual value clone() const { - return std::make_unique(*this); - } }; @@ -129,10 +135,9 @@ struct value_int_t : public value_t { virtual int64_t as_int() const override { return val_int; } virtual double as_float() const override { return static_cast(val_int); } virtual string as_string() const override { return std::to_string(val_int); } - virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; }; -using value_int = std::unique_ptr; +using value_int = std::shared_ptr; struct value_float_t : public value_t { @@ -141,10 +146,9 @@ struct value_float_t : public value_t { virtual double as_float() const override { return val_flt; } virtual int64_t as_int() const override { return static_cast(val_flt); } virtual string as_string() const override { return std::to_string(val_flt); } - virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; }; -using value_float = std::unique_ptr; +using value_float = std::shared_ptr; struct value_string_t : public value_t { @@ -160,13 +164,12 @@ struct value_string_t : public value_t { } return ss.str(); } - virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; void mark_input() { val_str.mark_input(); } }; -using value_string = std::unique_ptr; +using value_string = std::shared_ptr; struct value_bool_t : public value_t { @@ -174,92 +177,68 @@ struct value_bool_t : public value_t { virtual std::string type() const override { return "Boolean"; } virtual bool as_bool() const override { return val_bool; } virtual string as_string() const override { return std::string(val_bool ? "True" : "False"); } - virtual value clone() const override { return std::make_unique(*this); } virtual const func_builtins & get_builtins() const override; }; -using value_bool = std::unique_ptr; +using value_bool = std::shared_ptr; struct value_array_t : public value_t { - value_array_t() { - val_arr = std::make_shared>(); - } + value_array_t() = default; value_array_t(value & v) { // point to the same underlying data val_arr = v->val_arr; } void push_back(const value & val) { - val_arr->push_back(val->clone()); + val_arr.push_back(val); } virtual std::string type() const override { return "Array"; } - virtual const std::vector & as_array() const override { return *val_arr; } - // clone will also share the underlying data (point to the same vector) - virtual value clone() const override { - auto tmp = std::make_unique(); - tmp->val_arr = this->val_arr; - return tmp; - } + virtual const std::vector & as_array() const override { return val_arr; } virtual string as_string() const override { std::ostringstream ss; ss << "["; - for (size_t i = 0; i < val_arr->size(); i++) { + for (size_t i = 0; i < val_arr.size(); i++) { if (i > 0) ss << ", "; - ss << val_arr->at(i)->as_repr(); + ss << val_arr.at(i)->as_repr(); } ss << "]"; return ss.str(); } virtual bool as_bool() const override { - return !val_arr->empty(); + return !val_arr.empty(); } virtual const func_builtins & get_builtins() const override; }; -using value_array = std::unique_ptr; +using value_array = std::shared_ptr; struct value_object_t : public value_t { - value_object_t() { - val_obj = std::make_shared>(); - } + value_object_t() = default; value_object_t(value & v) { // point to the same underlying data val_obj = v->val_obj; } value_object_t(const std::map & obj) { - val_obj = std::make_shared>(); + val_obj = std::map(); for (const auto & pair : obj) { - (*val_obj)[pair.first] = pair.second->clone(); + val_obj[pair.first] = pair.second; } } void insert(const std::string & key, const value & val) { - (*val_obj)[key] = val->clone(); + val_obj[key] = val; } virtual std::string type() const override { return "Object"; } - virtual const std::map & as_object() const override { return *val_obj; } - // clone will also share the underlying data (point to the same map) - virtual value clone() const override { - auto tmp = std::make_unique(); - tmp->val_obj = this->val_obj; - return tmp; - } + virtual const std::map & as_object() const override { return val_obj; } virtual bool as_bool() const override { - return !val_obj->empty(); + return !val_obj.empty(); } virtual const func_builtins & get_builtins() const override; }; -using value_object = std::unique_ptr; +using value_object = std::shared_ptr; struct value_func_t : public value_t { std::string name; // for debugging value arg0; // bound "this" argument, if any - value_func_t(const value_func_t & other) { - val_func = other.val_func; - name = other.name; - if (other.arg0) { - arg0 = other.arg0->clone(); - } - } value_func_t(const func_handler & func, std::string func_name = "") { val_func = func; name = func_name; @@ -267,14 +246,14 @@ struct value_func_t : public value_t { value_func_t(const func_handler & func, const value & arg_this, std::string func_name = "") { val_func = func; name = func_name; - arg0 = arg_this->clone(); + arg0 = arg_this; } virtual value invoke(const func_args & args) const override { if (arg0) { func_args new_args; - new_args.args.push_back(arg0->clone()); + new_args.args.push_back(arg0); for (const auto & a : args.args) { - new_args.args.push_back(a->clone()); + new_args.args.push_back(a); } return val_func(new_args); } else { @@ -283,9 +262,8 @@ struct value_func_t : public value_t { } virtual std::string type() const override { return "Function"; } virtual std::string as_repr() const override { return type(); } - virtual value clone() const override { return std::make_unique(*this); } }; -using value_func = std::unique_ptr; +using value_func = std::shared_ptr; struct value_null_t : public value_t { @@ -293,9 +271,8 @@ struct value_null_t : public value_t { virtual bool is_null() const override { return true; } virtual bool as_bool() const override { return false; } virtual std::string as_repr() const override { return type(); } - virtual value clone() const override { return std::make_unique(*this); } }; -using value_null = std::unique_ptr; +using value_null = std::shared_ptr; struct value_undefined_t : public value_t { @@ -303,24 +280,18 @@ struct value_undefined_t : public value_t { virtual bool is_undefined() const override { return true; } virtual bool as_bool() const override { return false; } virtual std::string as_repr() const override { return type(); } - virtual value clone() const override { return std::make_unique(*this); } }; -using value_undefined = std::unique_ptr; +using value_undefined = std::shared_ptr; // special value for kwarg struct value_kwarg_t : public value_t { std::string key; value val; - value_kwarg_t(const value_kwarg_t & other) { - key = other.key; - val = other.val->clone(); - } - value_kwarg_t(const std::string & k, const value & v) : key(k), val(v->clone()) {} + value_kwarg_t(const std::string & k, const value & v) : key(k), val(v) {} virtual std::string type() const override { return "KwArg"; } virtual std::string as_repr() const override { return type(); } - virtual value clone() const override { return std::make_unique(*this); } }; -using value_kwarg = std::unique_ptr; +using value_kwarg = std::shared_ptr; const func_builtins & global_builtins(); diff --git a/common/jinja/jinja-vm-builtins.cpp b/common/jinja/jinja-vm-builtins.cpp index feb7ffb5d2..ed601eb9b1 100644 --- a/common/jinja/jinja-vm-builtins.cpp +++ b/common/jinja/jinja-vm-builtins.cpp @@ -55,7 +55,7 @@ static T slice(const T & array, std::optional start = std::nullopt, std } for (int64_t i = start_val; direction * i < direction * stop_val; i += step) { if (i >= 0 && i < len) { - result.push_back(std::move(array[static_cast(i)]->clone())); + result.push_back(array[static_cast(i)]); } } return result; @@ -87,7 +87,7 @@ const func_builtins & global_builtins() { if (!is_val(arg)) { throw raised_exception("namespace() arguments must be kwargs"); } - auto kwarg = dynamic_cast(arg.get()); + auto kwarg = cast_val(arg); out->insert(kwarg->key, kwarg->val); } return out; @@ -265,12 +265,12 @@ const func_builtins & value_string_t::get_builtins() const { std::string token; while ((pos = str.find(delim)) != std::string::npos) { token = str.substr(0, pos); - result->val_arr->push_back(mk_val(token)); + result->push_back(mk_val(token)); str.erase(0, pos + delim.length()); } auto res = mk_val(str); res->val_str.mark_input_based_on(args.args[0]->val_str); - result->val_arr->push_back(std::move(res)); + result->push_back(std::move(res)); return std::move(result); }}, {"replace", [](const func_args & args) -> value { @@ -353,7 +353,7 @@ const func_builtins & value_array_t::get_builtins() const { const auto & arr = args.args[0]->as_array(); auto result = mk_val(); for (const auto& v : arr) { - result->val_arr->push_back(v->clone()); + result->push_back(v); } return result; }}, @@ -363,7 +363,7 @@ const func_builtins & value_array_t::get_builtins() const { if (arr.empty()) { return mk_val(); } - return arr[0]->clone(); + return arr[0]; }}, {"last", [](const func_args & args) -> value { args.ensure_vals(); @@ -371,7 +371,7 @@ const func_builtins & value_array_t::get_builtins() const { if (arr.empty()) { return mk_val(); } - return arr[arr.size() - 1]->clone(); + return arr[arr.size() - 1]; }}, {"length", [](const func_args & args) -> value { args.ensure_vals(); @@ -391,7 +391,7 @@ const func_builtins & value_array_t::get_builtins() const { } auto arr = slice(args.args[0]->as_array(), start, stop, step); auto res = mk_val(); - res->val_arr = std::make_shared>(std::move(arr)); + res->val_arr = std::move(arr); return res; }}, // TODO: reverse, sort, join, string, unique @@ -408,7 +408,7 @@ const func_builtins & value_object_t::get_builtins() const { std::string key = args.args[1]->as_string().str(); auto it = obj.find(key); if (it != obj.end()) { - return it->second->clone(); + return it->second; } else { return mk_val(); } @@ -418,7 +418,7 @@ const func_builtins & value_object_t::get_builtins() const { const auto & obj = args.args[0]->as_object(); auto result = mk_val(); for (const auto & pair : obj) { - result->val_arr->push_back(mk_val(pair.first)); + result->push_back(mk_val(pair.first)); } return result; }}, @@ -427,7 +427,7 @@ const func_builtins & value_object_t::get_builtins() const { const auto & obj = args.args[0]->as_object(); auto result = mk_val(); for (const auto & pair : obj) { - result->val_arr->push_back(pair.second->clone()); + result->push_back(pair.second); } return result; }}, @@ -437,9 +437,9 @@ const func_builtins & value_object_t::get_builtins() const { auto result = mk_val(); for (const auto & pair : obj) { auto item = mk_val(); - item->val_arr->push_back(mk_val(pair.first)); - item->val_arr->push_back(pair.second->clone()); - result->val_arr->push_back(std::move(item)); + item->push_back(mk_val(pair.first)); + item->push_back(pair.second); + result->push_back(std::move(item)); } return result; }}, diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index f39321fa00..fea7c75f06 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -13,16 +13,11 @@ namespace jinja { -template -static bool is_stmt(const statement_ptr & ptr) { - return dynamic_cast(ptr.get()) != nullptr; -} - static value_array exec_statements(const statements & stmts, context & ctx) { auto result = mk_val(); for (const auto & stmt : stmts) { JJ_DEBUG("Executing statement of type %s", stmt->type().c_str()); - result->val_arr->push_back(stmt->execute(ctx)); + result->push_back(stmt->execute(ctx)); } return result; } @@ -32,7 +27,7 @@ value identifier::execute(context & ctx) { auto builtins = global_builtins(); if (it != ctx.var.end()) { JJ_DEBUG("Identifier '%s' found", val.c_str()); - return it->second->clone(); + return it->second; } else if (builtins.find(val) != builtins.end()) { JJ_DEBUG("Identifier '%s' found in builtins", val.c_str()); return mk_val(builtins.at(val), val); @@ -115,10 +110,10 @@ value binary_expression::execute(context & ctx) { auto & right_arr = right_val->as_array(); auto result = mk_val(); for (const auto & item : left_arr) { - result->val_arr->push_back(item->clone()); + result->push_back(item); } for (const auto & item : right_arr) { - result->val_arr->push_back(item->clone()); + result->push_back(item); } return result; } @@ -185,7 +180,7 @@ value filter_expression::execute(context & ctx) { value input = operand->execute(ctx); if (is_stmt(filter)) { - auto filter_val = dynamic_cast(filter.get())->val; + auto filter_val = cast_stmt(filter)->val; if (filter_val == "to_json") { // TODO: Implement to_json filter @@ -215,7 +210,7 @@ value test_expression::execute(context & ctx) { throw std::runtime_error("Invalid test expression"); } - auto test_id = dynamic_cast(test.get())->val; + auto test_id = cast_stmt(test)->val; auto it = builtins.find("test_is_" + test_id); JJ_DEBUG("Test expression %s '%s'", operand->type().c_str(), test_id.c_str()); if (it == builtins.end()) { @@ -252,12 +247,12 @@ value if_statement::execute(context & ctx) { if (test_val->as_bool()) { for (auto & stmt : body) { JJ_DEBUG("IF --> Executing THEN body, current block: %s", stmt->type().c_str()); - out->val_arr->push_back(stmt->execute(ctx)); + out->push_back(stmt->execute(ctx)); } } else { for (auto & stmt : alternate) { JJ_DEBUG("IF --> Executing ELSE body, current block: %s", stmt->type().c_str()); - out->val_arr->push_back(stmt->execute(ctx)); + out->push_back(stmt->execute(ctx)); } } return out; @@ -271,7 +266,7 @@ value for_statement::execute(context & ctx) { if (is_stmt(iterable)) { JJ_DEBUG("%s", "For loop has test expression"); - auto select = dynamic_cast(iterable.get()); + auto select = cast_stmt(iterable); iter_expr = std::move(select->lhs); test_expr = std::move(select->test); } @@ -292,7 +287,7 @@ value for_statement::execute(context & ctx) { } else { auto & arr = iterable_val->as_array(); for (const auto & item : arr) { - items.push_back(item->clone()); + items.push_back(item); } } @@ -306,12 +301,12 @@ value for_statement::execute(context & ctx) { std::function scope_update_fn = [](context &) { /* no-op */}; if (is_stmt(loopvar)) { - auto id = dynamic_cast(loopvar.get())->val; + auto id = cast_stmt(loopvar)->val; scope_update_fn = [id, &items, i](context & ctx) { - ctx.var[id] = items[i]->clone(); + ctx.var[id] = items[i]; }; } else if (is_stmt(loopvar)) { - auto tuple = dynamic_cast(loopvar.get()); + auto tuple = cast_stmt(loopvar); if (!is_val(current)) { throw std::runtime_error("Cannot unpack non-iterable type: " + current->type()); } @@ -325,8 +320,8 @@ value for_statement::execute(context & ctx) { if (!is_stmt(tuple->val[j])) { throw std::runtime_error("Cannot unpack non-identifier type: " + tuple->val[j]->type()); } - auto id = dynamic_cast(tuple->val[j].get())->val; - ctx.var[id] = c_arr[j]->clone(); + auto id = cast_stmt(tuple->val[j])->val; + ctx.var[id] = c_arr[j]; } }; } else { @@ -339,7 +334,7 @@ value for_statement::execute(context & ctx) { continue; } } - filtered_items.push_back(current->clone()); + filtered_items.push_back(current); scope_update_fns.push_back(scope_update_fn); } @@ -356,9 +351,9 @@ value for_statement::execute(context & ctx) { loop_obj->insert("first", mk_val(i == 0)); loop_obj->insert("last", mk_val(i == filtered_items.size() - 1)); loop_obj->insert("length", mk_val(filtered_items.size())); - loop_obj->insert("previtem", i > 0 ? filtered_items[i - 1]->clone() : mk_val()); - loop_obj->insert("nextitem", i < filtered_items.size() - 1 ? filtered_items[i + 1]->clone() : mk_val()); - ctx.var["loop"] = loop_obj->clone(); + loop_obj->insert("previtem", i > 0 ? filtered_items[i - 1] : mk_val()); + loop_obj->insert("nextitem", i < filtered_items.size() - 1 ? filtered_items[i + 1] : mk_val()); + ctx.var["loop"] = loop_obj; scope_update_fns[i](ctx); try { for (auto & stmt : body) { @@ -386,12 +381,12 @@ value set_statement::execute(context & ctx) { auto rhs = val ? val->execute(ctx) : exec_statements(body, ctx); if (is_stmt(assignee)) { - auto var_name = dynamic_cast(assignee.get())->val; + auto var_name = cast_stmt(assignee)->val; JJ_DEBUG("Setting variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str()); - ctx.var[var_name] = rhs->clone(); + ctx.var[var_name] = rhs; } else if (is_stmt(assignee)) { - auto tuple = dynamic_cast(assignee.get()); + auto tuple = cast_stmt(assignee); if (!is_val(rhs)) { throw std::runtime_error("Cannot unpack non-iterable type in set: " + rhs->type()); } @@ -404,27 +399,27 @@ value set_statement::execute(context & ctx) { if (!is_stmt(elem)) { throw std::runtime_error("Cannot unpack to non-identifier in set: " + elem->type()); } - auto var_name = dynamic_cast(elem.get())->val; - ctx.var[var_name] = arr[i]->clone(); + auto var_name = cast_stmt(elem)->val; + ctx.var[var_name] = arr[i]; } } else if (is_stmt(assignee)) { - auto member = dynamic_cast(assignee.get()); + auto member = cast_stmt(assignee); if (member->computed) { throw std::runtime_error("Cannot assign to computed member"); } if (!is_stmt(member->property)) { throw std::runtime_error("Cannot assign to member with non-identifier property"); } - auto prop_name = dynamic_cast(member->property.get())->val; + auto prop_name = cast_stmt(member->property)->val; value object = member->object->execute(ctx); if (!is_val(object)) { throw std::runtime_error("Cannot assign to member of non-object"); } - auto obj_ptr = dynamic_cast(object.get()); + auto obj_ptr = cast_val(object); JJ_DEBUG("Setting object property '%s'", prop_name.c_str()); - obj_ptr->insert(prop_name, rhs->clone()); + obj_ptr->insert(prop_name, rhs); } else { throw std::runtime_error("Invalid LHS inside assignment expression: " + assignee->type()); @@ -433,7 +428,7 @@ value set_statement::execute(context & ctx) { } value macro_statement::execute(context & ctx) { - std::string name = dynamic_cast(this->name.get())->val; + std::string name = cast_stmt(this->name)->val; const func_handler func = [this, &ctx, name](const func_args & args) -> value { JJ_DEBUG("Invoking macro '%s' with %zu arguments", name.c_str(), args.args.size()); context macro_ctx(ctx); // new scope for macro execution @@ -442,9 +437,9 @@ value macro_statement::execute(context & ctx) { size_t param_count = this->args.size(); size_t arg_count = args.args.size(); for (size_t i = 0; i < param_count; ++i) { - std::string param_name = dynamic_cast(this->args[i].get())->val; + std::string param_name = cast_stmt(this->args[i])->val; if (i < arg_count) { - macro_ctx.var[param_name] = args.args[i]->clone(); + macro_ctx.var[param_name] = args.args[i]; } else { macro_ctx.var[param_name] = mk_val(); } @@ -466,7 +461,7 @@ value member_expression::execute(context & ctx) { if (this->computed) { JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str()); if (is_stmt(this->property)) { - auto s = dynamic_cast(this->property.get()); + auto s = cast_stmt(this->property); value start_val = s->start_expr ? s->start_expr->execute(ctx) : mk_val(); value stop_val = s->stop_expr ? s->stop_expr->execute(ctx) : mk_val(); value step_val = s->step_expr ? s->step_expr->execute(ctx) : mk_val(); @@ -478,15 +473,15 @@ value member_expression::execute(context & ctx) { step_val->as_repr().c_str()); auto slice_func = try_builtin_func("slice", object); func_args args; - args.args.push_back(start_val->clone()); - args.args.push_back(stop_val->clone()); - args.args.push_back(step_val->clone()); + args.args.push_back(start_val); + args.args.push_back(stop_val); + args.args.push_back(step_val); return slice_func->invoke(args); } else { property = this->property->execute(ctx); } } else { - property = mk_val(dynamic_cast(this->property.get())->val); + property = mk_val(cast_stmt(this->property)->val); } JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str()); @@ -502,7 +497,7 @@ value member_expression::execute(context & ctx) { auto & obj = object->as_object(); auto it = obj.find(key); if (it != obj.end()) { - val = it->second->clone(); + val = it->second; } else { val = try_builtin_func(key, object, true); } @@ -514,7 +509,7 @@ value member_expression::execute(context & ctx) { if (is_val(object)) { auto & arr = object->as_array(); if (index >= 0 && index < static_cast(arr.size())) { - val = arr[index]->clone(); + val = arr[index]; } } else { // value_string auto str = object->as_string().str(); @@ -554,7 +549,7 @@ value call_expression::execute(context & ctx) { if (!is_val(callee_val)) { throw std::runtime_error("Callee is not a function: got " + callee_val->type()); } - auto * callee_func = dynamic_cast(callee_val.get()); + auto * callee_func = cast_val(callee_val); JJ_DEBUG("Calling function '%s' with %zu arguments", callee_func->name.c_str(), args.args.size()); return callee_func->invoke(args); } @@ -597,7 +592,7 @@ value keyword_argument_expression::execute(context & ctx) { throw std::runtime_error("Keyword argument key must be identifiers"); } - std::string k = dynamic_cast(key.get())->val; + std::string k = cast_stmt(key)->val; JJ_DEBUG("Keyword argument expression key: %s, value: %s", k.c_str(), val->type().c_str()); value v = val->execute(ctx); diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index a931bc1ea8..3cfc4b81df 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -12,6 +12,33 @@ namespace jinja { +struct statement; +using statement_ptr = std::unique_ptr; +using statements = std::vector; + +// Helpers for dynamic casting and type checking +template +struct extract_pointee_unique { + using type = T; +}; +template +struct extract_pointee_unique> { + using type = U; +}; +template +bool is_stmt(const statement_ptr & ptr) { + return dynamic_cast(ptr.get()) != nullptr; +} +template +T * cast_stmt(statement_ptr & ptr) { + return dynamic_cast(ptr.get()); +} +template +const T * cast_stmt(const statement_ptr & ptr) { + return dynamic_cast(ptr.get()); +} +// End Helpers + struct context { std::map var; @@ -25,7 +52,7 @@ struct context { context(const context & parent) { // inherit variables (for example, when entering a new scope) for (const auto & pair : parent.var) { - var[pair.first] = pair.second->clone(); + var[pair.first] = pair.second; } } }; @@ -39,9 +66,6 @@ struct statement { virtual value execute(context &) { throw std::runtime_error("cannot exec " + type()); } }; -using statement_ptr = std::unique_ptr; -using statements = std::vector; - // Type Checking Utilities template @@ -461,7 +485,7 @@ struct vm { value_array results = mk_val(); for (auto & stmt : prog.body) { value res = stmt->execute(ctx); - results->val_arr->push_back(std::move(res)); + results->push_back(std::move(res)); } return results; } @@ -474,13 +498,13 @@ struct vm { void gather_string_parts_recursive(const value & val, std::vector & parts) { if (is_val(val)) { - const auto & str_val = dynamic_cast(val.get())->val_str; + const auto & str_val = cast_val(val)->val_str; for (const auto & part : str_val.parts) { parts.push_back(part); } } else if (is_val(val)) { - auto items = dynamic_cast(val.get())->val_arr.get(); - for (const auto & item : *items) { + auto items = cast_val(val)->as_array(); + for (const auto & item : items) { gather_string_parts_recursive(item, parts); } } diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index ce17df5b1d..eff9831ff4 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -47,15 +47,15 @@ int main(void) { return str_val; }; - jinja::value messages = jinja::mk_val(); - jinja::value msg1 = jinja::mk_val(); - (*msg1->val_obj)["role"] = make_non_special_string("user"); - (*msg1->val_obj)["content"] = make_non_special_string("Hello, how are you?"); - messages->val_arr->push_back(std::move(msg1)); - jinja::value msg2 = jinja::mk_val(); - (*msg2->val_obj)["role"] = make_non_special_string("assistant"); - (*msg2->val_obj)["content"] = make_non_special_string("I am fine, thank you!"); - messages->val_arr->push_back(std::move(msg2)); + jinja::value_array messages = jinja::mk_val(); + jinja::value_object msg1 = jinja::mk_val(); + msg1->insert("role", make_non_special_string("user")); + msg1->insert("content", make_non_special_string("Hello, how are you?")); + messages->push_back(std::move(msg1)); + jinja::value_object msg2 = jinja::mk_val(); + msg2->insert("role", make_non_special_string("assistant")); + msg2->insert("content", make_non_special_string("I am fine, thank you!")); + messages->push_back(std::move(msg2)); ctx.var["messages"] = std::move(messages);