diff --git a/entry.py b/entry.py index 61f8eda3..574e8f96 100644 --- a/entry.py +++ b/entry.py @@ -91,11 +91,13 @@ sampler = EulerAncestralSampler( config_path = './sd_xl_base.yaml' config = OmegaConf.load(config_path) model = instantiate_from_config(config.model).cpu() -model.to(torch.float16) model.eval() model.load_state_dict(safetensors.torch.load_file('./sd_xl_base_1.0.safetensors'), strict=False) -model.conditioner.cuda() +# model.conditioner.cuda() + +model.conditioner.embedders[0].device = 'cpu' +model.conditioner.embedders[1].device = 'cpu' value_dict = { "prompt": "a handsome man in forest", "negative_prompt": "ugly, bad", "orig_height": 1024, "orig_width": 1024, @@ -112,7 +114,10 @@ batch, batch_uc = get_batch( c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc) -model.conditioner.cpu() +# model.conditioner.cpu() + +c = {a: b.to(torch.float16) for a, b in c.items()} +uc = {a: b.to(torch.float16) for a, b in uc.items()} torch.cuda.empty_cache() torch.cuda.ipc_collect() @@ -125,8 +130,8 @@ def denoiser(input, sigma, c): return model.denoiser(model.model, input, sigma, c) -model.model.cuda() -model.denoiser.cuda() +model.model.to(torch.float16).cuda() +model.denoiser.to(torch.float16).cuda() samples_z = sampler(denoiser, randn, cond=c, uc=uc) model.model.cpu() model.denoiser.cpu() diff --git a/sgm/modules/diffusionmodules/util.py b/sgm/modules/diffusionmodules/util.py index 069ff131..1acae630 100644 --- a/sgm/modules/diffusionmodules/util.py +++ b/sgm/modules/diffusionmodules/util.py @@ -272,7 +272,7 @@ class SiLU(nn.Module): class GroupNorm32(nn.GroupNorm): def forward(self, x): - return super().forward(x.float()).type(x.dtype) + return super().forward(x) def conv_nd(dims, *args, **kwargs): diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py index 99fc68c8..9fce2e10 100644 --- a/sgm/modules/encoders/modules.py +++ b/sgm/modules/encoders/modules.py @@ -166,7 +166,7 @@ class GeneralConditioner(nn.Module): emb = torch.zeros_like(emb) if out_key in output: output[out_key] = torch.cat( - (output[out_key], emb), self.KEY2CATDIM[out_key] + (output[out_key].cuda(), emb.cuda()), self.KEY2CATDIM[out_key] ) else: output[out_key] = emb