from typing import Any, Dict, List, Union from unittest.mock import patch import pytest from fastapi import FastAPI, Response from fastapi.responses import JSONResponse from fastapi.testclient import TestClient from fastapi.utils import infer_response_model_from_ast app = FastAPI() @app.get("/users/{user_id}") async def get_user(user_id: int) -> Dict[str, Any]: user: Dict[str, Any] = { "id": user_id, "username": "example", "email": "user@example.com", "age": 25, "is_active": True, } return user @app.get("/orders/{order_id}") async def get_order_details(order_id: str) -> Dict[str, Any]: order_data: Dict[str, Any] = { "order_id": order_id, "status": "processing", "total_amount": 150.50, "tags": ["urgent", "new_customer"], "customer_info": { "name": "John Doe", "vip_status": False, "preferences": {"notifications": True, "theme": "dark"}, }, "items": [ { "item_id": 1, "name": "Laptop Stand", "price": 45.00, "in_stock": True, }, ], "metadata": None, } return order_data @app.get("/edge_cases/mixed_types") async def get_mixed_types() -> Dict[str, Any]: return {"mixed_list": [1, "two", 3.0], "description": "List starting with int"} @app.get("/edge_cases/expressions") async def get_expressions() -> Dict[str, Any]: return {"calc_int": 10 + 5, "calc_str": "foo" + "bar", "calc_bool": 5 > 3} @app.get("/edge_cases/empty_structures") async def get_empty_structures() -> Dict[str, Any]: return {"empty_list": [], "empty_dict": {}} @app.get("/edge_cases/local_variable") async def get_local_variable() -> Dict[str, Any]: response_data = {"status": "ok", "nested": {"check": True}} return response_data @app.get("/edge_cases/explicit_response") def get_explicit_response() -> JSONResponse: return JSONResponse({"should_not_be_inferred": True}) @app.get("/edge_cases/nested_function") async def get_nested_function() -> Dict[str, Any]: def inner_function(): return {"inner": "value"} inner_function() # Call to ensure coverage return {"outer": "value"} @app.get("/edge_cases/invalid_keys") def get_invalid_keys() -> Dict[Any, Any]: return {1: "value", "valid": "key"} class FakeDB: def get_user(self) -> Dict[str, Any]: return {"id": 1, "username": "db_user"} fake_db = FakeDB() @app.get("/db/direct_return") def get_db_direct() -> Dict[str, Any]: return fake_db.get_user() @app.get("/db/dict_construction") def get_db_constructed() -> Dict[str, Any]: data = fake_db.get_user() return {"db_id": data["id"], "source": "database"} @app.get("/edge_cases/homogeneous_list") def get_homogeneous_list() -> Dict[str, Any]: return {"numbers": [1, 2, 3], "strings": ["a", "b", "c"]} @app.get("/edge_cases/int_float_binop") def get_int_float_binop() -> Dict[str, Any]: return {"result": 10 + 5.5, "int_result": 10 + 5} @app.get("/edge_cases/arg_types/{a}") def get_arg_types(a: int, b: str, c: bool, d: float) -> Dict[str, Any]: return {"int_val": a, "str_val": b, "bool_val": c, "float_val": d} client = TestClient(app) def _test_no_return_func() -> Dict[str, Any]: x = {"a": 1} # noqa: F841 def _test_returns_call() -> Dict[str, Any]: return {}.copy() def _test_returns_empty_dict() -> Dict[str, Any]: return {} def _test_returns_dict_literal() -> Dict[str, Any]: return {"name": "test", "value": 123} def _test_returns_variable() -> Dict[str, Any]: data = {"status": "ok"} return data def _test_returns_annotated_var() -> Dict[str, Any]: data: Dict[str, Any] = {"status": "ok", "count": 42} return data def _test_func_mixed(item: int) -> Dict[str, Any]: return {"typed_field": item, "literal_field": "hello"} def _test_list_with_any_elements(x: Any) -> Dict[str, Any]: return {"items": [x]} def _test_non_constant_key() -> Dict[str, Any]: key = "dynamic" return {key: "value", "static": "ok"} def _test_list_arg(items: list) -> Dict[str, Any]: return {"items_val": items} def _test_dict_arg(data: dict) -> Dict[str, Any]: return {"data_val": data} def _test_nested_dict() -> Dict[str, Any]: return {"nested": {"inner": "value"}} def _test_nested_dict_with_var_key() -> Dict[str, Any]: key = "dynamic" return {"nested": {key: "value", "static": "ok"}} some_global_var = "global" another_global = 123 def _test_all_any_fields() -> Dict[str, Any]: local_var = "local" return {"field1": local_var, "field2": some_global_var, "field3": another_global} def _test_invalid_field_name() -> Dict[str, Any]: return {"__class__": "invalid", "normal": "ok"} def test_openapi_schema_ast_inference(): response = client.get("/openapi.json") assert response.status_code == 200 schema = response.json() paths = schema["paths"] user_schema = paths["/users/{user_id}"]["get"]["responses"]["200"]["content"][ "application/json" ]["schema"] assert "$ref" in user_schema ref_name = user_schema["$ref"].split("/")[-1] user_props = schema["components"]["schemas"][ref_name]["properties"] assert user_props["id"]["type"] == "integer" assert user_props["username"]["type"] == "string" assert user_props["is_active"]["type"] == "boolean" order_schema = paths["/orders/{order_id}"]["get"]["responses"]["200"]["content"][ "application/json" ]["schema"] assert "$ref" in order_schema order_ref = order_schema["$ref"].split("/")[-1] order_props = schema["components"]["schemas"][order_ref]["properties"] items_prop = order_props["items"] assert items_prop["type"] == "array" assert "$ref" in items_prop["items"] customer_prop = order_props["customer_info"] assert "$ref" in customer_prop mixed_schema = paths["/edge_cases/mixed_types"]["get"]["responses"]["200"][ "content" ]["application/json"]["schema"] mixed_ref = mixed_schema["$ref"].split("/")[-1] mixed_props = schema["components"]["schemas"][mixed_ref]["properties"] assert mixed_props["mixed_list"]["type"] == "array" expr_schema = paths["/edge_cases/expressions"]["get"]["responses"]["200"][ "content" ]["application/json"]["schema"] expr_ref = expr_schema["$ref"].split("/")[-1] expr_props = schema["components"]["schemas"][expr_ref]["properties"] assert expr_props["calc_int"]["type"] == "integer" assert expr_props["calc_bool"]["type"] == "boolean" explicit_schema = paths["/edge_cases/explicit_response"]["get"]["responses"]["200"][ "content" ]["application/json"]["schema"] assert "$ref" not in explicit_schema nested_schema = paths["/edge_cases/nested_function"]["get"]["responses"]["200"][ "content" ]["application/json"]["schema"] assert "$ref" in nested_schema nested_ref = nested_schema["$ref"].split("/")[-1] nested_props = schema["components"]["schemas"][nested_ref]["properties"] assert "outer" in nested_props assert "inner" not in nested_props invalid_keys_schema = paths["/edge_cases/invalid_keys"]["get"]["responses"]["200"][ "content" ]["application/json"]["schema"] assert "$ref" not in invalid_keys_schema db_direct_schema = paths["/db/direct_return"]["get"]["responses"]["200"]["content"][ "application/json" ]["schema"] assert "$ref" not in db_direct_schema db_constructed_schema = paths["/db/dict_construction"]["get"]["responses"]["200"][ "content" ]["application/json"]["schema"] assert "$ref" in db_constructed_schema db_constructed_ref = db_constructed_schema["$ref"].split("/")[-1] db_constructed_props = schema["components"]["schemas"][db_constructed_ref][ "properties" ] assert db_constructed_props["source"]["type"] == "string" assert ( "type" not in db_constructed_props["db_id"] or db_constructed_props["db_id"] == {} ) homogeneous_schema = paths["/edge_cases/homogeneous_list"]["get"]["responses"][ "200" ]["content"]["application/json"]["schema"] assert "$ref" in homogeneous_schema homogeneous_ref = homogeneous_schema["$ref"].split("/")[-1] homogeneous_props = schema["components"]["schemas"][homogeneous_ref]["properties"] assert homogeneous_props["numbers"]["type"] == "array" assert homogeneous_props["strings"]["type"] == "array" binop_schema = paths["/edge_cases/int_float_binop"]["get"]["responses"]["200"][ "content" ]["application/json"]["schema"] assert "$ref" in binop_schema binop_ref = binop_schema["$ref"].split("/")[-1] binop_props = schema["components"]["schemas"][binop_ref]["properties"] assert binop_props["result"]["type"] == "number" assert binop_props["int_result"]["type"] == "integer" arg_types_schema = paths["/edge_cases/arg_types/{a}"]["get"]["responses"]["200"][ "content" ]["application/json"]["schema"] assert "$ref" in arg_types_schema arg_types_ref = arg_types_schema["$ref"].split("/")[-1] arg_types_props = schema["components"]["schemas"][arg_types_ref]["properties"] assert arg_types_props["int_val"]["type"] == "integer" assert arg_types_props["str_val"]["type"] == "string" assert arg_types_props["bool_val"]["type"] == "boolean" assert arg_types_props["float_val"]["type"] == "number" empty_structures_schema = paths["/edge_cases/empty_structures"]["get"]["responses"][ "200" ]["content"]["application/json"]["schema"] assert "$ref" in empty_structures_schema empty_structures_ref = empty_structures_schema["$ref"].split("/")[-1] empty_structures_props = schema["components"]["schemas"][empty_structures_ref][ "properties" ] assert empty_structures_props["empty_list"]["type"] == "array" assert "items" not in empty_structures_props[ "empty_list" ] or not empty_structures_props["empty_list"].get("items") assert ( "type" not in empty_structures_props["empty_dict"] or empty_structures_props["empty_dict"]["type"] == "object" ) local_variable_schema = paths["/edge_cases/local_variable"]["get"]["responses"][ "200" ]["content"]["application/json"]["schema"] assert "$ref" in local_variable_schema local_variable_ref = local_variable_schema["$ref"].split("/")[-1] local_variable_props = schema["components"]["schemas"][local_variable_ref][ "properties" ] assert local_variable_props["status"]["type"] == "string" assert "$ref" in local_variable_props["nested"] @pytest.mark.parametrize( "func", [ _test_no_return_func, _test_returns_call, _test_returns_empty_dict, _test_all_any_fields, ], ) def test_infer_response_model_returns_none(func): """Test cases where AST inference should return None.""" assert infer_response_model_from_ast(func) is None def test_infer_response_model_returns_none_for_lambdas_and_builtins(): """Test cases where AST inference cannot get source code.""" assert infer_response_model_from_ast(lambda: {"a": 1}) is None assert infer_response_model_from_ast(len) is None @pytest.mark.parametrize( "func,expected_fields", [ (_test_returns_dict_literal, ["name", "value"]), (_test_returns_variable, ["status"]), (_test_returns_annotated_var, ["status", "count"]), (_test_func_mixed, ["typed_field", "literal_field"]), (_test_list_with_any_elements, ["items"]), (_test_non_constant_key, ["static"]), (_test_list_arg, ["items_val"]), (_test_dict_arg, ["data_val"]), (_test_nested_dict, ["nested"]), (_test_nested_dict_with_var_key, ["nested"]), ], ) def test_infer_response_model_success(func, expected_fields): """Test cases where AST inference should succeed and return a model with specific fields.""" result = infer_response_model_from_ast(func) assert result is not None for field in expected_fields: assert field in result.__annotations__ def test_infer_response_model_invalid_field_name(): """Test that invalid field names are handled gracefully (either skipped or model creation fails safely).""" _test_invalid_field_name() def _test_arg_no_annotation(a): return {"x": a} def test_infer_type_from_ast_arg_no_annotation(): _test_arg_no_annotation(1) assert infer_response_model_from_ast(_test_arg_no_annotation) is None def _test_arg_complex_annotation(a: List[int]): return {"x": a} def test_infer_type_from_ast_arg_complex_annotation(): _test_arg_complex_annotation([1]) assert infer_response_model_from_ast(_test_arg_complex_annotation) is None def test_infer_response_model_getsource_error(): def func(): pass func() with patch("inspect.getsource", side_effect=OSError): assert infer_response_model_from_ast(func) is None def test_infer_response_model_syntax_error(): def func(): pass func() with patch("inspect.getsource", return_value="def func( invalid"): assert infer_response_model_from_ast(func) is None def test_infer_response_model_no_body(): def func(): pass func() with patch("inspect.getsource", return_value="# comments"): assert infer_response_model_from_ast(func) is None def test_infer_response_model_not_function_def(): def func(): pass func() with patch("inspect.getsource", return_value="class A: pass"): assert infer_response_model_from_ast(func) is None def test_infer_response_model_create_model_error(): def func(): return {"a": 1} func() from fastapi.utils import PYDANTIC_V2 target = ( "pydantic.create_model" if PYDANTIC_V2 else "fastapi._compat.v1.create_model" ) with patch(target, side_effect=Exception("Boom")): assert infer_response_model_from_ast(func) is None def test_contains_response() -> None: from fastapi.routing import _contains_response assert _contains_response(Response) is True assert _contains_response(JSONResponse) is True assert _contains_response(str) is False assert _contains_response(Dict[str, Any]) is False assert _contains_response(Union[Response, dict]) is True assert _contains_response(Union[str, int]) is False assert _contains_response(Union[str, Union[Response, int]]) is True assert _contains_response(List[str]) is False def test_execution_of_endpoints(): """ Call all endpoints to ensure their bodies are executed (coverage). """ client.get("/users/1") client.get("/orders/order1") client.get("/edge_cases/mixed_types") client.get("/edge_cases/expressions") client.get("/edge_cases/empty_structures") client.get("/edge_cases/local_variable") client.get("/edge_cases/explicit_response") client.get("/edge_cases/nested_function") client.get("/edge_cases/invalid_keys") client.get("/db/direct_return") client.get("/db/dict_construction") client.get("/edge_cases/homogeneous_list") client.get("/edge_cases/int_float_binop") client.get("/edge_cases/arg_types/1", params={"b": "foo", "c": "true", "d": "1.5"}) def test_execution_of_helpers(): """ Call all helper functions to ensure their bodies are executed (coverage). """ _test_no_return_func() _test_returns_call() _test_returns_empty_dict() _test_returns_dict_literal() _test_returns_variable() _test_returns_annotated_var() _test_func_mixed(1) _test_list_with_any_elements(1) _test_non_constant_key() _test_list_arg([]) _test_dict_arg({}) _test_nested_dict() _test_nested_dict_with_var_key() _test_all_any_fields() _test_invalid_field_name()