This commit is contained in:
lvmin 2023-08-09 13:54:43 -07:00
parent 6070c4a1e5
commit ebfe47d02d
1 changed files with 19 additions and 3 deletions

View File

@ -2,6 +2,7 @@ import os
import math
import numpy as np
import torch
import gc
import safetensors.torch
from omegaconf import OmegaConf
@ -91,9 +92,7 @@ config_path = './sd_xl_base.yaml'
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
model.eval()
sd = safetensors.torch.load_file('./sd_xl_base_1.0.safetensors')
model.load_state_dict(sd, strict=False)
model.load_state_dict(safetensors.torch.load_file('./sd_xl_base_1.0.safetensors'), strict=False)
model.conditioner.cuda()
@ -114,4 +113,21 @@ c, uc = model.conditioner.get_unconditional_conditioning(
batch_uc=batch_uc)
model.conditioner.cpu()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
shape = (1, 4, 128, 128)
randn = torch.randn(shape).cuda()
def denoiser(input, sigma, c):
return model.denoiser(model.model, input, sigma, c)
model.model.cuda()
model.denoiser.cuda()
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
model.model.cpu()
model.denoiser.cpu()
a = 0