This commit is contained in:
lvmin 2023-08-09 14:27:40 -07:00
parent 8807b2dff8
commit 69f57c0454
3 changed files with 12 additions and 7 deletions

View File

@ -91,11 +91,13 @@ sampler = EulerAncestralSampler(
config_path = './sd_xl_base.yaml'
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
model.to(torch.float16)
model.eval()
model.load_state_dict(safetensors.torch.load_file('./sd_xl_base_1.0.safetensors'), strict=False)
model.conditioner.cuda()
# model.conditioner.cuda()
model.conditioner.embedders[0].device = 'cpu'
model.conditioner.embedders[1].device = 'cpu'
value_dict = {
"prompt": "a handsome man in forest", "negative_prompt": "ugly, bad", "orig_height": 1024, "orig_width": 1024,
@ -112,7 +114,10 @@ batch, batch_uc = get_batch(
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc)
model.conditioner.cpu()
# model.conditioner.cpu()
c = {a: b.to(torch.float16) for a, b in c.items()}
uc = {a: b.to(torch.float16) for a, b in uc.items()}
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
@ -125,8 +130,8 @@ def denoiser(input, sigma, c):
return model.denoiser(model.model, input, sigma, c)
model.model.cuda()
model.denoiser.cuda()
model.model.to(torch.float16).cuda()
model.denoiser.to(torch.float16).cuda()
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
model.model.cpu()
model.denoiser.cpu()

View File

@ -272,7 +272,7 @@ class SiLU(nn.Module):
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
return super().forward(x)
def conv_nd(dims, *args, **kwargs):

View File

@ -166,7 +166,7 @@ class GeneralConditioner(nn.Module):
emb = torch.zeros_like(emb)
if out_key in output:
output[out_key] = torch.cat(
(output[out_key], emb), self.KEY2CATDIM[out_key]
(output[out_key].cuda(), emb.cuda()), self.KEY2CATDIM[out_key]
)
else:
output[out_key] = emb