Enhance AST inference logging and handle multiple return statements

- Added detailed logging for various failure cases in the `infer_response_model_from_ast` function to aid debugging.
- Modified the logic to collect all return statements, ensuring that inference is skipped if multiple returns are detected.
- Introduced a new test case to verify that functions with multiple return statements correctly return None for inference.
This commit is contained in:
g7azazlo 2025-12-05 01:46:24 +03:00
parent ce0990107d
commit d5268d7cf5
2 changed files with 64 additions and 7 deletions

View File

@ -31,6 +31,7 @@ from fastapi._compat import (
may_v1,
)
from fastapi.datastructures import DefaultPlaceholder, DefaultType
from fastapi.logger import logger
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from typing_extensions import Literal
@ -356,15 +357,24 @@ def infer_response_model_from_ast(
"""
import textwrap
func_name = getattr(endpoint_function, "__name__", "<unknown>")
try:
source = inspect.getsource(endpoint_function)
except (OSError, TypeError):
logger.debug(
f"AST inference skipped for '{func_name}': "
"could not retrieve source code"
)
return None
source = textwrap.dedent(source)
try:
tree = ast.parse(source)
except SyntaxError:
logger.debug(
f"AST inference skipped for '{func_name}': " "syntax error in source code"
)
return None
if not tree.body:
@ -374,15 +384,16 @@ def infer_response_model_from_ast(
if not isinstance(func_def, (ast.FunctionDef, ast.AsyncFunctionDef)):
return None
return_stmt = None
# Collect ALL return statements (not just the first one)
return_statements: List[ast.Return] = []
nodes_to_visit: List[ast.AST] = list(func_def.body)
while nodes_to_visit:
node = nodes_to_visit.pop(0)
if isinstance(node, ast.Return):
return_stmt = node
break
return_statements.append(node)
# Don't break - continue to find all returns
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
continue
@ -390,9 +401,22 @@ def infer_response_model_from_ast(
for child in ast.iter_child_nodes(node):
nodes_to_visit.append(child)
if not return_stmt:
if not return_statements:
logger.debug(
f"AST inference skipped for '{func_name}': " "no return statement found"
)
return None
# If there are multiple return statements, skip inference to avoid
# misleading documentation (we can't reliably determine the structure)
if len(return_statements) > 1:
logger.debug(
f"AST inference skipped for '{func_name}': "
f"multiple return statements detected ({len(return_statements)})"
)
return None
return_stmt = return_statements[0]
returned_value = return_stmt.value
dict_node = None
@ -418,6 +442,10 @@ def infer_response_model_from_ast(
break
if not dict_node:
logger.debug(
f"AST inference skipped for '{func_name}': "
"return value is not a dict literal or assigned variable"
)
return None
fields = {}
@ -426,23 +454,34 @@ def infer_response_model_from_ast(
continue
if not isinstance(key.value, str):
logger.debug(
f"AST inference skipped for '{func_name}': "
"non-string key found in dict"
)
return None
field_name = key.value
field_type = _infer_type_from_ast(
value, func_def, f"{endpoint_function.__name__}_{field_name}"
value, func_def, f"{func_name}_{field_name}"
)
fields[field_name] = (field_type, ...)
if not fields:
logger.debug(
f"AST inference skipped for '{func_name}': " "no fields could be inferred"
)
return None
# Don't create a model if all fields are Any - this provides no additional
# type information compared to Dict[str, Any] and would override explicit
# type annotations unnecessarily
if all(field_type is Any for field_type, _ in fields.values()):
logger.debug(
f"AST inference skipped for '{func_name}': "
"all fields resolved to Any type"
)
return None
if PYDANTIC_V2:
@ -450,8 +489,12 @@ def infer_response_model_from_ast(
else:
from fastapi._compat.v1 import create_model
model_name = f"ResponseModel_{endpoint_function.__name__}"
model_name = f"ResponseModel_{func_name}"
try:
return create_model(model_name, **fields) # type: ignore[call-overload,no-any-return]
except Exception:
except Exception as e:
logger.debug(
f"AST inference skipped for '{func_name}': "
f"failed to create model: {e}"
)
return None

View File

@ -385,6 +385,20 @@ def test_infer_response_model_invalid_field_name():
_test_invalid_field_name()
def _test_multiple_returns(flag: bool):
if flag:
return {"spam": "spam"}
else:
return {"eggs": "eggs"}
def test_infer_response_model_multiple_returns():
"""Test that multiple return statements result in None (no inference)."""
_test_multiple_returns(True)
_test_multiple_returns(False)
assert infer_response_model_from_ast(_test_multiple_returns) is None
def _test_arg_no_annotation(a):
return {"x": a}