mirror of https://github.com/tiangolo/fastapi.git
Fix stringified Annotated dependency forward refs
This commit is contained in:
parent
eb6851dd4b
commit
b6e6614249
|
|
@ -1,3 +1,4 @@
|
|||
import ast
|
||||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
|
|
@ -60,7 +61,7 @@ from fastapi.logger import logger
|
|||
from fastapi.security.oauth2 import SecurityScopes
|
||||
from fastapi.types import DependencyCacheKey
|
||||
from fastapi.utils import create_model_field, get_path_param_names
|
||||
from pydantic import BaseModel, Json
|
||||
from pydantic import BaseModel, Json, PydanticUndefinedAnnotation
|
||||
from pydantic.fields import FieldInfo
|
||||
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
|
@ -242,10 +243,66 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
|||
return typed_signature
|
||||
|
||||
|
||||
def _evaluate_typed_forwardref(annotation: str, globalns: dict[str, Any]) -> Any:
|
||||
forward_ref = ForwardRef(annotation)
|
||||
try:
|
||||
return evaluate_forwardref(forward_ref, globalns, globalns) # ty: ignore[deprecated]
|
||||
except (NameError, PydanticUndefinedAnnotation):
|
||||
return forward_ref
|
||||
|
||||
|
||||
def _get_stringified_annotated(annotation: str, globalns: dict[str, Any]) -> Any | None:
|
||||
try:
|
||||
expression = ast.parse(annotation, mode="eval").body
|
||||
except SyntaxError:
|
||||
return None
|
||||
|
||||
if not isinstance(expression, ast.Subscript):
|
||||
return None
|
||||
|
||||
annotated_target = expression.value
|
||||
is_annotated = (
|
||||
isinstance(annotated_target, ast.Name)
|
||||
and annotated_target.id == "Annotated"
|
||||
) or (
|
||||
isinstance(annotated_target, ast.Attribute)
|
||||
and annotated_target.attr == "Annotated"
|
||||
)
|
||||
if not is_annotated:
|
||||
return None
|
||||
|
||||
annotated_items = (
|
||||
expression.slice.elts
|
||||
if isinstance(expression.slice, ast.Tuple)
|
||||
else [expression.slice]
|
||||
)
|
||||
if len(annotated_items) < 2:
|
||||
return None
|
||||
|
||||
first_annotation = ast.unparse(annotated_items[0])
|
||||
annotated_args = [
|
||||
get_typed_annotation(first_annotation, globalns)
|
||||
if "Annotated[" in first_annotation
|
||||
else _evaluate_typed_forwardref(first_annotation, globalns)
|
||||
]
|
||||
annotated_args.extend(
|
||||
eval(
|
||||
compile(ast.Expression(item), "<fastapi_annotation>", "eval"),
|
||||
globalns,
|
||||
globalns,
|
||||
)
|
||||
for item in annotated_items[1:]
|
||||
)
|
||||
return Annotated[tuple(annotated_args)] # ty: ignore[invalid-type-form]
|
||||
|
||||
|
||||
def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns) # ty: ignore[deprecated]
|
||||
if "Annotated[" in annotation:
|
||||
annotation = _get_stringified_annotated(annotation, globalns) or annotation
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns) # ty: ignore[deprecated]
|
||||
if annotation is type(None):
|
||||
return None
|
||||
return annotation
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
from typing import TYPE_CHECKING, Any, Annotated
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
|
|
@ -77,3 +77,45 @@ def test_openapi_schema(client: TestClient):
|
|||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_openapi_schema_for_dependency_with_forward_ref_defined_later():
|
||||
namespace: dict[str, Any] = {}
|
||||
exec(
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def get_potato() -> Potato:
|
||||
return Potato(color="red", size=10)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def read_root(potato: Annotated[Potato, Depends(get_potato)]):
|
||||
return {"color": potato.color, "size": potato.size}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Potato:
|
||||
color: str
|
||||
size: int
|
||||
""",
|
||||
namespace,
|
||||
)
|
||||
|
||||
client = TestClient(namespace["app"])
|
||||
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"color": "red", "size": 10}
|
||||
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
assert "requestBody" not in response.json()["paths"]["/"]["get"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue