diff --git a/modules/core.py b/modules/core.py index c240674e..f4db8ed0 100644 --- a/modules/core.py +++ b/modules/core.py @@ -142,6 +142,70 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=9.0, sa return out +@torch.no_grad() +def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive, refiner_negative, latent, + seed=None, steps=30, refiner_switch_step=20, cfg=9.0, sampler_name='dpmpp_2m_sde', + scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, + force_full_denoise=False): + seed = seed if isinstance(seed, int) else random.randint(1, 2 ** 64) + + device = comfy.model_management.get_torch_device() + latent_image = latent["samples"] + + if disable_noise: + noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + else: + batch_inds = latent["batch_index"] if "batch_index" in latent else None + noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + previewer = get_previewer(device, model.model.latent_format) + + pbar = comfy.utils.ProgressBar(steps) + + def callback(step, x0, x, total_steps): + if previewer and step % 3 == 0: + previewer.preview(x0, step, total_steps) + pbar.update_absolute(step + 1, total_steps, None) + + sigmas = None + disable_pbar = False + + if noise_mask is not None: + noise_mask = prepare_mask(noise_mask, noise.shape, device) + + comfy.model_management.load_model_gpu(model) + real_model = model.model + + noise = noise.to(device) + latent_image = latent_image.to(device) + + positive_copy = broadcast_cond(positive, noise.shape[0], device) + negative_copy = broadcast_cond(negative, noise.shape[0], device) + + models = load_additional_models(positive, negative, model.model_dtype()) + + sampler = KSamplerWithRefiner(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, + denoise=denoise, model_options=model.model_options) + + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, + start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, + denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, + seed=seed) + + samples = samples.cpu() + + cleanup_additional_models(models) + + out = latent.copy() + out["samples"] = samples + + return out + + @torch.no_grad() def image_to_numpy(x): return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x] diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 3bcc01ee..3aba6651 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -23,20 +23,16 @@ def process(positive_prompt, negative_prompt, width=1024, height=1024, batch_siz empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=batch_size) - sampled_latent = core.ksampler( + sampled_latent = core.ksampler_with_refiner( model=xl_base.unet, positive=positive_conditions, negative=negative_conditions, + refiner=xl_refiner, + refiner_positive=positive_conditions_refiner, + refiner_negative=negative_conditions_refiner, + refiner_switch_step=20, latent=empty_latent, - steps=30, start_step=0, last_step=20, disable_noise=False, force_full_denoise=False - ) - - sampled_latent = core.ksampler( - model=xl_refiner.unet, - positive=positive_conditions_refiner, - negative=negative_conditions_refiner, - latent=sampled_latent, - steps=30, start_step=20, last_step=30, disable_noise=True, force_full_denoise=True + steps=30, start_step=0, last_step=30, disable_noise=False, force_full_denoise=True ) decoded_latent = core.decode_vae(vae=xl_refiner.vae, latent_image=sampled_latent)