diff --git a/entry.py b/entry.py index 471cd20d..9bb1c42a 100644 --- a/entry.py +++ b/entry.py @@ -2,6 +2,7 @@ import os import math import numpy as np import torch +import gc import safetensors.torch from omegaconf import OmegaConf @@ -91,9 +92,7 @@ config_path = './sd_xl_base.yaml' config = OmegaConf.load(config_path) model = instantiate_from_config(config.model).cpu() model.eval() - -sd = safetensors.torch.load_file('./sd_xl_base_1.0.safetensors') -model.load_state_dict(sd, strict=False) +model.load_state_dict(safetensors.torch.load_file('./sd_xl_base_1.0.safetensors'), strict=False) model.conditioner.cuda() @@ -114,4 +113,21 @@ c, uc = model.conditioner.get_unconditional_conditioning( batch_uc=batch_uc) model.conditioner.cpu() +torch.cuda.empty_cache() +torch.cuda.ipc_collect() + +shape = (1, 4, 128, 128) +randn = torch.randn(shape).cuda() + + +def denoiser(input, sigma, c): + return model.denoiser(model.model, input, sigma, c) + + +model.model.cuda() +model.denoiser.cuda() +samples_z = sampler(denoiser, randn, cond=c, uc=uc) +model.model.cpu() +model.denoiser.cpu() + a = 0