Fix many precision problems

Many users reported that image quality is different from 2.1.824. We reviewed all codes and fixed several precision problems in 2.1.846.
This commit is contained in:
lllyasviel 2023-12-16 15:55:53 -08:00 committed by GitHub
parent 3a727fd240
commit ec5dd950a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 363 additions and 44 deletions

View File

@ -167,14 +167,26 @@ def preprocess(img, ip_adapter_path):
ldm_patched.modules.model_management.load_model_gpu(clip_vision.patcher)
pixel_values = clip_preprocess(numpy_to_pytorch(img).to(clip_vision.load_device))
outputs = clip_vision.model(pixel_values=pixel_values, intermediate_output=-2)
if clip_vision.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(ldm_patched.modules.model_management.get_autocast_device(clip_vision.load_device), torch.float32):
outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True)
ip_adapter = entry['ip_adapter']
ip_layers = entry['ip_layers']
image_proj_model = entry['image_proj_model']
ip_unconds = entry['ip_unconds']
cond = outputs[1].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
if ip_adapter.plus:
cond = outputs.hidden_states[-2]
else:
cond = outputs.image_embeds
cond = cond.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
ldm_patched.modules.model_management.load_model_gpu(image_proj_model)
cond = image_proj_model.model(cond).to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)

View File

@ -1 +1 @@
version = '2.1.844'
version = '2.1.846'

View File

@ -25,6 +25,8 @@ import modules.constants as constants
from ldm_patched.modules.samplers import calc_cond_uncond_batch
from ldm_patched.k_diffusion.sampling import BatchedBrownianTree
from ldm_patched.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, apply_control
from modules.patch_precision import patch_all_precision
from modules.patch_clip import patch_all_clip
sharpness = 2.0
@ -286,46 +288,6 @@ def sdxl_encode_adm_patched(self, **kwargs):
return final_adm
def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs):
to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(ldm_patched.modules.sd1_clip.gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode)
if pooled is not None:
first_pooled = pooled[0:1].to(ldm_patched.modules.model_management.intermediate_device())
else:
first_pooled = pooled
output = []
for k in range(0, sections):
z = out[k:k + 1]
if has_weights:
original_mean = z.mean()
z_empty = out[-1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k][j][1]
if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
new_mean = z.mean()
z = z * (original_mean / new_mean)
output.append(z)
if len(output) == 0:
return out[-1:].to(ldm_patched.modules.model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).to(ldm_patched.modules.model_management.intermediate_device()), first_pooled
def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
if inpaint_worker.current_task is not None:
latent_processor = self.inner_model.inner_model.process_latent_in
@ -519,6 +481,9 @@ def build_loaded(module, loader_name):
def patch_all():
patch_all_precision()
patch_all_clip()
if not hasattr(ldm_patched.modules.model_management, 'load_models_gpu_origin'):
ldm_patched.modules.model_management.load_models_gpu_origin = ldm_patched.modules.model_management.load_models_gpu
@ -527,7 +492,6 @@ def patch_all():
ldm_patched.controlnet.cldm.ControlNet.forward = patched_cldm_forward
ldm_patched.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward
ldm_patched.modules.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method
ldm_patched.modules.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward
ldm_patched.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched
ldm_patched.modules.samplers.sampling_function = patched_sampling_function

279
modules/patch_clip.py Normal file
View File

@ -0,0 +1,279 @@
# Consistent with Kohya/A1111 to reduce differences between model training and inference.
import os
import torch
import ldm_patched.controlnet.cldm
import ldm_patched.k_diffusion.sampling
import ldm_patched.ldm.modules.attention
import ldm_patched.ldm.modules.diffusionmodules.model
import ldm_patched.ldm.modules.diffusionmodules.openaimodel
import ldm_patched.ldm.modules.diffusionmodules.openaimodel
import ldm_patched.modules.args_parser
import ldm_patched.modules.model_base
import ldm_patched.modules.model_management
import ldm_patched.modules.model_patcher
import ldm_patched.modules.ops
import ldm_patched.modules.samplers
import ldm_patched.modules.sd
import ldm_patched.modules.sd1_clip
import ldm_patched.modules.clip_vision
import ldm_patched.modules.model_management as model_management
import contextlib
from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection
@contextlib.contextmanager
def use_disable_weight_init_linear_ops(device=None, dtype=None):
old_torch_nn_linear = torch.nn.Linear
force_device = device
force_dtype = dtype
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
if force_device is not None:
device = force_device
if force_dtype is not None:
dtype = force_dtype
return ldm_patched.modules.ops.disable_weight_init.Linear(in_features, out_features, bias=bias, device=device,
dtype=dtype)
torch.nn.Linear = linear_with_dtype
try:
yield
finally:
torch.nn.Linear = old_torch_nn_linear
return
def encode_token_weights_fooocus(self, token_weight_pairs):
to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(ldm_patched.modules.sd1_clip.gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode)
if pooled is not None:
first_pooled = pooled[0:1].to(ldm_patched.modules.model_management.intermediate_device())
else:
first_pooled = pooled
output = []
for k in range(0, sections):
z = out[k:k + 1]
if has_weights:
original_mean = z.mean()
z_empty = out[-1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k][j][1]
if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
new_mean = z.mean()
z = z * (original_mean / new_mean)
output.append(z)
if len(output) == 0:
return out[-1:].to(ldm_patched.modules.model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).to(ldm_patched.modules.model_management.intermediate_device()), first_pooled
class SDClipModelFooocus(torch.nn.Module, ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
"pooled",
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=ldm_patched.modules.clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(ldm_patched.modules.sd1_clip.__file__)), "sd1_clip_config.json")
config = CLIPTextConfig.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with use_disable_weight_init_linear_ops(device, dtype):
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
self.inner_name = "text_model"
if dtype is not None:
self.transformer.to(dtype)
inner_model = getattr(self.transformer, self.inner_name)
if hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32)
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = None
self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = False
self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) < self.num_layers
self.clip_layer(layer_idx)
self.layer_default = (self.layer, self.layer_idx)
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def clip_layer(self, layer_idx):
if abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
self.layer_idx = layer_idx
def reset_clip_layer(self):
self.layer = self.layer_default[0]
self.layer_idx = self.layer_default[1]
def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = []
next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
embedding_weights = []
for x in tokens:
tokens_temp = []
for y in x:
if isinstance(y, int):
if y == token_dict_size: # EOS token
y = -1
tokens_temp += [y]
else:
if y.shape[0] == current_embeds.weight.shape[1]:
embedding_weights += [y]
tokens_temp += [next_new_token]
next_new_token += 1
else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored",
y.shape[0], current_embeds.weight.shape[1])
while len(tokens_temp) < len(x):
tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp]
n = token_dict_size
if len(embedding_weights) > 0:
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1],
device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
for x in embedding_weights:
new_embedding.weight[n] = x
n += 1
new_embedding.weight[n] = current_embeds.weight[-1] # EOS embedding
self.transformer.set_input_embeddings(new_embedding)
processed_tokens = []
for x in out_tokens:
processed_tokens += [
list(map(lambda a: n if a == -1 else a, x))] # The EOS token should always be the largest one
return processed_tokens
def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask,
output_hidden_states=self.layer == "hidden")
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
if hasattr(outputs, "pooler_output"):
pooled_output = outputs.pooler_output.float()
else:
pooled_output = None
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output
def encode(self, tokens):
return self(tokens)
def load_sd(self, sd):
if "text_projection" in sd:
self.text_projection[:] = sd.pop("text_projection")
if "text_projection.weight" in sd:
self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1)
return self.transformer.load_state_dict(sd, strict=False)
class ClipVisionModelFooocus:
def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config)
self.load_device = ldm_patched.modules.model_management.text_encoder_device()
offload_device = ldm_patched.modules.model_management.text_encoder_offload_device()
self.dtype = torch.float32
if ldm_patched.modules.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
self.dtype = torch.float16
with use_disable_weight_init_linear_ops(offload_device, self.dtype):
with modeling_utils.no_init_weights():
self.model = CLIPVisionModelWithProjection(config)
self.model.to(self.dtype)
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
def encode_image(self, image):
raise NotImplementedError('wrong clip vision call!')
def patch_all_clip():
ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_fooocus
ldm_patched.modules.sd1_clip.SDClipModel = SDClipModelFooocus
ldm_patched.modules.clip_vision.ClipVisionModel = ClipVisionModelFooocus
return

View File

@ -0,0 +1,60 @@
# Consistent with Kohya to reduce differences between model training and inference.
import torch
import math
import einops
import numpy as np
import ldm_patched.ldm.modules.diffusionmodules.openaimodel
import ldm_patched.modules.model_sampling
import ldm_patched.modules.sd1_clip
from ldm_patched.ldm.modules.diffusionmodules.util import make_beta_schedule
def patched_timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
# Consistent with Kohya to reduce differences between model training and inference.
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = einops.repeat(timesteps, 'b -> b d', d=dim)
return embedding
def patched_register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
# Consistent with Kohya to reduce differences between model training and inference.
if given_betas is not None:
betas = given_betas
else:
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
self.set_sigmas(sigmas)
return
def patch_all_precision():
ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding = patched_timestep_embedding
ldm_patched.modules.model_sampling.ModelSamplingDiscrete._register_schedule = patched_register_schedule
return

View File

@ -1,3 +1,7 @@
# 2.1.846
* Many users reported that image quality is different from 2.1.824. We reviewed all codes and fixed several precision problems in 2.1.846.
# 2.1.843
* Many improvements to Canvas. Thanks CanvasZoom author!