mirror of https://github.com/tiangolo/fastapi.git
Merge 42d4db860e into 0127069d47
This commit is contained in:
commit
7f4d0c063f
|
|
@ -0,0 +1,79 @@
|
|||
# Router-Level Exception Handlers { #router-level-exception-handlers }
|
||||
|
||||
In the [Tutorial - Handling Errors](../tutorial/handling-errors.md){.internal-link target=_blank} you learned how to add custom exception handlers to your **FastAPI** application using `@app.exception_handler()`.
|
||||
|
||||
Those handlers apply **globally** to all routes in the application.
|
||||
|
||||
But sometimes you want to handle exceptions differently depending on which part of your application raised them. For example, a group of routes related to payments might need different error handling than routes for user profiles.
|
||||
|
||||
You can do this by adding **exception handlers directly to an `APIRouter`**.
|
||||
|
||||
## Add Exception Handlers to a Router { #add-exception-handlers-to-a-router }
|
||||
|
||||
You can pass `exception_handlers` when creating an `APIRouter`, using the same format as with the `FastAPI` app — a dictionary where keys are exception classes (or status codes) and values are handler functions:
|
||||
|
||||
{* ../../docs_src/handling_errors/tutorial007_py310.py hl[10:20] *}
|
||||
|
||||
Now, if a `UnicornException` is raised in any route within this router, the router's handler will catch it.
|
||||
|
||||
Routes outside this router are **not affected** by the router's exception handlers.
|
||||
|
||||
## Router Handlers Override App Handlers { #router-handlers-override-app-handlers }
|
||||
|
||||
If both the app and a router define a handler for the same exception, the **router's handler takes priority** for routes within that router. The app-level handler still applies to all other routes.
|
||||
|
||||
{* ../../docs_src/handling_errors/tutorial008_py310.py hl[13:21,44:49] *}
|
||||
|
||||
In this example:
|
||||
|
||||
* A request to `/magic/unicorns/yolo` uses the **global** handler (the `magic_router` doesn't define its own).
|
||||
* A request to `/special/unicorns/yolo` uses the **router-level** handler (defined on the `special_router`).
|
||||
|
||||
This lets you customize error handling per section of your API while keeping a sensible default at the app level.
|
||||
|
||||
## Using `add_exception_handler` { #using-add-exception-handler }
|
||||
|
||||
You can also add exception handlers to a router after creation, using the `add_exception_handler()` method:
|
||||
|
||||
```python
|
||||
router = APIRouter()
|
||||
router.add_exception_handler(UnicornException, unicorn_exception_handler)
|
||||
```
|
||||
|
||||
This works the same as passing `exception_handlers` in the constructor.
|
||||
|
||||
## Nested Router Precedence { #nested-router-precedence }
|
||||
|
||||
When routers are nested (a router includes another router), exception handlers follow this precedence order (highest to lowest):
|
||||
|
||||
1. The **innermost (child) router**'s handlers
|
||||
2. The **parent router**'s handlers
|
||||
3. The **app-level** handlers
|
||||
|
||||
This means a child router can override its parent's handlers for its own routes.
|
||||
|
||||
## Status Code Handlers { #status-code-handlers }
|
||||
|
||||
Just like app-level handlers, you can use **integer status codes** as keys to handle specific HTTP error responses:
|
||||
|
||||
```python
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
def not_found_handler(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={"message": "Custom 404: resource not found"},
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
exception_handlers={404: not_found_handler}
|
||||
)
|
||||
```
|
||||
|
||||
## Recap { #recap }
|
||||
|
||||
* Pass `exception_handlers` to `APIRouter()` to scope handlers to that router's routes.
|
||||
* Use `router.add_exception_handler()` to add handlers after creation.
|
||||
* Router-level handlers **override** app-level handlers for the same exception type.
|
||||
* Routes **outside** the router are unaffected.
|
||||
* Nested routers follow **child > parent > app** precedence.
|
||||
|
|
@ -12,6 +12,7 @@ from fastapi import APIRouter
|
|||
options:
|
||||
members:
|
||||
- websocket
|
||||
- add_exception_handler
|
||||
- include_router
|
||||
- get
|
||||
- put
|
||||
|
|
|
|||
|
|
@ -242,3 +242,11 @@ If you want to use the exception along with the same default exception handlers
|
|||
{* ../../docs_src/handling_errors/tutorial006_py310.py hl[2:5,15,21] *}
|
||||
|
||||
In this example you are just printing the error with a very expressive message, but you get the idea. You can use the exception and then just reuse the default exception handlers.
|
||||
|
||||
## Router-Level Exception Handlers { #router-level-exception-handlers }
|
||||
|
||||
The examples above add exception handlers to the **application**, so they apply to all routes.
|
||||
|
||||
You can also add exception handlers scoped to a specific **`APIRouter`**, so they only apply to the routes in that router.
|
||||
|
||||
Read more about it in [Advanced User Guide - Router-Level Exception Handlers](../advanced/router-exception-handlers.md){.internal-link target=_blank}.
|
||||
|
|
|
|||
|
|
@ -173,6 +173,7 @@ nav:
|
|||
- advanced/response-headers.md
|
||||
- advanced/response-change-status-code.md
|
||||
- advanced/advanced-dependencies.md
|
||||
- advanced/router-exception-handlers.md
|
||||
- Advanced Security:
|
||||
- advanced/security/index.md
|
||||
- advanced/security/oauth2-scopes.md
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
class UnicornException(Exception):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/unicorns",
|
||||
exception_handlers={
|
||||
UnicornException: lambda request, exc: JSONResponse(
|
||||
status_code=418,
|
||||
content={
|
||||
"message": f"Oops! {exc.name} did something. There goes a rainbow..."
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{name}")
|
||||
async def read_unicorn(name: str):
|
||||
if name == "yolo":
|
||||
raise UnicornException(name=name)
|
||||
return {"unicorn_name": name}
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
class UnicornException(Exception):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.exception_handler(UnicornException)
|
||||
async def global_unicorn_exception_handler(request: Request, exc: UnicornException):
|
||||
return JSONResponse(
|
||||
status_code=418,
|
||||
content={"message": f"Global handler: {exc.name} did something."},
|
||||
)
|
||||
|
||||
|
||||
magic_router = APIRouter(prefix="/magic")
|
||||
|
||||
|
||||
@magic_router.get("/unicorns/{name}")
|
||||
async def read_magic_unicorn(name: str):
|
||||
if name == "yolo":
|
||||
raise UnicornException(name=name)
|
||||
return {"unicorn_name": name}
|
||||
|
||||
|
||||
def custom_unicorn_handler(request: Request, exc: UnicornException):
|
||||
return JSONResponse(
|
||||
status_code=418,
|
||||
content={"message": f"Special handler: {exc.name} did something magical!"},
|
||||
)
|
||||
|
||||
|
||||
special_router = APIRouter(
|
||||
prefix="/special",
|
||||
exception_handlers={UnicornException: custom_unicorn_handler},
|
||||
)
|
||||
|
||||
|
||||
@special_router.get("/unicorns/{name}")
|
||||
async def read_special_unicorn(name: str):
|
||||
if name == "yolo":
|
||||
raise UnicornException(name=name)
|
||||
return {"unicorn_name": name}
|
||||
|
||||
|
||||
app.include_router(magic_router)
|
||||
app.include_router(special_router)
|
||||
|
|
@ -96,6 +96,7 @@ from typing_extensions import deprecated
|
|||
# dependencies' AsyncExitStack
|
||||
def request_response(
|
||||
func: Callable[[Request], Awaitable[Response] | Response],
|
||||
exception_handlers: dict[int | type[Exception], Callable[..., Any]] | None = None,
|
||||
) -> ASGIApp:
|
||||
"""
|
||||
Takes a function or coroutine `func(request) -> response`,
|
||||
|
|
@ -108,6 +109,18 @@ def request_response(
|
|||
) # ty: ignore[invalid-assignment]
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if exception_handlers:
|
||||
existing_exc, existing_status = scope.get(
|
||||
"starlette.exception_handlers", ({}, {})
|
||||
)
|
||||
merged_exc = {**existing_exc}
|
||||
merged_status = {**existing_status}
|
||||
for key, handler in exception_handlers.items():
|
||||
if isinstance(key, int):
|
||||
merged_status[key] = handler
|
||||
else:
|
||||
merged_exc[key] = handler
|
||||
scope["starlette.exception_handlers"] = (merged_exc, merged_status)
|
||||
request = Request(scope, receive, send)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
|
|
@ -140,6 +153,7 @@ def request_response(
|
|||
# dependencies' AsyncExitStack
|
||||
def websocket_session(
|
||||
func: Callable[[WebSocket], Awaitable[None]],
|
||||
exception_handlers: dict[int | type[Exception], Callable[..., Any]] | None = None,
|
||||
) -> ASGIApp:
|
||||
"""
|
||||
Takes a coroutine `func(session)`, and returns an ASGI application.
|
||||
|
|
@ -147,6 +161,18 @@ def websocket_session(
|
|||
# assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async"
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if exception_handlers:
|
||||
existing_exc, existing_status = scope.get(
|
||||
"starlette.exception_handlers", ({}, {})
|
||||
)
|
||||
merged_exc = {**existing_exc}
|
||||
merged_status = {**existing_status}
|
||||
for key, handler in exception_handlers.items():
|
||||
if isinstance(key, int):
|
||||
merged_status[key] = handler
|
||||
else:
|
||||
merged_exc[key] = handler
|
||||
scope["starlette.exception_handlers"] = (merged_exc, merged_status)
|
||||
session = WebSocket(scope, receive=receive, send=send)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
|
|
@ -775,11 +801,14 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
|||
name: str | None = None,
|
||||
dependencies: Sequence[params.Depends] | None = None,
|
||||
dependency_overrides_provider: Any | None = None,
|
||||
exception_handlers: dict[int | type[Exception], Callable[..., Any]]
|
||||
| None = None,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
self.name = get_name(endpoint) if name is None else name
|
||||
self.dependencies = list(dependencies or [])
|
||||
self.exception_handlers = exception_handlers or {}
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
self.dependant = get_dependant(
|
||||
path=self.path_format, call=self.endpoint, scope="function"
|
||||
|
|
@ -798,7 +827,8 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
|||
dependant=self.dependant,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
),
|
||||
exception_handlers=self.exception_handlers or None,
|
||||
)
|
||||
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
|
|
@ -840,9 +870,12 @@ class APIRoute(routing.Route):
|
|||
generate_unique_id_function: Callable[["APIRoute"], str]
|
||||
| DefaultPlaceholder = Default(generate_unique_id),
|
||||
strict_content_type: bool | DefaultPlaceholder = Default(True),
|
||||
exception_handlers: dict[int | type[Exception], Callable[..., Any]]
|
||||
| None = None,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
self.exception_handlers = exception_handlers or {}
|
||||
self.stream_item_type: Any | None = None
|
||||
if isinstance(response_model, DefaultPlaceholder):
|
||||
return_annotation = get_typed_return_annotation(endpoint)
|
||||
|
|
@ -973,7 +1006,10 @@ class APIRoute(routing.Route):
|
|||
self.is_json_stream = is_generator and isinstance(
|
||||
response_class, DefaultPlaceholder
|
||||
)
|
||||
self.app = request_response(self.get_route_handler())
|
||||
self.app = request_response(
|
||||
self.get_route_handler(),
|
||||
exception_handlers=self.exception_handlers or None,
|
||||
)
|
||||
|
||||
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
return get_request_handler(
|
||||
|
|
@ -1266,6 +1302,18 @@ class APIRouter(routing.Router):
|
|||
"""
|
||||
),
|
||||
] = Default(True),
|
||||
exception_handlers: Annotated[
|
||||
dict[int | type[Exception], Callable[..., Any]] | None,
|
||||
Doc(
|
||||
"""
|
||||
A dictionary of exception handlers scoped to this router.
|
||||
|
||||
Keys can be exception classes or HTTP status codes. Handlers
|
||||
defined here will override app-level exception handlers for
|
||||
routes within this router.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
) -> None:
|
||||
# Determine the lifespan context to use
|
||||
if lifespan is None:
|
||||
|
|
@ -1313,6 +1361,16 @@ class APIRouter(routing.Router):
|
|||
self.default_response_class = default_response_class
|
||||
self.generate_unique_id_function = generate_unique_id_function
|
||||
self.strict_content_type = strict_content_type
|
||||
self.exception_handlers: dict[int | type[Exception], Callable[..., Any]] = (
|
||||
{} if exception_handlers is None else dict(exception_handlers)
|
||||
)
|
||||
|
||||
def add_exception_handler(
|
||||
self,
|
||||
exc_class_or_status_code: int | type[Exception],
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
self.exception_handlers[exc_class_or_status_code] = handler
|
||||
|
||||
def route(
|
||||
self,
|
||||
|
|
@ -1364,6 +1422,8 @@ class APIRouter(routing.Router):
|
|||
generate_unique_id_function: Callable[[APIRoute], str]
|
||||
| DefaultPlaceholder = Default(generate_unique_id),
|
||||
strict_content_type: bool | DefaultPlaceholder = Default(True),
|
||||
exception_handlers: dict[int | type[Exception], Callable[..., Any]]
|
||||
| None = None,
|
||||
) -> None:
|
||||
route_class = route_class_override or self.route_class
|
||||
responses = responses or {}
|
||||
|
|
@ -1383,6 +1443,11 @@ class APIRouter(routing.Router):
|
|||
current_generate_unique_id = get_value_or_default(
|
||||
generate_unique_id_function, self.generate_unique_id_function
|
||||
)
|
||||
current_exception_handlers: dict[int | type[Exception], Callable[..., Any]] = (
|
||||
dict(self.exception_handlers)
|
||||
)
|
||||
if exception_handlers:
|
||||
current_exception_handlers.update(exception_handlers)
|
||||
route = route_class(
|
||||
self.prefix + path,
|
||||
endpoint=endpoint,
|
||||
|
|
@ -1413,6 +1478,7 @@ class APIRouter(routing.Router):
|
|||
strict_content_type=get_value_or_default(
|
||||
strict_content_type, self.strict_content_type
|
||||
),
|
||||
exception_handlers=current_exception_handlers or None,
|
||||
)
|
||||
self.routes.append(route)
|
||||
|
||||
|
|
@ -1485,10 +1551,17 @@ class APIRouter(routing.Router):
|
|||
name: str | None = None,
|
||||
*,
|
||||
dependencies: Sequence[params.Depends] | None = None,
|
||||
exception_handlers: dict[int | type[Exception], Callable[..., Any]]
|
||||
| None = None,
|
||||
) -> None:
|
||||
current_dependencies = self.dependencies.copy()
|
||||
if dependencies:
|
||||
current_dependencies.extend(dependencies)
|
||||
current_exception_handlers: dict[int | type[Exception], Callable[..., Any]] = (
|
||||
dict(self.exception_handlers)
|
||||
)
|
||||
if exception_handlers:
|
||||
current_exception_handlers.update(exception_handlers)
|
||||
|
||||
route = APIWebSocketRoute(
|
||||
self.prefix + path,
|
||||
|
|
@ -1496,6 +1569,7 @@ class APIRouter(routing.Router):
|
|||
name=name,
|
||||
dependencies=current_dependencies,
|
||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
exception_handlers=current_exception_handlers or None,
|
||||
)
|
||||
self.routes.append(route)
|
||||
|
||||
|
|
@ -1759,6 +1833,13 @@ class APIRouter(routing.Router):
|
|||
generate_unique_id_function,
|
||||
self.generate_unique_id_function,
|
||||
)
|
||||
current_exception_handlers: dict[
|
||||
int | type[Exception], Callable[..., Any]
|
||||
] = {}
|
||||
if router.exception_handlers:
|
||||
current_exception_handlers.update(router.exception_handlers)
|
||||
if route.exception_handlers:
|
||||
current_exception_handlers.update(route.exception_handlers)
|
||||
self.add_api_route(
|
||||
prefix + route.path,
|
||||
route.endpoint,
|
||||
|
|
@ -1793,6 +1874,7 @@ class APIRouter(routing.Router):
|
|||
router.strict_content_type,
|
||||
self.strict_content_type,
|
||||
),
|
||||
exception_handlers=current_exception_handlers or None,
|
||||
)
|
||||
elif isinstance(route, routing.Route):
|
||||
methods = list(route.methods or [])
|
||||
|
|
@ -1809,11 +1891,17 @@ class APIRouter(routing.Router):
|
|||
current_dependencies.extend(dependencies)
|
||||
if route.dependencies:
|
||||
current_dependencies.extend(route.dependencies)
|
||||
current_exception_handlers = {}
|
||||
if router.exception_handlers:
|
||||
current_exception_handlers.update(router.exception_handlers)
|
||||
if route.exception_handlers:
|
||||
current_exception_handlers.update(route.exception_handlers)
|
||||
self.add_api_websocket_route(
|
||||
prefix + route.path,
|
||||
route.endpoint,
|
||||
dependencies=current_dependencies,
|
||||
name=route.name,
|
||||
exception_handlers=current_exception_handlers or None,
|
||||
)
|
||||
elif isinstance(route, routing.WebSocketRoute):
|
||||
self.add_websocket_route(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,321 @@
|
|||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi.websockets import WebSocket
|
||||
from starlette.requests import Request
|
||||
|
||||
|
||||
class CustomExcA(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CustomExcB(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def exc_a_handler(request: Request, exc: CustomExcA) -> JSONResponse:
|
||||
return JSONResponse({"error": "handled_a"}, status_code=400)
|
||||
|
||||
|
||||
def exc_b_handler(request: Request, exc: CustomExcB) -> JSONResponse:
|
||||
return JSONResponse({"error": "handled_b"}, status_code=400)
|
||||
|
||||
|
||||
def test_basic_router_exception_handler():
|
||||
app = FastAPI()
|
||||
router = APIRouter(exception_handlers={CustomExcA: exc_a_handler})
|
||||
|
||||
@router.get("/fail")
|
||||
def fail():
|
||||
raise CustomExcA()
|
||||
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
resp = client.get("/fail")
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {"error": "handled_a"}
|
||||
|
||||
|
||||
def test_isolation_between_routers():
|
||||
app = FastAPI()
|
||||
router1 = APIRouter(prefix="/r1", exception_handlers={CustomExcA: exc_a_handler})
|
||||
router2 = APIRouter(prefix="/r2")
|
||||
|
||||
@router1.get("/fail")
|
||||
def fail1():
|
||||
raise CustomExcA()
|
||||
|
||||
@router2.get("/fail")
|
||||
def fail2():
|
||||
raise CustomExcA()
|
||||
|
||||
app.include_router(router1)
|
||||
app.include_router(router2)
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
resp1 = client.get("/r1/fail")
|
||||
assert resp1.status_code == 400
|
||||
assert resp1.json() == {"error": "handled_a"}
|
||||
|
||||
resp2 = client.get("/r2/fail")
|
||||
assert resp2.status_code == 500
|
||||
|
||||
|
||||
def test_router_overrides_app():
|
||||
def app_handler(request: Request, exc: CustomExcA) -> JSONResponse:
|
||||
return JSONResponse({"error": "app_handled"}, status_code=400)
|
||||
|
||||
def router_handler(request: Request, exc: CustomExcA) -> JSONResponse:
|
||||
return JSONResponse({"error": "router_handled"}, status_code=400)
|
||||
|
||||
app = FastAPI(exception_handlers={CustomExcA: app_handler})
|
||||
router = APIRouter(prefix="/r", exception_handlers={CustomExcA: router_handler})
|
||||
|
||||
@router.get("/fail")
|
||||
def fail():
|
||||
raise CustomExcA()
|
||||
|
||||
@app.get("/app-fail")
|
||||
def app_fail():
|
||||
raise CustomExcA()
|
||||
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/r/fail")
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {"error": "router_handled"}
|
||||
|
||||
resp_app = client.get("/app-fail")
|
||||
assert resp_app.status_code == 400
|
||||
assert resp_app.json() == {"error": "app_handled"}
|
||||
|
||||
|
||||
def test_app_fallback():
|
||||
def app_handler(request: Request, exc: CustomExcA) -> JSONResponse:
|
||||
return JSONResponse({"error": "app_handled"}, status_code=400)
|
||||
|
||||
app = FastAPI(exception_handlers={CustomExcA: app_handler})
|
||||
router = APIRouter(prefix="/r")
|
||||
|
||||
@router.get("/fail")
|
||||
def fail():
|
||||
raise CustomExcA()
|
||||
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/r/fail")
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {"error": "app_handled"}
|
||||
|
||||
|
||||
def test_nested_router_precedence():
|
||||
def handler_a(request: Request, exc: CustomExcA) -> JSONResponse:
|
||||
return JSONResponse({"error": "A"}, status_code=400)
|
||||
|
||||
def handler_b(request: Request, exc: CustomExcA) -> JSONResponse:
|
||||
return JSONResponse({"error": "B"}, status_code=400)
|
||||
|
||||
router_a = APIRouter(prefix="/a", exception_handlers={CustomExcA: handler_a})
|
||||
router_b = APIRouter(prefix="/b", exception_handlers={CustomExcA: handler_b})
|
||||
|
||||
@router_a.get("/fail")
|
||||
def fail_a():
|
||||
raise CustomExcA()
|
||||
|
||||
@router_b.get("/fail")
|
||||
def fail_b():
|
||||
raise CustomExcA()
|
||||
|
||||
router_a.include_router(router_b)
|
||||
app = FastAPI()
|
||||
app.include_router(router_a)
|
||||
client = TestClient(app)
|
||||
|
||||
resp_b = client.get("/a/b/fail")
|
||||
assert resp_b.status_code == 400
|
||||
assert resp_b.json() == {"error": "B"}
|
||||
|
||||
resp_a = client.get("/a/fail")
|
||||
assert resp_a.status_code == 400
|
||||
assert resp_a.json() == {"error": "A"}
|
||||
|
||||
|
||||
def test_deep_nesting():
|
||||
def handler_a(request: Request, exc: CustomExcA) -> JSONResponse:
|
||||
return JSONResponse({"error": "A"}, status_code=400)
|
||||
|
||||
def handler_b(request: Request, exc: CustomExcA) -> JSONResponse:
|
||||
return JSONResponse({"error": "B"}, status_code=400)
|
||||
|
||||
def handler_c(request: Request, exc: CustomExcA) -> JSONResponse:
|
||||
return JSONResponse({"error": "C"}, status_code=400)
|
||||
|
||||
router_a = APIRouter(prefix="/a", exception_handlers={CustomExcA: handler_a})
|
||||
router_b = APIRouter(prefix="/b", exception_handlers={CustomExcA: handler_b})
|
||||
router_c = APIRouter(prefix="/c", exception_handlers={CustomExcA: handler_c})
|
||||
|
||||
@router_a.get("/fail")
|
||||
def fail_a():
|
||||
raise CustomExcA()
|
||||
|
||||
@router_b.get("/fail")
|
||||
def fail_b():
|
||||
raise CustomExcA()
|
||||
|
||||
@router_c.get("/fail")
|
||||
def fail_c():
|
||||
raise CustomExcA()
|
||||
|
||||
router_b.include_router(router_c)
|
||||
router_a.include_router(router_b)
|
||||
app = FastAPI()
|
||||
app.include_router(router_a)
|
||||
client = TestClient(app)
|
||||
|
||||
resp_c = client.get("/a/b/c/fail")
|
||||
assert resp_c.status_code == 400
|
||||
assert resp_c.json() == {"error": "C"}
|
||||
|
||||
resp_b = client.get("/a/b/fail")
|
||||
assert resp_b.status_code == 400
|
||||
assert resp_b.json() == {"error": "B"}
|
||||
|
||||
resp_a = client.get("/a/fail")
|
||||
assert resp_a.status_code == 400
|
||||
assert resp_a.json() == {"error": "A"}
|
||||
|
||||
|
||||
def test_add_exception_handler_method():
|
||||
app = FastAPI()
|
||||
router = APIRouter()
|
||||
router.add_exception_handler(CustomExcA, exc_a_handler)
|
||||
|
||||
@router.get("/fail")
|
||||
def fail():
|
||||
raise CustomExcA()
|
||||
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/fail")
|
||||
assert resp.status_code == 400
|
||||
assert resp.json() == {"error": "handled_a"}
|
||||
|
||||
|
||||
def test_status_code_handler():
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
def not_found_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
return JSONResponse({"error": "custom_404"}, status_code=404)
|
||||
|
||||
app = FastAPI()
|
||||
router = APIRouter(prefix="/r", exception_handlers={404: not_found_handler})
|
||||
|
||||
@router.get("/fail")
|
||||
def fail():
|
||||
raise HTTPException(status_code=404, detail="not found")
|
||||
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/r/fail")
|
||||
assert resp.status_code == 404
|
||||
assert resp.json() == {"error": "custom_404"}
|
||||
|
||||
|
||||
def test_app_routes_unaffected():
|
||||
app = FastAPI()
|
||||
router = APIRouter(prefix="/r", exception_handlers={CustomExcA: exc_a_handler})
|
||||
|
||||
@router.get("/fail")
|
||||
def router_fail():
|
||||
raise CustomExcA()
|
||||
|
||||
@app.get("/fail")
|
||||
def app_fail():
|
||||
raise CustomExcA()
|
||||
|
||||
app.include_router(router)
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
resp_router = client.get("/r/fail")
|
||||
assert resp_router.status_code == 400
|
||||
assert resp_router.json() == {"error": "handled_a"}
|
||||
|
||||
resp_app = client.get("/fail")
|
||||
assert resp_app.status_code == 500
|
||||
|
||||
|
||||
def test_multiple_exception_types():
|
||||
app = FastAPI()
|
||||
router = APIRouter(
|
||||
exception_handlers={
|
||||
CustomExcA: exc_a_handler,
|
||||
CustomExcB: exc_b_handler,
|
||||
}
|
||||
)
|
||||
|
||||
@router.get("/fail-a")
|
||||
def fail_a():
|
||||
raise CustomExcA()
|
||||
|
||||
@router.get("/fail-b")
|
||||
def fail_b():
|
||||
raise CustomExcB()
|
||||
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
|
||||
resp_a = client.get("/fail-a")
|
||||
assert resp_a.status_code == 400
|
||||
assert resp_a.json() == {"error": "handled_a"}
|
||||
|
||||
resp_b = client.get("/fail-b")
|
||||
assert resp_b.status_code == 400
|
||||
assert resp_b.json() == {"error": "handled_b"}
|
||||
|
||||
|
||||
def test_websocket_router_exception_handler():
|
||||
app = FastAPI()
|
||||
router = APIRouter(
|
||||
exception_handlers={CustomExcA: exc_a_handler, 404: exc_a_handler}
|
||||
)
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def ws_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
raise CustomExcA()
|
||||
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# The exception handler returns a JSON response, but for websockets
|
||||
# the connection should close with an error
|
||||
try:
|
||||
ws.receive_text()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def test_websocket_nested_router_exception_handler():
|
||||
app = FastAPI()
|
||||
router_a = APIRouter(prefix="/a", exception_handlers={CustomExcA: exc_a_handler})
|
||||
router_b = APIRouter(prefix="/b", exception_handlers={CustomExcA: exc_a_handler})
|
||||
|
||||
@router_b.websocket("/ws")
|
||||
async def ws_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
raise CustomExcA()
|
||||
|
||||
router_a.include_router(router_b)
|
||||
app.include_router(router_a)
|
||||
client = TestClient(app)
|
||||
|
||||
with client.websocket_connect("/a/b/ws") as ws:
|
||||
try:
|
||||
ws.receive_text()
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
from fastapi.testclient import TestClient
|
||||
from inline_snapshot import snapshot
|
||||
|
||||
from docs_src.handling_errors.tutorial007_py310 import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_get():
|
||||
response = client.get("/unicorns/shinny")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"unicorn_name": "shinny"}
|
||||
|
||||
|
||||
def test_get_exception():
|
||||
response = client.get("/unicorns/yolo")
|
||||
assert response.status_code == 418, response.text
|
||||
assert response.json() == {
|
||||
"message": "Oops! yolo did something. There goes a rainbow..."
|
||||
}
|
||||
|
||||
|
||||
def test_openapi_schema():
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == snapshot(
|
||||
{
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "FastAPI", "version": "0.1.0"},
|
||||
"paths": {
|
||||
"/unicorns/{name}": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
"summary": "Read Unicorn",
|
||||
"operationId": "read_unicorn_unicorns__name__get",
|
||||
"parameters": [
|
||||
{
|
||||
"required": True,
|
||||
"schema": {"title": "Name", "type": "string"},
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"ValidationError": {
|
||||
"title": "ValidationError",
|
||||
"required": ["loc", "msg", "type"],
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"loc": {
|
||||
"title": "Location",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "integer"},
|
||||
]
|
||||
},
|
||||
},
|
||||
"msg": {"title": "Message", "type": "string"},
|
||||
"type": {"title": "Error Type", "type": "string"},
|
||||
"input": {"title": "Input"},
|
||||
"ctx": {"title": "Context", "type": "object"},
|
||||
},
|
||||
},
|
||||
"HTTPValidationError": {
|
||||
"title": "HTTPValidationError",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"detail": {
|
||||
"title": "Detail",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/ValidationError"
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
from fastapi.testclient import TestClient
|
||||
from inline_snapshot import snapshot
|
||||
|
||||
from docs_src.handling_errors.tutorial008_py310 import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_get_magic():
|
||||
response = client.get("/magic/unicorns/shinny")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"unicorn_name": "shinny"}
|
||||
|
||||
|
||||
def test_get_magic_exception():
|
||||
response = client.get("/magic/unicorns/yolo")
|
||||
assert response.status_code == 418, response.text
|
||||
assert response.json() == {"message": "Global handler: yolo did something."}
|
||||
|
||||
|
||||
def test_get_special():
|
||||
response = client.get("/special/unicorns/shinny")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"unicorn_name": "shinny"}
|
||||
|
||||
|
||||
def test_get_special_exception():
|
||||
response = client.get("/special/unicorns/yolo")
|
||||
assert response.status_code == 418, response.text
|
||||
assert response.json() == {
|
||||
"message": "Special handler: yolo did something magical!"
|
||||
}
|
||||
|
||||
|
||||
def test_openapi_schema():
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == snapshot(
|
||||
{
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "FastAPI", "version": "0.1.0"},
|
||||
"paths": {
|
||||
"/magic/unicorns/{name}": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
"summary": "Read Magic Unicorn",
|
||||
"operationId": "read_magic_unicorn_magic_unicorns__name__get",
|
||||
"parameters": [
|
||||
{
|
||||
"required": True,
|
||||
"schema": {"title": "Name", "type": "string"},
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
"/special/unicorns/{name}": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
"summary": "Read Special Unicorn",
|
||||
"operationId": "read_special_unicorn_special_unicorns__name__get",
|
||||
"parameters": [
|
||||
{
|
||||
"required": True,
|
||||
"schema": {"title": "Name", "type": "string"},
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"ValidationError": {
|
||||
"title": "ValidationError",
|
||||
"required": ["loc", "msg", "type"],
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"loc": {
|
||||
"title": "Location",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "integer"},
|
||||
]
|
||||
},
|
||||
},
|
||||
"msg": {"title": "Message", "type": "string"},
|
||||
"type": {"title": "Error Type", "type": "string"},
|
||||
"input": {"title": "Input"},
|
||||
"ctx": {"title": "Context", "type": "object"},
|
||||
},
|
||||
},
|
||||
"HTTPValidationError": {
|
||||
"title": "HTTPValidationError",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"detail": {
|
||||
"title": "Detail",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/ValidationError"
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
Loading…
Reference in New Issue