diff --git a/.gitignore b/.gitignore
index eeaeb1ed..fb1bffdd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,6 +6,7 @@ __pycache__
lena.png
lena_result.png
lena_test.py
+/modules/*.png
/repositories
/venv
/tmp
diff --git a/fooocus_version.py b/fooocus_version.py
index 1d8c93b8..40dfeb3b 100644
--- a/fooocus_version.py
+++ b/fooocus_version.py
@@ -1 +1 @@
-version = '2.0.54'
+version = '2.0.60'
diff --git a/modules/async_worker.py b/modules/async_worker.py
index a6a42b1a..ec001e92 100644
--- a/modules/async_worker.py
+++ b/modules/async_worker.py
@@ -1,4 +1,6 @@
import threading
+
+import numpy as np
import torch
buffer = []
@@ -19,6 +21,7 @@ def worker():
import modules.patch
import modules.virtual_memory as virtual_memory
import comfy.model_management
+ import modules.inpaint_worker as inpaint_worker
from modules.sdxl_styles import apply_style, aspect_ratios, fooocus_expansion
from modules.private_logger import log
@@ -46,8 +49,10 @@ def worker():
aspect_ratios_selction, image_number, image_seed, sharpness, \
base_model_name, refiner_model_name, \
l1, w1, l2, w2, l3, w3, l4, w4, l5, w5, \
- input_image_checkbox, \
- uov_method, uov_input_image = task
+ input_image_checkbox, current_tab, \
+ uov_method, uov_input_image, outpaint_selections, inpaint_input_image = task
+
+ outpaint_selections = [o.lower() for o in outpaint_selections]
loras = [(l1, w1), (l2, w2), (l3, w3), (l4, w4), (l5, w5)]
@@ -63,9 +68,11 @@ def worker():
use_style = len(style_selections) > 0
modules.patch.sharpness = sharpness
+ modules.patch.negative_adm = True
initial_latent = None
denoising_strength = 1.0
tiled = False
+ inpaint_worker.current_task = None
if performance_selction == 'Speed':
steps = 30
@@ -80,7 +87,7 @@ def worker():
if input_image_checkbox:
progressbar(0, 'Image processing ...')
- if uov_method != flags.disabled and uov_input_image is not None:
+ if current_tab == 'uov' and uov_method != flags.disabled and uov_input_image is not None:
uov_input_image = HWC3(uov_input_image)
if 'vary' in uov_method:
if not image_is_generated_in_current_ui(uov_input_image, ui_width=width, ui_height=height):
@@ -156,6 +163,49 @@ def worker():
width = W * 8
height = H * 8
print(f'Final resolution is {str((height, width))}.')
+ if current_tab == 'inpaint' and isinstance(inpaint_input_image, dict):
+ inpaint_image = inpaint_input_image['image']
+ inpaint_mask = inpaint_input_image['mask'][:, :, 0]
+ if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
+ and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0):
+ if len(outpaint_selections) > 0:
+ H, W, C = inpaint_image.shape
+ if 'top' in outpaint_selections:
+ inpaint_image = np.pad(inpaint_image, [[int(H * 0.3), 0], [0, 0], [0, 0]], mode='edge')
+ inpaint_mask = np.pad(inpaint_mask, [[int(H * 0.3), 0], [0, 0]], mode='constant', constant_values=255)
+ if 'bottom' in outpaint_selections:
+ inpaint_image = np.pad(inpaint_image, [[0, int(H * 0.3)], [0, 0], [0, 0]], mode='edge')
+ inpaint_mask = np.pad(inpaint_mask, [[0, int(H * 0.3)], [0, 0]], mode='constant', constant_values=255)
+
+ H, W, C = inpaint_image.shape
+ if 'left' in outpaint_selections:
+ inpaint_image = np.pad(inpaint_image, [[0, 0], [int(H * 0.3), 0], [0, 0]], mode='edge')
+ inpaint_mask = np.pad(inpaint_mask, [[0, 0], [int(H * 0.3), 0]], mode='constant', constant_values=255)
+ if 'right' in outpaint_selections:
+ inpaint_image = np.pad(inpaint_image, [[0, 0], [0, int(H * 0.3)], [0, 0]], mode='edge')
+ inpaint_mask = np.pad(inpaint_mask, [[0, 0], [0, int(H * 0.3)]], mode='constant', constant_values=255)
+
+ inpaint_image = np.ascontiguousarray(inpaint_image.copy())
+ inpaint_mask = np.ascontiguousarray(inpaint_mask.copy())
+
+ inpaint_worker.current_task = inpaint_worker.InpaintWorker(image=inpaint_image, mask=inpaint_mask,
+ is_outpaint=len(outpaint_selections) > 0)
+
+ # print(f'Inpaint task: {str((height, width))}')
+ # outputs.append(['results', inpaint_worker.current_task.visualize_mask_processing()])
+ # return
+
+ inpaint_pixels = core.numpy_to_pytorch(inpaint_worker.current_task.image_ready)
+ progressbar(0, 'VAE encoding ...')
+ initial_latent = core.encode_vae(vae=pipeline.xl_base_patched.vae, pixels=inpaint_pixels)
+ inpaint_latent = initial_latent['samples']
+ B, C, H, W = inpaint_latent.shape
+ inpaint_mask = core.numpy_to_pytorch(inpaint_worker.current_task.mask_ready[None])
+ inpaint_mask = torch.nn.functional.avg_pool2d(inpaint_mask, (8, 8))
+ inpaint_mask = torch.nn.functional.interpolate(inpaint_mask, (H, W), mode='bilinear')
+ width = W * 8
+ height = H * 8
+ inpaint_worker.current_task.load_latent(latent=inpaint_latent, mask=inpaint_mask)
progressbar(1, 'Initializing ...')
@@ -262,6 +312,8 @@ def worker():
f'Step {step}/{total_steps} in the {current_task_id + 1}-th Sampling',
y)])
+ print(f'[ADM] Negative ADM = {modules.patch.negative_adm}')
+
outputs.append(['preview', (13, 'Starting tasks ...', None)])
for current_task_id, task in enumerate(tasks):
try:
@@ -279,6 +331,9 @@ def worker():
tiled=tiled
)
+ if inpaint_worker.current_task is not None:
+ imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
+
for x in imgs:
d = [
('Prompt', raw_prompt),
diff --git a/modules/core.py b/modules/core.py
index d36e906f..cd3c68df 100644
--- a/modules/core.py
+++ b/modules/core.py
@@ -11,7 +11,7 @@ from comfy.sd import load_checkpoint_guess_config
from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
from comfy.model_base import SDXLRefiner
-from modules.samplers_advanced import KSampler, KSamplerWithRefiner
+from modules.samplers_advanced import KSamplerBasic, KSamplerWithRefiner
from modules.patch import patch_all
@@ -147,7 +147,7 @@ def get_previewer(device, latent_format):
@torch.no_grad()
@torch.inference_mode()
-def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu',
+def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_fooocus_2m_sde_inpaint_seamless',
scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None,
force_full_denoise=False, callback_function=None):
# SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
@@ -199,7 +199,7 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa
models = load_additional_models(positive, negative, model.model_dtype())
- sampler = KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler,
+ sampler = KSamplerBasic(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler,
denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image,
@@ -220,7 +220,7 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa
@torch.no_grad()
@torch.inference_mode()
def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive, refiner_negative, latent,
- seed=None, steps=30, refiner_switch_step=20, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu',
+ seed=None, steps=30, refiner_switch_step=20, cfg=7.0, sampler_name='dpmpp_fooocus_2m_sde_inpaint_seamless',
scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None,
force_full_denoise=False, callback_function=None):
# SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py
index 6ad7c3f1..ed45d42c 100644
--- a/modules/default_pipeline.py
+++ b/modules/default_pipeline.py
@@ -5,7 +5,7 @@ import modules.path
import modules.virtual_memory as virtual_memory
from comfy.model_base import SDXL, SDXLRefiner
-from modules.patch import cfg_patched
+from modules.patch import cfg_patched, patched_model_function
from modules.expansion import FooocusExpansion
@@ -201,10 +201,14 @@ def patch_all_models():
assert xl_base_patched is not None
xl_base.unet.model_options['sampler_cfg_function'] = cfg_patched
+ xl_base.unet.model_options['model_function_wrapper'] = patched_model_function
+
xl_base_patched.unet.model_options['sampler_cfg_function'] = cfg_patched
+ xl_base_patched.unet.model_options['model_function_wrapper'] = patched_model_function
if xl_refiner is not None:
xl_refiner.unet.model_options['sampler_cfg_function'] = cfg_patched
+ xl_refiner.unet.model_options['model_function_wrapper'] = patched_model_function
return
diff --git a/modules/flags.py b/modules/flags.py
index 3b8ca025..cb47e617 100644
--- a/modules/flags.py
+++ b/modules/flags.py
@@ -1,4 +1,5 @@
disabled = 'Disabled'
+enabled = 'Enabled'
subtle_variation = 'Vary (Subtle)'
strong_variation = 'Vary (Strong)'
upscale_15 = 'Upscale (1.5x)'
diff --git a/modules/inpaint_worker.py b/modules/inpaint_worker.py
new file mode 100644
index 00000000..5ab4ab0d
--- /dev/null
+++ b/modules/inpaint_worker.py
@@ -0,0 +1,180 @@
+import numpy as np
+
+from PIL import Image, ImageFilter, ImageOps
+from modules.util import resample_image
+
+
+current_task = None
+
+
+def morphological_soft_open(x):
+ k = 12
+ x = Image.fromarray(x)
+ for _ in range(k):
+ x = x.filter(ImageFilter.MaxFilter(3))
+ x = x.filter(ImageFilter.BoxBlur(k * 2 + 1))
+ x = np.array(x)
+ return x
+
+
+def box_blur(x, k):
+ x = Image.fromarray(x)
+ x = x.filter(ImageFilter.BoxBlur(k))
+ return np.array(x)
+
+
+def threshold_0_255(x):
+ y = np.zeros_like(x)
+ y[x > 127] = 255
+ return y
+
+
+def morphological_hard_open(x):
+ y = threshold_0_255(x)
+ z = morphological_soft_open(x)
+ z[y > 127] = 255
+ return z
+
+
+def imsave(x, path):
+ x = Image.fromarray(x)
+ x.save(path)
+
+
+def regulate_abcd(x, a, b, c, d):
+ H, W = x.shape[:2]
+ if a < 0:
+ a = 0
+ if a > H:
+ a = H
+ if b < 0:
+ b = 0
+ if b > H:
+ b = H
+ if c < 0:
+ c = 0
+ if c > W:
+ c = W
+ if d < 0:
+ d = 0
+ if d > W:
+ d = W
+ return int(a), int(b), int(c), int(d)
+
+
+def compute_initial_abcd(x):
+ indices = np.where(x)
+ a = np.min(indices[0]) - 64
+ b = np.max(indices[0]) + 65
+ c = np.min(indices[1]) - 64
+ d = np.max(indices[1]) + 65
+ a, b, c, d = regulate_abcd(x, a, b, c, d)
+ return a, b, c, d
+
+
+def area_abcd(a, b, c, d):
+ return (b - a) * (d - c)
+
+
+def solve_abcd(x, a, b, c, d, k, outpaint):
+ H, W = x.shape[:2]
+ if outpaint:
+ return 0, H, 0, W
+ min_area = H * W * k
+ while area_abcd(a, b, c, d) < min_area:
+ if (b - a) < (d - c):
+ a -= 1
+ b += 1
+ else:
+ c -= 1
+ d += 1
+ a, b, c, d = regulate_abcd(x, a, b, c, d)
+ return a, b, c, d
+
+
+def fooocus_fill(image, mask):
+ current_image = image.copy()
+ raw_image = image.copy()
+ area = np.where(mask < 127)
+ store = raw_image[area]
+
+ for k, repeats in [(64, 4), (32, 4), (16, 4), (4, 4), (2, 4)]:
+ for _ in range(repeats):
+ current_image = box_blur(current_image, k)
+ current_image[area] = store
+
+ return current_image
+
+
+class InpaintWorker:
+ def __init__(self, image, mask, is_outpaint):
+ # mask processing
+ self.image_raw = fooocus_fill(image, mask)
+ self.mask_raw_user_input = mask
+ self.mask_raw_soft = morphological_hard_open(mask)
+ self.mask_raw_fg = (self.mask_raw_soft == 255).astype(np.uint8) * 255
+ self.mask_raw_bg = (self.mask_raw_soft == 0).astype(np.uint8) * 255
+ self.mask_raw_trim = 255 - np.maximum(self.mask_raw_fg, self.mask_raw_bg)
+ self.mask_raw_error = (self.mask_raw_user_input > self.mask_raw_fg).astype(np.uint8) * 255
+
+ # log all images
+ # imsave(self.mask_raw_user_input, 'mask_raw_user_input.png')
+ # imsave(self.mask_raw_soft, 'mask_raw_soft.png')
+ # imsave(self.mask_raw_fg, 'mask_raw_fg.png')
+ # imsave(self.mask_raw_bg, 'mask_raw_bg.png')
+ # imsave(self.mask_raw_trim, 'mask_raw_trim.png')
+ # imsave(self.mask_raw_error, 'mask_raw_error.png')
+
+ # compute abcd
+ a, b, c, d = compute_initial_abcd(self.mask_raw_bg < 127)
+ a, b, c, d = solve_abcd(self.mask_raw_bg, a, b, c, d, k=0.618, outpaint=is_outpaint)
+
+ # interested area
+ self.interested_area = (a, b, c, d)
+ self.mask_interested_soft = self.mask_raw_soft[a:b, c:d]
+ self.mask_interested_fg = self.mask_raw_fg[a:b, c:d]
+ self.mask_interested_bg = self.mask_raw_bg[a:b, c:d]
+ self.mask_interested_trim = self.mask_raw_trim[a:b, c:d]
+ self.image_interested = self.image_raw[a:b, c:d]
+
+ # resize to make images ready for diffusion
+ H, W, C = self.image_interested.shape
+ k = (1024.0 ** 2.0 / float(H * W)) ** 0.5
+ H = int(np.ceil(float(H) * k / 16.0)) * 16
+ W = int(np.ceil(float(W) * k / 16.0)) * 16
+ self.image_ready = resample_image(self.image_interested, W, H)
+ self.mask_ready = resample_image(self.mask_interested_soft, W, H)
+
+ # ending
+ self.latent = None
+ self.latent_mask = None
+ self.uc_guidance = None
+ return
+
+ def load_latent(self, latent, mask):
+ self.latent = latent
+ self.latent_mask = mask
+
+ def color_correction(self, img):
+ fg = img.astype(np.float32)
+ bg = self.image_raw.copy().astype(np.float32)
+ w = self.mask_raw_soft[:, :, None].astype(np.float32) / 255.0
+ y = fg * w + bg * (1 - w)
+ return y.clip(0, 255).astype(np.uint8)
+
+ def post_process(self, img):
+ a, b, c, d = self.interested_area
+ content = resample_image(img, d - c, b - a)
+ result = self.image_raw.copy()
+ result[a:b, c:d] = content
+ result = self.color_correction(result)
+ return result
+
+ def visualize_mask_processing(self):
+ result = self.image_raw // 4
+ a, b, c, d = self.interested_area
+ result[a:b, c:d] += 64
+ result[self.mask_raw_trim > 127] += 64
+ result[self.mask_raw_fg > 127] += 128
+ return [result, self.mask_raw_soft, self.image_ready, self.mask_ready]
+
diff --git a/modules/patch.py b/modules/patch.py
index 63a913a7..695cbe80 100644
--- a/modules/patch.py
+++ b/modules/patch.py
@@ -6,15 +6,20 @@ import comfy.k_diffusion.external
import comfy.model_management
import modules.anisotropic as anisotropic
import comfy.ldm.modules.attention
+import comfy.k_diffusion.sampling
import comfy.sd1_clip
+import modules.inpaint_worker as inpaint_worker
from comfy.k_diffusion import utils
+from comfy.k_diffusion.sampling import BrownianTreeNoiseSampler, trange
sharpness = 2.0
+negative_adm = True
cfg_x0 = 0.0
cfg_s = 1.0
+cfg_cin = 1.0
def cfg_patched(args):
@@ -37,14 +42,29 @@ def cfg_patched(args):
def patched_discrete_eps_ddpm_denoiser_forward(self, input, sigma, **kwargs):
- global cfg_x0, cfg_s
+ global cfg_x0, cfg_s, cfg_cin
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
cfg_x0 = input
cfg_s = c_out
+ cfg_cin = c_in
return self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
+def patched_model_function(func, args):
+ global cfg_cin
+ x = args['input']
+ t = args['timestep']
+ c = args['c']
+ is_uncond = torch.tensor(args['cond_or_uncond'])[:, None, None, None].to(x) * 5e-3
+ if inpaint_worker.current_task is not None:
+ p = inpaint_worker.current_task.uc_guidance * cfg_cin
+ x = p * is_uncond + x * (1 - is_uncond ** 2.0) ** 0.5
+ return func(x, t, **c)
+
+
def sdxl_encode_adm_patched(self, **kwargs):
+ global negative_adm
+
clip_pooled = kwargs["pooled_output"]
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
@@ -53,12 +73,13 @@ def sdxl_encode_adm_patched(self, **kwargs):
target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height)
- if kwargs.get("prompt_type", "") == "negative":
- width *= 0.8
- height *= 0.8
- elif kwargs.get("prompt_type", "") == "positive":
- width *= 1.5
- height *= 1.5
+ if negative_adm:
+ if kwargs.get("prompt_type", "") == "negative":
+ width *= 0.8
+ height *= 0.8
+ elif kwargs.get("prompt_type", "") == "positive":
+ width *= 1.5
+ height *= 1.5
out = []
out.append(self.embedder(torch.Tensor([height])))
@@ -71,35 +92,6 @@ def sdxl_encode_adm_patched(self, **kwargs):
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
-def sdxl_refiner_encode_adm_patched(self, **kwargs):
- clip_pooled = kwargs["pooled_output"]
- width = kwargs.get("width", 768)
- height = kwargs.get("height", 768)
- crop_w = kwargs.get("crop_w", 0)
- crop_h = kwargs.get("crop_h", 0)
-
- if kwargs.get("prompt_type", "") == "negative":
- aesthetic_score = kwargs.get("aesthetic_score", 2.5)
- else:
- aesthetic_score = kwargs.get("aesthetic_score", 7.0)
-
- if kwargs.get("prompt_type", "") == "negative":
- width *= 0.8
- height *= 0.8
- elif kwargs.get("prompt_type", "") == "positive":
- width *= 1.5
- height *= 1.5
-
- out = []
- out.append(self.embedder(torch.Tensor([height])))
- out.append(self.embedder(torch.Tensor([width])))
- out.append(self.embedder(torch.Tensor([crop_h])))
- out.append(self.embedder(torch.Tensor([crop_w])))
- out.append(self.embedder(torch.Tensor([aesthetic_score])))
- flat = torch.flatten(torch.cat(out))[None,]
- return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
-
-
def text_encoder_device_patched():
# Fooocus's style system uses text encoder much more times than comfy so this makes things much faster.
return comfy.model_management.get_torch_device()
@@ -138,15 +130,79 @@ def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs):
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
+@torch.no_grad()
+def sample_dpmpp_fooocus_2m_sde_inpaint_seamless(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, **kwargs):
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
+
+ seed = extra_args.get("seed", None)
+ assert isinstance(seed, int)
+
+ energy_generator = torch.Generator(device='cpu')
+ energy_generator.manual_seed(seed + 1) # avoid bad results by using different seeds.
+
+ def get_energy():
+ return torch.randn(x.size(), dtype=x.dtype, generator=energy_generator, device="cpu").to(x)
+
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+
+ old_denoised, h_last, h = None, None, None
+
+ latent_processor = model.inner_model.inner_model.inner_model.process_latent_in
+ inpaint_latent = None
+ inpaint_mask = None
+
+ if inpaint_worker.current_task is not None:
+ inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x)
+ inpaint_mask = inpaint_worker.current_task.latent_mask.to(x)
+
+ def blend_latent(a, b, w):
+ return a * w + b * (1 - w)
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ if inpaint_latent is None:
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ else:
+ inpaint_worker.current_task.uc_guidance = x.detach().clone()
+ energy = get_energy() * sigmas[i] + inpaint_latent
+ x_prime = blend_latent(x, energy, inpaint_mask)
+ denoised = model(x_prime, sigmas[i] * s_in, **extra_args)
+ denoised = blend_latent(denoised, inpaint_latent, inpaint_mask)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ if sigmas[i + 1] == 0:
+ x = denoised
+ else:
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
+ h = s - t
+ eta_h = eta * h
+
+ x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
+ if old_denoised is not None:
+ r = h_last / h
+ x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
+
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (
+ -2 * eta_h).expm1().neg().sqrt() * s_noise
+
+ old_denoised = denoised
+ h_last = h
+
+ return x
+
+
def patch_all():
comfy.ldm.modules.attention.print = lambda x: None
+ comfy.k_diffusion.sampling.sample_dpmpp_fooocus_2m_sde_inpaint_seamless = sample_dpmpp_fooocus_2m_sde_inpaint_seamless
comfy.model_management.text_encoder_device = text_encoder_device_patched
print(f'Fooocus Text Processing Pipelines are retargeted to {str(comfy.model_management.text_encoder_device())}')
comfy.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward
comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
- # comfy.model_base.SDXLRefiner.encode_adm = sdxl_refiner_encode_adm_patched
comfy.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method
return
diff --git a/modules/samplers_advanced.py b/modules/samplers_advanced.py
index 2d382f64..56850ebb 100644
--- a/modules/samplers_advanced.py
+++ b/modules/samplers_advanced.py
@@ -4,11 +4,209 @@ import comfy.model_management
import modules.virtual_memory
+class KSamplerBasic:
+ SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
+ SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
+ "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
+ "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2", "dpmpp_fooocus_2m_sde_inpaint_seamless"]
+
+ def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
+ self.model = model
+ self.model_denoise = CFGNoisePredictor(self.model)
+ if self.model.model_type == model_base.ModelType.V_PREDICTION:
+ self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
+ else:
+ self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
+
+ self.model_k = KSamplerX0Inpaint(self.model_wrap)
+ self.device = device
+ if scheduler not in self.SCHEDULERS:
+ scheduler = self.SCHEDULERS[0]
+ if sampler not in self.SAMPLERS:
+ sampler = self.SAMPLERS[0]
+ self.scheduler = scheduler
+ self.sampler = sampler
+ self.sigma_min=float(self.model_wrap.sigma_min)
+ self.sigma_max=float(self.model_wrap.sigma_max)
+ self.set_steps(steps, denoise)
+ self.denoise = denoise
+ self.model_options = model_options
+
+ def calculate_sigmas(self, steps):
+ sigmas = None
+
+ discard_penultimate_sigma = False
+ if self.sampler in ['dpm_2', 'dpm_2_ancestral']:
+ steps += 1
+ discard_penultimate_sigma = True
+
+ if self.scheduler == "karras":
+ sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
+ elif self.scheduler == "exponential":
+ sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
+ elif self.scheduler == "normal":
+ sigmas = self.model_wrap.get_sigmas(steps)
+ elif self.scheduler == "simple":
+ sigmas = simple_scheduler(self.model_wrap, steps)
+ elif self.scheduler == "ddim_uniform":
+ sigmas = ddim_scheduler(self.model_wrap, steps)
+ else:
+ print("error invalid scheduler", self.scheduler)
+
+ if discard_penultimate_sigma:
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
+ return sigmas
+
+ def set_steps(self, steps, denoise=None):
+ self.steps = steps
+ if denoise is None or denoise > 0.9999:
+ self.sigmas = self.calculate_sigmas(steps).to(self.device)
+ else:
+ new_steps = int(steps/denoise)
+ sigmas = self.calculate_sigmas(new_steps).to(self.device)
+ self.sigmas = sigmas[-(steps + 1):]
+
+ def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
+ if sigmas is None:
+ sigmas = self.sigmas
+ sigma_min = self.sigma_min
+
+ if last_step is not None and last_step < (len(sigmas) - 1):
+ sigma_min = sigmas[last_step]
+ sigmas = sigmas[:last_step + 1]
+ if force_full_denoise:
+ sigmas[-1] = 0
+
+ if start_step is not None:
+ if start_step < (len(sigmas) - 1):
+ sigmas = sigmas[start_step:]
+ else:
+ if latent_image is not None:
+ return latent_image
+ else:
+ return torch.zeros_like(noise)
+
+ positive = positive[:]
+ negative = negative[:]
+
+ resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
+ resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
+
+ calculate_start_end_timesteps(self.model_wrap, negative)
+ calculate_start_end_timesteps(self.model_wrap, positive)
+
+ #make sure each cond area has an opposite one with the same area
+ for c in positive:
+ create_cond_with_same_area_if_none(negative, c)
+ for c in negative:
+ create_cond_with_same_area_if_none(positive, c)
+
+ pre_run_control(self.model_wrap, negative + positive)
+
+ apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
+ apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
+
+ if self.model.is_adm():
+ positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
+ negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
+
+ if latent_image is not None:
+ latent_image = self.model.process_latent_in(latent_image)
+
+ extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options, "seed":seed}
+
+ cond_concat = None
+ if hasattr(self.model, 'concat_keys'): #inpaint
+ cond_concat = []
+ for ck in self.model.concat_keys:
+ if denoise_mask is not None:
+ if ck == "mask":
+ cond_concat.append(denoise_mask[:,:1])
+ elif ck == "masked_image":
+ cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
+ else:
+ if ck == "mask":
+ cond_concat.append(torch.ones_like(noise)[:,:1])
+ elif ck == "masked_image":
+ cond_concat.append(blank_inpaint_image_like(noise))
+ extra_args["cond_concat"] = cond_concat
+
+ if sigmas[0] != self.sigmas[0] or (self.denoise is not None and self.denoise < 1.0):
+ max_denoise = False
+ else:
+ max_denoise = True
+
+
+ if self.sampler == "uni_pc":
+ samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
+ elif self.sampler == "uni_pc_bh2":
+ samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
+ elif self.sampler == "ddim":
+ timesteps = []
+ for s in range(sigmas.shape[0]):
+ timesteps.insert(0, self.model_wrap.sigma_to_discrete_timestep(sigmas[s]))
+ noise_mask = None
+ if denoise_mask is not None:
+ noise_mask = 1.0 - denoise_mask
+
+ ddim_callback = None
+ if callback is not None:
+ total_steps = len(timesteps) - 1
+ ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
+
+ sampler = DDIMSampler(self.model, device=self.device)
+ sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
+ z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
+ samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
+ conditioning=positive,
+ batch_size=noise.shape[0],
+ shape=noise.shape[1:],
+ verbose=False,
+ unconditional_guidance_scale=cfg,
+ unconditional_conditioning=negative,
+ eta=0.0,
+ x_T=z_enc,
+ x0=latent_image,
+ img_callback=ddim_callback,
+ denoise_function=self.model_wrap.predict_eps_discrete_timestep,
+ extra_args=extra_args,
+ mask=noise_mask,
+ to_zero=sigmas[-1]==0,
+ end_step=sigmas.shape[0] - 1,
+ disable_pbar=disable_pbar)
+
+ else:
+ extra_args["denoise_mask"] = denoise_mask
+ self.model_k.latent_image = latent_image
+ self.model_k.noise = noise
+
+ if max_denoise:
+ noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
+ else:
+ noise = noise * sigmas[0]
+
+ k_callback = None
+ total_steps = len(sigmas) - 1
+ if callback is not None:
+ k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
+
+ if latent_image is not None:
+ noise += latent_image
+ if self.sampler == "dpm_fast":
+ samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
+ elif self.sampler == "dpm_adaptive":
+ samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
+ else:
+ samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
+
+ return self.model.process_latent_out(samples.to(torch.float32))
+
+
class KSamplerWithRefiner:
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
- "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"]
+ "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2", "dpmpp_fooocus_2m_sde_inpaint_seamless"]
def __init__(self, model, refiner_model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model_patcher = model
diff --git a/modules/util.py b/modules/util.py
index d8b8e1e5..c8a4f13a 100644
--- a/modules/util.py
+++ b/modules/util.py
@@ -28,6 +28,12 @@ def image_is_generated_in_current_ui(image, ui_width, ui_height):
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
+def resample_image(im, width, height):
+ im = Image.fromarray(im)
+ im = im.resize((width, height), resample=LANCZOS)
+ return np.array(im)
+
+
def resize_image(im, width, height, resize_mode=1):
"""
Resizes an image with the specified resize_mode, width, and height.
diff --git a/webui.py b/webui.py
index a36a0ea7..bab52ed3 100644
--- a/webui.py
+++ b/webui.py
@@ -61,14 +61,31 @@ with shared.gradio_root:
input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check')
advanced_checkbox = gr.Checkbox(label='Advanced', value=False, container=False, elem_classes='min_check')
with gr.Row(visible=False) as image_input_panel:
- with gr.Column(scale=0.5):
- with gr.Accordion(label='Upscale or Variation', open=True):
- uov_input_image = gr.Image(label='Drag above image to here', source='upload', type='numpy')
- uov_method = gr.Radio(label='Method', choices=flags.uov_list, value=flags.disabled, show_label=False, container=False)
- gr.HTML('\U0001F4D4 Document')
+ with gr.Tabs():
+ with gr.TabItem(label='Upscale or Variation') as uov_tab:
+ with gr.Row():
+ with gr.Column():
+ uov_input_image = gr.Image(label='Drag above image to here', source='upload', type='numpy')
+ with gr.Column():
+ uov_method = gr.Radio(label='Upscale or Variation:', choices=flags.uov_list, value=flags.disabled)
+ gr.HTML('\U0001F4D4 Document')
+ with gr.TabItem(label='Inpaint or Outpaint') as inpaint_tab:
+ inpaint_input_image = gr.Image(label='Drag above image to here', source='upload', type='numpy', tool='sketch', height=500, brush_color="#FFFFFF")
+ gr.HTML('Outpaint Expansion (\U0001F4D4 Document):')
+ outpaint_selections = gr.CheckboxGroup(choices=['Left', 'Right', 'Top', 'Bottom'], value=[], label='Outpaint', show_label=False, container=False)
+ gr.HTML('* \"Inpaint or Outpaint\" is powered by the sampler \"DPMPP Fooocus Seamless 2M SDE Karras Inpaint Sampler\" (beta)')
+
input_image_checkbox.change(lambda x: gr.update(visible=x), inputs=input_image_checkbox, outputs=image_input_panel, queue=False,
_js="(x) => {if(x){setTimeout(() => window.scrollTo({ top: window.scrollY + 500, behavior: 'smooth' }), 50);}else{setTimeout(() => window.scrollTo({ top: 0, behavior: 'smooth' }), 50);} return x}")
+ current_tab = gr.Textbox(value='uov', visible=False)
+ uov_tab.select(lambda: 'uov', outputs=current_tab, queue=False)
+ inpaint_tab.select(lambda: 'inpaint', outputs=current_tab, queue=False)
+
+ uov_input_image.upload(lambda x: x, inputs=[uov_input_image], outputs=[inpaint_input_image])
+ inpaint_input_image.upload(lambda: None).\
+ then(lambda x: x['image'], inputs=[inpaint_input_image], outputs=[uov_input_image])
+
# def get_select_index(g, evt: gr.SelectData):
# return g[evt.index]['name']
# gallery.select(get_select_index, gallery, uov_input_image)
@@ -132,8 +149,9 @@ with shared.gradio_root:
performance_selction, aspect_ratios_selction, image_number, image_seed, sharpness
]
ctrls += [base_model, refiner_model] + lora_ctrls
- ctrls += [input_image_checkbox]
+ ctrls += [input_image_checkbox, current_tab]
ctrls += [uov_method, uov_input_image]
+ ctrls += [outpaint_selections, inpaint_input_image]
run_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=False), []), outputs=[stop_button, run_button, gallery])\
.then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed)\