diff --git a/docs/en/docs/advanced/router-exception-handlers.md b/docs/en/docs/advanced/router-exception-handlers.md new file mode 100644 index 0000000000..fd6eebb482 --- /dev/null +++ b/docs/en/docs/advanced/router-exception-handlers.md @@ -0,0 +1,79 @@ +# Router-Level Exception Handlers { #router-level-exception-handlers } + +In the [Tutorial - Handling Errors](../tutorial/handling-errors.md){.internal-link target=_blank} you learned how to add custom exception handlers to your **FastAPI** application using `@app.exception_handler()`. + +Those handlers apply **globally** to all routes in the application. + +But sometimes you want to handle exceptions differently depending on which part of your application raised them. For example, a group of routes related to payments might need different error handling than routes for user profiles. + +You can do this by adding **exception handlers directly to an `APIRouter`**. + +## Add Exception Handlers to a Router { #add-exception-handlers-to-a-router } + +You can pass `exception_handlers` when creating an `APIRouter`, using the same format as with the `FastAPI` app — a dictionary where keys are exception classes (or status codes) and values are handler functions: + +{* ../../docs_src/handling_errors/tutorial007_py310.py hl[10:20] *} + +Now, if a `UnicornException` is raised in any route within this router, the router's handler will catch it. + +Routes outside this router are **not affected** by the router's exception handlers. + +## Router Handlers Override App Handlers { #router-handlers-override-app-handlers } + +If both the app and a router define a handler for the same exception, the **router's handler takes priority** for routes within that router. The app-level handler still applies to all other routes. + +{* ../../docs_src/handling_errors/tutorial008_py310.py hl[13:21,44:49] *} + +In this example: + +* A request to `/magic/unicorns/yolo` uses the **global** handler (the `magic_router` doesn't define its own). +* A request to `/special/unicorns/yolo` uses the **router-level** handler (defined on the `special_router`). + +This lets you customize error handling per section of your API while keeping a sensible default at the app level. + +## Using `add_exception_handler` { #using-add-exception-handler } + +You can also add exception handlers to a router after creation, using the `add_exception_handler()` method: + +```python +router = APIRouter() +router.add_exception_handler(UnicornException, unicorn_exception_handler) +``` + +This works the same as passing `exception_handlers` in the constructor. + +## Nested Router Precedence { #nested-router-precedence } + +When routers are nested (a router includes another router), exception handlers follow this precedence order (highest to lowest): + +1. The **innermost (child) router**'s handlers +2. The **parent router**'s handlers +3. The **app-level** handlers + +This means a child router can override its parent's handlers for its own routes. + +## Status Code Handlers { #status-code-handlers } + +Just like app-level handlers, you can use **integer status codes** as keys to handle specific HTTP error responses: + +```python +from starlette.exceptions import HTTPException + +def not_found_handler(request, exc): + return JSONResponse( + status_code=404, + content={"message": "Custom 404: resource not found"}, + ) + +router = APIRouter( + exception_handlers={404: not_found_handler} +) +``` + +## Recap { #recap } + +* Pass `exception_handlers` to `APIRouter()` to scope handlers to that router's routes. +* Use `router.add_exception_handler()` to add handlers after creation. +* Router-level handlers **override** app-level handlers for the same exception type. +* Routes **outside** the router are unaffected. +* Nested routers follow **child > parent > app** precedence. diff --git a/docs/en/docs/reference/apirouter.md b/docs/en/docs/reference/apirouter.md index d77364e45e..e10de61e13 100644 --- a/docs/en/docs/reference/apirouter.md +++ b/docs/en/docs/reference/apirouter.md @@ -12,6 +12,7 @@ from fastapi import APIRouter options: members: - websocket + - add_exception_handler - include_router - get - put diff --git a/docs/en/docs/tutorial/handling-errors.md b/docs/en/docs/tutorial/handling-errors.md index 78a5f1f20a..ff745c84aa 100644 --- a/docs/en/docs/tutorial/handling-errors.md +++ b/docs/en/docs/tutorial/handling-errors.md @@ -242,3 +242,11 @@ If you want to use the exception along with the same default exception handlers {* ../../docs_src/handling_errors/tutorial006_py310.py hl[2:5,15,21] *} In this example you are just printing the error with a very expressive message, but you get the idea. You can use the exception and then just reuse the default exception handlers. + +## Router-Level Exception Handlers { #router-level-exception-handlers } + +The examples above add exception handlers to the **application**, so they apply to all routes. + +You can also add exception handlers scoped to a specific **`APIRouter`**, so they only apply to the routes in that router. + +Read more about it in [Advanced User Guide - Router-Level Exception Handlers](../advanced/router-exception-handlers.md){.internal-link target=_blank}. diff --git a/docs/en/mkdocs.yml b/docs/en/mkdocs.yml index 0db3e7a95b..00f2ef419b 100644 --- a/docs/en/mkdocs.yml +++ b/docs/en/mkdocs.yml @@ -173,6 +173,7 @@ nav: - advanced/response-headers.md - advanced/response-change-status-code.md - advanced/advanced-dependencies.md + - advanced/router-exception-handlers.md - Advanced Security: - advanced/security/index.md - advanced/security/oauth2-scopes.md diff --git a/docs_src/handling_errors/tutorial007_py310.py b/docs_src/handling_errors/tutorial007_py310.py new file mode 100644 index 0000000000..a458f94df7 --- /dev/null +++ b/docs_src/handling_errors/tutorial007_py310.py @@ -0,0 +1,31 @@ +from fastapi import APIRouter, FastAPI +from fastapi.responses import JSONResponse + + +class UnicornException(Exception): + def __init__(self, name: str): + self.name = name + + +router = APIRouter( + prefix="/unicorns", + exception_handlers={ + UnicornException: lambda request, exc: JSONResponse( + status_code=418, + content={ + "message": f"Oops! {exc.name} did something. There goes a rainbow..." + }, + ) + }, +) + + +@router.get("/{name}") +async def read_unicorn(name: str): + if name == "yolo": + raise UnicornException(name=name) + return {"unicorn_name": name} + + +app = FastAPI() +app.include_router(router) diff --git a/docs_src/handling_errors/tutorial008_py310.py b/docs_src/handling_errors/tutorial008_py310.py new file mode 100644 index 0000000000..3a0e4f85c4 --- /dev/null +++ b/docs_src/handling_errors/tutorial008_py310.py @@ -0,0 +1,52 @@ +from fastapi import APIRouter, FastAPI, Request +from fastapi.responses import JSONResponse + + +class UnicornException(Exception): + def __init__(self, name: str): + self.name = name + + +app = FastAPI() + + +@app.exception_handler(UnicornException) +async def global_unicorn_exception_handler(request: Request, exc: UnicornException): + return JSONResponse( + status_code=418, + content={"message": f"Global handler: {exc.name} did something."}, + ) + + +magic_router = APIRouter(prefix="/magic") + + +@magic_router.get("/unicorns/{name}") +async def read_magic_unicorn(name: str): + if name == "yolo": + raise UnicornException(name=name) + return {"unicorn_name": name} + + +def custom_unicorn_handler(request: Request, exc: UnicornException): + return JSONResponse( + status_code=418, + content={"message": f"Special handler: {exc.name} did something magical!"}, + ) + + +special_router = APIRouter( + prefix="/special", + exception_handlers={UnicornException: custom_unicorn_handler}, +) + + +@special_router.get("/unicorns/{name}") +async def read_special_unicorn(name: str): + if name == "yolo": + raise UnicornException(name=name) + return {"unicorn_name": name} + + +app.include_router(magic_router) +app.include_router(special_router) diff --git a/fastapi/routing.py b/fastapi/routing.py index 36acb6b89d..eef1e3fc56 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -96,6 +96,7 @@ from typing_extensions import deprecated # dependencies' AsyncExitStack def request_response( func: Callable[[Request], Awaitable[Response] | Response], + exception_handlers: dict[int | type[Exception], Callable[..., Any]] | None = None, ) -> ASGIApp: """ Takes a function or coroutine `func(request) -> response`, @@ -108,6 +109,18 @@ def request_response( ) # ty: ignore[invalid-assignment] async def app(scope: Scope, receive: Receive, send: Send) -> None: + if exception_handlers: + existing_exc, existing_status = scope.get( + "starlette.exception_handlers", ({}, {}) + ) + merged_exc = {**existing_exc} + merged_status = {**existing_status} + for key, handler in exception_handlers.items(): + if isinstance(key, int): + merged_status[key] = handler + else: + merged_exc[key] = handler + scope["starlette.exception_handlers"] = (merged_exc, merged_status) request = Request(scope, receive, send) async def app(scope: Scope, receive: Receive, send: Send) -> None: @@ -140,6 +153,7 @@ def request_response( # dependencies' AsyncExitStack def websocket_session( func: Callable[[WebSocket], Awaitable[None]], + exception_handlers: dict[int | type[Exception], Callable[..., Any]] | None = None, ) -> ASGIApp: """ Takes a coroutine `func(session)`, and returns an ASGI application. @@ -147,6 +161,18 @@ def websocket_session( # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async" async def app(scope: Scope, receive: Receive, send: Send) -> None: + if exception_handlers: + existing_exc, existing_status = scope.get( + "starlette.exception_handlers", ({}, {}) + ) + merged_exc = {**existing_exc} + merged_status = {**existing_status} + for key, handler in exception_handlers.items(): + if isinstance(key, int): + merged_status[key] = handler + else: + merged_exc[key] = handler + scope["starlette.exception_handlers"] = (merged_exc, merged_status) session = WebSocket(scope, receive=receive, send=send) async def app(scope: Scope, receive: Receive, send: Send) -> None: @@ -775,11 +801,14 @@ class APIWebSocketRoute(routing.WebSocketRoute): name: str | None = None, dependencies: Sequence[params.Depends] | None = None, dependency_overrides_provider: Any | None = None, + exception_handlers: dict[int | type[Exception], Callable[..., Any]] + | None = None, ) -> None: self.path = path self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name self.dependencies = list(dependencies or []) + self.exception_handlers = exception_handlers or {} self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.dependant = get_dependant( path=self.path_format, call=self.endpoint, scope="function" @@ -798,7 +827,8 @@ class APIWebSocketRoute(routing.WebSocketRoute): dependant=self.dependant, dependency_overrides_provider=dependency_overrides_provider, embed_body_fields=self._embed_body_fields, - ) + ), + exception_handlers=self.exception_handlers or None, ) def matches(self, scope: Scope) -> tuple[Match, Scope]: @@ -840,9 +870,12 @@ class APIRoute(routing.Route): generate_unique_id_function: Callable[["APIRoute"], str] | DefaultPlaceholder = Default(generate_unique_id), strict_content_type: bool | DefaultPlaceholder = Default(True), + exception_handlers: dict[int | type[Exception], Callable[..., Any]] + | None = None, ) -> None: self.path = path self.endpoint = endpoint + self.exception_handlers = exception_handlers or {} self.stream_item_type: Any | None = None if isinstance(response_model, DefaultPlaceholder): return_annotation = get_typed_return_annotation(endpoint) @@ -973,7 +1006,10 @@ class APIRoute(routing.Route): self.is_json_stream = is_generator and isinstance( response_class, DefaultPlaceholder ) - self.app = request_response(self.get_route_handler()) + self.app = request_response( + self.get_route_handler(), + exception_handlers=self.exception_handlers or None, + ) def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: return get_request_handler( @@ -1266,6 +1302,18 @@ class APIRouter(routing.Router): """ ), ] = Default(True), + exception_handlers: Annotated[ + dict[int | type[Exception], Callable[..., Any]] | None, + Doc( + """ + A dictionary of exception handlers scoped to this router. + + Keys can be exception classes or HTTP status codes. Handlers + defined here will override app-level exception handlers for + routes within this router. + """ + ), + ] = None, ) -> None: # Determine the lifespan context to use if lifespan is None: @@ -1313,6 +1361,16 @@ class APIRouter(routing.Router): self.default_response_class = default_response_class self.generate_unique_id_function = generate_unique_id_function self.strict_content_type = strict_content_type + self.exception_handlers: dict[int | type[Exception], Callable[..., Any]] = ( + {} if exception_handlers is None else dict(exception_handlers) + ) + + def add_exception_handler( + self, + exc_class_or_status_code: int | type[Exception], + handler: Callable[..., Any], + ) -> None: + self.exception_handlers[exc_class_or_status_code] = handler def route( self, @@ -1364,6 +1422,8 @@ class APIRouter(routing.Router): generate_unique_id_function: Callable[[APIRoute], str] | DefaultPlaceholder = Default(generate_unique_id), strict_content_type: bool | DefaultPlaceholder = Default(True), + exception_handlers: dict[int | type[Exception], Callable[..., Any]] + | None = None, ) -> None: route_class = route_class_override or self.route_class responses = responses or {} @@ -1383,6 +1443,11 @@ class APIRouter(routing.Router): current_generate_unique_id = get_value_or_default( generate_unique_id_function, self.generate_unique_id_function ) + current_exception_handlers: dict[int | type[Exception], Callable[..., Any]] = ( + dict(self.exception_handlers) + ) + if exception_handlers: + current_exception_handlers.update(exception_handlers) route = route_class( self.prefix + path, endpoint=endpoint, @@ -1413,6 +1478,7 @@ class APIRouter(routing.Router): strict_content_type=get_value_or_default( strict_content_type, self.strict_content_type ), + exception_handlers=current_exception_handlers or None, ) self.routes.append(route) @@ -1485,10 +1551,17 @@ class APIRouter(routing.Router): name: str | None = None, *, dependencies: Sequence[params.Depends] | None = None, + exception_handlers: dict[int | type[Exception], Callable[..., Any]] + | None = None, ) -> None: current_dependencies = self.dependencies.copy() if dependencies: current_dependencies.extend(dependencies) + current_exception_handlers: dict[int | type[Exception], Callable[..., Any]] = ( + dict(self.exception_handlers) + ) + if exception_handlers: + current_exception_handlers.update(exception_handlers) route = APIWebSocketRoute( self.prefix + path, @@ -1496,6 +1569,7 @@ class APIRouter(routing.Router): name=name, dependencies=current_dependencies, dependency_overrides_provider=self.dependency_overrides_provider, + exception_handlers=current_exception_handlers or None, ) self.routes.append(route) @@ -1759,6 +1833,13 @@ class APIRouter(routing.Router): generate_unique_id_function, self.generate_unique_id_function, ) + current_exception_handlers: dict[ + int | type[Exception], Callable[..., Any] + ] = {} + if router.exception_handlers: + current_exception_handlers.update(router.exception_handlers) + if route.exception_handlers: + current_exception_handlers.update(route.exception_handlers) self.add_api_route( prefix + route.path, route.endpoint, @@ -1793,6 +1874,7 @@ class APIRouter(routing.Router): router.strict_content_type, self.strict_content_type, ), + exception_handlers=current_exception_handlers or None, ) elif isinstance(route, routing.Route): methods = list(route.methods or []) @@ -1809,11 +1891,17 @@ class APIRouter(routing.Router): current_dependencies.extend(dependencies) if route.dependencies: current_dependencies.extend(route.dependencies) + current_exception_handlers = {} + if router.exception_handlers: + current_exception_handlers.update(router.exception_handlers) + if route.exception_handlers: + current_exception_handlers.update(route.exception_handlers) self.add_api_websocket_route( prefix + route.path, route.endpoint, dependencies=current_dependencies, name=route.name, + exception_handlers=current_exception_handlers or None, ) elif isinstance(route, routing.WebSocketRoute): self.add_websocket_route( diff --git a/tests/test_router_exception_handlers.py b/tests/test_router_exception_handlers.py new file mode 100644 index 0000000000..ecad9e716b --- /dev/null +++ b/tests/test_router_exception_handlers.py @@ -0,0 +1,321 @@ +from fastapi import APIRouter, FastAPI +from fastapi.responses import JSONResponse +from fastapi.testclient import TestClient +from fastapi.websockets import WebSocket +from starlette.requests import Request + + +class CustomExcA(Exception): + pass + + +class CustomExcB(Exception): + pass + + +def exc_a_handler(request: Request, exc: CustomExcA) -> JSONResponse: + return JSONResponse({"error": "handled_a"}, status_code=400) + + +def exc_b_handler(request: Request, exc: CustomExcB) -> JSONResponse: + return JSONResponse({"error": "handled_b"}, status_code=400) + + +def test_basic_router_exception_handler(): + app = FastAPI() + router = APIRouter(exception_handlers={CustomExcA: exc_a_handler}) + + @router.get("/fail") + def fail(): + raise CustomExcA() + + app.include_router(router) + client = TestClient(app) + resp = client.get("/fail") + assert resp.status_code == 400 + assert resp.json() == {"error": "handled_a"} + + +def test_isolation_between_routers(): + app = FastAPI() + router1 = APIRouter(prefix="/r1", exception_handlers={CustomExcA: exc_a_handler}) + router2 = APIRouter(prefix="/r2") + + @router1.get("/fail") + def fail1(): + raise CustomExcA() + + @router2.get("/fail") + def fail2(): + raise CustomExcA() + + app.include_router(router1) + app.include_router(router2) + client = TestClient(app, raise_server_exceptions=False) + + resp1 = client.get("/r1/fail") + assert resp1.status_code == 400 + assert resp1.json() == {"error": "handled_a"} + + resp2 = client.get("/r2/fail") + assert resp2.status_code == 500 + + +def test_router_overrides_app(): + def app_handler(request: Request, exc: CustomExcA) -> JSONResponse: + return JSONResponse({"error": "app_handled"}, status_code=400) + + def router_handler(request: Request, exc: CustomExcA) -> JSONResponse: + return JSONResponse({"error": "router_handled"}, status_code=400) + + app = FastAPI(exception_handlers={CustomExcA: app_handler}) + router = APIRouter(prefix="/r", exception_handlers={CustomExcA: router_handler}) + + @router.get("/fail") + def fail(): + raise CustomExcA() + + @app.get("/app-fail") + def app_fail(): + raise CustomExcA() + + app.include_router(router) + client = TestClient(app) + + resp = client.get("/r/fail") + assert resp.status_code == 400 + assert resp.json() == {"error": "router_handled"} + + resp_app = client.get("/app-fail") + assert resp_app.status_code == 400 + assert resp_app.json() == {"error": "app_handled"} + + +def test_app_fallback(): + def app_handler(request: Request, exc: CustomExcA) -> JSONResponse: + return JSONResponse({"error": "app_handled"}, status_code=400) + + app = FastAPI(exception_handlers={CustomExcA: app_handler}) + router = APIRouter(prefix="/r") + + @router.get("/fail") + def fail(): + raise CustomExcA() + + app.include_router(router) + client = TestClient(app) + + resp = client.get("/r/fail") + assert resp.status_code == 400 + assert resp.json() == {"error": "app_handled"} + + +def test_nested_router_precedence(): + def handler_a(request: Request, exc: CustomExcA) -> JSONResponse: + return JSONResponse({"error": "A"}, status_code=400) + + def handler_b(request: Request, exc: CustomExcA) -> JSONResponse: + return JSONResponse({"error": "B"}, status_code=400) + + router_a = APIRouter(prefix="/a", exception_handlers={CustomExcA: handler_a}) + router_b = APIRouter(prefix="/b", exception_handlers={CustomExcA: handler_b}) + + @router_a.get("/fail") + def fail_a(): + raise CustomExcA() + + @router_b.get("/fail") + def fail_b(): + raise CustomExcA() + + router_a.include_router(router_b) + app = FastAPI() + app.include_router(router_a) + client = TestClient(app) + + resp_b = client.get("/a/b/fail") + assert resp_b.status_code == 400 + assert resp_b.json() == {"error": "B"} + + resp_a = client.get("/a/fail") + assert resp_a.status_code == 400 + assert resp_a.json() == {"error": "A"} + + +def test_deep_nesting(): + def handler_a(request: Request, exc: CustomExcA) -> JSONResponse: + return JSONResponse({"error": "A"}, status_code=400) + + def handler_b(request: Request, exc: CustomExcA) -> JSONResponse: + return JSONResponse({"error": "B"}, status_code=400) + + def handler_c(request: Request, exc: CustomExcA) -> JSONResponse: + return JSONResponse({"error": "C"}, status_code=400) + + router_a = APIRouter(prefix="/a", exception_handlers={CustomExcA: handler_a}) + router_b = APIRouter(prefix="/b", exception_handlers={CustomExcA: handler_b}) + router_c = APIRouter(prefix="/c", exception_handlers={CustomExcA: handler_c}) + + @router_a.get("/fail") + def fail_a(): + raise CustomExcA() + + @router_b.get("/fail") + def fail_b(): + raise CustomExcA() + + @router_c.get("/fail") + def fail_c(): + raise CustomExcA() + + router_b.include_router(router_c) + router_a.include_router(router_b) + app = FastAPI() + app.include_router(router_a) + client = TestClient(app) + + resp_c = client.get("/a/b/c/fail") + assert resp_c.status_code == 400 + assert resp_c.json() == {"error": "C"} + + resp_b = client.get("/a/b/fail") + assert resp_b.status_code == 400 + assert resp_b.json() == {"error": "B"} + + resp_a = client.get("/a/fail") + assert resp_a.status_code == 400 + assert resp_a.json() == {"error": "A"} + + +def test_add_exception_handler_method(): + app = FastAPI() + router = APIRouter() + router.add_exception_handler(CustomExcA, exc_a_handler) + + @router.get("/fail") + def fail(): + raise CustomExcA() + + app.include_router(router) + client = TestClient(app) + + resp = client.get("/fail") + assert resp.status_code == 400 + assert resp.json() == {"error": "handled_a"} + + +def test_status_code_handler(): + from starlette.exceptions import HTTPException + + def not_found_handler(request: Request, exc: HTTPException) -> JSONResponse: + return JSONResponse({"error": "custom_404"}, status_code=404) + + app = FastAPI() + router = APIRouter(prefix="/r", exception_handlers={404: not_found_handler}) + + @router.get("/fail") + def fail(): + raise HTTPException(status_code=404, detail="not found") + + app.include_router(router) + client = TestClient(app) + + resp = client.get("/r/fail") + assert resp.status_code == 404 + assert resp.json() == {"error": "custom_404"} + + +def test_app_routes_unaffected(): + app = FastAPI() + router = APIRouter(prefix="/r", exception_handlers={CustomExcA: exc_a_handler}) + + @router.get("/fail") + def router_fail(): + raise CustomExcA() + + @app.get("/fail") + def app_fail(): + raise CustomExcA() + + app.include_router(router) + client = TestClient(app, raise_server_exceptions=False) + + resp_router = client.get("/r/fail") + assert resp_router.status_code == 400 + assert resp_router.json() == {"error": "handled_a"} + + resp_app = client.get("/fail") + assert resp_app.status_code == 500 + + +def test_multiple_exception_types(): + app = FastAPI() + router = APIRouter( + exception_handlers={ + CustomExcA: exc_a_handler, + CustomExcB: exc_b_handler, + } + ) + + @router.get("/fail-a") + def fail_a(): + raise CustomExcA() + + @router.get("/fail-b") + def fail_b(): + raise CustomExcB() + + app.include_router(router) + client = TestClient(app) + + resp_a = client.get("/fail-a") + assert resp_a.status_code == 400 + assert resp_a.json() == {"error": "handled_a"} + + resp_b = client.get("/fail-b") + assert resp_b.status_code == 400 + assert resp_b.json() == {"error": "handled_b"} + + +def test_websocket_router_exception_handler(): + app = FastAPI() + router = APIRouter( + exception_handlers={CustomExcA: exc_a_handler, 404: exc_a_handler} + ) + + @router.websocket("/ws") + async def ws_endpoint(websocket: WebSocket): + await websocket.accept() + raise CustomExcA() + + app.include_router(router) + client = TestClient(app) + + with client.websocket_connect("/ws") as ws: + # The exception handler returns a JSON response, but for websockets + # the connection should close with an error + try: + ws.receive_text() + except Exception: + pass + + +def test_websocket_nested_router_exception_handler(): + app = FastAPI() + router_a = APIRouter(prefix="/a", exception_handlers={CustomExcA: exc_a_handler}) + router_b = APIRouter(prefix="/b", exception_handlers={CustomExcA: exc_a_handler}) + + @router_b.websocket("/ws") + async def ws_endpoint(websocket: WebSocket): + await websocket.accept() + raise CustomExcA() + + router_a.include_router(router_b) + app.include_router(router_a) + client = TestClient(app) + + with client.websocket_connect("/a/b/ws") as ws: + try: + ws.receive_text() + except Exception: + pass diff --git a/tests/test_tutorial/test_handling_errors/test_tutorial007.py b/tests/test_tutorial/test_handling_errors/test_tutorial007.py new file mode 100644 index 0000000000..8043e67733 --- /dev/null +++ b/tests/test_tutorial/test_handling_errors/test_tutorial007.py @@ -0,0 +1,101 @@ +from fastapi.testclient import TestClient +from inline_snapshot import snapshot + +from docs_src.handling_errors.tutorial007_py310 import app + +client = TestClient(app) + + +def test_get(): + response = client.get("/unicorns/shinny") + assert response.status_code == 200, response.text + assert response.json() == {"unicorn_name": "shinny"} + + +def test_get_exception(): + response = client.get("/unicorns/yolo") + assert response.status_code == 418, response.text + assert response.json() == { + "message": "Oops! yolo did something. There goes a rainbow..." + } + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == snapshot( + { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/unicorns/{name}": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Read Unicorn", + "operationId": "read_unicorn_unicorns__name__get", + "parameters": [ + { + "required": True, + "schema": {"title": "Name", "type": "string"}, + "name": "name", + "in": "path", + } + ], + } + } + }, + "components": { + "schemas": { + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + ] + }, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + "input": {"title": "Input"}, + "ctx": {"title": "Context", "type": "object"}, + }, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + } + }, + }, + } + }, + } + ) diff --git a/tests/test_tutorial/test_handling_errors/test_tutorial008.py b/tests/test_tutorial/test_handling_errors/test_tutorial008.py new file mode 100644 index 0000000000..1a090ee3b1 --- /dev/null +++ b/tests/test_tutorial/test_handling_errors/test_tutorial008.py @@ -0,0 +1,143 @@ +from fastapi.testclient import TestClient +from inline_snapshot import snapshot + +from docs_src.handling_errors.tutorial008_py310 import app + +client = TestClient(app) + + +def test_get_magic(): + response = client.get("/magic/unicorns/shinny") + assert response.status_code == 200, response.text + assert response.json() == {"unicorn_name": "shinny"} + + +def test_get_magic_exception(): + response = client.get("/magic/unicorns/yolo") + assert response.status_code == 418, response.text + assert response.json() == {"message": "Global handler: yolo did something."} + + +def test_get_special(): + response = client.get("/special/unicorns/shinny") + assert response.status_code == 200, response.text + assert response.json() == {"unicorn_name": "shinny"} + + +def test_get_special_exception(): + response = client.get("/special/unicorns/yolo") + assert response.status_code == 418, response.text + assert response.json() == { + "message": "Special handler: yolo did something magical!" + } + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == snapshot( + { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/magic/unicorns/{name}": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Read Magic Unicorn", + "operationId": "read_magic_unicorn_magic_unicorns__name__get", + "parameters": [ + { + "required": True, + "schema": {"title": "Name", "type": "string"}, + "name": "name", + "in": "path", + } + ], + } + }, + "/special/unicorns/{name}": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Read Special Unicorn", + "operationId": "read_special_unicorn_special_unicorns__name__get", + "parameters": [ + { + "required": True, + "schema": {"title": "Name", "type": "string"}, + "name": "name", + "in": "path", + } + ], + } + }, + }, + "components": { + "schemas": { + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + ] + }, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + "input": {"title": "Input"}, + "ctx": {"title": "Context", "type": "object"}, + }, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + } + }, + }, + } + }, + } + )