diff --git a/fastapi/utils.py b/fastapi/utils.py index becc3bb6d..3683c4a4e 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -31,6 +31,7 @@ from fastapi._compat import ( may_v1, ) from fastapi.datastructures import DefaultPlaceholder, DefaultType +from fastapi.logger import logger from pydantic import BaseModel from pydantic.fields import FieldInfo from typing_extensions import Literal @@ -356,15 +357,24 @@ def infer_response_model_from_ast( """ import textwrap + func_name = getattr(endpoint_function, "__name__", "") + try: source = inspect.getsource(endpoint_function) except (OSError, TypeError): + logger.debug( + f"AST inference skipped for '{func_name}': " + "could not retrieve source code" + ) return None source = textwrap.dedent(source) try: tree = ast.parse(source) except SyntaxError: + logger.debug( + f"AST inference skipped for '{func_name}': " "syntax error in source code" + ) return None if not tree.body: @@ -374,15 +384,16 @@ def infer_response_model_from_ast( if not isinstance(func_def, (ast.FunctionDef, ast.AsyncFunctionDef)): return None - return_stmt = None + # Collect ALL return statements (not just the first one) + return_statements: List[ast.Return] = [] nodes_to_visit: List[ast.AST] = list(func_def.body) while nodes_to_visit: node = nodes_to_visit.pop(0) if isinstance(node, ast.Return): - return_stmt = node - break + return_statements.append(node) + # Don't break - continue to find all returns if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): continue @@ -390,9 +401,22 @@ def infer_response_model_from_ast( for child in ast.iter_child_nodes(node): nodes_to_visit.append(child) - if not return_stmt: + if not return_statements: + logger.debug( + f"AST inference skipped for '{func_name}': " "no return statement found" + ) return None + # If there are multiple return statements, skip inference to avoid + # misleading documentation (we can't reliably determine the structure) + if len(return_statements) > 1: + logger.debug( + f"AST inference skipped for '{func_name}': " + f"multiple return statements detected ({len(return_statements)})" + ) + return None + + return_stmt = return_statements[0] returned_value = return_stmt.value dict_node = None @@ -418,6 +442,10 @@ def infer_response_model_from_ast( break if not dict_node: + logger.debug( + f"AST inference skipped for '{func_name}': " + "return value is not a dict literal or assigned variable" + ) return None fields = {} @@ -426,23 +454,34 @@ def infer_response_model_from_ast( continue if not isinstance(key.value, str): + logger.debug( + f"AST inference skipped for '{func_name}': " + "non-string key found in dict" + ) return None field_name = key.value field_type = _infer_type_from_ast( - value, func_def, f"{endpoint_function.__name__}_{field_name}" + value, func_def, f"{func_name}_{field_name}" ) fields[field_name] = (field_type, ...) if not fields: + logger.debug( + f"AST inference skipped for '{func_name}': " "no fields could be inferred" + ) return None # Don't create a model if all fields are Any - this provides no additional # type information compared to Dict[str, Any] and would override explicit # type annotations unnecessarily if all(field_type is Any for field_type, _ in fields.values()): + logger.debug( + f"AST inference skipped for '{func_name}': " + "all fields resolved to Any type" + ) return None if PYDANTIC_V2: @@ -450,8 +489,12 @@ def infer_response_model_from_ast( else: from fastapi._compat.v1 import create_model - model_name = f"ResponseModel_{endpoint_function.__name__}" + model_name = f"ResponseModel_{func_name}" try: return create_model(model_name, **fields) # type: ignore[call-overload,no-any-return] - except Exception: + except Exception as e: + logger.debug( + f"AST inference skipped for '{func_name}': " + f"failed to create model: {e}" + ) return None diff --git a/tests/test_ast_inference.py b/tests/test_ast_inference.py index 70270f4d1..65305076a 100644 --- a/tests/test_ast_inference.py +++ b/tests/test_ast_inference.py @@ -385,6 +385,20 @@ def test_infer_response_model_invalid_field_name(): _test_invalid_field_name() +def _test_multiple_returns(flag: bool): + if flag: + return {"spam": "spam"} + else: + return {"eggs": "eggs"} + + +def test_infer_response_model_multiple_returns(): + """Test that multiple return statements result in None (no inference).""" + _test_multiple_returns(True) + _test_multiple_returns(False) + assert infer_response_model_from_ast(_test_multiple_returns) is None + + def _test_arg_no_annotation(a): return {"x": a}