Merge branch 'feat/ast-response-inference' of https://github.com/g7AzaZLO/fastapi into feat/ast-response-inference

This commit is contained in:
g7azazlo 2025-12-03 22:46:32 +03:00
commit 6201559a1c
3 changed files with 98 additions and 78 deletions

View File

@ -545,12 +545,11 @@ class APIRoute(routing.Route):
response_model = None
else:
response_model = return_annotation
if (
response_model is None
or (
not lenient_issubclass(response_model, BaseModel)
response_model is not None
and not lenient_issubclass(response_model, BaseModel)
and not dataclasses.is_dataclass(response_model)
)
):
inferred = infer_response_model_from_ast(endpoint)
if inferred:

View File

@ -1,7 +1,6 @@
import ast
import inspect
import re
import warnings
from dataclasses import is_dataclass
from typing import (
@ -367,7 +366,6 @@ def infer_response_model_from_ast(
return_stmt = None
nodes_to_visit = list(func_def.body)
while nodes_to_visit:
node = nodes_to_visit.pop(0)
@ -394,7 +392,11 @@ def infer_response_model_from_ast(
variable_name = returned_value.id
# Find assignment
for node in func_def.body:
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.target.id == variable_name:
if (
isinstance(node, ast.AnnAssign)
and isinstance(node.target, ast.Name)
and node.target.id == variable_name
):
if isinstance(node.value, ast.Dict):
dict_node = node.value
break

View File

@ -1,10 +1,12 @@
from typing import Any, Dict, List
from typing import Any, Dict
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
app = FastAPI()
@app.get("/users/{user_id}")
async def get_user(user_id: int) -> Dict[str, Any]:
user: Dict[str, Any] = {
@ -16,6 +18,7 @@ async def get_user(user_id: int) -> Dict[str, Any]:
}
return user
@app.get("/orders/{order_id}")
async def get_order_details(order_id: str) -> Dict[str, Any]:
order_data: Dict[str, Any] = {
@ -40,42 +43,33 @@ async def get_order_details(order_id: str) -> Dict[str, Any]:
}
return order_data
@app.get("/edge_cases/mixed_types")
async def get_mixed_types() -> Dict[str, Any]:
return {
"mixed_list": [1, "two", 3.0],
"description": "List starting with int"
}
return {"mixed_list": [1, "two", 3.0], "description": "List starting with int"}
@app.get("/edge_cases/expressions")
async def get_expressions() -> Dict[str, Any]:
return {
"calc_int": 10 + 5,
"calc_str": "foo" + "bar",
"calc_bool": 5 > 3
}
return {"calc_int": 10 + 5, "calc_str": "foo" + "bar", "calc_bool": 5 > 3}
@app.get("/edge_cases/empty_structures")
async def get_empty_structures() -> Dict[str, Any]:
return {
"empty_list": [],
"empty_dict": {}
}
return {"empty_list": [], "empty_dict": {}}
@app.get("/edge_cases/local_variable")
async def get_local_variable() -> Dict[str, Any]:
response_data = {
"status": "ok",
"nested": {
"check": True
}
}
response_data = {"status": "ok", "nested": {"check": True}}
return response_data
@app.get("/edge_cases/explicit_response")
def get_explicit_response() -> JSONResponse:
return JSONResponse({"should_not_be_inferred": True})
@app.get("/edge_cases/nested_function")
async def get_nested_function() -> Dict[str, Any]:
def inner_function():
@ -83,6 +77,7 @@ async def get_nested_function() -> Dict[str, Any]:
return {"outer": "value"}
@app.get("/edge_cases/invalid_keys")
def get_invalid_keys() -> Dict[Any, Any]:
return {1: "value", "valid": "key"}
@ -92,29 +87,33 @@ class FakeDB:
def get_user(self) -> Dict[str, Any]:
return {"id": 1, "username": "db_user"}
fake_db = FakeDB()
@app.get("/db/direct_return")
def get_db_direct() -> Dict[str, Any]:
return fake_db.get_user()
@app.get("/db/dict_construction")
def get_db_constructed() -> Dict[str, Any]:
data = fake_db.get_user()
return {
"db_id": data["id"],
"source": "database"
}
return {"db_id": data["id"], "source": "database"}
client = TestClient(app)
def test_openapi_schema_ast_inference():
response = client.get("/openapi.json")
assert response.status_code == 200
schema = response.json()
paths = schema["paths"]
user_schema = paths["/users/{user_id}"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
user_schema = paths["/users/{user_id}"]["get"]["responses"]["200"]["content"][
"application/json"
]["schema"]
assert "$ref" in user_schema
ref_name = user_schema["$ref"].split("/")[-1]
user_props = schema["components"]["schemas"][ref_name]["properties"]
@ -123,7 +122,9 @@ def test_openapi_schema_ast_inference():
assert user_props["username"]["type"] == "string"
assert user_props["is_active"]["type"] == "boolean"
order_schema = paths["/orders/{order_id}"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
order_schema = paths["/orders/{order_id}"]["get"]["responses"]["200"]["content"][
"application/json"
]["schema"]
assert "$ref" in order_schema
order_ref = order_schema["$ref"].split("/")[-1]
order_props = schema["components"]["schemas"][order_ref]["properties"]
@ -135,39 +136,57 @@ def test_openapi_schema_ast_inference():
customer_prop = order_props["customer_info"]
assert "$ref" in customer_prop
mixed_schema = paths["/edge_cases/mixed_types"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
mixed_schema = paths["/edge_cases/mixed_types"]["get"]["responses"]["200"][
"content"
]["application/json"]["schema"]
mixed_ref = mixed_schema["$ref"].split("/")[-1]
mixed_props = schema["components"]["schemas"][mixed_ref]["properties"]
assert mixed_props["mixed_list"]["type"] == "array"
expr_schema = paths["/edge_cases/expressions"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
expr_schema = paths["/edge_cases/expressions"]["get"]["responses"]["200"][
"content"
]["application/json"]["schema"]
expr_ref = expr_schema["$ref"].split("/")[-1]
expr_props = schema["components"]["schemas"][expr_ref]["properties"]
assert expr_props["calc_int"]["type"] == "integer"
assert expr_props["calc_bool"]["type"] == "boolean"
explicit_schema = paths["/edge_cases/explicit_response"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
explicit_schema = paths["/edge_cases/explicit_response"]["get"]["responses"]["200"][
"content"
]["application/json"]["schema"]
assert "$ref" not in explicit_schema
nested_schema = paths["/edge_cases/nested_function"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
nested_schema = paths["/edge_cases/nested_function"]["get"]["responses"]["200"][
"content"
]["application/json"]["schema"]
assert "$ref" in nested_schema
nested_ref = nested_schema["$ref"].split("/")[-1]
nested_props = schema["components"]["schemas"][nested_ref]["properties"]
assert "outer" in nested_props
assert "inner" not in nested_props
invalid_keys_schema = paths["/edge_cases/invalid_keys"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
invalid_keys_schema = paths["/edge_cases/invalid_keys"]["get"]["responses"]["200"][
"content"
]["application/json"]["schema"]
assert "$ref" not in invalid_keys_schema
db_direct_schema = paths["/db/direct_return"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
db_direct_schema = paths["/db/direct_return"]["get"]["responses"]["200"]["content"][
"application/json"
]["schema"]
assert "$ref" not in db_direct_schema
db_constructed_schema = paths["/db/dict_construction"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
db_constructed_schema = paths["/db/dict_construction"]["get"]["responses"]["200"][
"content"
]["application/json"]["schema"]
assert "$ref" in db_constructed_schema
db_constructed_ref = db_constructed_schema["$ref"].split("/")[-1]
db_constructed_props = schema["components"]["schemas"][db_constructed_ref]["properties"]
db_constructed_props = schema["components"]["schemas"][db_constructed_ref][
"properties"
]
assert db_constructed_props["source"]["type"] == "string"
assert "type" not in db_constructed_props["db_id"] or db_constructed_props["db_id"] == {}
assert (
"type" not in db_constructed_props["db_id"]
or db_constructed_props["db_id"] == {}
)