diff --git a/extras/ip_adapter.py b/extras/ip_adapter.py index d29f1de2..c08d79de 100644 --- a/extras/ip_adapter.py +++ b/extras/ip_adapter.py @@ -207,7 +207,7 @@ def patch_model(model, tasks): def make_attn_patcher(ip_index): def patcher(n, context_attn2, value_attn2, extra_options): org_dtype = n.dtype - current_step = float(model.model.diffusion_model.current_step.detach().cpu().numpy()[0]) + current_step = float(model.model.diffusion_model.current_step.detach().cpu().float().numpy()[0]) cond_or_uncond = extra_options['cond_or_uncond'] q = n diff --git a/ldm_patched/modules/model_management.py b/ldm_patched/modules/model_management.py index 840d79a0..46f39f64 100644 --- a/ldm_patched/modules/model_management.py +++ b/ldm_patched/modules/model_management.py @@ -491,6 +491,14 @@ def unet_inital_load_device(parameters, dtype): return cpu_dev def unet_dtype(device=None, model_params=0): + # RTX 50xx Blackwell (sm_120, major>=12): use bf16 for native tensor core support + if torch.cuda.is_available(): + _dev = device if device is not None else "cuda" + try: + if torch.cuda.get_device_properties(_dev).major >= 12: + return torch.bfloat16 + except Exception: + pass if args.unet_in_bf16: return torch.bfloat16 if args.unet_in_fp16: @@ -508,6 +516,9 @@ def unet_manual_cast(weight_dtype, inference_device): if weight_dtype == torch.float32: return None + if weight_dtype == torch.bfloat16: + return None + fp16_supported = ldm_patched.modules.model_management.should_use_fp16(inference_device, prioritize_performance=False) if fp16_supported and weight_dtype == torch.float16: return None diff --git a/modules/patch.py b/modules/patch.py index 3c2dd8f4..991c9c97 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -375,7 +375,7 @@ def patched_cldm_forward(self, x, hint, timesteps, context, y=None, **kwargs): def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): self.current_step = 1.0 - timesteps.to(x) / 999.0 - patch_settings[os.getpid()].global_diffusion_progress = float(self.current_step.detach().cpu().numpy().tolist()[0]) + patch_settings[os.getpid()].global_diffusion_progress = float(self.current_step.detach().cpu().float().flatten()[0].item()) y = timed_adm(y, timesteps)