Enhance response model inference and improve error handling

- Updated the response model inference logic in "APIRoute" to check for "None" before validating against "BaseModel".
- Refined the "infer_response_model_from_ast" function to handle nested function definitions and invalid dictionary keys, ensuring robust schema generation.
- Added new test cases for edge scenarios to validate the inference behavior.
This commit is contained in:
g7azazlo 2025-12-03 22:06:15 +03:00
parent 72b210209b
commit 40af1b90c3
3 changed files with 63 additions and 4 deletions

View File

@ -546,7 +546,7 @@ class APIRoute(routing.Route):
else: else:
response_model = return_annotation response_model = return_annotation
if not lenient_issubclass(response_model, BaseModel): if response_model is not None and not lenient_issubclass(response_model, BaseModel):
inferred = infer_response_model_from_ast(endpoint) inferred = infer_response_model_from_ast(endpoint)
if inferred: if inferred:
response_model = inferred response_model = inferred

View File

@ -366,11 +366,22 @@ def infer_response_model_from_ast(
return None return None
return_stmt = None return_stmt = None
for node in ast.walk(func_def):
nodes_to_visit = list(func_def.body)
while nodes_to_visit:
node = nodes_to_visit.pop(0)
if isinstance(node, ast.Return): if isinstance(node, ast.Return):
return_stmt = node return_stmt = node
break break
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
continue
for child in ast.iter_child_nodes(node):
nodes_to_visit.append(child)
if not return_stmt: if not return_stmt:
return None return None
@ -401,6 +412,10 @@ def infer_response_model_from_ast(
for key, value in zip(dict_node.keys, dict_node.values): for key, value in zip(dict_node.keys, dict_node.values):
if not isinstance(key, ast.Constant): if not isinstance(key, ast.Constant):
continue continue
if not isinstance(key.value, str):
return None
field_name = key.value field_name = key.value
field_type = _infer_type_from_ast( field_type = _infer_type_from_ast(
@ -418,4 +433,7 @@ def infer_response_model_from_ast(
from pydantic import create_model from pydantic import create_model
model_name = f"ResponseModel_{endpoint_function.__name__}" model_name = f"ResponseModel_{endpoint_function.__name__}"
try:
return create_model(model_name, **fields) return create_model(model_name, **fields)
except Exception:
return None

View File

@ -1,5 +1,13 @@
import sys
import os
import uvicorn
# Добавляем корень проекта в sys.path, чтобы Python видел пакет fastapi
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from typing import Any, Dict, List from typing import Any, Dict, List
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
app = FastAPI() app = FastAPI()
@ -71,6 +79,21 @@ async def get_local_variable() -> Dict[str, Any]:
} }
return response_data return response_data
@app.get("/edge_cases/explicit_response")
def get_explicit_response() -> JSONResponse:
return JSONResponse({"should_not_be_inferred": True})
@app.get("/edge_cases/nested_function")
async def get_nested_function() -> Dict[str, Any]:
def inner_function():
return {"inner": "value"}
return {"outer": "value"}
@app.get("/edge_cases/invalid_keys")
def get_invalid_keys() -> Dict[Any, Any]:
return {1: "value", "valid": "key"}
client = TestClient(app) client = TestClient(app)
def test_openapi_schema_ast_inference(): def test_openapi_schema_ast_inference():
@ -112,3 +135,21 @@ def test_openapi_schema_ast_inference():
assert expr_props["calc_int"]["type"] == "integer" assert expr_props["calc_int"]["type"] == "integer"
assert expr_props["calc_bool"]["type"] == "boolean" assert expr_props["calc_bool"]["type"] == "boolean"
explicit_schema = paths["/edge_cases/explicit_response"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
assert "$ref" not in explicit_schema
nested_schema = paths["/edge_cases/nested_function"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
assert "$ref" in nested_schema
nested_ref = nested_schema["$ref"].split("/")[-1]
nested_props = schema["components"]["schemas"][nested_ref]["properties"]
assert "outer" in nested_props
assert "inner" not in nested_props
invalid_keys_schema = paths["/edge_cases/invalid_keys"]["get"]["responses"]["200"]["content"]["application/json"]["schema"]
assert "$ref" not in invalid_keys_schema
if __name__ == "__main__":
# test_openapi_schema_ast_inference()
print("Запуск сервера для проверки Swagger UI...")
print("Откройте в браузере: http://127.0.0.1:8000/docs")
uvicorn.run(app, host="127.0.0.1", port=8000)