diff --git a/backend/headless/fcbh/conds.py b/backend/headless/fcbh/conds.py new file mode 100644 index 00000000..252bb869 --- /dev/null +++ b/backend/headless/fcbh/conds.py @@ -0,0 +1,64 @@ +import enum +import torch +import math +import fcbh.utils + + +def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) + return abs(a*b) // math.gcd(a, b) + +class CONDRegular: + def __init__(self, cond): + self.cond = cond + + def _copy_with(self, cond): + return self.__class__(cond) + + def process_cond(self, batch_size, device, **kwargs): + return self._copy_with(fcbh.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) + + def can_concat(self, other): + if self.cond.shape != other.cond.shape: + return False + return True + + def concat(self, others): + conds = [self.cond] + for x in others: + conds.append(x.cond) + return torch.cat(conds) + +class CONDNoiseShape(CONDRegular): + def process_cond(self, batch_size, device, area, **kwargs): + data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + return self._copy_with(fcbh.utils.repeat_to_batch_size(data, batch_size).to(device)) + + +class CONDCrossAttn(CONDRegular): + def can_concat(self, other): + s1 = self.cond.shape + s2 = other.cond.shape + if s1 != s2: + if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen + return False + + mult_min = lcm(s1[1], s2[1]) + diff = mult_min // min(s1[1], s2[1]) + if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much + return False + return True + + def concat(self, others): + conds = [self.cond] + crossattn_max_len = self.cond.shape[1] + for x in others: + c = x.cond + crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) + conds.append(c) + + out = [] + for c in conds: + if c.shape[1] < crossattn_max_len: + c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result + out.append(c) + return torch.cat(out) diff --git a/backend/headless/fcbh/controlnet.py b/backend/headless/fcbh/controlnet.py index dcdd0c1f..ab6c38f6 100644 --- a/backend/headless/fcbh/controlnet.py +++ b/backend/headless/fcbh/controlnet.py @@ -156,7 +156,7 @@ class ControlNet(ControlBase): context = cond['c_crossattn'] - y = cond.get('c_adm', None) + y = cond.get('y', None) if y is not None: y = y.to(self.control_model.dtype) control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) diff --git a/backend/headless/fcbh/model_base.py b/backend/headless/fcbh/model_base.py index f3f708f7..86525d99 100644 --- a/backend/headless/fcbh/model_base.py +++ b/backend/headless/fcbh/model_base.py @@ -4,6 +4,7 @@ from fcbh.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmen from fcbh.ldm.modules.diffusionmodules.util import make_beta_schedule from fcbh.ldm.modules.diffusionmodules.openaimodel import Timestep import fcbh.model_management +import fcbh.conds import numpy as np from enum import Enum from . import utils @@ -49,7 +50,7 @@ class BaseModel(torch.nn.Module): self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) - def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}): + def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): if c_concat is not None: xc = torch.cat([x] + [c_concat], dim=1) else: @@ -59,9 +60,10 @@ class BaseModel(torch.nn.Module): xc = xc.to(dtype) t = t.to(dtype) context = context.to(dtype) - if c_adm is not None: - c_adm = c_adm.to(dtype) - return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float() + extra_conds = {} + for o in kwargs: + extra_conds[o] = kwargs[o].to(dtype) + return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() def get_dtype(self): return self.diffusion_model.dtype @@ -72,7 +74,8 @@ class BaseModel(torch.nn.Module): def encode_adm(self, **kwargs): return None - def cond_concat(self, **kwargs): + def extra_conds(self, **kwargs): + out = {} if self.inpaint_model: concat_keys = ("mask", "masked_image") cond_concat = [] @@ -101,8 +104,12 @@ class BaseModel(torch.nn.Module): cond_concat.append(torch.ones_like(noise)[:,:1]) elif ck == "masked_image": cond_concat.append(blank_inpaint_image_like(noise)) - return cond_concat - return None + data = torch.cat(cond_concat, dim=1) + out['c_concat'] = fcbh.conds.CONDNoiseShape(data) + adm = self.encode_adm(**kwargs) + if adm is not None: + out['y'] = fcbh.conds.CONDRegular(adm) + return out def load_model_weights(self, sd, unet_prefix=""): to_load = {} diff --git a/backend/headless/fcbh/sample.py b/backend/headless/fcbh/sample.py index b6e0fddc..55946160 100644 --- a/backend/headless/fcbh/sample.py +++ b/backend/headless/fcbh/sample.py @@ -1,6 +1,7 @@ import torch import fcbh.model_management import fcbh.samplers +import fcbh.conds import fcbh.utils import math import numpy as np @@ -33,22 +34,24 @@ def prepare_mask(noise_mask, shape, device): noise_mask = noise_mask.to(device) return noise_mask -def broadcast_cond(cond, batch, device): - """broadcasts conditioning to the batch size""" - copy = [] - for p in cond: - t = fcbh.utils.repeat_to_batch_size(p[0], batch) - t = t.to(device) - copy += [[t] + p[1:]] - return copy - def get_models_from_cond(cond, model_type): models = [] for c in cond: - if model_type in c[1]: - models += [c[1][model_type]] + if model_type in c: + models += [c[model_type]] return models +def convert_cond(cond): + out = [] + for c in cond: + temp = c[1].copy() + model_conds = temp.get("model_conds", {}) + if c[0] is not None: + model_conds["c_crossattn"] = fcbh.conds.CONDCrossAttn(c[0]) + temp["model_conds"] = model_conds + out.append(temp) + return out + def get_additional_models(positive, negative, dtype): """loads additional models in positive and negative conditioning""" control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) @@ -72,6 +75,8 @@ def cleanup_additional_models(models): def prepare_sampling(model, noise_shape, positive, negative, noise_mask): device = model.load_device + positive = convert_cond(positive) + negative = convert_cond(negative) if noise_mask is not None: noise_mask = prepare_mask(noise_mask, noise_shape, device) @@ -81,9 +86,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask): fcbh.model_management.load_models_gpu([model] + models, fcbh.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory) real_model = model.model - positive_copy = broadcast_cond(positive, noise_shape[0], device) - negative_copy = broadcast_cond(negative, noise_shape[0], device) - return real_model, positive_copy, negative_copy, noise_mask, models + return real_model, positive, negative, noise_mask, models def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): diff --git a/backend/headless/fcbh/samplers.py b/backend/headless/fcbh/samplers.py index fe414995..91050a4e 100644 --- a/backend/headless/fcbh/samplers.py +++ b/backend/headless/fcbh/samplers.py @@ -2,47 +2,44 @@ from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import external as k_diffusion_external from .extra_samplers import uni_pc import torch +import enum from fcbh import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps import math from fcbh import model_base import fcbh.utils +import fcbh.conds -def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) - return abs(a*b) // math.gcd(a, b) #The main sampling function shared by all the samplers #Returns predicted noise def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - def get_area_and_mult(cond, x_in, timestep_in): + def get_area_and_mult(conds, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 - if 'timestep_start' in cond[1]: - timestep_start = cond[1]['timestep_start'] + + if 'timestep_start' in conds: + timestep_start = conds['timestep_start'] if timestep_in[0] > timestep_start: return None - if 'timestep_end' in cond[1]: - timestep_end = cond[1]['timestep_end'] + if 'timestep_end' in conds: + timestep_end = conds['timestep_end'] if timestep_in[0] < timestep_end: return None - if 'area' in cond[1]: - area = cond[1]['area'] - if 'strength' in cond[1]: - strength = cond[1]['strength'] - - adm_cond = None - if 'adm_encoded' in cond[1]: - adm_cond = cond[1]['adm_encoded'] + if 'area' in conds: + area = conds['area'] + if 'strength' in conds: + strength = conds['strength'] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if 'mask' in cond[1]: + if 'mask' in conds: # Scale the mask to the size of the input # The mask should have been resized as we began the sampling process mask_strength = 1.0 - if "mask_strength" in cond[1]: - mask_strength = cond[1]["mask_strength"] - mask = cond[1]['mask'] + if "mask_strength" in conds: + mask_strength = conds["mask_strength"] + mask = conds['mask'] assert(mask.shape[1] == x_in.shape[2]) assert(mask.shape[2] == x_in.shape[3]) mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength @@ -51,7 +48,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod mask = torch.ones_like(input_x) mult = mask * strength - if 'mask' not in cond[1]: + if 'mask' not in conds: rr = 8 if area[2] != 0: for t in range(rr): @@ -67,27 +64,17 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) conditionning = {} - conditionning['c_crossattn'] = cond[0] - - if 'concat' in cond[1]: - cond_concat_in = cond[1]['concat'] - if cond_concat_in is not None and len(cond_concat_in) > 0: - cropped = [] - for x in cond_concat_in: - cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - cropped.append(cr) - conditionning['c_concat'] = torch.cat(cropped, dim=1) - - if adm_cond is not None: - conditionning['c_adm'] = adm_cond + model_conds = conds["model_conds"] + for c in model_conds: + conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) control = None - if 'control' in cond[1]: - control = cond[1]['control'] + if 'control' in conds: + control = conds['control'] patches = None - if 'gligen' in cond[1]: - gligen = cond[1]['gligen'] + if 'gligen' in conds: + gligen = conds['gligen'] patches = {} gligen_type = gligen[0] gligen_model = gligen[1] @@ -105,22 +92,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod return True if c1.keys() != c2.keys(): return False - if 'c_crossattn' in c1: - s1 = c1['c_crossattn'].shape - s2 = c2['c_crossattn'].shape - if s1 != s2: - if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen - return False - - mult_min = lcm(s1[1], s2[1]) - diff = mult_min // min(s1[1], s2[1]) - if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much - return False - if 'c_concat' in c1: - if c1['c_concat'].shape != c2['c_concat'].shape: - return False - if 'c_adm' in c1: - if c1['c_adm'].shape != c2['c_adm'].shape: + for k in c1: + if not c1[k].can_concat(c2[k]): return False return True @@ -149,31 +122,19 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod c_concat = [] c_adm = [] crossattn_max_len = 0 - for x in c_list: - if 'c_crossattn' in x: - c = x['c_crossattn'] - if crossattn_max_len == 0: - crossattn_max_len = c.shape[1] - else: - crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) - c_crossattn.append(c) - if 'c_concat' in x: - c_concat.append(x['c_concat']) - if 'c_adm' in x: - c_adm.append(x['c_adm']) - out = {} - c_crossattn_out = [] - for c in c_crossattn: - if c.shape[1] < crossattn_max_len: - c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result - c_crossattn_out.append(c) - if len(c_crossattn_out) > 0: - out['c_crossattn'] = torch.cat(c_crossattn_out) - if len(c_concat) > 0: - out['c_concat'] = torch.cat(c_concat) - if len(c_adm) > 0: - out['c_adm'] = torch.cat(c_adm) + temp = {} + for x in c_list: + for k in x: + cur = temp.get(k, []) + cur.append(x[k]) + temp[k] = cur + + out = {} + for k in temp: + conds = temp[k] + out[k] = conds[0].concat(conds[1:]) + return out def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options): @@ -389,19 +350,19 @@ def resolve_areas_and_cond_masks(conditions, h, w, device): # While we're doing this, we can also resolve the mask device and scaling for performance reasons for i in range(len(conditions)): c = conditions[i] - if 'area' in c[1]: - area = c[1]['area'] + if 'area' in c: + area = c['area'] if area[0] == "percentage": - modified = c[1].copy() + modified = c.copy() area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w)) modified['area'] = area - c = [c[0], modified] + c = modified conditions[i] = c - if 'mask' in c[1]: - mask = c[1]['mask'] + if 'mask' in c: + mask = c['mask'] mask = mask.to(device=device) - modified = c[1].copy() + modified = c.copy() if len(mask.shape) == 2: mask = mask.unsqueeze(0) if mask.shape[1] != h or mask.shape[2] != w: @@ -422,37 +383,39 @@ def resolve_areas_and_cond_masks(conditions, h, w, device): modified['area'] = area modified['mask'] = mask - conditions[i] = [c[0], modified] + conditions[i] = modified def create_cond_with_same_area_if_none(conds, c): - if 'area' not in c[1]: + if 'area' not in c: return - c_area = c[1]['area'] + c_area = c['area'] smallest = None for x in conds: - if 'area' in x[1]: - a = x[1]['area'] + if 'area' in x: + a = x['area'] if c_area[2] >= a[2] and c_area[3] >= a[3]: if a[0] + a[2] >= c_area[0] + c_area[2]: if a[1] + a[3] >= c_area[1] + c_area[3]: if smallest is None: smallest = x - elif 'area' not in smallest[1]: + elif 'area' not in smallest: smallest = x else: - if smallest[1]['area'][0] * smallest[1]['area'][1] > a[0] * a[1]: + if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]: smallest = x else: if smallest is None: smallest = x if smallest is None: return - if 'area' in smallest[1]: - if smallest[1]['area'] == c_area: + if 'area' in smallest: + if smallest['area'] == c_area: return - n = c[1].copy() - conds += [[smallest[0], n]] + + out = c.copy() + out['model_conds'] = smallest['model_conds'].copy() #TODO: which fields should be copied? + conds += [out] def calculate_start_end_timesteps(model, conds): for t in range(len(conds)): @@ -460,18 +423,18 @@ def calculate_start_end_timesteps(model, conds): timestep_start = None timestep_end = None - if 'start_percent' in x[1]: - timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0))) - if 'end_percent' in x[1]: - timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0))) + if 'start_percent' in x: + timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['start_percent'] * 999.0))) + if 'end_percent' in x: + timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['end_percent'] * 999.0))) if (timestep_start is not None) or (timestep_end is not None): - n = x[1].copy() + n = x.copy() if (timestep_start is not None): n['timestep_start'] = timestep_start if (timestep_end is not None): n['timestep_end'] = timestep_end - conds[t] = [x[0], n] + conds[t] = n def pre_run_control(model, conds): for t in range(len(conds)): @@ -480,8 +443,8 @@ def pre_run_control(model, conds): timestep_start = None timestep_end = None percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0)) - if 'control' in x[1]: - x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function) + if 'control' in x: + x['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function) def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] @@ -490,16 +453,16 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): uncond_other = [] for t in range(len(conds)): x = conds[t] - if 'area' not in x[1]: - if name in x[1] and x[1][name] is not None: - cond_cnets.append(x[1][name]) + if 'area' not in x: + if name in x and x[name] is not None: + cond_cnets.append(x[name]) else: cond_other.append((x, t)) for t in range(len(uncond)): x = uncond[t] - if 'area' not in x[1]: - if name in x[1] and x[1][name] is not None: - uncond_cnets.append(x[1][name]) + if 'area' not in x: + if name in x and x[name] is not None: + uncond_cnets.append(x[name]) else: uncond_other.append((x, t)) @@ -509,47 +472,35 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): for x in range(len(cond_cnets)): temp = uncond_other[x % len(uncond_other)] o = temp[0] - if name in o[1] and o[1][name] is not None: - n = o[1].copy() + if name in o and o[name] is not None: + n = o.copy() n[name] = uncond_fill_func(cond_cnets, x) - uncond += [[o[0], n]] + uncond += [n] else: - n = o[1].copy() + n = o.copy() n[name] = uncond_fill_func(cond_cnets, x) - uncond[temp[1]] = [o[0], n] + uncond[temp[1]] = n -def encode_adm(model, conds, batch_size, width, height, device, prompt_type): +def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs): for t in range(len(conds)): x = conds[t] - adm_out = None - if 'adm' in x[1]: - adm_out = x[1]["adm"] - else: - params = x[1].copy() - params["width"] = params.get("width", width * 8) - params["height"] = params.get("height", height * 8) - params["prompt_type"] = params.get("prompt_type", prompt_type) - adm_out = model.encode_adm(device=device, **params) - - if adm_out is not None: - x[1] = x[1].copy() - x[1]["adm_encoded"] = fcbh.utils.repeat_to_batch_size(adm_out, batch_size).to(device) - - return conds - -def encode_cond(model_function, key, conds, device, **kwargs): - for t in range(len(conds)): - x = conds[t] - params = x[1].copy() + params = x.copy() params["device"] = device + params["noise"] = noise + params["width"] = params.get("width", noise.shape[3] * 8) + params["height"] = params.get("height", noise.shape[2] * 8) + params["prompt_type"] = params.get("prompt_type", prompt_type) for k in kwargs: if k not in params: params[k] = kwargs[k] out = model_function(**params) - if out is not None: - x[1] = x[1].copy() - x[1][key] = out + x = x.copy() + model_conds = x['model_conds'].copy() + for k in out: + model_conds[k] = out[k] + x['model_conds'] = model_conds + conds[t] = x return conds class Sampler: @@ -667,19 +618,15 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model pre_run_control(model_wrap, negative + positive) - apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) if latent_image is not None: latent_image = model.process_latent_in(latent_image) - if model.is_adm(): - positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive") - negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative") - - if hasattr(model, 'cond_concat'): - positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) + if hasattr(model, 'extra_conds'): + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} diff --git a/fooocus_version.py b/fooocus_version.py index 0eab2d2d..5f627952 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.739' +version = '2.1.740' diff --git a/modules/async_worker.py b/modules/async_worker.py index 820bdc3c..5fc925ea 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -174,7 +174,6 @@ def worker(): loras += [(inpaint_patch_model_path, 1.0)] print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}') goals.append('inpaint') - sampler_name = 'dpmpp_2m_sde_gpu' # only support the patched dpmpp_2m_sde_gpu if current_tab == 'ip' or \ advanced_parameters.mixing_image_prompt_and_inpaint or \ advanced_parameters.mixing_image_prompt_and_vary_upscale: diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 8557ac2b..b32f8111 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -342,7 +342,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height sigma_max = float(sigma_max.cpu().numpy()) print(f'[Sampler] sigma_min = {sigma_min}, sigma_max = {sigma_max}') - modules.patch.globalBrownianTreeNoiseSampler = BrownianTreeNoiseSampler( + modules.patch.BrownianTreeNoiseSamplerPatched.global_init( empty_latent['samples'].to(fcbh.model_management.get_torch_device()), sigma_min, sigma_max, seed=image_seed, cpu=False) diff --git a/modules/patch.py b/modules/patch.py index 8303c22a..98b56d5e 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -23,9 +23,10 @@ import args_manager import modules.advanced_parameters as advanced_parameters import warnings import safetensors.torch +import modules.constants as constants from fcbh.k_diffusion import utils -from fcbh.k_diffusion.sampling import trange +from fcbh.k_diffusion.sampling import BatchedBrownianTree from fcbh.ldm.modules.diffusionmodules.openaimodel import timestep_embedding, forward_timestep_embed @@ -280,68 +281,58 @@ def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs): return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() -globalBrownianTreeNoiseSampler = None - - -@torch.no_grad() -def sample_dpmpp_fooocus_2m_sde_inpaint_seamless(model, x, sigmas, extra_args=None, callback=None, - disable=None, eta=1., s_noise=1., **kwargs): - print('[Sampler] Fooocus sampler is activated.') - - seed = extra_args.get("seed", None) - assert isinstance(seed, int) - - energy_generator = torch.Generator(device='cpu') - energy_generator.manual_seed(seed + 1) # avoid bad results by using different seeds. - - def get_energy(): - return torch.randn(x.size(), dtype=x.dtype, generator=energy_generator, device="cpu").to(x) - - extra_args = {} if extra_args is None else extra_args - s_in = x.new_ones([x.shape[0]]) - - old_denoised, h_last, h = None, None, None - - latent_processor = model.inner_model.inner_model.inner_model.process_latent_in - inpaint_latent = None - inpaint_mask = None - +def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): if inpaint_worker.current_task is not None: + if getattr(self, 'energy_generator', None) is None: + # avoid bad results by using different seeds. + self.energy_generator = torch.Generator(device='cpu').manual_seed((seed + 1) % constants.MAX_SEED) + + latent_processor = self.inner_model.inner_model.inner_model.process_latent_in inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x) inpaint_mask = inpaint_worker.current_task.latent_mask.to(x) + energy_sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(x.shape) - 1)) + current_energy = torch.randn(x.size(), dtype=x.dtype, generator=self.energy_generator, device="cpu").to(x) * energy_sigma + x = x * inpaint_mask + (inpaint_latent + current_energy) * (1.0 - inpaint_mask) - def blend_latent(a, b, w): - return a * w + b * (1 - w) + out = self.inner_model(x, sigma, + cond=cond, + uncond=uncond, + cond_scale=cond_scale, + model_options=model_options, + seed=seed) - for i in trange(len(sigmas) - 1, disable=disable): - if inpaint_latent is None: - denoised = model(x, sigmas[i] * s_in, **extra_args) - else: - energy = get_energy() * sigmas[i] + inpaint_latent - x_prime = blend_latent(x, energy, inpaint_mask) - denoised = model(x_prime, sigmas[i] * s_in, **extra_args) - denoised = blend_latent(denoised, inpaint_latent, inpaint_mask) - if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - if sigmas[i + 1] == 0: - x = denoised - else: - t, s = -sigmas[i].log(), -sigmas[i + 1].log() - h = s - t - eta_h = eta * h + out = out * inpaint_mask + inpaint_latent * (1.0 - inpaint_mask) + else: + out = self.inner_model(x, sigma, + cond=cond, + uncond=uncond, + cond_scale=cond_scale, + model_options=model_options, + seed=seed) + return out - x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised - if old_denoised is not None: - r = h_last / h - x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) - x = x + globalBrownianTreeNoiseSampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * ( - -2 * eta_h).expm1().neg().sqrt() * s_noise +class BrownianTreeNoiseSamplerPatched: + transform = None + tree = None - old_denoised = denoised - h_last = h + @staticmethod + def global_init(x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False): + t0, t1 = transform(torch.as_tensor(sigma_min)), transform(torch.as_tensor(sigma_max)) - return x + BrownianTreeNoiseSamplerPatched.transform = transform + BrownianTreeNoiseSamplerPatched.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu) + + def __init__(self, *args, **kwargs): + pass + + @staticmethod + def __call__(sigma, sigma_next): + transform = BrownianTreeNoiseSamplerPatched.transform + tree = BrownianTreeNoiseSamplerPatched.tree + + t0, t1 = transform(torch.as_tensor(sigma)), transform(torch.as_tensor(sigma_next)) + return tree(t0, t1) / (t1 - t0).abs().sqrt() def timed_adm(y, timesteps): @@ -523,10 +514,11 @@ def patch_all(): fcbh.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched fcbh.cldm.cldm.ControlNet.forward = patched_cldm_forward fcbh.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward - fcbh.k_diffusion.sampling.sample_dpmpp_2m_sde_gpu = sample_dpmpp_fooocus_2m_sde_inpaint_seamless fcbh.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward fcbh.model_base.SDXL.encode_adm = sdxl_encode_adm_patched fcbh.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method + fcbh.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward + fcbh.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched warnings.filterwarnings(action='ignore', module='torchsde') diff --git a/modules/sample_hijack.py b/modules/sample_hijack.py index bf7ea096..30e47b65 100644 --- a/modules/sample_hijack.py +++ b/modules/sample_hijack.py @@ -3,10 +3,10 @@ import fcbh.samplers import fcbh.model_management from fcbh.model_base import SDXLRefiner, SDXL +from fcbh.conds import CONDRegular from fcbh.sample import get_additional_models, get_models_from_cond, cleanup_additional_models from fcbh.samplers import resolve_areas_and_cond_masks, wrap_model, calculate_start_end_timesteps, \ - create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_adm, \ - encode_cond + create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_model_conds current_refiner = None @@ -15,15 +15,13 @@ refiner_switch_step = -1 @torch.no_grad() @torch.inference_mode() -def clip_separate(cond, target_model=None, target_clip=None): - c, p = cond[0] +def clip_separate_inner(c, p, target_model=None, target_clip=None): if target_model is None or isinstance(target_model, SDXLRefiner): c = c[..., -1280:].clone() - p = {"pooled_output": p["pooled_output"].clone()} elif isinstance(target_model, SDXL): c = c.clone() - p = {"pooled_output": p["pooled_output"].clone()} else: + p = None c = c[..., :768].clone() final_layer_norm = target_clip.cond_stage_model.clip_l.transformer.text_model.final_layer_norm @@ -43,9 +41,42 @@ def clip_separate(cond, target_model=None, target_clip=None): final_layer_norm.to(device=final_layer_norm_origin_device, dtype=final_layer_norm_origin_dtype) c = c.to(device=c_origin_device, dtype=c_origin_dtype) + return c, p - p = {} - return [[c, p]] + +@torch.no_grad() +@torch.inference_mode() +def clip_separate(cond, target_model=None, target_clip=None): + results = [] + + for c, px in cond: + p = px.get('pooled_output', None) + c, p = clip_separate_inner(c, p, target_model=target_model, target_clip=target_clip) + p = {} if p is None else {'pooled_output': p.clone()} + results.append([c, p]) + + return results + + +@torch.no_grad() +@torch.inference_mode() +def clip_separate_after_preparation(cond, target_model=None, target_clip=None): + results = [] + + for x in cond: + p = x.get('pooled_output', None) + c = x['model_conds']['c_crossattn'].cond + + c, p = clip_separate_inner(c, p, target_model=target_model, target_clip=target_clip) + + result = {'model_conds': {'c_crossattn': CONDRegular(c)}} + + if p is not None: + result['pooled_output'] = p.clone() + + results.append(result) + + return results @torch.no_grad() @@ -73,31 +104,24 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas # pre_run_control(model_wrap, negative + positive) pre_run_control(model_wrap, positive) # negative is not necessary in Fooocus, 0.5s faster. - apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) if latent_image is not None: latent_image = model.process_latent_in(latent_image) - if model.is_adm(): - positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive") - negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative") - - if hasattr(model, 'cond_concat'): - positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) + if hasattr(model, 'extra_conds'): + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} - if current_refiner is not None and current_refiner.model.is_adm(): - positive_refiner = clip_separate(positive, target_model=current_refiner.model) - negative_refiner = clip_separate(negative, target_model=current_refiner.model) + if current_refiner is not None and hasattr(current_refiner.model, 'extra_conds'): + positive_refiner = clip_separate_after_preparation(positive, target_model=current_refiner.model) + negative_refiner = clip_separate_after_preparation(negative, target_model=current_refiner.model) - positive_refiner = encode_adm(current_refiner.model, positive_refiner, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive") - negative_refiner = encode_adm(current_refiner.model, negative_refiner, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative") - - positive_refiner[0][1]['adm_encoded'].to(positive[0][1]['adm_encoded']) - negative_refiner[0][1]['adm_encoded'].to(negative[0][1]['adm_encoded']) + positive_refiner = encode_model_conds(current_refiner.model.extra_conds, positive_refiner, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative_refiner = encode_model_conds(current_refiner.model.extra_conds, negative_refiner, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) def refiner_switch(): cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) diff --git a/webui.py b/webui.py index 180825d8..ee0975a9 100644 --- a/webui.py +++ b/webui.py @@ -148,9 +148,9 @@ with shared.gradio_root: with gr.TabItem(label='Inpaint or Outpaint (beta)') as inpaint_tab: inpaint_input_image = grh.Image(label='Drag above image to here', source='upload', type='numpy', tool='sketch', height=500, brush_color="#FFFFFF", elem_id='inpaint_canvas') - gr.HTML('Outpaint Expansion (\U0001F4D4 Document):') + gr.HTML('Outpaint Expansion Direction:') outpaint_selections = gr.CheckboxGroup(choices=['Left', 'Right', 'Top', 'Bottom'], value=[], label='Outpaint', show_label=False, container=False) - gr.HTML('* \"Inpaint or Outpaint\" is powered by the sampler \"DPMPP Fooocus Seamless 2M SDE Karras Inpaint Sampler\" (beta)') + gr.HTML('* Powered by Fooocus Inpaint Engine (beta) \U0001F4D4 Document') switch_js = "(x) => {if(x){setTimeout(() => window.scrollTo({ top: 850, behavior: 'smooth' }), 50);}else{setTimeout(() => window.scrollTo({ top: 0, behavior: 'smooth' }), 50);} return x}" down_js = "() => {setTimeout(() => window.scrollTo({ top: 850, behavior: 'smooth' }), 50);}"