Merge 15547a8720 into ae05379cc9
This commit is contained in:
commit
24ac1597cd
|
|
@ -207,7 +207,7 @@ def patch_model(model, tasks):
|
|||
def make_attn_patcher(ip_index):
|
||||
def patcher(n, context_attn2, value_attn2, extra_options):
|
||||
org_dtype = n.dtype
|
||||
current_step = float(model.model.diffusion_model.current_step.detach().cpu().numpy()[0])
|
||||
current_step = float(model.model.diffusion_model.current_step.detach().cpu().float().numpy()[0])
|
||||
cond_or_uncond = extra_options['cond_or_uncond']
|
||||
|
||||
q = n
|
||||
|
|
|
|||
|
|
@ -491,6 +491,14 @@ def unet_inital_load_device(parameters, dtype):
|
|||
return cpu_dev
|
||||
|
||||
def unet_dtype(device=None, model_params=0):
|
||||
# RTX 50xx Blackwell (sm_120, major>=12): use bf16 for native tensor core support
|
||||
if torch.cuda.is_available():
|
||||
_dev = device if device is not None else "cuda"
|
||||
try:
|
||||
if torch.cuda.get_device_properties(_dev).major >= 12:
|
||||
return torch.bfloat16
|
||||
except Exception:
|
||||
pass
|
||||
if args.unet_in_bf16:
|
||||
return torch.bfloat16
|
||||
if args.unet_in_fp16:
|
||||
|
|
@ -508,6 +516,9 @@ def unet_manual_cast(weight_dtype, inference_device):
|
|||
if weight_dtype == torch.float32:
|
||||
return None
|
||||
|
||||
if weight_dtype == torch.bfloat16:
|
||||
return None
|
||||
|
||||
fp16_supported = ldm_patched.modules.model_management.should_use_fp16(inference_device, prioritize_performance=False)
|
||||
if fp16_supported and weight_dtype == torch.float16:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -375,7 +375,7 @@ def patched_cldm_forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
|||
|
||||
def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
self.current_step = 1.0 - timesteps.to(x) / 999.0
|
||||
patch_settings[os.getpid()].global_diffusion_progress = float(self.current_step.detach().cpu().numpy().tolist()[0])
|
||||
patch_settings[os.getpid()].global_diffusion_progress = float(self.current_step.detach().cpu().float().flatten()[0].item())
|
||||
|
||||
y = timed_adm(y, timesteps)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue