diff --git a/fooocus_version.py b/fooocus_version.py index 7230dbec..ad909872 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.0.81' +version = '2.0.82' diff --git a/modules/patch.py b/modules/patch.py index 374eb9d6..55badb43 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -147,27 +147,20 @@ def calculate_weight_patched(self, patches, weight, key): return weight -def get_adaptive_weight_k(cfg_scale): - w = float(cfg_scale) - w -= 7.0 - w /= 3.0 - w = max(w, 0.01) - w = min(w, 0.99) - return w - - -def compute_cfg(uncond, cond, cfg_scale): +def compute_cfg(uncond, cond, cfg_scale, t): global adaptive_cfg - k = adaptive_cfg * get_adaptive_weight_k(cfg_scale) - x_cfg = uncond + cfg_scale * (cond - uncond) - ro_pos = torch.std(cond, dim=(1, 2, 3), keepdim=True) - ro_cfg = torch.std(x_cfg, dim=(1, 2, 3), keepdim=True) + mimic_cfg = float(adaptive_cfg) + real_cfg = float(cfg_scale) - x_rescaled = x_cfg * (ro_pos / ro_cfg) - x_final = k * x_rescaled + (1.0 - k) * x_cfg + real_eps = uncond + real_cfg * (cond - uncond) - return x_final + if cfg_scale < adaptive_cfg: + return real_eps + + mimicked_eps = uncond + mimic_cfg * (cond - uncond) + + return real_eps * t + mimicked_eps * (1 - t) def patched_sampler_cfg_function(args): @@ -184,7 +177,7 @@ def patched_sampler_cfg_function(args): positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0) positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (1.0 - alpha) - return compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted, cfg_scale=cfg_scale) + return compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted, cfg_scale=cfg_scale, t=t) def patched_discrete_eps_ddpm_denoiser_forward(self, input, sigma, **kwargs): @@ -210,10 +203,8 @@ def sdxl_encode_adm_patched(self, **kwargs): clip_pooled = comfy.model_base.sdxl_pooled(kwargs, self.noise_augmentor) width = kwargs.get("width", 768) height = kwargs.get("height", 768) - crop_w = kwargs.get("crop_w", 0) - crop_h = kwargs.get("crop_h", 0) - target_width = kwargs.get("target_width", width) - target_height = kwargs.get("target_height", height) + target_width = width + target_height = height if kwargs.get("prompt_type", "") == "negative": width = float(width) * negative_adm_scale @@ -225,20 +216,22 @@ def sdxl_encode_adm_patched(self, **kwargs): # Avoid artifacts width = int(width) height = int(height) - crop_w = int(crop_w) - crop_h = int(crop_h) + crop_w = 0 + crop_h = 0 target_width = int(target_width) target_height = int(target_height) - out = [] - out.append(self.embedder(torch.Tensor([height]))) - out.append(self.embedder(torch.Tensor([width]))) - out.append(self.embedder(torch.Tensor([crop_h]))) - out.append(self.embedder(torch.Tensor([crop_w]))) - out.append(self.embedder(torch.Tensor([target_height]))) - out.append(self.embedder(torch.Tensor([target_width]))) - flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) - return torch.cat((clip_pooled.to(flat.device), flat), dim=1) + out_a = [self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])), + self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])), + self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width]))] + flat_a = torch.flatten(torch.cat(out_a)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) + + out_b = [self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width])), + self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])), + self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width]))] + flat_b = torch.flatten(torch.cat(out_b)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) + + return torch.cat((clip_pooled.to(flat_a.device), flat_a, clip_pooled.to(flat_b.device), flat_b), dim=1) def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs): @@ -346,9 +339,12 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control= transformer_options["current_index"] = 0 transformer_patches = transformer_options.get("patches", {}) - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" + if isinstance(y, torch.Tensor) and int(y.dim()) == 2 and int(y.shape[1]) == 5632: + t = (timesteps / 999.0)[:, None].clone().to(x) ** 2.0 + ya = y[..., :2816].clone() + yb = y[..., 2816:].clone() + y = t * ya + (1 - t) * yb + hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) emb = self.time_embed(t_emb) diff --git a/update_log.md b/update_log.md index 6fcb3f89..8df10931 100644 --- a/update_log.md +++ b/update_log.md @@ -1,5 +1,9 @@ # 2.0.80 +* Improved the scheduling of ADM guidance and CFG mimicking for better visual quality in high frequency domain and small objects. + +# 2.0.80 + * Rework many patches and some UI details. * Speed up processing. * Move Colab to independent branch. diff --git a/webui.py b/webui.py index 2f7f1d5c..dd33442e 100644 --- a/webui.py +++ b/webui.py @@ -165,8 +165,9 @@ with shared.gradio_root: step=0.001, value=1.5, info='The scaler multiplied to positive ADM (use 1.0 to disable). ') adm_scaler_negative = gr.Slider(label='Negative ADM Guidance Scaler', minimum=0.1, maximum=3.0, step=0.001, value=0.8, info='The scaler multiplied to negative ADM (use 1.0 to disable). ') - adaptive_cfg = gr.Slider(label='CFG Rescale from TSNR', minimum=0.0, maximum=1.0, - step=0.001, value=0.3, info='Enabling Fooocus\'s implementation of CFG re-weighting for TSNR (use 0 to disable, more effective when CFG > 7).') + adaptive_cfg = gr.Slider(label='CFG Mimicking from TSNR', minimum=1.0, maximum=30.0, step=0.01, value=5.0, + info='Enabling Fooocus\'s implementation of CFG mimicking for TSNR ' + '(effective when real CFG > mimicked CFG).') sampler_name = gr.Dropdown(label='Sampler', choices=flags.sampler_list, value=flags.default_sampler, info='Only effective in non-inpaint mode.') def dev_mode_checked(r):