From 72b210209bd3437c086baf4ee8ef73406bc61225 Mon Sep 17 00:00:00 2001 From: g7azazlo Date: Wed, 3 Dec 2025 21:44:57 +0300 Subject: [PATCH] Implement response model inference from endpoint function source code - Added "infer_response_model_from_ast" function to analyze endpoint functions and infer Pydantic models from returned dictionary literals or variable assignments. - Updated "APIRoute" to utilize the new inference method when the specified response model is not a subclass of "BaseModel" --- fastapi/routing.py | 7 ++ fastapi/utils.py | 161 ++++++++++++++++++++++++++++++++++++ tests/test_ast_inference.py | 114 +++++++++++++++++++++++++ 3 files changed, 282 insertions(+) create mode 100644 tests/test_ast_inference.py diff --git a/fastapi/routing.py b/fastapi/routing.py index c10175b16..1098c1d03 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -57,6 +57,7 @@ from fastapi.utils import ( create_model_field, generate_unique_id, get_value_or_default, + infer_response_model_from_ast, is_body_allowed_for_status_code, ) from pydantic import BaseModel @@ -544,6 +545,12 @@ class APIRoute(routing.Route): response_model = None else: response_model = return_annotation + + if not lenient_issubclass(response_model, BaseModel): + inferred = infer_response_model_from_ast(endpoint) + if inferred: + response_model = inferred + self.response_model = response_model self.summary = summary self.response_description = response_description diff --git a/fastapi/utils.py b/fastapi/utils.py index b3b89ed2b..e5831c893 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -1,10 +1,14 @@ +import ast +import inspect import re + import warnings from dataclasses import is_dataclass from typing import ( TYPE_CHECKING, Any, Dict, + List, MutableMapping, Optional, Set, @@ -258,3 +262,160 @@ def get_value_or_default( if not isinstance(item, DefaultPlaceholder): return item return first_item + + +def _infer_type_from_ast( + node: ast.AST, + func_def: Union[ast.FunctionDef, ast.AsyncFunctionDef], + context_name: str, +) -> Any: + if isinstance(node, ast.Constant): + return type(node.value) + + if isinstance(node, ast.List): + if not node.elts: + return List[Any] + + first_type = _infer_type_from_ast(node.elts[0], func_def, context_name + "Item") + + for elt in node.elts[1:]: + current_type = _infer_type_from_ast(elt, func_def, context_name + "Item") + if current_type != first_type: + return List[Any] + + if first_type is not Any: + return List[first_type] + return List[Any] + + if isinstance(node, ast.BinOp): + left_type = _infer_type_from_ast(node.left, func_def, context_name) + right_type = _infer_type_from_ast(node.right, func_def, context_name) + if left_type == right_type and left_type in (int, float, str): + return left_type + if {left_type, right_type} == {int, float}: + return float + + if isinstance(node, ast.Compare): + return bool + + if isinstance(node, ast.Dict): + fields = {} + for key, value in zip(node.keys, node.values): + if not isinstance(key, ast.Constant): + continue + field_name = key.value + field_type = _infer_type_from_ast( + value, func_def, context_name + "_" + str(field_name) + ) + fields[field_name] = (field_type, ...) + + if not fields: + return Dict[str, Any] + + if PYDANTIC_V2: + from pydantic import create_model + else: + from pydantic import create_model + + return create_model(f"Model_{context_name}", **fields) + + if isinstance(node, ast.Name): + arg_name = node.id + for arg in func_def.args.args: + if arg.arg == arg_name and arg.annotation: + if isinstance(arg.annotation, ast.Name): + if arg.annotation.id == "int": + return int + if arg.annotation.id == "str": + return str + if arg.annotation.id == "bool": + return bool + if arg.annotation.id == "float": + return float + if arg.annotation.id == "list": + return List[Any] + if arg.annotation.id == "dict": + return Dict[str, Any] + + return Any + + +def infer_response_model_from_ast( + endpoint_function: Any, +) -> Optional[Type[BaseModel]]: + """ + Analyze the endpoint function's source code to infer a Pydantic model + from a returned dictionary literal or variable assignment. + """ + try: + source = inspect.getsource(endpoint_function) + except (OSError, TypeError): + return None + + source = inspect.cleandoc(source) + try: + tree = ast.parse(source) + except SyntaxError: + return None + + if not tree.body: + return None + + func_def = tree.body[0] + if not isinstance(func_def, (ast.FunctionDef, ast.AsyncFunctionDef)): + return None + + return_stmt = None + for node in ast.walk(func_def): + if isinstance(node, ast.Return): + return_stmt = node + break + + if not return_stmt: + return None + + returned_value = return_stmt.value + dict_node = None + + if isinstance(returned_value, ast.Dict): + dict_node = returned_value + elif isinstance(returned_value, ast.Name): + variable_name = returned_value.id + # Find assignment + for node in func_def.body: + if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.target.id == variable_name: + if isinstance(node.value, ast.Dict): + dict_node = node.value + break + elif isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == variable_name: + if isinstance(node.value, ast.Dict): + dict_node = node.value + break + + if not dict_node: + return None + + fields = {} + for key, value in zip(dict_node.keys, dict_node.values): + if not isinstance(key, ast.Constant): + continue + field_name = key.value + + field_type = _infer_type_from_ast( + value, func_def, f"{endpoint_function.__name__}_{field_name}" + ) + + fields[field_name] = (field_type, ...) + + if not fields: + return None + + if PYDANTIC_V2: + from pydantic import create_model + else: + from pydantic import create_model + + model_name = f"ResponseModel_{endpoint_function.__name__}" + return create_model(model_name, **fields) diff --git a/tests/test_ast_inference.py b/tests/test_ast_inference.py new file mode 100644 index 000000000..5867385ec --- /dev/null +++ b/tests/test_ast_inference.py @@ -0,0 +1,114 @@ +from typing import Any, Dict, List +from fastapi import FastAPI +from fastapi.testclient import TestClient + +app = FastAPI() + +@app.get("/users/{user_id}") +async def get_user(user_id: int) -> Dict[str, Any]: + user: Dict[str, Any] = { + "id": user_id, + "username": "example", + "email": "user@example.com", + "age": 25, + "is_active": True, + } + return user + +@app.get("/orders/{order_id}") +async def get_order_details(order_id: str) -> Dict[str, Any]: + order_data: Dict[str, Any] = { + "order_id": order_id, + "status": "processing", + "total_amount": 150.50, + "tags": ["urgent", "new_customer"], + "customer_info": { + "name": "John Doe", + "vip_status": False, + "preferences": {"notifications": True, "theme": "dark"}, + }, + "items": [ + { + "item_id": 1, + "name": "Laptop Stand", + "price": 45.00, + "in_stock": True, + }, + ], + "metadata": None, + } + return order_data + +@app.get("/edge_cases/mixed_types") +async def get_mixed_types() -> Dict[str, Any]: + return { + "mixed_list": [1, "two", 3.0], + "description": "List starting with int" + } + +@app.get("/edge_cases/expressions") +async def get_expressions() -> Dict[str, Any]: + return { + "calc_int": 10 + 5, + "calc_str": "foo" + "bar", + "calc_bool": 5 > 3 + } + +@app.get("/edge_cases/empty_structures") +async def get_empty_structures() -> Dict[str, Any]: + return { + "empty_list": [], + "empty_dict": {} + } + +@app.get("/edge_cases/local_variable") +async def get_local_variable() -> Dict[str, Any]: + response_data = { + "status": "ok", + "nested": { + "check": True + } + } + return response_data + +client = TestClient(app) + +def test_openapi_schema_ast_inference(): + response = client.get("/openapi.json") + assert response.status_code == 200 + schema = response.json() + paths = schema["paths"] + + user_schema = paths["/users/{user_id}"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + assert "$ref" in user_schema + ref_name = user_schema["$ref"].split("/")[-1] + user_props = schema["components"]["schemas"][ref_name]["properties"] + + assert user_props["id"]["type"] == "integer" + assert user_props["username"]["type"] == "string" + assert user_props["is_active"]["type"] == "boolean" + + order_schema = paths["/orders/{order_id}"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + assert "$ref" in order_schema + order_ref = order_schema["$ref"].split("/")[-1] + order_props = schema["components"]["schemas"][order_ref]["properties"] + + items_prop = order_props["items"] + assert items_prop["type"] == "array" + assert "$ref" in items_prop["items"] + + customer_prop = order_props["customer_info"] + assert "$ref" in customer_prop + + mixed_schema = paths["/edge_cases/mixed_types"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + mixed_ref = mixed_schema["$ref"].split("/")[-1] + mixed_props = schema["components"]["schemas"][mixed_ref]["properties"] + assert mixed_props["mixed_list"]["type"] == "array" + + expr_schema = paths["/edge_cases/expressions"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + expr_ref = expr_schema["$ref"].split("/")[-1] + expr_props = schema["components"]["schemas"][expr_ref]["properties"] + + assert expr_props["calc_int"]["type"] == "integer" + assert expr_props["calc_bool"]["type"] == "boolean" + \ No newline at end of file