diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000..3ac19a0a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,14 @@ +--- +name: Bug report +about: Describe a problem +title: '' +labels: '' +assignees: '' + +--- + +**Describe the problem** +A clear and concise description of what the bug is. + +**Full Console Log** +Paste **full** console log here. You will make our job easier if you give a **full** log. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000..8101bc36 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,14 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: '' +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the idea you'd like** +A clear and concise description of what you want to happen. diff --git a/.gitignore b/.gitignore index ce656b04..30b6ca3b 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,4 @@ experiment.py /node_modules /package-lock.json /.coverage* +/auth.json diff --git a/auth-example.json b/auth-example.json new file mode 100644 index 00000000..59e321d0 --- /dev/null +++ b/auth-example.json @@ -0,0 +1,6 @@ +[ + { + "user": "sitting-duck-1", + "pass": "very-bad-publicly-known-password-change-it" + } +] diff --git a/backend/headless/fcbh/cldm/cldm.py b/backend/headless/fcbh/cldm/cldm.py index b177e924..d464a462 100644 --- a/backend/headless/fcbh/cldm/cldm.py +++ b/backend/headless/fcbh/cldm/cldm.py @@ -27,7 +27,6 @@ class ControlNet(nn.Module): model_channels, hint_channels, num_res_blocks, - attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, @@ -52,6 +51,7 @@ class ControlNet(nn.Module): use_linear_in_transformer=False, adm_in_channels=None, transformer_depth_middle=None, + transformer_depth_output=None, device=None, operations=fcbh.ops, ): @@ -79,10 +79,7 @@ class ControlNet(nn.Module): self.image_size = image_size self.in_channels = in_channels self.model_channels = model_channels - if isinstance(transformer_depth, int): - transformer_depth = len(channel_mult) * [transformer_depth] - if transformer_depth_middle is None: - transformer_depth_middle = transformer_depth[-1] + if isinstance(num_res_blocks, int): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: @@ -90,18 +87,16 @@ class ControlNet(nn.Module): raise ValueError("provide num_res_blocks either as an int (globally constant) or " "as a list/tuple (per-level) with the same length as channel_mult") self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") - self.attention_resolutions = attention_resolutions + transformer_depth = transformer_depth[:] + self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample @@ -180,11 +175,14 @@ class ControlNet(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, - operations=operations + dtype=self.dtype, + device=device, + operations=operations, ) ] ch = mult * model_channels - if ds in attention_resolutions: + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: if num_head_channels == -1: dim_head = ch // num_heads else: @@ -201,9 +199,9 @@ class ControlNet(nn.Module): if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: layers.append( SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, + ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, operations=operations + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -223,11 +221,13 @@ class ControlNet(nn.Module): use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, + dtype=self.dtype, + device=device, operations=operations ) if resblock_updown else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations + ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations ) ) ) @@ -245,7 +245,7 @@ class ControlNet(nn.Module): if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( + mid_block = [ ResBlock( ch, time_embed_dim, @@ -253,12 +253,15 @@ class ControlNet(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, operations=operations - ), - SpatialTransformer( # always uses a self-attn + )] + if transformer_depth_middle >= 0: + mid_block += [SpatialTransformer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, operations=operations + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ), ResBlock( ch, @@ -267,9 +270,11 @@ class ControlNet(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, operations=operations - ), - ) + )] + self.middle_block = TimestepEmbedSequential(*mid_block) self.middle_block_out = self.make_zero_conv(ch, operations=operations) self._feature_size += ch diff --git a/backend/headless/fcbh/cli_args.py b/backend/headless/fcbh/cli_args.py index 0b072376..85134e90 100644 --- a/backend/headless/fcbh/cli_args.py +++ b/backend/headless/fcbh/cli_args.py @@ -36,6 +36,8 @@ parser = argparse.ArgumentParser() parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") +parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.") + parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--output-directory", type=str, default=None, help="Set the fcbh_backend output directory.") parser.add_argument("--temp-directory", type=str, default=None, help="Set the fcbh_backend temp directory (default is in the fcbh_backend directory).") diff --git a/backend/headless/fcbh/clip_vision.py b/backend/headless/fcbh/clip_vision.py index b93b0da7..f3c4bb6c 100644 --- a/backend/headless/fcbh/clip_vision.py +++ b/backend/headless/fcbh/clip_vision.py @@ -1,5 +1,5 @@ -from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils -from .utils import load_torch_file, transformers_convert +from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils +from .utils import load_torch_file, transformers_convert, common_upscale import os import torch import contextlib @@ -7,6 +7,18 @@ import contextlib import fcbh.ops import fcbh.model_patcher import fcbh.model_management +import fcbh.utils + +def clip_preprocess(image, size=224): + mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) + std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) + scale = (size / min(image.shape[1], image.shape[2])) + image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True) + h = (image.shape[2] - size)//2 + w = (image.shape[3] - size)//2 + image = image[:,:,h:h+size,w:w+size] + image = torch.clip((255. * image), 0, 255).round() / 255.0 + return (image - mean.view([3,1,1])) / std.view([3,1,1]) class ClipVisionModel(): def __init__(self, json_config): @@ -23,25 +35,12 @@ class ClipVisionModel(): self.model.to(self.dtype) self.patcher = fcbh.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) - self.processor = CLIPImageProcessor(crop_size=224, - do_center_crop=True, - do_convert_rgb=True, - do_normalize=True, - do_resize=True, - image_mean=[ 0.48145466,0.4578275,0.40821073], - image_std=[0.26862954,0.26130258,0.27577711], - resample=3, #bicubic - size=224) - def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False) def encode_image(self, image): - img = torch.clip((255. * image), 0, 255).round().int() - img = list(map(lambda a: a, img)) - inputs = self.processor(images=img, return_tensors="pt") fcbh.model_management.load_model_gpu(self.patcher) - pixel_values = inputs['pixel_values'].to(self.load_device) + pixel_values = clip_preprocess(image.to(self.load_device)) if self.dtype != torch.float32: precision_scope = torch.autocast diff --git a/backend/headless/fcbh/conds.py b/backend/headless/fcbh/conds.py new file mode 100644 index 00000000..252bb869 --- /dev/null +++ b/backend/headless/fcbh/conds.py @@ -0,0 +1,64 @@ +import enum +import torch +import math +import fcbh.utils + + +def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) + return abs(a*b) // math.gcd(a, b) + +class CONDRegular: + def __init__(self, cond): + self.cond = cond + + def _copy_with(self, cond): + return self.__class__(cond) + + def process_cond(self, batch_size, device, **kwargs): + return self._copy_with(fcbh.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) + + def can_concat(self, other): + if self.cond.shape != other.cond.shape: + return False + return True + + def concat(self, others): + conds = [self.cond] + for x in others: + conds.append(x.cond) + return torch.cat(conds) + +class CONDNoiseShape(CONDRegular): + def process_cond(self, batch_size, device, area, **kwargs): + data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + return self._copy_with(fcbh.utils.repeat_to_batch_size(data, batch_size).to(device)) + + +class CONDCrossAttn(CONDRegular): + def can_concat(self, other): + s1 = self.cond.shape + s2 = other.cond.shape + if s1 != s2: + if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen + return False + + mult_min = lcm(s1[1], s2[1]) + diff = mult_min // min(s1[1], s2[1]) + if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much + return False + return True + + def concat(self, others): + conds = [self.cond] + crossattn_max_len = self.cond.shape[1] + for x in others: + c = x.cond + crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) + conds.append(c) + + out = [] + for c in conds: + if c.shape[1] < crossattn_max_len: + c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result + out.append(c) + return torch.cat(out) diff --git a/backend/headless/fcbh/controlnet.py b/backend/headless/fcbh/controlnet.py index a0858399..ab6c38f6 100644 --- a/backend/headless/fcbh/controlnet.py +++ b/backend/headless/fcbh/controlnet.py @@ -156,7 +156,7 @@ class ControlNet(ControlBase): context = cond['c_crossattn'] - y = cond.get('c_adm', None) + y = cond.get('y', None) if y is not None: y = y.to(self.control_model.dtype) control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) @@ -416,7 +416,7 @@ class T2IAdapter(ControlBase): if control_prev is not None: return control_prev else: - return {} + return None if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: diff --git a/backend/headless/fcbh/extra_samplers/uni_pc.py b/backend/headless/fcbh/extra_samplers/uni_pc.py index 58e030d0..9d5f0c60 100644 --- a/backend/headless/fcbh/extra_samplers/uni_pc.py +++ b/backend/headless/fcbh/extra_samplers/uni_pc.py @@ -881,7 +881,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex model_kwargs=extra_args, ) - order = min(3, len(timesteps) - 1) + order = min(3, len(timesteps) - 2) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) x /= ns.marginal_alpha(timesteps[-1]) diff --git a/backend/headless/fcbh/ldm/modules/attention.py b/backend/headless/fcbh/ldm/modules/attention.py index d00a9af4..c0383556 100644 --- a/backend/headless/fcbh/ldm/modules/attention.py +++ b/backend/headless/fcbh/ldm/modules/attention.py @@ -95,9 +95,19 @@ def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) def attention_basic(q, k, v, heads, mask=None): + b, _, dim_head = q.shape + dim_head //= heads + scale = dim_head ** -0.5 + h = heads - scale = (q.shape[-1] // heads) ** -0.5 - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": @@ -119,16 +129,24 @@ def attention_basic(q, k, v, heads, mask=None): sim = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) return out def attention_sub_quad(query, key, value, heads, mask=None): - scale = (query.shape[-1] // heads) ** -0.5 - query = query.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1) - key_t = key.transpose(1,2).unflatten(1, (heads, -1)).flatten(end_dim=1) - del key - value = value.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1) + b, _, dim_head = query.shape + dim_head //= heads + + scale = dim_head ** -0.5 + query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head) + value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head) + + key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1) dtype = query.dtype upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32 @@ -137,41 +155,28 @@ def attention_sub_quad(query, key, value, heads, mask=None): else: bytes_per_token = torch.finfo(query.dtype).bits//8 batch_x_heads, q_tokens, _ = query.shape - _, _, k_tokens = key_t.shape + _, _, k_tokens = key.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) - chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD - kv_chunk_size_min = None + kv_chunk_size = None + query_chunk_size = None - #not sure at all about the math here - #TODO: tweak this - if mem_free_total > 8192 * 1024 * 1024 * 1.3: - query_chunk_size_x = 1024 * 4 - elif mem_free_total > 4096 * 1024 * 1024 * 1.3: - query_chunk_size_x = 1024 * 2 - else: - query_chunk_size_x = 1024 - kv_chunk_size_min_x = None - kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024 - if kv_chunk_size_x < 1024: - kv_chunk_size_x = None + for x in [4096, 2048, 1024, 512, 256]: + count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0) + if count >= k_tokens: + kv_chunk_size = k_tokens + query_chunk_size = x + break - if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: - # the big matmul fits into our memory limit; do everything in 1 chunk, - # i.e. send it down the unchunked fast-path - query_chunk_size = q_tokens - kv_chunk_size = k_tokens - else: - query_chunk_size = query_chunk_size_x - kv_chunk_size = kv_chunk_size_x - kv_chunk_size_min = kv_chunk_size_min_x + if query_chunk_size is None: + query_chunk_size = 512 hidden_states = efficient_dot_product_attention( query, - key_t, + key, value, query_chunk_size=query_chunk_size, kv_chunk_size=kv_chunk_size, @@ -186,17 +191,32 @@ def attention_sub_quad(query, key, value, heads, mask=None): return hidden_states def attention_split(q, k, v, heads, mask=None): - scale = (q.shape[-1] // heads) ** -0.5 + b, _, dim_head = q.shape + dim_head //= heads + scale = dim_head ** -0.5 + h = heads - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) mem_free_total = model_management.get_free_memory(q.device) + if _ATTN_PRECISION =="fp32": + element_size = 4 + else: + element_size = q.element_size() + gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size + modifier = 3 mem_required = tensor_size * modifier steps = 1 @@ -224,10 +244,10 @@ def attention_split(q, k, v, heads, mask=None): s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale else: s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale - first_op_done = True s2 = s1.softmax(dim=-1).to(v.dtype) del s1 + first_op_done = True r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 @@ -248,17 +268,23 @@ def attention_split(q, k, v, heads, mask=None): del q, k, v - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - return r2 + r1 = ( + r1.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) + return r1 def attention_xformers(q, k, v, heads, mask=None): - b, _, _ = q.shape + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], heads, -1) + .reshape(b, -1, heads, dim_head) .permute(0, 2, 1, 3) - .reshape(b * heads, t.shape[1], -1) + .reshape(b * heads, -1, dim_head) .contiguous(), (q, k, v), ) @@ -270,9 +296,9 @@ def attention_xformers(q, k, v, heads, mask=None): raise NotImplementedError out = ( out.unsqueeze(0) - .reshape(b, heads, out.shape[1], -1) + .reshape(b, heads, -1, dim_head) .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], -1) + .reshape(b, -1, heads * dim_head) ) return out diff --git a/backend/headless/fcbh/ldm/modules/diffusionmodules/openaimodel.py b/backend/headless/fcbh/ldm/modules/diffusionmodules/openaimodel.py index d8ec0a62..9c7cfb8e 100644 --- a/backend/headless/fcbh/ldm/modules/diffusionmodules/openaimodel.py +++ b/backend/headless/fcbh/ldm/modules/diffusionmodules/openaimodel.py @@ -259,10 +259,6 @@ class UNetModel(nn.Module): :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param conv_resample: if True, use learned convolutions for upsampling and @@ -289,7 +285,6 @@ class UNetModel(nn.Module): model_channels, out_channels, num_res_blocks, - attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, @@ -314,6 +309,7 @@ class UNetModel(nn.Module): use_linear_in_transformer=False, adm_in_channels=None, transformer_depth_middle=None, + transformer_depth_output=None, device=None, operations=fcbh.ops, ): @@ -341,10 +337,7 @@ class UNetModel(nn.Module): self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels - if isinstance(transformer_depth, int): - transformer_depth = len(channel_mult) * [transformer_depth] - if transformer_depth_middle is None: - transformer_depth_middle = transformer_depth[-1] + if isinstance(num_res_blocks, int): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: @@ -352,18 +345,16 @@ class UNetModel(nn.Module): raise ValueError("provide num_res_blocks either as an int (globally constant) or " "as a list/tuple (per-level) with the same length as channel_mult") self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") - self.attention_resolutions = attention_resolutions + transformer_depth = transformer_depth[:] + transformer_depth_output = transformer_depth_output[:] + self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample @@ -428,7 +419,8 @@ class UNetModel(nn.Module): ) ] ch = mult * model_channels - if ds in attention_resolutions: + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: if num_head_channels == -1: dim_head = ch // num_heads else: @@ -444,7 +436,7 @@ class UNetModel(nn.Module): if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: layers.append(SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, + ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ) @@ -488,7 +480,7 @@ class UNetModel(nn.Module): if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( + mid_block = [ ResBlock( ch, time_embed_dim, @@ -499,8 +491,9 @@ class UNetModel(nn.Module): dtype=self.dtype, device=device, operations=operations - ), - SpatialTransformer( # always uses a self-attn + )] + if transformer_depth_middle >= 0: + mid_block += [SpatialTransformer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations @@ -515,8 +508,8 @@ class UNetModel(nn.Module): dtype=self.dtype, device=device, operations=operations - ), - ) + )] + self.middle_block = TimestepEmbedSequential(*mid_block) self._feature_size += ch self.output_blocks = nn.ModuleList([]) @@ -538,7 +531,8 @@ class UNetModel(nn.Module): ) ] ch = model_channels * mult - if ds in attention_resolutions: + num_transformers = transformer_depth_output.pop() + if num_transformers > 0: if num_head_channels == -1: dim_head = ch // num_heads else: @@ -555,7 +549,7 @@ class UNetModel(nn.Module): if not exists(num_attention_blocks) or i < num_attention_blocks[level]: layers.append( SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, + ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ) diff --git a/backend/headless/fcbh/ldm/modules/sub_quadratic_attention.py b/backend/headless/fcbh/ldm/modules/sub_quadratic_attention.py index 1f07431a..11d1dd45 100644 --- a/backend/headless/fcbh/ldm/modules/sub_quadratic_attention.py +++ b/backend/headless/fcbh/ldm/modules/sub_quadratic_attention.py @@ -83,7 +83,8 @@ def _summarize_chunk( ) max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() - torch.exp(attn_weights - max_score, out=attn_weights) + attn_weights -= max_score + torch.exp(attn_weights, out=attn_weights) exp_weights = attn_weights.to(value.dtype) exp_values = torch.bmm(exp_weights, value) max_score = max_score.squeeze(-1) diff --git a/backend/headless/fcbh/lora.py b/backend/headless/fcbh/lora.py index 4c1c5684..3bec26b5 100644 --- a/backend/headless/fcbh/lora.py +++ b/backend/headless/fcbh/lora.py @@ -141,9 +141,9 @@ def model_lora_keys_clip(model, key_map={}): text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False - for b in range(32): + for b in range(32): #TODO: clean up for c in LORA_CLIP_MAP: - k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) key_map[lora_key] = k @@ -154,6 +154,8 @@ def model_lora_keys_clip(model, key_map={}): k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: + lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base key_map[lora_key] = k clip_l_present = True diff --git a/backend/headless/fcbh/model_base.py b/backend/headless/fcbh/model_base.py index f3f708f7..86525d99 100644 --- a/backend/headless/fcbh/model_base.py +++ b/backend/headless/fcbh/model_base.py @@ -4,6 +4,7 @@ from fcbh.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmen from fcbh.ldm.modules.diffusionmodules.util import make_beta_schedule from fcbh.ldm.modules.diffusionmodules.openaimodel import Timestep import fcbh.model_management +import fcbh.conds import numpy as np from enum import Enum from . import utils @@ -49,7 +50,7 @@ class BaseModel(torch.nn.Module): self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) - def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}): + def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): if c_concat is not None: xc = torch.cat([x] + [c_concat], dim=1) else: @@ -59,9 +60,10 @@ class BaseModel(torch.nn.Module): xc = xc.to(dtype) t = t.to(dtype) context = context.to(dtype) - if c_adm is not None: - c_adm = c_adm.to(dtype) - return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float() + extra_conds = {} + for o in kwargs: + extra_conds[o] = kwargs[o].to(dtype) + return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() def get_dtype(self): return self.diffusion_model.dtype @@ -72,7 +74,8 @@ class BaseModel(torch.nn.Module): def encode_adm(self, **kwargs): return None - def cond_concat(self, **kwargs): + def extra_conds(self, **kwargs): + out = {} if self.inpaint_model: concat_keys = ("mask", "masked_image") cond_concat = [] @@ -101,8 +104,12 @@ class BaseModel(torch.nn.Module): cond_concat.append(torch.ones_like(noise)[:,:1]) elif ck == "masked_image": cond_concat.append(blank_inpaint_image_like(noise)) - return cond_concat - return None + data = torch.cat(cond_concat, dim=1) + out['c_concat'] = fcbh.conds.CONDNoiseShape(data) + adm = self.encode_adm(**kwargs) + if adm is not None: + out['y'] = fcbh.conds.CONDRegular(adm) + return out def load_model_weights(self, sd, unet_prefix=""): to_load = {} diff --git a/backend/headless/fcbh/model_detection.py b/backend/headless/fcbh/model_detection.py index cc3d10ee..53851274 100644 --- a/backend/headless/fcbh/model_detection.py +++ b/backend/headless/fcbh/model_detection.py @@ -14,6 +14,19 @@ def count_blocks(state_dict_keys, prefix_string): count += 1 return count +def calculate_transformer_depth(prefix, state_dict_keys, state_dict): + context_dim = None + use_linear_in_transformer = False + + transformer_prefix = prefix + "1.transformer_blocks." + transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys))) + if len(transformer_keys) > 0: + last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') + context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] + use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 + return last_transformer_depth, context_dim, use_linear_in_transformer + return None + def detect_unet_config(state_dict, key_prefix, dtype): state_dict_keys = list(state_dict.keys()) @@ -40,6 +53,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): channel_mult = [] attention_resolutions = [] transformer_depth = [] + transformer_depth_output = [] context_dim = None use_linear_in_transformer = False @@ -48,60 +62,67 @@ def detect_unet_config(state_dict, key_prefix, dtype): count = 0 last_res_blocks = 0 - last_transformer_depth = 0 last_channel_mult = 0 - while True: + input_block_count = count_blocks(state_dict_keys, '{}input_blocks'.format(key_prefix) + '.{}.') + for count in range(input_block_count): prefix = '{}input_blocks.{}.'.format(key_prefix, count) + prefix_output = '{}output_blocks.{}.'.format(key_prefix, input_block_count - count - 1) + block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys))) if len(block_keys) == 0: break + block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys))) + if "{}0.op.weight".format(prefix) in block_keys: #new layer - if last_transformer_depth > 0: - attention_resolutions.append(current_res) - transformer_depth.append(last_transformer_depth) num_res_blocks.append(last_res_blocks) channel_mult.append(last_channel_mult) current_res *= 2 last_res_blocks = 0 - last_transformer_depth = 0 last_channel_mult = 0 + out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict) + if out is not None: + transformer_depth_output.append(out[0]) + else: + transformer_depth_output.append(0) else: res_block_prefix = "{}0.in_layers.0.weight".format(prefix) if res_block_prefix in block_keys: last_res_blocks += 1 last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels - transformer_prefix = prefix + "1.transformer_blocks." - transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys))) - if len(transformer_keys) > 0: - last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') - if context_dim is None: - context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] - use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 + out = calculate_transformer_depth(prefix, state_dict_keys, state_dict) + if out is not None: + transformer_depth.append(out[0]) + if context_dim is None: + context_dim = out[1] + use_linear_in_transformer = out[2] + else: + transformer_depth.append(0) + + res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output) + if res_block_prefix in block_keys_output: + out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict) + if out is not None: + transformer_depth_output.append(out[0]) + else: + transformer_depth_output.append(0) - count += 1 - if last_transformer_depth > 0: - attention_resolutions.append(current_res) - transformer_depth.append(last_transformer_depth) num_res_blocks.append(last_res_blocks) channel_mult.append(last_channel_mult) - transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') - - if len(set(num_res_blocks)) == 1: - num_res_blocks = num_res_blocks[0] - - if len(set(transformer_depth)) == 1: - transformer_depth = transformer_depth[0] + if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys: + transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') + else: + transformer_depth_middle = -1 unet_config["in_channels"] = in_channels unet_config["model_channels"] = model_channels unet_config["num_res_blocks"] = num_res_blocks - unet_config["attention_resolutions"] = attention_resolutions unet_config["transformer_depth"] = transformer_depth + unet_config["transformer_depth_output"] = transformer_depth_output unet_config["channel_mult"] = channel_mult unet_config["transformer_depth_middle"] = transformer_depth_middle unet_config['use_linear_in_transformer'] = use_linear_in_transformer @@ -124,6 +145,45 @@ def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_ma else: return model_config +def convert_config(unet_config): + new_config = unet_config.copy() + num_res_blocks = new_config.get("num_res_blocks", None) + channel_mult = new_config.get("channel_mult", None) + + if isinstance(num_res_blocks, int): + num_res_blocks = len(channel_mult) * [num_res_blocks] + + if "attention_resolutions" in new_config: + attention_resolutions = new_config.pop("attention_resolutions") + transformer_depth = new_config.get("transformer_depth", None) + transformer_depth_middle = new_config.get("transformer_depth_middle", None) + + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + if transformer_depth_middle is None: + transformer_depth_middle = transformer_depth[-1] + t_in = [] + t_out = [] + s = 1 + for i in range(len(num_res_blocks)): + res = num_res_blocks[i] + d = 0 + if s in attention_resolutions: + d = transformer_depth[i] + + t_in += [d] * res + t_out += [d] * (res + 1) + s *= 2 + transformer_depth = t_in + transformer_depth_output = t_out + new_config["transformer_depth"] = t_in + new_config["transformer_depth_output"] = t_out + new_config["transformer_depth_middle"] = transformer_depth_middle + + new_config["num_res_blocks"] = num_res_blocks + return new_config + + def unet_config_from_diffusers_unet(state_dict, dtype): match = {} attention_resolutions = [] @@ -200,7 +260,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype): matches = False break if matches: - return unet_config + return convert_config(unet_config) return None def model_config_from_diffusers_unet(state_dict, dtype): diff --git a/backend/headless/fcbh/model_management.py b/backend/headless/fcbh/model_management.py index d6cee6c3..75108eed 100644 --- a/backend/headless/fcbh/model_management.py +++ b/backend/headless/fcbh/model_management.py @@ -339,7 +339,11 @@ def free_memory(memory_required, device, keep_loaded=[]): if unloaded_model: soft_empty_cache() - + else: + if vram_state != VRAMState.HIGH_VRAM: + mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) + if mem_free_torch > mem_free_total * 0.25: + soft_empty_cache() def load_models_gpu(models, memory_required=0): global vram_state diff --git a/backend/headless/fcbh/sample.py b/backend/headless/fcbh/sample.py index b6e0fddc..55946160 100644 --- a/backend/headless/fcbh/sample.py +++ b/backend/headless/fcbh/sample.py @@ -1,6 +1,7 @@ import torch import fcbh.model_management import fcbh.samplers +import fcbh.conds import fcbh.utils import math import numpy as np @@ -33,22 +34,24 @@ def prepare_mask(noise_mask, shape, device): noise_mask = noise_mask.to(device) return noise_mask -def broadcast_cond(cond, batch, device): - """broadcasts conditioning to the batch size""" - copy = [] - for p in cond: - t = fcbh.utils.repeat_to_batch_size(p[0], batch) - t = t.to(device) - copy += [[t] + p[1:]] - return copy - def get_models_from_cond(cond, model_type): models = [] for c in cond: - if model_type in c[1]: - models += [c[1][model_type]] + if model_type in c: + models += [c[model_type]] return models +def convert_cond(cond): + out = [] + for c in cond: + temp = c[1].copy() + model_conds = temp.get("model_conds", {}) + if c[0] is not None: + model_conds["c_crossattn"] = fcbh.conds.CONDCrossAttn(c[0]) + temp["model_conds"] = model_conds + out.append(temp) + return out + def get_additional_models(positive, negative, dtype): """loads additional models in positive and negative conditioning""" control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) @@ -72,6 +75,8 @@ def cleanup_additional_models(models): def prepare_sampling(model, noise_shape, positive, negative, noise_mask): device = model.load_device + positive = convert_cond(positive) + negative = convert_cond(negative) if noise_mask is not None: noise_mask = prepare_mask(noise_mask, noise_shape, device) @@ -81,9 +86,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask): fcbh.model_management.load_models_gpu([model] + models, fcbh.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory) real_model = model.model - positive_copy = broadcast_cond(positive, noise_shape[0], device) - negative_copy = broadcast_cond(negative, noise_shape[0], device) - return real_model, positive_copy, negative_copy, noise_mask, models + return real_model, positive, negative, noise_mask, models def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): diff --git a/backend/headless/fcbh/samplers.py b/backend/headless/fcbh/samplers.py index fe414995..91050a4e 100644 --- a/backend/headless/fcbh/samplers.py +++ b/backend/headless/fcbh/samplers.py @@ -2,47 +2,44 @@ from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import external as k_diffusion_external from .extra_samplers import uni_pc import torch +import enum from fcbh import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps import math from fcbh import model_base import fcbh.utils +import fcbh.conds -def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) - return abs(a*b) // math.gcd(a, b) #The main sampling function shared by all the samplers #Returns predicted noise def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - def get_area_and_mult(cond, x_in, timestep_in): + def get_area_and_mult(conds, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 - if 'timestep_start' in cond[1]: - timestep_start = cond[1]['timestep_start'] + + if 'timestep_start' in conds: + timestep_start = conds['timestep_start'] if timestep_in[0] > timestep_start: return None - if 'timestep_end' in cond[1]: - timestep_end = cond[1]['timestep_end'] + if 'timestep_end' in conds: + timestep_end = conds['timestep_end'] if timestep_in[0] < timestep_end: return None - if 'area' in cond[1]: - area = cond[1]['area'] - if 'strength' in cond[1]: - strength = cond[1]['strength'] - - adm_cond = None - if 'adm_encoded' in cond[1]: - adm_cond = cond[1]['adm_encoded'] + if 'area' in conds: + area = conds['area'] + if 'strength' in conds: + strength = conds['strength'] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if 'mask' in cond[1]: + if 'mask' in conds: # Scale the mask to the size of the input # The mask should have been resized as we began the sampling process mask_strength = 1.0 - if "mask_strength" in cond[1]: - mask_strength = cond[1]["mask_strength"] - mask = cond[1]['mask'] + if "mask_strength" in conds: + mask_strength = conds["mask_strength"] + mask = conds['mask'] assert(mask.shape[1] == x_in.shape[2]) assert(mask.shape[2] == x_in.shape[3]) mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength @@ -51,7 +48,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod mask = torch.ones_like(input_x) mult = mask * strength - if 'mask' not in cond[1]: + if 'mask' not in conds: rr = 8 if area[2] != 0: for t in range(rr): @@ -67,27 +64,17 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) conditionning = {} - conditionning['c_crossattn'] = cond[0] - - if 'concat' in cond[1]: - cond_concat_in = cond[1]['concat'] - if cond_concat_in is not None and len(cond_concat_in) > 0: - cropped = [] - for x in cond_concat_in: - cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - cropped.append(cr) - conditionning['c_concat'] = torch.cat(cropped, dim=1) - - if adm_cond is not None: - conditionning['c_adm'] = adm_cond + model_conds = conds["model_conds"] + for c in model_conds: + conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) control = None - if 'control' in cond[1]: - control = cond[1]['control'] + if 'control' in conds: + control = conds['control'] patches = None - if 'gligen' in cond[1]: - gligen = cond[1]['gligen'] + if 'gligen' in conds: + gligen = conds['gligen'] patches = {} gligen_type = gligen[0] gligen_model = gligen[1] @@ -105,22 +92,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod return True if c1.keys() != c2.keys(): return False - if 'c_crossattn' in c1: - s1 = c1['c_crossattn'].shape - s2 = c2['c_crossattn'].shape - if s1 != s2: - if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen - return False - - mult_min = lcm(s1[1], s2[1]) - diff = mult_min // min(s1[1], s2[1]) - if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much - return False - if 'c_concat' in c1: - if c1['c_concat'].shape != c2['c_concat'].shape: - return False - if 'c_adm' in c1: - if c1['c_adm'].shape != c2['c_adm'].shape: + for k in c1: + if not c1[k].can_concat(c2[k]): return False return True @@ -149,31 +122,19 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod c_concat = [] c_adm = [] crossattn_max_len = 0 - for x in c_list: - if 'c_crossattn' in x: - c = x['c_crossattn'] - if crossattn_max_len == 0: - crossattn_max_len = c.shape[1] - else: - crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) - c_crossattn.append(c) - if 'c_concat' in x: - c_concat.append(x['c_concat']) - if 'c_adm' in x: - c_adm.append(x['c_adm']) - out = {} - c_crossattn_out = [] - for c in c_crossattn: - if c.shape[1] < crossattn_max_len: - c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result - c_crossattn_out.append(c) - if len(c_crossattn_out) > 0: - out['c_crossattn'] = torch.cat(c_crossattn_out) - if len(c_concat) > 0: - out['c_concat'] = torch.cat(c_concat) - if len(c_adm) > 0: - out['c_adm'] = torch.cat(c_adm) + temp = {} + for x in c_list: + for k in x: + cur = temp.get(k, []) + cur.append(x[k]) + temp[k] = cur + + out = {} + for k in temp: + conds = temp[k] + out[k] = conds[0].concat(conds[1:]) + return out def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options): @@ -389,19 +350,19 @@ def resolve_areas_and_cond_masks(conditions, h, w, device): # While we're doing this, we can also resolve the mask device and scaling for performance reasons for i in range(len(conditions)): c = conditions[i] - if 'area' in c[1]: - area = c[1]['area'] + if 'area' in c: + area = c['area'] if area[0] == "percentage": - modified = c[1].copy() + modified = c.copy() area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w)) modified['area'] = area - c = [c[0], modified] + c = modified conditions[i] = c - if 'mask' in c[1]: - mask = c[1]['mask'] + if 'mask' in c: + mask = c['mask'] mask = mask.to(device=device) - modified = c[1].copy() + modified = c.copy() if len(mask.shape) == 2: mask = mask.unsqueeze(0) if mask.shape[1] != h or mask.shape[2] != w: @@ -422,37 +383,39 @@ def resolve_areas_and_cond_masks(conditions, h, w, device): modified['area'] = area modified['mask'] = mask - conditions[i] = [c[0], modified] + conditions[i] = modified def create_cond_with_same_area_if_none(conds, c): - if 'area' not in c[1]: + if 'area' not in c: return - c_area = c[1]['area'] + c_area = c['area'] smallest = None for x in conds: - if 'area' in x[1]: - a = x[1]['area'] + if 'area' in x: + a = x['area'] if c_area[2] >= a[2] and c_area[3] >= a[3]: if a[0] + a[2] >= c_area[0] + c_area[2]: if a[1] + a[3] >= c_area[1] + c_area[3]: if smallest is None: smallest = x - elif 'area' not in smallest[1]: + elif 'area' not in smallest: smallest = x else: - if smallest[1]['area'][0] * smallest[1]['area'][1] > a[0] * a[1]: + if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]: smallest = x else: if smallest is None: smallest = x if smallest is None: return - if 'area' in smallest[1]: - if smallest[1]['area'] == c_area: + if 'area' in smallest: + if smallest['area'] == c_area: return - n = c[1].copy() - conds += [[smallest[0], n]] + + out = c.copy() + out['model_conds'] = smallest['model_conds'].copy() #TODO: which fields should be copied? + conds += [out] def calculate_start_end_timesteps(model, conds): for t in range(len(conds)): @@ -460,18 +423,18 @@ def calculate_start_end_timesteps(model, conds): timestep_start = None timestep_end = None - if 'start_percent' in x[1]: - timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0))) - if 'end_percent' in x[1]: - timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0))) + if 'start_percent' in x: + timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['start_percent'] * 999.0))) + if 'end_percent' in x: + timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['end_percent'] * 999.0))) if (timestep_start is not None) or (timestep_end is not None): - n = x[1].copy() + n = x.copy() if (timestep_start is not None): n['timestep_start'] = timestep_start if (timestep_end is not None): n['timestep_end'] = timestep_end - conds[t] = [x[0], n] + conds[t] = n def pre_run_control(model, conds): for t in range(len(conds)): @@ -480,8 +443,8 @@ def pre_run_control(model, conds): timestep_start = None timestep_end = None percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0)) - if 'control' in x[1]: - x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function) + if 'control' in x: + x['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function) def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] @@ -490,16 +453,16 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): uncond_other = [] for t in range(len(conds)): x = conds[t] - if 'area' not in x[1]: - if name in x[1] and x[1][name] is not None: - cond_cnets.append(x[1][name]) + if 'area' not in x: + if name in x and x[name] is not None: + cond_cnets.append(x[name]) else: cond_other.append((x, t)) for t in range(len(uncond)): x = uncond[t] - if 'area' not in x[1]: - if name in x[1] and x[1][name] is not None: - uncond_cnets.append(x[1][name]) + if 'area' not in x: + if name in x and x[name] is not None: + uncond_cnets.append(x[name]) else: uncond_other.append((x, t)) @@ -509,47 +472,35 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): for x in range(len(cond_cnets)): temp = uncond_other[x % len(uncond_other)] o = temp[0] - if name in o[1] and o[1][name] is not None: - n = o[1].copy() + if name in o and o[name] is not None: + n = o.copy() n[name] = uncond_fill_func(cond_cnets, x) - uncond += [[o[0], n]] + uncond += [n] else: - n = o[1].copy() + n = o.copy() n[name] = uncond_fill_func(cond_cnets, x) - uncond[temp[1]] = [o[0], n] + uncond[temp[1]] = n -def encode_adm(model, conds, batch_size, width, height, device, prompt_type): +def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs): for t in range(len(conds)): x = conds[t] - adm_out = None - if 'adm' in x[1]: - adm_out = x[1]["adm"] - else: - params = x[1].copy() - params["width"] = params.get("width", width * 8) - params["height"] = params.get("height", height * 8) - params["prompt_type"] = params.get("prompt_type", prompt_type) - adm_out = model.encode_adm(device=device, **params) - - if adm_out is not None: - x[1] = x[1].copy() - x[1]["adm_encoded"] = fcbh.utils.repeat_to_batch_size(adm_out, batch_size).to(device) - - return conds - -def encode_cond(model_function, key, conds, device, **kwargs): - for t in range(len(conds)): - x = conds[t] - params = x[1].copy() + params = x.copy() params["device"] = device + params["noise"] = noise + params["width"] = params.get("width", noise.shape[3] * 8) + params["height"] = params.get("height", noise.shape[2] * 8) + params["prompt_type"] = params.get("prompt_type", prompt_type) for k in kwargs: if k not in params: params[k] = kwargs[k] out = model_function(**params) - if out is not None: - x[1] = x[1].copy() - x[1][key] = out + x = x.copy() + model_conds = x['model_conds'].copy() + for k in out: + model_conds[k] = out[k] + x['model_conds'] = model_conds + conds[t] = x return conds class Sampler: @@ -667,19 +618,15 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model pre_run_control(model_wrap, negative + positive) - apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) if latent_image is not None: latent_image = model.process_latent_in(latent_image) - if model.is_adm(): - positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive") - negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative") - - if hasattr(model, 'cond_concat'): - positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) + if hasattr(model, 'extra_conds'): + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} diff --git a/backend/headless/fcbh/sd.py b/backend/headless/fcbh/sd.py index 5f1f0c6b..0982446b 100644 --- a/backend/headless/fcbh/sd.py +++ b/backend/headless/fcbh/sd.py @@ -360,7 +360,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl from . import latent_formats model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) - model_config.unet_config = unet_config + model_config.unet_config = model_detection.convert_config(unet_config) if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type) @@ -388,11 +388,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"): clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model.clip_h elif clip_config["target"].endswith("FrozenCLIPEmbedder"): clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model.clip_l load_clip_weights(w, state_dict) return (fcbh.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) diff --git a/backend/headless/fcbh/sd1_clip.py b/backend/headless/fcbh/sd1_clip.py index 45382b00..56beb81c 100644 --- a/backend/headless/fcbh/sd1_clip.py +++ b/backend/headless/fcbh/sd1_clip.py @@ -35,7 +35,7 @@ class ClipTokenWeightEncoder: return z_empty.cpu(), first_pooled.cpu() return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() -class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): +class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = [ "last", @@ -278,7 +278,13 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No valid_file = None for embed_dir in embedding_directory: - embed_path = os.path.join(embed_dir, embedding_name) + embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name)) + embed_dir = os.path.abspath(embed_dir) + try: + if os.path.commonpath((embed_dir, embed_path)) != embed_dir: + continue + except: + continue if not os.path.isfile(embed_path): extensions = ['.safetensors', '.pt', '.bin'] for x in extensions: @@ -336,7 +342,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No embed_out = next(iter(values)) return embed_out -class SD1Tokenizer: +class SDTokenizer: def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") @@ -448,3 +454,40 @@ class SD1Tokenizer: def untokenize(self, token_weight_pair): return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) + + +class SD1Tokenizer: + def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer): + self.clip_name = clip_name + self.clip = "clip_{}".format(self.clip_name) + setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory)) + + def tokenize_with_weights(self, text:str, return_word_ids=False): + out = {} + out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids) + return out + + def untokenize(self, token_weight_pair): + return getattr(self, self.clip).untokenize(token_weight_pair) + + +class SD1ClipModel(torch.nn.Module): + def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs): + super().__init__() + self.clip_name = clip_name + self.clip = "clip_{}".format(self.clip_name) + setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) + + def clip_layer(self, layer_idx): + getattr(self, self.clip).clip_layer(layer_idx) + + def reset_clip_layer(self): + getattr(self, self.clip).reset_clip_layer() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs = token_weight_pairs[self.clip_name] + out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs) + return out, pooled + + def load_sd(self, sd): + return getattr(self, self.clip).load_sd(sd) diff --git a/backend/headless/fcbh/sd2_clip.py b/backend/headless/fcbh/sd2_clip.py index e5cac64b..052fe9ba 100644 --- a/backend/headless/fcbh/sd2_clip.py +++ b/backend/headless/fcbh/sd2_clip.py @@ -2,7 +2,7 @@ from fcbh import sd1_clip import torch import os -class SD2ClipModel(sd1_clip.SD1ClipModel): +class SD2ClipHModel(sd1_clip.SDClipModel): def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): if layer == "penultimate": layer="hidden" @@ -12,6 +12,14 @@ class SD2ClipModel(sd1_clip.SD1ClipModel): super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) self.empty_tokens = [[49406] + [49407] + [0] * 75] -class SD2Tokenizer(sd1_clip.SD1Tokenizer): +class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024) + +class SD2Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None): + super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer) + +class SD2ClipModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, **kwargs): + super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs) diff --git a/backend/headless/fcbh/sdxl_clip.py b/backend/headless/fcbh/sdxl_clip.py index 2064ba41..b05005c4 100644 --- a/backend/headless/fcbh/sdxl_clip.py +++ b/backend/headless/fcbh/sdxl_clip.py @@ -2,7 +2,7 @@ from fcbh import sd1_clip import torch import os -class SDXLClipG(sd1_clip.SD1ClipModel): +class SDXLClipG(sd1_clip.SDClipModel): def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): if layer == "penultimate": layer="hidden" @@ -16,14 +16,14 @@ class SDXLClipG(sd1_clip.SD1ClipModel): def load_sd(self, sd): return super().load_sd(sd) -class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer): +class SDXLClipGTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') -class SDXLTokenizer(sd1_clip.SD1Tokenizer): +class SDXLTokenizer: def __init__(self, embedding_directory=None): - self.clip_l = sd1_clip.SD1Tokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) def tokenize_with_weights(self, text:str, return_word_ids=False): @@ -38,7 +38,7 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer): class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None): super().__init__() - self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype) self.clip_l.layer_norm_hidden_state = False self.clip_g = SDXLClipG(device=device, dtype=dtype) @@ -63,21 +63,6 @@ class SDXLClipModel(torch.nn.Module): else: return self.clip_l.load_sd(sd) -class SDXLRefinerClipModel(torch.nn.Module): +class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None): - super().__init__() - self.clip_g = SDXLClipG(device=device, dtype=dtype) - - def clip_layer(self, layer_idx): - self.clip_g.clip_layer(layer_idx) - - def reset_clip_layer(self): - self.clip_g.reset_clip_layer() - - def encode_token_weights(self, token_weight_pairs): - token_weight_pairs_g = token_weight_pairs["g"] - g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) - return g_out, g_pooled - - def load_sd(self, sd): - return self.clip_g.load_sd(sd) + super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) diff --git a/backend/headless/fcbh/supported_models.py b/backend/headless/fcbh/supported_models.py index bb8ae214..fdd4ea4f 100644 --- a/backend/headless/fcbh/supported_models.py +++ b/backend/headless/fcbh/supported_models.py @@ -38,8 +38,15 @@ class SD15(supported_models_base.BASE): if ids.dtype == torch.float32: state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() + replace_prefix = {} + replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {"clip_l.": "cond_stage_model."} + return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def clip_target(self): return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel) @@ -62,12 +69,12 @@ class SD20(supported_models_base.BASE): return model_base.ModelType.EPS def process_clip_state_dict(self, state_dict): - state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) + state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24) return state_dict def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} - replace_prefix[""] = "cond_stage_model.model." + replace_prefix["clip_h"] = "cond_stage_model.model" state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) return state_dict @@ -104,7 +111,7 @@ class SDXLRefiner(supported_models_base.BASE): "use_linear_in_transformer": True, "context_dim": 1280, "adm_in_channels": 2560, - "transformer_depth": [0, 4, 4, 0], + "transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0], } latent_format = latent_formats.SDXL @@ -139,7 +146,7 @@ class SDXL(supported_models_base.BASE): unet_config = { "model_channels": 320, "use_linear_in_transformer": True, - "transformer_depth": [0, 2, 10], + "transformer_depth": [0, 0, 2, 2, 10, 10], "context_dim": 2048, "adm_in_channels": 2816 } @@ -165,6 +172,7 @@ class SDXL(supported_models_base.BASE): replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model" state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection" + keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection" keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) @@ -189,5 +197,14 @@ class SDXL(supported_models_base.BASE): def clip_target(self): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) +class SSD1B(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 0, 2, 2, 4, 4], + "context_dim": 2048, + "adm_in_channels": 2816 + } -models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL] + +models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B] diff --git a/backend/headless/fcbh/utils.py b/backend/headless/fcbh/utils.py index 2f50c821..5a694b14 100644 --- a/backend/headless/fcbh/utils.py +++ b/backend/headless/fcbh/utils.py @@ -170,25 +170,12 @@ UNET_MAP_BASIC = { def unet_to_diffusers(unet_config): num_res_blocks = unet_config["num_res_blocks"] - attention_resolutions = unet_config["attention_resolutions"] channel_mult = unet_config["channel_mult"] - transformer_depth = unet_config["transformer_depth"] + transformer_depth = unet_config["transformer_depth"][:] + transformer_depth_output = unet_config["transformer_depth_output"][:] num_blocks = len(channel_mult) - if isinstance(num_res_blocks, int): - num_res_blocks = [num_res_blocks] * num_blocks - if isinstance(transformer_depth, int): - transformer_depth = [transformer_depth] * num_blocks - transformers_per_layer = [] - res = 1 - for i in range(num_blocks): - transformers = 0 - if res in attention_resolutions: - transformers = transformer_depth[i] - transformers_per_layer.append(transformers) - res *= 2 - - transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1]) + transformers_mid = unet_config.get("transformer_depth_middle", None) diffusers_unet_map = {} for x in range(num_blocks): @@ -196,10 +183,11 @@ def unet_to_diffusers(unet_config): for i in range(num_res_blocks[x]): for b in UNET_MAP_RESNET: diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b) - if transformers_per_layer[x] > 0: + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: for b in UNET_MAP_ATTENTIONS: diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b) - for t in range(transformers_per_layer[x]): + for t in range(num_transformers): for b in TRANSFORMER_BLOCKS: diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) n += 1 @@ -218,7 +206,6 @@ def unet_to_diffusers(unet_config): diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b) num_res_blocks = list(reversed(num_res_blocks)) - transformers_per_layer = list(reversed(transformers_per_layer)) for x in range(num_blocks): n = (num_res_blocks[x] + 1) * x l = num_res_blocks[x] + 1 @@ -227,11 +214,12 @@ def unet_to_diffusers(unet_config): for b in UNET_MAP_RESNET: diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b) c += 1 - if transformers_per_layer[x] > 0: + num_transformers = transformer_depth_output.pop() + if num_transformers > 0: c += 1 for b in UNET_MAP_ATTENTIONS: diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b) - for t in range(transformers_per_layer[x]): + for t in range(num_transformers): for b in TRANSFORMER_BLOCKS: diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) if i == l - 1: diff --git a/backend/headless/fcbh_extras/nodes_hypertile.py b/backend/headless/fcbh_extras/nodes_hypertile.py new file mode 100644 index 00000000..0d7d4c95 --- /dev/null +++ b/backend/headless/fcbh_extras/nodes_hypertile.py @@ -0,0 +1,83 @@ +#Taken from: https://github.com/tfernd/HyperTile/ + +import math +from einops import rearrange +import random + +def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int: + min_value = min(min_value, value) + + # All big divisors of value (inclusive) + divisors = [i for i in range(min_value, value + 1) if value % i == 0] + + ns = [value // i for i in divisors[:max_options]] # has at least 1 element + + random.seed(counter) + idx = random.randint(0, len(ns) - 1) + + return ns[idx] + +class HyperTile: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}), + "swap_size": ("INT", {"default": 2, "min": 1, "max": 128}), + "max_depth": ("INT", {"default": 0, "min": 0, "max": 10}), + "scale_depth": ("BOOLEAN", {"default": False}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, tile_size, swap_size, max_depth, scale_depth): + model_channels = model.model.model_config.unet_config["model_channels"] + + apply_to = set() + temp = model_channels + for x in range(max_depth + 1): + apply_to.add(temp) + temp *= 2 + + latent_tile_size = max(32, tile_size) // 8 + self.temp = None + self.counter = 1 + + def hypertile_in(q, k, v, extra_options): + if q.shape[-1] in apply_to: + shape = extra_options["original_shape"] + aspect_ratio = shape[-1] / shape[-2] + + hw = q.size(1) + h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) + + factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 + nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter) + self.counter += 1 + nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter) + self.counter += 1 + + if nh * nw > 1: + q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) + self.temp = (nh, nw, h, w) + return q, k, v + + return q, k, v + def hypertile_out(out, extra_options): + if self.temp is not None: + nh, nw, h, w = self.temp + self.temp = None + out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) + out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) + return out + + + m = model.clone() + m.set_model_attn1_patch(hypertile_in) + m.set_model_attn1_output_patch(hypertile_out) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "HyperTile": HyperTile, +} diff --git a/backend/headless/latent_preview.py b/backend/headless/latent_preview.py index 5b07078e..798c3aad 100644 --- a/backend/headless/latent_preview.py +++ b/backend/headless/latent_preview.py @@ -22,7 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer): self.taesd = taesd def decode_latent_to_preview(self, x0): - x_sample = self.taesd.decoder(x0)[0].detach() + x_sample = self.taesd.decoder(x0[:1])[0].detach() # x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2] x_sample = x_sample.sub(0.5).mul(2) diff --git a/backend/headless/nodes.py b/backend/headless/nodes.py index 2c1a25f0..b57cd823 100644 --- a/backend/headless/nodes.py +++ b/backend/headless/nodes.py @@ -1796,7 +1796,8 @@ def init_custom_nodes(): "nodes_clip_sdxl.py", "nodes_canny.py", "nodes_freelunch.py", - "nodes_custom_sampler.py" + "nodes_custom_sampler.py", + "nodes_hypertile.py", ] for node_file in extras_files: diff --git a/colab_fix.txt b/colab_fix.txt deleted file mode 100644 index 7b2445ce..00000000 --- a/colab_fix.txt +++ /dev/null @@ -1 +0,0 @@ -{"default_refiner": ""} \ No newline at end of file diff --git a/expansion_experiments.py b/expansion_experiments.py new file mode 100644 index 00000000..5a2a946a --- /dev/null +++ b/expansion_experiments.py @@ -0,0 +1,8 @@ +from modules.expansion import FooocusExpansion + +expansion = FooocusExpansion() + +text = 'a handsome man' + +for i in range(64): + print(expansion(text, seed=i)) diff --git a/fooocus_colab.ipynb b/fooocus_colab.ipynb index c3497c68..205dac55 100644 --- a/fooocus_colab.ipynb +++ b/fooocus_colab.ipynb @@ -10,9 +10,8 @@ "source": [ "!pip install pygit2==1.12.2\n", "%cd /content\n", - "!git clone https://github.com/lllyasviel/Fooocus\n", + "!git clone https://github.com/lllyasviel/Fooocus.git\n", "%cd /content/Fooocus\n", - "!cp colab_fix.txt user_path_config.txt\n", "!python entry_with_update.py --share\n" ] } diff --git a/fooocus_extras/ip_adapter.py b/fooocus_extras/ip_adapter.py index ac2bed22..aeb7de2d 100644 --- a/fooocus_extras/ip_adapter.py +++ b/fooocus_extras/ip_adapter.py @@ -7,6 +7,7 @@ import fcbh.ldm.modules.attention as attention from fooocus_extras.resampler import Resampler from fcbh.model_patcher import ModelPatcher +from modules.core import numpy_to_pytorch SD_V12_CHANNELS = [320] * 4 + [640] * 4 + [1280] * 4 + [1280] * 6 + [640] * 6 + [320] * 6 + [1280] * 2 @@ -144,14 +145,27 @@ def load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path): return +@torch.no_grad() +@torch.inference_mode() +def clip_preprocess(image): + mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=image.device, dtype=image.dtype).view([1, 3, 1, 1]) + std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=image.device, dtype=image.dtype).view([1, 3, 1, 1]) + image = image.movedim(-1, 1) + + # https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75 + B, C, H, W = image.shape + assert H == 224 and W == 224 + + return (image - mean) / std + + @torch.no_grad() @torch.inference_mode() def preprocess(img): global ip_unconds - inputs = clip_vision.processor(images=img, return_tensors="pt") fcbh.model_management.load_model_gpu(clip_vision.patcher) - pixel_values = inputs['pixel_values'].to(clip_vision.load_device) + pixel_values = clip_preprocess(numpy_to_pytorch(img).to(clip_vision.load_device)) if clip_vision.dtype != torch.float32: precision_scope = torch.autocast @@ -162,9 +176,11 @@ def preprocess(img): outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True) if ip_adapter.plus: - cond = outputs.hidden_states[-2].to(ip_adapter.dtype) + cond = outputs.hidden_states[-2] else: - cond = outputs.image_embeds.to(ip_adapter.dtype) + cond = outputs.image_embeds + + cond = cond.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) fcbh.model_management.load_model_gpu(image_proj_model) cond = image_proj_model.model(cond).to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) diff --git a/fooocus_extras/vae_interpose.py b/fooocus_extras/vae_interpose.py index 41f81927..b069b2ff 100644 --- a/fooocus_extras/vae_interpose.py +++ b/fooocus_extras/vae_interpose.py @@ -69,7 +69,7 @@ vae_approx_filename = os.path.join(vae_approx_path, 'xl-to-v1_interposer-v3.1.sa def parse(x): global vae_approx_model - x_origin = x['samples'].clone() + x_origin = x.clone() if vae_approx_model is None: model = Interposer() @@ -89,6 +89,5 @@ def parse(x): fcbh.model_management.load_model_gpu(vae_approx_model) x = x_origin.to(device=vae_approx_model.load_device, dtype=vae_approx_model.dtype) - x = vae_approx_model.model(x) - - return {'samples': x.to(x_origin)} + x = vae_approx_model.model(x).to(x_origin) + return x diff --git a/fooocus_version.py b/fooocus_version.py index 8d20af60..80f1f10e 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.722' +version = '2.1.774' diff --git a/models/prompt_expansion/fooocus_expansion/positive.txt b/models/prompt_expansion/fooocus_expansion/positive.txt new file mode 100644 index 00000000..9be121d8 --- /dev/null +++ b/models/prompt_expansion/fooocus_expansion/positive.txt @@ -0,0 +1,642 @@ +abundant +accelerated +accepted +accepting +acclaimed +accomplished +acknowledged +activated +adapted +adjusted +admirable +adorable +adorned +advanced +adventurous +advocated +aesthetic +affirmed +affluent +agile +aimed +aligned +alive +altered +amazing +ambient +amplified +analytical +animated +appealing +applauded +appreciated +ardent +aromatic +arranged +arresting +articulate +artistic +associated +assured +astonishing +astounding +atmosphere +attempted +attentive +attractive +authentic +authoritative +awarded +awesome +backed +background +baked +balance +balanced +balancing +beaten +beautiful +beloved +beneficial +benevolent +best +bestowed +blazing +blended +blessed +boosted +borne +brave +breathtaking +brewed +bright +brilliant +brought +built +burning +calm +calmed +candid +caring +carried +catchy +celebrated +celestial +certain +championed +changed +charismatic +charming +chased +cheered +cheerful +cherished +chic +chosen +cinematic +clad +classic +classy +clear +coached +coherent +collected +color +colorful +colors +colossal +combined +comforting +commanding +committed +compassionate +compatible +complete +complex +complimentary +composed +composition +comprehensive +conceived +conferred +confident +connected +considerable +considered +consistent +conspicuous +constructed +constructive +contemplated +contemporary +content +contrasted +conveyed +cooked +cool +coordinated +coupled +courageous +coveted +cozy +created +creative +credited +crisp +critical +cultivated +cured +curious +current +customized +cute +daring +darling +dazzling +decorated +decorative +dedicated +deep +defended +definitive +delicate +delightful +delivered +depicted +designed +desirable +desired +destined +detail +detailed +determined +developed +devoted +devout +diligent +direct +directed +discovered +dispatched +displayed +distilled +distinct +distinctive +distinguished +diverse +divine +dramatic +draped +dreamed +driven +dynamic +earnest +eased +ecstatic +educated +effective +elaborate +elegant +elevated +elite +eminent +emotional +empowered +empowering +enchanted +encouraged +endorsed +endowed +enduring +energetic +engaging +enhanced +enigmatic +enlightened +enormous +enticing +envisioned +epic +esteemed +eternal +everlasting +evolved +exalted +examining +excellent +exceptional +exciting +exclusive +exemplary +exotic +expansive +exposed +expressive +exquisite +extended +extraordinary +extremely +fabulous +facilitated +fair +faithful +famous +fancy +fantastic +fascinating +fashionable +fashioned +favorable +favored +fearless +fermented +fertile +festive +fiery +fine +finest +firm +fixed +flaming +flashing +flashy +flavored +flawless +flourishing +flowing +focus +focused +formal +formed +fortunate +fostering +frank +fresh +fried +friendly +fruitful +fulfilled +full +futuristic +generous +gentle +genuine +gifted +gigantic +glamorous +glorious +glossy +glowing +gorgeous +graceful +gracious +grand +granted +grateful +great +grilled +grounded +grown +guarded +guided +hailed +handsome +healing +healthy +heartfelt +heavenly +heroic +highly +historic +holistic +holy +honest +honored +hoped +hopeful +iconic +ideal +illuminated +illuminating +illumination +illustrious +imaginative +imagined +immense +immortal +imposing +impressive +improved +incredible +infinite +informed +ingenious +innocent +innovative +insightful +inspirational +inspired +inspiring +instructed +integrated +intense +intricate +intriguing +invaluable +invented +investigative +invincible +inviting +irresistible +joined +joyful +keen +kindly +kinetic +knockout +laced +lasting +lauded +lavish +legendary +lifted +light +limited +linked +lively +located +logical +loved +lovely +loving +loyal +lucid +lucky +lush +luxurious +luxury +magic +magical +magnificent +majestic +marked +marvelous +massive +matched +matured +meaningful +memorable +merged +merry +meticulous +mindful +miraculous +modern +modified +monstrous +monumental +motivated +motivational +moved +moving +mystical +mythical +naive +neat +new +nice +nifty +noble +notable +noteworthy +novel +nuanced +offered +open +optimal +optimistic +orderly +organized +original +originated +outstanding +overwhelming +paired +palpable +passionate +peaceful +perfect +perfected +perpetual +persistent +phenomenal +pious +pivotal +placed +planned +pleasant +pleased +pleasing +plentiful +plotted +plush +poetic +poignant +polished +positive +praised +precious +precise +premier +premium +presented +preserved +prestigious +pretty +priceless +prime +pristine +probing +productive +professional +profound +progressed +progressive +prominent +promoted +pronounced +propelled +proportional +prosperous +protected +provided +provocative +pure +pursued +pushed +quaint +quality +questioning +quiet +radiant +rare +rational +real +reborn +reclaimed +recognized +recovered +refined +reflected +refreshed +refreshing +related +relaxed +relentless +reliable +relieved +remarkable +renewed +renowned +representative +rescued +resilient +respected +respectful +restored +retrieved +revealed +revealing +revered +revived +rewarded +rich +roasted +robust +romantic +royal +sacred +salient +satisfied +satisfying +saturated +saved +scenic +scientific +select +sensational +serious +set +shaped +sharp +shielded +shining +shiny +shown +significant +silent +sincere +singular +situated +sleek +slick +smart +snug +solemn +solid +soothing +sophisticated +sought +sparkling +special +spectacular +sped +spirited +spiritual +splendid +spread +stable +steady +still +stimulated +stimulating +stirred +straightforward +striking +strong +structured +stunning +sturdy +stylish +sublime +successful +sunny +superb +superior +supplied +supported +supportive +supreme +sure +surreal +sweet +symbolic +symmetry +synchronized +systematic +tailored +taking +targeted +taught +tempting +tender +terrific +thankful +theatrical +thought +thoughtful +thrilled +thrilling +thriving +tidy +timeless +touching +tough +trained +tranquil +transformed +translucent +transparent +transported +tremendous +trendy +tried +trim +true +trustworthy +unbelievable +unconditional +uncovered +unified +unique +united +universal +unmatched +unparalleled +upheld +valiant +valued +varied +very +vibrant +virtuous +vivid +warm +wealthy +whole +winning +wished +witty +wonderful +worshipped +worthy diff --git a/modules/async_worker.py b/modules/async_worker.py index b7c3c18d..8b96b80d 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -3,12 +3,14 @@ import threading buffer = [] outputs = [] +global_results = [] def worker(): - global buffer, outputs + global buffer, outputs, global_results import traceback + import math import numpy as np import torch import time @@ -23,14 +25,15 @@ def worker(): import fcbh.model_management import fooocus_extras.preprocessors as preprocessors import modules.inpaint_worker as inpaint_worker + import modules.constants as constants import modules.advanced_parameters as advanced_parameters import fooocus_extras.ip_adapter as ip_adapter - from modules.sdxl_styles import apply_style, apply_wildcards, aspect_ratios, fooocus_expansion + from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion from modules.private_logger import log from modules.expansion import safe_str from modules.util import join_prompts, remove_empty_str, HWC3, resize_image, \ - get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil + get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image from modules.upscaler import perform_upscale try: @@ -46,6 +49,60 @@ def worker(): print(f'[Fooocus] {text}') outputs.append(['preview', (number, text, None)]) + def yield_result(imgs, do_not_show_finished_images=False): + global global_results + + if not isinstance(imgs, list): + imgs = [imgs] + + global_results = global_results + imgs + + if do_not_show_finished_images: + return + + outputs.append(['results', global_results]) + return + + def build_image_wall(): + global global_results + + if len(global_results) < 2: + return + + for img in global_results: + if not isinstance(img, np.ndarray): + return + if img.ndim != 3: + return + + H, W, C = global_results[0].shape + + for img in global_results: + Hn, Wn, Cn = img.shape + if H != Hn: + return + if W != Wn: + return + if C != Cn: + return + + cols = float(len(global_results)) ** 0.5 + cols = int(math.ceil(cols)) + rows = float(len(global_results)) / float(cols) + rows = int(math.ceil(rows)) + + wall = np.zeros(shape=(H * rows, W * cols, C), dtype=np.uint8) + + for y in range(rows): + for x in range(cols): + if y * cols + x < len(global_results): + img = global_results[y * cols + x] + wall[y * H:y * H + H, x * W:x * W + W, :] = img + + # must use deep copy otherwise gradio is super laggy. Do not use list.append() . + global_results = global_results + [wall] + return + @torch.no_grad() @torch.inference_mode() def handler(args): @@ -64,6 +121,7 @@ def worker(): guidance_scale = args.pop() base_model_name = args.pop() refiner_model_name = args.pop() + refiner_switch = args.pop() loras = [(args.pop(), args.pop()) for _ in range(5)] input_image_checkbox = args.pop() current_tab = args.pop() @@ -112,7 +170,10 @@ def worker(): denoising_strength = 1.0 tiled = False inpaint_worker.current_task = None - width, height = aspect_ratios[aspect_ratios_selection] + + width, height = aspect_ratios_selection.split('×') + width, height = int(width), int(height) + skip_prompt_processing = False refiner_swap_method = advanced_parameters.refiner_swap_method @@ -123,20 +184,13 @@ def worker(): controlnet_cpds_path = None clip_vision_path, ip_negative_path, ip_adapter_path = None, None, None - seed = image_seed - max_seed = int(1024 * 1024 * 1024) - if not isinstance(seed, int): - seed = random.randint(1, max_seed) - if seed < 0: - seed = - seed - seed = seed % max_seed + seed = int(image_seed) + print(f'[Parameters] Seed = {seed}') if performance_selection == 'Speed': steps = 30 - switch = 20 else: steps = 60 - switch = 40 sampler_name = advanced_parameters.sampler_name scheduler_name = advanced_parameters.scheduler_name @@ -157,10 +211,8 @@ def worker(): else: if performance_selection == 'Speed': steps = 18 - switch = 12 else: steps = 36 - switch = 24 progressbar(1, 'Downloading upscale models ...') modules.path.downloading_upscale_model() if (current_tab == 'inpaint' or (current_tab == 'ip' and advanced_parameters.mixing_image_prompt_and_inpaint))\ @@ -175,7 +227,6 @@ def worker(): loras += [(inpaint_patch_model_path, 1.0)] print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}') goals.append('inpaint') - sampler_name = 'dpmpp_2m_sde_gpu' # only support the patched dpmpp_2m_sde_gpu if current_tab == 'ip' or \ advanced_parameters.mixing_image_prompt_and_inpaint or \ advanced_parameters.mixing_image_prompt_and_vary_upscale: @@ -193,6 +244,8 @@ def worker(): pipeline.refresh_controlnets([controlnet_canny_path, controlnet_cpds_path]) ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path) + switch = int(round(steps * refiner_switch)) + if advanced_parameters.overwrite_step > 0: steps = advanced_parameters.overwrite_step @@ -212,12 +265,16 @@ def worker(): if not skip_prompt_processing: - prompts = remove_empty_str([safe_str(p) for p in prompt.split('\n')], default='') - negative_prompts = remove_empty_str([safe_str(p) for p in negative_prompt.split('\n')], default='') + prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='') + negative_prompts = remove_empty_str([safe_str(p) for p in negative_prompt.splitlines()], default='') prompt = prompts[0] negative_prompt = negative_prompts[0] + if prompt == '': + # disable expansion when empty since it is not meaningful and influences image prompt + use_expansion = False + extra_positive_prompts = prompts[1:] if len(prompts) > 1 else [] extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else [] @@ -227,7 +284,7 @@ def worker(): progressbar(3, 'Processing prompts ...') tasks = [] for i in range(image_number): - task_seed = seed + i + task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not task_rng = random.Random(task_seed) # may bind to inpaint noise in the future task_prompt = apply_wildcards(prompt, task_rng) @@ -241,8 +298,8 @@ def worker(): if use_style: for s in style_selections: p, n = apply_style(s, positive=task_prompt) - positive_basic_workloads.append(p) - negative_basic_workloads.append(n) + positive_basic_workloads = positive_basic_workloads + p + negative_basic_workloads = negative_basic_workloads + n else: positive_basic_workloads.append(task_prompt) @@ -273,9 +330,9 @@ def worker(): for i, t in enumerate(tasks): progressbar(5, f'Preparing Fooocus text #{i + 1} ...') expansion = pipeline.final_expansion(t['task_prompt'], t['task_seed']) - print(f'[Prompt Expansion] New suffix: {expansion}') + print(f'[Prompt Expansion] {expansion}') t['expansion'] = expansion - t['positive'] = copy.deepcopy(t['positive']) + [join_prompts(t['task_prompt'], expansion)] # Deep copy. + t['positive'] = copy.deepcopy(t['positive']) + [expansion] # Deep copy. for i, t in enumerate(tasks): progressbar(7, f'Encoding positive #{i + 1} ...') @@ -331,10 +388,14 @@ def worker(): f = 1.0 shape_ceil = get_shape_ceil(H * f, W * f) + if shape_ceil < 1024: print(f'[Upscale] Image is resized because it is too small.') + uov_input_image = set_image_shape_ceil(uov_input_image, 1024) shape_ceil = 1024 - uov_input_image = set_image_shape_ceil(uov_input_image, shape_ceil) + else: + uov_input_image = resample_image(uov_input_image, width=W * f, height=H * f) + image_is_super_large = shape_ceil > 2800 if 'fast' in uov_method: @@ -350,7 +411,7 @@ def worker(): if direct_return: d = [('Upscale (Fast)', '2x')] log(uov_input_image, d, single_line_number=1) - outputs.append(['results', [uov_input_image]]) + yield_result(uov_input_image, do_not_show_finished_images=True) return tiled = True @@ -402,7 +463,7 @@ def worker(): pipeline.final_unet.model.diffusion_model.in_inpaint = True if advanced_parameters.debugging_cn_preprocessor: - outputs.append(['results', inpaint_worker.current_task.visualize_mask_processing()]) + yield_result(inpaint_worker.current_task.visualize_mask_processing(), do_not_show_finished_images=True) return progressbar(13, 'VAE Inpaint encoding ...') @@ -448,7 +509,7 @@ def worker(): cn_img = HWC3(cn_img) task[0] = core.numpy_to_pytorch(cn_img) if advanced_parameters.debugging_cn_preprocessor: - outputs.append(['results', [cn_img]]) + yield_result(cn_img, do_not_show_finished_images=True) return for task in cn_tasks[flags.cn_cpds]: cn_img, cn_stop, cn_weight = task @@ -457,7 +518,7 @@ def worker(): cn_img = HWC3(cn_img) task[0] = core.numpy_to_pytorch(cn_img) if advanced_parameters.debugging_cn_preprocessor: - outputs.append(['results', [cn_img]]) + yield_result(cn_img, do_not_show_finished_images=True) return for task in cn_tasks[flags.cn_ip]: cn_img, cn_stop, cn_weight = task @@ -468,7 +529,7 @@ def worker(): task[0] = ip_adapter.preprocess(cn_img) if advanced_parameters.debugging_cn_preprocessor: - outputs.append(['results', [cn_img]]) + yield_result(cn_img, do_not_show_finished_images=True) return if len(cn_tasks[flags.cn_ip]) > 0: @@ -484,7 +545,6 @@ def worker(): advanced_parameters.freeu_s2 ) - results = [] all_steps = steps * image_number preparation_time = time.perf_counter() - execution_start_time @@ -560,7 +620,7 @@ def worker(): d.append((f'LoRA [{n}] weight', w)) log(x, d, single_line_number=3) - results += imgs + yield_result(imgs, do_not_show_finished_images=len(tasks) == 1) except fcbh.model_management.InterruptProcessingException as e: if shared.last_stop == 'skip': print('User skipped') @@ -572,9 +632,6 @@ def worker(): execution_time = time.perf_counter() - execution_start_time print(f'Generating and saving time: {execution_time:.2f} seconds') - outputs.append(['results', results]) - - pipeline.prepare_text_encoder(async_call=True) return while True: @@ -585,7 +642,11 @@ def worker(): handler(task) except: traceback.print_exc() - outputs.append(['results', []]) + if len(buffer) == 0: + build_image_wall() + outputs.append(['finish', global_results]) + global_results = [] + pipeline.prepare_text_encoder(async_call=True) pass diff --git a/modules/auth.py b/modules/auth.py new file mode 100644 index 00000000..3ba11142 --- /dev/null +++ b/modules/auth.py @@ -0,0 +1,41 @@ +import json +import hashlib +import modules.constants as constants + +from os.path import exists + + +def auth_list_to_dict(auth_list): + auth_dict = {} + for auth_data in auth_list: + if 'user' in auth_data: + if 'hash' in auth_data: + auth_dict |= {auth_data['user']: auth_data['hash']} + elif 'pass' in auth_data: + auth_dict |= {auth_data['user']: hashlib.sha256(bytes(auth_data['pass'], encoding='utf-8')).hexdigest()} + return auth_dict + + +def load_auth_data(filename=None): + auth_dict = None + if filename != None and exists(filename): + with open(filename, encoding='utf-8') as auth_file: + try: + auth_obj = json.load(auth_file) + if isinstance(auth_obj, list) and len(auth_obj) > 0: + auth_dict = auth_list_to_dict(auth_obj) + except Exception as e: + print('load_auth_data, e: ' + str(e)) + return auth_dict + + +auth_dict = load_auth_data(constants.AUTH_FILENAME) + +auth_enabled = auth_dict != None + + +def check_auth(user, password): + if user not in auth_dict: + return False + else: + return hashlib.sha256(bytes(password, encoding='utf-8')).hexdigest() == auth_dict[user] diff --git a/modules/constants.py b/modules/constants.py new file mode 100644 index 00000000..667fa868 --- /dev/null +++ b/modules/constants.py @@ -0,0 +1,5 @@ +# as in k-diffusion (sampling.py) +MIN_SEED = 0 +MAX_SEED = 2**63 - 1 + +AUTH_FILENAME = 'auth.json' diff --git a/modules/core.py b/modules/core.py index c58b0fa7..60ebdf65 100644 --- a/modules/core.py +++ b/modules/core.py @@ -218,19 +218,21 @@ def get_previewer(model): def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu', scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, callback_function=None, refiner=None, refiner_switch=-1, - previewer_start=None, previewer_end=None, sigmas=None, noise=None): + previewer_start=None, previewer_end=None, sigmas=None, noise_mean=None): if sigmas is not None: sigmas = sigmas.clone().to(fcbh.model_management.get_torch_device()) latent_image = latent["samples"] - if noise is None: - if disable_noise: - noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") - else: - batch_inds = latent["batch_index"] if "batch_index" in latent else None - noise = fcbh.sample.prepare_noise(latent_image, seed, batch_inds) + if disable_noise: + noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + else: + batch_inds = latent["batch_index"] if "batch_index" in latent else None + noise = fcbh.sample.prepare_noise(latent_image, seed, batch_inds) + + if isinstance(noise_mean, torch.Tensor): + noise = noise + noise_mean - torch.mean(noise, dim=1, keepdim=True) noise_mask = None if "noise_mask" in latent: diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 5322e840..8a4d2556 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -6,12 +6,11 @@ import modules.path import fcbh.model_management import fcbh.latent_formats import modules.inpaint_worker -import modules.sample_hijack as sample_hijack +import fooocus_extras.vae_interpose as vae_interpose from fcbh.model_base import SDXL, SDXLRefiner from modules.expansion import FooocusExpansion from modules.sample_hijack import clip_separate -from fcbh.k_diffusion.sampling import BrownianTreeNoiseSampler xl_base: core.StableDiffusionModel = None @@ -270,22 +269,12 @@ refresh_everything( @torch.no_grad() @torch.inference_mode() -def vae_parse(x, tiled=False, use_interpose=True): - if final_vae is None or final_refiner_vae is None: - return x +def vae_parse(latent): + if final_refiner_vae is None: + return latent - if use_interpose: - print('VAE interposing ...') - import fooocus_extras.vae_interpose - x = fooocus_extras.vae_interpose.parse(x) - print('VAE interposed ...') - else: - print('VAE parsing ...') - x = core.decode_vae(vae=final_vae, latent_image=x, tiled=tiled) - x = core.encode_vae(vae=final_refiner_vae, pixels=x, tiled=tiled) - print('VAE parsed ...') - - return x + result = vae_interpose.parse(latent["samples"]) + return {'samples': result} @torch.no_grad() @@ -352,7 +341,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height sigma_max = float(sigma_max.cpu().numpy()) print(f'[Sampler] sigma_min = {sigma_min}, sigma_max = {sigma_max}') - modules.patch.globalBrownianTreeNoiseSampler = BrownianTreeNoiseSampler( + modules.patch.BrownianTreeNoiseSamplerPatched.global_init( empty_latent['samples'].to(fcbh.model_management.get_torch_device()), sigma_min, sigma_max, seed=image_seed, cpu=False) @@ -441,11 +430,12 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled) if refiner_swap_method == 'vae': + modules.patch.eps_record = 'vae' + if modules.inpaint_worker.current_task is not None: modules.inpaint_worker.current_task.unswap() - sample_hijack.history_record = [] - core.ksampler( + sampled_latent = core.ksampler( model=final_unet, positive=positive_cond, negative=negative_cond, @@ -467,33 +457,17 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height target_model = final_unet print('Use base model to refine itself - this may because of developer mode.') + sampled_latent = vae_parse(sampled_latent) + + k_sigmas = 1.4 sigmas = calculate_sigmas(sampler=sampler_name, scheduler=scheduler_name, model=target_model.model, steps=steps, - denoise=denoise)[switch:] - k1 = target_model.model.latent_format.scale_factor - k2 = final_unet.model.latent_format.scale_factor - k_sigmas = float(k1) / float(k2) - sigmas = sigmas * k_sigmas + denoise=denoise)[switch:] * k_sigmas len_sigmas = len(sigmas) - 1 - last_step, last_clean_latent, last_noisy_latent = sample_hijack.history_record[-1] - last_clean_latent = final_unet.model.process_latent_out(last_clean_latent.cpu().to(torch.float32)) - last_noisy_latent = final_unet.model.process_latent_out(last_noisy_latent.cpu().to(torch.float32)) - last_noise = last_noisy_latent - last_clean_latent - last_noise = last_noise / last_noise.std() - - noise_mean = torch.mean(last_noise, dim=1, keepdim=True).repeat(1, 4, 1, 1) / k_sigmas - - refiner_noise = torch.normal( - mean=noise_mean, - std=torch.ones_like(noise_mean), - generator=torch.manual_seed(image_seed+1) # Avoid artifacts - ).to(last_noise) - - sampled_latent = {'samples': last_clean_latent} - sampled_latent = vae_parse(sampled_latent) + noise_mean = torch.mean(modules.patch.eps_record, dim=1, keepdim=True) if modules.inpaint_worker.current_task is not None: modules.inpaint_worker.current_task.swap() @@ -504,7 +478,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=final_clip), latent=sampled_latent, steps=len_sigmas, start_step=0, last_step=len_sigmas, disable_noise=False, force_full_denoise=True, - seed=image_seed+2, # Avoid artifacts + seed=image_seed+1, denoise=denoise, callback_function=callback, cfg=cfg_scale, @@ -513,7 +487,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height previewer_start=switch, previewer_end=steps, sigmas=sigmas, - noise=refiner_noise + noise_mean=noise_mean ) target_model = final_refiner_vae @@ -522,5 +496,5 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled) images = core.pytorch_to_numpy(decoded_latent) - sample_hijack.history_record = None + modules.patch.eps_record = None return images diff --git a/modules/expansion.py b/modules/expansion.py index 2145a701..8f9507c8 100644 --- a/modules/expansion.py +++ b/modules/expansion.py @@ -1,17 +1,24 @@ -import torch +# Fooocus GPT2 Expansion +# Algorithm created by Lvmin Zhang at 2023, Stanford +# If used inside Fooocus, any use is permitted. +# If used outside Fooocus, only non-commercial use is permitted (CC-By NC 4.0). +# This applies to the word list, vocab, model, and algorithm. + +import os +import torch +import math import fcbh.model_management as model_management +from transformers.generation.logits_process import LogitsProcessorList from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed from modules.path import fooocus_expansion_path from fcbh.model_patcher import ModelPatcher -fooocus_magic_split = [ - ', extremely', - ', intricate,', -] -dangrous_patterns = '[]【】()()|::' +# limitation of np.random.seed(), called from transformers.set_seed() +SEED_LIMIT_NUMPY = 2**32 +neg_inf = - 8192.0 def safe_str(x): @@ -30,6 +37,28 @@ def remove_pattern(x, pattern): class FooocusExpansion: def __init__(self): self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_path) + + positive_words = open(os.path.join(fooocus_expansion_path, 'positive.txt'), + encoding='utf-8').read().splitlines() + positive_words = ['Ġ' + x.lower() for x in positive_words if x != ''] + + self.logits_bias = torch.zeros((1, len(self.tokenizer.vocab)), dtype=torch.float32) + neg_inf + + debug_list = [] + for k, v in self.tokenizer.vocab.items(): + if k in positive_words: + self.logits_bias[0, v] = 0 + debug_list.append(k[1:]) + + print(f'Fooocus V2 Expansion: Vocab with {len(debug_list)} words.') + + # debug_list = '\n'.join(sorted(debug_list)) + # print(debug_list) + + # t11 = self.tokenizer(',', return_tensors="np") + # t198 = self.tokenizer('\n', return_tensors="np") + # eos = self.tokenizer.eos_token_id + self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path) self.model.eval() @@ -49,29 +78,49 @@ class FooocusExpansion: self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device) print(f'Fooocus Expansion engine loaded for {load_device}, use_fp16 = {use_fp16}.') + @torch.no_grad() + @torch.inference_mode() + def logits_processor(self, input_ids, scores): + assert scores.ndim == 2 and scores.shape[0] == 1 + self.logits_bias = self.logits_bias.to(scores) + + bias = self.logits_bias.clone() + bias[0, input_ids[0].to(bias.device).long()] = neg_inf + bias[0, 11] = 0 + + return scores + bias + + @torch.no_grad() + @torch.inference_mode() def __call__(self, prompt, seed): + if prompt == '': + return '' + if self.patcher.current_device != self.patcher.load_device: print('Fooocus Expansion loaded by itself.') model_management.load_model_gpu(self.patcher) - seed = int(seed) + seed = int(seed) % SEED_LIMIT_NUMPY set_seed(seed) - origin = safe_str(prompt) - prompt = origin + fooocus_magic_split[seed % len(fooocus_magic_split)] + prompt = safe_str(prompt) + ',' tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt") tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.patcher.load_device) tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.patcher.load_device) + current_token_length = int(tokenized_kwargs.data['input_ids'].shape[1]) + max_token_length = 75 * int(math.ceil(float(current_token_length) / 75.0)) + max_new_tokens = max_token_length - current_token_length + # https://huggingface.co/blog/introducing-csearch # https://huggingface.co/docs/transformers/generation_strategies features = self.model.generate(**tokenized_kwargs, - num_beams=1, - max_new_tokens=256, - do_sample=True) + top_k=100, + max_new_tokens=max_new_tokens, + do_sample=True, + logits_processor=LogitsProcessorList([self.logits_processor])) response = self.tokenizer.batch_decode(features, skip_special_tokens=True) - result = response[0][len(origin):] - result = safe_str(result) - result = remove_pattern(result, dangrous_patterns) + result = safe_str(response[0]) + return result diff --git a/modules/gradio_hijack.py b/modules/gradio_hijack.py index 4fd2db56..181429ec 100644 --- a/modules/gradio_hijack.py +++ b/modules/gradio_hijack.py @@ -9,6 +9,9 @@ from typing import Any, Literal import numpy as np import PIL import PIL.ImageOps +import gradio.routes +import importlib + from gradio_client import utils as client_utils from gradio_client.documentation import document, set_documentation_group from gradio_client.serializing import ImgSerializable @@ -461,3 +464,17 @@ def blk_ini(self, *args, **kwargs): Block.__init__ = blk_ini + +gradio.routes.asyncio = importlib.reload(gradio.routes.asyncio) + +if not hasattr(gradio.routes.asyncio, 'original_wait_for'): + gradio.routes.asyncio.original_wait_for = gradio.routes.asyncio.wait_for + + +def patched_wait_for(fut, timeout): + del timeout + return gradio.routes.asyncio.original_wait_for(fut, timeout=65535) + + +gradio.routes.asyncio.wait_for = patched_wait_for + diff --git a/modules/launch_util.py b/modules/launch_util.py index d4641aa2..00fff8ae 100644 --- a/modules/launch_util.py +++ b/modules/launch_util.py @@ -1,16 +1,12 @@ import os import importlib import importlib.util -import shutil import subprocess import sys import re import logging -import pygit2 -pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0) - logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) diff --git a/modules/patch.py b/modules/patch.py index 7b6a38cd..98b56d5e 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -1,3 +1,4 @@ +import contextlib import os import torch import time @@ -22,9 +23,10 @@ import args_manager import modules.advanced_parameters as advanced_parameters import warnings import safetensors.torch +import modules.constants as constants from fcbh.k_diffusion import utils -from fcbh.k_diffusion.sampling import trange +from fcbh.k_diffusion.sampling import BatchedBrownianTree from fcbh.ldm.modules.diffusionmodules.openaimodel import timestep_embedding, forward_timestep_embed @@ -38,6 +40,7 @@ cfg_x0 = 0.0 cfg_s = 1.0 cfg_cin = 1.0 adaptive_cfg = 0.7 +eps_record = None def calculate_weight_patched(self, patches, weight, key): @@ -192,10 +195,12 @@ def patched_sampler_cfg_function(args): def patched_discrete_eps_ddpm_denoiser_forward(self, input, sigma, **kwargs): - global cfg_x0, cfg_s, cfg_cin + global cfg_x0, cfg_s, cfg_cin, eps_record c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] cfg_x0, cfg_s, cfg_cin = input, c_out, c_in eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) + if eps_record is not None: + eps_record = eps.clone().cpu() return input + eps * c_out @@ -276,70 +281,58 @@ def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs): return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() -globalBrownianTreeNoiseSampler = None - - -@torch.no_grad() -def sample_dpmpp_fooocus_2m_sde_inpaint_seamless(model, x, sigmas, extra_args=None, callback=None, - disable=None, eta=1., s_noise=1., **kwargs): - global sigma_min, sigma_max - - print('[Sampler] Fooocus sampler is activated.') - - seed = extra_args.get("seed", None) - assert isinstance(seed, int) - - energy_generator = torch.Generator(device='cpu') - energy_generator.manual_seed(seed + 1) # avoid bad results by using different seeds. - - def get_energy(): - return torch.randn(x.size(), dtype=x.dtype, generator=energy_generator, device="cpu").to(x) - - extra_args = {} if extra_args is None else extra_args - s_in = x.new_ones([x.shape[0]]) - - old_denoised, h_last, h = None, None, None - - latent_processor = model.inner_model.inner_model.inner_model.process_latent_in - inpaint_latent = None - inpaint_mask = None - +def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): if inpaint_worker.current_task is not None: + if getattr(self, 'energy_generator', None) is None: + # avoid bad results by using different seeds. + self.energy_generator = torch.Generator(device='cpu').manual_seed((seed + 1) % constants.MAX_SEED) + + latent_processor = self.inner_model.inner_model.inner_model.process_latent_in inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x) inpaint_mask = inpaint_worker.current_task.latent_mask.to(x) + energy_sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(x.shape) - 1)) + current_energy = torch.randn(x.size(), dtype=x.dtype, generator=self.energy_generator, device="cpu").to(x) * energy_sigma + x = x * inpaint_mask + (inpaint_latent + current_energy) * (1.0 - inpaint_mask) - def blend_latent(a, b, w): - return a * w + b * (1 - w) + out = self.inner_model(x, sigma, + cond=cond, + uncond=uncond, + cond_scale=cond_scale, + model_options=model_options, + seed=seed) - for i in trange(len(sigmas) - 1, disable=disable): - if inpaint_latent is None: - denoised = model(x, sigmas[i] * s_in, **extra_args) - else: - energy = get_energy() * sigmas[i] + inpaint_latent - x_prime = blend_latent(x, energy, inpaint_mask) - denoised = model(x_prime, sigmas[i] * s_in, **extra_args) - denoised = blend_latent(denoised, inpaint_latent, inpaint_mask) - if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - if sigmas[i + 1] == 0: - x = denoised - else: - t, s = -sigmas[i].log(), -sigmas[i + 1].log() - h = s - t - eta_h = eta * h + out = out * inpaint_mask + inpaint_latent * (1.0 - inpaint_mask) + else: + out = self.inner_model(x, sigma, + cond=cond, + uncond=uncond, + cond_scale=cond_scale, + model_options=model_options, + seed=seed) + return out - x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised - if old_denoised is not None: - r = h_last / h - x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) - x = x + globalBrownianTreeNoiseSampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * ( - -2 * eta_h).expm1().neg().sqrt() * s_noise +class BrownianTreeNoiseSamplerPatched: + transform = None + tree = None - old_denoised = denoised - h_last = h + @staticmethod + def global_init(x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False): + t0, t1 = transform(torch.as_tensor(sigma_min)), transform(torch.as_tensor(sigma_max)) - return x + BrownianTreeNoiseSamplerPatched.transform = transform + BrownianTreeNoiseSamplerPatched.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu) + + def __init__(self, *args, **kwargs): + pass + + @staticmethod + def __call__(sigma, sigma_next): + transform = BrownianTreeNoiseSamplerPatched.transform + tree = BrownianTreeNoiseSamplerPatched.tree + + t0, t1 = transform(torch.as_tensor(sigma)), transform(torch.as_tensor(sigma_next)) + return tree(t0, t1) / (t1 - t0).abs().sqrt() def timed_adm(y, timesteps): @@ -457,23 +450,6 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control= return self.out(h) -def text_encoder_device_patched(): - # Fooocus's style system uses text encoder much more times than fcbh so this makes things much faster. - return fcbh.model_management.get_torch_device() - - -def patched_get_autocast_device(dev): - # https://github.com/lllyasviel/Fooocus/discussions/571 - # https://github.com/lllyasviel/Fooocus/issues/620 - result = '' - if hasattr(dev, 'type'): - result = str(dev.type) - if 'cuda' in result: - return 'cuda' - else: - return 'cpu' - - def patched_load_models_gpu(*args, **kwargs): execution_start_time = time.perf_counter() y = fcbh.model_management.load_models_gpu_origin(*args, **kwargs) @@ -535,15 +511,14 @@ def patch_all(): fcbh.model_management.load_models_gpu_origin = fcbh.model_management.load_models_gpu fcbh.model_management.load_models_gpu = patched_load_models_gpu - fcbh.model_management.get_autocast_device = patched_get_autocast_device - fcbh.model_management.text_encoder_device = text_encoder_device_patched fcbh.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched fcbh.cldm.cldm.ControlNet.forward = patched_cldm_forward fcbh.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward - fcbh.k_diffusion.sampling.sample_dpmpp_2m_sde_gpu = sample_dpmpp_fooocus_2m_sde_inpaint_seamless fcbh.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward fcbh.model_base.SDXL.encode_adm = sdxl_encode_adm_patched fcbh.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method + fcbh.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward + fcbh.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched warnings.filterwarnings(action='ignore', module='torchsde') diff --git a/modules/path.py b/modules/path.py index e23b3dc4..0722468a 100644 --- a/modules/path.py +++ b/modules/path.py @@ -9,6 +9,7 @@ from modules.util import get_files_from_folder config_path = "user_path_config.txt" config_dict = {} +visited_keys = [] try: if os.path.exists(config_path): @@ -37,7 +38,8 @@ if preset is not None: def get_dir_or_set_default(key, default_value): - global config_dict + global config_dict, visited_keys + visited_keys.append(key) v = config_dict.get(key, None) if isinstance(v, str) and os.path.exists(v) and os.path.isdir(v): return v @@ -62,7 +64,8 @@ temp_outputs_path = get_dir_or_set_default('temp_outputs_path', '../outputs/') def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False): - global config_dict + global config_dict, visited_keys + visited_keys.append(key) if key not in config_dict: config_dict[key] = default_value return default_value @@ -80,14 +83,19 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_ default_base_model_name = get_config_item_or_set_default( key='default_model', - default_value='sd_xl_base_1.0_0.9vae.safetensors', + default_value='juggernautXL_version6Rundiffusion.safetensors', validator=lambda x: isinstance(x, str) ) default_refiner_model_name = get_config_item_or_set_default( key='default_refiner', - default_value='sd_xl_refiner_1.0_0.9vae.safetensors', + default_value='None', validator=lambda x: isinstance(x, str) ) +default_refiner_switch = get_config_item_or_set_default( + key='default_refiner_switch', + default_value=0.8, + validator=lambda x: isinstance(x, float) +) default_lora_name = get_config_item_or_set_default( key='default_lora', default_value='sd_xl_offset_example-lora_1.0.safetensors', @@ -95,12 +103,17 @@ default_lora_name = get_config_item_or_set_default( ) default_lora_weight = get_config_item_or_set_default( key='default_lora_weight', - default_value=0.5, + default_value=0.1, validator=lambda x: isinstance(x, float) ) default_cfg_scale = get_config_item_or_set_default( key='default_cfg_scale', - default_value=7.0, + default_value=4.0, + validator=lambda x: isinstance(x, float) +) +default_sample_sharpness = get_config_item_or_set_default( + key='default_sample_sharpness', + default_value=2, validator=lambda x: isinstance(x, float) ) default_sampler = get_config_item_or_set_default( @@ -115,17 +128,17 @@ default_scheduler = get_config_item_or_set_default( ) default_styles = get_config_item_or_set_default( key='default_styles', - default_value=['Fooocus V2', 'Default (Slightly Cinematic)'], + default_value=['Fooocus V2', 'Fooocus Enhance', 'Fooocus Sharp'], validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x) ) -default_negative_prompt = get_config_item_or_set_default( - key='default_negative_prompt', - default_value='low quality, bad hands, bad eyes, cropped, missing fingers, extra digit', +default_prompt_negative = get_config_item_or_set_default( + key='default_prompt_negative', + default_value='', validator=lambda x: isinstance(x, str), disable_empty_as_none=True ) -default_positive_prompt = get_config_item_or_set_default( - key='default_positive_prompt', +default_prompt = get_config_item_or_set_default( + key='default_prompt', default_value='', validator=lambda x: isinstance(x, str), disable_empty_as_none=True @@ -143,10 +156,8 @@ default_image_number = get_config_item_or_set_default( checkpoint_downloads = get_config_item_or_set_default( key='checkpoint_downloads', default_value={ - 'sd_xl_base_1.0_0.9vae.safetensors': - 'https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0_0.9vae.safetensors', - 'sd_xl_refiner_1.0_0.9vae.safetensors': - 'https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0_0.9vae.safetensors' + 'juggernautXL_version6Rundiffusion.safetensors': + 'https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/juggernautXL_version6Rundiffusion.safetensors' }, validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) ) @@ -163,22 +174,30 @@ embeddings_downloads = get_config_item_or_set_default( default_value={}, validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) ) +available_aspect_ratios = get_config_item_or_set_default( + key='available_aspect_ratios', + default_value=['704*1408', '704*1344', '768*1344', '768*1280', '832*1216', '832*1152', '896*1152', '896*1088', '960*1088', '960*1024', '1024*1024', '1024*960', '1088*960', '1088*896', '1152*896', '1152*832', '1216*832', '1280*768', '1344*768', '1344*704', '1408*704', '1472*704', '1536*640', '1600*640', '1664*576', '1728*576'], + validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1 +) default_aspect_ratio = get_config_item_or_set_default( key='default_aspect_ratio', - default_value='1152*896', - validator=lambda x: x.replace('*', '×') in modules.sdxl_styles.aspect_ratios -).replace('*', '×') + default_value='1152*896' if '1152*896' in available_aspect_ratios else available_aspect_ratios[0], + validator=lambda x: x in available_aspect_ratios +) if preset is None: # Do not overwrite user config if preset is applied. with open(config_path, "w", encoding="utf-8") as json_file: - json.dump(config_dict, json_file, indent=4) + json.dump({k: config_dict[k] for k in visited_keys}, json_file, indent=4) os.makedirs(temp_outputs_path, exist_ok=True) model_filenames = [] lora_filenames = [] +available_aspect_ratios = [x.replace('*', '×') for x in available_aspect_ratios] +default_aspect_ratio = default_aspect_ratio.replace('*', '×') + def get_model_filenames(folder_path, name_filter=None): return get_files_from_folder(folder_path, ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'], name_filter) diff --git a/modules/sample_hijack.py b/modules/sample_hijack.py index ed184dd6..30e47b65 100644 --- a/modules/sample_hijack.py +++ b/modules/sample_hijack.py @@ -3,28 +3,25 @@ import fcbh.samplers import fcbh.model_management from fcbh.model_base import SDXLRefiner, SDXL +from fcbh.conds import CONDRegular from fcbh.sample import get_additional_models, get_models_from_cond, cleanup_additional_models from fcbh.samplers import resolve_areas_and_cond_masks, wrap_model, calculate_start_end_timesteps, \ - create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_adm, \ - encode_cond + create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_model_conds current_refiner = None refiner_switch_step = -1 -history_record = None @torch.no_grad() @torch.inference_mode() -def clip_separate(cond, target_model=None, target_clip=None): - c, p = cond[0] +def clip_separate_inner(c, p, target_model=None, target_clip=None): if target_model is None or isinstance(target_model, SDXLRefiner): c = c[..., -1280:].clone() - p = {"pooled_output": p["pooled_output"].clone()} elif isinstance(target_model, SDXL): c = c.clone() - p = {"pooled_output": p["pooled_output"].clone()} else: + p = None c = c[..., :768].clone() final_layer_norm = target_clip.cond_stage_model.clip_l.transformer.text_model.final_layer_norm @@ -44,9 +41,42 @@ def clip_separate(cond, target_model=None, target_clip=None): final_layer_norm.to(device=final_layer_norm_origin_device, dtype=final_layer_norm_origin_dtype) c = c.to(device=c_origin_device, dtype=c_origin_dtype) + return c, p - p = {} - return [[c, p]] + +@torch.no_grad() +@torch.inference_mode() +def clip_separate(cond, target_model=None, target_clip=None): + results = [] + + for c, px in cond: + p = px.get('pooled_output', None) + c, p = clip_separate_inner(c, p, target_model=target_model, target_clip=target_clip) + p = {} if p is None else {'pooled_output': p.clone()} + results.append([c, p]) + + return results + + +@torch.no_grad() +@torch.inference_mode() +def clip_separate_after_preparation(cond, target_model=None, target_clip=None): + results = [] + + for x in cond: + p = x.get('pooled_output', None) + c = x['model_conds']['c_crossattn'].cond + + c, p = clip_separate_inner(c, p, target_model=target_model, target_clip=target_clip) + + result = {'model_conds': {'c_crossattn': CONDRegular(c)}} + + if p is not None: + result['pooled_output'] = p.clone() + + results.append(result) + + return results @torch.no_grad() @@ -74,31 +104,24 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas # pre_run_control(model_wrap, negative + positive) pre_run_control(model_wrap, positive) # negative is not necessary in Fooocus, 0.5s faster. - apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) if latent_image is not None: latent_image = model.process_latent_in(latent_image) - if model.is_adm(): - positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive") - negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative") - - if hasattr(model, 'cond_concat'): - positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) + if hasattr(model, 'extra_conds'): + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} - if current_refiner is not None and current_refiner.model.is_adm(): - positive_refiner = clip_separate(positive, target_model=current_refiner.model) - negative_refiner = clip_separate(negative, target_model=current_refiner.model) + if current_refiner is not None and hasattr(current_refiner.model, 'extra_conds'): + positive_refiner = clip_separate_after_preparation(positive, target_model=current_refiner.model) + negative_refiner = clip_separate_after_preparation(negative, target_model=current_refiner.model) - positive_refiner = encode_adm(current_refiner.model, positive_refiner, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive") - negative_refiner = encode_adm(current_refiner.model, negative_refiner, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative") - - positive_refiner[0][1]['adm_encoded'].to(positive[0][1]['adm_encoded']) - negative_refiner[0][1]['adm_encoded'].to(negative[0][1]['adm_encoded']) + positive_refiner = encode_model_conds(current_refiner.model.extra_conds, positive_refiner, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative_refiner = encode_model_conds(current_refiner.model.extra_conds, negative_refiner, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) def refiner_switch(): cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) @@ -118,9 +141,6 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas return def callback_wrap(step, x0, x, total_steps): - global history_record - if isinstance(history_record, list): - history_record.append((step, x0, x)) if step == refiner_switch_step and current_refiner is not None: refiner_switch() if callback is not None: diff --git a/modules/sdxl_styles.py b/modules/sdxl_styles.py index 87af5123..14a4ff11 100644 --- a/modules/sdxl_styles.py +++ b/modules/sdxl_styles.py @@ -40,7 +40,9 @@ for styles_file in styles_files: try: with open(os.path.join(styles_path, styles_file), encoding='utf-8') as f: for entry in json.load(f): - name, prompt, negative_prompt = normalize_key(entry['name']), entry['prompt'], entry['negative_prompt'] + name = normalize_key(entry['name']) + prompt = entry['prompt'] if 'prompt' in entry else '' + negative_prompt = entry['negative_prompt'] if 'negative_prompt' in entry else '' styles[name] = (prompt, negative_prompt) except Exception as e: print(str(e)) @@ -51,51 +53,9 @@ fooocus_expansion = "Fooocus V2" legal_style_names = [fooocus_expansion] + style_keys -SD_XL_BASE_RATIOS = { - "0.5": (704, 1408), - "0.52": (704, 1344), - "0.57": (768, 1344), - "0.6": (768, 1280), - "0.68": (832, 1216), - "0.72": (832, 1152), - "0.78": (896, 1152), - "0.82": (896, 1088), - "0.88": (960, 1088), - "0.94": (960, 1024), - "1.0": (1024, 1024), - "1.07": (1024, 960), - "1.13": (1088, 960), - "1.21": (1088, 896), - "1.29": (1152, 896), - "1.38": (1152, 832), - "1.46": (1216, 832), - "1.67": (1280, 768), - "1.75": (1344, 768), - "1.91": (1344, 704), - "2.0": (1408, 704), - "2.09": (1472, 704), - "2.4": (1536, 640), - "2.5": (1600, 640), - "2.89": (1664, 576), - "3.0": (1728, 576), -} - -aspect_ratios = {} - -# import math - -for k, (w, h) in SD_XL_BASE_RATIOS.items(): - txt = f'{w}×{h}' - - # gcd = math.gcd(w, h) - # txt += f' {w//gcd}:{h//gcd}' - - aspect_ratios[txt] = (w, h) - - def apply_style(style, positive): p, n = styles[style] - return p.replace('{prompt}', positive), n + return p.replace('{prompt}', positive).splitlines(), n.splitlines() def apply_wildcards(wildcard_text, rng, directory=wildcards_path): diff --git a/modules/util.py b/modules/util.py index 2bd46fd6..1601f1fe 100644 --- a/modules/util.py +++ b/modules/util.py @@ -12,7 +12,7 @@ LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.L def resample_image(im, width, height): im = Image.fromarray(im) - im = im.resize((width, height), resample=LANCZOS) + im = im.resize((int(width), int(height)), resample=LANCZOS) return np.array(im) @@ -84,11 +84,22 @@ def get_image_shape_ceil(im): def set_image_shape_ceil(im, shape_ceil): - H, W, _ = im.shape - shape_ceil_before = get_shape_ceil(H, W) - k = float(shape_ceil) / shape_ceil_before - H = int(round(float(H) * k / 64.0) * 64) - W = int(round(float(W) * k / 64.0) * 64) + shape_ceil = float(shape_ceil) + + H_origin, W_origin, _ = im.shape + H, W = H_origin, W_origin + + for _ in range(256): + current_shape_ceil = get_shape_ceil(H, W) + if abs(current_shape_ceil - shape_ceil) < 0.1: + break + k = shape_ceil / current_shape_ceil + H = int(round(float(H) * k / 64.0) * 64) + W = int(round(float(W) * k / 64.0) * 64) + + if H == H_origin and W == W_origin: + return im + return resample_image(im, width=W, height=H) diff --git a/presets/anime.json b/presets/anime.json index a75ca189..c7f84cdc 100644 --- a/presets/anime.json +++ b/presets/anime.json @@ -2,6 +2,7 @@ "default_model": "bluePencilXL_v050.safetensors", "default_refiner": "DreamShaper_8_pruned.safetensors", "default_lora": "sd_xl_offset_example-lora_1.0.safetensors", + "default_refiner_switch": 0.667, "default_lora_weight": 0.5, "default_cfg_scale": 7.0, "default_sampler": "dpmpp_2m_sde_gpu", @@ -14,8 +15,8 @@ "SAI Enhance", "SAI Fantasy Art" ], - "default_negative_prompt": "(embedding:unaestheticXLv31:0.8), low quality, watermark", - "default_positive_prompt": "1girl, ", + "default_prompt_negative": "(embedding:unaestheticXLv31:0.8), low quality, watermark", + "default_prompt": "1girl, ", "checkpoint_downloads": { "bluePencilXL_v050.safetensors": "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/bluePencilXL_v050.safetensors", "DreamShaper_8_pruned.safetensors": "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/DreamShaper_8_pruned.safetensors" diff --git a/presets/realistic.json b/presets/realistic.json index 14b19795..caf78a91 100644 --- a/presets/realistic.json +++ b/presets/realistic.json @@ -11,8 +11,8 @@ "Fooocus Photograph", "Fooocus Negative" ], - "default_negative_prompt": "unrealistic, saturated, high contrast, big nose, painting, drawing, sketch, cartoon, anime, manga, render, CG, 3d, watermark, signature, label", - "default_positive_prompt": "", + "default_prompt_negative": "unrealistic, saturated, high contrast, big nose, painting, drawing, sketch, cartoon, anime, manga, render, CG, 3d, watermark, signature, label", + "default_prompt": "", "checkpoint_downloads": { "realisticStockPhoto_v10.safetensors": "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticStockPhoto_v10.safetensors" }, diff --git a/presets/sdxl.json b/presets/sdxl.json deleted file mode 100644 index 141ca04c..00000000 --- a/presets/sdxl.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "default_model": "sd_xl_base_1.0_0.9vae.safetensors", - "default_refiner": "sd_xl_refiner_1.0_0.9vae.safetensors", - "default_lora": "sd_xl_offset_example-lora_1.0.safetensors", - "default_lora_weight": 0.5, - "default_cfg_scale": 7.0, - "default_sampler": "dpmpp_2m_sde_gpu", - "default_scheduler": "karras", - "default_styles": [ - "Fooocus V2", - "Default (Slightly Cinematic)" - ], - "default_negative_prompt": "low quality, bad hands, bad eyes, cropped, missing fingers, extra digit", - "default_positive_prompt": "", - "checkpoint_downloads": { - "sd_xl_base_1.0_0.9vae.safetensors": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0_0.9vae.safetensors", - "sd_xl_refiner_1.0_0.9vae.safetensors": "https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0_0.9vae.safetensors" - }, - "embeddings_downloads": {}, - "default_aspect_ratio": "1152*896", - "lora_downloads": { - "sd_xl_offset_example-lora_1.0.safetensors": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors" - } -} diff --git a/readme.md b/readme.md index 06465ed3..055292c4 100644 --- a/readme.md +++ b/readme.md @@ -1,9 +1,17 @@