Implement response model inference from endpoint function source code

- Added "infer_response_model_from_ast" function to analyze endpoint functions and infer Pydantic models from returned dictionary literals or variable assignments.

- Updated "APIRoute" to utilize the new inference method when the specified response model is not a subclass of "BaseModel"
This commit is contained in:
g7azazlo 2025-12-03 21:44:57 +03:00
parent c57ac7bdf3
commit 72b210209b
3 changed files with 282 additions and 0 deletions

View File

@ -57,6 +57,7 @@ from fastapi.utils import (
create_model_field,
generate_unique_id,
get_value_or_default,
infer_response_model_from_ast,
is_body_allowed_for_status_code,
)
from pydantic import BaseModel
@ -544,6 +545,12 @@ class APIRoute(routing.Route):
response_model = None
else:
response_model = return_annotation
if not lenient_issubclass(response_model, BaseModel):
inferred = infer_response_model_from_ast(endpoint)
if inferred:
response_model = inferred
self.response_model = response_model
self.summary = summary
self.response_description = response_description

View File

@ -1,10 +1,14 @@
import ast
import inspect
import re
import warnings
from dataclasses import is_dataclass
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
MutableMapping,
Optional,
Set,
@ -258,3 +262,160 @@ def get_value_or_default(
if not isinstance(item, DefaultPlaceholder):
return item
return first_item
def _infer_type_from_ast(
node: ast.AST,
func_def: Union[ast.FunctionDef, ast.AsyncFunctionDef],
context_name: str,
) -> Any:
if isinstance(node, ast.Constant):
return type(node.value)
if isinstance(node, ast.List):
if not node.elts:
return List[Any]
first_type = _infer_type_from_ast(node.elts[0], func_def, context_name + "Item")
for elt in node.elts[1:]:
current_type = _infer_type_from_ast(elt, func_def, context_name + "Item")
if current_type != first_type:
return List[Any]
if first_type is not Any:
return List[first_type]
return List[Any]
if isinstance(node, ast.BinOp):
left_type = _infer_type_from_ast(node.left, func_def, context_name)
right_type = _infer_type_from_ast(node.right, func_def, context_name)
if left_type == right_type and left_type in (int, float, str):
return left_type
if {left_type, right_type} == {int, float}:
return float
if isinstance(node, ast.Compare):
return bool
if isinstance(node, ast.Dict):
fields = {}
for key, value in zip(node.keys, node.values):
if not isinstance(key, ast.Constant):
continue
field_name = key.value
field_type = _infer_type_from_ast(
value, func_def, context_name + "_" + str(field_name)
)
fields[field_name] = (field_type, ...)
if not fields:
return Dict[str, Any]
if PYDANTIC_V2:
from pydantic import create_model
else:
from pydantic import create_model
return create_model(f"Model_{context_name}", **fields)
if isinstance(node, ast.Name):
arg_name = node.id
for arg in func_def.args.args:
if arg.arg == arg_name and arg.annotation:
if isinstance(arg.annotation, ast.Name):
if arg.annotation.id == "int":
return int
if arg.annotation.id == "str":
return str
if arg.annotation.id == "bool":
return bool
if arg.annotation.id == "float":
return float
if arg.annotation.id == "list":
return List[Any]
if arg.annotation.id == "dict":
return Dict[str, Any]
return Any
def infer_response_model_from_ast(
endpoint_function: Any,
) -> Optional[Type[BaseModel]]:
"""
Analyze the endpoint function's source code to infer a Pydantic model
from a returned dictionary literal or variable assignment.
"""
try:
source = inspect.getsource(endpoint_function)
except (OSError, TypeError):
return None
source = inspect.cleandoc(source)
try:
tree = ast.parse(source)
except SyntaxError:
return None
if not tree.body:
return None
func_def = tree.body[0]
if not isinstance(func_def, (ast.FunctionDef, ast.AsyncFunctionDef)):
return None
return_stmt = None
for node in ast.walk(func_def):
if isinstance(node, ast.Return):
return_stmt = node
break
if not return_stmt:
return None
returned_value = return_stmt.value
dict_node = None
if isinstance(returned_value, ast.Dict):
dict_node = returned_value
elif isinstance(returned_value, ast.Name):
variable_name = returned_value.id
# Find assignment
for node in func_def.body:
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.target.id == variable_name:
if isinstance(node.value, ast.Dict):
dict_node = node.value
break
elif isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == variable_name:
if isinstance(node.value, ast.Dict):
dict_node = node.value
break
if not dict_node:
return None
fields = {}
for key, value in zip(dict_node.keys, dict_node.values):
if not isinstance(key, ast.Constant):
continue
field_name = key.value
field_type = _infer_type_from_ast(
value, func_def, f"{endpoint_function.__name__}_{field_name}"
)
fields[field_name] = (field_type, ...)
if not fields:
return None
if PYDANTIC_V2:
from pydantic import create_model
else:
from pydantic import create_model
model_name = f"ResponseModel_{endpoint_function.__name__}"
return create_model(model_name, **fields)

114
tests/test_ast_inference.py Normal file
View File

@ -0,0 +1,114 @@
from typing import Any, Dict, List
from fastapi import FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
@app.get("/users/{user_id}")
async def get_user(user_id: int) -> Dict[str, Any]:
user: Dict[str, Any] = {
"id": user_id,
"username": "example",
"email": "user@example.com",
"age": 25,
"is_active": True,
}
return user
@app.get("/orders/{order_id}")
async def get_order_details(order_id: str) -> Dict[str, Any]:
order_data: Dict[str, Any] = {
"order_id": order_id,
"status": "processing",
"total_amount": 150.50,
"tags": ["urgent", "new_customer"],
"customer_info": {
"name": "John Doe",
"vip_status": False,
"preferences": {"notifications": True, "theme": "dark"},
},
"items": [
{
"item_id": 1,
"name": "Laptop Stand",
"price": 45.00,
"in_stock": True,
},
],
"metadata": None,
}
return order_data
@app.get("/edge_cases/mixed_types")
async def get_mixed_types() -> Dict[str, Any]:
return {
"mixed_list": [1, "two", 3.0],
"description": "List starting with int"
}
@app.get("/edge_cases/expressions")
async def get_expressions() -> Dict[str, Any]:
return {
"calc_int": 10 + 5,
"calc_str": "foo" + "bar",
"calc_bool": 5 > 3
}
@app.get("/edge_cases/empty_structures")
async def get_empty_structures() -> Dict[str, Any]:
return {
"empty_list": [],
"empty_dict": {}
}
@app.get("/edge_cases/local_variable")
async def get_local_variable() -> Dict[str, Any]:
response_data = {
"status": "ok",
"nested": {
"check": True
}
}
return response_data
client = TestClient(app)
def test_openapi_schema_ast_inference():
response = client.get("/openapi.json")
assert response.status_code == 200
schema = response.json()
paths = schema["paths"]
user_schema = paths["/users/{user_id}"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
assert "$ref" in user_schema
ref_name = user_schema["$ref"].split("/")[-1]
user_props = schema["components"]["schemas"][ref_name]["properties"]
assert user_props["id"]["type"] == "integer"
assert user_props["username"]["type"] == "string"
assert user_props["is_active"]["type"] == "boolean"
order_schema = paths["/orders/{order_id}"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
assert "$ref" in order_schema
order_ref = order_schema["$ref"].split("/")[-1]
order_props = schema["components"]["schemas"][order_ref]["properties"]
items_prop = order_props["items"]
assert items_prop["type"] == "array"
assert "$ref" in items_prop["items"]
customer_prop = order_props["customer_info"]
assert "$ref" in customer_prop
mixed_schema = paths["/edge_cases/mixed_types"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
mixed_ref = mixed_schema["$ref"].split("/")[-1]
mixed_props = schema["components"]["schemas"][mixed_ref]["properties"]
assert mixed_props["mixed_list"]["type"] == "array"
expr_schema = paths["/edge_cases/expressions"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
expr_ref = expr_schema["$ref"].split("/")[-1]
expr_props = schema["components"]["schemas"][expr_ref]["properties"]
assert expr_props["calc_int"]["type"] == "integer"
assert expr_props["calc_bool"]["type"] == "boolean"