Improved the scheduling of ADM guidance and CFG mimicking for better visual quality in high frequency domain and small objects.

Improved the scheduling of ADM guidance and CFG mimicking for better visual quality in high frequency domain and small objects.
This commit is contained in:
lllyasviel 2023-10-03 14:05:14 -07:00 committed by GitHub
parent 2f31d9e5a7
commit 480a7222c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 39 deletions

View File

@ -1 +1 @@
version = '2.0.81'
version = '2.0.82'

View File

@ -147,27 +147,20 @@ def calculate_weight_patched(self, patches, weight, key):
return weight
def get_adaptive_weight_k(cfg_scale):
w = float(cfg_scale)
w -= 7.0
w /= 3.0
w = max(w, 0.01)
w = min(w, 0.99)
return w
def compute_cfg(uncond, cond, cfg_scale):
def compute_cfg(uncond, cond, cfg_scale, t):
global adaptive_cfg
k = adaptive_cfg * get_adaptive_weight_k(cfg_scale)
x_cfg = uncond + cfg_scale * (cond - uncond)
ro_pos = torch.std(cond, dim=(1, 2, 3), keepdim=True)
ro_cfg = torch.std(x_cfg, dim=(1, 2, 3), keepdim=True)
mimic_cfg = float(adaptive_cfg)
real_cfg = float(cfg_scale)
x_rescaled = x_cfg * (ro_pos / ro_cfg)
x_final = k * x_rescaled + (1.0 - k) * x_cfg
real_eps = uncond + real_cfg * (cond - uncond)
return x_final
if cfg_scale < adaptive_cfg:
return real_eps
mimicked_eps = uncond + mimic_cfg * (cond - uncond)
return real_eps * t + mimicked_eps * (1 - t)
def patched_sampler_cfg_function(args):
@ -184,7 +177,7 @@ def patched_sampler_cfg_function(args):
positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0)
positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (1.0 - alpha)
return compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted, cfg_scale=cfg_scale)
return compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted, cfg_scale=cfg_scale, t=t)
def patched_discrete_eps_ddpm_denoiser_forward(self, input, sigma, **kwargs):
@ -210,10 +203,8 @@ def sdxl_encode_adm_patched(self, **kwargs):
clip_pooled = comfy.model_base.sdxl_pooled(kwargs, self.noise_augmentor)
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height)
target_width = width
target_height = height
if kwargs.get("prompt_type", "") == "negative":
width = float(width) * negative_adm_scale
@ -225,20 +216,22 @@ def sdxl_encode_adm_patched(self, **kwargs):
# Avoid artifacts
width = int(width)
height = int(height)
crop_w = int(crop_w)
crop_h = int(crop_h)
crop_w = 0
crop_h = 0
target_width = int(target_width)
target_height = int(target_height)
out = []
out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([target_height])))
out.append(self.embedder(torch.Tensor([target_width])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
out_a = [self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])),
self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])),
self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width]))]
flat_a = torch.flatten(torch.cat(out_a)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
out_b = [self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width])),
self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])),
self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width]))]
flat_b = torch.flatten(torch.cat(out_b)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat_a.device), flat_a, clip_pooled.to(flat_b.device), flat_b), dim=1)
def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs):
@ -346,9 +339,12 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=
transformer_options["current_index"] = 0
transformer_patches = transformer_options.get("patches", {})
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
if isinstance(y, torch.Tensor) and int(y.dim()) == 2 and int(y.shape[1]) == 5632:
t = (timesteps / 999.0)[:, None].clone().to(x) ** 2.0
ya = y[..., :2816].clone()
yb = y[..., 2816:].clone()
y = t * ya + (1 - t) * yb
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
emb = self.time_embed(t_emb)

View File

@ -1,5 +1,9 @@
# 2.0.80
* Improved the scheduling of ADM guidance and CFG mimicking for better visual quality in high frequency domain and small objects.
# 2.0.80
* Rework many patches and some UI details.
* Speed up processing.
* Move Colab to independent branch.

View File

@ -165,8 +165,9 @@ with shared.gradio_root:
step=0.001, value=1.5, info='The scaler multiplied to positive ADM (use 1.0 to disable). ')
adm_scaler_negative = gr.Slider(label='Negative ADM Guidance Scaler', minimum=0.1, maximum=3.0,
step=0.001, value=0.8, info='The scaler multiplied to negative ADM (use 1.0 to disable). ')
adaptive_cfg = gr.Slider(label='CFG Rescale from TSNR', minimum=0.0, maximum=1.0,
step=0.001, value=0.3, info='Enabling Fooocus\'s implementation of CFG re-weighting for TSNR (use 0 to disable, more effective when CFG > 7).')
adaptive_cfg = gr.Slider(label='CFG Mimicking from TSNR', minimum=1.0, maximum=30.0, step=0.01, value=5.0,
info='Enabling Fooocus\'s implementation of CFG mimicking for TSNR '
'(effective when real CFG > mimicked CFG).')
sampler_name = gr.Dropdown(label='Sampler', choices=flags.sampler_list, value=flags.default_sampler, info='Only effective in non-inpaint mode.')
def dev_mode_checked(r):