From 95d5b080a2fc169fe69faf8f9b8c149935bac7e4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Wed, 3 Dec 2025 19:19:46 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20Auto=20format?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/routing.py | 6 ++- fastapi/utils.py | 58 +++++++++++----------- tests/test_ast_inference.py | 99 ++++++++++++++++++++++--------------- 3 files changed, 93 insertions(+), 70 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 6485b1ad5..abcf9dab7 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -545,8 +545,10 @@ class APIRoute(routing.Route): response_model = None else: response_model = return_annotation - - if response_model is not None and not lenient_issubclass(response_model, BaseModel): + + if response_model is not None and not lenient_issubclass( + response_model, BaseModel + ): inferred = infer_response_model_from_ast(endpoint) if inferred: response_model = inferred diff --git a/fastapi/utils.py b/fastapi/utils.py index 873c6343a..233bac748 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -1,7 +1,6 @@ import ast import inspect import re - import warnings from dataclasses import is_dataclass from typing import ( @@ -277,12 +276,12 @@ def _infer_type_from_ast( return List[Any] first_type = _infer_type_from_ast(node.elts[0], func_def, context_name + "Item") - + for elt in node.elts[1:]: current_type = _infer_type_from_ast(elt, func_def, context_name + "Item") if current_type != first_type: return List[Any] - + if first_type is not Any: return List[first_type] return List[Any] @@ -291,7 +290,7 @@ def _infer_type_from_ast( left_type = _infer_type_from_ast(node.left, func_def, context_name) right_type = _infer_type_from_ast(node.right, func_def, context_name) if left_type == right_type and left_type in (int, float, str): - return left_type + return left_type if {left_type, right_type} == {int, float}: return float @@ -351,60 +350,63 @@ def infer_response_model_from_ast( source = inspect.getsource(endpoint_function) except (OSError, TypeError): return None - + source = inspect.cleandoc(source) try: tree = ast.parse(source) except SyntaxError: return None - + if not tree.body: return None - + func_def = tree.body[0] if not isinstance(func_def, (ast.FunctionDef, ast.AsyncFunctionDef)): return None - + return_stmt = None - nodes_to_visit = list(func_def.body) while nodes_to_visit: node = nodes_to_visit.pop(0) - + if isinstance(node, ast.Return): return_stmt = node break - + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): continue - + for child in ast.iter_child_nodes(node): nodes_to_visit.append(child) - + if not return_stmt: return None returned_value = return_stmt.value dict_node = None - + if isinstance(returned_value, ast.Dict): dict_node = returned_value elif isinstance(returned_value, ast.Name): variable_name = returned_value.id # Find assignment for node in func_def.body: - if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.target.id == variable_name: - if isinstance(node.value, ast.Dict): - dict_node = node.value - break + if ( + isinstance(node, ast.AnnAssign) + and isinstance(node.target, ast.Name) + and node.target.id == variable_name + ): + if isinstance(node.value, ast.Dict): + dict_node = node.value + break elif isinstance(node, ast.Assign): - for target in node.targets: - if isinstance(target, ast.Name) and target.id == variable_name: - if isinstance(node.value, ast.Dict): - dict_node = node.value - break - + for target in node.targets: + if isinstance(target, ast.Name) and target.id == variable_name: + if isinstance(node.value, ast.Dict): + dict_node = node.value + break + if not dict_node: return None @@ -412,16 +414,16 @@ def infer_response_model_from_ast( for key, value in zip(dict_node.keys, dict_node.values): if not isinstance(key, ast.Constant): continue - + if not isinstance(key.value, str): return None - + field_name = key.value - + field_type = _infer_type_from_ast( value, func_def, f"{endpoint_function.__name__}_{field_name}" ) - + fields[field_name] = (field_type, ...) if not fields: diff --git a/tests/test_ast_inference.py b/tests/test_ast_inference.py index e24cf6b54..1ece8dbb9 100644 --- a/tests/test_ast_inference.py +++ b/tests/test_ast_inference.py @@ -1,10 +1,12 @@ -from typing import Any, Dict, List +from typing import Any, Dict + from fastapi import FastAPI from fastapi.responses import JSONResponse from fastapi.testclient import TestClient app = FastAPI() + @app.get("/users/{user_id}") async def get_user(user_id: int) -> Dict[str, Any]: user: Dict[str, Any] = { @@ -16,6 +18,7 @@ async def get_user(user_id: int) -> Dict[str, Any]: } return user + @app.get("/orders/{order_id}") async def get_order_details(order_id: str) -> Dict[str, Any]: order_data: Dict[str, Any] = { @@ -40,49 +43,41 @@ async def get_order_details(order_id: str) -> Dict[str, Any]: } return order_data + @app.get("/edge_cases/mixed_types") async def get_mixed_types() -> Dict[str, Any]: - return { - "mixed_list": [1, "two", 3.0], - "description": "List starting with int" - } + return {"mixed_list": [1, "two", 3.0], "description": "List starting with int"} + @app.get("/edge_cases/expressions") async def get_expressions() -> Dict[str, Any]: - return { - "calc_int": 10 + 5, - "calc_str": "foo" + "bar", - "calc_bool": 5 > 3 - } + return {"calc_int": 10 + 5, "calc_str": "foo" + "bar", "calc_bool": 5 > 3} + @app.get("/edge_cases/empty_structures") async def get_empty_structures() -> Dict[str, Any]: - return { - "empty_list": [], - "empty_dict": {} - } + return {"empty_list": [], "empty_dict": {}} + @app.get("/edge_cases/local_variable") async def get_local_variable() -> Dict[str, Any]: - response_data = { - "status": "ok", - "nested": { - "check": True - } - } + response_data = {"status": "ok", "nested": {"check": True}} return response_data + @app.get("/edge_cases/explicit_response") def get_explicit_response() -> JSONResponse: return JSONResponse({"should_not_be_inferred": True}) + @app.get("/edge_cases/nested_function") async def get_nested_function() -> Dict[str, Any]: def inner_function(): return {"inner": "value"} - + return {"outer": "value"} + @app.get("/edge_cases/invalid_keys") def get_invalid_keys() -> Dict[Any, Any]: return {1: "value", "valid": "key"} @@ -92,38 +87,44 @@ class FakeDB: def get_user(self) -> Dict[str, Any]: return {"id": 1, "username": "db_user"} + fake_db = FakeDB() + @app.get("/db/direct_return") def get_db_direct() -> Dict[str, Any]: return fake_db.get_user() + @app.get("/db/dict_construction") def get_db_constructed() -> Dict[str, Any]: data = fake_db.get_user() - return { - "db_id": data["id"], - "source": "database" - } + return {"db_id": data["id"], "source": "database"} + client = TestClient(app) + def test_openapi_schema_ast_inference(): response = client.get("/openapi.json") assert response.status_code == 200 schema = response.json() paths = schema["paths"] - user_schema = paths["/users/{user_id}"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + user_schema = paths["/users/{user_id}"]["get"]["responses"]["200"]["content"][ + "application/json" + ]["schema"] assert "$ref" in user_schema ref_name = user_schema["$ref"].split("/")[-1] user_props = schema["components"]["schemas"][ref_name]["properties"] - + assert user_props["id"]["type"] == "integer" assert user_props["username"]["type"] == "string" assert user_props["is_active"]["type"] == "boolean" - order_schema = paths["/orders/{order_id}"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + order_schema = paths["/orders/{order_id}"]["get"]["responses"]["200"]["content"][ + "application/json" + ]["schema"] assert "$ref" in order_schema order_ref = order_schema["$ref"].split("/")[-1] order_props = schema["components"]["schemas"][order_ref]["properties"] @@ -135,39 +136,57 @@ def test_openapi_schema_ast_inference(): customer_prop = order_props["customer_info"] assert "$ref" in customer_prop - mixed_schema = paths["/edge_cases/mixed_types"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + mixed_schema = paths["/edge_cases/mixed_types"]["get"]["responses"]["200"][ + "content" + ]["application/json"]["schema"] mixed_ref = mixed_schema["$ref"].split("/")[-1] mixed_props = schema["components"]["schemas"][mixed_ref]["properties"] assert mixed_props["mixed_list"]["type"] == "array" - expr_schema = paths["/edge_cases/expressions"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + expr_schema = paths["/edge_cases/expressions"]["get"]["responses"]["200"][ + "content" + ]["application/json"]["schema"] expr_ref = expr_schema["$ref"].split("/")[-1] expr_props = schema["components"]["schemas"][expr_ref]["properties"] - + assert expr_props["calc_int"]["type"] == "integer" assert expr_props["calc_bool"]["type"] == "boolean" - explicit_schema = paths["/edge_cases/explicit_response"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + explicit_schema = paths["/edge_cases/explicit_response"]["get"]["responses"]["200"][ + "content" + ]["application/json"]["schema"] assert "$ref" not in explicit_schema - nested_schema = paths["/edge_cases/nested_function"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + nested_schema = paths["/edge_cases/nested_function"]["get"]["responses"]["200"][ + "content" + ]["application/json"]["schema"] assert "$ref" in nested_schema nested_ref = nested_schema["$ref"].split("/")[-1] nested_props = schema["components"]["schemas"][nested_ref]["properties"] assert "outer" in nested_props assert "inner" not in nested_props - invalid_keys_schema = paths["/edge_cases/invalid_keys"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + invalid_keys_schema = paths["/edge_cases/invalid_keys"]["get"]["responses"]["200"][ + "content" + ]["application/json"]["schema"] assert "$ref" not in invalid_keys_schema - db_direct_schema = paths["/db/direct_return"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + db_direct_schema = paths["/db/direct_return"]["get"]["responses"]["200"]["content"][ + "application/json" + ]["schema"] assert "$ref" not in db_direct_schema - db_constructed_schema = paths["/db/dict_construction"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + db_constructed_schema = paths["/db/dict_construction"]["get"]["responses"]["200"][ + "content" + ]["application/json"]["schema"] assert "$ref" in db_constructed_schema db_constructed_ref = db_constructed_schema["$ref"].split("/")[-1] - db_constructed_props = schema["components"]["schemas"][db_constructed_ref]["properties"] - - assert db_constructed_props["source"]["type"] == "string" - assert "type" not in db_constructed_props["db_id"] or db_constructed_props["db_id"] == {} + db_constructed_props = schema["components"]["schemas"][db_constructed_ref][ + "properties" + ] + assert db_constructed_props["source"]["type"] == "string" + assert ( + "type" not in db_constructed_props["db_id"] + or db_constructed_props["db_id"] == {} + )