From 47281e52c42dcf1e3b30c58201ba199e7f523bde Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Mon, 23 Oct 2023 13:07:55 -0700 Subject: [PATCH] Fixed many autocast problems. --- fooocus_version.py | 2 +- modules/patch.py | 49 +++++++++++++++++++++++++++++++++++++--------- update_log.md | 4 ++++ 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/fooocus_version.py b/fooocus_version.py index e634d37e..571dc632 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.734' +version = '2.1.735' diff --git a/modules/patch.py b/modules/patch.py index 77958aed..f550ed82 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -1,3 +1,4 @@ +import contextlib import os import torch import time @@ -463,16 +464,36 @@ def text_encoder_device_patched(): return fcbh.model_management.get_torch_device() -def patched_get_autocast_device(dev): +def patched_autocast(device_type, dtype=None, enabled=True, cache_enabled=None): # https://github.com/lllyasviel/Fooocus/discussions/571 # https://github.com/lllyasviel/Fooocus/issues/620 - result = '' - if hasattr(dev, 'type'): - result = str(dev.type) - if 'cuda' in result: - return 'cuda' - else: - return 'cpu' + # https://github.com/lllyasviel/Fooocus/issues/759 + + supported = False + + if device_type == 'cuda' and dtype == torch.float32 and enabled: + supported = True + + if device_type == 'cuda' and dtype == torch.float16 and enabled: + supported = True + + if device_type == 'cuda' and dtype == torch.bfloat16 and enabled: + supported = True + + if not supported: + print(f'[Fooocus Autocast Warning] Requested unsupported torch autocast [' + f'device_type={str(device_type)}, ' + f'dtype={str(dtype)}, ' + f'enabled={str(enabled)}, ' + f'cache_enabled={str(cache_enabled)}]. ' + f'Fooocus fixed it automatically, feel free to report to Fooocus on GitHub if this may cause potential problems.') + return contextlib.nullcontext() + + return torch.amp.autocast_mode.autocast_origin( + device_type=device_type, + dtype=dtype, + enabled=enabled, + cache_enabled=cache_enabled) def patched_load_models_gpu(*args, **kwargs): @@ -535,8 +556,18 @@ def patch_all(): if not hasattr(fcbh.model_management, 'load_models_gpu_origin'): fcbh.model_management.load_models_gpu_origin = fcbh.model_management.load_models_gpu + if not hasattr(torch.amp.autocast_mode, 'autocast_origin'): + torch.amp.autocast_mode.autocast_origin = torch.amp.autocast_mode.autocast + + torch.amp.autocast_mode.autocast = patched_autocast + torch.amp.autocast = patched_autocast + torch.autocast = patched_autocast + + # # Test if this will fail + # with torch.autocast(device_type='cpu', dtype=torch.float32): + # print(torch.ones(10)) + fcbh.model_management.load_models_gpu = patched_load_models_gpu - fcbh.model_management.get_autocast_device = patched_get_autocast_device fcbh.model_management.text_encoder_device = text_encoder_device_patched fcbh.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched fcbh.cldm.cldm.ControlNet.forward = patched_cldm_forward diff --git a/update_log.md b/update_log.md index 9d33ba27..e913c12e 100644 --- a/update_log.md +++ b/update_log.md @@ -1,3 +1,7 @@ +# 2.1.735 + +* Fixed many problems related to torch autocast. + # 2.1.733 * Increased allowed random seed range.