diff --git a/backend/headless/fcbh/conds.py b/backend/headless/fcbh/conds.py
new file mode 100644
index 00000000..252bb869
--- /dev/null
+++ b/backend/headless/fcbh/conds.py
@@ -0,0 +1,64 @@
+import enum
+import torch
+import math
+import fcbh.utils
+
+
+def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
+ return abs(a*b) // math.gcd(a, b)
+
+class CONDRegular:
+ def __init__(self, cond):
+ self.cond = cond
+
+ def _copy_with(self, cond):
+ return self.__class__(cond)
+
+ def process_cond(self, batch_size, device, **kwargs):
+ return self._copy_with(fcbh.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
+
+ def can_concat(self, other):
+ if self.cond.shape != other.cond.shape:
+ return False
+ return True
+
+ def concat(self, others):
+ conds = [self.cond]
+ for x in others:
+ conds.append(x.cond)
+ return torch.cat(conds)
+
+class CONDNoiseShape(CONDRegular):
+ def process_cond(self, batch_size, device, area, **kwargs):
+ data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
+ return self._copy_with(fcbh.utils.repeat_to_batch_size(data, batch_size).to(device))
+
+
+class CONDCrossAttn(CONDRegular):
+ def can_concat(self, other):
+ s1 = self.cond.shape
+ s2 = other.cond.shape
+ if s1 != s2:
+ if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
+ return False
+
+ mult_min = lcm(s1[1], s2[1])
+ diff = mult_min // min(s1[1], s2[1])
+ if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
+ return False
+ return True
+
+ def concat(self, others):
+ conds = [self.cond]
+ crossattn_max_len = self.cond.shape[1]
+ for x in others:
+ c = x.cond
+ crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
+ conds.append(c)
+
+ out = []
+ for c in conds:
+ if c.shape[1] < crossattn_max_len:
+ c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
+ out.append(c)
+ return torch.cat(out)
diff --git a/backend/headless/fcbh/controlnet.py b/backend/headless/fcbh/controlnet.py
index dcdd0c1f..ab6c38f6 100644
--- a/backend/headless/fcbh/controlnet.py
+++ b/backend/headless/fcbh/controlnet.py
@@ -156,7 +156,7 @@ class ControlNet(ControlBase):
context = cond['c_crossattn']
- y = cond.get('c_adm', None)
+ y = cond.get('y', None)
if y is not None:
y = y.to(self.control_model.dtype)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
diff --git a/backend/headless/fcbh/model_base.py b/backend/headless/fcbh/model_base.py
index f3f708f7..86525d99 100644
--- a/backend/headless/fcbh/model_base.py
+++ b/backend/headless/fcbh/model_base.py
@@ -4,6 +4,7 @@ from fcbh.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmen
from fcbh.ldm.modules.diffusionmodules.util import make_beta_schedule
from fcbh.ldm.modules.diffusionmodules.openaimodel import Timestep
import fcbh.model_management
+import fcbh.conds
import numpy as np
from enum import Enum
from . import utils
@@ -49,7 +50,7 @@ class BaseModel(torch.nn.Module):
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
- def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
+ def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
if c_concat is not None:
xc = torch.cat([x] + [c_concat], dim=1)
else:
@@ -59,9 +60,10 @@ class BaseModel(torch.nn.Module):
xc = xc.to(dtype)
t = t.to(dtype)
context = context.to(dtype)
- if c_adm is not None:
- c_adm = c_adm.to(dtype)
- return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float()
+ extra_conds = {}
+ for o in kwargs:
+ extra_conds[o] = kwargs[o].to(dtype)
+ return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
def get_dtype(self):
return self.diffusion_model.dtype
@@ -72,7 +74,8 @@ class BaseModel(torch.nn.Module):
def encode_adm(self, **kwargs):
return None
- def cond_concat(self, **kwargs):
+ def extra_conds(self, **kwargs):
+ out = {}
if self.inpaint_model:
concat_keys = ("mask", "masked_image")
cond_concat = []
@@ -101,8 +104,12 @@ class BaseModel(torch.nn.Module):
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
- return cond_concat
- return None
+ data = torch.cat(cond_concat, dim=1)
+ out['c_concat'] = fcbh.conds.CONDNoiseShape(data)
+ adm = self.encode_adm(**kwargs)
+ if adm is not None:
+ out['y'] = fcbh.conds.CONDRegular(adm)
+ return out
def load_model_weights(self, sd, unet_prefix=""):
to_load = {}
diff --git a/backend/headless/fcbh/sample.py b/backend/headless/fcbh/sample.py
index b6e0fddc..55946160 100644
--- a/backend/headless/fcbh/sample.py
+++ b/backend/headless/fcbh/sample.py
@@ -1,6 +1,7 @@
import torch
import fcbh.model_management
import fcbh.samplers
+import fcbh.conds
import fcbh.utils
import math
import numpy as np
@@ -33,22 +34,24 @@ def prepare_mask(noise_mask, shape, device):
noise_mask = noise_mask.to(device)
return noise_mask
-def broadcast_cond(cond, batch, device):
- """broadcasts conditioning to the batch size"""
- copy = []
- for p in cond:
- t = fcbh.utils.repeat_to_batch_size(p[0], batch)
- t = t.to(device)
- copy += [[t] + p[1:]]
- return copy
-
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
- if model_type in c[1]:
- models += [c[1][model_type]]
+ if model_type in c:
+ models += [c[model_type]]
return models
+def convert_cond(cond):
+ out = []
+ for c in cond:
+ temp = c[1].copy()
+ model_conds = temp.get("model_conds", {})
+ if c[0] is not None:
+ model_conds["c_crossattn"] = fcbh.conds.CONDCrossAttn(c[0])
+ temp["model_conds"] = model_conds
+ out.append(temp)
+ return out
+
def get_additional_models(positive, negative, dtype):
"""loads additional models in positive and negative conditioning"""
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
@@ -72,6 +75,8 @@ def cleanup_additional_models(models):
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
device = model.load_device
+ positive = convert_cond(positive)
+ negative = convert_cond(negative)
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise_shape, device)
@@ -81,9 +86,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
fcbh.model_management.load_models_gpu([model] + models, fcbh.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory)
real_model = model.model
- positive_copy = broadcast_cond(positive, noise_shape[0], device)
- negative_copy = broadcast_cond(negative, noise_shape[0], device)
- return real_model, positive_copy, negative_copy, noise_mask, models
+ return real_model, positive, negative, noise_mask, models
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
diff --git a/backend/headless/fcbh/samplers.py b/backend/headless/fcbh/samplers.py
index fe414995..91050a4e 100644
--- a/backend/headless/fcbh/samplers.py
+++ b/backend/headless/fcbh/samplers.py
@@ -2,47 +2,44 @@ from .k_diffusion import sampling as k_diffusion_sampling
from .k_diffusion import external as k_diffusion_external
from .extra_samplers import uni_pc
import torch
+import enum
from fcbh import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math
from fcbh import model_base
import fcbh.utils
+import fcbh.conds
-def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
- return abs(a*b) // math.gcd(a, b)
#The main sampling function shared by all the samplers
#Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
- def get_area_and_mult(cond, x_in, timestep_in):
+ def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
- if 'timestep_start' in cond[1]:
- timestep_start = cond[1]['timestep_start']
+
+ if 'timestep_start' in conds:
+ timestep_start = conds['timestep_start']
if timestep_in[0] > timestep_start:
return None
- if 'timestep_end' in cond[1]:
- timestep_end = cond[1]['timestep_end']
+ if 'timestep_end' in conds:
+ timestep_end = conds['timestep_end']
if timestep_in[0] < timestep_end:
return None
- if 'area' in cond[1]:
- area = cond[1]['area']
- if 'strength' in cond[1]:
- strength = cond[1]['strength']
-
- adm_cond = None
- if 'adm_encoded' in cond[1]:
- adm_cond = cond[1]['adm_encoded']
+ if 'area' in conds:
+ area = conds['area']
+ if 'strength' in conds:
+ strength = conds['strength']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
- if 'mask' in cond[1]:
+ if 'mask' in conds:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
- if "mask_strength" in cond[1]:
- mask_strength = cond[1]["mask_strength"]
- mask = cond[1]['mask']
+ if "mask_strength" in conds:
+ mask_strength = conds["mask_strength"]
+ mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
@@ -51,7 +48,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
mask = torch.ones_like(input_x)
mult = mask * strength
- if 'mask' not in cond[1]:
+ if 'mask' not in conds:
rr = 8
if area[2] != 0:
for t in range(rr):
@@ -67,27 +64,17 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {}
- conditionning['c_crossattn'] = cond[0]
-
- if 'concat' in cond[1]:
- cond_concat_in = cond[1]['concat']
- if cond_concat_in is not None and len(cond_concat_in) > 0:
- cropped = []
- for x in cond_concat_in:
- cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
- cropped.append(cr)
- conditionning['c_concat'] = torch.cat(cropped, dim=1)
-
- if adm_cond is not None:
- conditionning['c_adm'] = adm_cond
+ model_conds = conds["model_conds"]
+ for c in model_conds:
+ conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
control = None
- if 'control' in cond[1]:
- control = cond[1]['control']
+ if 'control' in conds:
+ control = conds['control']
patches = None
- if 'gligen' in cond[1]:
- gligen = cond[1]['gligen']
+ if 'gligen' in conds:
+ gligen = conds['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
@@ -105,22 +92,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
return True
if c1.keys() != c2.keys():
return False
- if 'c_crossattn' in c1:
- s1 = c1['c_crossattn'].shape
- s2 = c2['c_crossattn'].shape
- if s1 != s2:
- if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
- return False
-
- mult_min = lcm(s1[1], s2[1])
- diff = mult_min // min(s1[1], s2[1])
- if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
- return False
- if 'c_concat' in c1:
- if c1['c_concat'].shape != c2['c_concat'].shape:
- return False
- if 'c_adm' in c1:
- if c1['c_adm'].shape != c2['c_adm'].shape:
+ for k in c1:
+ if not c1[k].can_concat(c2[k]):
return False
return True
@@ -149,31 +122,19 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
c_concat = []
c_adm = []
crossattn_max_len = 0
- for x in c_list:
- if 'c_crossattn' in x:
- c = x['c_crossattn']
- if crossattn_max_len == 0:
- crossattn_max_len = c.shape[1]
- else:
- crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
- c_crossattn.append(c)
- if 'c_concat' in x:
- c_concat.append(x['c_concat'])
- if 'c_adm' in x:
- c_adm.append(x['c_adm'])
- out = {}
- c_crossattn_out = []
- for c in c_crossattn:
- if c.shape[1] < crossattn_max_len:
- c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
- c_crossattn_out.append(c)
- if len(c_crossattn_out) > 0:
- out['c_crossattn'] = torch.cat(c_crossattn_out)
- if len(c_concat) > 0:
- out['c_concat'] = torch.cat(c_concat)
- if len(c_adm) > 0:
- out['c_adm'] = torch.cat(c_adm)
+ temp = {}
+ for x in c_list:
+ for k in x:
+ cur = temp.get(k, [])
+ cur.append(x[k])
+ temp[k] = cur
+
+ out = {}
+ for k in temp:
+ conds = temp[k]
+ out[k] = conds[0].concat(conds[1:])
+
return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
@@ -389,19 +350,19 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
for i in range(len(conditions)):
c = conditions[i]
- if 'area' in c[1]:
- area = c[1]['area']
+ if 'area' in c:
+ area = c['area']
if area[0] == "percentage":
- modified = c[1].copy()
+ modified = c.copy()
area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w))
modified['area'] = area
- c = [c[0], modified]
+ c = modified
conditions[i] = c
- if 'mask' in c[1]:
- mask = c[1]['mask']
+ if 'mask' in c:
+ mask = c['mask']
mask = mask.to(device=device)
- modified = c[1].copy()
+ modified = c.copy()
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
if mask.shape[1] != h or mask.shape[2] != w:
@@ -422,37 +383,39 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
modified['area'] = area
modified['mask'] = mask
- conditions[i] = [c[0], modified]
+ conditions[i] = modified
def create_cond_with_same_area_if_none(conds, c):
- if 'area' not in c[1]:
+ if 'area' not in c:
return
- c_area = c[1]['area']
+ c_area = c['area']
smallest = None
for x in conds:
- if 'area' in x[1]:
- a = x[1]['area']
+ if 'area' in x:
+ a = x['area']
if c_area[2] >= a[2] and c_area[3] >= a[3]:
if a[0] + a[2] >= c_area[0] + c_area[2]:
if a[1] + a[3] >= c_area[1] + c_area[3]:
if smallest is None:
smallest = x
- elif 'area' not in smallest[1]:
+ elif 'area' not in smallest:
smallest = x
else:
- if smallest[1]['area'][0] * smallest[1]['area'][1] > a[0] * a[1]:
+ if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]:
smallest = x
else:
if smallest is None:
smallest = x
if smallest is None:
return
- if 'area' in smallest[1]:
- if smallest[1]['area'] == c_area:
+ if 'area' in smallest:
+ if smallest['area'] == c_area:
return
- n = c[1].copy()
- conds += [[smallest[0], n]]
+
+ out = c.copy()
+ out['model_conds'] = smallest['model_conds'].copy() #TODO: which fields should be copied?
+ conds += [out]
def calculate_start_end_timesteps(model, conds):
for t in range(len(conds)):
@@ -460,18 +423,18 @@ def calculate_start_end_timesteps(model, conds):
timestep_start = None
timestep_end = None
- if 'start_percent' in x[1]:
- timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0)))
- if 'end_percent' in x[1]:
- timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0)))
+ if 'start_percent' in x:
+ timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['start_percent'] * 999.0)))
+ if 'end_percent' in x:
+ timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['end_percent'] * 999.0)))
if (timestep_start is not None) or (timestep_end is not None):
- n = x[1].copy()
+ n = x.copy()
if (timestep_start is not None):
n['timestep_start'] = timestep_start
if (timestep_end is not None):
n['timestep_end'] = timestep_end
- conds[t] = [x[0], n]
+ conds[t] = n
def pre_run_control(model, conds):
for t in range(len(conds)):
@@ -480,8 +443,8 @@ def pre_run_control(model, conds):
timestep_start = None
timestep_end = None
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
- if 'control' in x[1]:
- x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
+ if 'control' in x:
+ x['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
@@ -490,16 +453,16 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
uncond_other = []
for t in range(len(conds)):
x = conds[t]
- if 'area' not in x[1]:
- if name in x[1] and x[1][name] is not None:
- cond_cnets.append(x[1][name])
+ if 'area' not in x:
+ if name in x and x[name] is not None:
+ cond_cnets.append(x[name])
else:
cond_other.append((x, t))
for t in range(len(uncond)):
x = uncond[t]
- if 'area' not in x[1]:
- if name in x[1] and x[1][name] is not None:
- uncond_cnets.append(x[1][name])
+ if 'area' not in x:
+ if name in x and x[name] is not None:
+ uncond_cnets.append(x[name])
else:
uncond_other.append((x, t))
@@ -509,47 +472,35 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
for x in range(len(cond_cnets)):
temp = uncond_other[x % len(uncond_other)]
o = temp[0]
- if name in o[1] and o[1][name] is not None:
- n = o[1].copy()
+ if name in o and o[name] is not None:
+ n = o.copy()
n[name] = uncond_fill_func(cond_cnets, x)
- uncond += [[o[0], n]]
+ uncond += [n]
else:
- n = o[1].copy()
+ n = o.copy()
n[name] = uncond_fill_func(cond_cnets, x)
- uncond[temp[1]] = [o[0], n]
+ uncond[temp[1]] = n
-def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
+def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs):
for t in range(len(conds)):
x = conds[t]
- adm_out = None
- if 'adm' in x[1]:
- adm_out = x[1]["adm"]
- else:
- params = x[1].copy()
- params["width"] = params.get("width", width * 8)
- params["height"] = params.get("height", height * 8)
- params["prompt_type"] = params.get("prompt_type", prompt_type)
- adm_out = model.encode_adm(device=device, **params)
-
- if adm_out is not None:
- x[1] = x[1].copy()
- x[1]["adm_encoded"] = fcbh.utils.repeat_to_batch_size(adm_out, batch_size).to(device)
-
- return conds
-
-def encode_cond(model_function, key, conds, device, **kwargs):
- for t in range(len(conds)):
- x = conds[t]
- params = x[1].copy()
+ params = x.copy()
params["device"] = device
+ params["noise"] = noise
+ params["width"] = params.get("width", noise.shape[3] * 8)
+ params["height"] = params.get("height", noise.shape[2] * 8)
+ params["prompt_type"] = params.get("prompt_type", prompt_type)
for k in kwargs:
if k not in params:
params[k] = kwargs[k]
out = model_function(**params)
- if out is not None:
- x[1] = x[1].copy()
- x[1][key] = out
+ x = x.copy()
+ model_conds = x['model_conds'].copy()
+ for k in out:
+ model_conds[k] = out[k]
+ x['model_conds'] = model_conds
+ conds[t] = x
return conds
class Sampler:
@@ -667,19 +618,15 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
pre_run_control(model_wrap, negative + positive)
- apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
+ apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
- if model.is_adm():
- positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
- negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
-
- if hasattr(model, 'cond_concat'):
- positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
- negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
+ if hasattr(model, 'extra_conds'):
+ positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
+ negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
diff --git a/fooocus_version.py b/fooocus_version.py
index 0eab2d2d..5f627952 100644
--- a/fooocus_version.py
+++ b/fooocus_version.py
@@ -1 +1 @@
-version = '2.1.739'
+version = '2.1.740'
diff --git a/modules/async_worker.py b/modules/async_worker.py
index 820bdc3c..5fc925ea 100644
--- a/modules/async_worker.py
+++ b/modules/async_worker.py
@@ -174,7 +174,6 @@ def worker():
loras += [(inpaint_patch_model_path, 1.0)]
print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
goals.append('inpaint')
- sampler_name = 'dpmpp_2m_sde_gpu' # only support the patched dpmpp_2m_sde_gpu
if current_tab == 'ip' or \
advanced_parameters.mixing_image_prompt_and_inpaint or \
advanced_parameters.mixing_image_prompt_and_vary_upscale:
diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py
index 8557ac2b..b32f8111 100644
--- a/modules/default_pipeline.py
+++ b/modules/default_pipeline.py
@@ -342,7 +342,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
sigma_max = float(sigma_max.cpu().numpy())
print(f'[Sampler] sigma_min = {sigma_min}, sigma_max = {sigma_max}')
- modules.patch.globalBrownianTreeNoiseSampler = BrownianTreeNoiseSampler(
+ modules.patch.BrownianTreeNoiseSamplerPatched.global_init(
empty_latent['samples'].to(fcbh.model_management.get_torch_device()),
sigma_min, sigma_max, seed=image_seed, cpu=False)
diff --git a/modules/patch.py b/modules/patch.py
index 8303c22a..98b56d5e 100644
--- a/modules/patch.py
+++ b/modules/patch.py
@@ -23,9 +23,10 @@ import args_manager
import modules.advanced_parameters as advanced_parameters
import warnings
import safetensors.torch
+import modules.constants as constants
from fcbh.k_diffusion import utils
-from fcbh.k_diffusion.sampling import trange
+from fcbh.k_diffusion.sampling import BatchedBrownianTree
from fcbh.ldm.modules.diffusionmodules.openaimodel import timestep_embedding, forward_timestep_embed
@@ -280,68 +281,58 @@ def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs):
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
-globalBrownianTreeNoiseSampler = None
-
-
-@torch.no_grad()
-def sample_dpmpp_fooocus_2m_sde_inpaint_seamless(model, x, sigmas, extra_args=None, callback=None,
- disable=None, eta=1., s_noise=1., **kwargs):
- print('[Sampler] Fooocus sampler is activated.')
-
- seed = extra_args.get("seed", None)
- assert isinstance(seed, int)
-
- energy_generator = torch.Generator(device='cpu')
- energy_generator.manual_seed(seed + 1) # avoid bad results by using different seeds.
-
- def get_energy():
- return torch.randn(x.size(), dtype=x.dtype, generator=energy_generator, device="cpu").to(x)
-
- extra_args = {} if extra_args is None else extra_args
- s_in = x.new_ones([x.shape[0]])
-
- old_denoised, h_last, h = None, None, None
-
- latent_processor = model.inner_model.inner_model.inner_model.process_latent_in
- inpaint_latent = None
- inpaint_mask = None
-
+def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
if inpaint_worker.current_task is not None:
+ if getattr(self, 'energy_generator', None) is None:
+ # avoid bad results by using different seeds.
+ self.energy_generator = torch.Generator(device='cpu').manual_seed((seed + 1) % constants.MAX_SEED)
+
+ latent_processor = self.inner_model.inner_model.inner_model.process_latent_in
inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x)
inpaint_mask = inpaint_worker.current_task.latent_mask.to(x)
+ energy_sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(x.shape) - 1))
+ current_energy = torch.randn(x.size(), dtype=x.dtype, generator=self.energy_generator, device="cpu").to(x) * energy_sigma
+ x = x * inpaint_mask + (inpaint_latent + current_energy) * (1.0 - inpaint_mask)
- def blend_latent(a, b, w):
- return a * w + b * (1 - w)
+ out = self.inner_model(x, sigma,
+ cond=cond,
+ uncond=uncond,
+ cond_scale=cond_scale,
+ model_options=model_options,
+ seed=seed)
- for i in trange(len(sigmas) - 1, disable=disable):
- if inpaint_latent is None:
- denoised = model(x, sigmas[i] * s_in, **extra_args)
- else:
- energy = get_energy() * sigmas[i] + inpaint_latent
- x_prime = blend_latent(x, energy, inpaint_mask)
- denoised = model(x_prime, sigmas[i] * s_in, **extra_args)
- denoised = blend_latent(denoised, inpaint_latent, inpaint_mask)
- if callback is not None:
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
- if sigmas[i + 1] == 0:
- x = denoised
- else:
- t, s = -sigmas[i].log(), -sigmas[i + 1].log()
- h = s - t
- eta_h = eta * h
+ out = out * inpaint_mask + inpaint_latent * (1.0 - inpaint_mask)
+ else:
+ out = self.inner_model(x, sigma,
+ cond=cond,
+ uncond=uncond,
+ cond_scale=cond_scale,
+ model_options=model_options,
+ seed=seed)
+ return out
- x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
- if old_denoised is not None:
- r = h_last / h
- x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
- x = x + globalBrownianTreeNoiseSampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (
- -2 * eta_h).expm1().neg().sqrt() * s_noise
+class BrownianTreeNoiseSamplerPatched:
+ transform = None
+ tree = None
- old_denoised = denoised
- h_last = h
+ @staticmethod
+ def global_init(x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
+ t0, t1 = transform(torch.as_tensor(sigma_min)), transform(torch.as_tensor(sigma_max))
- return x
+ BrownianTreeNoiseSamplerPatched.transform = transform
+ BrownianTreeNoiseSamplerPatched.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
+
+ def __init__(self, *args, **kwargs):
+ pass
+
+ @staticmethod
+ def __call__(sigma, sigma_next):
+ transform = BrownianTreeNoiseSamplerPatched.transform
+ tree = BrownianTreeNoiseSamplerPatched.tree
+
+ t0, t1 = transform(torch.as_tensor(sigma)), transform(torch.as_tensor(sigma_next))
+ return tree(t0, t1) / (t1 - t0).abs().sqrt()
def timed_adm(y, timesteps):
@@ -523,10 +514,11 @@ def patch_all():
fcbh.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched
fcbh.cldm.cldm.ControlNet.forward = patched_cldm_forward
fcbh.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward
- fcbh.k_diffusion.sampling.sample_dpmpp_2m_sde_gpu = sample_dpmpp_fooocus_2m_sde_inpaint_seamless
fcbh.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward
fcbh.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
fcbh.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method
+ fcbh.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward
+ fcbh.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched
warnings.filterwarnings(action='ignore', module='torchsde')
diff --git a/modules/sample_hijack.py b/modules/sample_hijack.py
index bf7ea096..30e47b65 100644
--- a/modules/sample_hijack.py
+++ b/modules/sample_hijack.py
@@ -3,10 +3,10 @@ import fcbh.samplers
import fcbh.model_management
from fcbh.model_base import SDXLRefiner, SDXL
+from fcbh.conds import CONDRegular
from fcbh.sample import get_additional_models, get_models_from_cond, cleanup_additional_models
from fcbh.samplers import resolve_areas_and_cond_masks, wrap_model, calculate_start_end_timesteps, \
- create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_adm, \
- encode_cond
+ create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_model_conds
current_refiner = None
@@ -15,15 +15,13 @@ refiner_switch_step = -1
@torch.no_grad()
@torch.inference_mode()
-def clip_separate(cond, target_model=None, target_clip=None):
- c, p = cond[0]
+def clip_separate_inner(c, p, target_model=None, target_clip=None):
if target_model is None or isinstance(target_model, SDXLRefiner):
c = c[..., -1280:].clone()
- p = {"pooled_output": p["pooled_output"].clone()}
elif isinstance(target_model, SDXL):
c = c.clone()
- p = {"pooled_output": p["pooled_output"].clone()}
else:
+ p = None
c = c[..., :768].clone()
final_layer_norm = target_clip.cond_stage_model.clip_l.transformer.text_model.final_layer_norm
@@ -43,9 +41,42 @@ def clip_separate(cond, target_model=None, target_clip=None):
final_layer_norm.to(device=final_layer_norm_origin_device, dtype=final_layer_norm_origin_dtype)
c = c.to(device=c_origin_device, dtype=c_origin_dtype)
+ return c, p
- p = {}
- return [[c, p]]
+
+@torch.no_grad()
+@torch.inference_mode()
+def clip_separate(cond, target_model=None, target_clip=None):
+ results = []
+
+ for c, px in cond:
+ p = px.get('pooled_output', None)
+ c, p = clip_separate_inner(c, p, target_model=target_model, target_clip=target_clip)
+ p = {} if p is None else {'pooled_output': p.clone()}
+ results.append([c, p])
+
+ return results
+
+
+@torch.no_grad()
+@torch.inference_mode()
+def clip_separate_after_preparation(cond, target_model=None, target_clip=None):
+ results = []
+
+ for x in cond:
+ p = x.get('pooled_output', None)
+ c = x['model_conds']['c_crossattn'].cond
+
+ c, p = clip_separate_inner(c, p, target_model=target_model, target_clip=target_clip)
+
+ result = {'model_conds': {'c_crossattn': CONDRegular(c)}}
+
+ if p is not None:
+ result['pooled_output'] = p.clone()
+
+ results.append(result)
+
+ return results
@torch.no_grad()
@@ -73,31 +104,24 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas
# pre_run_control(model_wrap, negative + positive)
pre_run_control(model_wrap, positive) # negative is not necessary in Fooocus, 0.5s faster.
- apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
+ apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
- if model.is_adm():
- positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
- negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
-
- if hasattr(model, 'cond_concat'):
- positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
- negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
+ if hasattr(model, 'extra_conds'):
+ positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
+ negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
- if current_refiner is not None and current_refiner.model.is_adm():
- positive_refiner = clip_separate(positive, target_model=current_refiner.model)
- negative_refiner = clip_separate(negative, target_model=current_refiner.model)
+ if current_refiner is not None and hasattr(current_refiner.model, 'extra_conds'):
+ positive_refiner = clip_separate_after_preparation(positive, target_model=current_refiner.model)
+ negative_refiner = clip_separate_after_preparation(negative, target_model=current_refiner.model)
- positive_refiner = encode_adm(current_refiner.model, positive_refiner, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
- negative_refiner = encode_adm(current_refiner.model, negative_refiner, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
-
- positive_refiner[0][1]['adm_encoded'].to(positive[0][1]['adm_encoded'])
- negative_refiner[0][1]['adm_encoded'].to(negative[0][1]['adm_encoded'])
+ positive_refiner = encode_model_conds(current_refiner.model.extra_conds, positive_refiner, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
+ negative_refiner = encode_model_conds(current_refiner.model.extra_conds, negative_refiner, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
def refiner_switch():
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
diff --git a/webui.py b/webui.py
index 180825d8..ee0975a9 100644
--- a/webui.py
+++ b/webui.py
@@ -148,9 +148,9 @@ with shared.gradio_root:
with gr.TabItem(label='Inpaint or Outpaint (beta)') as inpaint_tab:
inpaint_input_image = grh.Image(label='Drag above image to here', source='upload', type='numpy', tool='sketch', height=500, brush_color="#FFFFFF", elem_id='inpaint_canvas')
- gr.HTML('Outpaint Expansion (\U0001F4D4 Document):')
+ gr.HTML('Outpaint Expansion Direction:')
outpaint_selections = gr.CheckboxGroup(choices=['Left', 'Right', 'Top', 'Bottom'], value=[], label='Outpaint', show_label=False, container=False)
- gr.HTML('* \"Inpaint or Outpaint\" is powered by the sampler \"DPMPP Fooocus Seamless 2M SDE Karras Inpaint Sampler\" (beta)')
+ gr.HTML('* Powered by Fooocus Inpaint Engine (beta) \U0001F4D4 Document')
switch_js = "(x) => {if(x){setTimeout(() => window.scrollTo({ top: 850, behavior: 'smooth' }), 50);}else{setTimeout(() => window.scrollTo({ top: 0, behavior: 'smooth' }), 50);} return x}"
down_js = "() => {setTimeout(() => window.scrollTo({ top: 850, behavior: 'smooth' }), 50);}"