diff --git a/fooocus_extras/ip_adapter.py b/fooocus_extras/ip_adapter.py new file mode 100644 index 00000000..3fdc56b3 --- /dev/null +++ b/fooocus_extras/ip_adapter.py @@ -0,0 +1,304 @@ +import torch +import comfy.clip_vision +import safetensors.torch as sf +import comfy.model_management as model_management +import contextlib + +from fooocus_extras.resampler import Resampler +from comfy.model_patcher import ModelPatcher + + +if model_management.xformers_enabled(): + import xformers + import xformers.ops + + +SD_V12_CHANNELS = [320] * 4 + [640] * 4 + [1280] * 4 + [1280] * 6 + [640] * 6 + [320] * 6 + [1280] * 2 +SD_XL_CHANNELS = [640] * 8 + [1280] * 40 + [1280] * 60 + [640] * 12 + [1280] * 20 + + +def sdp(q, k, v, extra_options): + if model_management.xformers_enabled(): + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], extra_options["n_heads"], extra_options["dim_head"]) + .permute(0, 2, 1, 3) + .reshape(b * extra_options["n_heads"], t.shape[1], extra_options["dim_head"]) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None) + out = ( + out.unsqueeze(0) + .reshape(b, extra_options["n_heads"], out.shape[1], extra_options["dim_head"]) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], extra_options["n_heads"] * extra_options["dim_head"]) + ) + else: + b, _, _ = q.shape + q, k, v = map( + lambda t: t.view(b, -1, extra_options["n_heads"], extra_options["dim_head"]).transpose(1, 2), + (q, k, v), + ) + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = out.transpose(1, 2).reshape(b, -1, extra_options["n_heads"] * extra_options["dim_head"]) + return out + + +class ImageProjModel(torch.nn.Module): + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, + self.cross_attention_dim) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +class To_KV(torch.nn.Module): + def __init__(self, cross_attention_dim): + super().__init__() + + channels = SD_XL_CHANNELS if cross_attention_dim == 2048 else SD_V12_CHANNELS + self.to_kvs = torch.nn.ModuleList( + [torch.nn.Linear(cross_attention_dim, channel, bias=False) for channel in channels]) + + def load_state_dict_ordered(self, sd): + state_dict = [] + for i in range(4096): + for k in ['k', 'v']: + key = f'{i}.to_{k}_ip.weight' + if key in sd: + state_dict.append(sd[key]) + for i, v in enumerate(state_dict): + self.to_kvs[i].weight = torch.nn.Parameter(v, requires_grad=False) + + +class IPAdapterModel(torch.nn.Module): + def __init__(self, state_dict, plus, cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4, + sdxl_plus=False): + super().__init__() + self.plus = plus + if self.plus: + self.image_proj_model = Resampler( + dim=1280 if sdxl_plus else cross_attention_dim, + depth=4, + dim_head=64, + heads=20 if sdxl_plus else 12, + num_queries=clip_extra_context_tokens, + embedding_dim=clip_embeddings_dim, + output_dim=cross_attention_dim, + ff_mult=4 + ) + else: + self.image_proj_model = ImageProjModel( + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=clip_embeddings_dim, + clip_extra_context_tokens=clip_extra_context_tokens + ) + + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + self.ip_layers = To_KV(cross_attention_dim) + self.ip_layers.load_state_dict_ordered(state_dict["ip_adapter"]) + + +clip_vision: comfy.clip_vision.ClipVisionModel = None +ip_negative: torch.Tensor = None +image_proj_model: ModelPatcher = None +ip_layers: ModelPatcher = None +ip_adapter: IPAdapterModel = None + + +def load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path): + global clip_vision, image_proj_model, ip_layers, ip_negative, ip_adapter + + if clip_vision_path is None: + return + if ip_negative_path is None: + return + if ip_adapter_path is None: + return + if clip_vision is not None and image_proj_model is not None and ip_layers is not None and ip_negative is not None: + return + + ip_negative = sf.load_file(ip_negative_path)['data'] + clip_vision = comfy.clip_vision.load(clip_vision_path) + + load_device = model_management.get_torch_device() + offload_device = torch.device('cpu') + + use_fp16 = model_management.should_use_fp16(device=load_device) + ip_state_dict = torch.load(ip_adapter_path, map_location="cpu") + plus = "latents" in ip_state_dict["image_proj"] + cross_attention_dim = ip_state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1] + sdxl = cross_attention_dim == 2048 + sdxl_plus = sdxl and plus + + if plus: + clip_extra_context_tokens = ip_state_dict["image_proj"]["latents"].shape[1] + clip_embeddings_dim = ip_state_dict["image_proj"]["latents"].shape[2] + else: + clip_extra_context_tokens = ip_state_dict["image_proj"]["proj.weight"].shape[0] // cross_attention_dim + clip_embeddings_dim = None + + ip_adapter = IPAdapterModel( + ip_state_dict, + plus=plus, + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=clip_embeddings_dim, + clip_extra_context_tokens=clip_extra_context_tokens, + sdxl_plus=sdxl_plus + ) + ip_adapter.sdxl = sdxl + ip_adapter.load_device = load_device + ip_adapter.offload_device = offload_device + ip_adapter.dtype = torch.float16 if use_fp16 else torch.float32 + ip_adapter.to(offload_device, dtype=ip_adapter.dtype) + + image_proj_model = ModelPatcher(model=ip_adapter.image_proj_model, load_device=load_device, + offload_device=offload_device) + ip_layers = ModelPatcher(model=ip_adapter.ip_layers, load_device=load_device, + offload_device=offload_device) + + return + + +@torch.no_grad() +@torch.inference_mode() +def preprocess(img): + inputs = clip_vision.processor(images=img, return_tensors="pt") + comfy.model_management.load_models_gpu([clip_vision.patcher, image_proj_model]) + pixel_values = inputs['pixel_values'].to(clip_vision.load_device) + + if clip_vision.dtype != torch.float32: + precision_scope = torch.autocast + else: + precision_scope = lambda a, b: contextlib.nullcontext(a) + + with precision_scope(comfy.model_management.get_autocast_device(clip_vision.load_device), torch.float32): + outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True) + + if ip_adapter.plus: + cond = outputs.hidden_states[-2].to(ip_adapter.dtype) + else: + cond = outputs.image_embeds.to(ip_adapter.dtype) + + outputs = image_proj_model.model(cond) + + return outputs + + +@torch.no_grad() +@torch.inference_mode() +def patch_model(model, ip_tasks): + new_model = model.clone() + + tasks = [] + for cn_img, cn_stop, cn_weight in ip_tasks: + tasks.append((cn_img, cn_stop, cn_weight, {})) + + def make_attn_patcher(ip_index): + ip_model_k = ip_layers.model.to_kvs[ip_index * 2] + ip_model_v = ip_layers.model.to_kvs[ip_index * 2 + 1] + + def patcher(n, context_attn2, value_attn2, extra_options): + org_dtype = n.dtype + current_step = float(model.model.diffusion_model.current_step.detach().cpu().numpy()[0]) + cond_or_uncond = extra_options['cond_or_uncond'] + + with torch.autocast("cuda", dtype=ip_adapter.dtype): + q = n + k = [context_attn2] + v = [value_attn2] + b, _, _ = q.shape + batch_prompt = b // len(cond_or_uncond) + + for cn_img, cn_stop, cn_weight, cache in tasks: + if current_step < cn_stop: + if ip_index in cache: + ip_k, ip_v = cache[ip_index] + else: + ip_model_k.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) + ip_model_v.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) + cond = cn_img.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype).repeat(batch_prompt, 1, 1) + uncond = ip_negative.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype).repeat(batch_prompt, 1, 1) + uncond_cond = torch.cat([(cond, uncond)[i] for i in cond_or_uncond], dim=0) + ip_k = ip_model_k(uncond_cond) + ip_v = ip_model_v(uncond_cond) + + # Midjourney's attention formulation of image prompt (non-official reimplementation) + # Written by Lvmin Zhang at Stanford University, 2023 Dec + # For non-commercial use only - if you use this in commercial project then + # probably it has some intellectual property issues. + # Contact lvminzhang@acm.org if you are not sure. + + # Below is the sensitive part with potential intellectual property issues. + + ip_v_mean = torch.mean(ip_v, dim=1, keepdim=True) + ip_v_offset = ip_v - ip_v_mean + + B, F, C = ip_k.shape + channel_penalty = float(C) / 1280.0 + weight = cn_weight * channel_penalty + + ip_k = ip_k * weight + ip_v = ip_v_offset + ip_v_mean * weight + + # The sensitive part ends here. + + cache[ip_index] = ip_k, ip_v + ip_model_k.to(device=ip_adapter.offload_device, dtype=ip_adapter.dtype) + ip_model_v.to(device=ip_adapter.offload_device, dtype=ip_adapter.dtype) + + k.append(ip_k) + v.append(ip_v) + + k = torch.cat(k, dim=1) + v = torch.cat(v, dim=1) + out = sdp(q, k, v, extra_options) + + return out.to(dtype=org_dtype) + return patcher + + def set_model_patch_replace(model, number, key): + to = model.model_options["transformer_options"] + if "patches_replace" not in to: + to["patches_replace"] = {} + if "attn2" not in to["patches_replace"]: + to["patches_replace"]["attn2"] = {} + if key not in to["patches_replace"]["attn2"]: + to["patches_replace"]["attn2"][key] = make_attn_patcher(number) + + number = 0 + if not ip_adapter.sdxl: + for id in [1, 2, 4, 5, 7, 8]: # id of input_blocks that have cross attention + set_model_patch_replace(new_model, number, ("input", id)) + number += 1 + for id in [3, 4, 5, 6, 7, 8, 9, 10, 11]: # id of output_blocks that have cross attention + set_model_patch_replace(new_model, number, ("output", id)) + number += 1 + set_model_patch_replace(new_model, number, ("middle", 0)) + else: + for id in [4, 5, 7, 8]: # id of input_blocks that have cross attention + block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth + for index in block_indices: + set_model_patch_replace(new_model, number, ("input", id, index)) + number += 1 + for id in range(6): # id of output_blocks that have cross attention + block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth + for index in block_indices: + set_model_patch_replace(new_model, number, ("output", id, index)) + number += 1 + for index in range(10): + set_model_patch_replace(new_model, number, ("middle", 0, index)) + number += 1 + + return new_model diff --git a/fooocus_extras/preprocessors.py b/fooocus_extras/preprocessors.py new file mode 100644 index 00000000..ce31a7af --- /dev/null +++ b/fooocus_extras/preprocessors.py @@ -0,0 +1,42 @@ +import cv2 +import numpy as np + + +def canny_k(x, k=0.5): + import cv2 + H, W, C = x.shape + Hs, Ws = int(H * k), int(W * k) + small = cv2.resize(x, (Ws, Hs), interpolation=cv2.INTER_AREA) + return cv2.Canny(small, 100, 200).astype(np.float32) / 255.0 + + +def canny_pyramid(x): + # For some reasons, SAI's Control-lora Canny seems to be trained on canny maps with non-standard resolutions. + # Then we use pyramid to use all resolutions to avoid missing any structure in specific resolutions. + + ks = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + cs = [canny_k(x, k) for k in ks] + cur = None + + for c in cs: + if cur is None: + cur = c + else: + H, W = c.shape + cur = cv2.resize(cur, (W, H), interpolation=cv2.INTER_LINEAR) + cur = cur * 0.75 + c * 0.25 + + cur *= 400.0 + + return cur.clip(0, 255).astype(np.uint8) + + +def cpds(x): + import cv2 + # cv2.decolor is not "decolor", it is Cewu Lu's method + # See http://www.cse.cuhk.edu.hk/leojia/projects/color2gray/index.html + # See https://docs.opencv.org/3.0-beta/modules/photo/doc/decolor.html + + y = np.ascontiguousarray(x[:, :, ::-1].copy()) + y = cv2.decolor(y)[0] + return y diff --git a/fooocus_extras/resampler.py b/fooocus_extras/resampler.py new file mode 100644 index 00000000..4521c8c3 --- /dev/null +++ b/fooocus_extras/resampler.py @@ -0,0 +1,121 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +import math + +import torch +import torch.nn as nn + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) \ No newline at end of file diff --git a/fooocus_version.py b/fooocus_version.py index 32594f9a..dc58ef8a 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.0.93' +version = '2.1.0' diff --git a/launch.py b/launch.py index 8fd2cd26..39a03620 100644 --- a/launch.py +++ b/launch.py @@ -22,7 +22,7 @@ def prepare_environment(): xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') comfy_repo = os.environ.get('COMFY_REPO', "https://github.com/comfyanonymous/ComfyUI") - comfy_commit_hash = os.environ.get('COMFY_COMMIT_HASH', "9bfec2bdbf0b0d778087a9b32f79e57e2d15b913") + comfy_commit_hash = os.environ.get('COMFY_COMMIT_HASH', "1c5d6663faf1a33e00ec67240167b174a9cac655") print(f"Python {sys.version}") print(f"Fooocus version: {fooocus_version.version}") diff --git a/modules/advanced_parameters.py b/modules/advanced_parameters.py new file mode 100644 index 00000000..46884eac --- /dev/null +++ b/modules/advanced_parameters.py @@ -0,0 +1,21 @@ +adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, \ + scheduler_name, overwrite_step, overwrite_switch, overwrite_width, overwrite_height, \ + overwrite_vary_strength, overwrite_upscale_strength, \ + mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint, \ + debugging_cn_preprocessor, disable_soft_cn = [None] * 16 + + +def set_all_advanced_parameters(*args): + global adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, \ + scheduler_name, overwrite_step, overwrite_switch, overwrite_width, overwrite_height, \ + overwrite_vary_strength, overwrite_upscale_strength, \ + mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint, \ + debugging_cn_preprocessor, disable_soft_cn + + adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, \ + scheduler_name, overwrite_step, overwrite_switch, overwrite_width, overwrite_height, \ + overwrite_vary_strength, overwrite_upscale_strength, \ + mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint, \ + debugging_cn_preprocessor, disable_soft_cn = args + + return diff --git a/modules/async_worker.py b/modules/async_worker.py index 45a906a5..f99c376b 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -21,7 +21,10 @@ def worker(): import modules.path import modules.patch import comfy.model_management + import fooocus_extras.preprocessors as preprocessors import modules.inpaint_worker as inpaint_worker + import modules.advanced_parameters as advanced_parameters + import fooocus_extras.ip_adapter as ip_adapter from modules.sdxl_styles import apply_style, aspect_ratios, fooocus_expansion from modules.private_logger import log @@ -44,24 +47,42 @@ def worker(): @torch.no_grad() @torch.inference_mode() - def handler(task): + def handler(args): execution_start_time = time.perf_counter() - prompt, negative_prompt, style_selections, performance_selection, \ - aspect_ratios_selection, image_number, image_seed, sharpness, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, guidance_scale, adaptive_cfg, sampler_name, scheduler_name, \ - overwrite_step, overwrite_switch, overwrite_width, overwrite_height, overwrite_vary_strength, overwrite_upscale_strength, \ - base_model_name, refiner_model_name, \ - l1, w1, l2, w2, l3, w3, l4, w4, l5, w5, \ - input_image_checkbox, current_tab, \ - uov_method, uov_input_image, outpaint_selections, inpaint_input_image = task + args.reverse() + + prompt = args.pop() + negative_prompt = args.pop() + style_selections = args.pop() + performance_selection = args.pop() + aspect_ratios_selection = args.pop() + image_number = args.pop() + image_seed = args.pop() + sharpness = args.pop() + guidance_scale = args.pop() + base_model_name = args.pop() + refiner_model_name = args.pop() + loras = [(args.pop(), args.pop()) for _ in range(5)] + input_image_checkbox = args.pop() + current_tab = args.pop() + uov_method = args.pop() + uov_input_image = args.pop() + outpaint_selections = args.pop() + inpaint_input_image = args.pop() + + cn_tasks = {flags.cn_ip: [], flags.cn_canny: [], flags.cn_cpds: []} + for _ in range(4): + cn_img = args.pop() + cn_stop = args.pop() + cn_weight = args.pop() + cn_type = args.pop() + if cn_img is not None: + cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight]) outpaint_selections = [o.lower() for o in outpaint_selections] - - loras = [(l1, w1), (l2, w2), (l3, w3), (l4, w4), (l5, w5)] - loras_user_raw_input = copy.deepcopy(loras) - + loras_raw = copy.deepcopy(loras) raw_style_selections = copy.deepcopy(style_selections) - uov_method = uov_method.lower() if fooocus_expansion in style_selections: @@ -72,15 +93,15 @@ def worker(): use_style = len(style_selections) > 0 - modules.patch.adaptive_cfg = adaptive_cfg + modules.patch.adaptive_cfg = advanced_parameters.adaptive_cfg print(f'[Parameters] Adaptive CFG = {modules.patch.adaptive_cfg}') modules.patch.sharpness = sharpness print(f'[Parameters] Sharpness = {modules.patch.sharpness}') - modules.patch.positive_adm_scale = adm_scaler_positive - modules.patch.negative_adm_scale = adm_scaler_negative - modules.patch.adm_scaler_end = adm_scaler_end + modules.patch.positive_adm_scale = advanced_parameters.adm_scaler_positive + modules.patch.negative_adm_scale = advanced_parameters.adm_scaler_negative + modules.patch.adm_scaler_end = advanced_parameters.adm_scaler_end print(f'[Parameters] ADM Scale = {modules.patch.positive_adm_scale} : {modules.patch.negative_adm_scale} : {modules.patch.adm_scaler_end}') cfg_scale = float(guidance_scale) @@ -90,197 +111,18 @@ def worker(): denoising_strength = 1.0 tiled = False inpaint_worker.current_task = None - - if performance_selection == 'Speed': - steps = 30 - switch = 20 - else: - steps = 60 - switch = 40 - - if overwrite_step > 0: - steps = overwrite_step - - if overwrite_switch > 0: - switch = overwrite_switch - - pipeline.clear_all_caches() # save memory - width, height = aspect_ratios[aspect_ratios_selection] - - if overwrite_width > 0: - width = overwrite_width - - if overwrite_height > 0: - height = overwrite_height - - if input_image_checkbox: - progressbar(0, 'Image processing ...') - if current_tab == 'uov' and uov_method != flags.disabled and uov_input_image is not None: - uov_input_image = HWC3(uov_input_image) - if 'vary' in uov_method: - if not image_is_generated_in_current_ui(uov_input_image, ui_width=width, ui_height=height): - uov_input_image = resize_image(uov_input_image, width=width, height=height) - print(f'Resolution corrected - users are uploading their own images.') - else: - print(f'Processing images generated by Fooocus.') - if 'subtle' in uov_method: - denoising_strength = 0.5 - if 'strong' in uov_method: - denoising_strength = 0.85 - if overwrite_vary_strength > 0: - denoising_strength = overwrite_vary_strength - initial_pixels = core.numpy_to_pytorch(uov_input_image) - progressbar(0, 'VAE encoding ...') - initial_latent = core.encode_vae(vae=pipeline.xl_base_patched.vae, pixels=initial_pixels) - B, C, H, W = initial_latent['samples'].shape - width = W * 8 - height = H * 8 - print(f'Final resolution is {str((height, width))}.') - elif 'upscale' in uov_method: - H, W, C = uov_input_image.shape - progressbar(0, f'Upscaling image from {str((H, W))} ...') - - uov_input_image = core.numpy_to_pytorch(uov_input_image) - uov_input_image = perform_upscale(uov_input_image) - uov_input_image = core.pytorch_to_numpy(uov_input_image)[0] - print(f'Image upscaled.') - - if '1.5x' in uov_method: - f = 1.5 - elif '2x' in uov_method: - f = 2.0 - else: - f = 1.0 - - width_f = int(width * f) - height_f = int(height * f) - - if image_is_generated_in_current_ui(uov_input_image, ui_width=width_f, ui_height=height_f): - uov_input_image = resize_image(uov_input_image, width=int(W * f), height=int(H * f)) - print(f'Processing images generated by Fooocus.') - else: - uov_input_image = resize_image(uov_input_image, width=width_f, height=height_f) - print(f'Resolution corrected - users are uploading their own images.') - - H, W, C = uov_input_image.shape - image_is_super_large = H * W > 2800 * 2800 - - if 'fast' in uov_method: - direct_return = True - elif image_is_super_large: - print('Image is too large. Directly returned the SR image. ' - 'Usually directly return SR image at 4K resolution ' - 'yields better results than SDXL diffusion.') - direct_return = True - else: - direct_return = False - - if direct_return: - d = [('Upscale (Fast)', '2x')] - log(uov_input_image, d, single_line_number=1) - outputs.append(['results', [uov_input_image]]) - return - - tiled = True - denoising_strength = 1.0 - 0.618 - steps = int(steps * 0.618) - switch = int(steps * 0.67) - - if overwrite_upscale_strength > 0: - denoising_strength = overwrite_upscale_strength - if overwrite_step > 0: - steps = overwrite_step - if overwrite_switch > 0: - switch = overwrite_switch - - initial_pixels = core.numpy_to_pytorch(uov_input_image) - progressbar(0, 'VAE encoding ...') - - initial_latent = core.encode_vae(vae=pipeline.xl_base_patched.vae, pixels=initial_pixels, tiled=True) - B, C, H, W = initial_latent['samples'].shape - width = W * 8 - height = H * 8 - print(f'Final resolution is {str((height, width))}.') - if current_tab == 'inpaint' and isinstance(inpaint_input_image, dict): - inpaint_image = inpaint_input_image['image'] - inpaint_mask = inpaint_input_image['mask'][:, :, 0] - if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \ - and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0): - if len(outpaint_selections) > 0: - H, W, C = inpaint_image.shape - if 'top' in outpaint_selections: - inpaint_image = np.pad(inpaint_image, [[int(H * 0.3), 0], [0, 0], [0, 0]], mode='edge') - inpaint_mask = np.pad(inpaint_mask, [[int(H * 0.3), 0], [0, 0]], mode='constant', constant_values=255) - if 'bottom' in outpaint_selections: - inpaint_image = np.pad(inpaint_image, [[0, int(H * 0.3)], [0, 0], [0, 0]], mode='edge') - inpaint_mask = np.pad(inpaint_mask, [[0, int(H * 0.3)], [0, 0]], mode='constant', constant_values=255) - - H, W, C = inpaint_image.shape - if 'left' in outpaint_selections: - inpaint_image = np.pad(inpaint_image, [[0, 0], [int(H * 0.3), 0], [0, 0]], mode='edge') - inpaint_mask = np.pad(inpaint_mask, [[0, 0], [int(H * 0.3), 0]], mode='constant', constant_values=255) - if 'right' in outpaint_selections: - inpaint_image = np.pad(inpaint_image, [[0, 0], [0, int(H * 0.3)], [0, 0]], mode='edge') - inpaint_mask = np.pad(inpaint_mask, [[0, 0], [0, int(H * 0.3)]], mode='constant', constant_values=255) - - inpaint_image = np.ascontiguousarray(inpaint_image.copy()) - inpaint_mask = np.ascontiguousarray(inpaint_mask.copy()) - - inpaint_worker.current_task = inpaint_worker.InpaintWorker(image=inpaint_image, mask=inpaint_mask, - is_outpaint=len(outpaint_selections) > 0) - - # print(f'Inpaint task: {str((height, width))}') - # outputs.append(['results', inpaint_worker.current_task.visualize_mask_processing()]) - # return - - progressbar(0, 'Downloading inpainter ...') - inpaint_head_model_path, inpaint_patch_model_path = modules.path.downloading_inpaint_models() - loras += [(inpaint_patch_model_path, 1.0)] - - inpaint_pixels = core.numpy_to_pytorch(inpaint_worker.current_task.image_ready) - progressbar(0, 'VAE encoding ...') - initial_latent = core.encode_vae(vae=pipeline.xl_base_patched.vae, pixels=inpaint_pixels) - inpaint_latent = initial_latent['samples'] - B, C, H, W = inpaint_latent.shape - inpaint_mask = core.numpy_to_pytorch(inpaint_worker.current_task.mask_ready[None]) - inpaint_mask = torch.nn.functional.avg_pool2d(inpaint_mask, (8, 8)) - inpaint_mask = torch.nn.functional.interpolate(inpaint_mask, (H, W), mode='bilinear') - inpaint_worker.current_task.load_latent(latent=inpaint_latent, mask=inpaint_mask) - - progressbar(0, 'VAE inpaint encoding ...') - - inpaint_mask = (inpaint_worker.current_task.mask_ready > 0).astype(np.float32) - inpaint_mask = torch.tensor(inpaint_mask).float() - - vae_dict = core.encode_vae_inpaint( - mask=inpaint_mask, vae=pipeline.xl_base_patched.vae, pixels=inpaint_pixels) - - inpaint_latent = vae_dict['samples'] - inpaint_mask = vae_dict['noise_mask'] - inpaint_worker.current_task.load_inpaint_guidance(latent=inpaint_latent, mask=inpaint_mask, model_path=inpaint_head_model_path) - - B, C, H, W = inpaint_latent.shape - height, width = inpaint_worker.current_task.image_raw.shape[:2] - print(f'Final resolution is {str((height, width))}, latent is {str((H * 8, W * 8))}.') - - sampler_name = 'dpmpp_fooocus_2m_sde_inpaint_seamless' - - print(f'[Parameters] Sampler = {sampler_name} - {scheduler_name}') - - progressbar(1, 'Initializing ...') + skip_prompt_processing = False raw_prompt = prompt raw_negative_prompt = negative_prompt - prompts = remove_empty_str([safe_str(p) for p in prompt.split('\n')], default='') - negative_prompts = remove_empty_str([safe_str(p) for p in negative_prompt.split('\n')], default='') - - prompt = prompts[0] - negative_prompt = negative_prompts[0] - - extra_positive_prompts = prompts[1:] if len(prompts) > 1 else [] - extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else [] + inpaint_image = None + inpaint_mask = None + inpaint_head_model_path = None + controlnet_canny_path = None + controlnet_cpds_path = None + clip_vision_path, ip_negative_path, ip_adapter_path = None, None, None seed = image_seed max_seed = int(1024 * 1024 * 1024) @@ -290,77 +132,329 @@ def worker(): seed = - seed seed = seed % max_seed - progressbar(3, 'Loading models ...') - - pipeline.refresh_everything( - refiner_model_name=refiner_model_name, - base_model_name=base_model_name, - loras=loras) - pipeline.prepare_text_encoder(async_call=False) - - progressbar(3, 'Processing prompts ...') - - positive_basic_workloads = [] - negative_basic_workloads = [] - - if use_style: - for s in style_selections: - p, n = apply_style(s, positive=prompt) - positive_basic_workloads.append(p) - negative_basic_workloads.append(n) + if performance_selection == 'Speed': + steps = 30 + switch = 20 else: - positive_basic_workloads.append(prompt) + steps = 60 + switch = 40 - negative_basic_workloads.append(negative_prompt) # Always use independent workload for negative. + sampler_name = advanced_parameters.sampler_name + scheduler_name = advanced_parameters.scheduler_name - positive_basic_workloads = positive_basic_workloads + extra_positive_prompts - negative_basic_workloads = negative_basic_workloads + extra_negative_prompts + goals = [] + tasks = [] - positive_basic_workloads = remove_empty_str(positive_basic_workloads, default=prompt) - negative_basic_workloads = remove_empty_str(negative_basic_workloads, default=negative_prompt) + if input_image_checkbox: + progressbar(13, 'Image processing ...') + if (current_tab == 'uov' or (current_tab == 'ip' and advanced_parameters.mixing_image_prompt_and_vary_upscale)) \ + and uov_method != flags.disabled and uov_input_image is not None: + uov_input_image = HWC3(uov_input_image) + if 'vary' in uov_method: + goals.append('vary') + elif 'upscale' in uov_method: + goals.append('upscale') + if 'fast' in uov_method: + skip_prompt_processing = True + else: + if performance_selection == 'Speed': + steps = 18 + switch = 12 + else: + steps = 36 + switch = 24 + if (current_tab == 'inpaint' or (current_tab == 'ip' and advanced_parameters.mixing_image_prompt_and_inpaint))\ + and isinstance(inpaint_input_image, dict): + inpaint_image = inpaint_input_image['image'] + inpaint_mask = inpaint_input_image['mask'][:, :, 0] + inpaint_image = HWC3(inpaint_image) + if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \ + and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0): + progressbar(1, 'Downloading inpainter ...') + inpaint_head_model_path, inpaint_patch_model_path = modules.path.downloading_inpaint_models() + loras += [(inpaint_patch_model_path, 1.0)] + goals.append('inpaint') + sampler_name = 'dpmpp_fooocus_2m_sde_inpaint_seamless' + if current_tab == 'ip' or \ + advanced_parameters.mixing_image_prompt_and_inpaint or \ + advanced_parameters.mixing_image_prompt_and_vary_upscale: + goals.append('cn') + progressbar(1, 'Downloading control models ...') + if len(cn_tasks[flags.cn_canny]) > 0: + controlnet_canny_path = modules.path.downloading_controlnet_canny() + if len(cn_tasks[flags.cn_cpds]) > 0: + controlnet_cpds_path = modules.path.downloading_controlnet_cpds() + if len(cn_tasks[flags.cn_ip]) > 0: + clip_vision_path, ip_negative_path, ip_adapter_path = modules.path.downloading_ip_adapters() + progressbar(1, 'Loading control models ...') - positive_top_k = len(positive_basic_workloads) - negative_top_k = len(negative_basic_workloads) + # Load or unload CNs + pipeline.refresh_controlnets([controlnet_canny_path, controlnet_cpds_path]) + ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path) - tasks = [dict( - task_seed=seed + i, - positive=positive_basic_workloads, - negative=negative_basic_workloads, - expansion='', - c=[None, None], - uc=[None, None], - ) for i in range(image_number)] + if advanced_parameters.overwrite_step > 0: + steps = advanced_parameters.overwrite_step - if use_expansion: - for i, t in enumerate(tasks): - progressbar(5, f'Preparing Fooocus text #{i + 1} ...') - expansion = pipeline.expansion(prompt, t['task_seed']) - print(f'[Prompt Expansion] New suffix: {expansion}') - t['expansion'] = expansion - t['positive'] = copy.deepcopy(t['positive']) + [join_prompts(prompt, expansion)] # Deep copy. + if advanced_parameters.overwrite_switch > 0: + switch = advanced_parameters.overwrite_switch - for i, t in enumerate(tasks): - progressbar(7, f'Encoding base positive #{i + 1} ...') - t['c'][0] = pipeline.clip_encode(sd=pipeline.xl_base_patched, texts=t['positive'], - pool_top_k=positive_top_k) + if advanced_parameters.overwrite_width > 0: + width = advanced_parameters.overwrite_width - for i, t in enumerate(tasks): - progressbar(9, f'Encoding base negative #{i + 1} ...') - t['uc'][0] = pipeline.clip_encode(sd=pipeline.xl_base_patched, texts=t['negative'], - pool_top_k=negative_top_k) + if advanced_parameters.overwrite_height > 0: + height = advanced_parameters.overwrite_height - if pipeline.xl_refiner is not None: - for i, t in enumerate(tasks): - progressbar(11, f'Encoding refiner positive #{i + 1} ...') - t['c'][1] = pipeline.clip_separate(t['c'][0]) + print(f'[Parameters] Sampler = {sampler_name} - {scheduler_name}') + print(f'[Parameters] Steps = {steps} - {switch}') + + progressbar(1, 'Initializing ...') + + if not skip_prompt_processing: + + prompts = remove_empty_str([safe_str(p) for p in prompt.split('\n')], default='') + negative_prompts = remove_empty_str([safe_str(p) for p in negative_prompt.split('\n')], default='') + + prompt = prompts[0] + negative_prompt = negative_prompts[0] + + extra_positive_prompts = prompts[1:] if len(prompts) > 1 else [] + extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else [] + + progressbar(3, 'Loading models ...') + pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras) + + progressbar(3, 'Processing prompts ...') + positive_basic_workloads = [] + negative_basic_workloads = [] + + if use_style: + for s in style_selections: + p, n = apply_style(s, positive=prompt) + positive_basic_workloads.append(p) + negative_basic_workloads.append(n) + else: + positive_basic_workloads.append(prompt) + + negative_basic_workloads.append(negative_prompt) # Always use independent workload for negative. + + positive_basic_workloads = positive_basic_workloads + extra_positive_prompts + negative_basic_workloads = negative_basic_workloads + extra_negative_prompts + + positive_basic_workloads = remove_empty_str(positive_basic_workloads, default=prompt) + negative_basic_workloads = remove_empty_str(negative_basic_workloads, default=negative_prompt) + + positive_top_k = len(positive_basic_workloads) + negative_top_k = len(negative_basic_workloads) + + tasks = [dict( + task_seed=seed + i, + positive=positive_basic_workloads, + negative=negative_basic_workloads, + expansion='', + c=None, + uc=None, + ) for i in range(image_number)] + + if use_expansion: + for i, t in enumerate(tasks): + progressbar(5, f'Preparing Fooocus text #{i + 1} ...') + expansion = pipeline.final_expansion(prompt, t['task_seed']) + print(f'[Prompt Expansion] New suffix: {expansion}') + t['expansion'] = expansion + t['positive'] = copy.deepcopy(t['positive']) + [join_prompts(prompt, expansion)] # Deep copy. for i, t in enumerate(tasks): - progressbar(13, f'Encoding refiner negative #{i + 1} ...') - t['uc'][1] = pipeline.clip_separate(t['uc'][0]) + progressbar(7, f'Encoding positive #{i + 1} ...') + t['c'] = pipeline.clip_encode(texts=t['positive'], pool_top_k=positive_top_k) + + for i, t in enumerate(tasks): + progressbar(10, f'Encoding negative #{i + 1} ...') + t['uc'] = pipeline.clip_encode(texts=t['negative'], pool_top_k=negative_top_k) + + if len(goals) > 0: + progressbar(13, 'Image processing ...') + + if 'vary' in goals: + if not image_is_generated_in_current_ui(uov_input_image, ui_width=width, ui_height=height): + uov_input_image = resize_image(uov_input_image, width=width, height=height) + print(f'Resolution corrected - users are uploading their own images.') + else: + print(f'Processing images generated by Fooocus.') + if 'subtle' in uov_method: + denoising_strength = 0.5 + if 'strong' in uov_method: + denoising_strength = 0.85 + if advanced_parameters.overwrite_vary_strength > 0: + denoising_strength = advanced_parameters.overwrite_vary_strength + initial_pixels = core.numpy_to_pytorch(uov_input_image) + progressbar(13, 'VAE encoding ...') + initial_latent = core.encode_vae(vae=pipeline.final_vae, pixels=initial_pixels) + B, C, H, W = initial_latent['samples'].shape + width = W * 8 + height = H * 8 + print(f'Final resolution is {str((height, width))}.') + + if 'upscale' in goals: + H, W, C = uov_input_image.shape + progressbar(13, f'Upscaling image from {str((H, W))} ...') + + uov_input_image = core.numpy_to_pytorch(uov_input_image) + uov_input_image = perform_upscale(uov_input_image) + uov_input_image = core.pytorch_to_numpy(uov_input_image)[0] + print(f'Image upscaled.') + + if '1.5x' in uov_method: + f = 1.5 + elif '2x' in uov_method: + f = 2.0 + else: + f = 1.0 + + width_f = int(width * f) + height_f = int(height * f) + + if image_is_generated_in_current_ui(uov_input_image, ui_width=width_f, ui_height=height_f): + uov_input_image = resize_image(uov_input_image, width=int(W * f), height=int(H * f)) + print(f'Processing images generated by Fooocus.') + else: + uov_input_image = resize_image(uov_input_image, width=width_f, height=height_f) + print(f'Resolution corrected - users are uploading their own images.') + + H, W, C = uov_input_image.shape + image_is_super_large = H * W > 2800 * 2800 + + if 'fast' in uov_method: + direct_return = True + elif image_is_super_large: + print('Image is too large. Directly returned the SR image. ' + 'Usually directly return SR image at 4K resolution ' + 'yields better results than SDXL diffusion.') + direct_return = True + else: + direct_return = False + + if direct_return: + d = [('Upscale (Fast)', '2x')] + log(uov_input_image, d, single_line_number=1) + outputs.append(['results', [uov_input_image]]) + return + + tiled = True + denoising_strength = 0.382 + + if advanced_parameters.overwrite_upscale_strength > 0: + denoising_strength = advanced_parameters.overwrite_upscale_strength + + initial_pixels = core.numpy_to_pytorch(uov_input_image) + progressbar(13, 'VAE encoding ...') + + initial_latent = core.encode_vae(vae=pipeline.final_vae, pixels=initial_pixels, tiled=True) + B, C, H, W = initial_latent['samples'].shape + width = W * 8 + height = H * 8 + print(f'Final resolution is {str((height, width))}.') + + if 'inpaint' in goals: + if len(outpaint_selections) > 0: + H, W, C = inpaint_image.shape + if 'top' in outpaint_selections: + inpaint_image = np.pad(inpaint_image, [[int(H * 0.3), 0], [0, 0], [0, 0]], mode='edge') + inpaint_mask = np.pad(inpaint_mask, [[int(H * 0.3), 0], [0, 0]], mode='constant', + constant_values=255) + if 'bottom' in outpaint_selections: + inpaint_image = np.pad(inpaint_image, [[0, int(H * 0.3)], [0, 0], [0, 0]], mode='edge') + inpaint_mask = np.pad(inpaint_mask, [[0, int(H * 0.3)], [0, 0]], mode='constant', + constant_values=255) + + H, W, C = inpaint_image.shape + if 'left' in outpaint_selections: + inpaint_image = np.pad(inpaint_image, [[0, 0], [int(H * 0.3), 0], [0, 0]], mode='edge') + inpaint_mask = np.pad(inpaint_mask, [[0, 0], [int(H * 0.3), 0]], mode='constant', + constant_values=255) + if 'right' in outpaint_selections: + inpaint_image = np.pad(inpaint_image, [[0, 0], [0, int(H * 0.3)], [0, 0]], mode='edge') + inpaint_mask = np.pad(inpaint_mask, [[0, 0], [0, int(H * 0.3)]], mode='constant', + constant_values=255) + + inpaint_image = np.ascontiguousarray(inpaint_image.copy()) + inpaint_mask = np.ascontiguousarray(inpaint_mask.copy()) + + inpaint_worker.current_task = inpaint_worker.InpaintWorker(image=inpaint_image, mask=inpaint_mask, + is_outpaint=len(outpaint_selections) > 0) + + # print(f'Inpaint task: {str((height, width))}') + # outputs.append(['results', inpaint_worker.current_task.visualize_mask_processing()]) + # return + + progressbar(13, 'VAE encoding ...') + inpaint_pixels = core.numpy_to_pytorch(inpaint_worker.current_task.image_ready) + initial_latent = core.encode_vae(vae=pipeline.final_vae, pixels=inpaint_pixels) + inpaint_latent = initial_latent['samples'] + B, C, H, W = inpaint_latent.shape + inpaint_mask = core.numpy_to_pytorch(inpaint_worker.current_task.mask_ready[None]) + inpaint_mask = torch.nn.functional.avg_pool2d(inpaint_mask, (8, 8)) + inpaint_mask = torch.nn.functional.interpolate(inpaint_mask, (H, W), mode='bilinear') + inpaint_worker.current_task.load_latent(latent=inpaint_latent, mask=inpaint_mask) + + progressbar(13, 'VAE inpaint encoding ...') + + inpaint_mask = (inpaint_worker.current_task.mask_ready > 0).astype(np.float32) + inpaint_mask = torch.tensor(inpaint_mask).float() + + vae_dict = core.encode_vae_inpaint( + mask=inpaint_mask, vae=pipeline.final_vae, pixels=inpaint_pixels) + + inpaint_latent = vae_dict['samples'] + inpaint_mask = vae_dict['noise_mask'] + inpaint_worker.current_task.load_inpaint_guidance(latent=inpaint_latent, mask=inpaint_mask, + model_path=inpaint_head_model_path) + + B, C, H, W = inpaint_latent.shape + final_height, final_width = inpaint_worker.current_task.image_raw.shape[:2] + height, width = H * 8, W * 8 + print(f'Final resolution is {str((final_height, final_width))}, latent is {str((height, width))}.') + + if 'cn' in goals: + for task in cn_tasks[flags.cn_canny]: + cn_img, cn_stop, cn_weight = task + cn_img = resize_image(HWC3(cn_img), width=width, height=height) + cn_img = preprocessors.canny_pyramid(cn_img) + cn_img = HWC3(cn_img) + task[0] = core.numpy_to_pytorch(cn_img) + if advanced_parameters.debugging_cn_preprocessor: + outputs.append(['results', [cn_img]]) + return + for task in cn_tasks[flags.cn_cpds]: + cn_img, cn_stop, cn_weight = task + cn_img = resize_image(HWC3(cn_img), width=width, height=height) + cn_img = preprocessors.cpds(cn_img) + cn_img = HWC3(cn_img) + task[0] = core.numpy_to_pytorch(cn_img) + if advanced_parameters.debugging_cn_preprocessor: + outputs.append(['results', [cn_img]]) + return + for task in cn_tasks[flags.cn_ip]: + cn_img, cn_stop, cn_weight = task + cn_img = HWC3(cn_img) + task[0] = ip_adapter.preprocess(cn_img) + + if len(cn_tasks[flags.cn_ip]) > 0: + pipeline.final_unet = ip_adapter.patch_model(pipeline.final_unet, cn_tasks[flags.cn_ip]) results = [] all_steps = steps * image_number + preparation_time = time.perf_counter() - execution_start_time + print(f'Preparation time: {preparation_time:.2f} seconds') + + outputs.append(['preview', (13, 'Moving model to GPU ...', None)]) + execution_start_time = time.perf_counter() + comfy.model_management.load_models_gpu([pipeline.final_unet]) + moving_time = time.perf_counter() - execution_start_time + print(f'Moving model to GPU: {moving_time:.2f} seconds') + + outputs.append(['preview', (13, 'Starting tasks ...', None)]) + def callback(step, x0, x, total_steps, y): done_steps = current_task_id * steps + step outputs.append(['preview', ( @@ -368,17 +462,25 @@ def worker(): f'Step {step}/{total_steps} in the {current_task_id + 1}-th Sampling', y)]) - preparation_time = time.perf_counter() - execution_start_time - print(f'Preparation time: {preparation_time:.2f} seconds') - - outputs.append(['preview', (13, 'Starting tasks ...', None)]) for current_task_id, task in enumerate(tasks): execution_start_time = time.perf_counter() try: + positive_cond, negative_cond = task['c'], task['uc'] + + if 'cn' in goals: + for cn_flag, cn_path in [ + (flags.cn_canny, controlnet_canny_path), + (flags.cn_cpds, controlnet_cpds_path) + ]: + for cn_img, cn_stop, cn_weight in cn_tasks[cn_flag]: + positive_cond, negative_cond = core.apply_controlnet( + positive_cond, negative_cond, + pipeline.loaded_ControlNets[cn_path], cn_img, cn_weight, 0, cn_stop) + imgs = pipeline.process_diffusion( - positive_cond=task['c'], - negative_cond=task['uc'], + positive_cond=positive_cond, + negative_cond=negative_cond, steps=steps, switch=switch, width=width, @@ -393,6 +495,8 @@ def worker(): cfg_scale=cfg_scale ) + del task['c'], task['uc'], positive_cond, negative_cond # Save memory + if inpaint_worker.current_task is not None: imgs = [inpaint_worker.current_task.post_process(x) for x in imgs] @@ -406,14 +510,14 @@ def worker(): ('Resolution', str((width, height))), ('Sharpness', sharpness), ('Guidance Scale', guidance_scale), - ('ADM Guidance', str((adm_scaler_positive, adm_scaler_negative))), + ('ADM Guidance', str((modules.patch.positive_adm_scale, modules.patch.negative_adm_scale))), ('Base Model', base_model_name), ('Refiner Model', refiner_model_name), ('Sampler', sampler_name), ('Scheduler', scheduler_name), ('Seed', task['task_seed']) ] - for n, w in loras_user_raw_input: + for n, w in loras_raw: if n != 'None': d.append((f'LoRA [{n}] weight', w)) log(x, d, single_line_number=3) diff --git a/modules/core.py b/modules/core.py index b5777b94..b947afbd 100644 --- a/modules/core.py +++ b/modules/core.py @@ -13,13 +13,16 @@ import comfy.model_management import comfy.model_detection import comfy.model_patcher import comfy.utils +import comfy.controlnet +import modules.sample_hijack +import comfy.samplers from comfy.sd import load_checkpoint_guess_config -from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, VAEEncodeForInpaint -from comfy.sample import prepare_mask, broadcast_cond, get_additional_models, cleanup_additional_models +from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, VAEEncodeForInpaint, \ + ControlNetApplyAdvanced +from comfy.sample import prepare_mask from modules.patch import patched_sampler_cfg_function, patched_model_function_wrapper from comfy.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora -from modules.samplers_advanced import KSamplerBasic, KSamplerWithRefiner opEmptyLatentImage = EmptyLatentImage() @@ -28,6 +31,7 @@ opVAEEncode = VAEEncode() opVAEDecodeTiled = VAEDecodeTiled() opVAEEncodeTiled = VAEEncodeTiled() opVAEEncodeForInpaint = VAEEncodeForInpaint() +opControlNetApplyAdvanced = ControlNetApplyAdvanced() class StableDiffusionModel: @@ -38,6 +42,19 @@ class StableDiffusionModel: self.clip_vision = clip_vision +@torch.no_grad() +@torch.inference_mode() +def load_controlnet(ckpt_filename): + return comfy.controlnet.load_controlnet(ckpt_filename) + + +@torch.no_grad() +@torch.inference_mode() +def apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent): + return opControlNetApplyAdvanced.apply_controlnet(positive=positive, negative=negative, control_net=control_net, + image=image, strength=strength, start_percent=start_percent, end_percent=end_percent) + + @torch.no_grad() @torch.inference_mode() def load_unet_only(unet_path): @@ -214,12 +231,8 @@ def get_previewer(): @torch.inference_mode() def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_fooocus_2m_sde_inpaint_seamless', scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, - force_full_denoise=False, callback_function=None): - seed = seed if isinstance(seed, int) else random.randint(0, 2**63 - 1) - - device = comfy.model_management.get_torch_device() + force_full_denoise=False, callback_function=None, refiner=None, refiner_switch=-1): latent_image = latent["samples"] - if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: @@ -232,8 +245,6 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa previewer = get_previewer() - pbar = comfy.utils.ProgressBar(steps) - def callback(step, x0, x, total_steps): comfy.model_management.throw_exception_if_processing_interrupted() y = None @@ -241,111 +252,23 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa y = previewer(x0, step, total_steps) if callback_function is not None: callback_function(step, x0, x, total_steps, y) - pbar.update_absolute(step + 1, total_steps, None) - sigmas = None disable_pbar = False + modules.sample_hijack.current_refiner = refiner + modules.sample_hijack.refiner_switch_step = refiner_switch + comfy.samplers.sample = modules.sample_hijack.sample_hacked - if noise_mask is not None: - noise_mask = prepare_mask(noise_mask, noise.shape, device) + try: + samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise=denoise, disable_noise=disable_noise, start_step=start_step, + last_step=last_step, + force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, + disable_pbar=disable_pbar, seed=seed) - models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) - comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory) - real_model = model.model - - noise = noise.to(device) - latent_image = latent_image.to(device) - - positive_copy = broadcast_cond(positive, noise.shape[0], device) - negative_copy = broadcast_cond(negative, noise.shape[0], device) - - sampler = KSamplerBasic(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, - denoise=denoise, model_options=model.model_options) - - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, - start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, - denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, - seed=seed) - - samples = samples.cpu() - - cleanup_additional_models(models) - - out = latent.copy() - out["samples"] = samples - - return out - - -@torch.no_grad() -@torch.inference_mode() -def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive, refiner_negative, latent, - seed=None, steps=30, refiner_switch_step=20, cfg=7.0, sampler_name='dpmpp_fooocus_2m_sde_inpaint_seamless', - scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, - force_full_denoise=False, callback_function=None): - seed = seed if isinstance(seed, int) else random.randint(0, 2**63 - 1) - - device = comfy.model_management.get_torch_device() - latent_image = latent["samples"] - - if disable_noise: - noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") - else: - batch_inds = latent["batch_index"] if "batch_index" in latent else None - noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) - - noise_mask = None - if "noise_mask" in latent: - noise_mask = latent["noise_mask"] - - previewer = get_previewer() - - pbar = comfy.utils.ProgressBar(steps) - - def callback(step, x0, x, total_steps): - comfy.model_management.throw_exception_if_processing_interrupted() - y = None - if previewer is not None: - y = previewer(x0, step, total_steps) - if callback_function is not None: - callback_function(step, x0, x, total_steps, y) - pbar.update_absolute(step + 1, total_steps, None) - - sigmas = None - disable_pbar = False - - if noise_mask is not None: - noise_mask = prepare_mask(noise_mask, noise.shape, device) - - models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) - comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory) - - noise = noise.to(device) - latent_image = latent_image.to(device) - - positive_copy = broadcast_cond(positive, noise.shape[0], device) - negative_copy = broadcast_cond(negative, noise.shape[0], device) - - refiner_positive_copy = broadcast_cond(refiner_positive, noise.shape[0], device) - refiner_negative_copy = broadcast_cond(refiner_negative, noise.shape[0], device) - - sampler = KSamplerWithRefiner(model=model, refiner_model=refiner, steps=steps, device=device, - sampler=sampler_name, scheduler=scheduler, - denoise=denoise, model_options=model.model_options) - - samples = sampler.sample(noise, positive_copy, negative_copy, refiner_positive=refiner_positive_copy, - refiner_negative=refiner_negative_copy, refiner_switch_step=refiner_switch_step, - cfg=cfg, latent_image=latent_image, - start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, - denoise_mask=noise_mask, sigmas=sigmas, callback_function=callback, disable_pbar=disable_pbar, - seed=seed) - - samples = samples.cpu() - - cleanup_additional_models(models) - - out = latent.copy() - out["samples"] = samples + out = latent.copy() + out["samples"] = samples + finally: + modules.sample_hijack.current_refiner = None return out diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 78d0f3c9..684b96b6 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -18,6 +18,29 @@ xl_base_patched_hash = '' xl_refiner: ModelPatcher = None xl_refiner_hash = '' +final_expansion = None +final_unet = None +final_clip = None +final_vae = None +final_refiner = None + +loaded_ControlNets = {} + + +@torch.no_grad() +@torch.inference_mode() +def refresh_controlnets(model_paths): + global loaded_ControlNets + cache = {} + for p in model_paths: + if p is not None: + if p in loaded_ControlNets: + cache[p] = loaded_ControlNets[p] + else: + cache[p] = core.load_controlnet(p) + loaded_ControlNets = cache + return + @torch.no_grad() @torch.inference_mode() @@ -137,31 +160,21 @@ def clip_encode_single(clip, text, verbose=False): @torch.no_grad() @torch.inference_mode() -def clip_separate(cond): - c, p = cond[0] - c = c[..., -1280:].clone() - p = p["pooled_output"].clone() - return [[c, {"pooled_output": p}]] +def clip_encode(texts, pool_top_k=1): + global final_clip - -@torch.no_grad() -@torch.inference_mode() -def clip_encode(sd, texts, pool_top_k=1): - if sd is None: - return None - if sd.clip is None: + if final_clip is None: return None if not isinstance(texts, list): return None if len(texts) == 0: return None - clip = sd.clip cond_list = [] pooled_acc = 0 for i, text in enumerate(texts): - cond, pooled = clip_encode_single(clip, text) + cond, pooled = clip_encode_single(final_clip, text) cond_list.append(cond) if i < pool_top_k: pooled_acc += pooled @@ -176,13 +189,34 @@ def clear_all_caches(): xl_base_patched.clip.fcs_cond_cache = {} +@torch.no_grad() +@torch.inference_mode() +def prepare_text_encoder(async_call=True): + if async_call: + # TODO: make sure that this is always called in an async way so that users cannot feel it. + pass + assert_model_integrity() + comfy.model_management.load_models_gpu([final_clip.patcher, final_expansion.patcher]) + return + + @torch.no_grad() @torch.inference_mode() def refresh_everything(refiner_model_name, base_model_name, loras): + global final_unet, final_clip, final_vae, final_refiner, final_expansion + refresh_refiner_model(refiner_model_name) refresh_base_model(base_model_name) refresh_loras(loras) assert_model_integrity() + + final_unet, final_clip, final_vae, final_refiner = \ + xl_base_patched.unet, xl_base_patched.clip, xl_base_patched.vae, xl_refiner + + if final_expansion is None: + final_expansion = FooocusExpansion() + + prepare_text_encoder(async_call=True) clear_all_caches() return @@ -193,22 +227,6 @@ refresh_everything( loras=[(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)] ) -expansion = FooocusExpansion() - - -@torch.no_grad() -@torch.inference_mode() -def prepare_text_encoder(async_call=True): - if async_call: - # TODO: make sure that this is always called in an async way so that users cannot feel it. - pass - assert_model_integrity() - comfy.model_management.load_models_gpu([xl_base_patched.clip.patcher, expansion.patcher]) - return - - -prepare_text_encoder(async_call=True) - @torch.no_grad() @torch.inference_mode() @@ -218,40 +236,23 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height else: empty_latent = latent - if xl_refiner is not None: - sampled_latent = core.ksampler_with_refiner( - model=xl_base_patched.unet, - positive=positive_cond[0], - negative=negative_cond[0], - refiner=xl_refiner, - refiner_positive=positive_cond[1], - refiner_negative=negative_cond[1], - refiner_switch_step=switch, - latent=empty_latent, - steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True, - seed=image_seed, - denoise=denoise, - callback_function=callback, - cfg=cfg_scale, - sampler_name=sampler_name, - scheduler=scheduler_name - ) - else: - sampled_latent = core.ksampler( - model=xl_base_patched.unet, - positive=positive_cond[0], - negative=negative_cond[0], - latent=empty_latent, - steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True, - seed=image_seed, - denoise=denoise, - callback_function=callback, - cfg=cfg_scale, - sampler_name=sampler_name, - scheduler=scheduler_name - ) + sampled_latent = core.ksampler( + model=final_unet, + refiner=final_refiner, + positive=positive_cond, + negative=negative_cond, + latent=empty_latent, + steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True, + seed=image_seed, + denoise=denoise, + callback_function=callback, + cfg=cfg_scale, + sampler_name=sampler_name, + scheduler=scheduler_name, + refiner_switch=switch + ) - decoded_latent = core.decode_vae(vae=xl_base_patched.vae, latent_image=sampled_latent, tiled=tiled) + decoded_latent = core.decode_vae(vae=final_vae, latent_image=sampled_latent, tiled=tiled) images = core.pytorch_to_numpy(decoded_latent) comfy.model_management.soft_empty_cache() diff --git a/modules/flags.py b/modules/flags.py index cc0345db..9fd60adf 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -1,3 +1,6 @@ +import comfy.samplers + + disabled = 'Disabled' enabled = 'Enabled' subtle_variation = 'Vary (Subtle)' @@ -10,14 +13,19 @@ uov_list = [ disabled, subtle_variation, strong_variation, upscale_15, upscale_2, upscale_fast ] -sampler_list = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", - "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", - # "ddim", - "uni_pc", "uni_pc_bh2", - # "dpmpp_fooocus_2m_sde_inpaint_seamless" - ] +sampler_list = comfy.samplers.SAMPLER_NAMES default_sampler = 'dpmpp_2m_sde_gpu' -scheduler_list = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] +scheduler_list = comfy.samplers.SCHEDULER_NAMES default_scheduler = "karras" + +cn_ip = "Image Prompt" +cn_canny = "PyraCanny" +cn_cpds = "CPDS" + +ip_list = [cn_ip, cn_canny, cn_cpds] +default_ip = cn_ip + +default_parameters = { + cn_ip: (0.4, 0.6), cn_canny: (0.4, 1.0), cn_cpds: (0.4, 1.0) +} # stop, weight diff --git a/modules/html.py b/modules/html.py index b6c8e3ea..fe0d65cd 100644 --- a/modules/html.py +++ b/modules/html.py @@ -91,6 +91,11 @@ progress::after { min-width: min(1px, 100%) !important; } +.resizable_area { + resize: vertical; + overflow: auto !important; +} + ''' progress_html = '''