diff --git a/fastapi/datastructures.py b/fastapi/datastructures.py index 8ad9aa11a..c4ea83021 100644 --- a/fastapi/datastructures.py +++ b/fastapi/datastructures.py @@ -71,6 +71,17 @@ class UploadFile(StarletteUploadFile): Optional[str], Doc("The content type of the request, from the headers.") ] + @classmethod + def from_starlette( + cls: Type["UploadFile"], starlette_uploadfile: StarletteUploadFile + ) -> "UploadFile": + return cls( + file=starlette_uploadfile.file, + size=starlette_uploadfile.size, + filename=starlette_uploadfile.filename, + headers=starlette_uploadfile.headers, + ) + async def write( self, data: Annotated[ diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index cc7e55b4b..fc01c4ea6 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -55,6 +55,7 @@ from fastapi.concurrency import ( asynccontextmanager, contextmanager_in_threadpool, ) +from fastapi.datastructures import UploadFile as FastAPIUploadFile from fastapi.dependencies.models import Dependant from fastapi.exceptions import DependencyScopeError from fastapi.logger import logger @@ -908,31 +909,39 @@ async def _extract_form_body( for field in body_fields: value = _get_multidict_value(field, received_body) field_info = field.field_info - if ( + if ( # fmt: skip isinstance(field_info, (params.File, temp_pydantic_v1_params.File)) - and is_bytes_field(field) and isinstance(value, UploadFile) ): - value = await value.read() - elif ( - is_bytes_sequence_field(field) - and isinstance(field_info, (params.File, temp_pydantic_v1_params.File)) + if is_bytes_field(field): + value = await value.read() + else: + value = FastAPIUploadFile.from_starlette(value) + elif ( # fmt: skip + isinstance(field_info, (params.File, temp_pydantic_v1_params.File)) and value_is_sequence(value) ): - # For types - assert isinstance(value, sequence_types) # type: ignore[arg-type] - results: List[Union[bytes, str]] = [] + if is_bytes_sequence_field(field): + # For types + assert isinstance(value, sequence_types) # type: ignore[arg-type] + results: List[Union[bytes, str]] = [] - async def process_fn( - fn: Callable[[], Coroutine[Any, Any, Any]], - ) -> None: - result = await fn() - results.append(result) # noqa: B023 + async def process_fn( + fn: Callable[[], Coroutine[Any, Any, Any]], + ) -> None: + result = await fn() + results.append(result) # noqa: B023 - async with anyio.create_task_group() as tg: - for sub_value in value: - tg.start_soon(process_fn, sub_value.read) - value = serialize_sequence_value(field=field, value=results) + async with anyio.create_task_group() as tg: + for sub_value in value: + tg.start_soon(process_fn, sub_value.read) + value = serialize_sequence_value(field=field, value=results) + else: + value = [ + FastAPIUploadFile.from_starlette(sub_value) + for sub_value in value + if isinstance(sub_value, UploadFile) + ] if value is not None: values[get_validation_alias(field)] = value field_aliases = {get_validation_alias(field) for field in body_fields} diff --git a/tests/test_request_uploadfile_type.py b/tests/test_request_uploadfile_type.py new file mode 100644 index 000000000..6d30f818b --- /dev/null +++ b/tests/test_request_uploadfile_type.py @@ -0,0 +1,109 @@ +import io +from typing import Any, Dict, List + +import pytest +from fastapi import Depends, FastAPI, File, UploadFile +from fastapi.testclient import TestClient +from starlette.datastructures import UploadFile as StarletteUploadFile + +app = FastAPI() + + +@app.post("/uploadfile") +async def uploadfile(uploadfile: UploadFile = File(...)) -> Dict[str, Any]: + return { + "filename": uploadfile.filename, + "is_fastapi_uploadfile": isinstance(uploadfile, UploadFile), + "is_starlette_uploadfile": isinstance(uploadfile, StarletteUploadFile), + "class": f"{uploadfile.__class__.__module__}.{uploadfile.__class__.__name__}", + } + + +@app.post("/uploadfiles") +async def uploadfiles( + uploadfiles: List[UploadFile] = File(...), +) -> List[Dict[str, Any]]: + return [ + { + "filename": uploadfile.filename, + "is_fastapi_uploadfile": isinstance(uploadfile, UploadFile), + "is_starlette_uploadfile": isinstance(uploadfile, StarletteUploadFile), + "class": f"{uploadfile.__class__.__module__}.{uploadfile.__class__.__name__}", + } + for uploadfile in uploadfiles + ] + + +async def get_uploadfile_info(uploadfile: UploadFile = File(...)) -> Dict[str, Any]: + return { + "filename": uploadfile.filename, + "is_fastapi_uploadfile": isinstance(uploadfile, UploadFile), + "is_starlette_uploadfile": isinstance(uploadfile, StarletteUploadFile), + "class": f"{uploadfile.__class__.__module__}.{uploadfile.__class__.__name__}", + } + + +@app.post("/uploadfile-dep") +async def uploadfile_dep( + uploadfile_info: Dict[str, Any] = Depends(get_uploadfile_info), +) -> Dict[str, Any]: + return uploadfile_info + + +async def get_uploadfiles_info( + uploadfiles: List[UploadFile] = File(...), +) -> List[Dict[str, Any]]: + return [ + { + "filename": uploadfile.filename, + "is_fastapi_uploadfile": isinstance(uploadfile, UploadFile), + "is_starlette_uploadfile": isinstance(uploadfile, StarletteUploadFile), + "class": f"{uploadfile.__class__.__module__}.{uploadfile.__class__.__name__}", + } + for uploadfile in uploadfiles + ] + + +@app.post("/uploadfiles-dep") +async def uploadfiles_dep( + uploadfiles_info: List[Dict[str, Any]] = Depends(get_uploadfiles_info), +) -> List[Dict[str, Any]]: + return uploadfiles_info + + +@pytest.mark.parametrize("endpoint", ["/uploadfile", "/uploadfile-dep"]) +def test_uploadfile_type(endpoint: str) -> None: + client = TestClient(app) + files = {"uploadfile": ("example.txt", io.BytesIO(b"test content"), "text/plain")} + response = client.post(f"{endpoint}", files=files) + data = response.json() + + assert data["filename"] == "example.txt" + assert data["is_fastapi_uploadfile"] is True + assert data["is_starlette_uploadfile"] is True + assert data["class"].startswith("fastapi.") + + +@pytest.mark.parametrize("endpoint", ["/uploadfiles", "/uploadfiles-dep"]) +def test_uploadfiles_type(endpoint: str) -> None: + client = TestClient(app) + files = [ + ("uploadfiles", ("example.txt", io.BytesIO(b"test content"), "text/plain")), + ("uploadfiles", ("example2.txt", io.BytesIO(b"test content"), "text/plain")), + ] + response = client.post(f"{endpoint}", files=files) + files_data = response.json() + + assert len(files_data) == 2 + + file1 = files_data[0] + assert file1["filename"] == "example.txt" + assert file1["is_fastapi_uploadfile"] is True + assert file1["is_starlette_uploadfile"] is True + assert file1["class"].startswith("fastapi.") + + file2 = files_data[1] + assert file2["filename"] == "example2.txt" + assert file2["is_fastapi_uploadfile"] is True + assert file2["is_starlette_uploadfile"] is True + assert file2["class"].startswith("fastapi.")