parent
bb965067e0
commit
38e70cebcc
|
|
@ -0,0 +1,64 @@
|
|||
import enum
|
||||
import torch
|
||||
import math
|
||||
import fcbh.utils
|
||||
|
||||
|
||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||
return abs(a*b) // math.gcd(a, b)
|
||||
|
||||
class CONDRegular:
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def _copy_with(self, cond):
|
||||
return self.__class__(cond)
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
return self._copy_with(fcbh.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond.shape != other.cond.shape:
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
conds = [self.cond]
|
||||
for x in others:
|
||||
conds.append(x.cond)
|
||||
return torch.cat(conds)
|
||||
|
||||
class CONDNoiseShape(CONDRegular):
|
||||
def process_cond(self, batch_size, device, area, **kwargs):
|
||||
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
return self._copy_with(fcbh.utils.repeat_to_batch_size(data, batch_size).to(device))
|
||||
|
||||
|
||||
class CONDCrossAttn(CONDRegular):
|
||||
def can_concat(self, other):
|
||||
s1 = self.cond.shape
|
||||
s2 = other.cond.shape
|
||||
if s1 != s2:
|
||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||
return False
|
||||
|
||||
mult_min = lcm(s1[1], s2[1])
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
conds = [self.cond]
|
||||
crossattn_max_len = self.cond.shape[1]
|
||||
for x in others:
|
||||
c = x.cond
|
||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||
conds.append(c)
|
||||
|
||||
out = []
|
||||
for c in conds:
|
||||
if c.shape[1] < crossattn_max_len:
|
||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||
out.append(c)
|
||||
return torch.cat(out)
|
||||
|
|
@ -156,7 +156,7 @@ class ControlNet(ControlBase):
|
|||
|
||||
|
||||
context = cond['c_crossattn']
|
||||
y = cond.get('c_adm', None)
|
||||
y = cond.get('y', None)
|
||||
if y is not None:
|
||||
y = y.to(self.control_model.dtype)
|
||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from fcbh.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmen
|
|||
from fcbh.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
from fcbh.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
import fcbh.model_management
|
||||
import fcbh.conds
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
|
|
@ -49,7 +50,7 @@ class BaseModel(torch.nn.Module):
|
|||
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([x] + [c_concat], dim=1)
|
||||
else:
|
||||
|
|
@ -59,9 +60,10 @@ class BaseModel(torch.nn.Module):
|
|||
xc = xc.to(dtype)
|
||||
t = t.to(dtype)
|
||||
context = context.to(dtype)
|
||||
if c_adm is not None:
|
||||
c_adm = c_adm.to(dtype)
|
||||
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float()
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra_conds[o] = kwargs[o].to(dtype)
|
||||
return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
|
@ -72,7 +74,8 @@ class BaseModel(torch.nn.Module):
|
|||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
def cond_concat(self, **kwargs):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
if self.inpaint_model:
|
||||
concat_keys = ("mask", "masked_image")
|
||||
cond_concat = []
|
||||
|
|
@ -101,8 +104,12 @@ class BaseModel(torch.nn.Module):
|
|||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(blank_inpaint_image_like(noise))
|
||||
return cond_concat
|
||||
return None
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
out['c_concat'] = fcbh.conds.CONDNoiseShape(data)
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = fcbh.conds.CONDRegular(adm)
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
to_load = {}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import fcbh.model_management
|
||||
import fcbh.samplers
|
||||
import fcbh.conds
|
||||
import fcbh.utils
|
||||
import math
|
||||
import numpy as np
|
||||
|
|
@ -33,22 +34,24 @@ def prepare_mask(noise_mask, shape, device):
|
|||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
|
||||
def broadcast_cond(cond, batch, device):
|
||||
"""broadcasts conditioning to the batch size"""
|
||||
copy = []
|
||||
for p in cond:
|
||||
t = fcbh.utils.repeat_to_batch_size(p[0], batch)
|
||||
t = t.to(device)
|
||||
copy += [[t] + p[1:]]
|
||||
return copy
|
||||
|
||||
def get_models_from_cond(cond, model_type):
|
||||
models = []
|
||||
for c in cond:
|
||||
if model_type in c[1]:
|
||||
models += [c[1][model_type]]
|
||||
if model_type in c:
|
||||
models += [c[model_type]]
|
||||
return models
|
||||
|
||||
def convert_cond(cond):
|
||||
out = []
|
||||
for c in cond:
|
||||
temp = c[1].copy()
|
||||
model_conds = temp.get("model_conds", {})
|
||||
if c[0] is not None:
|
||||
model_conds["c_crossattn"] = fcbh.conds.CONDCrossAttn(c[0])
|
||||
temp["model_conds"] = model_conds
|
||||
out.append(temp)
|
||||
return out
|
||||
|
||||
def get_additional_models(positive, negative, dtype):
|
||||
"""loads additional models in positive and negative conditioning"""
|
||||
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
|
||||
|
|
@ -72,6 +75,8 @@ def cleanup_additional_models(models):
|
|||
|
||||
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
||||
device = model.load_device
|
||||
positive = convert_cond(positive)
|
||||
negative = convert_cond(negative)
|
||||
|
||||
if noise_mask is not None:
|
||||
noise_mask = prepare_mask(noise_mask, noise_shape, device)
|
||||
|
|
@ -81,9 +86,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
|||
fcbh.model_management.load_models_gpu([model] + models, fcbh.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory)
|
||||
real_model = model.model
|
||||
|
||||
positive_copy = broadcast_cond(positive, noise_shape[0], device)
|
||||
negative_copy = broadcast_cond(negative, noise_shape[0], device)
|
||||
return real_model, positive_copy, negative_copy, noise_mask, models
|
||||
return real_model, positive, negative, noise_mask, models
|
||||
|
||||
|
||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||
|
|
|
|||
|
|
@ -2,47 +2,44 @@ from .k_diffusion import sampling as k_diffusion_sampling
|
|||
from .k_diffusion import external as k_diffusion_external
|
||||
from .extra_samplers import uni_pc
|
||||
import torch
|
||||
import enum
|
||||
from fcbh import model_management
|
||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||
import math
|
||||
from fcbh import model_base
|
||||
import fcbh.utils
|
||||
import fcbh.conds
|
||||
|
||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||
return abs(a*b) // math.gcd(a, b)
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns predicted noise
|
||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def get_area_and_mult(cond, x_in, timestep_in):
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
strength = 1.0
|
||||
if 'timestep_start' in cond[1]:
|
||||
timestep_start = cond[1]['timestep_start']
|
||||
|
||||
if 'timestep_start' in conds:
|
||||
timestep_start = conds['timestep_start']
|
||||
if timestep_in[0] > timestep_start:
|
||||
return None
|
||||
if 'timestep_end' in cond[1]:
|
||||
timestep_end = cond[1]['timestep_end']
|
||||
if 'timestep_end' in conds:
|
||||
timestep_end = conds['timestep_end']
|
||||
if timestep_in[0] < timestep_end:
|
||||
return None
|
||||
if 'area' in cond[1]:
|
||||
area = cond[1]['area']
|
||||
if 'strength' in cond[1]:
|
||||
strength = cond[1]['strength']
|
||||
|
||||
adm_cond = None
|
||||
if 'adm_encoded' in cond[1]:
|
||||
adm_cond = cond[1]['adm_encoded']
|
||||
if 'area' in conds:
|
||||
area = conds['area']
|
||||
if 'strength' in conds:
|
||||
strength = conds['strength']
|
||||
|
||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
if 'mask' in cond[1]:
|
||||
if 'mask' in conds:
|
||||
# Scale the mask to the size of the input
|
||||
# The mask should have been resized as we began the sampling process
|
||||
mask_strength = 1.0
|
||||
if "mask_strength" in cond[1]:
|
||||
mask_strength = cond[1]["mask_strength"]
|
||||
mask = cond[1]['mask']
|
||||
if "mask_strength" in conds:
|
||||
mask_strength = conds["mask_strength"]
|
||||
mask = conds['mask']
|
||||
assert(mask.shape[1] == x_in.shape[2])
|
||||
assert(mask.shape[2] == x_in.shape[3])
|
||||
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
||||
|
|
@ -51,7 +48,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||
mask = torch.ones_like(input_x)
|
||||
mult = mask * strength
|
||||
|
||||
if 'mask' not in cond[1]:
|
||||
if 'mask' not in conds:
|
||||
rr = 8
|
||||
if area[2] != 0:
|
||||
for t in range(rr):
|
||||
|
|
@ -67,27 +64,17 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||
|
||||
conditionning = {}
|
||||
conditionning['c_crossattn'] = cond[0]
|
||||
|
||||
if 'concat' in cond[1]:
|
||||
cond_concat_in = cond[1]['concat']
|
||||
if cond_concat_in is not None and len(cond_concat_in) > 0:
|
||||
cropped = []
|
||||
for x in cond_concat_in:
|
||||
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
cropped.append(cr)
|
||||
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
||||
|
||||
if adm_cond is not None:
|
||||
conditionning['c_adm'] = adm_cond
|
||||
model_conds = conds["model_conds"]
|
||||
for c in model_conds:
|
||||
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||
|
||||
control = None
|
||||
if 'control' in cond[1]:
|
||||
control = cond[1]['control']
|
||||
if 'control' in conds:
|
||||
control = conds['control']
|
||||
|
||||
patches = None
|
||||
if 'gligen' in cond[1]:
|
||||
gligen = cond[1]['gligen']
|
||||
if 'gligen' in conds:
|
||||
gligen = conds['gligen']
|
||||
patches = {}
|
||||
gligen_type = gligen[0]
|
||||
gligen_model = gligen[1]
|
||||
|
|
@ -105,22 +92,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||
return True
|
||||
if c1.keys() != c2.keys():
|
||||
return False
|
||||
if 'c_crossattn' in c1:
|
||||
s1 = c1['c_crossattn'].shape
|
||||
s2 = c2['c_crossattn'].shape
|
||||
if s1 != s2:
|
||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||
return False
|
||||
|
||||
mult_min = lcm(s1[1], s2[1])
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
if 'c_concat' in c1:
|
||||
if c1['c_concat'].shape != c2['c_concat'].shape:
|
||||
return False
|
||||
if 'c_adm' in c1:
|
||||
if c1['c_adm'].shape != c2['c_adm'].shape:
|
||||
for k in c1:
|
||||
if not c1[k].can_concat(c2[k]):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
|
@ -149,31 +122,19 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||
c_concat = []
|
||||
c_adm = []
|
||||
crossattn_max_len = 0
|
||||
for x in c_list:
|
||||
if 'c_crossattn' in x:
|
||||
c = x['c_crossattn']
|
||||
if crossattn_max_len == 0:
|
||||
crossattn_max_len = c.shape[1]
|
||||
else:
|
||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||
c_crossattn.append(c)
|
||||
if 'c_concat' in x:
|
||||
c_concat.append(x['c_concat'])
|
||||
if 'c_adm' in x:
|
||||
c_adm.append(x['c_adm'])
|
||||
out = {}
|
||||
c_crossattn_out = []
|
||||
for c in c_crossattn:
|
||||
if c.shape[1] < crossattn_max_len:
|
||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||
c_crossattn_out.append(c)
|
||||
|
||||
if len(c_crossattn_out) > 0:
|
||||
out['c_crossattn'] = torch.cat(c_crossattn_out)
|
||||
if len(c_concat) > 0:
|
||||
out['c_concat'] = torch.cat(c_concat)
|
||||
if len(c_adm) > 0:
|
||||
out['c_adm'] = torch.cat(c_adm)
|
||||
temp = {}
|
||||
for x in c_list:
|
||||
for k in x:
|
||||
cur = temp.get(k, [])
|
||||
cur.append(x[k])
|
||||
temp[k] = cur
|
||||
|
||||
out = {}
|
||||
for k in temp:
|
||||
conds = temp[k]
|
||||
out[k] = conds[0].concat(conds[1:])
|
||||
|
||||
return out
|
||||
|
||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
||||
|
|
@ -389,19 +350,19 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
|||
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
|
||||
for i in range(len(conditions)):
|
||||
c = conditions[i]
|
||||
if 'area' in c[1]:
|
||||
area = c[1]['area']
|
||||
if 'area' in c:
|
||||
area = c['area']
|
||||
if area[0] == "percentage":
|
||||
modified = c[1].copy()
|
||||
modified = c.copy()
|
||||
area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w))
|
||||
modified['area'] = area
|
||||
c = [c[0], modified]
|
||||
c = modified
|
||||
conditions[i] = c
|
||||
|
||||
if 'mask' in c[1]:
|
||||
mask = c[1]['mask']
|
||||
if 'mask' in c:
|
||||
mask = c['mask']
|
||||
mask = mask.to(device=device)
|
||||
modified = c[1].copy()
|
||||
modified = c.copy()
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[1] != h or mask.shape[2] != w:
|
||||
|
|
@ -422,37 +383,39 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
|||
modified['area'] = area
|
||||
|
||||
modified['mask'] = mask
|
||||
conditions[i] = [c[0], modified]
|
||||
conditions[i] = modified
|
||||
|
||||
def create_cond_with_same_area_if_none(conds, c):
|
||||
if 'area' not in c[1]:
|
||||
if 'area' not in c:
|
||||
return
|
||||
|
||||
c_area = c[1]['area']
|
||||
c_area = c['area']
|
||||
smallest = None
|
||||
for x in conds:
|
||||
if 'area' in x[1]:
|
||||
a = x[1]['area']
|
||||
if 'area' in x:
|
||||
a = x['area']
|
||||
if c_area[2] >= a[2] and c_area[3] >= a[3]:
|
||||
if a[0] + a[2] >= c_area[0] + c_area[2]:
|
||||
if a[1] + a[3] >= c_area[1] + c_area[3]:
|
||||
if smallest is None:
|
||||
smallest = x
|
||||
elif 'area' not in smallest[1]:
|
||||
elif 'area' not in smallest:
|
||||
smallest = x
|
||||
else:
|
||||
if smallest[1]['area'][0] * smallest[1]['area'][1] > a[0] * a[1]:
|
||||
if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]:
|
||||
smallest = x
|
||||
else:
|
||||
if smallest is None:
|
||||
smallest = x
|
||||
if smallest is None:
|
||||
return
|
||||
if 'area' in smallest[1]:
|
||||
if smallest[1]['area'] == c_area:
|
||||
if 'area' in smallest:
|
||||
if smallest['area'] == c_area:
|
||||
return
|
||||
n = c[1].copy()
|
||||
conds += [[smallest[0], n]]
|
||||
|
||||
out = c.copy()
|
||||
out['model_conds'] = smallest['model_conds'].copy() #TODO: which fields should be copied?
|
||||
conds += [out]
|
||||
|
||||
def calculate_start_end_timesteps(model, conds):
|
||||
for t in range(len(conds)):
|
||||
|
|
@ -460,18 +423,18 @@ def calculate_start_end_timesteps(model, conds):
|
|||
|
||||
timestep_start = None
|
||||
timestep_end = None
|
||||
if 'start_percent' in x[1]:
|
||||
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0)))
|
||||
if 'end_percent' in x[1]:
|
||||
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0)))
|
||||
if 'start_percent' in x:
|
||||
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['start_percent'] * 999.0)))
|
||||
if 'end_percent' in x:
|
||||
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['end_percent'] * 999.0)))
|
||||
|
||||
if (timestep_start is not None) or (timestep_end is not None):
|
||||
n = x[1].copy()
|
||||
n = x.copy()
|
||||
if (timestep_start is not None):
|
||||
n['timestep_start'] = timestep_start
|
||||
if (timestep_end is not None):
|
||||
n['timestep_end'] = timestep_end
|
||||
conds[t] = [x[0], n]
|
||||
conds[t] = n
|
||||
|
||||
def pre_run_control(model, conds):
|
||||
for t in range(len(conds)):
|
||||
|
|
@ -480,8 +443,8 @@ def pre_run_control(model, conds):
|
|||
timestep_start = None
|
||||
timestep_end = None
|
||||
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
|
||||
if 'control' in x[1]:
|
||||
x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
|
||||
if 'control' in x:
|
||||
x['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
|
||||
|
||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
cond_cnets = []
|
||||
|
|
@ -490,16 +453,16 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
|||
uncond_other = []
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
if 'area' not in x[1]:
|
||||
if name in x[1] and x[1][name] is not None:
|
||||
cond_cnets.append(x[1][name])
|
||||
if 'area' not in x:
|
||||
if name in x and x[name] is not None:
|
||||
cond_cnets.append(x[name])
|
||||
else:
|
||||
cond_other.append((x, t))
|
||||
for t in range(len(uncond)):
|
||||
x = uncond[t]
|
||||
if 'area' not in x[1]:
|
||||
if name in x[1] and x[1][name] is not None:
|
||||
uncond_cnets.append(x[1][name])
|
||||
if 'area' not in x:
|
||||
if name in x and x[name] is not None:
|
||||
uncond_cnets.append(x[name])
|
||||
else:
|
||||
uncond_other.append((x, t))
|
||||
|
||||
|
|
@ -509,47 +472,35 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
|||
for x in range(len(cond_cnets)):
|
||||
temp = uncond_other[x % len(uncond_other)]
|
||||
o = temp[0]
|
||||
if name in o[1] and o[1][name] is not None:
|
||||
n = o[1].copy()
|
||||
if name in o and o[name] is not None:
|
||||
n = o.copy()
|
||||
n[name] = uncond_fill_func(cond_cnets, x)
|
||||
uncond += [[o[0], n]]
|
||||
uncond += [n]
|
||||
else:
|
||||
n = o[1].copy()
|
||||
n = o.copy()
|
||||
n[name] = uncond_fill_func(cond_cnets, x)
|
||||
uncond[temp[1]] = [o[0], n]
|
||||
uncond[temp[1]] = n
|
||||
|
||||
def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
|
||||
def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs):
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
adm_out = None
|
||||
if 'adm' in x[1]:
|
||||
adm_out = x[1]["adm"]
|
||||
else:
|
||||
params = x[1].copy()
|
||||
params["width"] = params.get("width", width * 8)
|
||||
params["height"] = params.get("height", height * 8)
|
||||
params["prompt_type"] = params.get("prompt_type", prompt_type)
|
||||
adm_out = model.encode_adm(device=device, **params)
|
||||
|
||||
if adm_out is not None:
|
||||
x[1] = x[1].copy()
|
||||
x[1]["adm_encoded"] = fcbh.utils.repeat_to_batch_size(adm_out, batch_size).to(device)
|
||||
|
||||
return conds
|
||||
|
||||
def encode_cond(model_function, key, conds, device, **kwargs):
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
params = x[1].copy()
|
||||
params = x.copy()
|
||||
params["device"] = device
|
||||
params["noise"] = noise
|
||||
params["width"] = params.get("width", noise.shape[3] * 8)
|
||||
params["height"] = params.get("height", noise.shape[2] * 8)
|
||||
params["prompt_type"] = params.get("prompt_type", prompt_type)
|
||||
for k in kwargs:
|
||||
if k not in params:
|
||||
params[k] = kwargs[k]
|
||||
|
||||
out = model_function(**params)
|
||||
if out is not None:
|
||||
x[1] = x[1].copy()
|
||||
x[1][key] = out
|
||||
x = x.copy()
|
||||
model_conds = x['model_conds'].copy()
|
||||
for k in out:
|
||||
model_conds[k] = out[k]
|
||||
x['model_conds'] = model_conds
|
||||
conds[t] = x
|
||||
return conds
|
||||
|
||||
class Sampler:
|
||||
|
|
@ -667,19 +618,15 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||
|
||||
pre_run_control(model_wrap, negative + positive)
|
||||
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
if latent_image is not None:
|
||||
latent_image = model.process_latent_in(latent_image)
|
||||
|
||||
if model.is_adm():
|
||||
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
|
||||
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
|
||||
|
||||
if hasattr(model, 'cond_concat'):
|
||||
positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
if hasattr(model, 'extra_conds'):
|
||||
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
version = '2.1.739'
|
||||
version = '2.1.740'
|
||||
|
|
|
|||
|
|
@ -174,7 +174,6 @@ def worker():
|
|||
loras += [(inpaint_patch_model_path, 1.0)]
|
||||
print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
|
||||
goals.append('inpaint')
|
||||
sampler_name = 'dpmpp_2m_sde_gpu' # only support the patched dpmpp_2m_sde_gpu
|
||||
if current_tab == 'ip' or \
|
||||
advanced_parameters.mixing_image_prompt_and_inpaint or \
|
||||
advanced_parameters.mixing_image_prompt_and_vary_upscale:
|
||||
|
|
|
|||
|
|
@ -342,7 +342,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
|
|||
sigma_max = float(sigma_max.cpu().numpy())
|
||||
print(f'[Sampler] sigma_min = {sigma_min}, sigma_max = {sigma_max}')
|
||||
|
||||
modules.patch.globalBrownianTreeNoiseSampler = BrownianTreeNoiseSampler(
|
||||
modules.patch.BrownianTreeNoiseSamplerPatched.global_init(
|
||||
empty_latent['samples'].to(fcbh.model_management.get_torch_device()),
|
||||
sigma_min, sigma_max, seed=image_seed, cpu=False)
|
||||
|
||||
|
|
|
|||
102
modules/patch.py
102
modules/patch.py
|
|
@ -23,9 +23,10 @@ import args_manager
|
|||
import modules.advanced_parameters as advanced_parameters
|
||||
import warnings
|
||||
import safetensors.torch
|
||||
import modules.constants as constants
|
||||
|
||||
from fcbh.k_diffusion import utils
|
||||
from fcbh.k_diffusion.sampling import trange
|
||||
from fcbh.k_diffusion.sampling import BatchedBrownianTree
|
||||
from fcbh.ldm.modules.diffusionmodules.openaimodel import timestep_embedding, forward_timestep_embed
|
||||
|
||||
|
||||
|
|
@ -280,68 +281,58 @@ def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs):
|
|||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
||||
|
||||
|
||||
globalBrownianTreeNoiseSampler = None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_fooocus_2m_sde_inpaint_seamless(model, x, sigmas, extra_args=None, callback=None,
|
||||
disable=None, eta=1., s_noise=1., **kwargs):
|
||||
print('[Sampler] Fooocus sampler is activated.')
|
||||
|
||||
seed = extra_args.get("seed", None)
|
||||
assert isinstance(seed, int)
|
||||
|
||||
energy_generator = torch.Generator(device='cpu')
|
||||
energy_generator.manual_seed(seed + 1) # avoid bad results by using different seeds.
|
||||
|
||||
def get_energy():
|
||||
return torch.randn(x.size(), dtype=x.dtype, generator=energy_generator, device="cpu").to(x)
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
old_denoised, h_last, h = None, None, None
|
||||
|
||||
latent_processor = model.inner_model.inner_model.inner_model.process_latent_in
|
||||
inpaint_latent = None
|
||||
inpaint_mask = None
|
||||
|
||||
def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
|
||||
if inpaint_worker.current_task is not None:
|
||||
if getattr(self, 'energy_generator', None) is None:
|
||||
# avoid bad results by using different seeds.
|
||||
self.energy_generator = torch.Generator(device='cpu').manual_seed((seed + 1) % constants.MAX_SEED)
|
||||
|
||||
latent_processor = self.inner_model.inner_model.inner_model.process_latent_in
|
||||
inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x)
|
||||
inpaint_mask = inpaint_worker.current_task.latent_mask.to(x)
|
||||
energy_sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(x.shape) - 1))
|
||||
current_energy = torch.randn(x.size(), dtype=x.dtype, generator=self.energy_generator, device="cpu").to(x) * energy_sigma
|
||||
x = x * inpaint_mask + (inpaint_latent + current_energy) * (1.0 - inpaint_mask)
|
||||
|
||||
def blend_latent(a, b, w):
|
||||
return a * w + b * (1 - w)
|
||||
out = self.inner_model(x, sigma,
|
||||
cond=cond,
|
||||
uncond=uncond,
|
||||
cond_scale=cond_scale,
|
||||
model_options=model_options,
|
||||
seed=seed)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
if inpaint_latent is None:
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
else:
|
||||
energy = get_energy() * sigmas[i] + inpaint_latent
|
||||
x_prime = blend_latent(x, energy, inpaint_mask)
|
||||
denoised = model(x_prime, sigmas[i] * s_in, **extra_args)
|
||||
denoised = blend_latent(denoised, inpaint_latent, inpaint_mask)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
eta_h = eta * h
|
||||
out = out * inpaint_mask + inpaint_latent * (1.0 - inpaint_mask)
|
||||
else:
|
||||
out = self.inner_model(x, sigma,
|
||||
cond=cond,
|
||||
uncond=uncond,
|
||||
cond_scale=cond_scale,
|
||||
model_options=model_options,
|
||||
seed=seed)
|
||||
return out
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
||||
if old_denoised is not None:
|
||||
r = h_last / h
|
||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
|
||||
x = x + globalBrownianTreeNoiseSampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (
|
||||
-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||
class BrownianTreeNoiseSamplerPatched:
|
||||
transform = None
|
||||
tree = None
|
||||
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
@staticmethod
|
||||
def global_init(x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
|
||||
t0, t1 = transform(torch.as_tensor(sigma_min)), transform(torch.as_tensor(sigma_max))
|
||||
|
||||
return x
|
||||
BrownianTreeNoiseSamplerPatched.transform = transform
|
||||
BrownianTreeNoiseSamplerPatched.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def __call__(sigma, sigma_next):
|
||||
transform = BrownianTreeNoiseSamplerPatched.transform
|
||||
tree = BrownianTreeNoiseSamplerPatched.tree
|
||||
|
||||
t0, t1 = transform(torch.as_tensor(sigma)), transform(torch.as_tensor(sigma_next))
|
||||
return tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
|
||||
def timed_adm(y, timesteps):
|
||||
|
|
@ -523,10 +514,11 @@ def patch_all():
|
|||
fcbh.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched
|
||||
fcbh.cldm.cldm.ControlNet.forward = patched_cldm_forward
|
||||
fcbh.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward
|
||||
fcbh.k_diffusion.sampling.sample_dpmpp_2m_sde_gpu = sample_dpmpp_fooocus_2m_sde_inpaint_seamless
|
||||
fcbh.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward
|
||||
fcbh.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
|
||||
fcbh.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method
|
||||
fcbh.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward
|
||||
fcbh.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched
|
||||
|
||||
warnings.filterwarnings(action='ignore', module='torchsde')
|
||||
|
||||
|
|
|
|||
|
|
@ -3,10 +3,10 @@ import fcbh.samplers
|
|||
import fcbh.model_management
|
||||
|
||||
from fcbh.model_base import SDXLRefiner, SDXL
|
||||
from fcbh.conds import CONDRegular
|
||||
from fcbh.sample import get_additional_models, get_models_from_cond, cleanup_additional_models
|
||||
from fcbh.samplers import resolve_areas_and_cond_masks, wrap_model, calculate_start_end_timesteps, \
|
||||
create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_adm, \
|
||||
encode_cond
|
||||
create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_model_conds
|
||||
|
||||
|
||||
current_refiner = None
|
||||
|
|
@ -15,15 +15,13 @@ refiner_switch_step = -1
|
|||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def clip_separate(cond, target_model=None, target_clip=None):
|
||||
c, p = cond[0]
|
||||
def clip_separate_inner(c, p, target_model=None, target_clip=None):
|
||||
if target_model is None or isinstance(target_model, SDXLRefiner):
|
||||
c = c[..., -1280:].clone()
|
||||
p = {"pooled_output": p["pooled_output"].clone()}
|
||||
elif isinstance(target_model, SDXL):
|
||||
c = c.clone()
|
||||
p = {"pooled_output": p["pooled_output"].clone()}
|
||||
else:
|
||||
p = None
|
||||
c = c[..., :768].clone()
|
||||
|
||||
final_layer_norm = target_clip.cond_stage_model.clip_l.transformer.text_model.final_layer_norm
|
||||
|
|
@ -43,9 +41,42 @@ def clip_separate(cond, target_model=None, target_clip=None):
|
|||
|
||||
final_layer_norm.to(device=final_layer_norm_origin_device, dtype=final_layer_norm_origin_dtype)
|
||||
c = c.to(device=c_origin_device, dtype=c_origin_dtype)
|
||||
return c, p
|
||||
|
||||
p = {}
|
||||
return [[c, p]]
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def clip_separate(cond, target_model=None, target_clip=None):
|
||||
results = []
|
||||
|
||||
for c, px in cond:
|
||||
p = px.get('pooled_output', None)
|
||||
c, p = clip_separate_inner(c, p, target_model=target_model, target_clip=target_clip)
|
||||
p = {} if p is None else {'pooled_output': p.clone()}
|
||||
results.append([c, p])
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def clip_separate_after_preparation(cond, target_model=None, target_clip=None):
|
||||
results = []
|
||||
|
||||
for x in cond:
|
||||
p = x.get('pooled_output', None)
|
||||
c = x['model_conds']['c_crossattn'].cond
|
||||
|
||||
c, p = clip_separate_inner(c, p, target_model=target_model, target_clip=target_clip)
|
||||
|
||||
result = {'model_conds': {'c_crossattn': CONDRegular(c)}}
|
||||
|
||||
if p is not None:
|
||||
result['pooled_output'] = p.clone()
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
@ -73,31 +104,24 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas
|
|||
# pre_run_control(model_wrap, negative + positive)
|
||||
pre_run_control(model_wrap, positive) # negative is not necessary in Fooocus, 0.5s faster.
|
||||
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
if latent_image is not None:
|
||||
latent_image = model.process_latent_in(latent_image)
|
||||
|
||||
if model.is_adm():
|
||||
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
|
||||
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
|
||||
|
||||
if hasattr(model, 'cond_concat'):
|
||||
positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
if hasattr(model, 'extra_conds'):
|
||||
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||
|
||||
if current_refiner is not None and current_refiner.model.is_adm():
|
||||
positive_refiner = clip_separate(positive, target_model=current_refiner.model)
|
||||
negative_refiner = clip_separate(negative, target_model=current_refiner.model)
|
||||
if current_refiner is not None and hasattr(current_refiner.model, 'extra_conds'):
|
||||
positive_refiner = clip_separate_after_preparation(positive, target_model=current_refiner.model)
|
||||
negative_refiner = clip_separate_after_preparation(negative, target_model=current_refiner.model)
|
||||
|
||||
positive_refiner = encode_adm(current_refiner.model, positive_refiner, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
|
||||
negative_refiner = encode_adm(current_refiner.model, negative_refiner, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
|
||||
|
||||
positive_refiner[0][1]['adm_encoded'].to(positive[0][1]['adm_encoded'])
|
||||
negative_refiner[0][1]['adm_encoded'].to(negative[0][1]['adm_encoded'])
|
||||
positive_refiner = encode_model_conds(current_refiner.model.extra_conds, positive_refiner, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
negative_refiner = encode_model_conds(current_refiner.model.extra_conds, negative_refiner, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
|
||||
def refiner_switch():
|
||||
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
|
||||
|
|
|
|||
4
webui.py
4
webui.py
|
|
@ -148,9 +148,9 @@ with shared.gradio_root:
|
|||
|
||||
with gr.TabItem(label='Inpaint or Outpaint (beta)') as inpaint_tab:
|
||||
inpaint_input_image = grh.Image(label='Drag above image to here', source='upload', type='numpy', tool='sketch', height=500, brush_color="#FFFFFF", elem_id='inpaint_canvas')
|
||||
gr.HTML('Outpaint Expansion (<a href="https://github.com/lllyasviel/Fooocus/discussions/414" target="_blank">\U0001F4D4 Document</a>):')
|
||||
gr.HTML('Outpaint Expansion Direction:')
|
||||
outpaint_selections = gr.CheckboxGroup(choices=['Left', 'Right', 'Top', 'Bottom'], value=[], label='Outpaint', show_label=False, container=False)
|
||||
gr.HTML('* \"Inpaint or Outpaint\" is powered by the sampler \"DPMPP Fooocus Seamless 2M SDE Karras Inpaint Sampler\" (beta)')
|
||||
gr.HTML('* Powered by Fooocus Inpaint Engine (beta) <a href="https://github.com/lllyasviel/Fooocus/discussions/414" target="_blank">\U0001F4D4 Document</a>')
|
||||
|
||||
switch_js = "(x) => {if(x){setTimeout(() => window.scrollTo({ top: 850, behavior: 'smooth' }), 50);}else{setTimeout(() => window.scrollTo({ top: 0, behavior: 'smooth' }), 50);} return x}"
|
||||
down_js = "() => {setTimeout(() => window.scrollTo({ top: 850, behavior: 'smooth' }), 50);}"
|
||||
|
|
|
|||
Loading…
Reference in New Issue