🐛 Fix `convert_underscores=False` for header Pydantic models (#13515)

This commit is contained in:
Sebastián Ramírez 2025-03-23 20:48:54 +00:00 committed by GitHub
parent c08a3e8f22
commit 2537d9d1c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 457 additions and 8 deletions

View File

@ -51,6 +51,22 @@ For example, if the client tries to send a `tool` header with a value of `plumbu
}
```
## Disable Convert Underscores
The same way as with regular header parameters, when you have underscore characters in the parameter names, they are **automatically converted to hyphens**.
For example, if you have a header parameter `save_data` in the code, the expected HTTP header will be `save-data`, and it will show up like that in the docs.
If for some reason you need to disable this automatic conversion, you can do it as well for Pydantic models for header parameters.
{* ../../docs_src/header_param_models/tutorial003_an_py310.py hl[19] *}
/// warning
Before setting `convert_underscores` to `False`, bear in mind that some HTTP proxies and servers disallow the usage of headers with underscores.
///
## Summary
You can use **Pydantic models** to declare **headers** in **FastAPI**. 😎

View File

@ -0,0 +1,19 @@
from typing import List, Union
from fastapi import FastAPI, Header
from pydantic import BaseModel
app = FastAPI()
class CommonHeaders(BaseModel):
host: str
save_data: bool
if_modified_since: Union[str, None] = None
traceparent: Union[str, None] = None
x_tag: List[str] = []
@app.get("/items/")
async def read_items(headers: CommonHeaders = Header(convert_underscores=False)):
return headers

View File

@ -0,0 +1,22 @@
from typing import List, Union
from fastapi import FastAPI, Header
from pydantic import BaseModel
from typing_extensions import Annotated
app = FastAPI()
class CommonHeaders(BaseModel):
host: str
save_data: bool
if_modified_since: Union[str, None] = None
traceparent: Union[str, None] = None
x_tag: List[str] = []
@app.get("/items/")
async def read_items(
headers: Annotated[CommonHeaders, Header(convert_underscores=False)],
):
return headers

View File

@ -0,0 +1,21 @@
from typing import Annotated
from fastapi import FastAPI, Header
from pydantic import BaseModel
app = FastAPI()
class CommonHeaders(BaseModel):
host: str
save_data: bool
if_modified_since: str | None = None
traceparent: str | None = None
x_tag: list[str] = []
@app.get("/items/")
async def read_items(
headers: Annotated[CommonHeaders, Header(convert_underscores=False)],
):
return headers

View File

@ -0,0 +1,21 @@
from typing import Annotated, Union
from fastapi import FastAPI, Header
from pydantic import BaseModel
app = FastAPI()
class CommonHeaders(BaseModel):
host: str
save_data: bool
if_modified_since: Union[str, None] = None
traceparent: Union[str, None] = None
x_tag: list[str] = []
@app.get("/items/")
async def read_items(
headers: Annotated[CommonHeaders, Header(convert_underscores=False)],
):
return headers

View File

@ -0,0 +1,17 @@
from fastapi import FastAPI, Header
from pydantic import BaseModel
app = FastAPI()
class CommonHeaders(BaseModel):
host: str
save_data: bool
if_modified_since: str | None = None
traceparent: str | None = None
x_tag: list[str] = []
@app.get("/items/")
async def read_items(headers: CommonHeaders = Header(convert_underscores=False)):
return headers

View File

@ -0,0 +1,19 @@
from typing import Union
from fastapi import FastAPI, Header
from pydantic import BaseModel
app = FastAPI()
class CommonHeaders(BaseModel):
host: str
save_data: bool
if_modified_since: Union[str, None] = None
traceparent: Union[str, None] = None
x_tag: list[str] = []
@app.get("/items/")
async def read_items(headers: CommonHeaders = Header(convert_underscores=False)):
return headers

View File

@ -750,9 +750,15 @@ def request_params_to_args(
first_field = fields[0]
fields_to_extract = fields
single_not_embedded_field = False
default_convert_underscores = True
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
fields_to_extract = get_cached_model_fields(first_field.type_)
single_not_embedded_field = True
# If headers are in a Pydantic model, the way to disable convert_underscores
# would be with Header(convert_underscores=False) at the Pydantic model level
default_convert_underscores = getattr(
first_field.field_info, "convert_underscores", True
)
params_to_process: Dict[str, Any] = {}
@ -763,7 +769,9 @@ def request_params_to_args(
if isinstance(received_params, Headers):
# Handle fields extracted from a Pydantic Model for a header, each field
# doesn't have a FieldInfo of type Header with the default convert_underscores=True
convert_underscores = getattr(field.field_info, "convert_underscores", True)
convert_underscores = getattr(
field.field_info, "convert_underscores", default_convert_underscores
)
if convert_underscores:
alias = (
field.alias

View File

@ -32,6 +32,7 @@ from fastapi.utils import (
generate_operation_id_for_path,
is_body_allowed_for_status_code,
)
from pydantic import BaseModel
from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
@ -113,6 +114,13 @@ def _get_openapi_operation_parameters(
(ParamTypes.header, header_params),
(ParamTypes.cookie, cookie_params),
]
default_convert_underscores = True
if len(flat_dependant.header_params) == 1:
first_field = flat_dependant.header_params[0]
if lenient_issubclass(first_field.type_, BaseModel):
default_convert_underscores = getattr(
first_field.field_info, "convert_underscores", True
)
for param_type, param_group in parameter_groups:
for param in param_group:
field_info = param.field_info
@ -126,8 +134,21 @@ def _get_openapi_operation_parameters(
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
name = param.alias
convert_underscores = getattr(
param.field_info,
"convert_underscores",
default_convert_underscores,
)
if (
param_type == ParamTypes.header
and param.alias == param.name
and convert_underscores
):
name = param.name.replace("_", "-")
parameter = {
"name": param.alias,
"name": name,
"in": param_type.value,
"required": param.required,
"schema": param_schema,

View File

@ -129,13 +129,13 @@ def test_openapi_schema(client: TestClient):
"schema": {"type": "string", "title": "Host"},
},
{
"name": "save_data",
"name": "save-data",
"in": "header",
"required": True,
"schema": {"type": "boolean", "title": "Save Data"},
},
{
"name": "if_modified_since",
"name": "if-modified-since",
"in": "header",
"required": False,
"schema": IsDict(
@ -171,7 +171,7 @@ def test_openapi_schema(client: TestClient):
),
},
{
"name": "x_tag",
"name": "x-tag",
"in": "header",
"required": False,
"schema": {

View File

@ -140,13 +140,13 @@ def test_openapi_schema(client: TestClient):
"schema": {"type": "string", "title": "Host"},
},
{
"name": "save_data",
"name": "save-data",
"in": "header",
"required": True,
"schema": {"type": "boolean", "title": "Save Data"},
},
{
"name": "if_modified_since",
"name": "if-modified-since",
"in": "header",
"required": False,
"schema": IsDict(
@ -182,7 +182,7 @@ def test_openapi_schema(client: TestClient):
),
},
{
"name": "x_tag",
"name": "x-tag",
"in": "header",
"required": False,
"schema": {

View File

@ -0,0 +1,285 @@
import importlib
import pytest
from dirty_equals import IsDict
from fastapi.testclient import TestClient
from inline_snapshot import snapshot
from tests.utils import needs_py39, needs_py310
@pytest.fixture(
name="client",
params=[
"tutorial003",
pytest.param("tutorial003_py39", marks=needs_py39),
pytest.param("tutorial003_py310", marks=needs_py310),
"tutorial003_an",
pytest.param("tutorial003_an_py39", marks=needs_py39),
pytest.param("tutorial003_an_py310", marks=needs_py310),
],
)
def get_client(request: pytest.FixtureRequest):
mod = importlib.import_module(f"docs_src.header_param_models.{request.param}")
client = TestClient(mod.app)
return client
def test_header_param_model(client: TestClient):
response = client.get(
"/items/",
headers=[
("save_data", "true"),
("if_modified_since", "yesterday"),
("traceparent", "123"),
("x_tag", "one"),
("x_tag", "two"),
],
)
assert response.status_code == 200
assert response.json() == {
"host": "testserver",
"save_data": True,
"if_modified_since": "yesterday",
"traceparent": "123",
"x_tag": ["one", "two"],
}
def test_header_param_model_no_underscore(client: TestClient):
response = client.get(
"/items/",
headers=[
("save-data", "true"),
("if-modified-since", "yesterday"),
("traceparent", "123"),
("x-tag", "one"),
("x-tag", "two"),
],
)
assert response.status_code == 422
assert response.json() == snapshot(
{
"detail": [
IsDict(
{
"type": "missing",
"loc": ["header", "save_data"],
"msg": "Field required",
"input": {
"host": "testserver",
"traceparent": "123",
"x_tag": [],
"accept": "*/*",
"accept-encoding": "gzip, deflate",
"connection": "keep-alive",
"user-agent": "testclient",
"save-data": "true",
"if-modified-since": "yesterday",
"x-tag": "two",
},
}
)
| IsDict(
# TODO: remove when deprecating Pydantic v1
{
"type": "value_error.missing",
"loc": ["header", "save_data"],
"msg": "field required",
}
)
]
}
)
def test_header_param_model_defaults(client: TestClient):
response = client.get("/items/", headers=[("save_data", "true")])
assert response.status_code == 200
assert response.json() == {
"host": "testserver",
"save_data": True,
"if_modified_since": None,
"traceparent": None,
"x_tag": [],
}
def test_header_param_model_invalid(client: TestClient):
response = client.get("/items/")
assert response.status_code == 422
assert response.json() == snapshot(
{
"detail": [
IsDict(
{
"type": "missing",
"loc": ["header", "save_data"],
"msg": "Field required",
"input": {
"x_tag": [],
"host": "testserver",
"accept": "*/*",
"accept-encoding": "gzip, deflate",
"connection": "keep-alive",
"user-agent": "testclient",
},
}
)
| IsDict(
# TODO: remove when deprecating Pydantic v1
{
"type": "value_error.missing",
"loc": ["header", "save_data"],
"msg": "field required",
}
)
]
}
)
def test_header_param_model_extra(client: TestClient):
response = client.get(
"/items/", headers=[("save_data", "true"), ("tool", "plumbus")]
)
assert response.status_code == 200, response.text
assert response.json() == snapshot(
{
"host": "testserver",
"save_data": True,
"if_modified_since": None,
"traceparent": None,
"x_tag": [],
}
)
def test_openapi_schema(client: TestClient):
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == snapshot(
{
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/items/": {
"get": {
"summary": "Read Items",
"operationId": "read_items_items__get",
"parameters": [
{
"name": "host",
"in": "header",
"required": True,
"schema": {"type": "string", "title": "Host"},
},
{
"name": "save_data",
"in": "header",
"required": True,
"schema": {"type": "boolean", "title": "Save Data"},
},
{
"name": "if_modified_since",
"in": "header",
"required": False,
"schema": IsDict(
{
"anyOf": [{"type": "string"}, {"type": "null"}],
"title": "If Modified Since",
}
)
| IsDict(
# TODO: remove when deprecating Pydantic v1
{
"type": "string",
"title": "If Modified Since",
}
),
},
{
"name": "traceparent",
"in": "header",
"required": False,
"schema": IsDict(
{
"anyOf": [{"type": "string"}, {"type": "null"}],
"title": "Traceparent",
}
)
| IsDict(
# TODO: remove when deprecating Pydantic v1
{
"type": "string",
"title": "Traceparent",
}
),
},
{
"name": "x_tag",
"in": "header",
"required": False,
"schema": {
"type": "array",
"items": {"type": "string"},
"default": [],
"title": "X Tag",
},
},
],
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
}
}
},
"components": {
"schemas": {
"HTTPValidationError": {
"properties": {
"detail": {
"items": {
"$ref": "#/components/schemas/ValidationError"
},
"type": "array",
"title": "Detail",
}
},
"type": "object",
"title": "HTTPValidationError",
},
"ValidationError": {
"properties": {
"loc": {
"items": {
"anyOf": [{"type": "string"}, {"type": "integer"}]
},
"type": "array",
"title": "Location",
},
"msg": {"type": "string", "title": "Message"},
"type": {"type": "string", "title": "Error Type"},
},
"type": "object",
"required": ["loc", "msg", "type"],
"title": "ValidationError",
},
}
},
}
)