diff --git a/fooocus_version.py b/fooocus_version.py index 571dc632..a70e2736 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.735' +version = '2.1.736' diff --git a/modules/patch.py b/modules/patch.py index f550ed82..f4783bbd 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -464,36 +464,19 @@ def text_encoder_device_patched(): return fcbh.model_management.get_torch_device() -def patched_autocast(device_type, dtype=None, enabled=True, cache_enabled=None): +def patched_autocast_enter(self): # https://github.com/lllyasviel/Fooocus/discussions/571 # https://github.com/lllyasviel/Fooocus/issues/620 # https://github.com/lllyasviel/Fooocus/issues/759 - supported = False - - if device_type == 'cuda' and dtype == torch.float32 and enabled: - supported = True - - if device_type == 'cuda' and dtype == torch.float16 and enabled: - supported = True - - if device_type == 'cuda' and dtype == torch.bfloat16 and enabled: - supported = True - - if not supported: - print(f'[Fooocus Autocast Warning] Requested unsupported torch autocast [' - f'device_type={str(device_type)}, ' - f'dtype={str(dtype)}, ' - f'enabled={str(enabled)}, ' - f'cache_enabled={str(cache_enabled)}]. ' + try: + result = self.enter_origin() + except Exception as e: + result = self + print(f'[Fooocus Autocast Warning] {str(e)}. \n' f'Fooocus fixed it automatically, feel free to report to Fooocus on GitHub if this may cause potential problems.') - return contextlib.nullcontext() - return torch.amp.autocast_mode.autocast_origin( - device_type=device_type, - dtype=dtype, - enabled=enabled, - cache_enabled=cache_enabled) + return result def patched_load_models_gpu(*args, **kwargs): @@ -556,14 +539,12 @@ def patch_all(): if not hasattr(fcbh.model_management, 'load_models_gpu_origin'): fcbh.model_management.load_models_gpu_origin = fcbh.model_management.load_models_gpu - if not hasattr(torch.amp.autocast_mode, 'autocast_origin'): - torch.amp.autocast_mode.autocast_origin = torch.amp.autocast_mode.autocast + if not hasattr(torch.amp.autocast_mode.autocast, 'enter_origin'): + torch.amp.autocast_mode.autocast.enter_origin = torch.amp.autocast_mode.autocast.__enter__ - torch.amp.autocast_mode.autocast = patched_autocast - torch.amp.autocast = patched_autocast - torch.autocast = patched_autocast + torch.amp.autocast_mode.autocast.__enter__ = patched_autocast_enter - # # Test if this will fail + # # Test if this would fail # with torch.autocast(device_type='cpu', dtype=torch.float32): # print(torch.ones(10))