🚸 Improve tracebacks by adding endpoint metadata (#14306)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
Savannah Ostrowski 2025-12-06 04:21:57 -08:00 committed by GitHub
parent 08b09e5236
commit e1117f7550
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 289 additions and 16 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Sequence, Type, Union
from typing import Any, Dict, Optional, Sequence, Type, TypedDict, Union
from annotated_doc import Doc
from pydantic import BaseModel, create_model
@ -7,6 +7,13 @@ from starlette.exceptions import WebSocketException as StarletteWebSocketExcepti
from typing_extensions import Annotated
class EndpointContext(TypedDict, total=False):
function: str
path: str
file: str
line: int
class HTTPException(StarletteHTTPException):
"""
An HTTP exception you can raise in your own code to show errors to the client.
@ -155,30 +162,72 @@ class DependencyScopeError(FastAPIError):
class ValidationException(Exception):
def __init__(self, errors: Sequence[Any]) -> None:
def __init__(
self,
errors: Sequence[Any],
*,
endpoint_ctx: Optional[EndpointContext] = None,
) -> None:
self._errors = errors
self.endpoint_ctx = endpoint_ctx
ctx = endpoint_ctx or {}
self.endpoint_function = ctx.get("function")
self.endpoint_path = ctx.get("path")
self.endpoint_file = ctx.get("file")
self.endpoint_line = ctx.get("line")
def errors(self) -> Sequence[Any]:
return self._errors
def _format_endpoint_context(self) -> str:
if not (self.endpoint_file and self.endpoint_line and self.endpoint_function):
if self.endpoint_path:
return f"\n Endpoint: {self.endpoint_path}"
return ""
context = f'\n File "{self.endpoint_file}", line {self.endpoint_line}, in {self.endpoint_function}'
if self.endpoint_path:
context += f"\n {self.endpoint_path}"
return context
def __str__(self) -> str:
message = f"{len(self._errors)} validation error{'s' if len(self._errors) != 1 else ''}:\n"
for err in self._errors:
message += f" {err}\n"
message += self._format_endpoint_context()
return message.rstrip()
class RequestValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
super().__init__(errors)
def __init__(
self,
errors: Sequence[Any],
*,
body: Any = None,
endpoint_ctx: Optional[EndpointContext] = None,
) -> None:
super().__init__(errors, endpoint_ctx=endpoint_ctx)
self.body = body
class WebSocketRequestValidationError(ValidationException):
pass
def __init__(
self,
errors: Sequence[Any],
*,
endpoint_ctx: Optional[EndpointContext] = None,
) -> None:
super().__init__(errors, endpoint_ctx=endpoint_ctx)
class ResponseValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
super().__init__(errors)
def __init__(
self,
errors: Sequence[Any],
*,
body: Any = None,
endpoint_ctx: Optional[EndpointContext] = None,
) -> None:
super().__init__(errors, endpoint_ctx=endpoint_ctx)
self.body = body
def __str__(self) -> str:
message = f"{len(self._errors)} validation errors:\n"
for err in self._errors:
message += f" {err}\n"
return message

View File

@ -46,6 +46,7 @@ from fastapi.dependencies.utils import (
)
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import (
EndpointContext,
FastAPIError,
RequestValidationError,
ResponseValidationError,
@ -212,6 +213,33 @@ def _merge_lifespan_context(
return merged_lifespan # type: ignore[return-value]
# Cache for endpoint context to avoid re-extracting on every request
_endpoint_context_cache: Dict[int, EndpointContext] = {}
def _extract_endpoint_context(func: Any) -> EndpointContext:
"""Extract endpoint context with caching to avoid repeated file I/O."""
func_id = id(func)
if func_id in _endpoint_context_cache:
return _endpoint_context_cache[func_id]
try:
ctx: EndpointContext = {}
if (source_file := inspect.getsourcefile(func)) is not None:
ctx["file"] = source_file
if (line_number := inspect.getsourcelines(func)[1]) is not None:
ctx["line"] = line_number
if (func_name := getattr(func, "__name__", None)) is not None:
ctx["function"] = func_name
except Exception:
ctx = EndpointContext()
_endpoint_context_cache[func_id] = ctx
return ctx
async def serialize_response(
*,
field: Optional[ModelField] = None,
@ -223,6 +251,7 @@ async def serialize_response(
exclude_defaults: bool = False,
exclude_none: bool = False,
is_coroutine: bool = True,
endpoint_ctx: Optional[EndpointContext] = None,
) -> Any:
if field:
errors = []
@ -245,8 +274,11 @@ async def serialize_response(
elif errors_:
errors.append(errors_)
if errors:
ctx = endpoint_ctx or EndpointContext()
raise ResponseValidationError(
errors=_normalize_errors(errors), body=response_content
errors=_normalize_errors(errors),
body=response_content,
endpoint_ctx=ctx,
)
if hasattr(field, "serialize"):
@ -318,6 +350,18 @@ def get_request_handler(
"fastapi_middleware_astack not found in request scope"
)
# Extract endpoint context for error messages
endpoint_ctx = (
_extract_endpoint_context(dependant.call)
if dependant.call
else EndpointContext()
)
if dependant.path:
# For mounted sub-apps, include the mount path prefix
mount_path = request.scope.get("root_path", "").rstrip("/")
endpoint_ctx["path"] = f"{request.method} {mount_path}{dependant.path}"
# Read body and auto-close files
try:
body: Any = None
@ -355,6 +399,7 @@ def get_request_handler(
}
],
body=e.doc,
endpoint_ctx=endpoint_ctx,
)
raise validation_error from e
except HTTPException:
@ -414,6 +459,7 @@ def get_request_handler(
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine,
endpoint_ctx=endpoint_ctx,
)
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
@ -421,7 +467,7 @@ def get_request_handler(
response.headers.raw.extend(solved_result.response.headers.raw)
if errors:
validation_error = RequestValidationError(
_normalize_errors(errors), body=body
_normalize_errors(errors), body=body, endpoint_ctx=endpoint_ctx
)
raise validation_error
@ -438,6 +484,15 @@ def get_websocket_app(
embed_body_fields: bool = False,
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
async def app(websocket: WebSocket) -> None:
endpoint_ctx = (
_extract_endpoint_context(dependant.call)
if dependant.call
else EndpointContext()
)
if dependant.path:
# For mounted sub-apps, include the mount path prefix
mount_path = websocket.scope.get("root_path", "").rstrip("/")
endpoint_ctx["path"] = f"WS {mount_path}{dependant.path}"
async_exit_stack = websocket.scope.get("fastapi_inner_astack")
assert isinstance(async_exit_stack, AsyncExitStack), (
"fastapi_inner_astack not found in request scope"
@ -451,7 +506,8 @@ def get_websocket_app(
)
if solved_result.errors:
raise WebSocketRequestValidationError(
_normalize_errors(solved_result.errors)
_normalize_errors(solved_result.errors),
endpoint_ctx=endpoint_ctx,
)
assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**solved_result.values)

View File

@ -0,0 +1,168 @@
from fastapi import FastAPI, Request, WebSocket
from fastapi.exceptions import (
RequestValidationError,
ResponseValidationError,
WebSocketRequestValidationError,
)
from fastapi.testclient import TestClient
from pydantic import BaseModel
class Item(BaseModel):
id: int
name: str
class ExceptionCapture:
def __init__(self):
self.exception = None
def capture(self, exc):
self.exception = exc
return exc
app = FastAPI()
sub_app = FastAPI()
captured_exception = ExceptionCapture()
app.mount(path="/sub", app=sub_app)
@app.exception_handler(RequestValidationError)
@sub_app.exception_handler(RequestValidationError)
async def request_validation_handler(request: Request, exc: RequestValidationError):
captured_exception.capture(exc)
raise exc
@app.exception_handler(ResponseValidationError)
@sub_app.exception_handler(ResponseValidationError)
async def response_validation_handler(_: Request, exc: ResponseValidationError):
captured_exception.capture(exc)
raise exc
@app.exception_handler(WebSocketRequestValidationError)
@sub_app.exception_handler(WebSocketRequestValidationError)
async def websocket_validation_handler(
websocket: WebSocket, exc: WebSocketRequestValidationError
):
captured_exception.capture(exc)
raise exc
@app.get("/users/{user_id}")
def get_user(user_id: int):
return {"user_id": user_id} # pragma: no cover
@app.get("/items/", response_model=Item)
def get_item():
return {"name": "Widget"}
@sub_app.get("/items/", response_model=Item)
def get_sub_item():
return {"name": "Widget"} # pragma: no cover
@app.websocket("/ws/{item_id}")
async def websocket_endpoint(websocket: WebSocket, item_id: int):
await websocket.accept() # pragma: no cover
await websocket.send_text(f"Item: {item_id}") # pragma: no cover
await websocket.close() # pragma: no cover
@sub_app.websocket("/ws/{item_id}")
async def subapp_websocket_endpoint(websocket: WebSocket, item_id: int):
await websocket.accept() # pragma: no cover
await websocket.send_text(f"Item: {item_id}") # pragma: no cover
await websocket.close() # pragma: no cover
client = TestClient(app)
def test_request_validation_error_includes_endpoint_context():
captured_exception.exception = None
try:
client.get("/users/invalid")
except Exception:
pass
assert captured_exception.exception is not None
error_str = str(captured_exception.exception)
assert "get_user" in error_str
assert "/users/" in error_str
def test_response_validation_error_includes_endpoint_context():
captured_exception.exception = None
try:
client.get("/items/")
except Exception:
pass
assert captured_exception.exception is not None
error_str = str(captured_exception.exception)
assert "get_item" in error_str
assert "/items/" in error_str
def test_websocket_validation_error_includes_endpoint_context():
captured_exception.exception = None
try:
with client.websocket_connect("/ws/invalid"):
pass # pragma: no cover
except Exception:
pass
assert captured_exception.exception is not None
error_str = str(captured_exception.exception)
assert "websocket_endpoint" in error_str
assert "/ws/" in error_str
def test_subapp_request_validation_error_includes_endpoint_context():
captured_exception.exception = None
try:
client.get("/sub/items/")
except Exception:
pass
assert captured_exception.exception is not None
error_str = str(captured_exception.exception)
assert "get_sub_item" in error_str
assert "/sub/items/" in error_str
def test_subapp_websocket_validation_error_includes_endpoint_context():
captured_exception.exception = None
try:
with client.websocket_connect("/sub/ws/invalid"):
pass # pragma: no cover
except Exception:
pass
assert captured_exception.exception is not None
error_str = str(captured_exception.exception)
assert "subapp_websocket_endpoint" in error_str
assert "/sub/ws/" in error_str
def test_validation_error_with_only_path():
errors = [{"type": "missing", "loc": ("body", "name"), "msg": "Field required"}]
exc = RequestValidationError(errors, endpoint_ctx={"path": "GET /api/test"})
error_str = str(exc)
assert "Endpoint: GET /api/test" in error_str
assert 'File "' not in error_str
def test_validation_error_with_no_context():
errors = [{"type": "missing", "loc": ("body", "name"), "msg": "Field required"}]
exc = RequestValidationError(errors, endpoint_ctx={})
error_str = str(exc)
assert "1 validation error:" in error_str
assert "Endpoint" not in error_str
assert 'File "' not in error_str