From 4479c382ce611eec159bd3d529d854fa1c5df864 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 30 Dec 2025 17:26:23 +0100 Subject: [PATCH] demo: type inferrence --- common/jinja/jinja-type-infer.h | 38 +++++++++++++++++++ common/jinja/jinja-value.cpp | 2 +- common/jinja/jinja-value.h | 28 ++++++++++++++ common/jinja/jinja-vm.cpp | 35 ++++++++++------- common/jinja/jinja-vm.h | 67 +++++++++++++++++++++++++++++---- tests/test-chat-jinja.cpp | 19 ++++++++++ 6 files changed, 167 insertions(+), 22 deletions(-) create mode 100644 common/jinja/jinja-type-infer.h diff --git a/common/jinja/jinja-type-infer.h b/common/jinja/jinja-type-infer.h new file mode 100644 index 0000000000..3f7508787f --- /dev/null +++ b/common/jinja/jinja-type-infer.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +#include "jinja-value.h" + +namespace jinja { + +struct value_t; +using value = std::shared_ptr; + +// this is used as a hint for chat parsing +// it is not a 1-to-1 mapping to value_t derived types +enum class inferred_type { + numeric, // int, float + string, + boolean, + array, + object, + optional, // null, undefined + unknown, +}; + +static std::string inferred_type_to_string(inferred_type type) { + switch (type) { + case inferred_type::numeric: return "numeric"; + case inferred_type::string: return "string"; + case inferred_type::boolean: return "boolean"; + case inferred_type::array: return "array"; + case inferred_type::object: return "object"; + case inferred_type::optional: return "optional"; + case inferred_type::unknown: return "unknown"; + default: return "invalid"; + } +} + +} // namespace jinja diff --git a/common/jinja/jinja-value.cpp b/common/jinja/jinja-value.cpp index 2c9ce6c76c..5a515fc8e4 100644 --- a/common/jinja/jinja-value.cpp +++ b/common/jinja/jinja-value.cpp @@ -708,7 +708,7 @@ void global_from_json(context & ctx, const nlohmann::json & json_obj) { throw std::runtime_error("global_from_json: input JSON value must be an object"); } for (auto it = json_obj.begin(); it != json_obj.end(); ++it) { - ctx.var[it.key()] = from_json(it.value()); + ctx.set_val(it.key(), from_json(it.value())); } } diff --git a/common/jinja/jinja-value.h b/common/jinja/jinja-value.h index 7c7d98d932..77d30c82f7 100644 --- a/common/jinja/jinja-value.h +++ b/common/jinja/jinja-value.h @@ -6,8 +6,10 @@ #include #include #include +#include #include "jinja-string.h" +#include "jinja-type-infer.h" namespace jinja { @@ -137,6 +139,10 @@ struct value_t { func_handler val_func; + // for type inference + std::set inf_types; + std::vector inf_vals; + value_t() = default; value_t(const value_t &) = default; virtual ~value_t() = default; @@ -333,4 +339,26 @@ using value_kwarg = std::shared_ptr; const func_builtins & global_builtins(); + +// utils + +static inferred_type value_to_inferred_type(const value & val) { + if (is_val(val) || is_val(val)) { + return inferred_type::numeric; + } else if (is_val(val)) { + return inferred_type::string; + } else if (is_val(val)) { + return inferred_type::boolean; + } else if (is_val(val)) { + return inferred_type::array; + } else if (is_val(val)) { + return inferred_type::object; + } else if (is_val(val) || is_val(val)) { + return inferred_type::optional; + } else { + return inferred_type::unknown; + } +} + + } // namespace jinja diff --git a/common/jinja/jinja-vm.cpp b/common/jinja/jinja-vm.cpp index b99fc605f0..ed98f1d050 100644 --- a/common/jinja/jinja-vm.cpp +++ b/common/jinja/jinja-vm.cpp @@ -63,11 +63,11 @@ value statement::execute(context & ctx) { } value identifier::execute_impl(context & ctx) { - auto it = ctx.var.find(val); + auto it = ctx.get_val(val); auto builtins = global_builtins(); - if (it != ctx.var.end()) { + if (!it->is_undefined()) { JJ_DEBUG("Identifier '%s' found", val.c_str()); - return it->second; + return it; } else if (builtins.find(val) != builtins.end()) { JJ_DEBUG("Identifier '%s' found in builtins", val.c_str()); return mk_val(builtins.at(val), val); @@ -102,6 +102,8 @@ value binary_expression::execute_impl(context & ctx) { value right_val = right->execute(ctx); JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right_val->type().c_str()); if (op.value == "==") { + ctx.mark_known_type(left_val, right_val); + ctx.mark_known_type(right_val, left_val); return mk_val(value_compare(left_val, right_val)); } else if (op.value == "!=") { return mk_val(!value_compare(left_val, right_val)); @@ -342,6 +344,10 @@ value unary_expression::execute_impl(context & ctx) { value if_statement::execute_impl(context & ctx) { value test_val = test->execute(ctx); + + ctx.mark_known_type(test_val, inferred_type::boolean); + ctx.mark_known_type(test_val, inferred_type::optional); + auto out = mk_val(); if (test_val->as_bool()) { for (auto & stmt : body) { @@ -384,6 +390,9 @@ value for_statement::execute_impl(context & ctx) { iterable_val = mk_val(); } + ctx.mark_known_type(iterable_val, inferred_type::array); + ctx.mark_known_type(iterable_val, inferred_type::object); + if (!is_val(iterable_val) && !is_val(iterable_val)) { throw std::runtime_error("Expected iterable or object type in for loop: got " + iterable_val->type()); } @@ -418,7 +427,7 @@ value for_statement::execute_impl(context & ctx) { if (is_stmt(loopvar)) { auto id = cast_stmt(loopvar)->val; scope_update_fn = [id, &items, i](context & ctx) { - ctx.var[id] = items[i]; + ctx.set_val(id, items[i]); }; } else if (is_stmt(loopvar)) { auto tuple = cast_stmt(loopvar); @@ -436,7 +445,7 @@ value for_statement::execute_impl(context & ctx) { throw std::runtime_error("Cannot unpack non-identifier type: " + tuple->val[j]->type()); } auto id = cast_stmt(tuple->val[j])->val; - ctx.var[id] = c_arr[j]; + ctx.set_val(id, c_arr[j]); } }; } else { @@ -470,11 +479,11 @@ value for_statement::execute_impl(context & ctx) { loop_obj->insert("length", mk_val(filtered_items.size())); loop_obj->insert("previtem", i > 0 ? filtered_items[i - 1] : mk_val("previtem")); loop_obj->insert("nextitem", i < filtered_items.size() - 1 ? filtered_items[i + 1] : mk_val("nextitem")); - ctx.var["loop"] = loop_obj; - scope_update_fns[i](ctx); + scope.set_val("loop", loop_obj); + scope_update_fns[i](scope); try { for (auto & stmt : body) { - value val = stmt->execute(ctx); + value val = stmt->execute(scope); result->push_back(val); } } catch (const continue_statement::signal &) { @@ -505,7 +514,7 @@ value set_statement::execute_impl(context & ctx) { if (is_stmt(assignee)) { auto var_name = cast_stmt(assignee)->val; JJ_DEBUG("Setting variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str()); - ctx.var[var_name] = rhs; + ctx.set_val(var_name, rhs); } else if (is_stmt(assignee)) { auto tuple = cast_stmt(assignee); @@ -522,7 +531,7 @@ value set_statement::execute_impl(context & ctx) { throw std::runtime_error("Cannot unpack to non-identifier in set: " + elem->type()); } auto var_name = cast_stmt(elem)->val; - ctx.var[var_name] = arr[i]; + ctx.set_val(var_name, arr[i]); } } else if (is_stmt(assignee)) { @@ -564,14 +573,14 @@ value macro_statement::execute_impl(context & ctx) { if (i < input_count) { std::string param_name = cast_stmt(this->args[i])->val; JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), args.args[i]->type().c_str()); - macro_ctx.var[param_name] = args.args[i]; + macro_ctx.set_val(param_name, args.args[i]); } else { auto & default_arg = this->args[i]; if (is_stmt(default_arg)) { auto kwarg = cast_stmt(default_arg); std::string param_name = cast_stmt(kwarg->key)->val; JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str()); - macro_ctx.var[param_name] = kwarg->val->execute(ctx); + macro_ctx.set_val(param_name, kwarg->val->execute(ctx)); } else { throw std::runtime_error("Not enough arguments provided to macro '" + name + "'"); } @@ -589,7 +598,7 @@ value macro_statement::execute_impl(context & ctx) { }; JJ_DEBUG("Defining macro '%s' with %zu parameters", name.c_str(), args.size()); - ctx.var[name] = mk_val(func); + ctx.set_val(name, mk_val(func)); return mk_val(); } diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 1095d71870..bb24abad96 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -47,23 +47,74 @@ const T * cast_stmt(const statement_ptr & ptr) { void enable_debug(bool enable); struct context { - std::map var; std::string source; // for debugging - std::time_t current_time; // for functions that need current time context() { - var["true"] = mk_val(true); - var["false"] = mk_val(false); - var["none"] = mk_val(); + global = mk_val(); + global->insert("true", mk_val(true)); + global->insert("false", mk_val(false)); + global->insert("none", mk_val()); current_time = std::time(nullptr); } ~context() = default; - context(const context & parent) { + context(const context & parent) : context() { // inherit variables (for example, when entering a new scope) - for (const auto & pair : parent.var) { - var[pair.first] = pair.second; + auto & pvar = parent.global->as_object(); + for (const auto & pair : pvar) { + set_val(pair.first, pair.second); + } + } + + value get_val(const std::string & name) { + auto it = global->val_obj.find(name); + if (it != global->val_obj.end()) { + return it->second; + } else { + return mk_val(name); + } + } + + void set_val(const std::string & name, const value & val) { + global->insert(name, val); + set_flattened_global_recursively(name, val); + } + + void mark_known_type(value & val, inferred_type type) { + val->inf_types.insert(type); + } + + void mark_known_type(value & val, value & known_val) { + mark_known_type(val, value_to_inferred_type(known_val)); + val->inf_vals.push_back(known_val); + } + + // FOR TESTING ONLY + const value_object & get_global_object() const { + return global; + } + +private: + value_object global; + +public: + std::map flatten_globals; // for debugging + void set_flattened_global_recursively(std::string path, const value & val) { + flatten_globals[path] = val; + if (is_val(val)) { + auto & obj = val->as_object(); + for (const auto & pair : obj) { + flatten_globals[pair.first] = pair.second; + set_flattened_global_recursively(pair.first, pair.second); + } + } else if (is_val(val)) { + auto & arr = val->as_array(); + for (size_t i = 0; i < arr.size(); ++i) { + std::string idx_path = path + "[" + std::to_string(i) + "]"; + flatten_globals[idx_path] = arr[i]; + set_flattened_global_recursively(idx_path, arr[i]); + } } } }; diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 0ab18c0f4f..39ce9fed00 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -13,6 +13,7 @@ #include "jinja/jinja-parser.h" #include "jinja/jinja-lexer.h" +#include "jinja/jinja-type-infer.h" void run_multiple(); void run_single(std::string contents); @@ -147,4 +148,22 @@ void run_single(std::string contents) { for (const auto & part : parts.get()->val_str.parts) { std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n"; } + + std::cout << "\n=== TYPES ===\n"; + auto & global_obj = ctx.flatten_globals; + for (const auto & pair : global_obj) { + std::string name = pair.first; + std::string inf_types; + for (const auto & t : pair.second->inf_types) { + inf_types += inferred_type_to_string(t) + " "; + } + if (inf_types.empty()) { + continue; + } + std::string inf_vals; + for (const auto & v : pair.second->inf_vals) { + inf_vals += v->as_string().str() + " ; "; + } + printf("Var: %-20s | Types: %-10s | Vals: %s\n", name.c_str(), inf_types.c_str(), inf_vals.c_str()); + } }