This commit is contained in:
Hasham Vakani ⚡ 2026-03-04 05:44:45 +05:00 committed by GitHub
commit 24ac1597cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 2 deletions

View File

@ -207,7 +207,7 @@ def patch_model(model, tasks):
def make_attn_patcher(ip_index):
def patcher(n, context_attn2, value_attn2, extra_options):
org_dtype = n.dtype
current_step = float(model.model.diffusion_model.current_step.detach().cpu().numpy()[0])
current_step = float(model.model.diffusion_model.current_step.detach().cpu().float().numpy()[0])
cond_or_uncond = extra_options['cond_or_uncond']
q = n

View File

@ -491,6 +491,14 @@ def unet_inital_load_device(parameters, dtype):
return cpu_dev
def unet_dtype(device=None, model_params=0):
# RTX 50xx Blackwell (sm_120, major>=12): use bf16 for native tensor core support
if torch.cuda.is_available():
_dev = device if device is not None else "cuda"
try:
if torch.cuda.get_device_properties(_dev).major >= 12:
return torch.bfloat16
except Exception:
pass
if args.unet_in_bf16:
return torch.bfloat16
if args.unet_in_fp16:
@ -508,6 +516,9 @@ def unet_manual_cast(weight_dtype, inference_device):
if weight_dtype == torch.float32:
return None
if weight_dtype == torch.bfloat16:
return None
fp16_supported = ldm_patched.modules.model_management.should_use_fp16(inference_device, prioritize_performance=False)
if fp16_supported and weight_dtype == torch.float16:
return None

View File

@ -375,7 +375,7 @@ def patched_cldm_forward(self, x, hint, timesteps, context, y=None, **kwargs):
def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
self.current_step = 1.0 - timesteps.to(x) / 999.0
patch_settings[os.getpid()].global_diffusion_progress = float(self.current_step.detach().cpu().numpy().tolist()[0])
patch_settings[os.getpid()].global_diffusion_progress = float(self.current_step.detach().cpu().float().flatten()[0].item())
y = timed_adm(y, timesteps)