mirror of https://github.com/tiangolo/fastapi.git
Enhance response model inference and improve error handling
- Updated the response model inference logic in "APIRoute" to check for "None" before validating against "BaseModel". - Refined the "infer_response_model_from_ast" function to handle nested function definitions and invalid dictionary keys, ensuring robust schema generation. - Added new test cases for edge scenarios to validate the inference behavior.
This commit is contained in:
parent
72b210209b
commit
40af1b90c3
|
|
@ -546,7 +546,7 @@ class APIRoute(routing.Route):
|
||||||
else:
|
else:
|
||||||
response_model = return_annotation
|
response_model = return_annotation
|
||||||
|
|
||||||
if not lenient_issubclass(response_model, BaseModel):
|
if response_model is not None and not lenient_issubclass(response_model, BaseModel):
|
||||||
inferred = infer_response_model_from_ast(endpoint)
|
inferred = infer_response_model_from_ast(endpoint)
|
||||||
if inferred:
|
if inferred:
|
||||||
response_model = inferred
|
response_model = inferred
|
||||||
|
|
|
||||||
|
|
@ -366,11 +366,22 @@ def infer_response_model_from_ast(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return_stmt = None
|
return_stmt = None
|
||||||
for node in ast.walk(func_def):
|
|
||||||
|
|
||||||
|
nodes_to_visit = list(func_def.body)
|
||||||
|
while nodes_to_visit:
|
||||||
|
node = nodes_to_visit.pop(0)
|
||||||
|
|
||||||
if isinstance(node, ast.Return):
|
if isinstance(node, ast.Return):
|
||||||
return_stmt = node
|
return_stmt = node
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for child in ast.iter_child_nodes(node):
|
||||||
|
nodes_to_visit.append(child)
|
||||||
|
|
||||||
if not return_stmt:
|
if not return_stmt:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -401,6 +412,10 @@ def infer_response_model_from_ast(
|
||||||
for key, value in zip(dict_node.keys, dict_node.values):
|
for key, value in zip(dict_node.keys, dict_node.values):
|
||||||
if not isinstance(key, ast.Constant):
|
if not isinstance(key, ast.Constant):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if not isinstance(key.value, str):
|
||||||
|
return None
|
||||||
|
|
||||||
field_name = key.value
|
field_name = key.value
|
||||||
|
|
||||||
field_type = _infer_type_from_ast(
|
field_type = _infer_type_from_ast(
|
||||||
|
|
@ -418,4 +433,7 @@ def infer_response_model_from_ast(
|
||||||
from pydantic import create_model
|
from pydantic import create_model
|
||||||
|
|
||||||
model_name = f"ResponseModel_{endpoint_function.__name__}"
|
model_name = f"ResponseModel_{endpoint_function.__name__}"
|
||||||
return create_model(model_name, **fields)
|
try:
|
||||||
|
return create_model(model_name, **fields)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,13 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
# Добавляем корень проекта в sys.path, чтобы Python видел пакет fastapi
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
@ -71,6 +79,21 @@ async def get_local_variable() -> Dict[str, Any]:
|
||||||
}
|
}
|
||||||
return response_data
|
return response_data
|
||||||
|
|
||||||
|
@app.get("/edge_cases/explicit_response")
|
||||||
|
def get_explicit_response() -> JSONResponse:
|
||||||
|
return JSONResponse({"should_not_be_inferred": True})
|
||||||
|
|
||||||
|
@app.get("/edge_cases/nested_function")
|
||||||
|
async def get_nested_function() -> Dict[str, Any]:
|
||||||
|
def inner_function():
|
||||||
|
return {"inner": "value"}
|
||||||
|
|
||||||
|
return {"outer": "value"}
|
||||||
|
|
||||||
|
@app.get("/edge_cases/invalid_keys")
|
||||||
|
def get_invalid_keys() -> Dict[Any, Any]:
|
||||||
|
return {1: "value", "valid": "key"}
|
||||||
|
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
|
||||||
def test_openapi_schema_ast_inference():
|
def test_openapi_schema_ast_inference():
|
||||||
|
|
@ -112,3 +135,21 @@ def test_openapi_schema_ast_inference():
|
||||||
assert expr_props["calc_int"]["type"] == "integer"
|
assert expr_props["calc_int"]["type"] == "integer"
|
||||||
assert expr_props["calc_bool"]["type"] == "boolean"
|
assert expr_props["calc_bool"]["type"] == "boolean"
|
||||||
|
|
||||||
|
explicit_schema = paths["/edge_cases/explicit_response"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
|
||||||
|
assert "$ref" not in explicit_schema
|
||||||
|
|
||||||
|
nested_schema = paths["/edge_cases/nested_function"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
|
||||||
|
assert "$ref" in nested_schema
|
||||||
|
nested_ref = nested_schema["$ref"].split("/")[-1]
|
||||||
|
nested_props = schema["components"]["schemas"][nested_ref]["properties"]
|
||||||
|
assert "outer" in nested_props
|
||||||
|
assert "inner" not in nested_props
|
||||||
|
|
||||||
|
invalid_keys_schema = paths["/edge_cases/invalid_keys"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
|
||||||
|
assert "$ref" not in invalid_keys_schema
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# test_openapi_schema_ast_inference()
|
||||||
|
print("Запуск сервера для проверки Swagger UI...")
|
||||||
|
print("Откройте в браузере: http://127.0.0.1:8000/docs")
|
||||||
|
uvicorn.run(app, host="127.0.0.1", port=8000)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue