diff --git a/docs/de/docs/project-generation.md b/docs/de/docs/project-generation.md index f830f0f4d..dd3c3b427 100644 --- a/docs/de/docs/project-generation.md +++ b/docs/de/docs/project-generation.md @@ -14,7 +14,7 @@ GitHub-Repository: Any: return model.model_config +def _has_computed_fields(field: ModelField) -> bool: + computed_fields = field._type_adapter.core_schema.get("schema", {}).get( + "computed_fields", [] + ) + return len(computed_fields) > 0 + + def get_schema_from_model_field( *, field: ModelField, @@ -180,12 +187,9 @@ def get_schema_from_model_field( ], separate_input_output_schemas: bool = True, ) -> Dict[str, Any]: - computed_fields = field._type_adapter.core_schema.get("schema", {}).get( - "computed_fields", [] - ) override_mode: Union[Literal["validation"], None] = ( None - if (separate_input_output_schemas or len(computed_fields) > 0) + if (separate_input_output_schemas or _has_computed_fields(field)) else "validation" ) # This expects that GenerateJsonSchema was already used to generate the definitions @@ -208,15 +212,7 @@ def get_definitions( Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], Dict[str, Dict[str, Any]], ]: - has_computed_fields: bool = any( - field._type_adapter.core_schema.get("schema", {}).get("computed_fields", []) - for field in fields - ) - schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) - override_mode: Union[Literal["validation"], None] = ( - None if (separate_input_output_schemas or has_computed_fields) else "validation" - ) validation_fields = [field for field in fields if field.mode == "validation"] serialization_fields = [field for field in fields if field.mode == "serialization"] flat_validation_models = get_flat_models_from_fields( @@ -246,9 +242,16 @@ def get_definitions( unique_flat_model_fields = { f for f in flat_model_fields if f.type_ not in input_types } - inputs = [ - (field, override_mode or field.mode, field._type_adapter.core_schema) + ( + field, + ( + field.mode + if (separate_input_output_schemas or _has_computed_fields(field)) + else "validation" + ), + field._type_adapter.core_schema, + ) for field in list(fields) + list(unique_flat_model_fields) ] field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs) diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 7e1b15a5e..a4bc22e65 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -110,6 +110,8 @@ class Dependant: _impartial(self.call) ) or inspect.isgeneratorfunction(_unwrapped_call(self.call)): return True + if inspect.isclass(_unwrapped_call(self.call)): + return False dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004 if dunder_call is None: return False # pragma: no cover @@ -134,6 +136,8 @@ class Dependant: _impartial(self.call) ) or inspect.isasyncgenfunction(_unwrapped_call(self.call)): return True + if inspect.isclass(_unwrapped_call(self.call)): + return False dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004 if dunder_call is None: return False # pragma: no cover @@ -162,6 +166,8 @@ class Dependant: _unwrapped_call(self.call) ): return True + if inspect.isclass(_unwrapped_call(self.call)): + return False dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004 if dunder_call is None: return False # pragma: no cover @@ -176,7 +182,6 @@ class Dependant: _impartial(dunder_unwrapped_call) ) or iscoroutinefunction(_unwrapped_call(dunder_unwrapped_call)): return True - # if inspect.isclass(self.call): False, covered by default return return False @cached_property diff --git a/fastapi/exceptions.py b/fastapi/exceptions.py index 0620428be..a46e82350 100644 --- a/fastapi/exceptions.py +++ b/fastapi/exceptions.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Sequence, Type, Union +from typing import Any, Dict, Optional, Sequence, Type, TypedDict, Union from annotated_doc import Doc from pydantic import BaseModel, create_model @@ -7,6 +7,13 @@ from starlette.exceptions import WebSocketException as StarletteWebSocketExcepti from typing_extensions import Annotated +class EndpointContext(TypedDict, total=False): + function: str + path: str + file: str + line: int + + class HTTPException(StarletteHTTPException): """ An HTTP exception you can raise in your own code to show errors to the client. @@ -155,30 +162,72 @@ class DependencyScopeError(FastAPIError): class ValidationException(Exception): - def __init__(self, errors: Sequence[Any]) -> None: + def __init__( + self, + errors: Sequence[Any], + *, + endpoint_ctx: Optional[EndpointContext] = None, + ) -> None: self._errors = errors + self.endpoint_ctx = endpoint_ctx + + ctx = endpoint_ctx or {} + self.endpoint_function = ctx.get("function") + self.endpoint_path = ctx.get("path") + self.endpoint_file = ctx.get("file") + self.endpoint_line = ctx.get("line") def errors(self) -> Sequence[Any]: return self._errors + def _format_endpoint_context(self) -> str: + if not (self.endpoint_file and self.endpoint_line and self.endpoint_function): + if self.endpoint_path: + return f"\n Endpoint: {self.endpoint_path}" + return "" + + context = f'\n File "{self.endpoint_file}", line {self.endpoint_line}, in {self.endpoint_function}' + if self.endpoint_path: + context += f"\n {self.endpoint_path}" + return context + + def __str__(self) -> str: + message = f"{len(self._errors)} validation error{'s' if len(self._errors) != 1 else ''}:\n" + for err in self._errors: + message += f" {err}\n" + message += self._format_endpoint_context() + return message.rstrip() + class RequestValidationError(ValidationException): - def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: - super().__init__(errors) + def __init__( + self, + errors: Sequence[Any], + *, + body: Any = None, + endpoint_ctx: Optional[EndpointContext] = None, + ) -> None: + super().__init__(errors, endpoint_ctx=endpoint_ctx) self.body = body class WebSocketRequestValidationError(ValidationException): - pass + def __init__( + self, + errors: Sequence[Any], + *, + endpoint_ctx: Optional[EndpointContext] = None, + ) -> None: + super().__init__(errors, endpoint_ctx=endpoint_ctx) class ResponseValidationError(ValidationException): - def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: - super().__init__(errors) + def __init__( + self, + errors: Sequence[Any], + *, + body: Any = None, + endpoint_ctx: Optional[EndpointContext] = None, + ) -> None: + super().__init__(errors, endpoint_ctx=endpoint_ctx) self.body = body - - def __str__(self) -> str: - message = f"{len(self._errors)} validation errors:\n" - for err in self._errors: - message += f" {err}\n" - return message diff --git a/fastapi/routing.py b/fastapi/routing.py index c10175b16..9be2b44bc 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -46,6 +46,7 @@ from fastapi.dependencies.utils import ( ) from fastapi.encoders import jsonable_encoder from fastapi.exceptions import ( + EndpointContext, FastAPIError, RequestValidationError, ResponseValidationError, @@ -212,6 +213,33 @@ def _merge_lifespan_context( return merged_lifespan # type: ignore[return-value] +# Cache for endpoint context to avoid re-extracting on every request +_endpoint_context_cache: Dict[int, EndpointContext] = {} + + +def _extract_endpoint_context(func: Any) -> EndpointContext: + """Extract endpoint context with caching to avoid repeated file I/O.""" + func_id = id(func) + + if func_id in _endpoint_context_cache: + return _endpoint_context_cache[func_id] + + try: + ctx: EndpointContext = {} + + if (source_file := inspect.getsourcefile(func)) is not None: + ctx["file"] = source_file + if (line_number := inspect.getsourcelines(func)[1]) is not None: + ctx["line"] = line_number + if (func_name := getattr(func, "__name__", None)) is not None: + ctx["function"] = func_name + except Exception: + ctx = EndpointContext() + + _endpoint_context_cache[func_id] = ctx + return ctx + + async def serialize_response( *, field: Optional[ModelField] = None, @@ -223,6 +251,7 @@ async def serialize_response( exclude_defaults: bool = False, exclude_none: bool = False, is_coroutine: bool = True, + endpoint_ctx: Optional[EndpointContext] = None, ) -> Any: if field: errors = [] @@ -245,8 +274,11 @@ async def serialize_response( elif errors_: errors.append(errors_) if errors: + ctx = endpoint_ctx or EndpointContext() raise ResponseValidationError( - errors=_normalize_errors(errors), body=response_content + errors=_normalize_errors(errors), + body=response_content, + endpoint_ctx=ctx, ) if hasattr(field, "serialize"): @@ -318,6 +350,18 @@ def get_request_handler( "fastapi_middleware_astack not found in request scope" ) + # Extract endpoint context for error messages + endpoint_ctx = ( + _extract_endpoint_context(dependant.call) + if dependant.call + else EndpointContext() + ) + + if dependant.path: + # For mounted sub-apps, include the mount path prefix + mount_path = request.scope.get("root_path", "").rstrip("/") + endpoint_ctx["path"] = f"{request.method} {mount_path}{dependant.path}" + # Read body and auto-close files try: body: Any = None @@ -355,6 +399,7 @@ def get_request_handler( } ], body=e.doc, + endpoint_ctx=endpoint_ctx, ) raise validation_error from e except HTTPException: @@ -414,6 +459,7 @@ def get_request_handler( exclude_defaults=response_model_exclude_defaults, exclude_none=response_model_exclude_none, is_coroutine=is_coroutine, + endpoint_ctx=endpoint_ctx, ) response = actual_response_class(content, **response_args) if not is_body_allowed_for_status_code(response.status_code): @@ -421,7 +467,7 @@ def get_request_handler( response.headers.raw.extend(solved_result.response.headers.raw) if errors: validation_error = RequestValidationError( - _normalize_errors(errors), body=body + _normalize_errors(errors), body=body, endpoint_ctx=endpoint_ctx ) raise validation_error @@ -438,6 +484,15 @@ def get_websocket_app( embed_body_fields: bool = False, ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: async def app(websocket: WebSocket) -> None: + endpoint_ctx = ( + _extract_endpoint_context(dependant.call) + if dependant.call + else EndpointContext() + ) + if dependant.path: + # For mounted sub-apps, include the mount path prefix + mount_path = websocket.scope.get("root_path", "").rstrip("/") + endpoint_ctx["path"] = f"WS {mount_path}{dependant.path}" async_exit_stack = websocket.scope.get("fastapi_inner_astack") assert isinstance(async_exit_stack, AsyncExitStack), ( "fastapi_inner_astack not found in request scope" @@ -451,7 +506,8 @@ def get_websocket_app( ) if solved_result.errors: raise WebSocketRequestValidationError( - _normalize_errors(solved_result.errors) + _normalize_errors(solved_result.errors), + endpoint_ctx=endpoint_ctx, ) assert dependant.call is not None, "dependant.call must be a function" await dependant.call(**solved_result.values) diff --git a/scripts/mkdocs_hooks.py b/scripts/mkdocs_hooks.py index b9e4ff59e..09cfa99e3 100644 --- a/scripts/mkdocs_hooks.py +++ b/scripts/mkdocs_hooks.py @@ -132,7 +132,7 @@ def on_pre_page(page: Page, *, config: MkDocsConfig, files: Files) -> Page: def on_page_markdown( markdown: str, *, page: Page, config: MkDocsConfig, files: Files ) -> str: - # Set matadata["social"]["cards_layout_options"]["title"] to clean title (without + # Set metadata["social"]["cards_layout_options"]["title"] to clean title (without # permalink) title = page.title clean_title = title.split("{ #")[0] diff --git a/tests/test_dependency_class.py b/tests/test_dependency_class.py index 0233492e6..75241b467 100644 --- a/tests/test_dependency_class.py +++ b/tests/test_dependency_class.py @@ -48,6 +48,34 @@ async_callable_gen_dependency = AsyncCallableGenDependency() methods_dependency = MethodsDependency() +@app.get("/callable-dependency-class") +async def get_callable_dependency_class( + value: str, instance: CallableDependency = Depends() +): + return instance(value) + + +@app.get("/callable-gen-dependency-class") +async def get_callable_gen_dependency_class( + value: str, instance: CallableGenDependency = Depends() +): + return next(instance(value)) + + +@app.get("/async-callable-dependency-class") +async def get_async_callable_dependency_class( + value: str, instance: AsyncCallableDependency = Depends() +): + return await instance(value) + + +@app.get("/async-callable-gen-dependency-class") +async def get_async_callable_gen_dependency_class( + value: str, instance: AsyncCallableGenDependency = Depends() +): + return await instance(value).__anext__() + + @app.get("/callable-dependency") async def get_callable_dependency(value: str = Depends(callable_dependency)): return value @@ -114,6 +142,10 @@ client = TestClient(app) ("/synchronous-method-gen-dependency", "synchronous-method-gen-dependency"), ("/asynchronous-method-dependency", "asynchronous-method-dependency"), ("/asynchronous-method-gen-dependency", "asynchronous-method-gen-dependency"), + ("/callable-dependency-class", "callable-dependency-class"), + ("/callable-gen-dependency-class", "callable-gen-dependency-class"), + ("/async-callable-dependency-class", "async-callable-dependency-class"), + ("/async-callable-gen-dependency-class", "async-callable-gen-dependency-class"), ], ) def test_class_dependency(route, value): diff --git a/tests/test_openapi_separate_input_output_schemas.py b/tests/test_openapi_separate_input_output_schemas.py index fa73620ea..c9a05418b 100644 --- a/tests/test_openapi_separate_input_output_schemas.py +++ b/tests/test_openapi_separate_input_output_schemas.py @@ -24,6 +24,18 @@ class Item(BaseModel): model_config = {"json_schema_serialization_defaults_required": True} +if PYDANTIC_V2: + from pydantic import computed_field + + class WithComputedField(BaseModel): + name: str + + @computed_field + @property + def computed_field(self) -> str: + return f"computed {self.name}" + + def get_app_client(separate_input_output_schemas: bool = True) -> TestClient: app = FastAPI(separate_input_output_schemas=separate_input_output_schemas) @@ -46,6 +58,14 @@ def get_app_client(separate_input_output_schemas: bool = True) -> TestClient: Item(name="Plumbus"), ] + if PYDANTIC_V2: + + @app.post("/with-computed-field/") + def create_with_computed_field( + with_computed_field: WithComputedField, + ) -> WithComputedField: + return with_computed_field + client = TestClient(app) return client @@ -131,6 +151,23 @@ def test_read_items(): ) +@needs_pydanticv2 +def test_with_computed_field(): + client = get_app_client() + client_no = get_app_client(separate_input_output_schemas=False) + response = client.post("/with-computed-field/", json={"name": "example"}) + response2 = client_no.post("/with-computed-field/", json={"name": "example"}) + assert response.status_code == response2.status_code == 200, response.text + assert ( + response.json() + == response2.json() + == { + "name": "example", + "computed_field": "computed example", + } + ) + + @needs_pydanticv2 def test_openapi_schema(): client = get_app_client() @@ -245,6 +282,44 @@ def test_openapi_schema(): }, } }, + "/with-computed-field/": { + "post": { + "summary": "Create With Computed Field", + "operationId": "create_with_computed_field_with_computed_field__post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WithComputedField-Input" + } + } + }, + "required": True, + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WithComputedField-Output" + } + } + }, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + }, + }, }, "components": { "schemas": { @@ -333,6 +408,25 @@ def test_openapi_schema(): "required": ["subname", "sub_description", "tags"], "title": "SubItem", }, + "WithComputedField-Input": { + "properties": {"name": {"type": "string", "title": "Name"}}, + "type": "object", + "required": ["name"], + "title": "WithComputedField", + }, + "WithComputedField-Output": { + "properties": { + "name": {"type": "string", "title": "Name"}, + "computed_field": { + "type": "string", + "title": "Computed Field", + "readOnly": True, + }, + }, + "type": "object", + "required": ["name", "computed_field"], + "title": "WithComputedField", + }, "ValidationError": { "properties": { "loc": { @@ -458,6 +552,44 @@ def test_openapi_schema_no_separate(): }, } }, + "/with-computed-field/": { + "post": { + "summary": "Create With Computed Field", + "operationId": "create_with_computed_field_with_computed_field__post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WithComputedField-Input" + } + } + }, + "required": True, + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WithComputedField-Output" + } + } + }, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + }, + }, }, "components": { "schemas": { @@ -508,6 +640,25 @@ def test_openapi_schema_no_separate(): "required": ["subname"], "title": "SubItem", }, + "WithComputedField-Input": { + "properties": {"name": {"type": "string", "title": "Name"}}, + "type": "object", + "required": ["name"], + "title": "WithComputedField", + }, + "WithComputedField-Output": { + "properties": { + "name": {"type": "string", "title": "Name"}, + "computed_field": { + "type": "string", + "title": "Computed Field", + "readOnly": True, + }, + }, + "type": "object", + "required": ["name", "computed_field"], + "title": "WithComputedField", + }, "ValidationError": { "properties": { "loc": { diff --git a/tests/test_validation_error_context.py b/tests/test_validation_error_context.py new file mode 100644 index 000000000..844b8a64f --- /dev/null +++ b/tests/test_validation_error_context.py @@ -0,0 +1,168 @@ +from fastapi import FastAPI, Request, WebSocket +from fastapi.exceptions import ( + RequestValidationError, + ResponseValidationError, + WebSocketRequestValidationError, +) +from fastapi.testclient import TestClient +from pydantic import BaseModel + + +class Item(BaseModel): + id: int + name: str + + +class ExceptionCapture: + def __init__(self): + self.exception = None + + def capture(self, exc): + self.exception = exc + return exc + + +app = FastAPI() +sub_app = FastAPI() +captured_exception = ExceptionCapture() + +app.mount(path="/sub", app=sub_app) + + +@app.exception_handler(RequestValidationError) +@sub_app.exception_handler(RequestValidationError) +async def request_validation_handler(request: Request, exc: RequestValidationError): + captured_exception.capture(exc) + raise exc + + +@app.exception_handler(ResponseValidationError) +@sub_app.exception_handler(ResponseValidationError) +async def response_validation_handler(_: Request, exc: ResponseValidationError): + captured_exception.capture(exc) + raise exc + + +@app.exception_handler(WebSocketRequestValidationError) +@sub_app.exception_handler(WebSocketRequestValidationError) +async def websocket_validation_handler( + websocket: WebSocket, exc: WebSocketRequestValidationError +): + captured_exception.capture(exc) + raise exc + + +@app.get("/users/{user_id}") +def get_user(user_id: int): + return {"user_id": user_id} # pragma: no cover + + +@app.get("/items/", response_model=Item) +def get_item(): + return {"name": "Widget"} + + +@sub_app.get("/items/", response_model=Item) +def get_sub_item(): + return {"name": "Widget"} # pragma: no cover + + +@app.websocket("/ws/{item_id}") +async def websocket_endpoint(websocket: WebSocket, item_id: int): + await websocket.accept() # pragma: no cover + await websocket.send_text(f"Item: {item_id}") # pragma: no cover + await websocket.close() # pragma: no cover + + +@sub_app.websocket("/ws/{item_id}") +async def subapp_websocket_endpoint(websocket: WebSocket, item_id: int): + await websocket.accept() # pragma: no cover + await websocket.send_text(f"Item: {item_id}") # pragma: no cover + await websocket.close() # pragma: no cover + + +client = TestClient(app) + + +def test_request_validation_error_includes_endpoint_context(): + captured_exception.exception = None + try: + client.get("/users/invalid") + except Exception: + pass + + assert captured_exception.exception is not None + error_str = str(captured_exception.exception) + assert "get_user" in error_str + assert "/users/" in error_str + + +def test_response_validation_error_includes_endpoint_context(): + captured_exception.exception = None + try: + client.get("/items/") + except Exception: + pass + + assert captured_exception.exception is not None + error_str = str(captured_exception.exception) + assert "get_item" in error_str + assert "/items/" in error_str + + +def test_websocket_validation_error_includes_endpoint_context(): + captured_exception.exception = None + try: + with client.websocket_connect("/ws/invalid"): + pass # pragma: no cover + except Exception: + pass + + assert captured_exception.exception is not None + error_str = str(captured_exception.exception) + assert "websocket_endpoint" in error_str + assert "/ws/" in error_str + + +def test_subapp_request_validation_error_includes_endpoint_context(): + captured_exception.exception = None + try: + client.get("/sub/items/") + except Exception: + pass + + assert captured_exception.exception is not None + error_str = str(captured_exception.exception) + assert "get_sub_item" in error_str + assert "/sub/items/" in error_str + + +def test_subapp_websocket_validation_error_includes_endpoint_context(): + captured_exception.exception = None + try: + with client.websocket_connect("/sub/ws/invalid"): + pass # pragma: no cover + except Exception: + pass + + assert captured_exception.exception is not None + error_str = str(captured_exception.exception) + assert "subapp_websocket_endpoint" in error_str + assert "/sub/ws/" in error_str + + +def test_validation_error_with_only_path(): + errors = [{"type": "missing", "loc": ("body", "name"), "msg": "Field required"}] + exc = RequestValidationError(errors, endpoint_ctx={"path": "GET /api/test"}) + error_str = str(exc) + assert "Endpoint: GET /api/test" in error_str + assert 'File "' not in error_str + + +def test_validation_error_with_no_context(): + errors = [{"type": "missing", "loc": ("body", "name"), "msg": "Field required"}] + exc = RequestValidationError(errors, endpoint_ctx={}) + error_str = str(exc) + assert "1 validation error:" in error_str + assert "Endpoint" not in error_str + assert 'File "' not in error_str