diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 6b14dac8dc..cc347d89c6 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,3 +1,4 @@ +import ast import dataclasses import inspect import sys @@ -60,7 +61,7 @@ from fastapi.logger import logger from fastapi.security.oauth2 import SecurityScopes from fastapi.types import DependencyCacheKey from fastapi.utils import create_model_field, get_path_param_names -from pydantic import BaseModel, Json +from pydantic import BaseModel, Json, PydanticUndefinedAnnotation from pydantic.fields import FieldInfo from starlette.background import BackgroundTasks as StarletteBackgroundTasks from starlette.concurrency import run_in_threadpool @@ -242,10 +243,66 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: return typed_signature +def _evaluate_typed_forwardref(annotation: str, globalns: dict[str, Any]) -> Any: + forward_ref = ForwardRef(annotation) + try: + return evaluate_forwardref(forward_ref, globalns, globalns) # ty: ignore[deprecated] + except (NameError, PydanticUndefinedAnnotation): + return forward_ref + + +def _get_stringified_annotated(annotation: str, globalns: dict[str, Any]) -> Any | None: + try: + expression = ast.parse(annotation, mode="eval").body + except SyntaxError: + return None + + if not isinstance(expression, ast.Subscript): + return None + + annotated_target = expression.value + is_annotated = ( + isinstance(annotated_target, ast.Name) + and annotated_target.id == "Annotated" + ) or ( + isinstance(annotated_target, ast.Attribute) + and annotated_target.attr == "Annotated" + ) + if not is_annotated: + return None + + annotated_items = ( + expression.slice.elts + if isinstance(expression.slice, ast.Tuple) + else [expression.slice] + ) + if len(annotated_items) < 2: + return None + + first_annotation = ast.unparse(annotated_items[0]) + annotated_args = [ + get_typed_annotation(first_annotation, globalns) + if "Annotated[" in first_annotation + else _evaluate_typed_forwardref(first_annotation, globalns) + ] + annotated_args.extend( + eval( + compile(ast.Expression(item), "", "eval"), + globalns, + globalns, + ) + for item in annotated_items[1:] + ) + return Annotated[tuple(annotated_args)] # ty: ignore[invalid-type-form] + + def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: if isinstance(annotation, str): - annotation = ForwardRef(annotation) - annotation = evaluate_forwardref(annotation, globalns, globalns) # ty: ignore[deprecated] + if "Annotated[" in annotation: + annotation = _get_stringified_annotated(annotation, globalns) or annotation + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation = evaluate_forwardref(annotation, globalns, globalns) # ty: ignore[deprecated] if annotation is type(None): return None return annotation diff --git a/tests/test_stringified_annotation_dependency.py b/tests/test_stringified_annotation_dependency.py index ce88074957..a0e921ce3f 100644 --- a/tests/test_stringified_annotation_dependency.py +++ b/tests/test_stringified_annotation_dependency.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Annotated +from typing import TYPE_CHECKING, Any, Annotated import pytest from fastapi import Depends, FastAPI @@ -77,3 +77,45 @@ def test_openapi_schema(client: TestClient): }, } ) + + +def test_openapi_schema_for_dependency_with_forward_ref_defined_later(): + namespace: dict[str, Any] = {} + exec( + """ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Annotated + +from fastapi import Depends, FastAPI + +app = FastAPI() + + +def get_potato() -> Potato: + return Potato(color="red", size=10) + + +@app.get("/") +async def read_root(potato: Annotated[Potato, Depends(get_potato)]): + return {"color": potato.color, "size": potato.size} + + +@dataclass +class Potato: + color: str + size: int +""", + namespace, + ) + + client = TestClient(namespace["app"]) + + response = client.get("/") + assert response.status_code == 200, response.text + assert response.json() == {"color": "red", "size": 10} + + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert "requestBody" not in response.json()["paths"]["/"]["get"]