diff --git a/fastapi/applications.py b/fastapi/applications.py index 02193312b..da3d4e9f9 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -2624,6 +2624,36 @@ class FastAPI(Starlette): """ ), ] = Default(generate_unique_id), + form_max_fields: Annotated[ + int, + Doc( + """ + Maximum number of form fields to accept. + + This limits the number of fields in a form submission to prevent + potential denial-of-service attacks. + """ + ), + ] = 1000, + form_max_files: Annotated[ + int, + Doc( + """ + Maximum number of files to accept in a form submission. + + This limits the number of files in a form submission to prevent + potential denial-of-service attacks. + """ + ), + ] = 1000, + max_part_size: Annotated[ + int, + Doc( + """ + Maximum size (in bytes) for each part in a multipart form submission. + """ + ), + ] = 1024 * 1024, ) -> Callable[[DecoratedCallable], DecoratedCallable]: """ Add a *path operation* using an HTTP POST operation. @@ -2669,6 +2699,9 @@ class FastAPI(Starlette): callbacks=callbacks, openapi_extra=openapi_extra, generate_unique_id_function=generate_unique_id_function, + form_max_fields=form_max_fields, + form_max_files=form_max_files, + max_part_size=max_part_size, ) def delete( diff --git a/fastapi/routing.py b/fastapi/routing.py index 9be2b44bc..f06a18f9e 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -332,6 +332,9 @@ def get_request_handler( response_model_exclude_none: bool = False, dependency_overrides_provider: Optional[Any] = None, embed_body_fields: bool = False, + form_max_fields: int = 1000, + form_max_files: int = 1000, + max_part_size: int = 1024 * 1024, ) -> Callable[[Request], Coroutine[Any, Any, Response]]: assert dependant.call is not None, "dependant.call must be a function" is_coroutine = dependant.is_coroutine_callable @@ -367,7 +370,11 @@ def get_request_handler( body: Any = None if body_field: if is_body_form: - body = await request.form() + body = await request.form( + max_fields=form_max_fields, + max_files=form_max_files, + max_part_size=max_part_size, + ) file_stack.push_async_callback(body.close) else: body_bytes = await request.body() @@ -591,6 +598,9 @@ class APIRoute(routing.Route): generate_unique_id_function: Union[ Callable[["APIRoute"], str], DefaultPlaceholder ] = Default(generate_unique_id), + form_max_fields: int = 1000, + form_max_files: int = 1000, + max_part_size: int = 1024 * 1024, ) -> None: self.path = path self.endpoint = endpoint @@ -621,6 +631,9 @@ class APIRoute(routing.Route): self.responses = responses or {} self.name = get_name(endpoint) if name is None else name self.path_regex, self.path_format, self.param_convertors = compile_path(path) + self.form_max_fields = form_max_fields + self.form_max_files = form_max_files + self.max_part_size = max_part_size if methods is None: methods = ["GET"] self.methods: Set[str] = {method.upper() for method in methods} @@ -717,6 +730,9 @@ class APIRoute(routing.Route): response_model_exclude_none=self.response_model_exclude_none, dependency_overrides_provider=self.dependency_overrides_provider, embed_body_fields=self._embed_body_fields, + form_max_fields=self.form_max_fields, + form_max_files=self.form_max_files, + max_part_size=self.max_part_size, ) def matches(self, scope: Scope) -> Tuple[Match, Scope]: @@ -1045,6 +1061,9 @@ class APIRouter(routing.Router): generate_unique_id_function: Union[ Callable[[APIRoute], str], DefaultPlaceholder ] = Default(generate_unique_id), + form_max_fields: int = 1000, + form_max_files: int = 1000, + form_max_part_size: int = 1024 * 1024, ) -> None: route_class = route_class_override or self.route_class responses = responses or {} @@ -1091,6 +1110,9 @@ class APIRouter(routing.Router): callbacks=current_callbacks, openapi_extra=openapi_extra, generate_unique_id_function=current_generate_unique_id, + form_max_fields=form_max_fields, + form_max_files=form_max_files, + max_part_size=form_max_part_size, ) self.routes.append(route) @@ -1123,6 +1145,9 @@ class APIRouter(routing.Router): generate_unique_id_function: Callable[[APIRoute], str] = Default( generate_unique_id ), + form_max_fields: int = 1000, + form_max_files: int = 1000, + form_max_part_size: int = 1024 * 1024, ) -> Callable[[DecoratedCallable], DecoratedCallable]: def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_api_route( @@ -1151,6 +1176,9 @@ class APIRouter(routing.Router): callbacks=callbacks, openapi_extra=openapi_extra, generate_unique_id_function=generate_unique_id_function, + form_max_fields=form_max_fields, + form_max_files=form_max_files, + form_max_part_size=form_max_part_size, ) return func @@ -2587,6 +2615,36 @@ class APIRouter(routing.Router): """ ), ] = Default(generate_unique_id), + form_max_fields: Annotated[ + int, + Doc( + """ + Maximum number of form fields to accept. + + This limits the number of fields in a form submission to prevent + potential denial-of-service attacks. + """ + ), + ] = 1000, + form_max_files: Annotated[ + int, + Doc( + """ + Maximum number of files to accept in a form submission. + + This limits the number of files in a form submission to prevent + potential denial-of-service attacks. + """ + ), + ] = 1000, + max_part_size: Annotated[ + int, + Doc( + """ + Maximum size (in bytes) for each part in a multipart form submission. + """ + ), + ] = 1024 * 1024, ) -> Callable[[DecoratedCallable], DecoratedCallable]: """ Add a *path operation* using an HTTP POST operation. @@ -2636,6 +2694,9 @@ class APIRouter(routing.Router): callbacks=callbacks, openapi_extra=openapi_extra, generate_unique_id_function=generate_unique_id_function, + form_max_fields=form_max_fields, + form_max_files=form_max_files, + form_max_part_size=max_part_size, ) def delete( diff --git a/tests/test_form_max_fields_files_part_size.py b/tests/test_form_max_fields_files_part_size.py new file mode 100644 index 000000000..9a5056011 --- /dev/null +++ b/tests/test_form_max_fields_files_part_size.py @@ -0,0 +1,89 @@ +from typing import List + +from fastapi import FastAPI, File, UploadFile +from fastapi.testclient import TestClient + +app = FastAPI() + + +@app.post("/", form_max_files=2, max_part_size=1024, form_max_fields=2) +async def upload_files(files: List[UploadFile] = File(...)): + return {"filenames": [file.filename for file in files]} + + +def test_form_max_files_send_one(): + client = TestClient(app) + + response = client.post( + "/", + files=[ + ("files", ("file1.txt", b"file1 content", "text/plain")), + ], + ) + + assert response.status_code == 200, response.text + assert response.json() == {"filenames": ["file1.txt"]} + + +def test_form_max_files_send_too_many(): + client = TestClient(app) + + response = client.post( + "/", + files=[ + ("files", ("file1.txt", b"file1 content", "text/plain")), + ("files", ("file2.txt", b"file2 content", "text/plain")), + ("files", ("file3.txt", b"file3 content", "text/plain")), + ], + ) + + assert response.status_code == 400, response.text + assert response.json() == { + "detail": "Too many files. Maximum number of files is 2." + } + + +def test_max_part_size_exceeds_custom_limit(): + client = TestClient(app) + + boundary = "------------------------4K1ON9fZkj9uCUmqLHRbbR" + + multipart_data = ( + f"--{boundary}\r\n" + f'Content-Disposition: form-data; name="small"\r\n\r\n' + "small content\r\n" + f"--{boundary}\r\n" + f'Content-Disposition: form-data; name="large"\r\n\r\n' + + ("x" * 1024 + "x") # 1KB + 1 byte of data + + "\r\n" + f"--{boundary}--\r\n" + ).encode("utf-8") + + headers = { + "Content-Type": f"multipart/form-data; boundary={boundary}", + "Transfer-Encoding": "chunked", + } + + response = client.post("/", content=multipart_data, headers=headers) + assert response.status_code == 400 + assert response.text == '{"detail":"Part exceeded maximum size of 1KB."}' + + +def test_form_max_fields_exceeds_limit(): + client = TestClient(app) + + response = client.post( + "/", + files=[("files", ("file1.txt", b"file1 content", "text/plain"))], + data={ + "field1": "value1", + "field2": "value2", + "field3": "value3", + "field4": "value4", + }, + ) + + assert response.status_code == 400, response.text + assert response.json() == { + "detail": "Too many fields. Maximum number of fields is 2." + }