diff --git a/.gitignore b/.gitignore index eeaeb1ed..fb1bffdd 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__ lena.png lena_result.png lena_test.py +/modules/*.png /repositories /venv /tmp diff --git a/fooocus_version.py b/fooocus_version.py index 1d8c93b8..40dfeb3b 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.0.54' +version = '2.0.60' diff --git a/modules/async_worker.py b/modules/async_worker.py index a6a42b1a..ec001e92 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -1,4 +1,6 @@ import threading + +import numpy as np import torch buffer = [] @@ -19,6 +21,7 @@ def worker(): import modules.patch import modules.virtual_memory as virtual_memory import comfy.model_management + import modules.inpaint_worker as inpaint_worker from modules.sdxl_styles import apply_style, aspect_ratios, fooocus_expansion from modules.private_logger import log @@ -46,8 +49,10 @@ def worker(): aspect_ratios_selction, image_number, image_seed, sharpness, \ base_model_name, refiner_model_name, \ l1, w1, l2, w2, l3, w3, l4, w4, l5, w5, \ - input_image_checkbox, \ - uov_method, uov_input_image = task + input_image_checkbox, current_tab, \ + uov_method, uov_input_image, outpaint_selections, inpaint_input_image = task + + outpaint_selections = [o.lower() for o in outpaint_selections] loras = [(l1, w1), (l2, w2), (l3, w3), (l4, w4), (l5, w5)] @@ -63,9 +68,11 @@ def worker(): use_style = len(style_selections) > 0 modules.patch.sharpness = sharpness + modules.patch.negative_adm = True initial_latent = None denoising_strength = 1.0 tiled = False + inpaint_worker.current_task = None if performance_selction == 'Speed': steps = 30 @@ -80,7 +87,7 @@ def worker(): if input_image_checkbox: progressbar(0, 'Image processing ...') - if uov_method != flags.disabled and uov_input_image is not None: + if current_tab == 'uov' and uov_method != flags.disabled and uov_input_image is not None: uov_input_image = HWC3(uov_input_image) if 'vary' in uov_method: if not image_is_generated_in_current_ui(uov_input_image, ui_width=width, ui_height=height): @@ -156,6 +163,49 @@ def worker(): width = W * 8 height = H * 8 print(f'Final resolution is {str((height, width))}.') + if current_tab == 'inpaint' and isinstance(inpaint_input_image, dict): + inpaint_image = inpaint_input_image['image'] + inpaint_mask = inpaint_input_image['mask'][:, :, 0] + if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \ + and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0): + if len(outpaint_selections) > 0: + H, W, C = inpaint_image.shape + if 'top' in outpaint_selections: + inpaint_image = np.pad(inpaint_image, [[int(H * 0.3), 0], [0, 0], [0, 0]], mode='edge') + inpaint_mask = np.pad(inpaint_mask, [[int(H * 0.3), 0], [0, 0]], mode='constant', constant_values=255) + if 'bottom' in outpaint_selections: + inpaint_image = np.pad(inpaint_image, [[0, int(H * 0.3)], [0, 0], [0, 0]], mode='edge') + inpaint_mask = np.pad(inpaint_mask, [[0, int(H * 0.3)], [0, 0]], mode='constant', constant_values=255) + + H, W, C = inpaint_image.shape + if 'left' in outpaint_selections: + inpaint_image = np.pad(inpaint_image, [[0, 0], [int(H * 0.3), 0], [0, 0]], mode='edge') + inpaint_mask = np.pad(inpaint_mask, [[0, 0], [int(H * 0.3), 0]], mode='constant', constant_values=255) + if 'right' in outpaint_selections: + inpaint_image = np.pad(inpaint_image, [[0, 0], [0, int(H * 0.3)], [0, 0]], mode='edge') + inpaint_mask = np.pad(inpaint_mask, [[0, 0], [0, int(H * 0.3)]], mode='constant', constant_values=255) + + inpaint_image = np.ascontiguousarray(inpaint_image.copy()) + inpaint_mask = np.ascontiguousarray(inpaint_mask.copy()) + + inpaint_worker.current_task = inpaint_worker.InpaintWorker(image=inpaint_image, mask=inpaint_mask, + is_outpaint=len(outpaint_selections) > 0) + + # print(f'Inpaint task: {str((height, width))}') + # outputs.append(['results', inpaint_worker.current_task.visualize_mask_processing()]) + # return + + inpaint_pixels = core.numpy_to_pytorch(inpaint_worker.current_task.image_ready) + progressbar(0, 'VAE encoding ...') + initial_latent = core.encode_vae(vae=pipeline.xl_base_patched.vae, pixels=inpaint_pixels) + inpaint_latent = initial_latent['samples'] + B, C, H, W = inpaint_latent.shape + inpaint_mask = core.numpy_to_pytorch(inpaint_worker.current_task.mask_ready[None]) + inpaint_mask = torch.nn.functional.avg_pool2d(inpaint_mask, (8, 8)) + inpaint_mask = torch.nn.functional.interpolate(inpaint_mask, (H, W), mode='bilinear') + width = W * 8 + height = H * 8 + inpaint_worker.current_task.load_latent(latent=inpaint_latent, mask=inpaint_mask) progressbar(1, 'Initializing ...') @@ -262,6 +312,8 @@ def worker(): f'Step {step}/{total_steps} in the {current_task_id + 1}-th Sampling', y)]) + print(f'[ADM] Negative ADM = {modules.patch.negative_adm}') + outputs.append(['preview', (13, 'Starting tasks ...', None)]) for current_task_id, task in enumerate(tasks): try: @@ -279,6 +331,9 @@ def worker(): tiled=tiled ) + if inpaint_worker.current_task is not None: + imgs = [inpaint_worker.current_task.post_process(x) for x in imgs] + for x in imgs: d = [ ('Prompt', raw_prompt), diff --git a/modules/core.py b/modules/core.py index d36e906f..cd3c68df 100644 --- a/modules/core.py +++ b/modules/core.py @@ -11,7 +11,7 @@ from comfy.sd import load_checkpoint_guess_config from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models from comfy.model_base import SDXLRefiner -from modules.samplers_advanced import KSampler, KSamplerWithRefiner +from modules.samplers_advanced import KSamplerBasic, KSamplerWithRefiner from modules.patch import patch_all @@ -147,7 +147,7 @@ def get_previewer(device, latent_format): @torch.no_grad() @torch.inference_mode() -def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu', +def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_fooocus_2m_sde_inpaint_seamless', scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, callback_function=None): # SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] @@ -199,7 +199,7 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa models = load_additional_models(positive, negative, model.model_dtype()) - sampler = KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, + sampler = KSamplerBasic(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, @@ -220,7 +220,7 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa @torch.no_grad() @torch.inference_mode() def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive, refiner_negative, latent, - seed=None, steps=30, refiner_switch_step=20, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu', + seed=None, steps=30, refiner_switch_step=20, cfg=7.0, sampler_name='dpmpp_fooocus_2m_sde_inpaint_seamless', scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, callback_function=None): # SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 6ad7c3f1..ed45d42c 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -5,7 +5,7 @@ import modules.path import modules.virtual_memory as virtual_memory from comfy.model_base import SDXL, SDXLRefiner -from modules.patch import cfg_patched +from modules.patch import cfg_patched, patched_model_function from modules.expansion import FooocusExpansion @@ -201,10 +201,14 @@ def patch_all_models(): assert xl_base_patched is not None xl_base.unet.model_options['sampler_cfg_function'] = cfg_patched + xl_base.unet.model_options['model_function_wrapper'] = patched_model_function + xl_base_patched.unet.model_options['sampler_cfg_function'] = cfg_patched + xl_base_patched.unet.model_options['model_function_wrapper'] = patched_model_function if xl_refiner is not None: xl_refiner.unet.model_options['sampler_cfg_function'] = cfg_patched + xl_refiner.unet.model_options['model_function_wrapper'] = patched_model_function return diff --git a/modules/flags.py b/modules/flags.py index 3b8ca025..cb47e617 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -1,4 +1,5 @@ disabled = 'Disabled' +enabled = 'Enabled' subtle_variation = 'Vary (Subtle)' strong_variation = 'Vary (Strong)' upscale_15 = 'Upscale (1.5x)' diff --git a/modules/inpaint_worker.py b/modules/inpaint_worker.py new file mode 100644 index 00000000..5ab4ab0d --- /dev/null +++ b/modules/inpaint_worker.py @@ -0,0 +1,180 @@ +import numpy as np + +from PIL import Image, ImageFilter, ImageOps +from modules.util import resample_image + + +current_task = None + + +def morphological_soft_open(x): + k = 12 + x = Image.fromarray(x) + for _ in range(k): + x = x.filter(ImageFilter.MaxFilter(3)) + x = x.filter(ImageFilter.BoxBlur(k * 2 + 1)) + x = np.array(x) + return x + + +def box_blur(x, k): + x = Image.fromarray(x) + x = x.filter(ImageFilter.BoxBlur(k)) + return np.array(x) + + +def threshold_0_255(x): + y = np.zeros_like(x) + y[x > 127] = 255 + return y + + +def morphological_hard_open(x): + y = threshold_0_255(x) + z = morphological_soft_open(x) + z[y > 127] = 255 + return z + + +def imsave(x, path): + x = Image.fromarray(x) + x.save(path) + + +def regulate_abcd(x, a, b, c, d): + H, W = x.shape[:2] + if a < 0: + a = 0 + if a > H: + a = H + if b < 0: + b = 0 + if b > H: + b = H + if c < 0: + c = 0 + if c > W: + c = W + if d < 0: + d = 0 + if d > W: + d = W + return int(a), int(b), int(c), int(d) + + +def compute_initial_abcd(x): + indices = np.where(x) + a = np.min(indices[0]) - 64 + b = np.max(indices[0]) + 65 + c = np.min(indices[1]) - 64 + d = np.max(indices[1]) + 65 + a, b, c, d = regulate_abcd(x, a, b, c, d) + return a, b, c, d + + +def area_abcd(a, b, c, d): + return (b - a) * (d - c) + + +def solve_abcd(x, a, b, c, d, k, outpaint): + H, W = x.shape[:2] + if outpaint: + return 0, H, 0, W + min_area = H * W * k + while area_abcd(a, b, c, d) < min_area: + if (b - a) < (d - c): + a -= 1 + b += 1 + else: + c -= 1 + d += 1 + a, b, c, d = regulate_abcd(x, a, b, c, d) + return a, b, c, d + + +def fooocus_fill(image, mask): + current_image = image.copy() + raw_image = image.copy() + area = np.where(mask < 127) + store = raw_image[area] + + for k, repeats in [(64, 4), (32, 4), (16, 4), (4, 4), (2, 4)]: + for _ in range(repeats): + current_image = box_blur(current_image, k) + current_image[area] = store + + return current_image + + +class InpaintWorker: + def __init__(self, image, mask, is_outpaint): + # mask processing + self.image_raw = fooocus_fill(image, mask) + self.mask_raw_user_input = mask + self.mask_raw_soft = morphological_hard_open(mask) + self.mask_raw_fg = (self.mask_raw_soft == 255).astype(np.uint8) * 255 + self.mask_raw_bg = (self.mask_raw_soft == 0).astype(np.uint8) * 255 + self.mask_raw_trim = 255 - np.maximum(self.mask_raw_fg, self.mask_raw_bg) + self.mask_raw_error = (self.mask_raw_user_input > self.mask_raw_fg).astype(np.uint8) * 255 + + # log all images + # imsave(self.mask_raw_user_input, 'mask_raw_user_input.png') + # imsave(self.mask_raw_soft, 'mask_raw_soft.png') + # imsave(self.mask_raw_fg, 'mask_raw_fg.png') + # imsave(self.mask_raw_bg, 'mask_raw_bg.png') + # imsave(self.mask_raw_trim, 'mask_raw_trim.png') + # imsave(self.mask_raw_error, 'mask_raw_error.png') + + # compute abcd + a, b, c, d = compute_initial_abcd(self.mask_raw_bg < 127) + a, b, c, d = solve_abcd(self.mask_raw_bg, a, b, c, d, k=0.618, outpaint=is_outpaint) + + # interested area + self.interested_area = (a, b, c, d) + self.mask_interested_soft = self.mask_raw_soft[a:b, c:d] + self.mask_interested_fg = self.mask_raw_fg[a:b, c:d] + self.mask_interested_bg = self.mask_raw_bg[a:b, c:d] + self.mask_interested_trim = self.mask_raw_trim[a:b, c:d] + self.image_interested = self.image_raw[a:b, c:d] + + # resize to make images ready for diffusion + H, W, C = self.image_interested.shape + k = (1024.0 ** 2.0 / float(H * W)) ** 0.5 + H = int(np.ceil(float(H) * k / 16.0)) * 16 + W = int(np.ceil(float(W) * k / 16.0)) * 16 + self.image_ready = resample_image(self.image_interested, W, H) + self.mask_ready = resample_image(self.mask_interested_soft, W, H) + + # ending + self.latent = None + self.latent_mask = None + self.uc_guidance = None + return + + def load_latent(self, latent, mask): + self.latent = latent + self.latent_mask = mask + + def color_correction(self, img): + fg = img.astype(np.float32) + bg = self.image_raw.copy().astype(np.float32) + w = self.mask_raw_soft[:, :, None].astype(np.float32) / 255.0 + y = fg * w + bg * (1 - w) + return y.clip(0, 255).astype(np.uint8) + + def post_process(self, img): + a, b, c, d = self.interested_area + content = resample_image(img, d - c, b - a) + result = self.image_raw.copy() + result[a:b, c:d] = content + result = self.color_correction(result) + return result + + def visualize_mask_processing(self): + result = self.image_raw // 4 + a, b, c, d = self.interested_area + result[a:b, c:d] += 64 + result[self.mask_raw_trim > 127] += 64 + result[self.mask_raw_fg > 127] += 128 + return [result, self.mask_raw_soft, self.image_ready, self.mask_ready] + diff --git a/modules/patch.py b/modules/patch.py index 63a913a7..695cbe80 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -6,15 +6,20 @@ import comfy.k_diffusion.external import comfy.model_management import modules.anisotropic as anisotropic import comfy.ldm.modules.attention +import comfy.k_diffusion.sampling import comfy.sd1_clip +import modules.inpaint_worker as inpaint_worker from comfy.k_diffusion import utils +from comfy.k_diffusion.sampling import BrownianTreeNoiseSampler, trange sharpness = 2.0 +negative_adm = True cfg_x0 = 0.0 cfg_s = 1.0 +cfg_cin = 1.0 def cfg_patched(args): @@ -37,14 +42,29 @@ def cfg_patched(args): def patched_discrete_eps_ddpm_denoiser_forward(self, input, sigma, **kwargs): - global cfg_x0, cfg_s + global cfg_x0, cfg_s, cfg_cin c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] cfg_x0 = input cfg_s = c_out + cfg_cin = c_in return self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) +def patched_model_function(func, args): + global cfg_cin + x = args['input'] + t = args['timestep'] + c = args['c'] + is_uncond = torch.tensor(args['cond_or_uncond'])[:, None, None, None].to(x) * 5e-3 + if inpaint_worker.current_task is not None: + p = inpaint_worker.current_task.uc_guidance * cfg_cin + x = p * is_uncond + x * (1 - is_uncond ** 2.0) ** 0.5 + return func(x, t, **c) + + def sdxl_encode_adm_patched(self, **kwargs): + global negative_adm + clip_pooled = kwargs["pooled_output"] width = kwargs.get("width", 768) height = kwargs.get("height", 768) @@ -53,12 +73,13 @@ def sdxl_encode_adm_patched(self, **kwargs): target_width = kwargs.get("target_width", width) target_height = kwargs.get("target_height", height) - if kwargs.get("prompt_type", "") == "negative": - width *= 0.8 - height *= 0.8 - elif kwargs.get("prompt_type", "") == "positive": - width *= 1.5 - height *= 1.5 + if negative_adm: + if kwargs.get("prompt_type", "") == "negative": + width *= 0.8 + height *= 0.8 + elif kwargs.get("prompt_type", "") == "positive": + width *= 1.5 + height *= 1.5 out = [] out.append(self.embedder(torch.Tensor([height]))) @@ -71,35 +92,6 @@ def sdxl_encode_adm_patched(self, **kwargs): return torch.cat((clip_pooled.to(flat.device), flat), dim=1) -def sdxl_refiner_encode_adm_patched(self, **kwargs): - clip_pooled = kwargs["pooled_output"] - width = kwargs.get("width", 768) - height = kwargs.get("height", 768) - crop_w = kwargs.get("crop_w", 0) - crop_h = kwargs.get("crop_h", 0) - - if kwargs.get("prompt_type", "") == "negative": - aesthetic_score = kwargs.get("aesthetic_score", 2.5) - else: - aesthetic_score = kwargs.get("aesthetic_score", 7.0) - - if kwargs.get("prompt_type", "") == "negative": - width *= 0.8 - height *= 0.8 - elif kwargs.get("prompt_type", "") == "positive": - width *= 1.5 - height *= 1.5 - - out = [] - out.append(self.embedder(torch.Tensor([height]))) - out.append(self.embedder(torch.Tensor([width]))) - out.append(self.embedder(torch.Tensor([crop_h]))) - out.append(self.embedder(torch.Tensor([crop_w]))) - out.append(self.embedder(torch.Tensor([aesthetic_score]))) - flat = torch.flatten(torch.cat(out))[None,] - return torch.cat((clip_pooled.to(flat.device), flat), dim=1) - - def text_encoder_device_patched(): # Fooocus's style system uses text encoder much more times than comfy so this makes things much faster. return comfy.model_management.get_torch_device() @@ -138,15 +130,79 @@ def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs): return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() +@torch.no_grad() +def sample_dpmpp_fooocus_2m_sde_inpaint_seamless(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, **kwargs): + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler + + seed = extra_args.get("seed", None) + assert isinstance(seed, int) + + energy_generator = torch.Generator(device='cpu') + energy_generator.manual_seed(seed + 1) # avoid bad results by using different seeds. + + def get_energy(): + return torch.randn(x.size(), dtype=x.dtype, generator=energy_generator, device="cpu").to(x) + + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + old_denoised, h_last, h = None, None, None + + latent_processor = model.inner_model.inner_model.inner_model.process_latent_in + inpaint_latent = None + inpaint_mask = None + + if inpaint_worker.current_task is not None: + inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x) + inpaint_mask = inpaint_worker.current_task.latent_mask.to(x) + + def blend_latent(a, b, w): + return a * w + b * (1 - w) + + for i in trange(len(sigmas) - 1, disable=disable): + if inpaint_latent is None: + denoised = model(x, sigmas[i] * s_in, **extra_args) + else: + inpaint_worker.current_task.uc_guidance = x.detach().clone() + energy = get_energy() * sigmas[i] + inpaint_latent + x_prime = blend_latent(x, energy, inpaint_mask) + denoised = model(x_prime, sigmas[i] * s_in, **extra_args) + denoised = blend_latent(denoised, inpaint_latent, inpaint_mask) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + x = denoised + else: + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + eta_h = eta * h + + x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised + if old_denoised is not None: + r = h_last / h + x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) + + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * ( + -2 * eta_h).expm1().neg().sqrt() * s_noise + + old_denoised = denoised + h_last = h + + return x + + def patch_all(): comfy.ldm.modules.attention.print = lambda x: None + comfy.k_diffusion.sampling.sample_dpmpp_fooocus_2m_sde_inpaint_seamless = sample_dpmpp_fooocus_2m_sde_inpaint_seamless comfy.model_management.text_encoder_device = text_encoder_device_patched print(f'Fooocus Text Processing Pipelines are retargeted to {str(comfy.model_management.text_encoder_device())}') comfy.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched - # comfy.model_base.SDXLRefiner.encode_adm = sdxl_refiner_encode_adm_patched comfy.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method return diff --git a/modules/samplers_advanced.py b/modules/samplers_advanced.py index 2d382f64..56850ebb 100644 --- a/modules/samplers_advanced.py +++ b/modules/samplers_advanced.py @@ -4,11 +4,209 @@ import comfy.model_management import modules.virtual_memory +class KSamplerBasic: + SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] + SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", + "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2", "dpmpp_fooocus_2m_sde_inpaint_seamless"] + + def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): + self.model = model + self.model_denoise = CFGNoisePredictor(self.model) + if self.model.model_type == model_base.ModelType.V_PREDICTION: + self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True) + else: + self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True) + + self.model_k = KSamplerX0Inpaint(self.model_wrap) + self.device = device + if scheduler not in self.SCHEDULERS: + scheduler = self.SCHEDULERS[0] + if sampler not in self.SAMPLERS: + sampler = self.SAMPLERS[0] + self.scheduler = scheduler + self.sampler = sampler + self.sigma_min=float(self.model_wrap.sigma_min) + self.sigma_max=float(self.model_wrap.sigma_max) + self.set_steps(steps, denoise) + self.denoise = denoise + self.model_options = model_options + + def calculate_sigmas(self, steps): + sigmas = None + + discard_penultimate_sigma = False + if self.sampler in ['dpm_2', 'dpm_2_ancestral']: + steps += 1 + discard_penultimate_sigma = True + + if self.scheduler == "karras": + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) + elif self.scheduler == "exponential": + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) + elif self.scheduler == "normal": + sigmas = self.model_wrap.get_sigmas(steps) + elif self.scheduler == "simple": + sigmas = simple_scheduler(self.model_wrap, steps) + elif self.scheduler == "ddim_uniform": + sigmas = ddim_scheduler(self.model_wrap, steps) + else: + print("error invalid scheduler", self.scheduler) + + if discard_penultimate_sigma: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + return sigmas + + def set_steps(self, steps, denoise=None): + self.steps = steps + if denoise is None or denoise > 0.9999: + self.sigmas = self.calculate_sigmas(steps).to(self.device) + else: + new_steps = int(steps/denoise) + sigmas = self.calculate_sigmas(new_steps).to(self.device) + self.sigmas = sigmas[-(steps + 1):] + + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): + if sigmas is None: + sigmas = self.sigmas + sigma_min = self.sigma_min + + if last_step is not None and last_step < (len(sigmas) - 1): + sigma_min = sigmas[last_step] + sigmas = sigmas[:last_step + 1] + if force_full_denoise: + sigmas[-1] = 0 + + if start_step is not None: + if start_step < (len(sigmas) - 1): + sigmas = sigmas[start_step:] + else: + if latent_image is not None: + return latent_image + else: + return torch.zeros_like(noise) + + positive = positive[:] + negative = negative[:] + + resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) + resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) + + calculate_start_end_timesteps(self.model_wrap, negative) + calculate_start_end_timesteps(self.model_wrap, positive) + + #make sure each cond area has an opposite one with the same area + for c in positive: + create_cond_with_same_area_if_none(negative, c) + for c in negative: + create_cond_with_same_area_if_none(positive, c) + + pre_run_control(self.model_wrap, negative + positive) + + apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) + + if self.model.is_adm(): + positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive") + negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative") + + if latent_image is not None: + latent_image = self.model.process_latent_in(latent_image) + + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options, "seed":seed} + + cond_concat = None + if hasattr(self.model, 'concat_keys'): #inpaint + cond_concat = [] + for ck in self.model.concat_keys: + if denoise_mask is not None: + if ck == "mask": + cond_concat.append(denoise_mask[:,:1]) + elif ck == "masked_image": + cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space + else: + if ck == "mask": + cond_concat.append(torch.ones_like(noise)[:,:1]) + elif ck == "masked_image": + cond_concat.append(blank_inpaint_image_like(noise)) + extra_args["cond_concat"] = cond_concat + + if sigmas[0] != self.sigmas[0] or (self.denoise is not None and self.denoise < 1.0): + max_denoise = False + else: + max_denoise = True + + + if self.sampler == "uni_pc": + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) + elif self.sampler == "uni_pc_bh2": + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) + elif self.sampler == "ddim": + timesteps = [] + for s in range(sigmas.shape[0]): + timesteps.insert(0, self.model_wrap.sigma_to_discrete_timestep(sigmas[s])) + noise_mask = None + if denoise_mask is not None: + noise_mask = 1.0 - denoise_mask + + ddim_callback = None + if callback is not None: + total_steps = len(timesteps) - 1 + ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps) + + sampler = DDIMSampler(self.model, device=self.device) + sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) + z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise) + samples, _ = sampler.sample_custom(ddim_timesteps=timesteps, + conditioning=positive, + batch_size=noise.shape[0], + shape=noise.shape[1:], + verbose=False, + unconditional_guidance_scale=cfg, + unconditional_conditioning=negative, + eta=0.0, + x_T=z_enc, + x0=latent_image, + img_callback=ddim_callback, + denoise_function=self.model_wrap.predict_eps_discrete_timestep, + extra_args=extra_args, + mask=noise_mask, + to_zero=sigmas[-1]==0, + end_step=sigmas.shape[0] - 1, + disable_pbar=disable_pbar) + + else: + extra_args["denoise_mask"] = denoise_mask + self.model_k.latent_image = latent_image + self.model_k.noise = noise + + if max_denoise: + noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) + else: + noise = noise * sigmas[0] + + k_callback = None + total_steps = len(sigmas) - 1 + if callback is not None: + k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) + + if latent_image is not None: + noise += latent_image + if self.sampler == "dpm_fast": + samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) + elif self.sampler == "dpm_adaptive": + samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) + else: + samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar) + + return self.model.process_latent_out(samples.to(torch.float32)) + + class KSamplerWithRefiner: SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2", "dpmpp_fooocus_2m_sde_inpaint_seamless"] def __init__(self, model, refiner_model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model_patcher = model diff --git a/modules/util.py b/modules/util.py index d8b8e1e5..c8a4f13a 100644 --- a/modules/util.py +++ b/modules/util.py @@ -28,6 +28,12 @@ def image_is_generated_in_current_ui(image, ui_width, ui_height): LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) +def resample_image(im, width, height): + im = Image.fromarray(im) + im = im.resize((width, height), resample=LANCZOS) + return np.array(im) + + def resize_image(im, width, height, resize_mode=1): """ Resizes an image with the specified resize_mode, width, and height. diff --git a/webui.py b/webui.py index a36a0ea7..bab52ed3 100644 --- a/webui.py +++ b/webui.py @@ -61,14 +61,31 @@ with shared.gradio_root: input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check') advanced_checkbox = gr.Checkbox(label='Advanced', value=False, container=False, elem_classes='min_check') with gr.Row(visible=False) as image_input_panel: - with gr.Column(scale=0.5): - with gr.Accordion(label='Upscale or Variation', open=True): - uov_input_image = gr.Image(label='Drag above image to here', source='upload', type='numpy') - uov_method = gr.Radio(label='Method', choices=flags.uov_list, value=flags.disabled, show_label=False, container=False) - gr.HTML('\U0001F4D4 Document') + with gr.Tabs(): + with gr.TabItem(label='Upscale or Variation') as uov_tab: + with gr.Row(): + with gr.Column(): + uov_input_image = gr.Image(label='Drag above image to here', source='upload', type='numpy') + with gr.Column(): + uov_method = gr.Radio(label='Upscale or Variation:', choices=flags.uov_list, value=flags.disabled) + gr.HTML('\U0001F4D4 Document') + with gr.TabItem(label='Inpaint or Outpaint') as inpaint_tab: + inpaint_input_image = gr.Image(label='Drag above image to here', source='upload', type='numpy', tool='sketch', height=500, brush_color="#FFFFFF") + gr.HTML('Outpaint Expansion (\U0001F4D4 Document):') + outpaint_selections = gr.CheckboxGroup(choices=['Left', 'Right', 'Top', 'Bottom'], value=[], label='Outpaint', show_label=False, container=False) + gr.HTML('* \"Inpaint or Outpaint\" is powered by the sampler \"DPMPP Fooocus Seamless 2M SDE Karras Inpaint Sampler\" (beta)') + input_image_checkbox.change(lambda x: gr.update(visible=x), inputs=input_image_checkbox, outputs=image_input_panel, queue=False, _js="(x) => {if(x){setTimeout(() => window.scrollTo({ top: window.scrollY + 500, behavior: 'smooth' }), 50);}else{setTimeout(() => window.scrollTo({ top: 0, behavior: 'smooth' }), 50);} return x}") + current_tab = gr.Textbox(value='uov', visible=False) + uov_tab.select(lambda: 'uov', outputs=current_tab, queue=False) + inpaint_tab.select(lambda: 'inpaint', outputs=current_tab, queue=False) + + uov_input_image.upload(lambda x: x, inputs=[uov_input_image], outputs=[inpaint_input_image]) + inpaint_input_image.upload(lambda: None).\ + then(lambda x: x['image'], inputs=[inpaint_input_image], outputs=[uov_input_image]) + # def get_select_index(g, evt: gr.SelectData): # return g[evt.index]['name'] # gallery.select(get_select_index, gallery, uov_input_image) @@ -132,8 +149,9 @@ with shared.gradio_root: performance_selction, aspect_ratios_selction, image_number, image_seed, sharpness ] ctrls += [base_model, refiner_model] + lora_ctrls - ctrls += [input_image_checkbox] + ctrls += [input_image_checkbox, current_tab] ctrls += [uov_method, uov_input_image] + ctrls += [outpaint_selections, inpaint_input_image] run_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=False), []), outputs=[stop_button, run_button, gallery])\ .then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed)\