Refactor type inference logic in utils for improved clarity

- Simplified the `_infer_type_from_ast` function by consolidating conditional checks for argument annotations, enhancing readability and maintainability.
- Updated test cases in `tests/test_ast_inference.py` to utilize parameterization for better organization and coverage of edge cases in response model inference.
This commit is contained in:
g7azazlo 2025-12-04 00:28:00 +03:00
parent e03e232c39
commit bf90082191
2 changed files with 69 additions and 131 deletions

View File

@ -321,20 +321,28 @@ def _infer_type_from_ast(
if isinstance(node, ast.Name):
arg_name = node.id
for arg in func_def.args.args:
if arg.arg == arg_name and arg.annotation:
if isinstance(arg.annotation, ast.Name):
if arg.annotation.id == "int":
return int
if arg.annotation.id == "str":
return str
if arg.annotation.id == "bool":
return bool
if arg.annotation.id == "float":
return float
if arg.annotation.id == "list":
return List[Any]
if arg.annotation.id == "dict":
return Dict[str, Any]
if arg.arg != arg_name:
continue
if not arg.annotation:
continue
if not isinstance(arg.annotation, ast.Name):
continue
annotation_id = arg.annotation.id
if annotation_id == "int":
return int
if annotation_id == "str":
return str
if annotation_id == "bool":
return bool
if annotation_id == "float":
return float
if annotation_id == "list":
return List[Any]
if annotation_id == "dict":
return Dict[str, Any]
return Any

View File

@ -1,5 +1,6 @@
from typing import Any, Dict, List, Union
import pytest
from fastapi import FastAPI, Response
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
@ -103,37 +104,30 @@ def get_db_constructed() -> Dict[str, Any]:
return {"db_id": data["id"], "source": "database"}
# Test for homogeneous list type inference
@app.get("/edge_cases/homogeneous_list")
def get_homogeneous_list() -> Dict[str, Any]:
return {"numbers": [1, 2, 3], "strings": ["a", "b", "c"]}
# Test for int/float binary operation
@app.get("/edge_cases/int_float_binop")
def get_int_float_binop() -> Dict[str, Any]:
return {"result": 10 + 5.5, "int_result": 10 + 5}
# Test for argument with different type annotations
@app.get("/edge_cases/arg_types/{a}")
def get_arg_types(a: int, b: str, c: bool, d: float) -> Dict[str, Any]:
return {"int_val": a, "str_val": b, "bool_val": c, "float_val": d}
client = TestClient(app)
# Module-level functions for testing infer_response_model_from_ast
# (nested functions don't work with inspect.getsource)
def _test_no_return_func() -> Dict[str, Any]:
x = {"a": 1} # noqa: F841
def _test_returns_call() -> Dict[str, Any]:
return dict(a=1)
return {}.copy()
def _test_returns_empty_dict() -> Dict[str, Any]:
@ -179,13 +173,11 @@ def _test_nested_dict() -> Dict[str, Any]:
return {"nested": {"inner": "value"}}
# Nested dict with variable key - should trigger line 304 in _infer_type_from_ast
def _test_nested_dict_with_var_key() -> Dict[str, Any]:
key = "dynamic"
return {"nested": {key: "value", "static": "ok"}}
# Test function where all returned dict values are unannotated variables (resolve to Any)
some_global_var = "global"
another_global = 123
@ -195,7 +187,6 @@ def _test_all_any_fields() -> Dict[str, Any]:
return {"field1": local_var, "field2": some_global_var, "field3": another_global}
# Test function with field name that could cause model creation issues
def _test_invalid_field_name() -> Dict[str, Any]:
return {"__class__": "invalid", "normal": "ok"}
@ -286,7 +277,6 @@ def test_openapi_schema_ast_inference():
or db_constructed_props["db_id"] == {}
)
# Test homogeneous list inference
homogeneous_schema = paths["/edge_cases/homogeneous_list"]["get"]["responses"][
"200"
]["content"]["application/json"]["schema"]
@ -296,7 +286,6 @@ def test_openapi_schema_ast_inference():
assert homogeneous_props["numbers"]["type"] == "array"
assert homogeneous_props["strings"]["type"] == "array"
# Test int/float binary operation
binop_schema = paths["/edge_cases/int_float_binop"]["get"]["responses"]["200"][
"content"
]["application/json"]["schema"]
@ -306,7 +295,6 @@ def test_openapi_schema_ast_inference():
assert binop_props["result"]["type"] == "number"
assert binop_props["int_result"]["type"] == "integer"
# Test argument type annotations
arg_types_schema = paths["/edge_cases/arg_types/{a}"]["get"]["responses"]["200"][
"content"
]["application/json"]["schema"]
@ -319,126 +307,68 @@ def test_openapi_schema_ast_inference():
assert arg_types_props["float_val"]["type"] == "number"
@pytest.mark.parametrize(
"func",
[
_test_no_return_func,
_test_returns_call,
_test_returns_empty_dict,
_test_all_any_fields,
],
)
def test_infer_response_model_returns_none(func):
"""Test cases where AST inference should return None."""
assert infer_response_model_from_ast(func) is None
def test_infer_response_model_edge_cases() -> None:
"""Test edge cases for infer_response_model_from_ast function."""
# Test function without return statement
result = infer_response_model_from_ast(_test_no_return_func)
assert result is None
def test_infer_response_model_returns_none_for_lambdas_and_builtins():
"""Test cases where AST inference cannot get source code."""
assert infer_response_model_from_ast(lambda: {"a": 1}) is None
assert infer_response_model_from_ast(len) is None
# Test function returning a function call (not dict literal)
result = infer_response_model_from_ast(_test_returns_call)
assert result is None
# Test function with empty dict
result = infer_response_model_from_ast(_test_returns_empty_dict)
assert result is None
# Test function with dict literal
result = infer_response_model_from_ast(_test_returns_dict_literal)
@pytest.mark.parametrize(
"func,expected_fields",
[
(_test_returns_dict_literal, ["name", "value"]),
(_test_returns_variable, ["status"]),
(_test_returns_annotated_var, ["status", "count"]),
(_test_func_mixed, ["typed_field", "literal_field"]),
(_test_list_with_any_elements, ["items"]),
(_test_non_constant_key, ["static"]),
(_test_list_arg, ["items_val"]),
(_test_dict_arg, ["data_val"]),
(_test_nested_dict, ["nested"]),
(_test_nested_dict_with_var_key, ["nested"]),
],
)
def test_infer_response_model_success(func, expected_fields):
"""Test cases where AST inference should succeed and return a model with specific fields."""
result = infer_response_model_from_ast(func)
assert result is not None
assert "name" in result.__annotations__
assert "value" in result.__annotations__
# Test lambda (cannot get source)
result = infer_response_model_from_ast(lambda: {"a": 1})
assert result is None
# Test built-in function (cannot get source)
result = infer_response_model_from_ast(len)
assert result is None
# Test function with variable return
result = infer_response_model_from_ast(_test_returns_variable)
assert result is not None
assert "status" in result.__annotations__
# Test function with annotated assignment
result = infer_response_model_from_ast(_test_returns_annotated_var)
assert result is not None
assert "status" in result.__annotations__
assert "count" in result.__annotations__
for field in expected_fields:
assert field in result.__annotations__
def test_infer_response_model_all_any_fields() -> None:
"""Test that model is NOT created when all fields are Any."""
# Use module-level function where all values are unannotated variables
# This should result in all fields being Any
result = infer_response_model_from_ast(_test_all_any_fields)
# Should return None because all fields are Any
assert result is None
def test_infer_response_model_mixed_any_and_typed() -> None:
"""Test that model IS created when some fields have types."""
result = infer_response_model_from_ast(_test_func_mixed)
# Should create model because "literal_field" is str, not Any
assert result is not None
assert "typed_field" in result.__annotations__
assert "literal_field" in result.__annotations__
def test_infer_type_from_ast_edge_cases() -> None:
"""Test edge cases for _infer_type_from_ast function."""
# Test list with Any elements (line 287)
result = infer_response_model_from_ast(_test_list_with_any_elements)
# Should return None because "items" will be List[Any] and that's the only non-Any field
# Actually let me check if this creates a model with List[Any]
assert result is not None or result is None # Just ensure no error
# Test non-constant key in dict - should be skipped (line 304)
result = infer_response_model_from_ast(_test_non_constant_key)
# Should still create model for the "static" key
assert result is not None
assert "static" in result.__annotations__
# Test list annotation (line 335)
result = infer_response_model_from_ast(_test_list_arg)
assert result is not None
assert "items_val" in result.__annotations__
# Test dict annotation (line 337)
result = infer_response_model_from_ast(_test_dict_arg)
assert result is not None
assert "data_val" in result.__annotations__
# Test nested dict creates nested model
result = infer_response_model_from_ast(_test_nested_dict)
assert result is not None
assert "nested" in result.__annotations__
# Test nested dict with variable key (triggers line 304 in _infer_type_from_ast)
result = infer_response_model_from_ast(_test_nested_dict_with_var_key)
assert result is not None
assert "nested" in result.__annotations__
# Test invalid field name that might cause create_model to fail (lines 448-449)
result = infer_response_model_from_ast(_test_invalid_field_name)
# Either None (exception caught) or a valid model
assert result is None or result is not None
def test_infer_response_model_invalid_field_name():
"""Test that invalid field names are handled gracefully (either skipped or model creation fails safely)."""
# This specifically tests protections against things like {"__class__": ...}
# It might return None (if create_model fails) or a model (if pydantic handles it)
# We just want to ensure it doesn't raise an unhandled exception
try:
infer_response_model_from_ast(_test_invalid_field_name)
except Exception as e:
pytest.fail(f"infer_response_model_from_ast raised exception: {e}")
def test_contains_response() -> None:
"""Test _contains_response function from routing module."""
from fastapi.routing import _contains_response
# Test simple Response
assert _contains_response(Response) is True
# Test JSONResponse (subclass)
assert _contains_response(JSONResponse) is True
# Test non-Response type
assert _contains_response(str) is False
assert _contains_response(Dict[str, Any]) is False
# Test Union with Response
assert _contains_response(Union[Response, dict]) is True
assert _contains_response(Union[str, int]) is False
# Test nested Union
assert _contains_response(Union[str, Union[Response, int]]) is True
# Test List (no Response)
assert _contains_response(List[str]) is False