diff --git a/extras/ip_adapter.py b/extras/ip_adapter.py index cb1d366f..22527d24 100644 --- a/extras/ip_adapter.py +++ b/extras/ip_adapter.py @@ -2,12 +2,13 @@ import torch import ldm_patched.modules.clip_vision import safetensors.torch as sf import ldm_patched.modules.model_management as model_management -import contextlib import ldm_patched.ldm.modules.attention as attention from extras.resampler import Resampler from ldm_patched.modules.model_patcher import ModelPatcher from modules.core import numpy_to_pytorch +from modules.ops import use_patched_ops +from ldm_patched.modules.ops import manual_cast SD_V12_CHANNELS = [320] * 4 + [640] * 4 + [1280] * 4 + [1280] * 6 + [640] * 6 + [320] * 6 + [1280] * 2 @@ -116,14 +117,16 @@ def load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path): clip_extra_context_tokens = ip_state_dict["image_proj"]["proj.weight"].shape[0] // cross_attention_dim clip_embeddings_dim = None - ip_adapter = IPAdapterModel( - ip_state_dict, - plus=plus, - cross_attention_dim=cross_attention_dim, - clip_embeddings_dim=clip_embeddings_dim, - clip_extra_context_tokens=clip_extra_context_tokens, - sdxl_plus=sdxl_plus - ) + with use_patched_ops(manual_cast): + ip_adapter = IPAdapterModel( + ip_state_dict, + plus=plus, + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=clip_embeddings_dim, + clip_extra_context_tokens=clip_extra_context_tokens, + sdxl_plus=sdxl_plus + ) + ip_adapter.sdxl = sdxl ip_adapter.load_device = load_device ip_adapter.offload_device = offload_device diff --git a/extras/resampler.py b/extras/resampler.py index 4521c8c3..539f309d 100644 --- a/extras/resampler.py +++ b/extras/resampler.py @@ -108,8 +108,7 @@ class Resampler(nn.Module): ) def forward(self, x): - - latents = self.latents.repeat(x.size(0), 1, 1) + latents = self.latents.repeat(x.size(0), 1, 1).to(x) x = self.proj_in(x) @@ -118,4 +117,4 @@ class Resampler(nn.Module): latents = ff(latents) + latents latents = self.proj_out(latents) - return self.norm_out(latents) \ No newline at end of file + return self.norm_out(latents) diff --git a/fooocus_version.py b/fooocus_version.py index 2511cfc7..482cc12c 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.855' +version = '2.1.860' diff --git a/ldm_patched/contrib/external.py b/ldm_patched/contrib/external.py index 7f95f084..9d2238df 100644 --- a/ldm_patched/contrib/external.py +++ b/ldm_patched/contrib/external.py @@ -11,7 +11,7 @@ import math import time import random -from PIL import Image, ImageOps +from PIL import Image, ImageOps, ImageSequence from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch @@ -1412,17 +1412,30 @@ class LoadImage: FUNCTION = "load_image" def load_image(self, image): image_path = ldm_patched.utils.path_utils.get_annotated_filepath(image) - i = Image.open(image_path) - i = ImageOps.exif_transpose(i) - image = i.convert("RGB") - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - if 'A' in i.getbands(): - mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) + img = Image.open(image_path) + output_images = [] + output_masks = [] + for i in ImageSequence.Iterator(img): + i = ImageOps.exif_transpose(i) + image = i.convert("RGB") + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + output_images.append(image) + output_masks.append(mask.unsqueeze(0)) + + if len(output_images) > 1: + output_image = torch.cat(output_images, dim=0) + output_mask = torch.cat(output_masks, dim=0) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - return (image, mask.unsqueeze(0)) + output_image = output_images[0] + output_mask = output_masks[0] + + return (output_image, output_mask) @classmethod def IS_CHANGED(s, image): @@ -1480,13 +1493,10 @@ class LoadImageMask: return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, image, channel): + def VALIDATE_INPUTS(s, image): if not ldm_patched.utils.path_utils.exists_annotated_filepath(image): return "Invalid image file: {}".format(image) - if channel not in s._color_channels: - return "Invalid color channel: {}".format(channel) - return True class ImageScale: @@ -1871,6 +1881,7 @@ def init_custom_nodes(): "nodes_video_model.py", "nodes_sag.py", "nodes_perpneg.py", + "nodes_stable3d.py", ] for node_file in extras_files: diff --git a/ldm_patched/contrib/external_custom_sampler.py b/ldm_patched/contrib/external_custom_sampler.py index 9413a58f..6e5a769b 100644 --- a/ldm_patched/contrib/external_custom_sampler.py +++ b/ldm_patched/contrib/external_custom_sampler.py @@ -89,6 +89,7 @@ class SDTurboScheduler: return {"required": {"model": ("MODEL",), "steps": ("INT", {"default": 1, "min": 1, "max": 10}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), } } RETURN_TYPES = ("SIGMAS",) @@ -96,8 +97,9 @@ class SDTurboScheduler: FUNCTION = "get_sigmas" - def get_sigmas(self, model, steps): - timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps] + def get_sigmas(self, model, steps, denoise): + start_step = 10 - int(10 * denoise) + timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] sigmas = model.model.model_sampling.sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) return (sigmas, ) diff --git a/ldm_patched/contrib/external_mask.py b/ldm_patched/contrib/external_mask.py index ab1da4c6..a86a7fe6 100644 --- a/ldm_patched/contrib/external_mask.py +++ b/ldm_patched/contrib/external_mask.py @@ -8,6 +8,7 @@ import ldm_patched.modules.utils from ldm_patched.contrib.external import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): + source = source.to(destination.device) if resize_source: source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") @@ -22,7 +23,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou if mask is None: mask = torch.ones_like(source) else: - mask = mask.clone() + mask = mask.to(destination.device, copy=True) mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") mask = ldm_patched.modules.utils.repeat_to_batch_size(mask, source.shape[0]) diff --git a/ldm_patched/contrib/external_rebatch.py b/ldm_patched/contrib/external_rebatch.py index 607c7feb..c24cc8c3 100644 --- a/ldm_patched/contrib/external_rebatch.py +++ b/ldm_patched/contrib/external_rebatch.py @@ -101,10 +101,40 @@ class LatentRebatch: return (output_list,) +class ImageRebatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "images": ("IMAGE",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }} + RETURN_TYPES = ("IMAGE",) + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True, ) + + FUNCTION = "rebatch" + + CATEGORY = "image/batch" + + def rebatch(self, images, batch_size): + batch_size = batch_size[0] + + output_list = [] + all_images = [] + for img in images: + for i in range(img.shape[0]): + all_images.append(img[i:i+1]) + + for i in range(0, len(all_images), batch_size): + output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) + + return (output_list,) + NODE_CLASS_MAPPINGS = { "RebatchLatents": LatentRebatch, + "RebatchImages": ImageRebatch, } NODE_DISPLAY_NAME_MAPPINGS = { "RebatchLatents": "Rebatch Latents", -} \ No newline at end of file + "RebatchImages": "Rebatch Images", +} diff --git a/ldm_patched/contrib/external_sag.py b/ldm_patched/contrib/external_sag.py index 06ca67fa..9cffe879 100644 --- a/ldm_patched/contrib/external_sag.py +++ b/ldm_patched/contrib/external_sag.py @@ -153,7 +153,7 @@ class SelfAttentionGuidance: (sag, _) = ldm_patched.modules.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) return cfg_result + (degraded - sag) * sag_scale - m.set_model_sampler_post_cfg_function(post_cfg_function) + m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) # from diffusers: # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch diff --git a/ldm_patched/contrib/external_stable3d.py b/ldm_patched/contrib/external_stable3d.py new file mode 100644 index 00000000..2913a3d0 --- /dev/null +++ b/ldm_patched/contrib/external_stable3d.py @@ -0,0 +1,60 @@ +# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py + +import torch +import ldm_patched.contrib.external +import ldm_patched.modules.utils + +def camera_embeddings(elevation, azimuth): + elevation = torch.as_tensor([elevation]) + azimuth = torch.as_tensor([azimuth]) + embeddings = torch.stack( + [ + torch.deg2rad( + (90 - elevation) - (90) + ), # Zero123 polar is 90-elevation + torch.sin(torch.deg2rad(azimuth)), + torch.cos(torch.deg2rad(azimuth)), + torch.deg2rad( + 90 - torch.full_like(elevation, 0) + ), + ], dim=-1).unsqueeze(1) + + return embeddings + + +class StableZero123_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 256, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 256, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = ldm_patched.modules.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + cam_embeds = camera_embeddings(elevation, azimuth) + cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1) + + positive = [[cond, {"concat_latent_image": t}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] + latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + +NODE_CLASS_MAPPINGS = { + "StableZero123_Conditioning": StableZero123_Conditioning, +} diff --git a/ldm_patched/ldm/models/autoencoder.py b/ldm_patched/ldm/models/autoencoder.py index 14224ad3..c809a0c3 100644 --- a/ldm_patched/ldm/models/autoencoder.py +++ b/ldm_patched/ldm/models/autoencoder.py @@ -8,6 +8,7 @@ from ldm_patched.ldm.modules.distributions.distributions import DiagonalGaussian from ldm_patched.ldm.util import instantiate_from_config from ldm_patched.ldm.modules.ema import LitEma +import ldm_patched.modules.ops class DiagonalGaussianRegularizer(torch.nn.Module): def __init__(self, sample: bool = True): @@ -161,12 +162,12 @@ class AutoencodingEngineLegacy(AutoencodingEngine): }, **kwargs, ) - self.quant_conv = torch.nn.Conv2d( + self.quant_conv = ldm_patched.modules.ops.disable_weight_init.Conv2d( (1 + ddconfig["double_z"]) * ddconfig["z_channels"], (1 + ddconfig["double_z"]) * embed_dim, 1, ) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.post_quant_conv = ldm_patched.modules.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim def get_autoencoder_params(self) -> list: diff --git a/ldm_patched/ldm/modules/diffusionmodules/model.py b/ldm_patched/ldm/modules/diffusionmodules/model.py index 9c898639..1901145c 100644 --- a/ldm_patched/ldm/modules/diffusionmodules/model.py +++ b/ldm_patched/ldm/modules/diffusionmodules/model.py @@ -41,7 +41,7 @@ def nonlinearity(x): def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) class Upsample(nn.Module): diff --git a/ldm_patched/ldm/modules/diffusionmodules/upscaling.py b/ldm_patched/ldm/modules/diffusionmodules/upscaling.py index 59d4d3cc..2cde80c5 100644 --- a/ldm_patched/ldm/modules/diffusionmodules/upscaling.py +++ b/ldm_patched/ldm/modules/diffusionmodules/upscaling.py @@ -43,8 +43,8 @@ class AbstractLowScaleModel(nn.Module): def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise) def forward(self, x): return x, None diff --git a/ldm_patched/ldm/modules/diffusionmodules/util.py b/ldm_patched/ldm/modules/diffusionmodules/util.py index ca0f4b99..e261e06a 100644 --- a/ldm_patched/ldm/modules/diffusionmodules/util.py +++ b/ldm_patched/ldm/modules/diffusionmodules/util.py @@ -51,9 +51,9 @@ class AlphaBlender(nn.Module): if self.merge_strategy == "fixed": # make shape compatible # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) - alpha = self.mix_factor + alpha = self.mix_factor.to(image_only_indicator.device) elif self.merge_strategy == "learned": - alpha = torch.sigmoid(self.mix_factor) + alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device)) # make shape compatible # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) elif self.merge_strategy == "learned_with_images": @@ -61,7 +61,7 @@ class AlphaBlender(nn.Module): alpha = torch.where( image_only_indicator.bool(), torch.ones(1, 1, device=image_only_indicator.device), - rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), + rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"), ) alpha = rearrange(alpha, self.rearrange_pattern) # make shape compatible diff --git a/ldm_patched/ldm/modules/encoders/noise_aug_modules.py b/ldm_patched/ldm/modules/encoders/noise_aug_modules.py index b59bf204..66767b58 100644 --- a/ldm_patched/ldm/modules/encoders/noise_aug_modules.py +++ b/ldm_patched/ldm/modules/encoders/noise_aug_modules.py @@ -15,12 +15,12 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): def scale(self, x): # re-normalize to centered mean and unit variance - x = (x - self.data_mean) * 1. / self.data_std + x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device) return x def unscale(self, x): # back to original data stats - x = (x * self.data_std) + self.data_mean + x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device) return x def forward(self, x, noise_level=None): diff --git a/ldm_patched/ldm/modules/temporal_ae.py b/ldm_patched/ldm/modules/temporal_ae.py index 248d850b..ee851921 100644 --- a/ldm_patched/ldm/modules/temporal_ae.py +++ b/ldm_patched/ldm/modules/temporal_ae.py @@ -82,14 +82,14 @@ class VideoResBlock(ResnetBlock): x = self.time_stack(x, temb) - alpha = self.get_alpha(bs=b // timesteps) + alpha = self.get_alpha(bs=b // timesteps).to(x.device) x = alpha * x + (1.0 - alpha) * x_mix x = rearrange(x, "b c t h w -> (b t) c h w") return x -class AE3DConv(torch.nn.Conv2d): +class AE3DConv(ops.Conv2d): def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): super().__init__(in_channels, out_channels, *args, **kwargs) if isinstance(video_kernel_size, Iterable): @@ -97,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d): else: padding = int(video_kernel_size // 2) - self.time_mix_conv = torch.nn.Conv3d( + self.time_mix_conv = ops.Conv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=video_kernel_size, @@ -167,7 +167,7 @@ class AttnVideoBlock(AttnBlock): emb = emb[:, None, :] x_mix = x_mix + emb - alpha = self.get_alpha() + alpha = self.get_alpha().to(x.device) x_mix = self.time_mix_block(x_mix, timesteps=timesteps) x = alpha * x + (1.0 - alpha) * x_mix # alpha merge diff --git a/ldm_patched/modules/args_parser.py b/ldm_patched/modules/args_parser.py index 7957783e..7ffc4a81 100644 --- a/ldm_patched/modules/args_parser.py +++ b/ldm_patched/modules/args_parser.py @@ -66,6 +66,8 @@ fpvae_group.add_argument("--vae-in-fp16", action="store_true") fpvae_group.add_argument("--vae-in-fp32", action="store_true") fpvae_group.add_argument("--vae-in-bf16", action="store_true") +parser.add_argument("--vae-in-cpu", action="store_true") + fpte_group = parser.add_mutually_exclusive_group() fpte_group.add_argument("--clip-in-fp8-e4m3fn", action="store_true") fpte_group.add_argument("--clip-in-fp8-e5m2", action="store_true") diff --git a/ldm_patched/modules/clip_model.py b/ldm_patched/modules/clip_model.py index e7f3fb2d..4c4588c3 100644 --- a/ldm_patched/modules/clip_model.py +++ b/ldm_patched/modules/clip_model.py @@ -151,7 +151,7 @@ class CLIPVisionEmbeddings(torch.nn.Module): def forward(self, pixel_values): embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) - return torch.cat([self.class_embedding.expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight + return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device) class CLIPVision(torch.nn.Module): diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index e478e221..a7224660 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -283,7 +283,7 @@ class ControlLora(ControlNet): cm = self.control_model.state_dict() for k in sd: - weight = ldm_patched.modules.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k) + weight = sd[k] try: ldm_patched.modules.utils.set_attr(self.control_model, k, weight) except: diff --git a/ldm_patched/modules/model_base.py b/ldm_patched/modules/model_base.py index 1374a669..c04ccb3e 100644 --- a/ldm_patched/modules/model_base.py +++ b/ldm_patched/modules/model_base.py @@ -126,9 +126,15 @@ class BaseModel(torch.nn.Module): cond_concat.append(blank_inpaint_image_like(noise)) data = torch.cat(cond_concat, dim=1) out['c_concat'] = ldm_patched.modules.conds.CONDNoiseShape(data) + adm = self.encode_adm(**kwargs) if adm is not None: out['y'] = ldm_patched.modules.conds.CONDRegular(adm) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = ldm_patched.modules.conds.CONDCrossAttn(cross_attn) + return out def load_model_weights(self, sd, unet_prefix=""): @@ -156,11 +162,7 @@ class BaseModel(torch.nn.Module): def state_dict_for_saving(self, clip_state_dict, vae_state_dict): clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) - unet_sd = self.diffusion_model.state_dict() - unet_state_dict = {} - for k in unet_sd: - unet_state_dict[k] = ldm_patched.modules.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k) - + unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) if self.get_dtype() == torch.float16: @@ -322,9 +324,43 @@ class SVD_img2vid(BaseModel): out['c_concat'] = ldm_patched.modules.conds.CONDNoiseShape(latent_image) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = ldm_patched.modules.conds.CONDCrossAttn(cross_attn) + if "time_conditioning" in kwargs: out["time_context"] = ldm_patched.modules.conds.CONDCrossAttn(kwargs["time_conditioning"]) out['image_only_indicator'] = ldm_patched.modules.conds.CONDConstant(torch.zeros((1,), device=device)) out['num_video_frames'] = ldm_patched.modules.conds.CONDConstant(noise.shape[0]) return out + +class Stable_Zero123(BaseModel): + def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): + super().__init__(model_config, model_type, device=device) + self.cc_projection = ldm_patched.modules.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device) + self.cc_projection.weight.copy_(cc_projection_weight) + self.cc_projection.bias.copy_(cc_projection_bias) + + def extra_conds(self, **kwargs): + out = {} + + latent_image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + + if latent_image is None: + latent_image = torch.zeros_like(noise) + + if latent_image.shape[1:] != noise.shape[1:]: + latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0]) + + out['c_concat'] = ldm_patched.modules.conds.CONDNoiseShape(latent_image) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + if cross_attn.shape[-1] != 768: + cross_attn = self.cc_projection(cross_attn) + out['c_crossattn'] = ldm_patched.modules.conds.CONDCrossAttn(cross_attn) + return out diff --git a/ldm_patched/modules/model_management.py b/ldm_patched/modules/model_management.py index 31cf95da..59f0f3d0 100644 --- a/ldm_patched/modules/model_management.py +++ b/ldm_patched/modules/model_management.py @@ -186,6 +186,9 @@ except: if is_intel_xpu(): VAE_DTYPE = torch.bfloat16 +if args.vae_in_cpu: + VAE_DTYPE = torch.float32 + if args.vae_in_fp16: VAE_DTYPE = torch.float16 elif args.vae_in_bf16: @@ -218,15 +221,8 @@ if args.all_in_fp16: FORCE_FP16 = True if lowvram_available: - try: - import accelerate - if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): - vram_state = set_vram_to - except Exception as e: - import traceback - print(traceback.format_exc()) - print("ERROR: LOW VRAM MODE NEEDS accelerate.") - lowvram_available = False + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to if cpu_state != CPUState.GPU: @@ -266,6 +262,14 @@ print("VAE dtype:", VAE_DTYPE) current_loaded_models = [] +def module_size(module): + module_mem = 0 + sd = module.state_dict() + for k in sd: + t = sd[k] + module_mem += t.nelement() * t.element_size() + return module_mem + class LoadedModel: def __init__(self, model): self.model = model @@ -298,8 +302,20 @@ class LoadedModel: if lowvram_model_memory > 0: print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) - device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) - accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) + mem_counter = 0 + for m in self.real_model.modules(): + if hasattr(m, "ldm_patched_cast_weights"): + m.prev_ldm_patched_cast_weights = m.ldm_patched_cast_weights + m.ldm_patched_cast_weights = True + module_mem = module_size(m) + if mem_counter + module_mem < lowvram_model_memory: + m.to(self.device) + mem_counter += module_mem + elif hasattr(m, "weight"): #only modules with ldm_patched_cast_weights can be set to lowvram mode + m.to(self.device) + mem_counter += module_size(m) + print("lowvram: loaded module regularly", m) + self.model_accelerated = True if is_intel_xpu() and not args.disable_ipex_hijack: @@ -309,7 +325,11 @@ class LoadedModel: def model_unload(self): if self.model_accelerated: - accelerate.hooks.remove_hook_from_submodules(self.real_model) + for m in self.real_model.modules(): + if hasattr(m, "prev_ldm_patched_cast_weights"): + m.ldm_patched_cast_weights = m.prev_ldm_patched_cast_weights + del m.prev_ldm_patched_cast_weights + self.model_accelerated = False self.model.unpatch_model(self.model.offload_device) @@ -402,14 +422,14 @@ def load_models_gpu(models, memory_required=0): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) + lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary vram_set_state = VRAMState.LOW_VRAM else: lowvram_model_memory = 0 if vram_set_state == VRAMState.NO_VRAM: - lowvram_model_memory = 256 * 1024 * 1024 + lowvram_model_memory = 64 * 1024 * 1024 cur_loaded_model = loaded_model.model_load(lowvram_model_memory) current_loaded_models.insert(0, loaded_model) @@ -538,6 +558,8 @@ def intermediate_device(): return torch.device("cpu") def vae_device(): + if args.vae_in_cpu: + return torch.device("cpu") return get_torch_device() def vae_offload_device(): @@ -566,6 +588,11 @@ def supports_dtype(device, dtype): #TODO return True return False +def device_supports_non_blocking(device): + if is_device_mps(device): + return False #pytorch bug? mps doesn't support non blocking + return True + def cast_to_device(tensor, device, dtype, copy=False): device_supports_cast = False if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: @@ -576,9 +603,7 @@ def cast_to_device(tensor, device, dtype, copy=False): elif is_intel_xpu(): device_supports_cast = True - non_blocking = True - if is_device_mps(device): - non_blocking = False #pytorch bug? mps doesn't support non blocking + non_blocking = device_supports_non_blocking(device) if device_supports_cast: if copy: @@ -742,11 +767,11 @@ def soft_empty_cache(force=False): torch.cuda.empty_cache() torch.cuda.ipc_collect() -def resolve_lowvram_weight(weight, model, key): - if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. - key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. - op = ldm_patched.modules.utils.get_attr(model, '.'.join(key_split[:-1])) - weight = op._hf_hook.weights_map[key_split[-1]] +def unload_all_models(): + free_memory(1e30, get_torch_device()) + + +def resolve_lowvram_weight(weight, model, key): #TODO: remove return weight #TODO: might be cleaner to put this somewhere else diff --git a/ldm_patched/modules/model_patcher.py b/ldm_patched/modules/model_patcher.py index ae795ca9..0945a13c 100644 --- a/ldm_patched/modules/model_patcher.py +++ b/ldm_patched/modules/model_patcher.py @@ -28,13 +28,9 @@ class ModelPatcher: if self.size > 0: return self.size model_sd = self.model.state_dict() - size = 0 - for k in model_sd: - t = model_sd[k] - size += t.nelement() * t.element_size() - self.size = size + self.size = ldm_patched.modules.model_management.module_size(self.model) self.model_keys = set(model_sd.keys()) - return size + return self.size def clone(self): n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) @@ -55,14 +51,18 @@ class ModelPatcher: def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) - def set_model_sampler_cfg_function(self, sampler_cfg_function): + def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way else: self.model_options["sampler_cfg_function"] = sampler_cfg_function + if disable_cfg1_optimization: + self.model_options["disable_cfg1_optimization"] = True - def set_model_sampler_post_cfg_function(self, post_cfg_function): + def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function] + if disable_cfg1_optimization: + self.model_options["disable_cfg1_optimization"] = True def set_model_unet_function_wrapper(self, unet_wrapper_function): self.model_options["model_function_wrapper"] = unet_wrapper_function diff --git a/ldm_patched/modules/ops.py b/ldm_patched/modules/ops.py index 08c63384..435aba57 100644 --- a/ldm_patched/modules/ops.py +++ b/ldm_patched/modules/ops.py @@ -1,27 +1,93 @@ import torch from contextlib import contextmanager +import ldm_patched.modules.model_management + +def cast_bias_weight(s, input): + bias = None + non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device) + if s.bias is not None: + bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + return weight, bias + class disable_weight_init: class Linear(torch.nn.Linear): + ldm_patched_cast_weights = False def reset_parameters(self): return None + def forward_ldm_patched_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.ldm_patched_cast_weights: + return self.forward_ldm_patched_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class Conv2d(torch.nn.Conv2d): + ldm_patched_cast_weights = False def reset_parameters(self): return None + def forward_ldm_patched_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.ldm_patched_cast_weights: + return self.forward_ldm_patched_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class Conv3d(torch.nn.Conv3d): + ldm_patched_cast_weights = False def reset_parameters(self): return None + def forward_ldm_patched_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.ldm_patched_cast_weights: + return self.forward_ldm_patched_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class GroupNorm(torch.nn.GroupNorm): + ldm_patched_cast_weights = False def reset_parameters(self): return None + def forward_ldm_patched_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + + def forward(self, *args, **kwargs): + if self.ldm_patched_cast_weights: + return self.forward_ldm_patched_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + class LayerNorm(torch.nn.LayerNorm): + ldm_patched_cast_weights = False def reset_parameters(self): return None + def forward_ldm_patched_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + + def forward(self, *args, **kwargs): + if self.ldm_patched_cast_weights: + return self.forward_ldm_patched_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + @classmethod def conv_nd(s, dims, *args, **kwargs): if dims == 2: @@ -31,35 +97,19 @@ class disable_weight_init: else: raise ValueError(f"unsupported dimensions: {dims}") -def cast_bias_weight(s, input): - bias = None - if s.bias is not None: - bias = s.bias.to(device=input.device, dtype=input.dtype) - weight = s.weight.to(device=input.device, dtype=input.dtype) - return weight, bias class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + ldm_patched_cast_weights = True class Conv2d(disable_weight_init.Conv2d): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + ldm_patched_cast_weights = True class Conv3d(disable_weight_init.Conv3d): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + ldm_patched_cast_weights = True class GroupNorm(disable_weight_init.GroupNorm): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + ldm_patched_cast_weights = True class LayerNorm(disable_weight_init.LayerNorm): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + ldm_patched_cast_weights = True diff --git a/ldm_patched/modules/sample.py b/ldm_patched/modules/sample.py index 7a7e3092..b5576cee 100644 --- a/ldm_patched/modules/sample.py +++ b/ldm_patched/modules/sample.py @@ -47,7 +47,8 @@ def convert_cond(cond): temp = c[1].copy() model_conds = temp.get("model_conds", {}) if c[0] is not None: - model_conds["c_crossattn"] = ldm_patched.modules.conds.CONDCrossAttn(c[0]) + model_conds["c_crossattn"] = ldm_patched.modules.conds.CONDCrossAttn(c[0]) #TODO: remove + temp["cross_attn"] = c[0] temp["model_conds"] = model_conds out.append(temp) return out diff --git a/ldm_patched/modules/samplers.py b/ldm_patched/modules/samplers.py index bfcb3f56..fc17ef4d 100644 --- a/ldm_patched/modules/samplers.py +++ b/ldm_patched/modules/samplers.py @@ -244,7 +244,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #The main sampling function shared by all the samplers #Returns denoised def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - if math.isclose(cond_scale, 1.0): + if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: uncond_ = None else: uncond_ = uncond @@ -599,6 +599,13 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model calculate_start_end_timesteps(model, negative) calculate_start_end_timesteps(model, positive) + if latent_image is not None: + latent_image = model.process_latent_in(latent_image) + + if hasattr(model, 'extra_conds'): + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) @@ -610,13 +617,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) - if latent_image is not None: - latent_image = model.process_latent_in(latent_image) - - if hasattr(model, 'extra_conds'): - positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) diff --git a/ldm_patched/modules/supported_models.py b/ldm_patched/modules/supported_models.py index 2f2dee87..251bf6ac 100644 --- a/ldm_patched/modules/supported_models.py +++ b/ldm_patched/modules/supported_models.py @@ -252,5 +252,32 @@ class SVD_img2vid(supported_models_base.BASE): def clip_target(self): return None -models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega] +class Stable_Zero123(supported_models_base.BASE): + unet_config = { + "context_dim": 768, + "model_channels": 320, + "use_linear_in_transformer": False, + "adm_in_channels": None, + "use_temporal_attention": False, + "in_channels": 8, + } + + unet_extra_config = { + "num_heads": 8, + "num_head_channels": -1, + } + + clip_vision_prefix = "cond_stage_model.model.visual." + + latent_format = latent_formats.SD15 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"]) + return out + + def clip_target(self): + return None + + +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega] models += [SVD_img2vid] diff --git a/ldm_patched/taesd/taesd.py b/ldm_patched/taesd/taesd.py index ac88e594..0b4b885f 100644 --- a/ldm_patched/taesd/taesd.py +++ b/ldm_patched/taesd/taesd.py @@ -7,9 +7,10 @@ import torch import torch.nn as nn import ldm_patched.modules.utils +import ldm_patched.modules.ops def conv(n_in, n_out, **kwargs): - return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + return ldm_patched.modules.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs) class Clamp(nn.Module): def forward(self, x): @@ -19,7 +20,7 @@ class Block(nn.Module): def __init__(self, n_in, n_out): super().__init__() self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) - self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.skip = ldm_patched.modules.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.fuse = nn.ReLU() def forward(self, x): return self.fuse(self.conv(x) + self.skip(x)) diff --git a/ldm_patched/utils/path_utils.py b/ldm_patched/utils/path_utils.py index 34cd52c9..d21b6485 100644 --- a/ldm_patched/utils/path_utils.py +++ b/ldm_patched/utils/path_utils.py @@ -184,8 +184,7 @@ def cached_filename_list_(folder_name): if folder_name not in filename_list_cache: return None out = filename_list_cache[folder_name] - if time.perf_counter() < (out[2] + 0.5): - return out + for x in out[1]: time_modified = out[1][x] folder = x diff --git a/modules/advanced_parameters.py b/modules/advanced_parameters.py index ea04db6c..0caa3eec 100644 --- a/modules/advanced_parameters.py +++ b/modules/advanced_parameters.py @@ -5,7 +5,8 @@ disable_preview, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adapt debugging_cn_preprocessor, skipping_cn_preprocessor, controlnet_softness, canny_low_threshold, canny_high_threshold, \ refiner_swap_method, \ freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2, \ - debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine, inpaint_strength, inpaint_respective_field = [None] * 32 + debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine, inpaint_strength, inpaint_respective_field, \ + inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate = [None] * 35 def set_all_advanced_parameters(*args): @@ -16,7 +17,8 @@ def set_all_advanced_parameters(*args): debugging_cn_preprocessor, skipping_cn_preprocessor, controlnet_softness, canny_low_threshold, canny_high_threshold, \ refiner_swap_method, \ freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2, \ - debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine, inpaint_strength, inpaint_respective_field + debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine, inpaint_strength, inpaint_respective_field, \ + inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate disable_preview, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, \ scheduler_name, generate_image_grid, overwrite_step, overwrite_switch, overwrite_width, overwrite_height, \ @@ -25,6 +27,7 @@ def set_all_advanced_parameters(*args): debugging_cn_preprocessor, skipping_cn_preprocessor, controlnet_softness, canny_low_threshold, canny_high_threshold, \ refiner_swap_method, \ freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2, \ - debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine, inpaint_strength, inpaint_respective_field = args + debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine, inpaint_strength, inpaint_respective_field, \ + inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate = args return diff --git a/modules/async_worker.py b/modules/async_worker.py index 421098d3..bf21efa5 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -42,7 +42,7 @@ def worker(): from modules.private_logger import log from extras.expansion import safe_str from modules.util import remove_empty_str, HWC3, resize_image, \ - get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image + get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate from modules.upscaler import perform_upscale try: @@ -142,6 +142,7 @@ def worker(): outpaint_selections = args.pop() inpaint_input_image = args.pop() inpaint_additional_prompt = args.pop() + inpaint_mask_image_upload = args.pop() cn_tasks = {x: [] for x in flags.ip_list} for _ in range(4): @@ -277,6 +278,22 @@ def worker(): and isinstance(inpaint_input_image, dict): inpaint_image = inpaint_input_image['image'] inpaint_mask = inpaint_input_image['mask'][:, :, 0] + + if advanced_parameters.inpaint_mask_upload_checkbox: + if isinstance(inpaint_mask_image_upload, np.ndarray): + if inpaint_mask_image_upload.ndim == 3: + H, W, C = inpaint_image.shape + inpaint_mask_image_upload = resample_image(inpaint_mask_image_upload, width=W, height=H) + inpaint_mask_image_upload = np.mean(inpaint_mask_image_upload, axis=2) + inpaint_mask_image_upload = (inpaint_mask_image_upload > 127).astype(np.uint8) * 255 + inpaint_mask = np.maximum(inpaint_mask, inpaint_mask_image_upload) + + if int(advanced_parameters.inpaint_erode_or_dilate) != 0: + inpaint_mask = erode_or_dilate(inpaint_mask, advanced_parameters.inpaint_erode_or_dilate) + + if advanced_parameters.invert_mask_checkbox: + inpaint_mask = 255 - inpaint_mask + inpaint_image = HWC3(inpaint_image) if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \ and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0): diff --git a/modules/config.py b/modules/config.py index 9cc61788..c7af33db 100644 --- a/modules/config.py +++ b/modules/config.py @@ -243,10 +243,15 @@ default_advanced_checkbox = get_config_item_or_set_default( default_value=False, validator=lambda x: isinstance(x, bool) ) +default_max_image_number = get_config_item_or_set_default( + key='default_max_image_number', + default_value=32, + validator=lambda x: isinstance(x, int) and x >= 1 +) default_image_number = get_config_item_or_set_default( key='default_image_number', default_value=2, - validator=lambda x: isinstance(x, int) and 1 <= x <= 32 + validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number ) checkpoint_downloads = get_config_item_or_set_default( key='checkpoint_downloads', diff --git a/modules/meta_parser.py b/modules/meta_parser.py index 78d73978..07b42a16 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -3,7 +3,7 @@ import gradio as gr import modules.config -def load_parameter_button_click(raw_prompt_txt): +def load_parameter_button_click(raw_prompt_txt, is_generating): loaded_parameter_dict = json.loads(raw_prompt_txt) assert isinstance(loaded_parameter_dict, dict) @@ -128,7 +128,11 @@ def load_parameter_button_click(raw_prompt_txt): results.append(gr.update()) results.append(gr.update()) - results.append(gr.update(visible=True)) + if is_generating: + results.append(gr.update()) + else: + results.append(gr.update(visible=True)) + results.append(gr.update(visible=False)) for i in range(1, 6): diff --git a/modules/ops.py b/modules/ops.py new file mode 100644 index 00000000..ee0e7756 --- /dev/null +++ b/modules/ops.py @@ -0,0 +1,19 @@ +import torch +import contextlib + + +@contextlib.contextmanager +def use_patched_ops(operations): + op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm'] + backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names} + + try: + for op_name in op_names: + setattr(torch.nn, op_name, getattr(operations, op_name)) + + yield + + finally: + for op_name in op_names: + setattr(torch.nn, op_name, backups[op_name]) + return diff --git a/modules/patch.py b/modules/patch.py index 66b243cb..2e2409c5 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -218,7 +218,7 @@ def compute_cfg(uncond, cond, cfg_scale, t): def patched_sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options=None, seed=None): global eps_record - if math.isclose(cond_scale, 1.0): + if math.isclose(cond_scale, 1.0) and not model_options.get("disable_cfg1_optimization", False): final_x0 = calc_cond_uncond_batch(model, cond, None, x, timestep, model_options)[0] if eps_record is not None: @@ -480,6 +480,10 @@ def build_loaded(module, loader_name): def patch_all(): + if ldm_patched.modules.model_management.directml_enabled: + ldm_patched.modules.model_management.lowvram_available = True + ldm_patched.modules.model_management.OOM_EXCEPTION = Exception + patch_all_precision() patch_all_clip() diff --git a/modules/patch_clip.py b/modules/patch_clip.py index 74ee436a..06b7f01b 100644 --- a/modules/patch_clip.py +++ b/modules/patch_clip.py @@ -16,30 +16,12 @@ import ldm_patched.modules.samplers import ldm_patched.modules.sd import ldm_patched.modules.sd1_clip import ldm_patched.modules.clip_vision -import ldm_patched.modules.model_management as model_management import ldm_patched.modules.ops as ops -import contextlib +from modules.ops import use_patched_ops from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection -@contextlib.contextmanager -def use_patched_ops(operations): - op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm'] - backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names} - - try: - for op_name in op_names: - setattr(torch.nn, op_name, getattr(operations, op_name)) - - yield - - finally: - for op_name in op_names: - setattr(torch.nn, op_name, backups[op_name]) - return - - def patched_encode_token_weights(self, token_weight_pairs): to_encode = list() max_token_len = 0 diff --git a/modules/private_logger.py b/modules/private_logger.py index 83ba9e36..968bd4f5 100644 --- a/modules/private_logger.py +++ b/modules/private_logger.py @@ -44,13 +44,28 @@ def log(img, dic): ) js = ( - "" + """""" ) begin_part = f"
Fooocus Log {date_string} (private)
\nAll images are clean, without any hidden data/meta, and safe to share with others.
\n\n" diff --git a/modules/sample_hijack.py b/modules/sample_hijack.py index 7d8f757b..5936a096 100644 --- a/modules/sample_hijack.py +++ b/modules/sample_hijack.py @@ -99,6 +99,13 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas calculate_start_end_timesteps(model, negative) calculate_start_end_timesteps(model, positive) + if latent_image is not None: + latent_image = model.process_latent_in(latent_image) + + if hasattr(model, 'extra_conds'): + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) @@ -111,13 +118,6 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) - if latent_image is not None: - latent_image = model.process_latent_in(latent_image) - - if hasattr(model, 'extra_conds'): - positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} if current_refiner is not None and hasattr(current_refiner.model, 'extra_conds'): @@ -174,7 +174,7 @@ def calculate_sigmas_scheduler_hacked(model, scheduler_name, steps): elif scheduler_name == "sgm_uniform": sigmas = normal_scheduler(model, steps, sgm=True) elif scheduler_name == "turbo": - sigmas = SDTurboScheduler().get_sigmas(namedtuple('Patcher', ['model'])(model=model), steps)[0] + sigmas = SDTurboScheduler().get_sigmas(namedtuple('Patcher', ['model'])(model=model), steps=steps, denoise=1.0)[0] else: raise TypeError("error invalid scheduler") return sigmas diff --git a/modules/util.py b/modules/util.py index fce7efd7..052b746b 100644 --- a/modules/util.py +++ b/modules/util.py @@ -3,6 +3,7 @@ import datetime import random import math import os +import cv2 from PIL import Image @@ -10,6 +11,15 @@ from PIL import Image LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) +def erode_or_dilate(x, k): + k = int(k) + if k > 0: + return cv2.dilate(x, kernel=np.ones(shape=(3, 3), dtype=np.uint8), iterations=k) + if k < 0: + return cv2.erode(x, kernel=np.ones(shape=(3, 3), dtype=np.uint8), iterations=-k) + return x + + def resample_image(im, width, height): im = Image.fromarray(im) im = im.resize((int(width), int(height)), resample=LANCZOS) diff --git a/readme.md b/readme.md index 6b458e74..87c44b83 100644 --- a/readme.md +++ b/readme.md @@ -38,7 +38,7 @@ Using Fooocus is as easy as (probably easier than) Midjourney – but this does | Midjourney | Fooocus | | - | - | -| High-quality text-to-image without needing much prompt engineering or parameter tuning.