mirror of https://github.com/tiangolo/fastapi.git
♻️ Refactor and simplify internal data from `solve_dependencies()` using dataclasses (#12100)
This commit is contained in:
parent
8d7d89e8c6
commit
5b7fa3900e
|
|
@ -529,6 +529,15 @@ async def solve_generator(
|
||||||
return await stack.enter_async_context(cm)
|
return await stack.enter_async_context(cm)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SolvedDependency:
|
||||||
|
values: Dict[str, Any]
|
||||||
|
errors: List[Any]
|
||||||
|
background_tasks: Optional[StarletteBackgroundTasks]
|
||||||
|
response: Response
|
||||||
|
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
|
||||||
|
|
||||||
|
|
||||||
async def solve_dependencies(
|
async def solve_dependencies(
|
||||||
*,
|
*,
|
||||||
request: Union[Request, WebSocket],
|
request: Union[Request, WebSocket],
|
||||||
|
|
@ -539,13 +548,7 @@ async def solve_dependencies(
|
||||||
dependency_overrides_provider: Optional[Any] = None,
|
dependency_overrides_provider: Optional[Any] = None,
|
||||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
|
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
|
||||||
async_exit_stack: AsyncExitStack,
|
async_exit_stack: AsyncExitStack,
|
||||||
) -> Tuple[
|
) -> SolvedDependency:
|
||||||
Dict[str, Any],
|
|
||||||
List[Any],
|
|
||||||
Optional[StarletteBackgroundTasks],
|
|
||||||
Response,
|
|
||||||
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
|
|
||||||
]:
|
|
||||||
values: Dict[str, Any] = {}
|
values: Dict[str, Any] = {}
|
||||||
errors: List[Any] = []
|
errors: List[Any] = []
|
||||||
if response is None:
|
if response is None:
|
||||||
|
|
@ -587,27 +590,21 @@ async def solve_dependencies(
|
||||||
dependency_cache=dependency_cache,
|
dependency_cache=dependency_cache,
|
||||||
async_exit_stack=async_exit_stack,
|
async_exit_stack=async_exit_stack,
|
||||||
)
|
)
|
||||||
(
|
background_tasks = solved_result.background_tasks
|
||||||
sub_values,
|
dependency_cache.update(solved_result.dependency_cache)
|
||||||
sub_errors,
|
if solved_result.errors:
|
||||||
background_tasks,
|
errors.extend(solved_result.errors)
|
||||||
_, # the subdependency returns the same response we have
|
|
||||||
sub_dependency_cache,
|
|
||||||
) = solved_result
|
|
||||||
dependency_cache.update(sub_dependency_cache)
|
|
||||||
if sub_errors:
|
|
||||||
errors.extend(sub_errors)
|
|
||||||
continue
|
continue
|
||||||
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
||||||
solved = dependency_cache[sub_dependant.cache_key]
|
solved = dependency_cache[sub_dependant.cache_key]
|
||||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||||
solved = await solve_generator(
|
solved = await solve_generator(
|
||||||
call=call, stack=async_exit_stack, sub_values=sub_values
|
call=call, stack=async_exit_stack, sub_values=solved_result.values
|
||||||
)
|
)
|
||||||
elif is_coroutine_callable(call):
|
elif is_coroutine_callable(call):
|
||||||
solved = await call(**sub_values)
|
solved = await call(**solved_result.values)
|
||||||
else:
|
else:
|
||||||
solved = await run_in_threadpool(call, **sub_values)
|
solved = await run_in_threadpool(call, **solved_result.values)
|
||||||
if sub_dependant.name is not None:
|
if sub_dependant.name is not None:
|
||||||
values[sub_dependant.name] = solved
|
values[sub_dependant.name] = solved
|
||||||
if sub_dependant.cache_key not in dependency_cache:
|
if sub_dependant.cache_key not in dependency_cache:
|
||||||
|
|
@ -654,7 +651,13 @@ async def solve_dependencies(
|
||||||
values[dependant.security_scopes_param_name] = SecurityScopes(
|
values[dependant.security_scopes_param_name] = SecurityScopes(
|
||||||
scopes=dependant.security_scopes
|
scopes=dependant.security_scopes
|
||||||
)
|
)
|
||||||
return values, errors, background_tasks, response, dependency_cache
|
return SolvedDependency(
|
||||||
|
values=values,
|
||||||
|
errors=errors,
|
||||||
|
background_tasks=background_tasks,
|
||||||
|
response=response,
|
||||||
|
dependency_cache=dependency_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def request_params_to_args(
|
def request_params_to_args(
|
||||||
|
|
|
||||||
|
|
@ -292,26 +292,34 @@ def get_request_handler(
|
||||||
dependency_overrides_provider=dependency_overrides_provider,
|
dependency_overrides_provider=dependency_overrides_provider,
|
||||||
async_exit_stack=async_exit_stack,
|
async_exit_stack=async_exit_stack,
|
||||||
)
|
)
|
||||||
values, errors, background_tasks, sub_response, _ = solved_result
|
errors = solved_result.errors
|
||||||
if not errors:
|
if not errors:
|
||||||
raw_response = await run_endpoint_function(
|
raw_response = await run_endpoint_function(
|
||||||
dependant=dependant, values=values, is_coroutine=is_coroutine
|
dependant=dependant,
|
||||||
|
values=solved_result.values,
|
||||||
|
is_coroutine=is_coroutine,
|
||||||
)
|
)
|
||||||
if isinstance(raw_response, Response):
|
if isinstance(raw_response, Response):
|
||||||
if raw_response.background is None:
|
if raw_response.background is None:
|
||||||
raw_response.background = background_tasks
|
raw_response.background = solved_result.background_tasks
|
||||||
response = raw_response
|
response = raw_response
|
||||||
else:
|
else:
|
||||||
response_args: Dict[str, Any] = {"background": background_tasks}
|
response_args: Dict[str, Any] = {
|
||||||
|
"background": solved_result.background_tasks
|
||||||
|
}
|
||||||
# If status_code was set, use it, otherwise use the default from the
|
# If status_code was set, use it, otherwise use the default from the
|
||||||
# response class, in the case of redirect it's 307
|
# response class, in the case of redirect it's 307
|
||||||
current_status_code = (
|
current_status_code = (
|
||||||
status_code if status_code else sub_response.status_code
|
status_code
|
||||||
|
if status_code
|
||||||
|
else solved_result.response.status_code
|
||||||
)
|
)
|
||||||
if current_status_code is not None:
|
if current_status_code is not None:
|
||||||
response_args["status_code"] = current_status_code
|
response_args["status_code"] = current_status_code
|
||||||
if sub_response.status_code:
|
if solved_result.response.status_code:
|
||||||
response_args["status_code"] = sub_response.status_code
|
response_args["status_code"] = (
|
||||||
|
solved_result.response.status_code
|
||||||
|
)
|
||||||
content = await serialize_response(
|
content = await serialize_response(
|
||||||
field=response_field,
|
field=response_field,
|
||||||
response_content=raw_response,
|
response_content=raw_response,
|
||||||
|
|
@ -326,7 +334,7 @@ def get_request_handler(
|
||||||
response = actual_response_class(content, **response_args)
|
response = actual_response_class(content, **response_args)
|
||||||
if not is_body_allowed_for_status_code(response.status_code):
|
if not is_body_allowed_for_status_code(response.status_code):
|
||||||
response.body = b""
|
response.body = b""
|
||||||
response.headers.raw.extend(sub_response.headers.raw)
|
response.headers.raw.extend(solved_result.response.headers.raw)
|
||||||
if errors:
|
if errors:
|
||||||
validation_error = RequestValidationError(
|
validation_error = RequestValidationError(
|
||||||
_normalize_errors(errors), body=body
|
_normalize_errors(errors), body=body
|
||||||
|
|
@ -360,11 +368,12 @@ def get_websocket_app(
|
||||||
dependency_overrides_provider=dependency_overrides_provider,
|
dependency_overrides_provider=dependency_overrides_provider,
|
||||||
async_exit_stack=async_exit_stack,
|
async_exit_stack=async_exit_stack,
|
||||||
)
|
)
|
||||||
values, errors, _, _2, _3 = solved_result
|
if solved_result.errors:
|
||||||
if errors:
|
raise WebSocketRequestValidationError(
|
||||||
raise WebSocketRequestValidationError(_normalize_errors(errors))
|
_normalize_errors(solved_result.errors)
|
||||||
|
)
|
||||||
assert dependant.call is not None, "dependant.call must be a function"
|
assert dependant.call is not None, "dependant.call must be a function"
|
||||||
await dependant.call(**values)
|
await dependant.call(**solved_result.values)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue