mirror of https://github.com/tiangolo/fastapi.git
✨ Add support for multi-file uploads (#158)
This commit is contained in:
parent
e40e87c662
commit
aad6b123f7
|
|
@ -0,0 +1,33 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi import FastAPI, File, UploadFile
|
||||||
|
from starlette.responses import HTMLResponse
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/files/")
|
||||||
|
async def create_files(files: List[bytes] = File(...)):
|
||||||
|
return {"file_sizes": [len(file) for file in files]}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/uploadfiles/")
|
||||||
|
async def create_upload_files(files: List[UploadFile] = File(...)):
|
||||||
|
return {"filenames": [file.filename for file in files]}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def main():
|
||||||
|
content = """
|
||||||
|
<body>
|
||||||
|
<form action="/files/" enctype="multipart/form-data" method="post">
|
||||||
|
<input name="files" type="file" multiple>
|
||||||
|
<input type="submit">
|
||||||
|
</form>
|
||||||
|
<form action="/uploadfiles/" enctype="multipart/form-data" method="post">
|
||||||
|
<input name="files" type="file" multiple>
|
||||||
|
<input type="submit">
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
"""
|
||||||
|
return HTMLResponse(content=content)
|
||||||
|
|
@ -43,7 +43,7 @@ Using `UploadFile` has several advantages over `bytes`:
|
||||||
|
|
||||||
* It uses a "spooled" file:
|
* It uses a "spooled" file:
|
||||||
* A file stored in memory up to a maximum size limit, and after passing this limit it will be stored in disk.
|
* A file stored in memory up to a maximum size limit, and after passing this limit it will be stored in disk.
|
||||||
* This means that it will work well for large files like images, videos, large binaries, etc. All without consuming all the memory.
|
* This means that it will work well for large files like images, videos, large binaries, etc. without consuming all the memory.
|
||||||
* You can get metadata from the uploaded file.
|
* You can get metadata from the uploaded file.
|
||||||
* It has a <a href="https://docs.python.org/3/glossary.html#term-file-like-object" target="_blank">file-like</a> `async` interface.
|
* It has a <a href="https://docs.python.org/3/glossary.html#term-file-like-object" target="_blank">file-like</a> `async` interface.
|
||||||
* It exposes an actual Python <a href="https://docs.python.org/3/library/tempfile.html#tempfile.SpooledTemporaryFile" target="_blank">`SpooledTemporaryFile`</a> object that you can pass directly to other libraries that expect a file-like object.
|
* It exposes an actual Python <a href="https://docs.python.org/3/library/tempfile.html#tempfile.SpooledTemporaryFile" target="_blank">`SpooledTemporaryFile`</a> object that you can pass directly to other libraries that expect a file-like object.
|
||||||
|
|
@ -107,6 +107,20 @@ The way HTML forms (`<form></form>`) sends the data to the server normally uses
|
||||||
|
|
||||||
This is not a limitation of **FastAPI**, it's part of the HTTP protocol.
|
This is not a limitation of **FastAPI**, it's part of the HTTP protocol.
|
||||||
|
|
||||||
|
## Multiple file uploads
|
||||||
|
|
||||||
|
It's possible to upload several files at the same time.
|
||||||
|
|
||||||
|
They would be associated to the same "form field" sent using "form data".
|
||||||
|
|
||||||
|
To use that, declare a `List` of `bytes` or `UploadFile`:
|
||||||
|
|
||||||
|
```Python hl_lines="10 15"
|
||||||
|
{!./src/request_files/tutorial002.py!}
|
||||||
|
```
|
||||||
|
|
||||||
|
You will receive, as declared, a `list` of `bytes` or `UploadFile`s.
|
||||||
|
|
||||||
## Recap
|
## Recap
|
||||||
|
|
||||||
Use `File` to declare files to be uploaded as input parameters (as form data).
|
Use `File` to declare files to be uploaded as input parameters (as form data).
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,8 @@ from pydantic.schema import get_annotation_from_schema
|
||||||
from pydantic.utils import lenient_issubclass
|
from pydantic.utils import lenient_issubclass
|
||||||
from starlette.background import BackgroundTasks
|
from starlette.background import BackgroundTasks
|
||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
from starlette.datastructures import UploadFile
|
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
||||||
from starlette.requests import Headers, QueryParams, Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
param_supported_types = (
|
param_supported_types = (
|
||||||
str,
|
str,
|
||||||
|
|
@ -47,6 +47,10 @@ param_supported_types = (
|
||||||
Decimal,
|
Decimal,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sequence_shapes = {Shape.LIST, Shape.SET, Shape.TUPLE}
|
||||||
|
sequence_types = (list, set, tuple)
|
||||||
|
sequence_shape_to_type = {Shape.LIST: list, Shape.SET: set, Shape.TUPLE: tuple}
|
||||||
|
|
||||||
|
|
||||||
def get_sub_dependant(
|
def get_sub_dependant(
|
||||||
*, param: inspect.Parameter, path: str, security_scopes: List[str] = None
|
*, param: inspect.Parameter, path: str, security_scopes: List[str] = None
|
||||||
|
|
@ -318,7 +322,7 @@ def request_params_to_args(
|
||||||
values = {}
|
values = {}
|
||||||
errors = []
|
errors = []
|
||||||
for field in required_params:
|
for field in required_params:
|
||||||
if field.shape in {Shape.LIST, Shape.SET, Shape.TUPLE} and isinstance(
|
if field.shape in sequence_shapes and isinstance(
|
||||||
received_params, (QueryParams, Headers)
|
received_params, (QueryParams, Headers)
|
||||||
):
|
):
|
||||||
value = received_params.getlist(field.alias)
|
value = received_params.getlist(field.alias)
|
||||||
|
|
@ -358,11 +362,20 @@ async def request_body_to_args(
|
||||||
embed = getattr(field.schema, "embed", None)
|
embed = getattr(field.schema, "embed", None)
|
||||||
if len(required_params) == 1 and not embed:
|
if len(required_params) == 1 and not embed:
|
||||||
received_body = {field.alias: received_body}
|
received_body = {field.alias: received_body}
|
||||||
elif received_body is None:
|
|
||||||
received_body = {}
|
|
||||||
for field in required_params:
|
for field in required_params:
|
||||||
value = received_body.get(field.alias)
|
if field.shape in sequence_shapes and isinstance(received_body, FormData):
|
||||||
if value is None or (isinstance(field.schema, params.Form) and value == ""):
|
value = received_body.getlist(field.alias)
|
||||||
|
else:
|
||||||
|
value = received_body.get(field.alias)
|
||||||
|
if (
|
||||||
|
value is None
|
||||||
|
or (isinstance(field.schema, params.Form) and value == "")
|
||||||
|
or (
|
||||||
|
isinstance(field.schema, params.Form)
|
||||||
|
and field.shape in sequence_shapes
|
||||||
|
and len(value) == 0
|
||||||
|
)
|
||||||
|
):
|
||||||
if field.required:
|
if field.required:
|
||||||
errors.append(
|
errors.append(
|
||||||
ErrorWrapper(
|
ErrorWrapper(
|
||||||
|
|
@ -380,6 +393,15 @@ async def request_body_to_args(
|
||||||
and isinstance(value, UploadFile)
|
and isinstance(value, UploadFile)
|
||||||
):
|
):
|
||||||
value = await value.read()
|
value = await value.read()
|
||||||
|
elif (
|
||||||
|
field.shape in sequence_shapes
|
||||||
|
and isinstance(field.schema, params.File)
|
||||||
|
and lenient_issubclass(field.type_, bytes)
|
||||||
|
and isinstance(value, sequence_types)
|
||||||
|
):
|
||||||
|
awaitables = [sub_value.read() for sub_value in value]
|
||||||
|
contents = await asyncio.gather(*awaitables)
|
||||||
|
value = sequence_shape_to_type[field.shape](contents)
|
||||||
v_, errors_ = field.validate(value, values, loc=("body", field.alias))
|
v_, errors_ = field.validate(value, values, loc=("body", field.alias))
|
||||||
if isinstance(errors_, ErrorWrapper):
|
if isinstance(errors_, ErrorWrapper):
|
||||||
errors.append(errors_)
|
errors.append(errors_)
|
||||||
|
|
@ -391,10 +413,14 @@ async def request_body_to_args(
|
||||||
|
|
||||||
|
|
||||||
def get_schema_compatible_field(*, field: Field) -> Field:
|
def get_schema_compatible_field(*, field: Field) -> Field:
|
||||||
|
out_field = field
|
||||||
if lenient_issubclass(field.type_, UploadFile):
|
if lenient_issubclass(field.type_, UploadFile):
|
||||||
return Field(
|
use_type: type = bytes
|
||||||
|
if field.shape in sequence_shapes:
|
||||||
|
use_type = List[bytes]
|
||||||
|
out_field = Field(
|
||||||
name=field.name,
|
name=field.name,
|
||||||
type_=bytes,
|
type_=use_type,
|
||||||
class_validators=field.class_validators,
|
class_validators=field.class_validators,
|
||||||
model_config=field.model_config,
|
model_config=field.model_config,
|
||||||
default=field.default,
|
default=field.default,
|
||||||
|
|
@ -402,10 +428,10 @@ def get_schema_compatible_field(*, field: Field) -> Field:
|
||||||
alias=field.alias,
|
alias=field.alias,
|
||||||
schema=field.schema,
|
schema=field.schema,
|
||||||
)
|
)
|
||||||
return field
|
return out_field
|
||||||
|
|
||||||
|
|
||||||
def get_body_field(*, dependant: Dependant, name: str) -> Field:
|
def get_body_field(*, dependant: Dependant, name: str) -> Optional[Field]:
|
||||||
flat_dependant = get_flat_dependant(dependant)
|
flat_dependant = get_flat_dependant(dependant)
|
||||||
if not flat_dependant.body_params:
|
if not flat_dependant.body_params:
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -53,12 +53,7 @@ def get_app(
|
||||||
body = None
|
body = None
|
||||||
if body_field:
|
if body_field:
|
||||||
if is_body_form:
|
if is_body_form:
|
||||||
raw_body = await request.form()
|
body = await request.form()
|
||||||
form_fields = {}
|
|
||||||
for field, value in raw_body.items():
|
|
||||||
form_fields[field] = value
|
|
||||||
if form_fields:
|
|
||||||
body = form_fields
|
|
||||||
else:
|
else:
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
if body_bytes:
|
if body_bytes:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,219 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from request_files.tutorial002 import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
openapi_schema = {
|
||||||
|
"openapi": "3.0.2",
|
||||||
|
"info": {"title": "Fast API", "version": "0.1.0"},
|
||||||
|
"paths": {
|
||||||
|
"/files/": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {"application/json": {"schema": {}}},
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/HTTPValidationError"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"summary": "Create Files",
|
||||||
|
"operationId": "create_files_files__post",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"multipart/form-data": {
|
||||||
|
"schema": {"$ref": "#/components/schemas/Body_create_files"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/uploadfiles/": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {"application/json": {"schema": {}}},
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/HTTPValidationError"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"summary": "Create Upload Files",
|
||||||
|
"operationId": "create_upload_files_uploadfiles__post",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"multipart/form-data": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/Body_create_upload_files"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/": {
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {"application/json": {"schema": {}}},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"summary": "Main",
|
||||||
|
"operationId": "main__get",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"components": {
|
||||||
|
"schemas": {
|
||||||
|
"Body_create_files": {
|
||||||
|
"title": "Body_create_files",
|
||||||
|
"required": ["files"],
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"files": {
|
||||||
|
"title": "Files",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string", "format": "binary"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Body_create_upload_files": {
|
||||||
|
"title": "Body_create_upload_files",
|
||||||
|
"required": ["files"],
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"files": {
|
||||||
|
"title": "Files",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string", "format": "binary"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"ValidationError": {
|
||||||
|
"title": "ValidationError",
|
||||||
|
"required": ["loc", "msg", "type"],
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"loc": {
|
||||||
|
"title": "Location",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
"msg": {"title": "Message", "type": "string"},
|
||||||
|
"type": {"title": "Error Type", "type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"HTTPValidationError": {
|
||||||
|
"title": "HTTPValidationError",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"detail": {
|
||||||
|
"title": "Detail",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"$ref": "#/components/schemas/ValidationError"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_openapi_schema():
|
||||||
|
response = client.get("/openapi.json")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == openapi_schema
|
||||||
|
|
||||||
|
|
||||||
|
file_required = {
|
||||||
|
"detail": [
|
||||||
|
{
|
||||||
|
"loc": ["body", "files"],
|
||||||
|
"msg": "field required",
|
||||||
|
"type": "value_error.missing",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_post_form_no_body():
|
||||||
|
response = client.post("/files/")
|
||||||
|
assert response.status_code == 422
|
||||||
|
assert response.json() == file_required
|
||||||
|
|
||||||
|
|
||||||
|
def test_post_body_json():
|
||||||
|
response = client.post("/files/", json={"file": "Foo"})
|
||||||
|
print(response)
|
||||||
|
print(response.content)
|
||||||
|
assert response.status_code == 422
|
||||||
|
assert response.json() == file_required
|
||||||
|
|
||||||
|
|
||||||
|
def test_post_files(tmpdir):
|
||||||
|
path = os.path.join(tmpdir, "test.txt")
|
||||||
|
with open(path, "wb") as file:
|
||||||
|
file.write(b"<file content>")
|
||||||
|
path2 = os.path.join(tmpdir, "test2.txt")
|
||||||
|
with open(path2, "wb") as file:
|
||||||
|
file.write(b"<file content2>")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.post(
|
||||||
|
"/files/",
|
||||||
|
files=(
|
||||||
|
("files", ("test.txt", open(path, "rb"))),
|
||||||
|
("files", ("test2.txt", open(path2, "rb"))),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"file_sizes": [14, 15]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_post_upload_file(tmpdir):
|
||||||
|
path = os.path.join(tmpdir, "test.txt")
|
||||||
|
with open(path, "wb") as file:
|
||||||
|
file.write(b"<file content>")
|
||||||
|
path2 = os.path.join(tmpdir, "test2.txt")
|
||||||
|
with open(path2, "wb") as file:
|
||||||
|
file.write(b"<file content2>")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.post(
|
||||||
|
"/uploadfiles/",
|
||||||
|
files=(
|
||||||
|
("files", ("test.txt", open(path, "rb"))),
|
||||||
|
("files", ("test2.txt", open(path2, "rb"))),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"filenames": ["test.txt", "test2.txt"]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_root():
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert b"<form" in response.content
|
||||||
Loading…
Reference in New Issue