fix: Add NVIDIA Blackwell (RTX 50xx, sm_120) support

- Use bfloat16 dtype for UNet on Blackwell GPUs (compute major >= 12)
  which have native bf16 tensor core support
- Skip manual_cast for bfloat16 weights to avoid unnecessary casting
- Fix numpy TypeError with bfloat16 tensors in patch.py and
  ip_adapter.py by converting to float32 before .numpy() calls

Tested on RTX 5070 (sm_120, CUDA 12.8) with PyTorch nightly (cu128).
Generates images at ~3.2 it/s including Image Prompt (IP-Adapter) mode.

Fixes #3862, #4123, #4141
This commit is contained in:
Hasham Vakani ⚡ 2026-03-04 05:43:44 +05:00
parent ae05379cc9
commit 15547a8720
3 changed files with 13 additions and 2 deletions

View File

@ -207,7 +207,7 @@ def patch_model(model, tasks):
def make_attn_patcher(ip_index):
def patcher(n, context_attn2, value_attn2, extra_options):
org_dtype = n.dtype
current_step = float(model.model.diffusion_model.current_step.detach().cpu().numpy()[0])
current_step = float(model.model.diffusion_model.current_step.detach().cpu().float().numpy()[0])
cond_or_uncond = extra_options['cond_or_uncond']
q = n

View File

@ -491,6 +491,14 @@ def unet_inital_load_device(parameters, dtype):
return cpu_dev
def unet_dtype(device=None, model_params=0):
# RTX 50xx Blackwell (sm_120, major>=12): use bf16 for native tensor core support
if torch.cuda.is_available():
_dev = device if device is not None else "cuda"
try:
if torch.cuda.get_device_properties(_dev).major >= 12:
return torch.bfloat16
except Exception:
pass
if args.unet_in_bf16:
return torch.bfloat16
if args.unet_in_fp16:
@ -508,6 +516,9 @@ def unet_manual_cast(weight_dtype, inference_device):
if weight_dtype == torch.float32:
return None
if weight_dtype == torch.bfloat16:
return None
fp16_supported = ldm_patched.modules.model_management.should_use_fp16(inference_device, prioritize_performance=False)
if fp16_supported and weight_dtype == torch.float16:
return None

View File

@ -375,7 +375,7 @@ def patched_cldm_forward(self, x, hint, timesteps, context, y=None, **kwargs):
def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
self.current_step = 1.0 - timesteps.to(x) / 999.0
patch_settings[os.getpid()].global_diffusion_progress = float(self.current_step.detach().cpu().numpy().tolist()[0])
patch_settings[os.getpid()].global_diffusion_progress = float(self.current_step.detach().cpu().float().flatten()[0].item())
y = timed_adm(y, timesteps)