diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 1e92c1ba2..7733ba1a2 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -54,6 +54,7 @@ from fastapi.concurrency import ( asynccontextmanager, contextmanager_in_threadpool, ) +from fastapi.datastructures import UploadFile as FastAPIUploadFile from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.exceptions import DependencyScopeError from fastapi.logger import logger @@ -875,31 +876,39 @@ async def _extract_form_body( for field in body_fields: value = _get_multidict_value(field, received_body) field_info = field.field_info - if ( + if ( # fmt: skip isinstance(field_info, (params.File, temp_pydantic_v1_params.File)) - and is_bytes_field(field) and isinstance(value, UploadFile) ): - value = await value.read() - elif ( - is_bytes_sequence_field(field) - and isinstance(field_info, (params.File, temp_pydantic_v1_params.File)) + if is_bytes_field(field): + value = await value.read() + else: + value = FastAPIUploadFile.from_starlette(value) + elif ( # fmt: skip + isinstance(field_info, (params.File, temp_pydantic_v1_params.File)) and value_is_sequence(value) ): - # For types - assert isinstance(value, sequence_types) # type: ignore[arg-type] - results: List[Union[bytes, str]] = [] + if is_bytes_sequence_field(field): + # For types + assert isinstance(value, sequence_types) # type: ignore[arg-type] + results: List[Union[bytes, str]] = [] - async def process_fn( - fn: Callable[[], Coroutine[Any, Any, Any]], - ) -> None: - result = await fn() - results.append(result) # noqa: B023 + async def process_fn( + fn: Callable[[], Coroutine[Any, Any, Any]], + ) -> None: + result = await fn() + results.append(result) # noqa: B023 - async with anyio.create_task_group() as tg: - for sub_value in value: - tg.start_soon(process_fn, sub_value.read) - value = serialize_sequence_value(field=field, value=results) + async with anyio.create_task_group() as tg: + for sub_value in value: + tg.start_soon(process_fn, sub_value.read) + value = serialize_sequence_value(field=field, value=results) + else: + value = [ + FastAPIUploadFile.from_starlette(sub_value) + for sub_value in value + if isinstance(sub_value, UploadFile) + ] if value is not None: values[field.alias] = value for key, value in received_body.items(): diff --git a/fastapi/routing.py b/fastapi/routing.py index f46c2e89d..a8e12eb60 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -34,7 +34,7 @@ from fastapi._compat import ( _normalize_errors, lenient_issubclass, ) -from fastapi.datastructures import Default, DefaultPlaceholder, UploadFile +from fastapi.datastructures import Default, DefaultPlaceholder from fastapi.dependencies.models import Dependant from fastapi.dependencies.utils import ( _should_embed_body_fields, @@ -65,7 +65,6 @@ from starlette import routing from starlette._exception_handler import wrap_app_handling_exceptions from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool -from starlette.datastructures import UploadFile as StarletteUploadFile from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -287,19 +286,6 @@ async def run_endpoint_function( # facilitate profiling endpoints, since inner functions are harder to profile. assert dependant.call is not None, "dependant.call must be a function" - # Convert all Starlette UploadFiles to FastAPI UploadFiles - for key, value in values.items(): - if isinstance(value, StarletteUploadFile) and not isinstance(value, UploadFile): - values[key] = UploadFile.from_starlette(value) - elif isinstance(value, list): - values[key] = [ - UploadFile.from_starlette(item) - if isinstance(item, StarletteUploadFile) - and not isinstance(item, UploadFile) - else item - for item in value - ] - if is_coroutine: return await dependant.call(**values) else: diff --git a/tests/test_request_uploadfile_type.py b/tests/test_request_uploadfile_type.py index 47dec7817..6d30f818b 100644 --- a/tests/test_request_uploadfile_type.py +++ b/tests/test_request_uploadfile_type.py @@ -1,7 +1,8 @@ import io from typing import Any, Dict, List -from fastapi import FastAPI, File, UploadFile +import pytest +from fastapi import Depends, FastAPI, File, UploadFile from fastapi.testclient import TestClient from starlette.datastructures import UploadFile as StarletteUploadFile @@ -33,10 +34,48 @@ async def uploadfiles( ] -def test_uploadfile_type() -> None: +async def get_uploadfile_info(uploadfile: UploadFile = File(...)) -> Dict[str, Any]: + return { + "filename": uploadfile.filename, + "is_fastapi_uploadfile": isinstance(uploadfile, UploadFile), + "is_starlette_uploadfile": isinstance(uploadfile, StarletteUploadFile), + "class": f"{uploadfile.__class__.__module__}.{uploadfile.__class__.__name__}", + } + + +@app.post("/uploadfile-dep") +async def uploadfile_dep( + uploadfile_info: Dict[str, Any] = Depends(get_uploadfile_info), +) -> Dict[str, Any]: + return uploadfile_info + + +async def get_uploadfiles_info( + uploadfiles: List[UploadFile] = File(...), +) -> List[Dict[str, Any]]: + return [ + { + "filename": uploadfile.filename, + "is_fastapi_uploadfile": isinstance(uploadfile, UploadFile), + "is_starlette_uploadfile": isinstance(uploadfile, StarletteUploadFile), + "class": f"{uploadfile.__class__.__module__}.{uploadfile.__class__.__name__}", + } + for uploadfile in uploadfiles + ] + + +@app.post("/uploadfiles-dep") +async def uploadfiles_dep( + uploadfiles_info: List[Dict[str, Any]] = Depends(get_uploadfiles_info), +) -> List[Dict[str, Any]]: + return uploadfiles_info + + +@pytest.mark.parametrize("endpoint", ["/uploadfile", "/uploadfile-dep"]) +def test_uploadfile_type(endpoint: str) -> None: client = TestClient(app) files = {"uploadfile": ("example.txt", io.BytesIO(b"test content"), "text/plain")} - response = client.post("/uploadfile/", files=files) + response = client.post(f"{endpoint}", files=files) data = response.json() assert data["filename"] == "example.txt" @@ -45,19 +84,26 @@ def test_uploadfile_type() -> None: assert data["class"].startswith("fastapi.") -def test_uploadfiles_type() -> None: +@pytest.mark.parametrize("endpoint", ["/uploadfiles", "/uploadfiles-dep"]) +def test_uploadfiles_type(endpoint: str) -> None: client = TestClient(app) files = [ - ("uploadfiles", ("example.txt", io.BytesIO(b"test content"), "text/plain")) + ("uploadfiles", ("example.txt", io.BytesIO(b"test content"), "text/plain")), + ("uploadfiles", ("example2.txt", io.BytesIO(b"test content"), "text/plain")), ] - response = client.post("/uploadfiles/", files=files) + response = client.post(f"{endpoint}", files=files) files_data = response.json() - assert len(files_data) == 1 + assert len(files_data) == 2 - data = files_data[0] + file1 = files_data[0] + assert file1["filename"] == "example.txt" + assert file1["is_fastapi_uploadfile"] is True + assert file1["is_starlette_uploadfile"] is True + assert file1["class"].startswith("fastapi.") - assert data["filename"] == "example.txt" - assert data["is_fastapi_uploadfile"] is True - assert data["is_starlette_uploadfile"] is True - assert data["class"].startswith("fastapi.") + file2 = files_data[1] + assert file2["filename"] == "example2.txt" + assert file2["is_fastapi_uploadfile"] is True + assert file2["is_starlette_uploadfile"] is True + assert file2["class"].startswith("fastapi.")