🚸 Improve tracebacks by adding endpoint metadata (#14306)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
Savannah Ostrowski 2025-12-06 04:21:57 -08:00 committed by GitHub
parent 08b09e5236
commit e1117f7550
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 289 additions and 16 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Sequence, Type, Union from typing import Any, Dict, Optional, Sequence, Type, TypedDict, Union
from annotated_doc import Doc from annotated_doc import Doc
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
@ -7,6 +7,13 @@ from starlette.exceptions import WebSocketException as StarletteWebSocketExcepti
from typing_extensions import Annotated from typing_extensions import Annotated
class EndpointContext(TypedDict, total=False):
function: str
path: str
file: str
line: int
class HTTPException(StarletteHTTPException): class HTTPException(StarletteHTTPException):
""" """
An HTTP exception you can raise in your own code to show errors to the client. An HTTP exception you can raise in your own code to show errors to the client.
@ -155,30 +162,72 @@ class DependencyScopeError(FastAPIError):
class ValidationException(Exception): class ValidationException(Exception):
def __init__(self, errors: Sequence[Any]) -> None: def __init__(
self,
errors: Sequence[Any],
*,
endpoint_ctx: Optional[EndpointContext] = None,
) -> None:
self._errors = errors self._errors = errors
self.endpoint_ctx = endpoint_ctx
ctx = endpoint_ctx or {}
self.endpoint_function = ctx.get("function")
self.endpoint_path = ctx.get("path")
self.endpoint_file = ctx.get("file")
self.endpoint_line = ctx.get("line")
def errors(self) -> Sequence[Any]: def errors(self) -> Sequence[Any]:
return self._errors return self._errors
def _format_endpoint_context(self) -> str:
if not (self.endpoint_file and self.endpoint_line and self.endpoint_function):
if self.endpoint_path:
return f"\n Endpoint: {self.endpoint_path}"
return ""
context = f'\n File "{self.endpoint_file}", line {self.endpoint_line}, in {self.endpoint_function}'
if self.endpoint_path:
context += f"\n {self.endpoint_path}"
return context
def __str__(self) -> str:
message = f"{len(self._errors)} validation error{'s' if len(self._errors) != 1 else ''}:\n"
for err in self._errors:
message += f" {err}\n"
message += self._format_endpoint_context()
return message.rstrip()
class RequestValidationError(ValidationException): class RequestValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: def __init__(
super().__init__(errors) self,
errors: Sequence[Any],
*,
body: Any = None,
endpoint_ctx: Optional[EndpointContext] = None,
) -> None:
super().__init__(errors, endpoint_ctx=endpoint_ctx)
self.body = body self.body = body
class WebSocketRequestValidationError(ValidationException): class WebSocketRequestValidationError(ValidationException):
pass def __init__(
self,
errors: Sequence[Any],
*,
endpoint_ctx: Optional[EndpointContext] = None,
) -> None:
super().__init__(errors, endpoint_ctx=endpoint_ctx)
class ResponseValidationError(ValidationException): class ResponseValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: def __init__(
super().__init__(errors) self,
errors: Sequence[Any],
*,
body: Any = None,
endpoint_ctx: Optional[EndpointContext] = None,
) -> None:
super().__init__(errors, endpoint_ctx=endpoint_ctx)
self.body = body self.body = body
def __str__(self) -> str:
message = f"{len(self._errors)} validation errors:\n"
for err in self._errors:
message += f" {err}\n"
return message

View File

@ -46,6 +46,7 @@ from fastapi.dependencies.utils import (
) )
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import ( from fastapi.exceptions import (
EndpointContext,
FastAPIError, FastAPIError,
RequestValidationError, RequestValidationError,
ResponseValidationError, ResponseValidationError,
@ -212,6 +213,33 @@ def _merge_lifespan_context(
return merged_lifespan # type: ignore[return-value] return merged_lifespan # type: ignore[return-value]
# Cache for endpoint context to avoid re-extracting on every request
_endpoint_context_cache: Dict[int, EndpointContext] = {}
def _extract_endpoint_context(func: Any) -> EndpointContext:
"""Extract endpoint context with caching to avoid repeated file I/O."""
func_id = id(func)
if func_id in _endpoint_context_cache:
return _endpoint_context_cache[func_id]
try:
ctx: EndpointContext = {}
if (source_file := inspect.getsourcefile(func)) is not None:
ctx["file"] = source_file
if (line_number := inspect.getsourcelines(func)[1]) is not None:
ctx["line"] = line_number
if (func_name := getattr(func, "__name__", None)) is not None:
ctx["function"] = func_name
except Exception:
ctx = EndpointContext()
_endpoint_context_cache[func_id] = ctx
return ctx
async def serialize_response( async def serialize_response(
*, *,
field: Optional[ModelField] = None, field: Optional[ModelField] = None,
@ -223,6 +251,7 @@ async def serialize_response(
exclude_defaults: bool = False, exclude_defaults: bool = False,
exclude_none: bool = False, exclude_none: bool = False,
is_coroutine: bool = True, is_coroutine: bool = True,
endpoint_ctx: Optional[EndpointContext] = None,
) -> Any: ) -> Any:
if field: if field:
errors = [] errors = []
@ -245,8 +274,11 @@ async def serialize_response(
elif errors_: elif errors_:
errors.append(errors_) errors.append(errors_)
if errors: if errors:
ctx = endpoint_ctx or EndpointContext()
raise ResponseValidationError( raise ResponseValidationError(
errors=_normalize_errors(errors), body=response_content errors=_normalize_errors(errors),
body=response_content,
endpoint_ctx=ctx,
) )
if hasattr(field, "serialize"): if hasattr(field, "serialize"):
@ -318,6 +350,18 @@ def get_request_handler(
"fastapi_middleware_astack not found in request scope" "fastapi_middleware_astack not found in request scope"
) )
# Extract endpoint context for error messages
endpoint_ctx = (
_extract_endpoint_context(dependant.call)
if dependant.call
else EndpointContext()
)
if dependant.path:
# For mounted sub-apps, include the mount path prefix
mount_path = request.scope.get("root_path", "").rstrip("/")
endpoint_ctx["path"] = f"{request.method} {mount_path}{dependant.path}"
# Read body and auto-close files # Read body and auto-close files
try: try:
body: Any = None body: Any = None
@ -355,6 +399,7 @@ def get_request_handler(
} }
], ],
body=e.doc, body=e.doc,
endpoint_ctx=endpoint_ctx,
) )
raise validation_error from e raise validation_error from e
except HTTPException: except HTTPException:
@ -414,6 +459,7 @@ def get_request_handler(
exclude_defaults=response_model_exclude_defaults, exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none, exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine, is_coroutine=is_coroutine,
endpoint_ctx=endpoint_ctx,
) )
response = actual_response_class(content, **response_args) response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code): if not is_body_allowed_for_status_code(response.status_code):
@ -421,7 +467,7 @@ def get_request_handler(
response.headers.raw.extend(solved_result.response.headers.raw) response.headers.raw.extend(solved_result.response.headers.raw)
if errors: if errors:
validation_error = RequestValidationError( validation_error = RequestValidationError(
_normalize_errors(errors), body=body _normalize_errors(errors), body=body, endpoint_ctx=endpoint_ctx
) )
raise validation_error raise validation_error
@ -438,6 +484,15 @@ def get_websocket_app(
embed_body_fields: bool = False, embed_body_fields: bool = False,
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
async def app(websocket: WebSocket) -> None: async def app(websocket: WebSocket) -> None:
endpoint_ctx = (
_extract_endpoint_context(dependant.call)
if dependant.call
else EndpointContext()
)
if dependant.path:
# For mounted sub-apps, include the mount path prefix
mount_path = websocket.scope.get("root_path", "").rstrip("/")
endpoint_ctx["path"] = f"WS {mount_path}{dependant.path}"
async_exit_stack = websocket.scope.get("fastapi_inner_astack") async_exit_stack = websocket.scope.get("fastapi_inner_astack")
assert isinstance(async_exit_stack, AsyncExitStack), ( assert isinstance(async_exit_stack, AsyncExitStack), (
"fastapi_inner_astack not found in request scope" "fastapi_inner_astack not found in request scope"
@ -451,7 +506,8 @@ def get_websocket_app(
) )
if solved_result.errors: if solved_result.errors:
raise WebSocketRequestValidationError( raise WebSocketRequestValidationError(
_normalize_errors(solved_result.errors) _normalize_errors(solved_result.errors),
endpoint_ctx=endpoint_ctx,
) )
assert dependant.call is not None, "dependant.call must be a function" assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**solved_result.values) await dependant.call(**solved_result.values)

View File

@ -0,0 +1,168 @@
from fastapi import FastAPI, Request, WebSocket
from fastapi.exceptions import (
RequestValidationError,
ResponseValidationError,
WebSocketRequestValidationError,
)
from fastapi.testclient import TestClient
from pydantic import BaseModel
class Item(BaseModel):
id: int
name: str
class ExceptionCapture:
def __init__(self):
self.exception = None
def capture(self, exc):
self.exception = exc
return exc
app = FastAPI()
sub_app = FastAPI()
captured_exception = ExceptionCapture()
app.mount(path="/sub", app=sub_app)
@app.exception_handler(RequestValidationError)
@sub_app.exception_handler(RequestValidationError)
async def request_validation_handler(request: Request, exc: RequestValidationError):
captured_exception.capture(exc)
raise exc
@app.exception_handler(ResponseValidationError)
@sub_app.exception_handler(ResponseValidationError)
async def response_validation_handler(_: Request, exc: ResponseValidationError):
captured_exception.capture(exc)
raise exc
@app.exception_handler(WebSocketRequestValidationError)
@sub_app.exception_handler(WebSocketRequestValidationError)
async def websocket_validation_handler(
websocket: WebSocket, exc: WebSocketRequestValidationError
):
captured_exception.capture(exc)
raise exc
@app.get("/users/{user_id}")
def get_user(user_id: int):
return {"user_id": user_id} # pragma: no cover
@app.get("/items/", response_model=Item)
def get_item():
return {"name": "Widget"}
@sub_app.get("/items/", response_model=Item)
def get_sub_item():
return {"name": "Widget"} # pragma: no cover
@app.websocket("/ws/{item_id}")
async def websocket_endpoint(websocket: WebSocket, item_id: int):
await websocket.accept() # pragma: no cover
await websocket.send_text(f"Item: {item_id}") # pragma: no cover
await websocket.close() # pragma: no cover
@sub_app.websocket("/ws/{item_id}")
async def subapp_websocket_endpoint(websocket: WebSocket, item_id: int):
await websocket.accept() # pragma: no cover
await websocket.send_text(f"Item: {item_id}") # pragma: no cover
await websocket.close() # pragma: no cover
client = TestClient(app)
def test_request_validation_error_includes_endpoint_context():
captured_exception.exception = None
try:
client.get("/users/invalid")
except Exception:
pass
assert captured_exception.exception is not None
error_str = str(captured_exception.exception)
assert "get_user" in error_str
assert "/users/" in error_str
def test_response_validation_error_includes_endpoint_context():
captured_exception.exception = None
try:
client.get("/items/")
except Exception:
pass
assert captured_exception.exception is not None
error_str = str(captured_exception.exception)
assert "get_item" in error_str
assert "/items/" in error_str
def test_websocket_validation_error_includes_endpoint_context():
captured_exception.exception = None
try:
with client.websocket_connect("/ws/invalid"):
pass # pragma: no cover
except Exception:
pass
assert captured_exception.exception is not None
error_str = str(captured_exception.exception)
assert "websocket_endpoint" in error_str
assert "/ws/" in error_str
def test_subapp_request_validation_error_includes_endpoint_context():
captured_exception.exception = None
try:
client.get("/sub/items/")
except Exception:
pass
assert captured_exception.exception is not None
error_str = str(captured_exception.exception)
assert "get_sub_item" in error_str
assert "/sub/items/" in error_str
def test_subapp_websocket_validation_error_includes_endpoint_context():
captured_exception.exception = None
try:
with client.websocket_connect("/sub/ws/invalid"):
pass # pragma: no cover
except Exception:
pass
assert captured_exception.exception is not None
error_str = str(captured_exception.exception)
assert "subapp_websocket_endpoint" in error_str
assert "/sub/ws/" in error_str
def test_validation_error_with_only_path():
errors = [{"type": "missing", "loc": ("body", "name"), "msg": "Field required"}]
exc = RequestValidationError(errors, endpoint_ctx={"path": "GET /api/test"})
error_str = str(exc)
assert "Endpoint: GET /api/test" in error_str
assert 'File "' not in error_str
def test_validation_error_with_no_context():
errors = [{"type": "missing", "loc": ("body", "name"), "msg": "Field required"}]
exc = RequestValidationError(errors, endpoint_ctx={})
error_str = str(exc)
assert "1 validation error:" in error_str
assert "Endpoint" not in error_str
assert 'File "' not in error_str