mirror of https://github.com/tiangolo/fastapi.git
509 lines
16 KiB
Python
509 lines
16 KiB
Python
from typing import Any, Dict, List, Union
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from fastapi import FastAPI, Response
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.testclient import TestClient
|
|
from fastapi.utils import infer_response_model_from_ast
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
@app.get("/users/{user_id}")
|
|
async def get_user(user_id: int) -> Dict[str, Any]:
|
|
user: Dict[str, Any] = {
|
|
"id": user_id,
|
|
"username": "example",
|
|
"email": "user@example.com",
|
|
"age": 25,
|
|
"is_active": True,
|
|
}
|
|
return user
|
|
|
|
|
|
@app.get("/orders/{order_id}")
|
|
async def get_order_details(order_id: str) -> Dict[str, Any]:
|
|
order_data: Dict[str, Any] = {
|
|
"order_id": order_id,
|
|
"status": "processing",
|
|
"total_amount": 150.50,
|
|
"tags": ["urgent", "new_customer"],
|
|
"customer_info": {
|
|
"name": "John Doe",
|
|
"vip_status": False,
|
|
"preferences": {"notifications": True, "theme": "dark"},
|
|
},
|
|
"items": [
|
|
{
|
|
"item_id": 1,
|
|
"name": "Laptop Stand",
|
|
"price": 45.00,
|
|
"in_stock": True,
|
|
},
|
|
],
|
|
"metadata": None,
|
|
}
|
|
return order_data
|
|
|
|
|
|
@app.get("/edge_cases/mixed_types")
|
|
async def get_mixed_types() -> Dict[str, Any]:
|
|
return {"mixed_list": [1, "two", 3.0], "description": "List starting with int"}
|
|
|
|
|
|
@app.get("/edge_cases/expressions")
|
|
async def get_expressions() -> Dict[str, Any]:
|
|
return {"calc_int": 10 + 5, "calc_str": "foo" + "bar", "calc_bool": 5 > 3}
|
|
|
|
|
|
@app.get("/edge_cases/empty_structures")
|
|
async def get_empty_structures() -> Dict[str, Any]:
|
|
return {"empty_list": [], "empty_dict": {}}
|
|
|
|
|
|
@app.get("/edge_cases/local_variable")
|
|
async def get_local_variable() -> Dict[str, Any]:
|
|
response_data = {"status": "ok", "nested": {"check": True}}
|
|
return response_data
|
|
|
|
|
|
@app.get("/edge_cases/explicit_response")
|
|
def get_explicit_response() -> JSONResponse:
|
|
return JSONResponse({"should_not_be_inferred": True})
|
|
|
|
|
|
@app.get("/edge_cases/nested_function")
|
|
async def get_nested_function() -> Dict[str, Any]:
|
|
def inner_function():
|
|
return {"inner": "value"}
|
|
|
|
inner_function() # Call to ensure coverage
|
|
return {"outer": "value"}
|
|
|
|
|
|
@app.get("/edge_cases/invalid_keys")
|
|
def get_invalid_keys() -> Dict[Any, Any]:
|
|
return {1: "value", "valid": "key"}
|
|
|
|
|
|
class FakeDB:
|
|
def get_user(self) -> Dict[str, Any]:
|
|
return {"id": 1, "username": "db_user"}
|
|
|
|
|
|
fake_db = FakeDB()
|
|
|
|
|
|
@app.get("/db/direct_return")
|
|
def get_db_direct() -> Dict[str, Any]:
|
|
return fake_db.get_user()
|
|
|
|
|
|
@app.get("/db/dict_construction")
|
|
def get_db_constructed() -> Dict[str, Any]:
|
|
data = fake_db.get_user()
|
|
return {"db_id": data["id"], "source": "database"}
|
|
|
|
|
|
@app.get("/edge_cases/homogeneous_list")
|
|
def get_homogeneous_list() -> Dict[str, Any]:
|
|
return {"numbers": [1, 2, 3], "strings": ["a", "b", "c"]}
|
|
|
|
|
|
@app.get("/edge_cases/int_float_binop")
|
|
def get_int_float_binop() -> Dict[str, Any]:
|
|
return {"result": 10 + 5.5, "int_result": 10 + 5}
|
|
|
|
|
|
@app.get("/edge_cases/arg_types/{a}")
|
|
def get_arg_types(a: int, b: str, c: bool, d: float) -> Dict[str, Any]:
|
|
return {"int_val": a, "str_val": b, "bool_val": c, "float_val": d}
|
|
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
def _test_no_return_func() -> Dict[str, Any]:
|
|
x = {"a": 1} # noqa: F841
|
|
|
|
|
|
def _test_returns_call() -> Dict[str, Any]:
|
|
return {}.copy()
|
|
|
|
|
|
def _test_returns_empty_dict() -> Dict[str, Any]:
|
|
return {}
|
|
|
|
|
|
def _test_returns_dict_literal() -> Dict[str, Any]:
|
|
return {"name": "test", "value": 123}
|
|
|
|
|
|
def _test_returns_variable() -> Dict[str, Any]:
|
|
data = {"status": "ok"}
|
|
return data
|
|
|
|
|
|
def _test_returns_annotated_var() -> Dict[str, Any]:
|
|
data: Dict[str, Any] = {"status": "ok", "count": 42}
|
|
return data
|
|
|
|
|
|
def _test_func_mixed(item: int) -> Dict[str, Any]:
|
|
return {"typed_field": item, "literal_field": "hello"}
|
|
|
|
|
|
def _test_list_with_any_elements(x: Any) -> Dict[str, Any]:
|
|
return {"items": [x]}
|
|
|
|
|
|
def _test_non_constant_key() -> Dict[str, Any]:
|
|
key = "dynamic"
|
|
return {key: "value", "static": "ok"}
|
|
|
|
|
|
def _test_list_arg(items: list) -> Dict[str, Any]:
|
|
return {"items_val": items}
|
|
|
|
|
|
def _test_dict_arg(data: dict) -> Dict[str, Any]:
|
|
return {"data_val": data}
|
|
|
|
|
|
def _test_nested_dict() -> Dict[str, Any]:
|
|
return {"nested": {"inner": "value"}}
|
|
|
|
|
|
def _test_nested_dict_with_var_key() -> Dict[str, Any]:
|
|
key = "dynamic"
|
|
return {"nested": {key: "value", "static": "ok"}}
|
|
|
|
|
|
some_global_var = "global"
|
|
another_global = 123
|
|
|
|
|
|
def _test_all_any_fields() -> Dict[str, Any]:
|
|
local_var = "local"
|
|
return {"field1": local_var, "field2": some_global_var, "field3": another_global}
|
|
|
|
|
|
def _test_invalid_field_name() -> Dict[str, Any]:
|
|
return {"__class__": "invalid", "normal": "ok"}
|
|
|
|
|
|
def test_openapi_schema_ast_inference():
|
|
response = client.get("/openapi.json")
|
|
assert response.status_code == 200
|
|
schema = response.json()
|
|
paths = schema["paths"]
|
|
|
|
user_schema = paths["/users/{user_id}"]["get"]["responses"]["200"]["content"][
|
|
"application/json"
|
|
]["schema"]
|
|
assert "$ref" in user_schema
|
|
ref_name = user_schema["$ref"].split("/")[-1]
|
|
user_props = schema["components"]["schemas"][ref_name]["properties"]
|
|
|
|
assert user_props["id"]["type"] == "integer"
|
|
assert user_props["username"]["type"] == "string"
|
|
assert user_props["is_active"]["type"] == "boolean"
|
|
|
|
order_schema = paths["/orders/{order_id}"]["get"]["responses"]["200"]["content"][
|
|
"application/json"
|
|
]["schema"]
|
|
assert "$ref" in order_schema
|
|
order_ref = order_schema["$ref"].split("/")[-1]
|
|
order_props = schema["components"]["schemas"][order_ref]["properties"]
|
|
|
|
items_prop = order_props["items"]
|
|
assert items_prop["type"] == "array"
|
|
assert "$ref" in items_prop["items"]
|
|
|
|
customer_prop = order_props["customer_info"]
|
|
assert "$ref" in customer_prop
|
|
|
|
mixed_schema = paths["/edge_cases/mixed_types"]["get"]["responses"]["200"][
|
|
"content"
|
|
]["application/json"]["schema"]
|
|
mixed_ref = mixed_schema["$ref"].split("/")[-1]
|
|
mixed_props = schema["components"]["schemas"][mixed_ref]["properties"]
|
|
assert mixed_props["mixed_list"]["type"] == "array"
|
|
|
|
expr_schema = paths["/edge_cases/expressions"]["get"]["responses"]["200"][
|
|
"content"
|
|
]["application/json"]["schema"]
|
|
expr_ref = expr_schema["$ref"].split("/")[-1]
|
|
expr_props = schema["components"]["schemas"][expr_ref]["properties"]
|
|
|
|
assert expr_props["calc_int"]["type"] == "integer"
|
|
assert expr_props["calc_bool"]["type"] == "boolean"
|
|
|
|
explicit_schema = paths["/edge_cases/explicit_response"]["get"]["responses"]["200"][
|
|
"content"
|
|
]["application/json"]["schema"]
|
|
assert "$ref" not in explicit_schema
|
|
|
|
nested_schema = paths["/edge_cases/nested_function"]["get"]["responses"]["200"][
|
|
"content"
|
|
]["application/json"]["schema"]
|
|
assert "$ref" in nested_schema
|
|
nested_ref = nested_schema["$ref"].split("/")[-1]
|
|
nested_props = schema["components"]["schemas"][nested_ref]["properties"]
|
|
assert "outer" in nested_props
|
|
assert "inner" not in nested_props
|
|
|
|
invalid_keys_schema = paths["/edge_cases/invalid_keys"]["get"]["responses"]["200"][
|
|
"content"
|
|
]["application/json"]["schema"]
|
|
assert "$ref" not in invalid_keys_schema
|
|
|
|
db_direct_schema = paths["/db/direct_return"]["get"]["responses"]["200"]["content"][
|
|
"application/json"
|
|
]["schema"]
|
|
assert "$ref" not in db_direct_schema
|
|
|
|
db_constructed_schema = paths["/db/dict_construction"]["get"]["responses"]["200"][
|
|
"content"
|
|
]["application/json"]["schema"]
|
|
assert "$ref" in db_constructed_schema
|
|
db_constructed_ref = db_constructed_schema["$ref"].split("/")[-1]
|
|
db_constructed_props = schema["components"]["schemas"][db_constructed_ref][
|
|
"properties"
|
|
]
|
|
|
|
assert db_constructed_props["source"]["type"] == "string"
|
|
assert (
|
|
"type" not in db_constructed_props["db_id"]
|
|
or db_constructed_props["db_id"] == {}
|
|
)
|
|
|
|
homogeneous_schema = paths["/edge_cases/homogeneous_list"]["get"]["responses"][
|
|
"200"
|
|
]["content"]["application/json"]["schema"]
|
|
assert "$ref" in homogeneous_schema
|
|
homogeneous_ref = homogeneous_schema["$ref"].split("/")[-1]
|
|
homogeneous_props = schema["components"]["schemas"][homogeneous_ref]["properties"]
|
|
assert homogeneous_props["numbers"]["type"] == "array"
|
|
assert homogeneous_props["strings"]["type"] == "array"
|
|
|
|
binop_schema = paths["/edge_cases/int_float_binop"]["get"]["responses"]["200"][
|
|
"content"
|
|
]["application/json"]["schema"]
|
|
assert "$ref" in binop_schema
|
|
binop_ref = binop_schema["$ref"].split("/")[-1]
|
|
binop_props = schema["components"]["schemas"][binop_ref]["properties"]
|
|
assert binop_props["result"]["type"] == "number"
|
|
assert binop_props["int_result"]["type"] == "integer"
|
|
|
|
arg_types_schema = paths["/edge_cases/arg_types/{a}"]["get"]["responses"]["200"][
|
|
"content"
|
|
]["application/json"]["schema"]
|
|
assert "$ref" in arg_types_schema
|
|
arg_types_ref = arg_types_schema["$ref"].split("/")[-1]
|
|
arg_types_props = schema["components"]["schemas"][arg_types_ref]["properties"]
|
|
assert arg_types_props["int_val"]["type"] == "integer"
|
|
assert arg_types_props["str_val"]["type"] == "string"
|
|
assert arg_types_props["bool_val"]["type"] == "boolean"
|
|
assert arg_types_props["float_val"]["type"] == "number"
|
|
|
|
empty_structures_schema = paths["/edge_cases/empty_structures"]["get"]["responses"][
|
|
"200"
|
|
]["content"]["application/json"]["schema"]
|
|
assert "$ref" in empty_structures_schema
|
|
empty_structures_ref = empty_structures_schema["$ref"].split("/")[-1]
|
|
empty_structures_props = schema["components"]["schemas"][empty_structures_ref][
|
|
"properties"
|
|
]
|
|
assert empty_structures_props["empty_list"]["type"] == "array"
|
|
assert "items" not in empty_structures_props[
|
|
"empty_list"
|
|
] or not empty_structures_props["empty_list"].get("items")
|
|
assert (
|
|
"type" not in empty_structures_props["empty_dict"]
|
|
or empty_structures_props["empty_dict"]["type"] == "object"
|
|
)
|
|
|
|
local_variable_schema = paths["/edge_cases/local_variable"]["get"]["responses"][
|
|
"200"
|
|
]["content"]["application/json"]["schema"]
|
|
assert "$ref" in local_variable_schema
|
|
local_variable_ref = local_variable_schema["$ref"].split("/")[-1]
|
|
local_variable_props = schema["components"]["schemas"][local_variable_ref][
|
|
"properties"
|
|
]
|
|
assert local_variable_props["status"]["type"] == "string"
|
|
assert "$ref" in local_variable_props["nested"]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"func",
|
|
[
|
|
_test_no_return_func,
|
|
_test_returns_call,
|
|
_test_returns_empty_dict,
|
|
_test_all_any_fields,
|
|
],
|
|
)
|
|
def test_infer_response_model_returns_none(func):
|
|
"""Test cases where AST inference should return None."""
|
|
assert infer_response_model_from_ast(func) is None
|
|
|
|
|
|
def test_infer_response_model_returns_none_for_lambdas_and_builtins():
|
|
"""Test cases where AST inference cannot get source code."""
|
|
assert infer_response_model_from_ast(lambda: {"a": 1}) is None
|
|
assert infer_response_model_from_ast(len) is None
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"func,expected_fields",
|
|
[
|
|
(_test_returns_dict_literal, ["name", "value"]),
|
|
(_test_returns_variable, ["status"]),
|
|
(_test_returns_annotated_var, ["status", "count"]),
|
|
(_test_func_mixed, ["typed_field", "literal_field"]),
|
|
(_test_list_with_any_elements, ["items"]),
|
|
(_test_non_constant_key, ["static"]),
|
|
(_test_list_arg, ["items_val"]),
|
|
(_test_dict_arg, ["data_val"]),
|
|
(_test_nested_dict, ["nested"]),
|
|
(_test_nested_dict_with_var_key, ["nested"]),
|
|
],
|
|
)
|
|
def test_infer_response_model_success(func, expected_fields):
|
|
"""Test cases where AST inference should succeed and return a model with specific fields."""
|
|
result = infer_response_model_from_ast(func)
|
|
assert result is not None
|
|
for field in expected_fields:
|
|
assert field in result.__annotations__
|
|
|
|
|
|
def test_infer_response_model_invalid_field_name():
|
|
"""Test that invalid field names are handled gracefully (either skipped or model creation fails safely)."""
|
|
_test_invalid_field_name()
|
|
|
|
|
|
def _test_arg_no_annotation(a):
|
|
return {"x": a}
|
|
|
|
|
|
def test_infer_type_from_ast_arg_no_annotation():
|
|
_test_arg_no_annotation(1)
|
|
assert infer_response_model_from_ast(_test_arg_no_annotation) is None
|
|
|
|
|
|
def _test_arg_complex_annotation(a: List[int]):
|
|
return {"x": a}
|
|
|
|
|
|
def test_infer_type_from_ast_arg_complex_annotation():
|
|
_test_arg_complex_annotation([1])
|
|
assert infer_response_model_from_ast(_test_arg_complex_annotation) is None
|
|
|
|
|
|
def test_infer_response_model_getsource_error():
|
|
def func():
|
|
pass
|
|
|
|
func()
|
|
with patch("inspect.getsource", side_effect=OSError):
|
|
assert infer_response_model_from_ast(func) is None
|
|
|
|
|
|
def test_infer_response_model_syntax_error():
|
|
def func():
|
|
pass
|
|
|
|
func()
|
|
with patch("inspect.getsource", return_value="def func( invalid"):
|
|
assert infer_response_model_from_ast(func) is None
|
|
|
|
|
|
def test_infer_response_model_no_body():
|
|
def func():
|
|
pass
|
|
|
|
func()
|
|
with patch("inspect.getsource", return_value="# comments"):
|
|
assert infer_response_model_from_ast(func) is None
|
|
|
|
|
|
def test_infer_response_model_not_function_def():
|
|
def func():
|
|
pass
|
|
|
|
func()
|
|
with patch("inspect.getsource", return_value="class A: pass"):
|
|
assert infer_response_model_from_ast(func) is None
|
|
|
|
|
|
def test_infer_response_model_create_model_error():
|
|
def func():
|
|
return {"a": 1}
|
|
|
|
func()
|
|
from fastapi.utils import PYDANTIC_V2
|
|
|
|
target = (
|
|
"pydantic.create_model" if PYDANTIC_V2 else "fastapi._compat.v1.create_model"
|
|
)
|
|
|
|
with patch(target, side_effect=Exception("Boom")):
|
|
assert infer_response_model_from_ast(func) is None
|
|
|
|
|
|
def test_contains_response() -> None:
|
|
from fastapi.routing import _contains_response
|
|
|
|
assert _contains_response(Response) is True
|
|
assert _contains_response(JSONResponse) is True
|
|
assert _contains_response(str) is False
|
|
assert _contains_response(Dict[str, Any]) is False
|
|
assert _contains_response(Union[Response, dict]) is True
|
|
assert _contains_response(Union[str, int]) is False
|
|
assert _contains_response(Union[str, Union[Response, int]]) is True
|
|
assert _contains_response(List[str]) is False
|
|
|
|
|
|
def test_execution_of_endpoints():
|
|
"""
|
|
Call all endpoints to ensure their bodies are executed (coverage).
|
|
"""
|
|
client.get("/users/1")
|
|
client.get("/orders/order1")
|
|
client.get("/edge_cases/mixed_types")
|
|
client.get("/edge_cases/expressions")
|
|
client.get("/edge_cases/empty_structures")
|
|
client.get("/edge_cases/local_variable")
|
|
client.get("/edge_cases/explicit_response")
|
|
client.get("/edge_cases/nested_function")
|
|
client.get("/edge_cases/invalid_keys")
|
|
client.get("/db/direct_return")
|
|
client.get("/db/dict_construction")
|
|
client.get("/edge_cases/homogeneous_list")
|
|
client.get("/edge_cases/int_float_binop")
|
|
client.get("/edge_cases/arg_types/1", params={"b": "foo", "c": "true", "d": "1.5"})
|
|
|
|
|
|
def test_execution_of_helpers():
|
|
"""
|
|
Call all helper functions to ensure their bodies are executed (coverage).
|
|
"""
|
|
_test_no_return_func()
|
|
_test_returns_call()
|
|
_test_returns_empty_dict()
|
|
_test_returns_dict_literal()
|
|
_test_returns_variable()
|
|
_test_returns_annotated_var()
|
|
_test_func_mixed(1)
|
|
_test_list_with_any_elements(1)
|
|
_test_non_constant_key()
|
|
_test_list_arg([])
|
|
_test_dict_arg({})
|
|
_test_nested_dict()
|
|
_test_nested_dict_with_var_key()
|
|
_test_all_any_fields()
|
|
_test_invalid_field_name()
|