From 480a7222c518bc4df18d63d890698d3e55962d8d Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Tue, 3 Oct 2023 14:05:14 -0700 Subject: [PATCH] Improved the scheduling of ADM guidance and CFG mimicking for better visual quality in high frequency domain and small objects. Improved the scheduling of ADM guidance and CFG mimicking for better visual quality in high frequency domain and small objects. --- fooocus_version.py | 2 +- modules/patch.py | 68 ++++++++++++++++++++++------------------------ update_log.md | 4 +++ webui.py | 5 ++-- 4 files changed, 40 insertions(+), 39 deletions(-) diff --git a/fooocus_version.py b/fooocus_version.py index 7230dbec..ad909872 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.0.81' +version = '2.0.82' diff --git a/modules/patch.py b/modules/patch.py index 374eb9d6..55badb43 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -147,27 +147,20 @@ def calculate_weight_patched(self, patches, weight, key): return weight -def get_adaptive_weight_k(cfg_scale): - w = float(cfg_scale) - w -= 7.0 - w /= 3.0 - w = max(w, 0.01) - w = min(w, 0.99) - return w - - -def compute_cfg(uncond, cond, cfg_scale): +def compute_cfg(uncond, cond, cfg_scale, t): global adaptive_cfg - k = adaptive_cfg * get_adaptive_weight_k(cfg_scale) - x_cfg = uncond + cfg_scale * (cond - uncond) - ro_pos = torch.std(cond, dim=(1, 2, 3), keepdim=True) - ro_cfg = torch.std(x_cfg, dim=(1, 2, 3), keepdim=True) + mimic_cfg = float(adaptive_cfg) + real_cfg = float(cfg_scale) - x_rescaled = x_cfg * (ro_pos / ro_cfg) - x_final = k * x_rescaled + (1.0 - k) * x_cfg + real_eps = uncond + real_cfg * (cond - uncond) - return x_final + if cfg_scale < adaptive_cfg: + return real_eps + + mimicked_eps = uncond + mimic_cfg * (cond - uncond) + + return real_eps * t + mimicked_eps * (1 - t) def patched_sampler_cfg_function(args): @@ -184,7 +177,7 @@ def patched_sampler_cfg_function(args): positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0) positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (1.0 - alpha) - return compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted, cfg_scale=cfg_scale) + return compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted, cfg_scale=cfg_scale, t=t) def patched_discrete_eps_ddpm_denoiser_forward(self, input, sigma, **kwargs): @@ -210,10 +203,8 @@ def sdxl_encode_adm_patched(self, **kwargs): clip_pooled = comfy.model_base.sdxl_pooled(kwargs, self.noise_augmentor) width = kwargs.get("width", 768) height = kwargs.get("height", 768) - crop_w = kwargs.get("crop_w", 0) - crop_h = kwargs.get("crop_h", 0) - target_width = kwargs.get("target_width", width) - target_height = kwargs.get("target_height", height) + target_width = width + target_height = height if kwargs.get("prompt_type", "") == "negative": width = float(width) * negative_adm_scale @@ -225,20 +216,22 @@ def sdxl_encode_adm_patched(self, **kwargs): # Avoid artifacts width = int(width) height = int(height) - crop_w = int(crop_w) - crop_h = int(crop_h) + crop_w = 0 + crop_h = 0 target_width = int(target_width) target_height = int(target_height) - out = [] - out.append(self.embedder(torch.Tensor([height]))) - out.append(self.embedder(torch.Tensor([width]))) - out.append(self.embedder(torch.Tensor([crop_h]))) - out.append(self.embedder(torch.Tensor([crop_w]))) - out.append(self.embedder(torch.Tensor([target_height]))) - out.append(self.embedder(torch.Tensor([target_width]))) - flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) - return torch.cat((clip_pooled.to(flat.device), flat), dim=1) + out_a = [self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])), + self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])), + self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width]))] + flat_a = torch.flatten(torch.cat(out_a)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) + + out_b = [self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width])), + self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])), + self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width]))] + flat_b = torch.flatten(torch.cat(out_b)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) + + return torch.cat((clip_pooled.to(flat_a.device), flat_a, clip_pooled.to(flat_b.device), flat_b), dim=1) def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs): @@ -346,9 +339,12 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control= transformer_options["current_index"] = 0 transformer_patches = transformer_options.get("patches", {}) - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" + if isinstance(y, torch.Tensor) and int(y.dim()) == 2 and int(y.shape[1]) == 5632: + t = (timesteps / 999.0)[:, None].clone().to(x) ** 2.0 + ya = y[..., :2816].clone() + yb = y[..., 2816:].clone() + y = t * ya + (1 - t) * yb + hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) emb = self.time_embed(t_emb) diff --git a/update_log.md b/update_log.md index 6fcb3f89..8df10931 100644 --- a/update_log.md +++ b/update_log.md @@ -1,5 +1,9 @@ # 2.0.80 +* Improved the scheduling of ADM guidance and CFG mimicking for better visual quality in high frequency domain and small objects. + +# 2.0.80 + * Rework many patches and some UI details. * Speed up processing. * Move Colab to independent branch. diff --git a/webui.py b/webui.py index 2f7f1d5c..dd33442e 100644 --- a/webui.py +++ b/webui.py @@ -165,8 +165,9 @@ with shared.gradio_root: step=0.001, value=1.5, info='The scaler multiplied to positive ADM (use 1.0 to disable). ') adm_scaler_negative = gr.Slider(label='Negative ADM Guidance Scaler', minimum=0.1, maximum=3.0, step=0.001, value=0.8, info='The scaler multiplied to negative ADM (use 1.0 to disable). ') - adaptive_cfg = gr.Slider(label='CFG Rescale from TSNR', minimum=0.0, maximum=1.0, - step=0.001, value=0.3, info='Enabling Fooocus\'s implementation of CFG re-weighting for TSNR (use 0 to disable, more effective when CFG > 7).') + adaptive_cfg = gr.Slider(label='CFG Mimicking from TSNR', minimum=1.0, maximum=30.0, step=0.01, value=5.0, + info='Enabling Fooocus\'s implementation of CFG mimicking for TSNR ' + '(effective when real CFG > mimicked CFG).') sampler_name = gr.Dropdown(label='Sampler', choices=flags.sampler_list, value=flags.default_sampler, info='Only effective in non-inpaint mode.') def dev_mode_checked(r):