diff --git a/fastapi/utils.py b/fastapi/utils.py index 8c4b7bfe12..e24dd0ad71 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -321,20 +321,28 @@ def _infer_type_from_ast( if isinstance(node, ast.Name): arg_name = node.id for arg in func_def.args.args: - if arg.arg == arg_name and arg.annotation: - if isinstance(arg.annotation, ast.Name): - if arg.annotation.id == "int": - return int - if arg.annotation.id == "str": - return str - if arg.annotation.id == "bool": - return bool - if arg.annotation.id == "float": - return float - if arg.annotation.id == "list": - return List[Any] - if arg.annotation.id == "dict": - return Dict[str, Any] + if arg.arg != arg_name: + continue + + if not arg.annotation: + continue + + if not isinstance(arg.annotation, ast.Name): + continue + + annotation_id = arg.annotation.id + if annotation_id == "int": + return int + if annotation_id == "str": + return str + if annotation_id == "bool": + return bool + if annotation_id == "float": + return float + if annotation_id == "list": + return List[Any] + if annotation_id == "dict": + return Dict[str, Any] return Any diff --git a/tests/test_ast_inference.py b/tests/test_ast_inference.py index 1109814a85..81561dcc8a 100644 --- a/tests/test_ast_inference.py +++ b/tests/test_ast_inference.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Union +import pytest from fastapi import FastAPI, Response from fastapi.responses import JSONResponse from fastapi.testclient import TestClient @@ -103,37 +104,30 @@ def get_db_constructed() -> Dict[str, Any]: return {"db_id": data["id"], "source": "database"} -# Test for homogeneous list type inference @app.get("/edge_cases/homogeneous_list") def get_homogeneous_list() -> Dict[str, Any]: return {"numbers": [1, 2, 3], "strings": ["a", "b", "c"]} -# Test for int/float binary operation @app.get("/edge_cases/int_float_binop") def get_int_float_binop() -> Dict[str, Any]: return {"result": 10 + 5.5, "int_result": 10 + 5} -# Test for argument with different type annotations @app.get("/edge_cases/arg_types/{a}") def get_arg_types(a: int, b: str, c: bool, d: float) -> Dict[str, Any]: return {"int_val": a, "str_val": b, "bool_val": c, "float_val": d} - - client = TestClient(app) -# Module-level functions for testing infer_response_model_from_ast -# (nested functions don't work with inspect.getsource) def _test_no_return_func() -> Dict[str, Any]: x = {"a": 1} # noqa: F841 def _test_returns_call() -> Dict[str, Any]: - return dict(a=1) + return {}.copy() def _test_returns_empty_dict() -> Dict[str, Any]: @@ -179,13 +173,11 @@ def _test_nested_dict() -> Dict[str, Any]: return {"nested": {"inner": "value"}} -# Nested dict with variable key - should trigger line 304 in _infer_type_from_ast def _test_nested_dict_with_var_key() -> Dict[str, Any]: key = "dynamic" return {"nested": {key: "value", "static": "ok"}} -# Test function where all returned dict values are unannotated variables (resolve to Any) some_global_var = "global" another_global = 123 @@ -195,7 +187,6 @@ def _test_all_any_fields() -> Dict[str, Any]: return {"field1": local_var, "field2": some_global_var, "field3": another_global} -# Test function with field name that could cause model creation issues def _test_invalid_field_name() -> Dict[str, Any]: return {"__class__": "invalid", "normal": "ok"} @@ -286,7 +277,6 @@ def test_openapi_schema_ast_inference(): or db_constructed_props["db_id"] == {} ) - # Test homogeneous list inference homogeneous_schema = paths["/edge_cases/homogeneous_list"]["get"]["responses"][ "200" ]["content"]["application/json"]["schema"] @@ -296,7 +286,6 @@ def test_openapi_schema_ast_inference(): assert homogeneous_props["numbers"]["type"] == "array" assert homogeneous_props["strings"]["type"] == "array" - # Test int/float binary operation binop_schema = paths["/edge_cases/int_float_binop"]["get"]["responses"]["200"][ "content" ]["application/json"]["schema"] @@ -306,7 +295,6 @@ def test_openapi_schema_ast_inference(): assert binop_props["result"]["type"] == "number" assert binop_props["int_result"]["type"] == "integer" - # Test argument type annotations arg_types_schema = paths["/edge_cases/arg_types/{a}"]["get"]["responses"]["200"][ "content" ]["application/json"]["schema"] @@ -319,126 +307,68 @@ def test_openapi_schema_ast_inference(): assert arg_types_props["float_val"]["type"] == "number" +@pytest.mark.parametrize( + "func", + [ + _test_no_return_func, + _test_returns_call, + _test_returns_empty_dict, + _test_all_any_fields, + ], +) +def test_infer_response_model_returns_none(func): + """Test cases where AST inference should return None.""" + assert infer_response_model_from_ast(func) is None -def test_infer_response_model_edge_cases() -> None: - """Test edge cases for infer_response_model_from_ast function.""" - # Test function without return statement - result = infer_response_model_from_ast(_test_no_return_func) - assert result is None +def test_infer_response_model_returns_none_for_lambdas_and_builtins(): + """Test cases where AST inference cannot get source code.""" + assert infer_response_model_from_ast(lambda: {"a": 1}) is None + assert infer_response_model_from_ast(len) is None - # Test function returning a function call (not dict literal) - result = infer_response_model_from_ast(_test_returns_call) - assert result is None - # Test function with empty dict - result = infer_response_model_from_ast(_test_returns_empty_dict) - assert result is None - - # Test function with dict literal - result = infer_response_model_from_ast(_test_returns_dict_literal) +@pytest.mark.parametrize( + "func,expected_fields", + [ + (_test_returns_dict_literal, ["name", "value"]), + (_test_returns_variable, ["status"]), + (_test_returns_annotated_var, ["status", "count"]), + (_test_func_mixed, ["typed_field", "literal_field"]), + (_test_list_with_any_elements, ["items"]), + (_test_non_constant_key, ["static"]), + (_test_list_arg, ["items_val"]), + (_test_dict_arg, ["data_val"]), + (_test_nested_dict, ["nested"]), + (_test_nested_dict_with_var_key, ["nested"]), + ], +) +def test_infer_response_model_success(func, expected_fields): + """Test cases where AST inference should succeed and return a model with specific fields.""" + result = infer_response_model_from_ast(func) assert result is not None - assert "name" in result.__annotations__ - assert "value" in result.__annotations__ - - # Test lambda (cannot get source) - result = infer_response_model_from_ast(lambda: {"a": 1}) - assert result is None - - # Test built-in function (cannot get source) - result = infer_response_model_from_ast(len) - assert result is None - - # Test function with variable return - result = infer_response_model_from_ast(_test_returns_variable) - assert result is not None - assert "status" in result.__annotations__ - - # Test function with annotated assignment - result = infer_response_model_from_ast(_test_returns_annotated_var) - assert result is not None - assert "status" in result.__annotations__ - assert "count" in result.__annotations__ + for field in expected_fields: + assert field in result.__annotations__ -def test_infer_response_model_all_any_fields() -> None: - """Test that model is NOT created when all fields are Any.""" - # Use module-level function where all values are unannotated variables - # This should result in all fields being Any - result = infer_response_model_from_ast(_test_all_any_fields) - # Should return None because all fields are Any - assert result is None - - -def test_infer_response_model_mixed_any_and_typed() -> None: - """Test that model IS created when some fields have types.""" - result = infer_response_model_from_ast(_test_func_mixed) - # Should create model because "literal_field" is str, not Any - assert result is not None - assert "typed_field" in result.__annotations__ - assert "literal_field" in result.__annotations__ - - -def test_infer_type_from_ast_edge_cases() -> None: - """Test edge cases for _infer_type_from_ast function.""" - # Test list with Any elements (line 287) - result = infer_response_model_from_ast(_test_list_with_any_elements) - # Should return None because "items" will be List[Any] and that's the only non-Any field - # Actually let me check if this creates a model with List[Any] - assert result is not None or result is None # Just ensure no error - - # Test non-constant key in dict - should be skipped (line 304) - result = infer_response_model_from_ast(_test_non_constant_key) - # Should still create model for the "static" key - assert result is not None - assert "static" in result.__annotations__ - - # Test list annotation (line 335) - result = infer_response_model_from_ast(_test_list_arg) - assert result is not None - assert "items_val" in result.__annotations__ - - # Test dict annotation (line 337) - result = infer_response_model_from_ast(_test_dict_arg) - assert result is not None - assert "data_val" in result.__annotations__ - - # Test nested dict creates nested model - result = infer_response_model_from_ast(_test_nested_dict) - assert result is not None - assert "nested" in result.__annotations__ - - # Test nested dict with variable key (triggers line 304 in _infer_type_from_ast) - result = infer_response_model_from_ast(_test_nested_dict_with_var_key) - assert result is not None - assert "nested" in result.__annotations__ - - # Test invalid field name that might cause create_model to fail (lines 448-449) - result = infer_response_model_from_ast(_test_invalid_field_name) - # Either None (exception caught) or a valid model - assert result is None or result is not None +def test_infer_response_model_invalid_field_name(): + """Test that invalid field names are handled gracefully (either skipped or model creation fails safely).""" + # This specifically tests protections against things like {"__class__": ...} + # It might return None (if create_model fails) or a model (if pydantic handles it) + # We just want to ensure it doesn't raise an unhandled exception + try: + infer_response_model_from_ast(_test_invalid_field_name) + except Exception as e: + pytest.fail(f"infer_response_model_from_ast raised exception: {e}") def test_contains_response() -> None: - """Test _contains_response function from routing module.""" from fastapi.routing import _contains_response - # Test simple Response assert _contains_response(Response) is True - - # Test JSONResponse (subclass) assert _contains_response(JSONResponse) is True - - # Test non-Response type assert _contains_response(str) is False assert _contains_response(Dict[str, Any]) is False - - # Test Union with Response assert _contains_response(Union[Response, dict]) is True assert _contains_response(Union[str, int]) is False - - # Test nested Union assert _contains_response(Union[str, Union[Response, int]]) is True - - # Test List (no Response) assert _contains_response(List[str]) is False