mirror of https://github.com/tiangolo/fastapi.git
Merge branch 'feat/ast-response-inference' of https://github.com/g7AzaZLO/fastapi into feat/ast-response-inference
This commit is contained in:
commit
6201559a1c
|
|
@ -545,16 +545,15 @@ class APIRoute(routing.Route):
|
|||
response_model = None
|
||||
else:
|
||||
response_model = return_annotation
|
||||
if (
|
||||
response_model is None
|
||||
or (
|
||||
not lenient_issubclass(response_model, BaseModel)
|
||||
and not dataclasses.is_dataclass(response_model)
|
||||
)
|
||||
):
|
||||
inferred = infer_response_model_from_ast(endpoint)
|
||||
if inferred:
|
||||
response_model = inferred
|
||||
|
||||
if (
|
||||
response_model is not None
|
||||
and not lenient_issubclass(response_model, BaseModel)
|
||||
and not dataclasses.is_dataclass(response_model)
|
||||
):
|
||||
inferred = infer_response_model_from_ast(endpoint)
|
||||
if inferred:
|
||||
response_model = inferred
|
||||
|
||||
self.response_model = response_model
|
||||
self.summary = summary
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import ast
|
||||
import inspect
|
||||
import re
|
||||
|
||||
import warnings
|
||||
from dataclasses import is_dataclass
|
||||
from typing import (
|
||||
|
|
@ -277,12 +276,12 @@ def _infer_type_from_ast(
|
|||
return List[Any]
|
||||
|
||||
first_type = _infer_type_from_ast(node.elts[0], func_def, context_name + "Item")
|
||||
|
||||
|
||||
for elt in node.elts[1:]:
|
||||
current_type = _infer_type_from_ast(elt, func_def, context_name + "Item")
|
||||
if current_type != first_type:
|
||||
return List[Any]
|
||||
|
||||
|
||||
if first_type is not Any:
|
||||
return List[first_type]
|
||||
return List[Any]
|
||||
|
|
@ -291,7 +290,7 @@ def _infer_type_from_ast(
|
|||
left_type = _infer_type_from_ast(node.left, func_def, context_name)
|
||||
right_type = _infer_type_from_ast(node.right, func_def, context_name)
|
||||
if left_type == right_type and left_type in (int, float, str):
|
||||
return left_type
|
||||
return left_type
|
||||
if {left_type, right_type} == {int, float}:
|
||||
return float
|
||||
|
||||
|
|
@ -351,60 +350,63 @@ def infer_response_model_from_ast(
|
|||
source = inspect.getsource(endpoint_function)
|
||||
except (OSError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
source = inspect.cleandoc(source)
|
||||
try:
|
||||
tree = ast.parse(source)
|
||||
except SyntaxError:
|
||||
return None
|
||||
|
||||
|
||||
if not tree.body:
|
||||
return None
|
||||
|
||||
|
||||
func_def = tree.body[0]
|
||||
if not isinstance(func_def, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
return None
|
||||
|
||||
|
||||
return_stmt = None
|
||||
|
||||
|
||||
nodes_to_visit = list(func_def.body)
|
||||
while nodes_to_visit:
|
||||
node = nodes_to_visit.pop(0)
|
||||
|
||||
|
||||
if isinstance(node, ast.Return):
|
||||
return_stmt = node
|
||||
break
|
||||
|
||||
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
continue
|
||||
|
||||
|
||||
for child in ast.iter_child_nodes(node):
|
||||
nodes_to_visit.append(child)
|
||||
|
||||
|
||||
if not return_stmt:
|
||||
return None
|
||||
|
||||
returned_value = return_stmt.value
|
||||
dict_node = None
|
||||
|
||||
|
||||
if isinstance(returned_value, ast.Dict):
|
||||
dict_node = returned_value
|
||||
elif isinstance(returned_value, ast.Name):
|
||||
variable_name = returned_value.id
|
||||
# Find assignment
|
||||
for node in func_def.body:
|
||||
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.target.id == variable_name:
|
||||
if isinstance(node.value, ast.Dict):
|
||||
dict_node = node.value
|
||||
break
|
||||
if (
|
||||
isinstance(node, ast.AnnAssign)
|
||||
and isinstance(node.target, ast.Name)
|
||||
and node.target.id == variable_name
|
||||
):
|
||||
if isinstance(node.value, ast.Dict):
|
||||
dict_node = node.value
|
||||
break
|
||||
elif isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == variable_name:
|
||||
if isinstance(node.value, ast.Dict):
|
||||
dict_node = node.value
|
||||
break
|
||||
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == variable_name:
|
||||
if isinstance(node.value, ast.Dict):
|
||||
dict_node = node.value
|
||||
break
|
||||
|
||||
if not dict_node:
|
||||
return None
|
||||
|
||||
|
|
@ -412,16 +414,16 @@ def infer_response_model_from_ast(
|
|||
for key, value in zip(dict_node.keys, dict_node.values):
|
||||
if not isinstance(key, ast.Constant):
|
||||
continue
|
||||
|
||||
|
||||
if not isinstance(key.value, str):
|
||||
return None
|
||||
|
||||
|
||||
field_name = key.value
|
||||
|
||||
|
||||
field_type = _infer_type_from_ast(
|
||||
value, func_def, f"{endpoint_function.__name__}_{field_name}"
|
||||
)
|
||||
|
||||
|
||||
fields[field_name] = (field_type, ...)
|
||||
|
||||
if not fields:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/users/{user_id}")
|
||||
async def get_user(user_id: int) -> Dict[str, Any]:
|
||||
user: Dict[str, Any] = {
|
||||
|
|
@ -16,6 +18,7 @@ async def get_user(user_id: int) -> Dict[str, Any]:
|
|||
}
|
||||
return user
|
||||
|
||||
|
||||
@app.get("/orders/{order_id}")
|
||||
async def get_order_details(order_id: str) -> Dict[str, Any]:
|
||||
order_data: Dict[str, Any] = {
|
||||
|
|
@ -40,49 +43,41 @@ async def get_order_details(order_id: str) -> Dict[str, Any]:
|
|||
}
|
||||
return order_data
|
||||
|
||||
|
||||
@app.get("/edge_cases/mixed_types")
|
||||
async def get_mixed_types() -> Dict[str, Any]:
|
||||
return {
|
||||
"mixed_list": [1, "two", 3.0],
|
||||
"description": "List starting with int"
|
||||
}
|
||||
return {"mixed_list": [1, "two", 3.0], "description": "List starting with int"}
|
||||
|
||||
|
||||
@app.get("/edge_cases/expressions")
|
||||
async def get_expressions() -> Dict[str, Any]:
|
||||
return {
|
||||
"calc_int": 10 + 5,
|
||||
"calc_str": "foo" + "bar",
|
||||
"calc_bool": 5 > 3
|
||||
}
|
||||
return {"calc_int": 10 + 5, "calc_str": "foo" + "bar", "calc_bool": 5 > 3}
|
||||
|
||||
|
||||
@app.get("/edge_cases/empty_structures")
|
||||
async def get_empty_structures() -> Dict[str, Any]:
|
||||
return {
|
||||
"empty_list": [],
|
||||
"empty_dict": {}
|
||||
}
|
||||
return {"empty_list": [], "empty_dict": {}}
|
||||
|
||||
|
||||
@app.get("/edge_cases/local_variable")
|
||||
async def get_local_variable() -> Dict[str, Any]:
|
||||
response_data = {
|
||||
"status": "ok",
|
||||
"nested": {
|
||||
"check": True
|
||||
}
|
||||
}
|
||||
response_data = {"status": "ok", "nested": {"check": True}}
|
||||
return response_data
|
||||
|
||||
|
||||
@app.get("/edge_cases/explicit_response")
|
||||
def get_explicit_response() -> JSONResponse:
|
||||
return JSONResponse({"should_not_be_inferred": True})
|
||||
|
||||
|
||||
@app.get("/edge_cases/nested_function")
|
||||
async def get_nested_function() -> Dict[str, Any]:
|
||||
def inner_function():
|
||||
return {"inner": "value"}
|
||||
|
||||
|
||||
return {"outer": "value"}
|
||||
|
||||
|
||||
@app.get("/edge_cases/invalid_keys")
|
||||
def get_invalid_keys() -> Dict[Any, Any]:
|
||||
return {1: "value", "valid": "key"}
|
||||
|
|
@ -92,38 +87,44 @@ class FakeDB:
|
|||
def get_user(self) -> Dict[str, Any]:
|
||||
return {"id": 1, "username": "db_user"}
|
||||
|
||||
|
||||
fake_db = FakeDB()
|
||||
|
||||
|
||||
@app.get("/db/direct_return")
|
||||
def get_db_direct() -> Dict[str, Any]:
|
||||
return fake_db.get_user()
|
||||
|
||||
|
||||
@app.get("/db/dict_construction")
|
||||
def get_db_constructed() -> Dict[str, Any]:
|
||||
data = fake_db.get_user()
|
||||
return {
|
||||
"db_id": data["id"],
|
||||
"source": "database"
|
||||
}
|
||||
return {"db_id": data["id"], "source": "database"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_openapi_schema_ast_inference():
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200
|
||||
schema = response.json()
|
||||
paths = schema["paths"]
|
||||
|
||||
user_schema = paths["/users/{user_id}"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
|
||||
user_schema = paths["/users/{user_id}"]["get"]["responses"]["200"]["content"][
|
||||
"application/json"
|
||||
]["schema"]
|
||||
assert "$ref" in user_schema
|
||||
ref_name = user_schema["$ref"].split("/")[-1]
|
||||
user_props = schema["components"]["schemas"][ref_name]["properties"]
|
||||
|
||||
|
||||
assert user_props["id"]["type"] == "integer"
|
||||
assert user_props["username"]["type"] == "string"
|
||||
assert user_props["is_active"]["type"] == "boolean"
|
||||
|
||||
order_schema = paths["/orders/{order_id}"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
|
||||
order_schema = paths["/orders/{order_id}"]["get"]["responses"]["200"]["content"][
|
||||
"application/json"
|
||||
]["schema"]
|
||||
assert "$ref" in order_schema
|
||||
order_ref = order_schema["$ref"].split("/")[-1]
|
||||
order_props = schema["components"]["schemas"][order_ref]["properties"]
|
||||
|
|
@ -135,39 +136,57 @@ def test_openapi_schema_ast_inference():
|
|||
customer_prop = order_props["customer_info"]
|
||||
assert "$ref" in customer_prop
|
||||
|
||||
mixed_schema = paths["/edge_cases/mixed_types"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
|
||||
mixed_schema = paths["/edge_cases/mixed_types"]["get"]["responses"]["200"][
|
||||
"content"
|
||||
]["application/json"]["schema"]
|
||||
mixed_ref = mixed_schema["$ref"].split("/")[-1]
|
||||
mixed_props = schema["components"]["schemas"][mixed_ref]["properties"]
|
||||
assert mixed_props["mixed_list"]["type"] == "array"
|
||||
|
||||
expr_schema = paths["/edge_cases/expressions"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
|
||||
expr_schema = paths["/edge_cases/expressions"]["get"]["responses"]["200"][
|
||||
"content"
|
||||
]["application/json"]["schema"]
|
||||
expr_ref = expr_schema["$ref"].split("/")[-1]
|
||||
expr_props = schema["components"]["schemas"][expr_ref]["properties"]
|
||||
|
||||
|
||||
assert expr_props["calc_int"]["type"] == "integer"
|
||||
assert expr_props["calc_bool"]["type"] == "boolean"
|
||||
|
||||
explicit_schema = paths["/edge_cases/explicit_response"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
|
||||
explicit_schema = paths["/edge_cases/explicit_response"]["get"]["responses"]["200"][
|
||||
"content"
|
||||
]["application/json"]["schema"]
|
||||
assert "$ref" not in explicit_schema
|
||||
|
||||
nested_schema = paths["/edge_cases/nested_function"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
|
||||
nested_schema = paths["/edge_cases/nested_function"]["get"]["responses"]["200"][
|
||||
"content"
|
||||
]["application/json"]["schema"]
|
||||
assert "$ref" in nested_schema
|
||||
nested_ref = nested_schema["$ref"].split("/")[-1]
|
||||
nested_props = schema["components"]["schemas"][nested_ref]["properties"]
|
||||
assert "outer" in nested_props
|
||||
assert "inner" not in nested_props
|
||||
|
||||
invalid_keys_schema = paths["/edge_cases/invalid_keys"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
|
||||
invalid_keys_schema = paths["/edge_cases/invalid_keys"]["get"]["responses"]["200"][
|
||||
"content"
|
||||
]["application/json"]["schema"]
|
||||
assert "$ref" not in invalid_keys_schema
|
||||
|
||||
db_direct_schema = paths["/db/direct_return"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
|
||||
db_direct_schema = paths["/db/direct_return"]["get"]["responses"]["200"]["content"][
|
||||
"application/json"
|
||||
]["schema"]
|
||||
assert "$ref" not in db_direct_schema
|
||||
|
||||
db_constructed_schema = paths["/db/dict_construction"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
|
||||
db_constructed_schema = paths["/db/dict_construction"]["get"]["responses"]["200"][
|
||||
"content"
|
||||
]["application/json"]["schema"]
|
||||
assert "$ref" in db_constructed_schema
|
||||
db_constructed_ref = db_constructed_schema["$ref"].split("/")[-1]
|
||||
db_constructed_props = schema["components"]["schemas"][db_constructed_ref]["properties"]
|
||||
|
||||
assert db_constructed_props["source"]["type"] == "string"
|
||||
assert "type" not in db_constructed_props["db_id"] or db_constructed_props["db_id"] == {}
|
||||
db_constructed_props = schema["components"]["schemas"][db_constructed_ref][
|
||||
"properties"
|
||||
]
|
||||
|
||||
assert db_constructed_props["source"]["type"] == "string"
|
||||
assert (
|
||||
"type" not in db_constructed_props["db_id"]
|
||||
or db_constructed_props["db_id"] == {}
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue