From ec5dd950a25cbbc2c6a9462101712402bf0920c7 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sat, 16 Dec 2023 15:55:53 -0800 Subject: [PATCH] Fix many precision problems Many users reported that image quality is different from 2.1.824. We reviewed all codes and fixed several precision problems in 2.1.846. --- extras/ip_adapter.py | 16 ++- fooocus_version.py | 2 +- modules/patch.py | 46 +----- modules/patch_clip.py | 279 +++++++++++++++++++++++++++++++++++++ modules/patch_precision.py | 60 ++++++++ update_log.md | 4 + 6 files changed, 363 insertions(+), 44 deletions(-) create mode 100644 modules/patch_clip.py create mode 100644 modules/patch_precision.py diff --git a/extras/ip_adapter.py b/extras/ip_adapter.py index b18f0dfc..ac64ef1d 100644 --- a/extras/ip_adapter.py +++ b/extras/ip_adapter.py @@ -167,14 +167,26 @@ def preprocess(img, ip_adapter_path): ldm_patched.modules.model_management.load_model_gpu(clip_vision.patcher) pixel_values = clip_preprocess(numpy_to_pytorch(img).to(clip_vision.load_device)) - outputs = clip_vision.model(pixel_values=pixel_values, intermediate_output=-2) + + if clip_vision.dtype != torch.float32: + precision_scope = torch.autocast + else: + precision_scope = lambda a, b: contextlib.nullcontext(a) + + with precision_scope(ldm_patched.modules.model_management.get_autocast_device(clip_vision.load_device), torch.float32): + outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True) ip_adapter = entry['ip_adapter'] ip_layers = entry['ip_layers'] image_proj_model = entry['image_proj_model'] ip_unconds = entry['ip_unconds'] - cond = outputs[1].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) + if ip_adapter.plus: + cond = outputs.hidden_states[-2] + else: + cond = outputs.image_embeds + + cond = cond.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) ldm_patched.modules.model_management.load_model_gpu(image_proj_model) cond = image_proj_model.model(cond).to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) diff --git a/fooocus_version.py b/fooocus_version.py index a7dac990..26b2cf04 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.844' +version = '2.1.846' diff --git a/modules/patch.py b/modules/patch.py index 0a04bafb..0ae53585 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -25,6 +25,8 @@ import modules.constants as constants from ldm_patched.modules.samplers import calc_cond_uncond_batch from ldm_patched.k_diffusion.sampling import BatchedBrownianTree from ldm_patched.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, apply_control +from modules.patch_precision import patch_all_precision +from modules.patch_clip import patch_all_clip sharpness = 2.0 @@ -286,46 +288,6 @@ def sdxl_encode_adm_patched(self, **kwargs): return final_adm -def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs): - to_encode = list() - max_token_len = 0 - has_weights = False - for x in token_weight_pairs: - tokens = list(map(lambda a: a[0], x)) - max_token_len = max(len(tokens), max_token_len) - has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) - to_encode.append(tokens) - - sections = len(to_encode) - if has_weights or sections == 0: - to_encode.append(ldm_patched.modules.sd1_clip.gen_empty_tokens(self.special_tokens, max_token_len)) - - out, pooled = self.encode(to_encode) - if pooled is not None: - first_pooled = pooled[0:1].to(ldm_patched.modules.model_management.intermediate_device()) - else: - first_pooled = pooled - - output = [] - for k in range(0, sections): - z = out[k:k + 1] - if has_weights: - original_mean = z.mean() - z_empty = out[-1] - for i in range(len(z)): - for j in range(len(z[i])): - weight = token_weight_pairs[k][j][1] - if weight != 1.0: - z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] - new_mean = z.mean() - z = z * (original_mean / new_mean) - output.append(z) - - if len(output) == 0: - return out[-1:].to(ldm_patched.modules.model_management.intermediate_device()), first_pooled - return torch.cat(output, dim=-2).to(ldm_patched.modules.model_management.intermediate_device()), first_pooled - - def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): if inpaint_worker.current_task is not None: latent_processor = self.inner_model.inner_model.process_latent_in @@ -519,6 +481,9 @@ def build_loaded(module, loader_name): def patch_all(): + patch_all_precision() + patch_all_clip() + if not hasattr(ldm_patched.modules.model_management, 'load_models_gpu_origin'): ldm_patched.modules.model_management.load_models_gpu_origin = ldm_patched.modules.model_management.load_models_gpu @@ -527,7 +492,6 @@ def patch_all(): ldm_patched.controlnet.cldm.ControlNet.forward = patched_cldm_forward ldm_patched.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward ldm_patched.modules.model_base.SDXL.encode_adm = sdxl_encode_adm_patched - ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method ldm_patched.modules.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward ldm_patched.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched ldm_patched.modules.samplers.sampling_function = patched_sampling_function diff --git a/modules/patch_clip.py b/modules/patch_clip.py new file mode 100644 index 00000000..157f051b --- /dev/null +++ b/modules/patch_clip.py @@ -0,0 +1,279 @@ +# Consistent with Kohya/A1111 to reduce differences between model training and inference. + +import os +import torch +import ldm_patched.controlnet.cldm +import ldm_patched.k_diffusion.sampling +import ldm_patched.ldm.modules.attention +import ldm_patched.ldm.modules.diffusionmodules.model +import ldm_patched.ldm.modules.diffusionmodules.openaimodel +import ldm_patched.ldm.modules.diffusionmodules.openaimodel +import ldm_patched.modules.args_parser +import ldm_patched.modules.model_base +import ldm_patched.modules.model_management +import ldm_patched.modules.model_patcher +import ldm_patched.modules.ops +import ldm_patched.modules.samplers +import ldm_patched.modules.sd +import ldm_patched.modules.sd1_clip +import ldm_patched.modules.clip_vision +import ldm_patched.modules.model_management as model_management +import contextlib + +from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection + + +@contextlib.contextmanager +def use_disable_weight_init_linear_ops(device=None, dtype=None): + old_torch_nn_linear = torch.nn.Linear + force_device = device + force_dtype = dtype + + def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): + if force_device is not None: + device = force_device + if force_dtype is not None: + dtype = force_dtype + return ldm_patched.modules.ops.disable_weight_init.Linear(in_features, out_features, bias=bias, device=device, + dtype=dtype) + + torch.nn.Linear = linear_with_dtype + try: + yield + finally: + torch.nn.Linear = old_torch_nn_linear + return + + +def encode_token_weights_fooocus(self, token_weight_pairs): + to_encode = list() + max_token_len = 0 + has_weights = False + for x in token_weight_pairs: + tokens = list(map(lambda a: a[0], x)) + max_token_len = max(len(tokens), max_token_len) + has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) + to_encode.append(tokens) + + sections = len(to_encode) + if has_weights or sections == 0: + to_encode.append(ldm_patched.modules.sd1_clip.gen_empty_tokens(self.special_tokens, max_token_len)) + + out, pooled = self.encode(to_encode) + if pooled is not None: + first_pooled = pooled[0:1].to(ldm_patched.modules.model_management.intermediate_device()) + else: + first_pooled = pooled + + output = [] + for k in range(0, sections): + z = out[k:k + 1] + if has_weights: + original_mean = z.mean() + z_empty = out[-1] + for i in range(len(z)): + for j in range(len(z[i])): + weight = token_weight_pairs[k][j][1] + if weight != 1.0: + z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] + new_mean = z.mean() + z = z * (original_mean / new_mean) + output.append(z) + + if len(output) == 0: + return out[-1:].to(ldm_patched.modules.model_management.intermediate_device()), first_pooled + return torch.cat(output, dim=-2).to(ldm_patched.modules.model_management.intermediate_device()), first_pooled + + +class SDClipModelFooocus(torch.nn.Module, ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = [ + "last", + "pooled", + "hidden" + ] + + def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, + freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=ldm_patched.modules.clip_model.CLIPTextModel, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + + if textmodel_json_config is None: + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(ldm_patched.modules.sd1_clip.__file__)), "sd1_clip_config.json") + + config = CLIPTextConfig.from_json_file(textmodel_json_config) + + self.num_layers = config.num_hidden_layers + with use_disable_weight_init_linear_ops(device, dtype): + with modeling_utils.no_init_weights(): + self.transformer = CLIPTextModel(config) + + self.inner_name = "text_model" + if dtype is not None: + self.transformer.to(dtype) + inner_model = getattr(self.transformer, self.inner_name) + if hasattr(inner_model, "embeddings"): + inner_model.embeddings.to(torch.float32) + else: + self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32)) + + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens + self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.enable_attention_masks = False + + self.layer_norm_hidden_state = layer_norm_hidden_state + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.clip_layer(layer_idx) + self.layer_default = (self.layer, self.layer_idx) + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def clip_layer(self, layer_idx): + if abs(layer_idx) > self.num_layers: + self.layer = "last" + else: + self.layer = "hidden" + self.layer_idx = layer_idx + + def reset_clip_layer(self): + self.layer = self.layer_default[0] + self.layer_idx = self.layer_default[1] + + def set_up_textual_embeddings(self, tokens, current_embeds): + out_tokens = [] + next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1 + embedding_weights = [] + + for x in tokens: + tokens_temp = [] + for y in x: + if isinstance(y, int): + if y == token_dict_size: # EOS token + y = -1 + tokens_temp += [y] + else: + if y.shape[0] == current_embeds.weight.shape[1]: + embedding_weights += [y] + tokens_temp += [next_new_token] + next_new_token += 1 + else: + print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", + y.shape[0], current_embeds.weight.shape[1]) + while len(tokens_temp) < len(x): + tokens_temp += [self.special_tokens["pad"]] + out_tokens += [tokens_temp] + + n = token_dict_size + if len(embedding_weights) > 0: + new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], + device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) + new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1] + for x in embedding_weights: + new_embedding.weight[n] = x + n += 1 + new_embedding.weight[n] = current_embeds.weight[-1] # EOS embedding + self.transformer.set_input_embeddings(new_embedding) + + processed_tokens = [] + for x in out_tokens: + processed_tokens += [ + list(map(lambda a: n if a == -1 else a, x))] # The EOS token should always be the largest one + + return processed_tokens + + def forward(self, tokens): + backup_embeds = self.transformer.get_input_embeddings() + device = backup_embeds.weight.device + tokens = self.set_up_textual_embeddings(tokens, backup_embeds) + tokens = torch.LongTensor(tokens).to(device) + + if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32: + precision_scope = torch.autocast + else: + precision_scope = lambda a, dtype: contextlib.nullcontext(a) + + with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32): + attention_mask = None + if self.enable_attention_masks: + attention_mask = torch.zeros_like(tokens) + max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 + for x in range(attention_mask.shape[0]): + for y in range(attention_mask.shape[1]): + attention_mask[x, y] = 1 + if tokens[x, y] == max_token: + break + + outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, + output_hidden_states=self.layer == "hidden") + self.transformer.set_input_embeddings(backup_embeds) + + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + if self.layer_norm_hidden_state: + z = getattr(self.transformer, self.inner_name).final_layer_norm(z) + + if hasattr(outputs, "pooler_output"): + pooled_output = outputs.pooler_output.float() + else: + pooled_output = None + + if self.text_projection is not None and pooled_output is not None: + pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() + return z.float(), pooled_output + + def encode(self, tokens): + return self(tokens) + + def load_sd(self, sd): + if "text_projection" in sd: + self.text_projection[:] = sd.pop("text_projection") + if "text_projection.weight" in sd: + self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1) + return self.transformer.load_state_dict(sd, strict=False) + + +class ClipVisionModelFooocus: + def __init__(self, json_config): + config = CLIPVisionConfig.from_json_file(json_config) + self.load_device = ldm_patched.modules.model_management.text_encoder_device() + offload_device = ldm_patched.modules.model_management.text_encoder_offload_device() + self.dtype = torch.float32 + if ldm_patched.modules.model_management.should_use_fp16(self.load_device, prioritize_performance=False): + self.dtype = torch.float16 + + with use_disable_weight_init_linear_ops(offload_device, self.dtype): + with modeling_utils.no_init_weights(): + self.model = CLIPVisionModelWithProjection(config) + self.model.to(self.dtype) + + self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + + def load_sd(self, sd): + return self.model.load_state_dict(sd, strict=False) + + def encode_image(self, image): + raise NotImplementedError('wrong clip vision call!') + + +def patch_all_clip(): + ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_fooocus + ldm_patched.modules.sd1_clip.SDClipModel = SDClipModelFooocus + ldm_patched.modules.clip_vision.ClipVisionModel = ClipVisionModelFooocus + return diff --git a/modules/patch_precision.py b/modules/patch_precision.py new file mode 100644 index 00000000..83569bdd --- /dev/null +++ b/modules/patch_precision.py @@ -0,0 +1,60 @@ +# Consistent with Kohya to reduce differences between model training and inference. + +import torch +import math +import einops +import numpy as np + +import ldm_patched.ldm.modules.diffusionmodules.openaimodel +import ldm_patched.modules.model_sampling +import ldm_patched.modules.sd1_clip + +from ldm_patched.ldm.modules.diffusionmodules.util import make_beta_schedule + + +def patched_timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + # Consistent with Kohya to reduce differences between model training and inference. + + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = einops.repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def patched_register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + # Consistent with Kohya to reduce differences between model training and inference. + + if given_betas is not None: + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32) + self.set_sigmas(sigmas) + return + + +def patch_all_precision(): + ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding = patched_timestep_embedding + ldm_patched.modules.model_sampling.ModelSamplingDiscrete._register_schedule = patched_register_schedule + return diff --git a/update_log.md b/update_log.md index 640b5665..39e4c9e5 100644 --- a/update_log.md +++ b/update_log.md @@ -1,3 +1,7 @@ +# 2.1.846 + +* Many users reported that image quality is different from 2.1.824. We reviewed all codes and fixed several precision problems in 2.1.846. + # 2.1.843 * Many improvements to Canvas. Thanks CanvasZoom author!