From 0927445492a1ce0ee4a9670282966cb7a6166cdf Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Tue, 19 Sep 2023 04:52:22 -0700 Subject: [PATCH] use fooocus inpaint control model (#429) use fooocus inpaint control model (#429) --- .gitignore | 1 + fooocus_version.py | 2 +- modules/async_worker.py | 16 ++++ modules/core.py | 36 +++++++- modules/default_pipeline.py | 10 +- modules/inpaint_worker.py | 51 +++++++++- modules/patch.py | 179 +++++++++++++++++++++++++++++++++++- modules/path.py | 24 ++++- update_log.md | 4 + 9 files changed, 304 insertions(+), 19 deletions(-) diff --git a/.gitignore b/.gitignore index fb1bffdd..29c99edb 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__ *.safetensors *.pth *.bin +*.patch lena.png lena_result.png lena_test.py diff --git a/fooocus_version.py b/fooocus_version.py index 382ba9ba..e1274470 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.0.62' +version = '2.0.65' diff --git a/modules/async_worker.py b/modules/async_worker.py index ec001e92..0ff66833 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -195,6 +195,10 @@ def worker(): # outputs.append(['results', inpaint_worker.current_task.visualize_mask_processing()]) # return + progressbar(0, 'Downloading inpainter ...') + inpaint_head_model_path, inpaint_patch_model_path = modules.path.downloading_inpaint_models() + loras += [(inpaint_patch_model_path, 1.0)] + inpaint_pixels = core.numpy_to_pytorch(inpaint_worker.current_task.image_ready) progressbar(0, 'VAE encoding ...') initial_latent = core.encode_vae(vae=pipeline.xl_base_patched.vae, pixels=inpaint_pixels) @@ -207,6 +211,18 @@ def worker(): height = H * 8 inpaint_worker.current_task.load_latent(latent=inpaint_latent, mask=inpaint_mask) + progressbar(0, 'VAE inpaint encoding ...') + + inpaint_mask = (inpaint_worker.current_task.mask_ready > 0).astype(np.float32) + inpaint_mask = torch.tensor(inpaint_mask).float() + + vae_dict = core.encode_vae_inpaint( + mask=inpaint_mask, vae=pipeline.xl_base_patched.vae, pixels=inpaint_pixels) + + inpaint_latent = vae_dict['samples'] + inpaint_mask = vae_dict['noise_mask'] + inpaint_worker.current_task.load_inpaint_guidance(latent=inpaint_latent, mask=inpaint_mask, model_path=inpaint_head_model_path) + progressbar(1, 'Initializing ...') raw_prompt = prompt diff --git a/modules/core.py b/modules/core.py index cd3c68df..e57a81a7 100644 --- a/modules/core.py +++ b/modules/core.py @@ -8,9 +8,10 @@ import comfy.model_management import comfy.utils from comfy.sd import load_checkpoint_guess_config -from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled +from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, VAEEncodeForInpaint from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models from comfy.model_base import SDXLRefiner +from comfy.sd import model_lora_keys_unet, model_lora_keys_clip, load_lora from modules.samplers_advanced import KSamplerBasic, KSamplerWithRefiner from modules.patch import patch_all @@ -21,6 +22,7 @@ opVAEDecode = VAEDecode() opVAEEncode = VAEEncode() opVAEDecodeTiled = VAEDecodeTiled() opVAEEncodeTiled = VAEEncodeTiled() +opVAEEncodeForInpaint = VAEEncodeForInpaint() class StableDiffusionModel: @@ -56,12 +58,32 @@ def load_model(ckpt_filename): @torch.no_grad() @torch.inference_mode() -def load_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0): +def load_sd_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0): if strength_model == 0 and strength_clip == 0: return model - lora = comfy.utils.load_torch_file(lora_filename, safe_load=True) - unet, clip = comfy.sd.load_lora_for_models(model.unet, model.clip, lora, strength_model, strength_clip) + lora = comfy.utils.load_torch_file(lora_filename, safe_load=False) + + if lora_filename.lower().endswith('.fooocus.patch'): + loaded = lora + else: + key_map = model_lora_keys_unet(model.unet.model) + key_map = model_lora_keys_clip(model.clip.cond_stage_model, key_map) + loaded = load_lora(lora, key_map) + + new_modelpatcher = model.unet.clone() + k = new_modelpatcher.add_patches(loaded, strength_model) + + new_clip = model.clip.clone() + k1 = new_clip.add_patches(loaded, strength_clip) + + k = set(k) + k1 = set(k1) + for x in loaded: + if (x not in k) and (x not in k1): + print("Lora missed: ", x) + + unet, clip = new_modelpatcher, new_clip return StableDiffusionModel(unet=unet, clip=clip, vae=model.vae, clip_vision=model.clip_vision) @@ -83,6 +105,12 @@ def encode_vae(vae, pixels, tiled=False): return (opVAEEncodeTiled if tiled else opVAEEncode).encode(pixels=pixels, vae=vae)[0] +@torch.no_grad() +@torch.inference_mode() +def encode_vae_inpaint(vae, pixels, mask): + return opVAEEncodeForInpaint.encode(pixels=pixels, vae=vae, mask=mask)[0] + + class VAEApprox(torch.nn.Module): def __init__(self): super(VAEApprox, self).__init__() diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index ed45d42c..f19e9f76 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -101,8 +101,14 @@ def refresh_loras(loras): if name == 'None': continue - filename = os.path.join(modules.path.lorafile_path, name) - model = core.load_lora(model, filename, strength_model=weight, strength_clip=weight) + if os.path.exists(name): + filename = name + else: + filename = os.path.join(modules.path.lorafile_path, name) + + assert os.path.exists(filename), 'Lora file not found!' + + model = core.load_sd_lora(model, filename, strength_model=weight, strength_clip=weight) xl_base_patched = model xl_base_patched_hash = str(loras) print(f'LoRAs loaded: {xl_base_patched_hash}') diff --git a/modules/inpaint_worker.py b/modules/inpaint_worker.py index 5ab4ab0d..480e6358 100644 --- a/modules/inpaint_worker.py +++ b/modules/inpaint_worker.py @@ -1,7 +1,25 @@ -import numpy as np +import os.path -from PIL import Image, ImageFilter, ImageOps +import torch +import numpy as np +import modules.default_pipeline as pipeline + +from PIL import Image, ImageFilter from modules.util import resample_image +from modules.path import inpaint_models_path + + +inpaint_head = None + + +class InpaintHead(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device='cpu')) + + def __call__(self, x): + x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate") + return torch.nn.functional.conv2d(input=x, weight=self.head) current_task = None @@ -81,7 +99,12 @@ def solve_abcd(x, a, b, c, d, k, outpaint): if outpaint: return 0, H, 0, W min_area = H * W * k - while area_abcd(a, b, c, d) < min_area: + max_area = H * W + while True: + if area_abcd(a, b, c, d) > min_area and abs((b - a) - (d - c)) < 16: + break + if area_abcd(a, b, c, d) >= max_area: + break if (b - a) < (d - c): a -= 1 b += 1 @@ -148,7 +171,27 @@ class InpaintWorker: # ending self.latent = None self.latent_mask = None - self.uc_guidance = None + self.inpaint_head_feature = None + return + + def load_inpaint_guidance(self, latent, mask, model_path): + global inpaint_head + if inpaint_head is None: + inpaint_head = InpaintHead() + sd = torch.load(model_path, map_location='cpu') + inpaint_head.load_state_dict(sd) + process_latent_in = pipeline.xl_base_patched.unet.model.process_latent_in + + latent = process_latent_in(latent) + B, C, H, W = latent.shape + + mask = torch.nn.functional.interpolate(mask, size=(H, W), mode="bilinear") + mask = mask.round() + + feed = torch.cat([mask, latent], dim=1) + + inpaint_head.to(device=feed.device, dtype=feed.dtype) + self.inpaint_head_feature = inpaint_head(feed) return def load_latent(self, latent, mask): diff --git a/modules/patch.py b/modules/patch.py index 695cbe80..ceba5e30 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -9,9 +9,12 @@ import comfy.ldm.modules.attention import comfy.k_diffusion.sampling import comfy.sd1_clip import modules.inpaint_worker as inpaint_worker +import comfy.ldm.modules.diffusionmodules.openaimodel +import comfy.sd from comfy.k_diffusion import utils from comfy.k_diffusion.sampling import BrownianTreeNoiseSampler, trange +from comfy.ldm.modules.diffusionmodules.openaimodel import timestep_embedding, forward_timestep_embed sharpness = 2.0 @@ -22,6 +25,112 @@ cfg_s = 1.0 cfg_cin = 1.0 +def calculate_weight_patched(self, patches, weight, key): + for p in patches: + alpha = p[0] + v = p[1] + strength_model = p[2] + + if strength_model != 1.0: + weight *= strength_model + + if isinstance(v, list): + v = (self.calculate_weight(v[1:], v[0].clone(), key),) + + if len(v) == 1: + w1 = v[0] + if alpha != 0.0: + if w1.shape != weight.shape: + print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + else: + weight += alpha * w1.type(weight.dtype).to(weight.device) + elif len(v) == 3: + # fooocus + w1 = v[0].float() + w_min = v[1].float() + w_max = v[2].float() + w1 = (w1 / 255.0) * (w_max - w_min) + w_min + if alpha != 0.0: + if w1.shape != weight.shape: + print("WARNING SHAPE MISMATCH {} FOOOCUS WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + else: + weight += alpha * w1.type(weight.dtype).to(weight.device) + elif len(v) == 4: # lora/locon + mat1 = v[0].float().to(weight.device) + mat2 = v[1].float().to(weight.device) + if v[2] is not None: + alpha *= v[2] / mat2.shape[0] + if v[3] is not None: + mat3 = v[3].float().to(weight.device) + final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] + mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), + mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) + try: + weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape( + weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + elif len(v) == 8: # lokr + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(w1_a.float(), w1_b.float()) + else: + w1 = w1.float().to(weight.device) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device)) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), + w2_b.float().to(weight.device), w2_a.float().to(weight.device)) + else: + w2 = w2.float().to(weight.device) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha *= v[2] / dim + + try: + weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + else: # loha + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha *= v[2] / w1b.shape[0] + w2a = v[3] + w2b = v[4] + if v[5] is not None: # cp decomposition + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), + w1b.float().to(weight.device), w1a.float().to(weight.device)) + m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), + w2b.float().to(weight.device), w2a.float().to(weight.device)) + else: + m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) + m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) + + try: + weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + + return weight + + def cfg_patched(args): global cfg_x0, cfg_s positive_eps = args['cond'].clone() @@ -55,10 +164,7 @@ def patched_model_function(func, args): x = args['input'] t = args['timestep'] c = args['c'] - is_uncond = torch.tensor(args['cond_or_uncond'])[:, None, None, None].to(x) * 5e-3 - if inpaint_worker.current_task is not None: - p = inpaint_worker.current_task.uc_guidance * cfg_cin - x = p * is_uncond + x * (1 - is_uncond ** 2.0) ** 0.5 + # is_uncond = torch.tensor(args['cond_or_uncond'])[:, None, None, None].to(x) * 5e-3 return func(x, t, **c) @@ -166,7 +272,6 @@ def sample_dpmpp_fooocus_2m_sde_inpaint_seamless(model, x, sigmas, extra_args=No if inpaint_latent is None: denoised = model(x, sigmas[i] * s_in, **extra_args) else: - inpaint_worker.current_task.uc_guidance = x.detach().clone() energy = get_energy() * sigmas[i] + inpaint_latent x_prime = blend_latent(x, energy, inpaint_mask) denoised = model(x_prime, sigmas[i] * s_in, **extra_args) @@ -194,7 +299,71 @@ def sample_dpmpp_fooocus_2m_sde_inpaint_seamless(model, x, sigmas, extra_args=No return x +def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): + inpaint_fix = None + if inpaint_worker.current_task is not None: + inpaint_fix = inpaint_worker.current_task.inpaint_head_feature + + transformer_options["original_shape"] = list(x.shape) + transformer_options["current_index"] = 0 + + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for id, module in enumerate(self.input_blocks): + transformer_options["block"] = ("input", id) + h = forward_timestep_embed(module, h, emb, context, transformer_options) + + if inpaint_fix is not None: + if int(h.shape[1]) == int(inpaint_fix.shape[1]): + h = h + inpaint_fix.to(h) + inpaint_fix = None + + if control is not None and 'input' in control and len(control['input']) > 0: + ctrl = control['input'].pop() + if ctrl is not None: + h += ctrl + hs.append(h) + transformer_options["block"] = ("middle", 0) + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) + if control is not None and 'middle' in control and len(control['middle']) > 0: + h += control['middle'].pop() + + for id, module in enumerate(self.output_blocks): + transformer_options["block"] = ("output", id) + hsp = hs.pop() + if control is not None and 'output' in control and len(control['output']) > 0: + ctrl = control['output'].pop() + if ctrl is not None: + hsp += ctrl + + h = torch.cat([h, hsp], dim=1) + del hsp + if len(hs) > 0: + output_shape = hs[-1].shape + else: + output_shape = None + h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + def patch_all(): + comfy.sd.ModelPatcher.calculate_weight = calculate_weight_patched + comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward + comfy.ldm.modules.attention.print = lambda x: None comfy.k_diffusion.sampling.sample_dpmpp_fooocus_2m_sde_inpaint_seamless = sample_dpmpp_fooocus_2m_sde_inpaint_seamless diff --git a/modules/path.py b/modules/path.py index b2af0896..f3ad35a5 100644 --- a/modules/path.py +++ b/modules/path.py @@ -1,9 +1,12 @@ import os +from modules.model_loader import load_file_from_url + modelfile_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/checkpoints/')) lorafile_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/loras/')) vae_approx_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/vae_approx/')) upscale_models_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/upscale_models/')) +inpaint_models_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/inpaint/')) temp_outputs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../outputs/')) fooocus_expansion_path = os.path.abspath(os.path.join(os.path.dirname(__file__), @@ -27,9 +30,10 @@ def get_model_filenames(folder_path): filenames = [] for filename in os.listdir(folder_path): if os.path.isfile(os.path.join(folder_path, filename)): - _, file_extension = os.path.splitext(filename) - if file_extension.lower() in ['.pth', '.ckpt', '.bin', '.safetensors']: - filenames.append(filename) + for ends in ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']: + if filename.lower().endswith(ends): + filenames.append(filename) + break return filenames @@ -41,4 +45,18 @@ def update_all_model_names(): return +def downloading_inpaint_models(): + load_file_from_url( + url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/fooocus_inpaint_head.pth', + model_dir=inpaint_models_path, + file_name='fooocus_inpaint_head.pth' + ) + load_file_from_url( + url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch', + model_dir=inpaint_models_path, + file_name='inpaint.fooocus.patch' + ) + return os.path.join(inpaint_models_path, 'fooocus_inpaint_head.pth'), os.path.join(inpaint_models_path, 'inpaint.fooocus.patch') + + update_all_model_names() diff --git a/update_log.md b/update_log.md index 9dd294d2..16c2a02b 100644 --- a/update_log.md +++ b/update_log.md @@ -1,3 +1,7 @@ +### 2.0.65 + +* Inpaint model released. + ### 2.0.50 * Variation/Upscale (Midjourney Toolbar) implemented.