mirror of https://github.com/tiangolo/fastapi.git
Enhance AST inference logging and handle multiple return statements
- Added detailed logging for various failure cases in the `infer_response_model_from_ast` function to aid debugging. - Modified the logic to collect all return statements, ensuring that inference is skipped if multiple returns are detected. - Introduced a new test case to verify that functions with multiple return statements correctly return None for inference.
This commit is contained in:
parent
ce0990107d
commit
d5268d7cf5
|
|
@ -31,6 +31,7 @@ from fastapi._compat import (
|
|||
may_v1,
|
||||
)
|
||||
from fastapi.datastructures import DefaultPlaceholder, DefaultType
|
||||
from fastapi.logger import logger
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import Literal
|
||||
|
|
@ -356,15 +357,24 @@ def infer_response_model_from_ast(
|
|||
"""
|
||||
import textwrap
|
||||
|
||||
func_name = getattr(endpoint_function, "__name__", "<unknown>")
|
||||
|
||||
try:
|
||||
source = inspect.getsource(endpoint_function)
|
||||
except (OSError, TypeError):
|
||||
logger.debug(
|
||||
f"AST inference skipped for '{func_name}': "
|
||||
"could not retrieve source code"
|
||||
)
|
||||
return None
|
||||
|
||||
source = textwrap.dedent(source)
|
||||
try:
|
||||
tree = ast.parse(source)
|
||||
except SyntaxError:
|
||||
logger.debug(
|
||||
f"AST inference skipped for '{func_name}': " "syntax error in source code"
|
||||
)
|
||||
return None
|
||||
|
||||
if not tree.body:
|
||||
|
|
@ -374,15 +384,16 @@ def infer_response_model_from_ast(
|
|||
if not isinstance(func_def, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
return None
|
||||
|
||||
return_stmt = None
|
||||
# Collect ALL return statements (not just the first one)
|
||||
return_statements: List[ast.Return] = []
|
||||
|
||||
nodes_to_visit: List[ast.AST] = list(func_def.body)
|
||||
while nodes_to_visit:
|
||||
node = nodes_to_visit.pop(0)
|
||||
|
||||
if isinstance(node, ast.Return):
|
||||
return_stmt = node
|
||||
break
|
||||
return_statements.append(node)
|
||||
# Don't break - continue to find all returns
|
||||
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
continue
|
||||
|
|
@ -390,9 +401,22 @@ def infer_response_model_from_ast(
|
|||
for child in ast.iter_child_nodes(node):
|
||||
nodes_to_visit.append(child)
|
||||
|
||||
if not return_stmt:
|
||||
if not return_statements:
|
||||
logger.debug(
|
||||
f"AST inference skipped for '{func_name}': " "no return statement found"
|
||||
)
|
||||
return None
|
||||
|
||||
# If there are multiple return statements, skip inference to avoid
|
||||
# misleading documentation (we can't reliably determine the structure)
|
||||
if len(return_statements) > 1:
|
||||
logger.debug(
|
||||
f"AST inference skipped for '{func_name}': "
|
||||
f"multiple return statements detected ({len(return_statements)})"
|
||||
)
|
||||
return None
|
||||
|
||||
return_stmt = return_statements[0]
|
||||
returned_value = return_stmt.value
|
||||
dict_node = None
|
||||
|
||||
|
|
@ -418,6 +442,10 @@ def infer_response_model_from_ast(
|
|||
break
|
||||
|
||||
if not dict_node:
|
||||
logger.debug(
|
||||
f"AST inference skipped for '{func_name}': "
|
||||
"return value is not a dict literal or assigned variable"
|
||||
)
|
||||
return None
|
||||
|
||||
fields = {}
|
||||
|
|
@ -426,23 +454,34 @@ def infer_response_model_from_ast(
|
|||
continue
|
||||
|
||||
if not isinstance(key.value, str):
|
||||
logger.debug(
|
||||
f"AST inference skipped for '{func_name}': "
|
||||
"non-string key found in dict"
|
||||
)
|
||||
return None
|
||||
|
||||
field_name = key.value
|
||||
|
||||
field_type = _infer_type_from_ast(
|
||||
value, func_def, f"{endpoint_function.__name__}_{field_name}"
|
||||
value, func_def, f"{func_name}_{field_name}"
|
||||
)
|
||||
|
||||
fields[field_name] = (field_type, ...)
|
||||
|
||||
if not fields:
|
||||
logger.debug(
|
||||
f"AST inference skipped for '{func_name}': " "no fields could be inferred"
|
||||
)
|
||||
return None
|
||||
|
||||
# Don't create a model if all fields are Any - this provides no additional
|
||||
# type information compared to Dict[str, Any] and would override explicit
|
||||
# type annotations unnecessarily
|
||||
if all(field_type is Any for field_type, _ in fields.values()):
|
||||
logger.debug(
|
||||
f"AST inference skipped for '{func_name}': "
|
||||
"all fields resolved to Any type"
|
||||
)
|
||||
return None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
|
|
@ -450,8 +489,12 @@ def infer_response_model_from_ast(
|
|||
else:
|
||||
from fastapi._compat.v1 import create_model
|
||||
|
||||
model_name = f"ResponseModel_{endpoint_function.__name__}"
|
||||
model_name = f"ResponseModel_{func_name}"
|
||||
try:
|
||||
return create_model(model_name, **fields) # type: ignore[call-overload,no-any-return]
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"AST inference skipped for '{func_name}': "
|
||||
f"failed to create model: {e}"
|
||||
)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -385,6 +385,20 @@ def test_infer_response_model_invalid_field_name():
|
|||
_test_invalid_field_name()
|
||||
|
||||
|
||||
def _test_multiple_returns(flag: bool):
|
||||
if flag:
|
||||
return {"spam": "spam"}
|
||||
else:
|
||||
return {"eggs": "eggs"}
|
||||
|
||||
|
||||
def test_infer_response_model_multiple_returns():
|
||||
"""Test that multiple return statements result in None (no inference)."""
|
||||
_test_multiple_returns(True)
|
||||
_test_multiple_returns(False)
|
||||
assert infer_response_model_from_ast(_test_multiple_returns) is None
|
||||
|
||||
|
||||
def _test_arg_no_annotation(a):
|
||||
return {"x": a}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue