Moved conversion logic to `_extract_form_body`, added test to cover case with dependency

This commit is contained in:
Yurii Motov 2025-11-24 14:15:36 +01:00
parent 689c11b535
commit fcdad3a183
3 changed files with 86 additions and 45 deletions

View File

@ -54,6 +54,7 @@ from fastapi.concurrency import (
asynccontextmanager,
contextmanager_in_threadpool,
)
from fastapi.datastructures import UploadFile as FastAPIUploadFile
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.exceptions import DependencyScopeError
from fastapi.logger import logger
@ -875,31 +876,39 @@ async def _extract_form_body(
for field in body_fields:
value = _get_multidict_value(field, received_body)
field_info = field.field_info
if (
if ( # fmt: skip
isinstance(field_info, (params.File, temp_pydantic_v1_params.File))
and is_bytes_field(field)
and isinstance(value, UploadFile)
):
value = await value.read()
elif (
is_bytes_sequence_field(field)
and isinstance(field_info, (params.File, temp_pydantic_v1_params.File))
if is_bytes_field(field):
value = await value.read()
else:
value = FastAPIUploadFile.from_starlette(value)
elif ( # fmt: skip
isinstance(field_info, (params.File, temp_pydantic_v1_params.File))
and value_is_sequence(value)
):
# For types
assert isinstance(value, sequence_types) # type: ignore[arg-type]
results: List[Union[bytes, str]] = []
if is_bytes_sequence_field(field):
# For types
assert isinstance(value, sequence_types) # type: ignore[arg-type]
results: List[Union[bytes, str]] = []
async def process_fn(
fn: Callable[[], Coroutine[Any, Any, Any]],
) -> None:
result = await fn()
results.append(result) # noqa: B023
async def process_fn(
fn: Callable[[], Coroutine[Any, Any, Any]],
) -> None:
result = await fn()
results.append(result) # noqa: B023
async with anyio.create_task_group() as tg:
for sub_value in value:
tg.start_soon(process_fn, sub_value.read)
value = serialize_sequence_value(field=field, value=results)
async with anyio.create_task_group() as tg:
for sub_value in value:
tg.start_soon(process_fn, sub_value.read)
value = serialize_sequence_value(field=field, value=results)
else:
value = [
FastAPIUploadFile.from_starlette(sub_value)
for sub_value in value
if isinstance(sub_value, UploadFile)
]
if value is not None:
values[field.alias] = value
for key, value in received_body.items():

View File

@ -34,7 +34,7 @@ from fastapi._compat import (
_normalize_errors,
lenient_issubclass,
)
from fastapi.datastructures import Default, DefaultPlaceholder, UploadFile
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import (
_should_embed_body_fields,
@ -65,7 +65,6 @@ from starlette import routing
from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import UploadFile as StarletteUploadFile
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
@ -287,19 +286,6 @@ async def run_endpoint_function(
# facilitate profiling endpoints, since inner functions are harder to profile.
assert dependant.call is not None, "dependant.call must be a function"
# Convert all Starlette UploadFiles to FastAPI UploadFiles
for key, value in values.items():
if isinstance(value, StarletteUploadFile) and not isinstance(value, UploadFile):
values[key] = UploadFile.from_starlette(value)
elif isinstance(value, list):
values[key] = [
UploadFile.from_starlette(item)
if isinstance(item, StarletteUploadFile)
and not isinstance(item, UploadFile)
else item
for item in value
]
if is_coroutine:
return await dependant.call(**values)
else:

View File

@ -1,7 +1,8 @@
import io
from typing import Any, Dict, List
from fastapi import FastAPI, File, UploadFile
import pytest
from fastapi import Depends, FastAPI, File, UploadFile
from fastapi.testclient import TestClient
from starlette.datastructures import UploadFile as StarletteUploadFile
@ -33,10 +34,48 @@ async def uploadfiles(
]
def test_uploadfile_type() -> None:
async def get_uploadfile_info(uploadfile: UploadFile = File(...)) -> Dict[str, Any]:
return {
"filename": uploadfile.filename,
"is_fastapi_uploadfile": isinstance(uploadfile, UploadFile),
"is_starlette_uploadfile": isinstance(uploadfile, StarletteUploadFile),
"class": f"{uploadfile.__class__.__module__}.{uploadfile.__class__.__name__}",
}
@app.post("/uploadfile-dep")
async def uploadfile_dep(
uploadfile_info: Dict[str, Any] = Depends(get_uploadfile_info),
) -> Dict[str, Any]:
return uploadfile_info
async def get_uploadfiles_info(
uploadfiles: List[UploadFile] = File(...),
) -> List[Dict[str, Any]]:
return [
{
"filename": uploadfile.filename,
"is_fastapi_uploadfile": isinstance(uploadfile, UploadFile),
"is_starlette_uploadfile": isinstance(uploadfile, StarletteUploadFile),
"class": f"{uploadfile.__class__.__module__}.{uploadfile.__class__.__name__}",
}
for uploadfile in uploadfiles
]
@app.post("/uploadfiles-dep")
async def uploadfiles_dep(
uploadfiles_info: List[Dict[str, Any]] = Depends(get_uploadfiles_info),
) -> List[Dict[str, Any]]:
return uploadfiles_info
@pytest.mark.parametrize("endpoint", ["/uploadfile", "/uploadfile-dep"])
def test_uploadfile_type(endpoint: str) -> None:
client = TestClient(app)
files = {"uploadfile": ("example.txt", io.BytesIO(b"test content"), "text/plain")}
response = client.post("/uploadfile/", files=files)
response = client.post(f"{endpoint}", files=files)
data = response.json()
assert data["filename"] == "example.txt"
@ -45,19 +84,26 @@ def test_uploadfile_type() -> None:
assert data["class"].startswith("fastapi.")
def test_uploadfiles_type() -> None:
@pytest.mark.parametrize("endpoint", ["/uploadfiles", "/uploadfiles-dep"])
def test_uploadfiles_type(endpoint: str) -> None:
client = TestClient(app)
files = [
("uploadfiles", ("example.txt", io.BytesIO(b"test content"), "text/plain"))
("uploadfiles", ("example.txt", io.BytesIO(b"test content"), "text/plain")),
("uploadfiles", ("example2.txt", io.BytesIO(b"test content"), "text/plain")),
]
response = client.post("/uploadfiles/", files=files)
response = client.post(f"{endpoint}", files=files)
files_data = response.json()
assert len(files_data) == 1
assert len(files_data) == 2
data = files_data[0]
file1 = files_data[0]
assert file1["filename"] == "example.txt"
assert file1["is_fastapi_uploadfile"] is True
assert file1["is_starlette_uploadfile"] is True
assert file1["class"].startswith("fastapi.")
assert data["filename"] == "example.txt"
assert data["is_fastapi_uploadfile"] is True
assert data["is_starlette_uploadfile"] is True
assert data["class"].startswith("fastapi.")
file2 = files_data[1]
assert file2["filename"] == "example2.txt"
assert file2["is_fastapi_uploadfile"] is True
assert file2["is_starlette_uploadfile"] is True
assert file2["class"].startswith("fastapi.")