use fooocus inpaint control model (#429)

use fooocus inpaint control model (#429)
This commit is contained in:
lllyasviel 2023-09-19 04:52:22 -07:00 committed by GitHub
parent 6ae7de377d
commit 0927445492
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 304 additions and 19 deletions

1
.gitignore vendored
View File

@ -3,6 +3,7 @@ __pycache__
*.safetensors
*.pth
*.bin
*.patch
lena.png
lena_result.png
lena_test.py

View File

@ -1 +1 @@
version = '2.0.62'
version = '2.0.65'

View File

@ -195,6 +195,10 @@ def worker():
# outputs.append(['results', inpaint_worker.current_task.visualize_mask_processing()])
# return
progressbar(0, 'Downloading inpainter ...')
inpaint_head_model_path, inpaint_patch_model_path = modules.path.downloading_inpaint_models()
loras += [(inpaint_patch_model_path, 1.0)]
inpaint_pixels = core.numpy_to_pytorch(inpaint_worker.current_task.image_ready)
progressbar(0, 'VAE encoding ...')
initial_latent = core.encode_vae(vae=pipeline.xl_base_patched.vae, pixels=inpaint_pixels)
@ -207,6 +211,18 @@ def worker():
height = H * 8
inpaint_worker.current_task.load_latent(latent=inpaint_latent, mask=inpaint_mask)
progressbar(0, 'VAE inpaint encoding ...')
inpaint_mask = (inpaint_worker.current_task.mask_ready > 0).astype(np.float32)
inpaint_mask = torch.tensor(inpaint_mask).float()
vae_dict = core.encode_vae_inpaint(
mask=inpaint_mask, vae=pipeline.xl_base_patched.vae, pixels=inpaint_pixels)
inpaint_latent = vae_dict['samples']
inpaint_mask = vae_dict['noise_mask']
inpaint_worker.current_task.load_inpaint_guidance(latent=inpaint_latent, mask=inpaint_mask, model_path=inpaint_head_model_path)
progressbar(1, 'Initializing ...')
raw_prompt = prompt

View File

@ -8,9 +8,10 @@ import comfy.model_management
import comfy.utils
from comfy.sd import load_checkpoint_guess_config
from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled
from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, VAEEncodeForInpaint
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
from comfy.model_base import SDXLRefiner
from comfy.sd import model_lora_keys_unet, model_lora_keys_clip, load_lora
from modules.samplers_advanced import KSamplerBasic, KSamplerWithRefiner
from modules.patch import patch_all
@ -21,6 +22,7 @@ opVAEDecode = VAEDecode()
opVAEEncode = VAEEncode()
opVAEDecodeTiled = VAEDecodeTiled()
opVAEEncodeTiled = VAEEncodeTiled()
opVAEEncodeForInpaint = VAEEncodeForInpaint()
class StableDiffusionModel:
@ -56,12 +58,32 @@ def load_model(ckpt_filename):
@torch.no_grad()
@torch.inference_mode()
def load_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
def load_sd_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
if strength_model == 0 and strength_clip == 0:
return model
lora = comfy.utils.load_torch_file(lora_filename, safe_load=True)
unet, clip = comfy.sd.load_lora_for_models(model.unet, model.clip, lora, strength_model, strength_clip)
lora = comfy.utils.load_torch_file(lora_filename, safe_load=False)
if lora_filename.lower().endswith('.fooocus.patch'):
loaded = lora
else:
key_map = model_lora_keys_unet(model.unet.model)
key_map = model_lora_keys_clip(model.clip.cond_stage_model, key_map)
loaded = load_lora(lora, key_map)
new_modelpatcher = model.unet.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
new_clip = model.clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip)
k = set(k)
k1 = set(k1)
for x in loaded:
if (x not in k) and (x not in k1):
print("Lora missed: ", x)
unet, clip = new_modelpatcher, new_clip
return StableDiffusionModel(unet=unet, clip=clip, vae=model.vae, clip_vision=model.clip_vision)
@ -83,6 +105,12 @@ def encode_vae(vae, pixels, tiled=False):
return (opVAEEncodeTiled if tiled else opVAEEncode).encode(pixels=pixels, vae=vae)[0]
@torch.no_grad()
@torch.inference_mode()
def encode_vae_inpaint(vae, pixels, mask):
return opVAEEncodeForInpaint.encode(pixels=pixels, vae=vae, mask=mask)[0]
class VAEApprox(torch.nn.Module):
def __init__(self):
super(VAEApprox, self).__init__()

View File

@ -101,8 +101,14 @@ def refresh_loras(loras):
if name == 'None':
continue
filename = os.path.join(modules.path.lorafile_path, name)
model = core.load_lora(model, filename, strength_model=weight, strength_clip=weight)
if os.path.exists(name):
filename = name
else:
filename = os.path.join(modules.path.lorafile_path, name)
assert os.path.exists(filename), 'Lora file not found!'
model = core.load_sd_lora(model, filename, strength_model=weight, strength_clip=weight)
xl_base_patched = model
xl_base_patched_hash = str(loras)
print(f'LoRAs loaded: {xl_base_patched_hash}')

View File

@ -1,7 +1,25 @@
import numpy as np
import os.path
from PIL import Image, ImageFilter, ImageOps
import torch
import numpy as np
import modules.default_pipeline as pipeline
from PIL import Image, ImageFilter
from modules.util import resample_image
from modules.path import inpaint_models_path
inpaint_head = None
class InpaintHead(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device='cpu'))
def __call__(self, x):
x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate")
return torch.nn.functional.conv2d(input=x, weight=self.head)
current_task = None
@ -81,7 +99,12 @@ def solve_abcd(x, a, b, c, d, k, outpaint):
if outpaint:
return 0, H, 0, W
min_area = H * W * k
while area_abcd(a, b, c, d) < min_area:
max_area = H * W
while True:
if area_abcd(a, b, c, d) > min_area and abs((b - a) - (d - c)) < 16:
break
if area_abcd(a, b, c, d) >= max_area:
break
if (b - a) < (d - c):
a -= 1
b += 1
@ -148,7 +171,27 @@ class InpaintWorker:
# ending
self.latent = None
self.latent_mask = None
self.uc_guidance = None
self.inpaint_head_feature = None
return
def load_inpaint_guidance(self, latent, mask, model_path):
global inpaint_head
if inpaint_head is None:
inpaint_head = InpaintHead()
sd = torch.load(model_path, map_location='cpu')
inpaint_head.load_state_dict(sd)
process_latent_in = pipeline.xl_base_patched.unet.model.process_latent_in
latent = process_latent_in(latent)
B, C, H, W = latent.shape
mask = torch.nn.functional.interpolate(mask, size=(H, W), mode="bilinear")
mask = mask.round()
feed = torch.cat([mask, latent], dim=1)
inpaint_head.to(device=feed.device, dtype=feed.dtype)
self.inpaint_head_feature = inpaint_head(feed)
return
def load_latent(self, latent, mask):

View File

@ -9,9 +9,12 @@ import comfy.ldm.modules.attention
import comfy.k_diffusion.sampling
import comfy.sd1_clip
import modules.inpaint_worker as inpaint_worker
import comfy.ldm.modules.diffusionmodules.openaimodel
import comfy.sd
from comfy.k_diffusion import utils
from comfy.k_diffusion.sampling import BrownianTreeNoiseSampler, trange
from comfy.ldm.modules.diffusionmodules.openaimodel import timestep_embedding, forward_timestep_embed
sharpness = 2.0
@ -22,6 +25,112 @@ cfg_s = 1.0
cfg_cin = 1.0
def calculate_weight_patched(self, patches, weight, key):
for p in patches:
alpha = p[0]
v = p[1]
strength_model = p[2]
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
if len(v) == 1:
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * w1.type(weight.dtype).to(weight.device)
elif len(v) == 3:
# fooocus
w1 = v[0].float()
w_min = v[1].float()
w_max = v[2].float()
w1 = (w1 / 255.0) * (w_max - w_min) + w_min
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} FOOOCUS WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * w1.type(weight.dtype).to(weight.device)
elif len(v) == 4: # lora/locon
mat1 = v[0].float().to(weight.device)
mat2 = v[1].float().to(weight.device)
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
mat3 = v[3].float().to(weight.device)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1),
mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(
weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: # lokr
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
else:
w1 = w1.float().to(weight.device)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device),
w2_b.float().to(weight.device), w2_a.float().to(weight.device))
else:
w2 = w2.float().to(weight.device)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else: # loha
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
if v[5] is not None: # cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device),
w1b.float().to(weight.device), w1a.float().to(weight.device))
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device),
w2b.float().to(weight.device), w2a.float().to(weight.device))
else:
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
return weight
def cfg_patched(args):
global cfg_x0, cfg_s
positive_eps = args['cond'].clone()
@ -55,10 +164,7 @@ def patched_model_function(func, args):
x = args['input']
t = args['timestep']
c = args['c']
is_uncond = torch.tensor(args['cond_or_uncond'])[:, None, None, None].to(x) * 5e-3
if inpaint_worker.current_task is not None:
p = inpaint_worker.current_task.uc_guidance * cfg_cin
x = p * is_uncond + x * (1 - is_uncond ** 2.0) ** 0.5
# is_uncond = torch.tensor(args['cond_or_uncond'])[:, None, None, None].to(x) * 5e-3
return func(x, t, **c)
@ -166,7 +272,6 @@ def sample_dpmpp_fooocus_2m_sde_inpaint_seamless(model, x, sigmas, extra_args=No
if inpaint_latent is None:
denoised = model(x, sigmas[i] * s_in, **extra_args)
else:
inpaint_worker.current_task.uc_guidance = x.detach().clone()
energy = get_energy() * sigmas[i] + inpaint_latent
x_prime = blend_latent(x, energy, inpaint_mask)
denoised = model(x_prime, sigmas[i] * s_in, **extra_args)
@ -194,7 +299,71 @@ def sample_dpmpp_fooocus_2m_sde_inpaint_seamless(model, x, sigmas, extra_args=No
return x
def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
inpaint_fix = None
if inpaint_worker.current_task is not None:
inpaint_fix = inpaint_worker.current_task.inpaint_head_feature
transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options)
if inpaint_fix is not None:
if int(h.shape[1]) == int(inpaint_fix.shape[1]):
h = h + inpaint_fix.to(h)
inpaint_fix = None
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
hs.append(h)
transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control['middle'].pop()
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
hsp = hs.pop()
if control is not None and 'output' in control and len(control['output']) > 0:
ctrl = control['output'].pop()
if ctrl is not None:
hsp += ctrl
h = torch.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
def patch_all():
comfy.sd.ModelPatcher.calculate_weight = calculate_weight_patched
comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward
comfy.ldm.modules.attention.print = lambda x: None
comfy.k_diffusion.sampling.sample_dpmpp_fooocus_2m_sde_inpaint_seamless = sample_dpmpp_fooocus_2m_sde_inpaint_seamless

View File

@ -1,9 +1,12 @@
import os
from modules.model_loader import load_file_from_url
modelfile_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/checkpoints/'))
lorafile_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/loras/'))
vae_approx_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/vae_approx/'))
upscale_models_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/upscale_models/'))
inpaint_models_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/inpaint/'))
temp_outputs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../outputs/'))
fooocus_expansion_path = os.path.abspath(os.path.join(os.path.dirname(__file__),
@ -27,9 +30,10 @@ def get_model_filenames(folder_path):
filenames = []
for filename in os.listdir(folder_path):
if os.path.isfile(os.path.join(folder_path, filename)):
_, file_extension = os.path.splitext(filename)
if file_extension.lower() in ['.pth', '.ckpt', '.bin', '.safetensors']:
filenames.append(filename)
for ends in ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']:
if filename.lower().endswith(ends):
filenames.append(filename)
break
return filenames
@ -41,4 +45,18 @@ def update_all_model_names():
return
def downloading_inpaint_models():
load_file_from_url(
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/fooocus_inpaint_head.pth',
model_dir=inpaint_models_path,
file_name='fooocus_inpaint_head.pth'
)
load_file_from_url(
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch',
model_dir=inpaint_models_path,
file_name='inpaint.fooocus.patch'
)
return os.path.join(inpaint_models_path, 'fooocus_inpaint_head.pth'), os.path.join(inpaint_models_path, 'inpaint.fooocus.patch')
update_all_model_names()

View File

@ -1,3 +1,7 @@
### 2.0.65
* Inpaint model released.
### 2.0.50
* Variation/Upscale (Midjourney Toolbar) implemented.