mirror of https://github.com/tiangolo/fastapi.git
✨ Add support for strings and __future__ type annotations (#451)
* Add support for strings and __future__ annotations * Add comments indicating reason for string annotations * Fix ignores (including removing some unused ignores)
This commit is contained in:
parent
580cf8f4e2
commit
d8fe307d61
|
|
@ -26,7 +26,7 @@ from pydantic.error_wrappers import ErrorWrapper
|
||||||
from pydantic.errors import MissingError
|
from pydantic.errors import MissingError
|
||||||
from pydantic.fields import Field, Required, Shape
|
from pydantic.fields import Field, Required, Shape
|
||||||
from pydantic.schema import get_annotation_from_schema
|
from pydantic.schema import get_annotation_from_schema
|
||||||
from pydantic.utils import lenient_issubclass
|
from pydantic.utils import ForwardRef, evaluate_forwardref, lenient_issubclass
|
||||||
from starlette.background import BackgroundTasks
|
from starlette.background import BackgroundTasks
|
||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
||||||
|
|
@ -171,6 +171,30 @@ def is_scalar_sequence_field(field: Field) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_typed_signature(call: Callable) -> inspect.Signature:
|
||||||
|
signature = inspect.signature(call)
|
||||||
|
globalns = getattr(call, "__globals__", {})
|
||||||
|
typed_params = [
|
||||||
|
inspect.Parameter(
|
||||||
|
name=param.name,
|
||||||
|
kind=param.kind,
|
||||||
|
default=param.default,
|
||||||
|
annotation=get_typed_annotation(param, globalns),
|
||||||
|
)
|
||||||
|
for param in signature.parameters.values()
|
||||||
|
]
|
||||||
|
typed_signature = inspect.Signature(typed_params)
|
||||||
|
return typed_signature
|
||||||
|
|
||||||
|
|
||||||
|
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
|
||||||
|
annotation = param.annotation
|
||||||
|
if isinstance(annotation, str):
|
||||||
|
annotation = ForwardRef(annotation)
|
||||||
|
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||||
|
return annotation
|
||||||
|
|
||||||
|
|
||||||
def get_dependant(
|
def get_dependant(
|
||||||
*,
|
*,
|
||||||
path: str,
|
path: str,
|
||||||
|
|
@ -180,7 +204,7 @@ def get_dependant(
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
) -> Dependant:
|
) -> Dependant:
|
||||||
path_param_names = get_path_param_names(path)
|
path_param_names = get_path_param_names(path)
|
||||||
endpoint_signature = inspect.signature(call)
|
endpoint_signature = get_typed_signature(call)
|
||||||
signature_params = endpoint_signature.parameters
|
signature_params = endpoint_signature.parameters
|
||||||
dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
|
dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
|
||||||
for param_name, param in signature_params.items():
|
for param_name, param in signature_params.items():
|
||||||
|
|
@ -329,8 +353,12 @@ async def solve_dependencies(
|
||||||
]:
|
]:
|
||||||
values: Dict[str, Any] = {}
|
values: Dict[str, Any] = {}
|
||||||
errors: List[ErrorWrapper] = []
|
errors: List[ErrorWrapper] = []
|
||||||
response = response or Response( # type: ignore
|
response = response or Response(
|
||||||
content=None, status_code=None, headers=None, media_type=None, background=None
|
content=None,
|
||||||
|
status_code=None, # type: ignore
|
||||||
|
headers=None,
|
||||||
|
media_type=None,
|
||||||
|
background=None,
|
||||||
)
|
)
|
||||||
dependency_cache = dependency_cache or {}
|
dependency_cache = dependency_cache or {}
|
||||||
sub_dependant: Dependant
|
sub_dependant: Dependant
|
||||||
|
|
@ -405,7 +433,7 @@ async def solve_dependencies(
|
||||||
values.update(cookie_values)
|
values.update(cookie_values)
|
||||||
errors += path_errors + query_errors + header_errors + cookie_errors
|
errors += path_errors + query_errors + header_errors + cookie_errors
|
||||||
if dependant.body_params:
|
if dependant.body_params:
|
||||||
body_values, body_errors = await request_body_to_args( # type: ignore # body_params checked above
|
body_values, body_errors = await request_body_to_args( # body_params checked above
|
||||||
required_params=dependant.body_params, received_body=body
|
required_params=dependant.body_params, received_body=body
|
||||||
)
|
)
|
||||||
values.update(body_values)
|
values.update(body_values)
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ try:
|
||||||
import email_validator
|
import email_validator
|
||||||
|
|
||||||
assert email_validator # make autoflake ignore the unused import
|
assert email_validator # make autoflake ignore the unused import
|
||||||
from pydantic.types import EmailStr # type: ignore
|
from pydantic.types import EmailStr
|
||||||
except ImportError: # pragma: no cover
|
except ImportError: # pragma: no cover
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"email-validator not installed, email fields will be treated as str.\n"
|
"email-validator not installed, email fields will be treated as str.\n"
|
||||||
|
|
|
||||||
|
|
@ -58,10 +58,10 @@ def create_cloned_field(field: Field) -> Field:
|
||||||
use_type = original_type
|
use_type = original_type
|
||||||
if lenient_issubclass(original_type, BaseModel):
|
if lenient_issubclass(original_type, BaseModel):
|
||||||
original_type = cast(Type[BaseModel], original_type)
|
original_type = cast(Type[BaseModel], original_type)
|
||||||
use_type = create_model( # type: ignore
|
use_type = create_model(
|
||||||
original_type.__name__,
|
original_type.__name__,
|
||||||
__config__=original_type.__config__,
|
__config__=original_type.__config__,
|
||||||
__validators__=original_type.__validators__,
|
__validators__=original_type.__validators__, # type: ignore
|
||||||
)
|
)
|
||||||
for f in original_type.__fields__.values():
|
for f in original_type.__fields__.values():
|
||||||
use_type.__fields__[f.name] = f
|
use_type.__fields__[f.name] = f
|
||||||
|
|
|
||||||
|
|
@ -21,18 +21,21 @@ class User(BaseModel):
|
||||||
username: str
|
username: str
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(oauth_header: str = Security(reusable_oauth2)):
|
# Here we use string annotations to test them
|
||||||
|
def get_current_user(oauth_header: "str" = Security(reusable_oauth2)):
|
||||||
user = User(username=oauth_header)
|
user = User(username=oauth_header)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@app.post("/login")
|
@app.post("/login")
|
||||||
def read_current_user(form_data: OAuth2PasswordRequestFormStrict = Depends()):
|
# Here we use string annotations to test them
|
||||||
|
def read_current_user(form_data: "OAuth2PasswordRequestFormStrict" = Depends()):
|
||||||
return form_data
|
return form_data
|
||||||
|
|
||||||
|
|
||||||
@app.get("/users/me")
|
@app.get("/users/me")
|
||||||
def read_current_user(current_user: User = Depends(get_current_user)):
|
# Here we use string annotations to test them
|
||||||
|
def read_current_user(current_user: "User" = Depends(get_current_user)):
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue