mirror of https://github.com/tiangolo/fastapi.git
♻️ Refactor and simplify internal data from `solve_dependencies()` using dataclasses (#12100)
This commit is contained in:
parent
8d7d89e8c6
commit
5b7fa3900e
|
|
@ -529,6 +529,15 @@ async def solve_generator(
|
|||
return await stack.enter_async_context(cm)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolvedDependency:
|
||||
values: Dict[str, Any]
|
||||
errors: List[Any]
|
||||
background_tasks: Optional[StarletteBackgroundTasks]
|
||||
response: Response
|
||||
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
|
||||
|
||||
|
||||
async def solve_dependencies(
|
||||
*,
|
||||
request: Union[Request, WebSocket],
|
||||
|
|
@ -539,13 +548,7 @@ async def solve_dependencies(
|
|||
dependency_overrides_provider: Optional[Any] = None,
|
||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
|
||||
async_exit_stack: AsyncExitStack,
|
||||
) -> Tuple[
|
||||
Dict[str, Any],
|
||||
List[Any],
|
||||
Optional[StarletteBackgroundTasks],
|
||||
Response,
|
||||
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
|
||||
]:
|
||||
) -> SolvedDependency:
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Any] = []
|
||||
if response is None:
|
||||
|
|
@ -587,27 +590,21 @@ async def solve_dependencies(
|
|||
dependency_cache=dependency_cache,
|
||||
async_exit_stack=async_exit_stack,
|
||||
)
|
||||
(
|
||||
sub_values,
|
||||
sub_errors,
|
||||
background_tasks,
|
||||
_, # the subdependency returns the same response we have
|
||||
sub_dependency_cache,
|
||||
) = solved_result
|
||||
dependency_cache.update(sub_dependency_cache)
|
||||
if sub_errors:
|
||||
errors.extend(sub_errors)
|
||||
background_tasks = solved_result.background_tasks
|
||||
dependency_cache.update(solved_result.dependency_cache)
|
||||
if solved_result.errors:
|
||||
errors.extend(solved_result.errors)
|
||||
continue
|
||||
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependant.cache_key]
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
solved = await solve_generator(
|
||||
call=call, stack=async_exit_stack, sub_values=sub_values
|
||||
call=call, stack=async_exit_stack, sub_values=solved_result.values
|
||||
)
|
||||
elif is_coroutine_callable(call):
|
||||
solved = await call(**sub_values)
|
||||
solved = await call(**solved_result.values)
|
||||
else:
|
||||
solved = await run_in_threadpool(call, **sub_values)
|
||||
solved = await run_in_threadpool(call, **solved_result.values)
|
||||
if sub_dependant.name is not None:
|
||||
values[sub_dependant.name] = solved
|
||||
if sub_dependant.cache_key not in dependency_cache:
|
||||
|
|
@ -654,7 +651,13 @@ async def solve_dependencies(
|
|||
values[dependant.security_scopes_param_name] = SecurityScopes(
|
||||
scopes=dependant.security_scopes
|
||||
)
|
||||
return values, errors, background_tasks, response, dependency_cache
|
||||
return SolvedDependency(
|
||||
values=values,
|
||||
errors=errors,
|
||||
background_tasks=background_tasks,
|
||||
response=response,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
|
||||
|
||||
def request_params_to_args(
|
||||
|
|
|
|||
|
|
@ -292,26 +292,34 @@ def get_request_handler(
|
|||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
async_exit_stack=async_exit_stack,
|
||||
)
|
||||
values, errors, background_tasks, sub_response, _ = solved_result
|
||||
errors = solved_result.errors
|
||||
if not errors:
|
||||
raw_response = await run_endpoint_function(
|
||||
dependant=dependant, values=values, is_coroutine=is_coroutine
|
||||
dependant=dependant,
|
||||
values=solved_result.values,
|
||||
is_coroutine=is_coroutine,
|
||||
)
|
||||
if isinstance(raw_response, Response):
|
||||
if raw_response.background is None:
|
||||
raw_response.background = background_tasks
|
||||
raw_response.background = solved_result.background_tasks
|
||||
response = raw_response
|
||||
else:
|
||||
response_args: Dict[str, Any] = {"background": background_tasks}
|
||||
response_args: Dict[str, Any] = {
|
||||
"background": solved_result.background_tasks
|
||||
}
|
||||
# If status_code was set, use it, otherwise use the default from the
|
||||
# response class, in the case of redirect it's 307
|
||||
current_status_code = (
|
||||
status_code if status_code else sub_response.status_code
|
||||
status_code
|
||||
if status_code
|
||||
else solved_result.response.status_code
|
||||
)
|
||||
if current_status_code is not None:
|
||||
response_args["status_code"] = current_status_code
|
||||
if sub_response.status_code:
|
||||
response_args["status_code"] = sub_response.status_code
|
||||
if solved_result.response.status_code:
|
||||
response_args["status_code"] = (
|
||||
solved_result.response.status_code
|
||||
)
|
||||
content = await serialize_response(
|
||||
field=response_field,
|
||||
response_content=raw_response,
|
||||
|
|
@ -326,7 +334,7 @@ def get_request_handler(
|
|||
response = actual_response_class(content, **response_args)
|
||||
if not is_body_allowed_for_status_code(response.status_code):
|
||||
response.body = b""
|
||||
response.headers.raw.extend(sub_response.headers.raw)
|
||||
response.headers.raw.extend(solved_result.response.headers.raw)
|
||||
if errors:
|
||||
validation_error = RequestValidationError(
|
||||
_normalize_errors(errors), body=body
|
||||
|
|
@ -360,11 +368,12 @@ def get_websocket_app(
|
|||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
async_exit_stack=async_exit_stack,
|
||||
)
|
||||
values, errors, _, _2, _3 = solved_result
|
||||
if errors:
|
||||
raise WebSocketRequestValidationError(_normalize_errors(errors))
|
||||
if solved_result.errors:
|
||||
raise WebSocketRequestValidationError(
|
||||
_normalize_errors(solved_result.errors)
|
||||
)
|
||||
assert dependant.call is not None, "dependant.call must be a function"
|
||||
await dependant.call(**values)
|
||||
await dependant.call(**solved_result.values)
|
||||
|
||||
return app
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue