demo: type inferrence

This commit is contained in:
Xuan Son Nguyen 2025-12-30 17:26:23 +01:00
parent 9c0fa6f810
commit 4479c382ce
6 changed files with 167 additions and 22 deletions

View File

@ -0,0 +1,38 @@
#pragma once
#include <memory>
#include <string>
#include "jinja-value.h"
namespace jinja {
struct value_t;
using value = std::shared_ptr<value_t>;
// this is used as a hint for chat parsing
// it is not a 1-to-1 mapping to value_t derived types
enum class inferred_type {
numeric, // int, float
string,
boolean,
array,
object,
optional, // null, undefined
unknown,
};
static std::string inferred_type_to_string(inferred_type type) {
switch (type) {
case inferred_type::numeric: return "numeric";
case inferred_type::string: return "string";
case inferred_type::boolean: return "boolean";
case inferred_type::array: return "array";
case inferred_type::object: return "object";
case inferred_type::optional: return "optional";
case inferred_type::unknown: return "unknown";
default: return "invalid";
}
}
} // namespace jinja

View File

@ -708,7 +708,7 @@ void global_from_json(context & ctx, const nlohmann::json & json_obj) {
throw std::runtime_error("global_from_json: input JSON value must be an object");
}
for (auto it = json_obj.begin(); it != json_obj.end(); ++it) {
ctx.var[it.key()] = from_json(it.value());
ctx.set_val(it.key(), from_json(it.value()));
}
}

View File

@ -6,8 +6,10 @@
#include <functional>
#include <memory>
#include <sstream>
#include <set>
#include "jinja-string.h"
#include "jinja-type-infer.h"
namespace jinja {
@ -137,6 +139,10 @@ struct value_t {
func_handler val_func;
// for type inference
std::set<inferred_type> inf_types;
std::vector<value> inf_vals;
value_t() = default;
value_t(const value_t &) = default;
virtual ~value_t() = default;
@ -333,4 +339,26 @@ using value_kwarg = std::shared_ptr<value_kwarg_t>;
const func_builtins & global_builtins();
// utils
static inferred_type value_to_inferred_type(const value & val) {
if (is_val<value_int>(val) || is_val<value_float>(val)) {
return inferred_type::numeric;
} else if (is_val<value_string>(val)) {
return inferred_type::string;
} else if (is_val<value_bool>(val)) {
return inferred_type::boolean;
} else if (is_val<value_array>(val)) {
return inferred_type::array;
} else if (is_val<value_object>(val)) {
return inferred_type::object;
} else if (is_val<value_null>(val) || is_val<value_undefined>(val)) {
return inferred_type::optional;
} else {
return inferred_type::unknown;
}
}
} // namespace jinja

View File

@ -63,11 +63,11 @@ value statement::execute(context & ctx) {
}
value identifier::execute_impl(context & ctx) {
auto it = ctx.var.find(val);
auto it = ctx.get_val(val);
auto builtins = global_builtins();
if (it != ctx.var.end()) {
if (!it->is_undefined()) {
JJ_DEBUG("Identifier '%s' found", val.c_str());
return it->second;
return it;
} else if (builtins.find(val) != builtins.end()) {
JJ_DEBUG("Identifier '%s' found in builtins", val.c_str());
return mk_val<value_func>(builtins.at(val), val);
@ -102,6 +102,8 @@ value binary_expression::execute_impl(context & ctx) {
value right_val = right->execute(ctx);
JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right_val->type().c_str());
if (op.value == "==") {
ctx.mark_known_type(left_val, right_val);
ctx.mark_known_type(right_val, left_val);
return mk_val<value_bool>(value_compare(left_val, right_val));
} else if (op.value == "!=") {
return mk_val<value_bool>(!value_compare(left_val, right_val));
@ -342,6 +344,10 @@ value unary_expression::execute_impl(context & ctx) {
value if_statement::execute_impl(context & ctx) {
value test_val = test->execute(ctx);
ctx.mark_known_type(test_val, inferred_type::boolean);
ctx.mark_known_type(test_val, inferred_type::optional);
auto out = mk_val<value_array>();
if (test_val->as_bool()) {
for (auto & stmt : body) {
@ -384,6 +390,9 @@ value for_statement::execute_impl(context & ctx) {
iterable_val = mk_val<value_array>();
}
ctx.mark_known_type(iterable_val, inferred_type::array);
ctx.mark_known_type(iterable_val, inferred_type::object);
if (!is_val<value_array>(iterable_val) && !is_val<value_object>(iterable_val)) {
throw std::runtime_error("Expected iterable or object type in for loop: got " + iterable_val->type());
}
@ -418,7 +427,7 @@ value for_statement::execute_impl(context & ctx) {
if (is_stmt<identifier>(loopvar)) {
auto id = cast_stmt<identifier>(loopvar)->val;
scope_update_fn = [id, &items, i](context & ctx) {
ctx.var[id] = items[i];
ctx.set_val(id, items[i]);
};
} else if (is_stmt<tuple_literal>(loopvar)) {
auto tuple = cast_stmt<tuple_literal>(loopvar);
@ -436,7 +445,7 @@ value for_statement::execute_impl(context & ctx) {
throw std::runtime_error("Cannot unpack non-identifier type: " + tuple->val[j]->type());
}
auto id = cast_stmt<identifier>(tuple->val[j])->val;
ctx.var[id] = c_arr[j];
ctx.set_val(id, c_arr[j]);
}
};
} else {
@ -470,11 +479,11 @@ value for_statement::execute_impl(context & ctx) {
loop_obj->insert("length", mk_val<value_int>(filtered_items.size()));
loop_obj->insert("previtem", i > 0 ? filtered_items[i - 1] : mk_val<value_undefined>("previtem"));
loop_obj->insert("nextitem", i < filtered_items.size() - 1 ? filtered_items[i + 1] : mk_val<value_undefined>("nextitem"));
ctx.var["loop"] = loop_obj;
scope_update_fns[i](ctx);
scope.set_val("loop", loop_obj);
scope_update_fns[i](scope);
try {
for (auto & stmt : body) {
value val = stmt->execute(ctx);
value val = stmt->execute(scope);
result->push_back(val);
}
} catch (const continue_statement::signal &) {
@ -505,7 +514,7 @@ value set_statement::execute_impl(context & ctx) {
if (is_stmt<identifier>(assignee)) {
auto var_name = cast_stmt<identifier>(assignee)->val;
JJ_DEBUG("Setting variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str());
ctx.var[var_name] = rhs;
ctx.set_val(var_name, rhs);
} else if (is_stmt<tuple_literal>(assignee)) {
auto tuple = cast_stmt<tuple_literal>(assignee);
@ -522,7 +531,7 @@ value set_statement::execute_impl(context & ctx) {
throw std::runtime_error("Cannot unpack to non-identifier in set: " + elem->type());
}
auto var_name = cast_stmt<identifier>(elem)->val;
ctx.var[var_name] = arr[i];
ctx.set_val(var_name, arr[i]);
}
} else if (is_stmt<member_expression>(assignee)) {
@ -564,14 +573,14 @@ value macro_statement::execute_impl(context & ctx) {
if (i < input_count) {
std::string param_name = cast_stmt<identifier>(this->args[i])->val;
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), args.args[i]->type().c_str());
macro_ctx.var[param_name] = args.args[i];
macro_ctx.set_val(param_name, args.args[i]);
} else {
auto & default_arg = this->args[i];
if (is_stmt<keyword_argument_expression>(default_arg)) {
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
macro_ctx.var[param_name] = kwarg->val->execute(ctx);
macro_ctx.set_val(param_name, kwarg->val->execute(ctx));
} else {
throw std::runtime_error("Not enough arguments provided to macro '" + name + "'");
}
@ -589,7 +598,7 @@ value macro_statement::execute_impl(context & ctx) {
};
JJ_DEBUG("Defining macro '%s' with %zu parameters", name.c_str(), args.size());
ctx.var[name] = mk_val<value_func>(func);
ctx.set_val(name, mk_val<value_func>(func));
return mk_val<value_null>();
}

View File

@ -47,23 +47,74 @@ const T * cast_stmt(const statement_ptr & ptr) {
void enable_debug(bool enable);
struct context {
std::map<std::string, value> var;
std::string source; // for debugging
std::time_t current_time; // for functions that need current time
context() {
var["true"] = mk_val<value_bool>(true);
var["false"] = mk_val<value_bool>(false);
var["none"] = mk_val<value_null>();
global = mk_val<value_object>();
global->insert("true", mk_val<value_bool>(true));
global->insert("false", mk_val<value_bool>(false));
global->insert("none", mk_val<value_null>());
current_time = std::time(nullptr);
}
~context() = default;
context(const context & parent) {
context(const context & parent) : context() {
// inherit variables (for example, when entering a new scope)
for (const auto & pair : parent.var) {
var[pair.first] = pair.second;
auto & pvar = parent.global->as_object();
for (const auto & pair : pvar) {
set_val(pair.first, pair.second);
}
}
value get_val(const std::string & name) {
auto it = global->val_obj.find(name);
if (it != global->val_obj.end()) {
return it->second;
} else {
return mk_val<value_undefined>(name);
}
}
void set_val(const std::string & name, const value & val) {
global->insert(name, val);
set_flattened_global_recursively(name, val);
}
void mark_known_type(value & val, inferred_type type) {
val->inf_types.insert(type);
}
void mark_known_type(value & val, value & known_val) {
mark_known_type(val, value_to_inferred_type(known_val));
val->inf_vals.push_back(known_val);
}
// FOR TESTING ONLY
const value_object & get_global_object() const {
return global;
}
private:
value_object global;
public:
std::map<std::string, value> flatten_globals; // for debugging
void set_flattened_global_recursively(std::string path, const value & val) {
flatten_globals[path] = val;
if (is_val<value_object>(val)) {
auto & obj = val->as_object();
for (const auto & pair : obj) {
flatten_globals[pair.first] = pair.second;
set_flattened_global_recursively(pair.first, pair.second);
}
} else if (is_val<value_array>(val)) {
auto & arr = val->as_array();
for (size_t i = 0; i < arr.size(); ++i) {
std::string idx_path = path + "[" + std::to_string(i) + "]";
flatten_globals[idx_path] = arr[i];
set_flattened_global_recursively(idx_path, arr[i]);
}
}
}
};

View File

@ -13,6 +13,7 @@
#include "jinja/jinja-parser.h"
#include "jinja/jinja-lexer.h"
#include "jinja/jinja-type-infer.h"
void run_multiple();
void run_single(std::string contents);
@ -147,4 +148,22 @@ void run_single(std::string contents) {
for (const auto & part : parts.get()->val_str.parts) {
std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n";
}
std::cout << "\n=== TYPES ===\n";
auto & global_obj = ctx.flatten_globals;
for (const auto & pair : global_obj) {
std::string name = pair.first;
std::string inf_types;
for (const auto & t : pair.second->inf_types) {
inf_types += inferred_type_to_string(t) + " ";
}
if (inf_types.empty()) {
continue;
}
std::string inf_vals;
for (const auto & v : pair.second->inf_vals) {
inf_vals += v->as_string().str() + " ; ";
}
printf("Var: %-20s | Types: %-10s | Vals: %s\n", name.c_str(), inf_types.c_str(), inf_vals.c_str());
}
}