mirror of https://github.com/tiangolo/fastapi.git
Merge 6543e75d03 into 31bbb38074
This commit is contained in:
commit
f2476d4b78
|
|
@ -994,6 +994,7 @@ class FastAPI(Starlette):
|
|||
responses=responses,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
strict_content_type=strict_content_type,
|
||||
defer_init=False,
|
||||
)
|
||||
self.exception_handlers: dict[
|
||||
Any, Callable[[Request, Any], Response | Awaitable[Response]]
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from contextlib import (
|
|||
asynccontextmanager,
|
||||
)
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
|
|
@ -840,6 +841,7 @@ class APIRoute(routing.Route):
|
|||
generate_unique_id_function: Callable[["APIRoute"], str]
|
||||
| DefaultPlaceholder = Default(generate_unique_id),
|
||||
strict_content_type: bool | DefaultPlaceholder = Default(True),
|
||||
defer_init: bool = True,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
|
|
@ -905,29 +907,43 @@ class APIRoute(routing.Route):
|
|||
assert is_body_allowed_for_status_code(status_code), (
|
||||
f"Status code {status_code} must not have a response body"
|
||||
)
|
||||
response_name = "Response_" + self.unique_id
|
||||
self.response_field = create_model_field(
|
||||
name=response_name,
|
||||
type_=self.response_model,
|
||||
mode="serialization",
|
||||
)
|
||||
else:
|
||||
self.response_field = None # type: ignore # ty: ignore[unused-ignore-comment]
|
||||
if self.stream_item_type:
|
||||
stream_item_name = "StreamItem_" + self.unique_id
|
||||
self.stream_item_field: ModelField | None = create_model_field(
|
||||
name=stream_item_name,
|
||||
type_=self.stream_item_type,
|
||||
mode="serialization",
|
||||
)
|
||||
else:
|
||||
self.stream_item_field = None
|
||||
|
||||
self.dependencies = list(dependencies or [])
|
||||
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
|
||||
# if a "form feed" character (page break) is found in the description text,
|
||||
# truncate description text to the content preceding the first "form feed"
|
||||
self.description = self.description.split("\f")[0].strip()
|
||||
response_fields = {}
|
||||
|
||||
assert callable(endpoint), "An endpoint must be a callable"
|
||||
|
||||
if not defer_init:
|
||||
self.init_attributes()
|
||||
|
||||
@cached_property
|
||||
def response_field(self) -> ModelField | None:
|
||||
if not self.response_model:
|
||||
return None
|
||||
response_name = "Response_" + self.unique_id
|
||||
return create_model_field(
|
||||
name=response_name,
|
||||
type_=self.response_model,
|
||||
mode="serialization",
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def stream_item_field(self) -> ModelField | None:
|
||||
if not self.stream_item_type:
|
||||
return None
|
||||
stream_item_name = "StreamItem_" + self.unique_id
|
||||
return create_model_field(
|
||||
name=stream_item_name,
|
||||
type_=self.stream_item_type,
|
||||
mode="serialization",
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def response_fields(self) -> dict[int | str, ModelField]:
|
||||
response_fields: dict[int | str, ModelField] = {}
|
||||
for additional_status_code, response in self.responses.items():
|
||||
assert isinstance(response, dict), "An additional response must be a dict"
|
||||
model = response.get("model")
|
||||
|
|
@ -940,40 +956,64 @@ class APIRoute(routing.Route):
|
|||
name=response_name, type_=model, mode="serialization"
|
||||
)
|
||||
response_fields[additional_status_code] = response_field
|
||||
if response_fields:
|
||||
self.response_fields: dict[int | str, ModelField] = response_fields
|
||||
else:
|
||||
self.response_fields = {}
|
||||
|
||||
assert callable(endpoint), "An endpoint must be a callable"
|
||||
self.dependant = get_dependant(
|
||||
return response_fields
|
||||
|
||||
@cached_property
|
||||
def dependant(self) -> Dependant:
|
||||
dependant = get_dependant(
|
||||
path=self.path_format, call=self.endpoint, scope="function"
|
||||
)
|
||||
for depends in self.dependencies[::-1]:
|
||||
self.dependant.dependencies.insert(
|
||||
dependant.dependencies.insert(
|
||||
0,
|
||||
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||
)
|
||||
self._flat_dependant = get_flat_dependant(self.dependant)
|
||||
self._embed_body_fields = _should_embed_body_fields(
|
||||
self._flat_dependant.body_params
|
||||
return dependant
|
||||
|
||||
@property
|
||||
def is_generator(self) -> bool:
|
||||
return self.dependant.is_async_gen_callable or self.dependant.is_gen_callable
|
||||
|
||||
@property
|
||||
def is_sse_stream(self) -> bool:
|
||||
return self.is_generator and lenient_issubclass(
|
||||
self.response_class, EventSourceResponse
|
||||
)
|
||||
self.body_field = get_body_field(
|
||||
|
||||
@property
|
||||
def is_json_stream(self) -> bool:
|
||||
return self.is_generator and isinstance(self.response_class, DefaultPlaceholder)
|
||||
|
||||
@cached_property
|
||||
def _flat_dependant(self) -> Dependant:
|
||||
return get_flat_dependant(self.dependant)
|
||||
|
||||
@cached_property
|
||||
def _embed_body_fields(self) -> bool:
|
||||
return _should_embed_body_fields(self._flat_dependant.body_params)
|
||||
|
||||
@cached_property
|
||||
def body_field(self) -> ModelField | None:
|
||||
return get_body_field(
|
||||
flat_dependant=self._flat_dependant,
|
||||
name=self.unique_id,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
# Detect generator endpoints that should stream as JSONL or SSE
|
||||
is_generator = (
|
||||
self.dependant.is_async_gen_callable or self.dependant.is_gen_callable
|
||||
)
|
||||
self.is_sse_stream = is_generator and lenient_issubclass(
|
||||
response_class, EventSourceResponse
|
||||
)
|
||||
self.is_json_stream = is_generator and isinstance(
|
||||
response_class, DefaultPlaceholder
|
||||
)
|
||||
self.app = request_response(self.get_route_handler())
|
||||
|
||||
@cached_property
|
||||
def app(self) -> ASGIApp: # type: ignore
|
||||
return request_response(self.get_route_handler())
|
||||
|
||||
def init_attributes(self) -> None:
|
||||
_ = self._embed_body_fields
|
||||
_ = self._flat_dependant
|
||||
_ = self.app
|
||||
_ = self.body_field
|
||||
_ = self.dependant
|
||||
_ = self.response_field
|
||||
_ = self.response_fields
|
||||
_ = self.stream_item_field
|
||||
|
||||
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
return get_request_handler(
|
||||
|
|
@ -1266,6 +1306,16 @@ class APIRouter(routing.Router):
|
|||
"""
|
||||
),
|
||||
] = Default(True),
|
||||
defer_init: Annotated[
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, every route will defer its initialization until the first call.
|
||||
This flag can be used to deactivate this behavior for the routes defined in this router,
|
||||
causing the routes to initialize immediately when they are defined.
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
) -> None:
|
||||
# Determine the lifespan context to use
|
||||
if lifespan is None:
|
||||
|
|
@ -1313,6 +1363,10 @@ class APIRouter(routing.Router):
|
|||
self.default_response_class = default_response_class
|
||||
self.generate_unique_id_function = generate_unique_id_function
|
||||
self.strict_content_type = strict_content_type
|
||||
self.defer_init = defer_init
|
||||
|
||||
if not self.defer_init:
|
||||
self.init_routes()
|
||||
|
||||
def route(
|
||||
self,
|
||||
|
|
@ -1413,6 +1467,7 @@ class APIRouter(routing.Router):
|
|||
strict_content_type=get_value_or_default(
|
||||
strict_content_type, self.strict_content_type
|
||||
),
|
||||
defer_init=self.defer_init,
|
||||
)
|
||||
self.routes.append(route)
|
||||
|
||||
|
|
@ -1575,6 +1630,11 @@ class APIRouter(routing.Router):
|
|||
|
||||
return decorator
|
||||
|
||||
def init_routes(self) -> None:
|
||||
for route in self.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
route.init_attributes()
|
||||
|
||||
def include_router(
|
||||
self,
|
||||
router: Annotated["APIRouter", Doc("The `APIRouter` to include.")],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,114 @@
|
|||
from itertools import chain
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import BaseModel
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
deferred_keys = [
|
||||
"app",
|
||||
"response_fields",
|
||||
"body_field",
|
||||
"response_field",
|
||||
"stream_item_field",
|
||||
"dependant",
|
||||
"_flat_dependant",
|
||||
"_embed_body_fields",
|
||||
]
|
||||
|
||||
|
||||
def check_if_initialized(route: APIRoute, should_not: bool = False):
|
||||
for key in deferred_keys:
|
||||
if should_not:
|
||||
assert key not in route.__dict__
|
||||
else:
|
||||
assert key in route.__dict__
|
||||
|
||||
|
||||
def create_test_router(routes: list[BaseRoute] | None = None, defer_init: bool = True):
|
||||
router = APIRouter(routes=routes or [], defer_init=defer_init)
|
||||
|
||||
class UserIdBody(BaseModel):
|
||||
user_id: int
|
||||
|
||||
@router.get("/user_id", dependencies=[Depends(lambda: True)])
|
||||
async def get_user_id(user_id: int = Depends(lambda: 1)) -> UserIdBody:
|
||||
return {"user_id": user_id}
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def test_route_defers():
|
||||
app = FastAPI()
|
||||
router = create_test_router(routes=app.router.routes)
|
||||
|
||||
for route in router.routes:
|
||||
if not isinstance(route, APIRoute):
|
||||
continue
|
||||
check_if_initialized(route, should_not=True)
|
||||
|
||||
app.router = router
|
||||
client = TestClient(app)
|
||||
response = client.get("/user_id")
|
||||
assert response.status_code == 200
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200
|
||||
|
||||
for route in router.routes:
|
||||
if not isinstance(route, APIRoute):
|
||||
continue
|
||||
check_if_initialized(route)
|
||||
|
||||
|
||||
def test_route_manual_init():
|
||||
router = create_test_router()
|
||||
for route in router.routes:
|
||||
check_if_initialized(route, should_not=True)
|
||||
route.init_attributes()
|
||||
check_if_initialized(route)
|
||||
|
||||
router = create_test_router()
|
||||
router.init_routes()
|
||||
for route in router.routes:
|
||||
check_if_initialized(route)
|
||||
|
||||
|
||||
def test_router_defer_init_flag():
|
||||
route = APIRoute("/test", lambda: {"test": True}, defer_init=False)
|
||||
check_if_initialized(route)
|
||||
|
||||
deferring_router = create_test_router()
|
||||
router = create_test_router(routes=deferring_router.routes, defer_init=False)
|
||||
|
||||
for route in router.routes:
|
||||
check_if_initialized(route)
|
||||
|
||||
|
||||
def test_root_router_always_initialized():
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
async def test_get():
|
||||
return {"test": 1}
|
||||
|
||||
router = create_test_router()
|
||||
app.include_router(router)
|
||||
for route in app.router.routes:
|
||||
if not isinstance(route, APIRoute):
|
||||
continue
|
||||
check_if_initialized(route)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_include_router_no_init():
|
||||
router1 = create_test_router()
|
||||
|
||||
router2 = create_test_router()
|
||||
router2.include_router(router1)
|
||||
|
||||
for route in chain(router1.routes, router2.routes):
|
||||
check_if_initialized(route, should_not=True)
|
||||
Loading…
Reference in New Issue