use fooocus inpaint control model (#429)
use fooocus inpaint control model (#429)
This commit is contained in:
parent
6ae7de377d
commit
0927445492
|
|
@ -3,6 +3,7 @@ __pycache__
|
|||
*.safetensors
|
||||
*.pth
|
||||
*.bin
|
||||
*.patch
|
||||
lena.png
|
||||
lena_result.png
|
||||
lena_test.py
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
version = '2.0.62'
|
||||
version = '2.0.65'
|
||||
|
|
|
|||
|
|
@ -195,6 +195,10 @@ def worker():
|
|||
# outputs.append(['results', inpaint_worker.current_task.visualize_mask_processing()])
|
||||
# return
|
||||
|
||||
progressbar(0, 'Downloading inpainter ...')
|
||||
inpaint_head_model_path, inpaint_patch_model_path = modules.path.downloading_inpaint_models()
|
||||
loras += [(inpaint_patch_model_path, 1.0)]
|
||||
|
||||
inpaint_pixels = core.numpy_to_pytorch(inpaint_worker.current_task.image_ready)
|
||||
progressbar(0, 'VAE encoding ...')
|
||||
initial_latent = core.encode_vae(vae=pipeline.xl_base_patched.vae, pixels=inpaint_pixels)
|
||||
|
|
@ -207,6 +211,18 @@ def worker():
|
|||
height = H * 8
|
||||
inpaint_worker.current_task.load_latent(latent=inpaint_latent, mask=inpaint_mask)
|
||||
|
||||
progressbar(0, 'VAE inpaint encoding ...')
|
||||
|
||||
inpaint_mask = (inpaint_worker.current_task.mask_ready > 0).astype(np.float32)
|
||||
inpaint_mask = torch.tensor(inpaint_mask).float()
|
||||
|
||||
vae_dict = core.encode_vae_inpaint(
|
||||
mask=inpaint_mask, vae=pipeline.xl_base_patched.vae, pixels=inpaint_pixels)
|
||||
|
||||
inpaint_latent = vae_dict['samples']
|
||||
inpaint_mask = vae_dict['noise_mask']
|
||||
inpaint_worker.current_task.load_inpaint_guidance(latent=inpaint_latent, mask=inpaint_mask, model_path=inpaint_head_model_path)
|
||||
|
||||
progressbar(1, 'Initializing ...')
|
||||
|
||||
raw_prompt = prompt
|
||||
|
|
|
|||
|
|
@ -8,9 +8,10 @@ import comfy.model_management
|
|||
import comfy.utils
|
||||
|
||||
from comfy.sd import load_checkpoint_guess_config
|
||||
from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled
|
||||
from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, VAEEncodeForInpaint
|
||||
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
|
||||
from comfy.model_base import SDXLRefiner
|
||||
from comfy.sd import model_lora_keys_unet, model_lora_keys_clip, load_lora
|
||||
from modules.samplers_advanced import KSamplerBasic, KSamplerWithRefiner
|
||||
from modules.patch import patch_all
|
||||
|
||||
|
|
@ -21,6 +22,7 @@ opVAEDecode = VAEDecode()
|
|||
opVAEEncode = VAEEncode()
|
||||
opVAEDecodeTiled = VAEDecodeTiled()
|
||||
opVAEEncodeTiled = VAEEncodeTiled()
|
||||
opVAEEncodeForInpaint = VAEEncodeForInpaint()
|
||||
|
||||
|
||||
class StableDiffusionModel:
|
||||
|
|
@ -56,12 +58,32 @@ def load_model(ckpt_filename):
|
|||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def load_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
|
||||
def load_sd_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
return model
|
||||
|
||||
lora = comfy.utils.load_torch_file(lora_filename, safe_load=True)
|
||||
unet, clip = comfy.sd.load_lora_for_models(model.unet, model.clip, lora, strength_model, strength_clip)
|
||||
lora = comfy.utils.load_torch_file(lora_filename, safe_load=False)
|
||||
|
||||
if lora_filename.lower().endswith('.fooocus.patch'):
|
||||
loaded = lora
|
||||
else:
|
||||
key_map = model_lora_keys_unet(model.unet.model)
|
||||
key_map = model_lora_keys_clip(model.clip.cond_stage_model, key_map)
|
||||
loaded = load_lora(lora, key_map)
|
||||
|
||||
new_modelpatcher = model.unet.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
|
||||
new_clip = model.clip.clone()
|
||||
k1 = new_clip.add_patches(loaded, strength_clip)
|
||||
|
||||
k = set(k)
|
||||
k1 = set(k1)
|
||||
for x in loaded:
|
||||
if (x not in k) and (x not in k1):
|
||||
print("Lora missed: ", x)
|
||||
|
||||
unet, clip = new_modelpatcher, new_clip
|
||||
return StableDiffusionModel(unet=unet, clip=clip, vae=model.vae, clip_vision=model.clip_vision)
|
||||
|
||||
|
||||
|
|
@ -83,6 +105,12 @@ def encode_vae(vae, pixels, tiled=False):
|
|||
return (opVAEEncodeTiled if tiled else opVAEEncode).encode(pixels=pixels, vae=vae)[0]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def encode_vae_inpaint(vae, pixels, mask):
|
||||
return opVAEEncodeForInpaint.encode(pixels=pixels, vae=vae, mask=mask)[0]
|
||||
|
||||
|
||||
class VAEApprox(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(VAEApprox, self).__init__()
|
||||
|
|
|
|||
|
|
@ -101,8 +101,14 @@ def refresh_loras(loras):
|
|||
if name == 'None':
|
||||
continue
|
||||
|
||||
filename = os.path.join(modules.path.lorafile_path, name)
|
||||
model = core.load_lora(model, filename, strength_model=weight, strength_clip=weight)
|
||||
if os.path.exists(name):
|
||||
filename = name
|
||||
else:
|
||||
filename = os.path.join(modules.path.lorafile_path, name)
|
||||
|
||||
assert os.path.exists(filename), 'Lora file not found!'
|
||||
|
||||
model = core.load_sd_lora(model, filename, strength_model=weight, strength_clip=weight)
|
||||
xl_base_patched = model
|
||||
xl_base_patched_hash = str(loras)
|
||||
print(f'LoRAs loaded: {xl_base_patched_hash}')
|
||||
|
|
|
|||
|
|
@ -1,7 +1,25 @@
|
|||
import numpy as np
|
||||
import os.path
|
||||
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
import torch
|
||||
import numpy as np
|
||||
import modules.default_pipeline as pipeline
|
||||
|
||||
from PIL import Image, ImageFilter
|
||||
from modules.util import resample_image
|
||||
from modules.path import inpaint_models_path
|
||||
|
||||
|
||||
inpaint_head = None
|
||||
|
||||
|
||||
class InpaintHead(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device='cpu'))
|
||||
|
||||
def __call__(self, x):
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate")
|
||||
return torch.nn.functional.conv2d(input=x, weight=self.head)
|
||||
|
||||
|
||||
current_task = None
|
||||
|
|
@ -81,7 +99,12 @@ def solve_abcd(x, a, b, c, d, k, outpaint):
|
|||
if outpaint:
|
||||
return 0, H, 0, W
|
||||
min_area = H * W * k
|
||||
while area_abcd(a, b, c, d) < min_area:
|
||||
max_area = H * W
|
||||
while True:
|
||||
if area_abcd(a, b, c, d) > min_area and abs((b - a) - (d - c)) < 16:
|
||||
break
|
||||
if area_abcd(a, b, c, d) >= max_area:
|
||||
break
|
||||
if (b - a) < (d - c):
|
||||
a -= 1
|
||||
b += 1
|
||||
|
|
@ -148,7 +171,27 @@ class InpaintWorker:
|
|||
# ending
|
||||
self.latent = None
|
||||
self.latent_mask = None
|
||||
self.uc_guidance = None
|
||||
self.inpaint_head_feature = None
|
||||
return
|
||||
|
||||
def load_inpaint_guidance(self, latent, mask, model_path):
|
||||
global inpaint_head
|
||||
if inpaint_head is None:
|
||||
inpaint_head = InpaintHead()
|
||||
sd = torch.load(model_path, map_location='cpu')
|
||||
inpaint_head.load_state_dict(sd)
|
||||
process_latent_in = pipeline.xl_base_patched.unet.model.process_latent_in
|
||||
|
||||
latent = process_latent_in(latent)
|
||||
B, C, H, W = latent.shape
|
||||
|
||||
mask = torch.nn.functional.interpolate(mask, size=(H, W), mode="bilinear")
|
||||
mask = mask.round()
|
||||
|
||||
feed = torch.cat([mask, latent], dim=1)
|
||||
|
||||
inpaint_head.to(device=feed.device, dtype=feed.dtype)
|
||||
self.inpaint_head_feature = inpaint_head(feed)
|
||||
return
|
||||
|
||||
def load_latent(self, latent, mask):
|
||||
|
|
|
|||
179
modules/patch.py
179
modules/patch.py
|
|
@ -9,9 +9,12 @@ import comfy.ldm.modules.attention
|
|||
import comfy.k_diffusion.sampling
|
||||
import comfy.sd1_clip
|
||||
import modules.inpaint_worker as inpaint_worker
|
||||
import comfy.ldm.modules.diffusionmodules.openaimodel
|
||||
import comfy.sd
|
||||
|
||||
from comfy.k_diffusion import utils
|
||||
from comfy.k_diffusion.sampling import BrownianTreeNoiseSampler, trange
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import timestep_embedding, forward_timestep_embed
|
||||
|
||||
|
||||
sharpness = 2.0
|
||||
|
|
@ -22,6 +25,112 @@ cfg_s = 1.0
|
|||
cfg_cin = 1.0
|
||||
|
||||
|
||||
def calculate_weight_patched(self, patches, weight, key):
|
||||
for p in patches:
|
||||
alpha = p[0]
|
||||
v = p[1]
|
||||
strength_model = p[2]
|
||||
|
||||
if strength_model != 1.0:
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
|
||||
|
||||
if len(v) == 1:
|
||||
w1 = v[0]
|
||||
if alpha != 0.0:
|
||||
if w1.shape != weight.shape:
|
||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
||||
elif len(v) == 3:
|
||||
# fooocus
|
||||
w1 = v[0].float()
|
||||
w_min = v[1].float()
|
||||
w_max = v[2].float()
|
||||
w1 = (w1 / 255.0) * (w_max - w_min) + w_min
|
||||
if alpha != 0.0:
|
||||
if w1.shape != weight.shape:
|
||||
print("WARNING SHAPE MISMATCH {} FOOOCUS WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
||||
elif len(v) == 4: # lora/locon
|
||||
mat1 = v[0].float().to(weight.device)
|
||||
mat2 = v[1].float().to(weight.device)
|
||||
if v[2] is not None:
|
||||
alpha *= v[2] / mat2.shape[0]
|
||||
if v[3] is not None:
|
||||
mat3 = v[3].float().to(weight.device)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1),
|
||||
mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
try:
|
||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(
|
||||
weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
elif len(v) == 8: # lokr
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(w1_a.float(), w1_b.float())
|
||||
else:
|
||||
w1 = w1.float().to(weight.device)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device),
|
||||
w2_b.float().to(weight.device), w2_a.float().to(weight.device))
|
||||
else:
|
||||
w2 = w2.float().to(weight.device)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha *= v[2] / dim
|
||||
|
||||
try:
|
||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
else: # loha
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
alpha *= v[2] / w1b.shape[0]
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
if v[5] is not None: # cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device),
|
||||
w1b.float().to(weight.device), w1a.float().to(weight.device))
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device),
|
||||
w2b.float().to(weight.device), w2a.float().to(weight.device))
|
||||
else:
|
||||
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
|
||||
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
|
||||
|
||||
try:
|
||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
def cfg_patched(args):
|
||||
global cfg_x0, cfg_s
|
||||
positive_eps = args['cond'].clone()
|
||||
|
|
@ -55,10 +164,7 @@ def patched_model_function(func, args):
|
|||
x = args['input']
|
||||
t = args['timestep']
|
||||
c = args['c']
|
||||
is_uncond = torch.tensor(args['cond_or_uncond'])[:, None, None, None].to(x) * 5e-3
|
||||
if inpaint_worker.current_task is not None:
|
||||
p = inpaint_worker.current_task.uc_guidance * cfg_cin
|
||||
x = p * is_uncond + x * (1 - is_uncond ** 2.0) ** 0.5
|
||||
# is_uncond = torch.tensor(args['cond_or_uncond'])[:, None, None, None].to(x) * 5e-3
|
||||
return func(x, t, **c)
|
||||
|
||||
|
||||
|
|
@ -166,7 +272,6 @@ def sample_dpmpp_fooocus_2m_sde_inpaint_seamless(model, x, sigmas, extra_args=No
|
|||
if inpaint_latent is None:
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
else:
|
||||
inpaint_worker.current_task.uc_guidance = x.detach().clone()
|
||||
energy = get_energy() * sigmas[i] + inpaint_latent
|
||||
x_prime = blend_latent(x, energy, inpaint_mask)
|
||||
denoised = model(x_prime, sigmas[i] * s_in, **extra_args)
|
||||
|
|
@ -194,7 +299,71 @@ def sample_dpmpp_fooocus_2m_sde_inpaint_seamless(model, x, sigmas, extra_args=No
|
|||
return x
|
||||
|
||||
|
||||
def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
inpaint_fix = None
|
||||
if inpaint_worker.current_task is not None:
|
||||
inpaint_fix = inpaint_worker.current_task.inpaint_head_feature
|
||||
|
||||
transformer_options["original_shape"] = list(x.shape)
|
||||
transformer_options["current_index"] = 0
|
||||
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
transformer_options["block"] = ("input", id)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
||||
|
||||
if inpaint_fix is not None:
|
||||
if int(h.shape[1]) == int(inpaint_fix.shape[1]):
|
||||
h = h + inpaint_fix.to(h)
|
||||
inpaint_fix = None
|
||||
|
||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||
ctrl = control['input'].pop()
|
||||
if ctrl is not None:
|
||||
h += ctrl
|
||||
hs.append(h)
|
||||
transformer_options["block"] = ("middle", 0)
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||
h += control['middle'].pop()
|
||||
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
hsp = hs.pop()
|
||||
if control is not None and 'output' in control and len(control['output']) > 0:
|
||||
ctrl = control['output'].pop()
|
||||
if ctrl is not None:
|
||||
hsp += ctrl
|
||||
|
||||
h = torch.cat([h, hsp], dim=1)
|
||||
del hsp
|
||||
if len(hs) > 0:
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
output_shape = None
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
else:
|
||||
return self.out(h)
|
||||
|
||||
|
||||
def patch_all():
|
||||
comfy.sd.ModelPatcher.calculate_weight = calculate_weight_patched
|
||||
comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward
|
||||
|
||||
comfy.ldm.modules.attention.print = lambda x: None
|
||||
comfy.k_diffusion.sampling.sample_dpmpp_fooocus_2m_sde_inpaint_seamless = sample_dpmpp_fooocus_2m_sde_inpaint_seamless
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import os
|
||||
from modules.model_loader import load_file_from_url
|
||||
|
||||
|
||||
modelfile_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/checkpoints/'))
|
||||
lorafile_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/loras/'))
|
||||
vae_approx_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/vae_approx/'))
|
||||
upscale_models_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/upscale_models/'))
|
||||
inpaint_models_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/inpaint/'))
|
||||
temp_outputs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../outputs/'))
|
||||
|
||||
fooocus_expansion_path = os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
|
|
@ -27,9 +30,10 @@ def get_model_filenames(folder_path):
|
|||
filenames = []
|
||||
for filename in os.listdir(folder_path):
|
||||
if os.path.isfile(os.path.join(folder_path, filename)):
|
||||
_, file_extension = os.path.splitext(filename)
|
||||
if file_extension.lower() in ['.pth', '.ckpt', '.bin', '.safetensors']:
|
||||
filenames.append(filename)
|
||||
for ends in ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']:
|
||||
if filename.lower().endswith(ends):
|
||||
filenames.append(filename)
|
||||
break
|
||||
|
||||
return filenames
|
||||
|
||||
|
|
@ -41,4 +45,18 @@ def update_all_model_names():
|
|||
return
|
||||
|
||||
|
||||
def downloading_inpaint_models():
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/fooocus_inpaint_head.pth',
|
||||
model_dir=inpaint_models_path,
|
||||
file_name='fooocus_inpaint_head.pth'
|
||||
)
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch',
|
||||
model_dir=inpaint_models_path,
|
||||
file_name='inpaint.fooocus.patch'
|
||||
)
|
||||
return os.path.join(inpaint_models_path, 'fooocus_inpaint_head.pth'), os.path.join(inpaint_models_path, 'inpaint.fooocus.patch')
|
||||
|
||||
|
||||
update_all_model_names()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
### 2.0.65
|
||||
|
||||
* Inpaint model released.
|
||||
|
||||
### 2.0.50
|
||||
|
||||
* Variation/Upscale (Midjourney Toolbar) implemented.
|
||||
|
|
|
|||
Loading…
Reference in New Issue