mirror of https://github.com/tiangolo/fastapi.git
♻️ Refactor include_router to mount sub-routers
This commit is contained in:
parent
26f725d259
commit
edcae918e2
|
|
@ -137,6 +137,10 @@ class FastAPI(Starlette):
|
|||
self.middleware_stack: ASGIApp = self.build_middleware_stack()
|
||||
self.setup()
|
||||
|
||||
@property
|
||||
def routes(self) -> List[BaseRoute]:
|
||||
return list(self.router.iter_all_routes())
|
||||
|
||||
def build_middleware_stack(self) -> ASGIApp:
|
||||
# Duplicate/override from Starlette to add AsyncExitStackMiddleware
|
||||
# inside of ExceptionMiddleware, inside of custom user middlewares
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ def generate_operation_id(
|
|||
)
|
||||
if route.operation_id:
|
||||
return route.operation_id
|
||||
path: str = route.path_format
|
||||
path: str = route._route_full_path_format
|
||||
return generate_operation_id_for_path(name=route.name, path=path, method=method)
|
||||
|
||||
|
||||
|
|
@ -243,7 +243,7 @@ def get_openapi_path(
|
|||
model_name_map=model_name_map,
|
||||
operation_ids=operation_ids,
|
||||
)
|
||||
callbacks[callback.name] = {callback.path: cb_path}
|
||||
callbacks[callback.name] = {callback._route_full_path: cb_path}
|
||||
operation["callbacks"] = callbacks
|
||||
if route.status_code is not None:
|
||||
status_code = str(route.status_code)
|
||||
|
|
@ -422,7 +422,7 @@ def get_openapi(
|
|||
if result:
|
||||
path, security_schemes, path_definitions = result
|
||||
if path:
|
||||
paths.setdefault(route.path_format, {}).update(path)
|
||||
paths.setdefault(route._route_full_path_format, {}).update(path)
|
||||
if security_schemes:
|
||||
components.setdefault("securitySchemes", {}).update(
|
||||
security_schemes
|
||||
|
|
|
|||
|
|
@ -9,12 +9,14 @@ from typing import (
|
|||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
|
@ -57,6 +59,10 @@ from starlette.status import WS_1008_POLICY_VIOLATION
|
|||
from starlette.types import ASGIApp, Scope
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
APIRouteType = TypeVar("APIRouteType", bound="APIRoute")
|
||||
APIRouterType = TypeVar("APIRouterType", bound="APIRouter")
|
||||
APIMountType = TypeVar("APIMountType", bound="APIMount")
|
||||
|
||||
|
||||
def _prepare_response_content(
|
||||
res: Any,
|
||||
|
|
@ -338,13 +344,13 @@ class APIRoute(routing.Route):
|
|||
generate_unique_id_function: Union[
|
||||
Callable[["APIRoute"], str], DefaultPlaceholder
|
||||
] = Default(generate_unique_id),
|
||||
router: Optional["APIRouter"] = None,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
self.response_model = response_model
|
||||
self.summary = summary
|
||||
self.response_description = response_description
|
||||
self.deprecated = deprecated
|
||||
self.operation_id = operation_id
|
||||
self.response_model_include = response_model_include
|
||||
self.response_model_exclude = response_model_exclude
|
||||
|
|
@ -352,34 +358,128 @@ class APIRoute(routing.Route):
|
|||
self.response_model_exclude_unset = response_model_exclude_unset
|
||||
self.response_model_exclude_defaults = response_model_exclude_defaults
|
||||
self.response_model_exclude_none = response_model_exclude_none
|
||||
self.include_in_schema = include_in_schema
|
||||
self.response_class = response_class
|
||||
self.dependency_overrides_provider = dependency_overrides_provider
|
||||
self.callbacks = callbacks
|
||||
self.openapi_extra = openapi_extra
|
||||
self.generate_unique_id_function = generate_unique_id_function
|
||||
self.tags = tags or []
|
||||
self.responses = responses or {}
|
||||
self.router = router
|
||||
|
||||
self.name = get_name(endpoint) if name is None else name
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
if methods is None:
|
||||
methods = ["GET"]
|
||||
self.methods: Set[str] = set([method.upper() for method in methods])
|
||||
if isinstance(generate_unique_id_function, DefaultPlaceholder):
|
||||
current_generate_unique_id: Callable[
|
||||
["APIRoute"], str
|
||||
] = generate_unique_id_function.value
|
||||
else:
|
||||
current_generate_unique_id = generate_unique_id_function
|
||||
self.unique_id = self.operation_id or current_generate_unique_id(self)
|
||||
# normalize enums e.g. http.HTTPStatus
|
||||
if isinstance(status_code, IntEnum):
|
||||
status_code = int(status_code)
|
||||
self.status_code = status_code
|
||||
if methods is None:
|
||||
methods = ["GET"]
|
||||
self.methods: Set[str] = set([method.upper() for method in methods])
|
||||
|
||||
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
|
||||
# if a "form feed" character (page break) is found in the description text,
|
||||
# truncate description text to the content preceding the first "form feed"
|
||||
self.description = self.description.split("\f")[0]
|
||||
|
||||
assert callable(endpoint), "An endpoint must be a callable"
|
||||
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(
|
||||
self.path
|
||||
)
|
||||
|
||||
# Attributes set in route used to compute resolved attributes
|
||||
self._route_deprecated = deprecated
|
||||
self._route_include_in_schema = include_in_schema
|
||||
self._route_response_class = response_class
|
||||
self._route_callbacks = callbacks
|
||||
self._route_generate_unique_id_function = generate_unique_id_function
|
||||
self._route_tags = tags or []
|
||||
self._route_responses = responses or {}
|
||||
if dependencies:
|
||||
self._route_dependencies = dependencies
|
||||
else:
|
||||
self._route_dependencies = []
|
||||
|
||||
self.setup()
|
||||
|
||||
def setup(self) -> None:
|
||||
# setup full path
|
||||
self._route_full_path = self.path
|
||||
if self.router:
|
||||
self._route_full_path = self.router._router_full_path + self.path
|
||||
|
||||
# setup dependencies
|
||||
self.dependencies: List[params.Depends] = []
|
||||
if self.router:
|
||||
self.dependencies.extend(self.router.dependencies)
|
||||
self.dependencies.extend(self._route_dependencies)
|
||||
|
||||
# setup generate_unique_id
|
||||
generate_unique_id_functions: List[
|
||||
Union[Callable[[APIRoute], str], DefaultPlaceholder]
|
||||
] = [self._route_generate_unique_id_function]
|
||||
if self.router:
|
||||
generate_unique_id_functions.append(self.router.generate_unique_id_function)
|
||||
current_generate_unique_id_function = get_value_or_default(
|
||||
*generate_unique_id_functions
|
||||
)
|
||||
self.generate_unique_id_function: Union[
|
||||
Callable[[APIRoute], str], DefaultPlaceholder
|
||||
] = current_generate_unique_id_function
|
||||
|
||||
# setup responses
|
||||
responses: Dict[Union[int, str], Dict[str, Any]] = {}
|
||||
if self.router:
|
||||
responses.update(self.router.responses)
|
||||
responses.update(self._route_responses)
|
||||
self.responses: Dict[Union[int, str], Dict[str, Any]] = responses
|
||||
|
||||
# setup default_response_class
|
||||
default_response_classes: List[Union[Type[Response], DefaultPlaceholder]] = [
|
||||
self._route_response_class
|
||||
]
|
||||
if self.router:
|
||||
default_response_classes.append(self.router.default_response_class)
|
||||
current_default_response_class = get_value_or_default(*default_response_classes)
|
||||
self.response_class: Union[
|
||||
Type[Response], DefaultPlaceholder
|
||||
] = current_default_response_class
|
||||
|
||||
# setup tags
|
||||
self.tags: List[Union[str, Enum]] = []
|
||||
if self.router:
|
||||
self.tags.extend(self.router.tags)
|
||||
self.tags.extend(self._route_tags)
|
||||
|
||||
# setup callbacks
|
||||
callbacks: List[BaseRoute] = []
|
||||
if self.router:
|
||||
callbacks.extend(self.router.callbacks)
|
||||
if self._route_callbacks:
|
||||
callbacks.extend(self._route_callbacks)
|
||||
self.callbacks = callbacks
|
||||
|
||||
# setup deprecated
|
||||
self.deprecated = self._route_deprecated
|
||||
if self.router:
|
||||
self.deprecated = self._route_deprecated or self.router.deprecated
|
||||
|
||||
# setup include_in_schema
|
||||
self.include_in_schema = self._route_include_in_schema
|
||||
if self.router:
|
||||
self.include_in_schema = (
|
||||
self._route_include_in_schema and self.router.include_in_schema
|
||||
)
|
||||
|
||||
_, self._route_full_path_format, _ = compile_path(self._route_full_path)
|
||||
|
||||
if isinstance(self.generate_unique_id_function, DefaultPlaceholder):
|
||||
resolved_generate_unique_id: Callable[
|
||||
["APIRoute"], str
|
||||
] = self.generate_unique_id_function.value
|
||||
else:
|
||||
resolved_generate_unique_id = self.generate_unique_id_function
|
||||
self.unique_id = self.operation_id or resolved_generate_unique_id(self)
|
||||
|
||||
if self.response_model:
|
||||
assert (
|
||||
status_code not in STATUS_CODES_WITH_NO_BODY
|
||||
), f"Status code {status_code} must not have a response body"
|
||||
self.status_code not in STATUS_CODES_WITH_NO_BODY
|
||||
), f"Status code {self.status_code} must not have a response body"
|
||||
response_name = "Response_" + self.unique_id
|
||||
self.response_field = create_response_field(
|
||||
name=response_name, type_=self.response_model
|
||||
|
|
@ -397,14 +497,7 @@ class APIRoute(routing.Route):
|
|||
else:
|
||||
self.response_field = None # type: ignore
|
||||
self.secure_cloned_response_field = None
|
||||
if dependencies:
|
||||
self.dependencies = list(dependencies)
|
||||
else:
|
||||
self.dependencies = []
|
||||
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
|
||||
# if a "form feed" character (page break) is found in the description text,
|
||||
# truncate description text to the content preceding the first "form feed"
|
||||
self.description = self.description.split("\f")[0]
|
||||
|
||||
response_fields = {}
|
||||
for additional_status_code, response in self.responses.items():
|
||||
assert isinstance(response, dict), "An additional response must be a dict"
|
||||
|
|
@ -421,16 +514,50 @@ class APIRoute(routing.Route):
|
|||
else:
|
||||
self.response_fields = {}
|
||||
|
||||
assert callable(endpoint), "An endpoint must be a callable"
|
||||
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
|
||||
self.dependant = get_dependant(
|
||||
path=self._route_full_path_format, call=self.endpoint
|
||||
)
|
||||
for depends in self.dependencies[::-1]:
|
||||
self.dependant.dependencies.insert(
|
||||
0,
|
||||
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||
get_parameterless_sub_dependant(
|
||||
depends=depends, path=self._route_full_path_format
|
||||
),
|
||||
)
|
||||
self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
|
||||
self.app = request_response(self.get_route_handler())
|
||||
|
||||
def copy(self: APIRouteType) -> APIRouteType:
|
||||
return type(self)(
|
||||
path=self.path,
|
||||
endpoint=self.endpoint,
|
||||
response_model=self.response_model,
|
||||
status_code=self.status_code,
|
||||
tags=self._route_tags,
|
||||
dependencies=self._route_dependencies,
|
||||
summary=self.summary,
|
||||
description=self.description,
|
||||
response_description=self.response_description,
|
||||
responses=self._route_responses,
|
||||
deprecated=self._route_deprecated,
|
||||
name=self.name,
|
||||
methods=self.methods,
|
||||
operation_id=self.operation_id,
|
||||
response_model_include=self.response_model_include,
|
||||
response_model_exclude=self.response_model_exclude,
|
||||
response_model_by_alias=self.response_model_by_alias,
|
||||
response_model_exclude_unset=self.response_model_exclude_unset,
|
||||
response_model_exclude_defaults=self.response_model_exclude_defaults,
|
||||
response_model_exclude_none=self.response_model_exclude_none,
|
||||
include_in_schema=self._route_include_in_schema,
|
||||
response_class=self._route_response_class,
|
||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
callbacks=self._route_callbacks,
|
||||
openapi_extra=self.openapi_extra,
|
||||
generate_unique_id_function=self._route_generate_unique_id_function,
|
||||
router=self.router,
|
||||
)
|
||||
|
||||
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
return get_request_handler(
|
||||
dependant=self.dependant,
|
||||
|
|
@ -476,6 +603,7 @@ class APIRouter(routing.Router):
|
|||
generate_unique_id_function: Callable[[APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
parent_router: Optional["APIRouter"] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
routes=routes, # type: ignore # in Starlette
|
||||
|
|
@ -490,16 +618,151 @@ class APIRouter(routing.Router):
|
|||
"/"
|
||||
), "A path prefix must not end with '/', as the routes will start with '/'"
|
||||
self.prefix = prefix
|
||||
self.tags: List[Union[str, Enum]] = tags or []
|
||||
self.dependencies = list(dependencies or []) or []
|
||||
self.deprecated = deprecated
|
||||
self.include_in_schema = include_in_schema
|
||||
self.responses = responses or {}
|
||||
self.callbacks = callbacks or []
|
||||
self.dependency_overrides_provider = dependency_overrides_provider
|
||||
self.route_class = route_class
|
||||
self.default_response_class = default_response_class
|
||||
self.generate_unique_id_function = generate_unique_id_function
|
||||
|
||||
self.parent_router = parent_router
|
||||
|
||||
# Attributes set in router used to compute resolved attributes
|
||||
self._router_dependencies = list(dependencies or []) or []
|
||||
self._router_generate_unique_id_function = generate_unique_id_function
|
||||
self._router_responses = responses or {}
|
||||
self._router_default_response_class = default_response_class
|
||||
self._router_tags: List[Union[str, Enum]] = tags or []
|
||||
self._router_callbacks = callbacks or []
|
||||
self._router_deprecated = deprecated
|
||||
self._router_include_in_schema = include_in_schema
|
||||
self._router_has_empty_route = False
|
||||
self._router_has_root_route = False
|
||||
self.setup()
|
||||
|
||||
def setup(self) -> None:
|
||||
# setup full path
|
||||
self._router_full_path = self.prefix
|
||||
if self.parent_router:
|
||||
self._router_full_path = self.parent_router._router_full_path + self.prefix
|
||||
# setup dependencies
|
||||
self.dependencies: List[params.Depends] = []
|
||||
if self.parent_router:
|
||||
self.dependencies.extend(self.parent_router.dependencies)
|
||||
self.dependencies.extend(self._router_dependencies)
|
||||
|
||||
# setup generate_unique_id
|
||||
generate_unique_id_functions: List[
|
||||
Union[Callable[[APIRoute], str], DefaultPlaceholder]
|
||||
] = [self._router_generate_unique_id_function]
|
||||
if self.parent_router:
|
||||
generate_unique_id_functions.append(
|
||||
self.parent_router.generate_unique_id_function
|
||||
)
|
||||
current_generate_unique_id_function = get_value_or_default(
|
||||
*generate_unique_id_functions
|
||||
)
|
||||
self.generate_unique_id_function: Union[
|
||||
Callable[[APIRoute], str], DefaultPlaceholder
|
||||
] = current_generate_unique_id_function
|
||||
|
||||
# setup responses
|
||||
responses: Dict[Union[int, str], Dict[str, Any]] = {}
|
||||
if self.parent_router:
|
||||
responses.update(self.parent_router.responses)
|
||||
responses.update(self._router_responses)
|
||||
self.responses: Dict[Union[int, str], Dict[str, Any]] = responses
|
||||
|
||||
# setup default_response_class
|
||||
default_response_classes: List[Union[Type[Response], DefaultPlaceholder]] = [
|
||||
self._router_default_response_class
|
||||
]
|
||||
if self.parent_router:
|
||||
default_response_classes.append(self.parent_router.default_response_class)
|
||||
current_default_response_class = get_value_or_default(*default_response_classes)
|
||||
self.default_response_class: Union[
|
||||
Type[Response], DefaultPlaceholder
|
||||
] = current_default_response_class
|
||||
|
||||
# setup tags
|
||||
self.tags: List[Union[str, Enum]] = []
|
||||
if self.parent_router:
|
||||
self.tags.extend(self.parent_router.tags)
|
||||
self.tags.extend(self._router_tags)
|
||||
|
||||
# setup callbacks
|
||||
self.callbacks: List[BaseRoute] = []
|
||||
if self.parent_router:
|
||||
self.callbacks.extend(self.parent_router.callbacks)
|
||||
self.callbacks.extend(self._router_callbacks)
|
||||
|
||||
# setup deprecated
|
||||
self.deprecated = self._router_deprecated
|
||||
if self.parent_router:
|
||||
self.deprecated = self._router_deprecated or self.parent_router.deprecated
|
||||
|
||||
# setup include_in_schema
|
||||
self.include_in_schema = self._router_include_in_schema
|
||||
if self.parent_router:
|
||||
self.include_in_schema = (
|
||||
self._router_include_in_schema and self.parent_router.include_in_schema
|
||||
)
|
||||
|
||||
# setup routes
|
||||
for route in self.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
route.router = self
|
||||
route.setup()
|
||||
elif isinstance(route, APIMount):
|
||||
route.parent_router = self
|
||||
route.setup()
|
||||
|
||||
def copy(self: APIRouterType) -> APIRouterType:
|
||||
routes: List[routing.BaseRoute] = []
|
||||
for route in self.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
routes.append(route.copy())
|
||||
elif isinstance(route, APIMount):
|
||||
routes.append(route.copy())
|
||||
else:
|
||||
routes.append(route)
|
||||
copied_router = type(self)(
|
||||
prefix=self.prefix,
|
||||
tags=self._router_tags,
|
||||
dependencies=self._router_dependencies,
|
||||
default_response_class=self._router_default_response_class,
|
||||
responses=self._router_responses,
|
||||
callbacks=self._router_callbacks,
|
||||
routes=routes,
|
||||
redirect_slashes=self.redirect_slashes,
|
||||
default=self.default,
|
||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
route_class=self.route_class,
|
||||
on_startup=self.on_startup,
|
||||
on_shutdown=self.on_shutdown,
|
||||
deprecated=self._router_deprecated,
|
||||
include_in_schema=self._router_include_in_schema,
|
||||
generate_unique_id_function=self._router_generate_unique_id_function,
|
||||
parent_router=self.parent_router,
|
||||
)
|
||||
copied_router._router_has_empty_route = self._router_has_empty_route
|
||||
copied_router._router_has_root_route = self._router_has_root_route
|
||||
for route in copied_router.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
route.router = copied_router
|
||||
route.setup()
|
||||
elif isinstance(route, Mount):
|
||||
if isinstance(route.app, APIRouter):
|
||||
route.app.setup()
|
||||
return copied_router
|
||||
|
||||
def iter_all_routes(self) -> Iterator[routing.BaseRoute]:
|
||||
for route in self.routes:
|
||||
if isinstance(route, Mount):
|
||||
if isinstance(route.app, APIRouter):
|
||||
yield from route.app.iter_all_routes()
|
||||
else:
|
||||
yield route
|
||||
|
||||
def api_mount(self, router: "APIRouter", name: Optional[str] = None) -> None:
|
||||
route = APIMount(router=router, name=name, parent_router=self)
|
||||
self.routes.append(route)
|
||||
|
||||
def add_api_route(
|
||||
self,
|
||||
|
|
@ -537,34 +800,18 @@ class APIRouter(routing.Router):
|
|||
) -> None:
|
||||
route_class = route_class_override or self.route_class
|
||||
responses = responses or {}
|
||||
combined_responses = {**self.responses, **responses}
|
||||
current_response_class = get_value_or_default(
|
||||
response_class, self.default_response_class
|
||||
)
|
||||
current_tags = self.tags.copy()
|
||||
if tags:
|
||||
current_tags.extend(tags)
|
||||
current_dependencies = self.dependencies.copy()
|
||||
if dependencies:
|
||||
current_dependencies.extend(dependencies)
|
||||
current_callbacks = self.callbacks.copy()
|
||||
if callbacks:
|
||||
current_callbacks.extend(callbacks)
|
||||
current_generate_unique_id = get_value_or_default(
|
||||
generate_unique_id_function, self.generate_unique_id_function
|
||||
)
|
||||
route = route_class(
|
||||
self.prefix + path,
|
||||
path,
|
||||
endpoint=endpoint,
|
||||
response_model=response_model,
|
||||
status_code=status_code,
|
||||
tags=current_tags,
|
||||
dependencies=current_dependencies,
|
||||
tags=tags,
|
||||
dependencies=dependencies,
|
||||
summary=summary,
|
||||
description=description,
|
||||
response_description=response_description,
|
||||
responses=combined_responses,
|
||||
deprecated=deprecated or self.deprecated,
|
||||
responses=responses,
|
||||
deprecated=deprecated,
|
||||
methods=methods,
|
||||
operation_id=operation_id,
|
||||
response_model_include=response_model_include,
|
||||
|
|
@ -573,15 +820,20 @@ class APIRouter(routing.Router):
|
|||
response_model_exclude_unset=response_model_exclude_unset,
|
||||
response_model_exclude_defaults=response_model_exclude_defaults,
|
||||
response_model_exclude_none=response_model_exclude_none,
|
||||
include_in_schema=include_in_schema and self.include_in_schema,
|
||||
response_class=current_response_class,
|
||||
include_in_schema=include_in_schema,
|
||||
response_class=response_class,
|
||||
name=name,
|
||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
callbacks=current_callbacks,
|
||||
callbacks=callbacks,
|
||||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=current_generate_unique_id,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
router=self,
|
||||
)
|
||||
self.routes.append(route)
|
||||
if not path:
|
||||
self._router_has_empty_route = True
|
||||
if path == "/":
|
||||
self._router_has_root_route = True
|
||||
|
||||
def api_route(
|
||||
self,
|
||||
|
|
@ -680,103 +932,197 @@ class APIRouter(routing.Router):
|
|||
generate_unique_id_function: Callable[[APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
copy_flat_routes: Optional[bool] = None,
|
||||
) -> None:
|
||||
if prefix:
|
||||
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
||||
assert not prefix.endswith(
|
||||
"/"
|
||||
), "A path prefix must not end with '/', as the routes will start with '/'"
|
||||
else:
|
||||
for r in router.routes:
|
||||
path = getattr(r, "path")
|
||||
name = getattr(r, "name", "unknown")
|
||||
if path is not None and not path:
|
||||
raise Exception(
|
||||
f"Prefix and path cannot be both empty (path operation: {name})"
|
||||
resolved_copy_flat_routes = copy_flat_routes
|
||||
if resolved_copy_flat_routes is None:
|
||||
resolved_copy_flat_routes = not (prefix or router.prefix)
|
||||
if not resolved_copy_flat_routes:
|
||||
included_router = router.copy()
|
||||
if (
|
||||
prefix
|
||||
or tags
|
||||
or dependencies
|
||||
or not isinstance(default_response_class, DefaultPlaceholder)
|
||||
or responses
|
||||
or callbacks
|
||||
or deprecated is not None
|
||||
or include_in_schema is not True
|
||||
or not isinstance(generate_unique_id_function, DefaultPlaceholder)
|
||||
):
|
||||
current_router = type(self)(
|
||||
prefix=prefix,
|
||||
tags=tags,
|
||||
dependencies=dependencies,
|
||||
default_response_class=default_response_class,
|
||||
responses=responses,
|
||||
callbacks=callbacks,
|
||||
deprecated=deprecated,
|
||||
include_in_schema=include_in_schema,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
parent_router=self,
|
||||
)
|
||||
# current_router.api_mount(included_router)
|
||||
current_router.include_router(included_router)
|
||||
if included_router._router_has_empty_route and not self.prefix:
|
||||
current_router._router_has_empty_route = True
|
||||
current_router._router_has_root_route = (
|
||||
included_router._router_has_root_route
|
||||
)
|
||||
if responses is None:
|
||||
responses = {}
|
||||
for route in router.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
combined_responses = {**responses, **route.responses}
|
||||
use_response_class = get_value_or_default(
|
||||
route.response_class,
|
||||
router.default_response_class,
|
||||
default_response_class,
|
||||
self.default_response_class,
|
||||
)
|
||||
current_tags = []
|
||||
if tags:
|
||||
current_tags.extend(tags)
|
||||
if route.tags:
|
||||
current_tags.extend(route.tags)
|
||||
current_dependencies: List[params.Depends] = []
|
||||
if dependencies:
|
||||
current_dependencies.extend(dependencies)
|
||||
if route.dependencies:
|
||||
current_dependencies.extend(route.dependencies)
|
||||
current_callbacks = []
|
||||
if callbacks:
|
||||
current_callbacks.extend(callbacks)
|
||||
if route.callbacks:
|
||||
current_callbacks.extend(route.callbacks)
|
||||
current_generate_unique_id = get_value_or_default(
|
||||
route.generate_unique_id_function,
|
||||
router.generate_unique_id_function,
|
||||
generate_unique_id_function,
|
||||
self.generate_unique_id_function,
|
||||
)
|
||||
self.add_api_route(
|
||||
prefix + route.path,
|
||||
route.endpoint,
|
||||
response_model=route.response_model,
|
||||
status_code=route.status_code,
|
||||
tags=current_tags,
|
||||
dependencies=current_dependencies,
|
||||
summary=route.summary,
|
||||
description=route.description,
|
||||
response_description=route.response_description,
|
||||
responses=combined_responses,
|
||||
deprecated=route.deprecated or deprecated or self.deprecated,
|
||||
methods=route.methods,
|
||||
operation_id=route.operation_id,
|
||||
response_model_include=route.response_model_include,
|
||||
response_model_exclude=route.response_model_exclude,
|
||||
response_model_by_alias=route.response_model_by_alias,
|
||||
response_model_exclude_unset=route.response_model_exclude_unset,
|
||||
response_model_exclude_defaults=route.response_model_exclude_defaults,
|
||||
response_model_exclude_none=route.response_model_exclude_none,
|
||||
include_in_schema=route.include_in_schema
|
||||
and self.include_in_schema
|
||||
and include_in_schema,
|
||||
response_class=use_response_class,
|
||||
name=route.name,
|
||||
route_class_override=type(route),
|
||||
callbacks=current_callbacks,
|
||||
openapi_extra=route.openapi_extra,
|
||||
generate_unique_id_function=current_generate_unique_id,
|
||||
)
|
||||
elif isinstance(route, routing.Route):
|
||||
methods = list(route.methods or []) # type: ignore # in Starlette
|
||||
self.add_route(
|
||||
prefix + route.path,
|
||||
route.endpoint,
|
||||
methods=methods,
|
||||
include_in_schema=route.include_in_schema,
|
||||
name=route.name,
|
||||
)
|
||||
elif isinstance(route, APIWebSocketRoute):
|
||||
self.add_api_websocket_route(
|
||||
prefix + route.path, route.endpoint, name=route.name
|
||||
)
|
||||
elif isinstance(route, routing.WebSocketRoute):
|
||||
self.add_websocket_route(
|
||||
prefix + route.path, route.endpoint, name=route.name
|
||||
)
|
||||
for handler in router.on_startup:
|
||||
self.add_event_handler("startup", handler)
|
||||
for handler in router.on_shutdown:
|
||||
self.add_event_handler("shutdown", handler)
|
||||
self.api_mount(current_router)
|
||||
included_router.parent_router = current_router
|
||||
else:
|
||||
self.api_mount(included_router)
|
||||
included_router.parent_router = self
|
||||
|
||||
included_router.setup()
|
||||
else:
|
||||
# TODO: remove this and its test, as a subrouter can mount another
|
||||
# subrouter (done automatically of other things are overwritten) and both
|
||||
# can omit a prefix, this would error out
|
||||
# for r in router.routes:
|
||||
# path = getattr(r, "path")
|
||||
# name = getattr(r, "name", "unknown")
|
||||
# if path is not None and not path:
|
||||
# raise Exception(
|
||||
# f"Prefix and path cannot be both empty (path operation: {name})"
|
||||
# )
|
||||
if responses is None:
|
||||
responses = {}
|
||||
for route in router.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
combined_responses = {}
|
||||
if route.router:
|
||||
combined_responses.update(route.router.responses)
|
||||
combined_responses.update(responses)
|
||||
combined_responses.update(route.responses)
|
||||
|
||||
response_classes: List[
|
||||
Union[Type[Response], DefaultPlaceholder]
|
||||
] = []
|
||||
if route.router:
|
||||
response_classes.append(route.router.default_response_class)
|
||||
response_classes.extend(
|
||||
[
|
||||
route.response_class,
|
||||
router.default_response_class,
|
||||
default_response_class,
|
||||
self.default_response_class,
|
||||
]
|
||||
)
|
||||
use_response_class = get_value_or_default(*response_classes)
|
||||
current_tags = []
|
||||
if route.router:
|
||||
current_tags.extend(route.router.tags)
|
||||
if tags:
|
||||
current_tags.extend(tags)
|
||||
if route.tags:
|
||||
current_tags.extend(route.tags)
|
||||
current_dependencies: List[params.Depends] = []
|
||||
if route.router:
|
||||
current_dependencies.extend(route.router.dependencies)
|
||||
if dependencies:
|
||||
current_dependencies.extend(dependencies)
|
||||
if route.dependencies:
|
||||
current_dependencies.extend(route.dependencies)
|
||||
current_callbacks = []
|
||||
if route.router:
|
||||
current_callbacks.extend(route.router.callbacks)
|
||||
if callbacks:
|
||||
current_callbacks.extend(callbacks)
|
||||
if route.callbacks:
|
||||
current_callbacks.extend(route.callbacks)
|
||||
|
||||
generate_unique_id_functions: List[
|
||||
Union[Callable[[APIRoute], str], DefaultPlaceholder]
|
||||
] = []
|
||||
if route.router:
|
||||
generate_unique_id_functions.append(
|
||||
route.router.generate_unique_id_function
|
||||
)
|
||||
generate_unique_id_functions.extend(
|
||||
[
|
||||
route.generate_unique_id_function,
|
||||
router.generate_unique_id_function,
|
||||
generate_unique_id_function,
|
||||
self.generate_unique_id_function,
|
||||
]
|
||||
)
|
||||
current_generate_unique_id_function = get_value_or_default(
|
||||
*generate_unique_id_functions
|
||||
)
|
||||
path = prefix + route.path
|
||||
if route.router:
|
||||
path = prefix + route.router.prefix + path
|
||||
self.add_api_route(
|
||||
path,
|
||||
route.endpoint,
|
||||
response_model=route.response_model,
|
||||
status_code=route.status_code,
|
||||
tags=current_tags,
|
||||
dependencies=current_dependencies,
|
||||
summary=route.summary,
|
||||
description=route.description,
|
||||
response_description=route.response_description,
|
||||
responses=combined_responses,
|
||||
deprecated=route.deprecated or deprecated or self.deprecated,
|
||||
methods=route.methods,
|
||||
operation_id=route.operation_id,
|
||||
response_model_include=route.response_model_include,
|
||||
response_model_exclude=route.response_model_exclude,
|
||||
response_model_by_alias=route.response_model_by_alias,
|
||||
response_model_exclude_unset=route.response_model_exclude_unset,
|
||||
response_model_exclude_defaults=route.response_model_exclude_defaults,
|
||||
response_model_exclude_none=route.response_model_exclude_none,
|
||||
include_in_schema=route.include_in_schema
|
||||
and self.include_in_schema
|
||||
and include_in_schema,
|
||||
response_class=use_response_class,
|
||||
name=route.name,
|
||||
route_class_override=type(route),
|
||||
callbacks=current_callbacks,
|
||||
openapi_extra=route.openapi_extra,
|
||||
generate_unique_id_function=current_generate_unique_id_function,
|
||||
)
|
||||
elif isinstance(route, APIMount):
|
||||
self.include_router(
|
||||
route.app,
|
||||
prefix=prefix,
|
||||
tags=tags,
|
||||
dependencies=dependencies,
|
||||
default_response_class=default_response_class,
|
||||
responses=responses,
|
||||
callbacks=callbacks,
|
||||
deprecated=deprecated,
|
||||
include_in_schema=include_in_schema,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
||||
elif isinstance(route, routing.Route):
|
||||
methods = list(route.methods or []) # type: ignore # in Starlette
|
||||
self.add_route(
|
||||
prefix + route.path,
|
||||
route.endpoint,
|
||||
methods=methods,
|
||||
include_in_schema=route.include_in_schema,
|
||||
name=route.name,
|
||||
)
|
||||
elif isinstance(route, APIWebSocketRoute):
|
||||
self.add_api_websocket_route(
|
||||
prefix + route.path, route.endpoint, name=route.name
|
||||
)
|
||||
elif isinstance(route, routing.WebSocketRoute):
|
||||
self.add_websocket_route(
|
||||
prefix + route.path, route.endpoint, name=route.name
|
||||
)
|
||||
for handler in router.on_startup:
|
||||
self.add_event_handler("startup", handler)
|
||||
for handler in router.on_shutdown:
|
||||
self.add_event_handler("shutdown", handler)
|
||||
|
||||
def get(
|
||||
self,
|
||||
|
|
@ -1226,3 +1572,100 @@ class APIRouter(routing.Router):
|
|||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
||||
|
||||
|
||||
class APIMount(routing.Mount):
|
||||
def __init__(
|
||||
self,
|
||||
router: APIRouter,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
parent_router: Optional[APIRouter] = None,
|
||||
) -> None:
|
||||
self.name = name # type: ignore # in Starlette
|
||||
self.parent_router = parent_router
|
||||
self.router = router
|
||||
|
||||
self.setup()
|
||||
|
||||
def setup(self) -> None:
|
||||
self.app: APIRouter = self.router.copy()
|
||||
if self.parent_router:
|
||||
self.app.parent_router = self.parent_router
|
||||
self.app.setup()
|
||||
self.path = self.app.prefix
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(
|
||||
self.path + "/{path:path}"
|
||||
)
|
||||
|
||||
# Add custom additional root without trailing slash for compatibility with
|
||||
# include_router and possibly app migrations
|
||||
# Ref: https://github.com/tiangolo/fastapi/issues/414
|
||||
(
|
||||
self._root_path_regex,
|
||||
self._root_path_format,
|
||||
self._root_param_convertors,
|
||||
) = compile_path(self.path)
|
||||
(
|
||||
self._root_path_regex_trailing,
|
||||
self._root_path_format_trailing,
|
||||
self._root_param_convertors_trailing,
|
||||
) = compile_path(self.path + "/")
|
||||
|
||||
def copy(self: APIMountType) -> APIMountType:
|
||||
return type(self)(
|
||||
router=self.router.copy(),
|
||||
name=self.name,
|
||||
parent_router=self.parent_router,
|
||||
)
|
||||
|
||||
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
|
||||
if scope["type"] in ("http", "websocket"):
|
||||
path = scope["path"]
|
||||
if self.app._router_has_empty_route:
|
||||
# Custom logic to support paths without trailing slash
|
||||
# Ref: https://github.com/tiangolo/fastapi/issues/414
|
||||
# This mixes the code in
|
||||
# starlette.routing.Route.matches() and starlette.routing.Mount.matches()
|
||||
match = self._root_path_regex.match(path)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
root_path = scope.get("root_path", "")
|
||||
child_scope = {
|
||||
"path_params": path_params,
|
||||
"app_root_path": scope.get("app_root_path", root_path),
|
||||
"root_path": root_path,
|
||||
"path": "",
|
||||
"endpoint": self.app,
|
||||
}
|
||||
return Match.FULL, child_scope
|
||||
if not self.app._router_has_root_route:
|
||||
match = self._root_path_regex_trailing.match(path)
|
||||
if match:
|
||||
return Match.NONE, {}
|
||||
# End of custom logic
|
||||
# Duplicated code from Starlette
|
||||
match = self.path_regex.match(path)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
remaining_path = "/" + matched_params.pop("path")
|
||||
matched_path = path[: -len(remaining_path)]
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
root_path = scope.get("root_path", "")
|
||||
child_scope = {
|
||||
"path_params": path_params,
|
||||
"app_root_path": scope.get("app_root_path", root_path),
|
||||
"root_path": root_path + matched_path,
|
||||
"path": remaining_path,
|
||||
"endpoint": self.app,
|
||||
}
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
# End of duplicated code from Starlette
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ def generate_operation_id_for_path(
|
|||
|
||||
|
||||
def generate_unique_id(route: "APIRoute") -> str:
|
||||
operation_id = route.name + route.path_format
|
||||
operation_id = route.name + route._route_full_path_format
|
||||
operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id)
|
||||
assert route.methods
|
||||
operation_id = operation_id + "_" + list(route.methods)[0].lower()
|
||||
|
|
|
|||
|
|
@ -107,9 +107,9 @@ def test_get_path(path, expected_status, expected_response):
|
|||
|
||||
def test_route_classes():
|
||||
routes = {}
|
||||
for r in app.router.routes:
|
||||
assert isinstance(r, Route)
|
||||
routes[r.path] = r
|
||||
for r in app.router.iter_all_routes():
|
||||
if isinstance(r, APIRoute):
|
||||
routes[r._route_full_path_format] = r
|
||||
assert getattr(routes["/a/"], "x_type") == "A"
|
||||
assert getattr(routes["/a/b/"], "x_type") == "B"
|
||||
assert getattr(routes["/a/b/c/"], "x_type") == "C"
|
||||
|
|
|
|||
Loading…
Reference in New Issue