diff --git a/fastapi/routing.py b/fastapi/routing.py index 1098c1d03..6485b1ad5 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -546,7 +546,7 @@ class APIRoute(routing.Route): else: response_model = return_annotation - if not lenient_issubclass(response_model, BaseModel): + if response_model is not None and not lenient_issubclass(response_model, BaseModel): inferred = infer_response_model_from_ast(endpoint) if inferred: response_model = inferred diff --git a/fastapi/utils.py b/fastapi/utils.py index e5831c893..873c6343a 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -366,11 +366,22 @@ def infer_response_model_from_ast( return None return_stmt = None - for node in ast.walk(func_def): + + + nodes_to_visit = list(func_def.body) + while nodes_to_visit: + node = nodes_to_visit.pop(0) + if isinstance(node, ast.Return): return_stmt = node break + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + continue + + for child in ast.iter_child_nodes(node): + nodes_to_visit.append(child) + if not return_stmt: return None @@ -401,6 +412,10 @@ def infer_response_model_from_ast( for key, value in zip(dict_node.keys, dict_node.values): if not isinstance(key, ast.Constant): continue + + if not isinstance(key.value, str): + return None + field_name = key.value field_type = _infer_type_from_ast( @@ -418,4 +433,7 @@ def infer_response_model_from_ast( from pydantic import create_model model_name = f"ResponseModel_{endpoint_function.__name__}" - return create_model(model_name, **fields) + try: + return create_model(model_name, **fields) + except Exception: + return None diff --git a/tests/test_ast_inference.py b/tests/test_ast_inference.py index 5867385ec..5e11fb09f 100644 --- a/tests/test_ast_inference.py +++ b/tests/test_ast_inference.py @@ -1,5 +1,13 @@ +import sys +import os +import uvicorn + +# Добавляем корень проекта в sys.path, чтобы Python видел пакет fastapi +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + from typing import Any, Dict, List from fastapi import FastAPI +from fastapi.responses import JSONResponse from fastapi.testclient import TestClient app = FastAPI() @@ -71,6 +79,21 @@ async def get_local_variable() -> Dict[str, Any]: } return response_data +@app.get("/edge_cases/explicit_response") +def get_explicit_response() -> JSONResponse: + return JSONResponse({"should_not_be_inferred": True}) + +@app.get("/edge_cases/nested_function") +async def get_nested_function() -> Dict[str, Any]: + def inner_function(): + return {"inner": "value"} + + return {"outer": "value"} + +@app.get("/edge_cases/invalid_keys") +def get_invalid_keys() -> Dict[Any, Any]: + return {1: "value", "valid": "key"} + client = TestClient(app) def test_openapi_schema_ast_inference(): @@ -111,4 +134,22 @@ def test_openapi_schema_ast_inference(): assert expr_props["calc_int"]["type"] == "integer" assert expr_props["calc_bool"]["type"] == "boolean" - \ No newline at end of file + + explicit_schema = paths["/edge_cases/explicit_response"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + assert "$ref" not in explicit_schema + + nested_schema = paths["/edge_cases/nested_function"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + assert "$ref" in nested_schema + nested_ref = nested_schema["$ref"].split("/")[-1] + nested_props = schema["components"]["schemas"][nested_ref]["properties"] + assert "outer" in nested_props + assert "inner" not in nested_props + + invalid_keys_schema = paths["/edge_cases/invalid_keys"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + assert "$ref" not in invalid_keys_schema + +if __name__ == "__main__": + # test_openapi_schema_ast_inference() + print("Запуск сервера для проверки Swagger UI...") + print("Откройте в браузере: http://127.0.0.1:8000/docs") + uvicorn.run(app, host="127.0.0.1", port=8000)