This commit is contained in:
lvmin 2023-08-09 14:27:40 -07:00
parent 8807b2dff8
commit 69f57c0454
3 changed files with 12 additions and 7 deletions

View File

@ -91,11 +91,13 @@ sampler = EulerAncestralSampler(
config_path = './sd_xl_base.yaml' config_path = './sd_xl_base.yaml'
config = OmegaConf.load(config_path) config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu() model = instantiate_from_config(config.model).cpu()
model.to(torch.float16)
model.eval() model.eval()
model.load_state_dict(safetensors.torch.load_file('./sd_xl_base_1.0.safetensors'), strict=False) model.load_state_dict(safetensors.torch.load_file('./sd_xl_base_1.0.safetensors'), strict=False)
model.conditioner.cuda() # model.conditioner.cuda()
model.conditioner.embedders[0].device = 'cpu'
model.conditioner.embedders[1].device = 'cpu'
value_dict = { value_dict = {
"prompt": "a handsome man in forest", "negative_prompt": "ugly, bad", "orig_height": 1024, "orig_width": 1024, "prompt": "a handsome man in forest", "negative_prompt": "ugly, bad", "orig_height": 1024, "orig_width": 1024,
@ -112,7 +114,10 @@ batch, batch_uc = get_batch(
c, uc = model.conditioner.get_unconditional_conditioning( c, uc = model.conditioner.get_unconditional_conditioning(
batch, batch,
batch_uc=batch_uc) batch_uc=batch_uc)
model.conditioner.cpu() # model.conditioner.cpu()
c = {a: b.to(torch.float16) for a, b in c.items()}
uc = {a: b.to(torch.float16) for a, b in uc.items()}
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
@ -125,8 +130,8 @@ def denoiser(input, sigma, c):
return model.denoiser(model.model, input, sigma, c) return model.denoiser(model.model, input, sigma, c)
model.model.cuda() model.model.to(torch.float16).cuda()
model.denoiser.cuda() model.denoiser.to(torch.float16).cuda()
samples_z = sampler(denoiser, randn, cond=c, uc=uc) samples_z = sampler(denoiser, randn, cond=c, uc=uc)
model.model.cpu() model.model.cpu()
model.denoiser.cpu() model.denoiser.cpu()

View File

@ -272,7 +272,7 @@ class SiLU(nn.Module):
class GroupNorm32(nn.GroupNorm): class GroupNorm32(nn.GroupNorm):
def forward(self, x): def forward(self, x):
return super().forward(x.float()).type(x.dtype) return super().forward(x)
def conv_nd(dims, *args, **kwargs): def conv_nd(dims, *args, **kwargs):

View File

@ -166,7 +166,7 @@ class GeneralConditioner(nn.Module):
emb = torch.zeros_like(emb) emb = torch.zeros_like(emb)
if out_key in output: if out_key in output:
output[out_key] = torch.cat( output[out_key] = torch.cat(
(output[out_key], emb), self.KEY2CATDIM[out_key] (output[out_key].cuda(), emb.cuda()), self.KEY2CATDIM[out_key]
) )
else: else:
output[out_key] = emb output[out_key] = emb