This commit is contained in:
lvmin 2023-08-10 10:42:36 -07:00
parent 9038f954f2
commit 669314dff1
3 changed files with 16 additions and 5 deletions

View File

@ -178,7 +178,6 @@ def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive,
noise_mask = prepare_mask(noise_mask, noise.shape, device)
comfy.model_management.load_model_gpu(model)
real_model = model.model
noise = noise.to(device)
latent_image = latent_image.to(device)
@ -188,8 +187,9 @@ def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive,
models = load_additional_models(positive, negative, model.model_dtype())
sampler = KSamplerWithRefiner(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler,
denoise=denoise, model_options=model.model_options)
sampler = KSamplerWithRefiner(model=model.model, refiner_model=refiner.model, steps=steps, device=device,
sampler=sampler_name, scheduler=scheduler,
denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image,
start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise,

View File

@ -27,7 +27,7 @@ def process(positive_prompt, negative_prompt, width=1024, height=1024, batch_siz
model=xl_base.unet,
positive=positive_conditions,
negative=negative_conditions,
refiner=xl_refiner,
refiner=xl_refiner.unet,
refiner_positive=positive_conditions_refiner,
refiner_negative=negative_conditions_refiner,
refiner_switch_step=20,

View File

@ -7,15 +7,26 @@ class KSamplerWithRefiner:
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"]
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
def __init__(self, model, refiner_model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model
self.refiner_model = refiner_model
self.model_denoise = CFGNoisePredictor(self.model)
self.refiner_model_denoise = CFGNoisePredictor(self.refiner_model)
if self.model.model_type == model_base.ModelType.V_PREDICTION:
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
else:
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
if self.refiner_model.model_type == model_base.ModelType.V_PREDICTION:
self.refiner_model_wrap = CompVisVDenoiser(self.refiner_model_denoise, quantize=True)
else:
self.refiner_model_wrap = k_diffusion_external.CompVisDenoiser(self.refiner_model_denoise, quantize=True)
self.model_k = KSamplerX0Inpaint(self.model_wrap)
self.refiner_model_k = KSamplerX0Inpaint(self.refiner_model_wrap)
self.device = device
if scheduler not in self.SCHEDULERS:
scheduler = self.SCHEDULERS[0]