From 56c8a342ec96f50fe09710be266167f3c01cbca0 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sat, 28 Oct 2023 14:49:18 -0700 Subject: [PATCH] Fooocus GitHub Bot Commit This commit is generated by a GitHub bot of Fooocus --- backend/headless/fcbh/cldm/cldm.py | 47 ++++---- .../modules/diffusionmodules/openaimodel.py | 42 +++---- backend/headless/fcbh/lora.py | 6 +- backend/headless/fcbh/model_detection.py | 112 ++++++++++++++---- backend/headless/fcbh/sd.py | 8 +- backend/headless/fcbh/sd1_clip.py | 49 +++++++- backend/headless/fcbh/sd2_clip.py | 12 +- backend/headless/fcbh/sdxl_clip.py | 29 ++--- backend/headless/fcbh/supported_models.py | 27 ++++- backend/headless/fcbh/utils.py | 30 ++--- fooocus_version.py | 2 +- 11 files changed, 234 insertions(+), 130 deletions(-) diff --git a/backend/headless/fcbh/cldm/cldm.py b/backend/headless/fcbh/cldm/cldm.py index b177e924..d464a462 100644 --- a/backend/headless/fcbh/cldm/cldm.py +++ b/backend/headless/fcbh/cldm/cldm.py @@ -27,7 +27,6 @@ class ControlNet(nn.Module): model_channels, hint_channels, num_res_blocks, - attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, @@ -52,6 +51,7 @@ class ControlNet(nn.Module): use_linear_in_transformer=False, adm_in_channels=None, transformer_depth_middle=None, + transformer_depth_output=None, device=None, operations=fcbh.ops, ): @@ -79,10 +79,7 @@ class ControlNet(nn.Module): self.image_size = image_size self.in_channels = in_channels self.model_channels = model_channels - if isinstance(transformer_depth, int): - transformer_depth = len(channel_mult) * [transformer_depth] - if transformer_depth_middle is None: - transformer_depth_middle = transformer_depth[-1] + if isinstance(num_res_blocks, int): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: @@ -90,18 +87,16 @@ class ControlNet(nn.Module): raise ValueError("provide num_res_blocks either as an int (globally constant) or " "as a list/tuple (per-level) with the same length as channel_mult") self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") - self.attention_resolutions = attention_resolutions + transformer_depth = transformer_depth[:] + self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample @@ -180,11 +175,14 @@ class ControlNet(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, - operations=operations + dtype=self.dtype, + device=device, + operations=operations, ) ] ch = mult * model_channels - if ds in attention_resolutions: + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: if num_head_channels == -1: dim_head = ch // num_heads else: @@ -201,9 +199,9 @@ class ControlNet(nn.Module): if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: layers.append( SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, + ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, operations=operations + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -223,11 +221,13 @@ class ControlNet(nn.Module): use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, + dtype=self.dtype, + device=device, operations=operations ) if resblock_updown else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations + ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations ) ) ) @@ -245,7 +245,7 @@ class ControlNet(nn.Module): if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( + mid_block = [ ResBlock( ch, time_embed_dim, @@ -253,12 +253,15 @@ class ControlNet(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, operations=operations - ), - SpatialTransformer( # always uses a self-attn + )] + if transformer_depth_middle >= 0: + mid_block += [SpatialTransformer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, operations=operations + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ), ResBlock( ch, @@ -267,9 +270,11 @@ class ControlNet(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, operations=operations - ), - ) + )] + self.middle_block = TimestepEmbedSequential(*mid_block) self.middle_block_out = self.make_zero_conv(ch, operations=operations) self._feature_size += ch diff --git a/backend/headless/fcbh/ldm/modules/diffusionmodules/openaimodel.py b/backend/headless/fcbh/ldm/modules/diffusionmodules/openaimodel.py index d8ec0a62..9c7cfb8e 100644 --- a/backend/headless/fcbh/ldm/modules/diffusionmodules/openaimodel.py +++ b/backend/headless/fcbh/ldm/modules/diffusionmodules/openaimodel.py @@ -259,10 +259,6 @@ class UNetModel(nn.Module): :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param conv_resample: if True, use learned convolutions for upsampling and @@ -289,7 +285,6 @@ class UNetModel(nn.Module): model_channels, out_channels, num_res_blocks, - attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, @@ -314,6 +309,7 @@ class UNetModel(nn.Module): use_linear_in_transformer=False, adm_in_channels=None, transformer_depth_middle=None, + transformer_depth_output=None, device=None, operations=fcbh.ops, ): @@ -341,10 +337,7 @@ class UNetModel(nn.Module): self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels - if isinstance(transformer_depth, int): - transformer_depth = len(channel_mult) * [transformer_depth] - if transformer_depth_middle is None: - transformer_depth_middle = transformer_depth[-1] + if isinstance(num_res_blocks, int): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: @@ -352,18 +345,16 @@ class UNetModel(nn.Module): raise ValueError("provide num_res_blocks either as an int (globally constant) or " "as a list/tuple (per-level) with the same length as channel_mult") self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") - self.attention_resolutions = attention_resolutions + transformer_depth = transformer_depth[:] + transformer_depth_output = transformer_depth_output[:] + self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample @@ -428,7 +419,8 @@ class UNetModel(nn.Module): ) ] ch = mult * model_channels - if ds in attention_resolutions: + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: if num_head_channels == -1: dim_head = ch // num_heads else: @@ -444,7 +436,7 @@ class UNetModel(nn.Module): if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: layers.append(SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, + ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ) @@ -488,7 +480,7 @@ class UNetModel(nn.Module): if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( + mid_block = [ ResBlock( ch, time_embed_dim, @@ -499,8 +491,9 @@ class UNetModel(nn.Module): dtype=self.dtype, device=device, operations=operations - ), - SpatialTransformer( # always uses a self-attn + )] + if transformer_depth_middle >= 0: + mid_block += [SpatialTransformer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations @@ -515,8 +508,8 @@ class UNetModel(nn.Module): dtype=self.dtype, device=device, operations=operations - ), - ) + )] + self.middle_block = TimestepEmbedSequential(*mid_block) self._feature_size += ch self.output_blocks = nn.ModuleList([]) @@ -538,7 +531,8 @@ class UNetModel(nn.Module): ) ] ch = model_channels * mult - if ds in attention_resolutions: + num_transformers = transformer_depth_output.pop() + if num_transformers > 0: if num_head_channels == -1: dim_head = ch // num_heads else: @@ -555,7 +549,7 @@ class UNetModel(nn.Module): if not exists(num_attention_blocks) or i < num_attention_blocks[level]: layers.append( SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, + ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ) diff --git a/backend/headless/fcbh/lora.py b/backend/headless/fcbh/lora.py index 4c1c5684..3bec26b5 100644 --- a/backend/headless/fcbh/lora.py +++ b/backend/headless/fcbh/lora.py @@ -141,9 +141,9 @@ def model_lora_keys_clip(model, key_map={}): text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False - for b in range(32): + for b in range(32): #TODO: clean up for c in LORA_CLIP_MAP: - k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) key_map[lora_key] = k @@ -154,6 +154,8 @@ def model_lora_keys_clip(model, key_map={}): k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: + lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base key_map[lora_key] = k clip_l_present = True diff --git a/backend/headless/fcbh/model_detection.py b/backend/headless/fcbh/model_detection.py index cc3d10ee..53851274 100644 --- a/backend/headless/fcbh/model_detection.py +++ b/backend/headless/fcbh/model_detection.py @@ -14,6 +14,19 @@ def count_blocks(state_dict_keys, prefix_string): count += 1 return count +def calculate_transformer_depth(prefix, state_dict_keys, state_dict): + context_dim = None + use_linear_in_transformer = False + + transformer_prefix = prefix + "1.transformer_blocks." + transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys))) + if len(transformer_keys) > 0: + last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') + context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] + use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 + return last_transformer_depth, context_dim, use_linear_in_transformer + return None + def detect_unet_config(state_dict, key_prefix, dtype): state_dict_keys = list(state_dict.keys()) @@ -40,6 +53,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): channel_mult = [] attention_resolutions = [] transformer_depth = [] + transformer_depth_output = [] context_dim = None use_linear_in_transformer = False @@ -48,60 +62,67 @@ def detect_unet_config(state_dict, key_prefix, dtype): count = 0 last_res_blocks = 0 - last_transformer_depth = 0 last_channel_mult = 0 - while True: + input_block_count = count_blocks(state_dict_keys, '{}input_blocks'.format(key_prefix) + '.{}.') + for count in range(input_block_count): prefix = '{}input_blocks.{}.'.format(key_prefix, count) + prefix_output = '{}output_blocks.{}.'.format(key_prefix, input_block_count - count - 1) + block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys))) if len(block_keys) == 0: break + block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys))) + if "{}0.op.weight".format(prefix) in block_keys: #new layer - if last_transformer_depth > 0: - attention_resolutions.append(current_res) - transformer_depth.append(last_transformer_depth) num_res_blocks.append(last_res_blocks) channel_mult.append(last_channel_mult) current_res *= 2 last_res_blocks = 0 - last_transformer_depth = 0 last_channel_mult = 0 + out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict) + if out is not None: + transformer_depth_output.append(out[0]) + else: + transformer_depth_output.append(0) else: res_block_prefix = "{}0.in_layers.0.weight".format(prefix) if res_block_prefix in block_keys: last_res_blocks += 1 last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels - transformer_prefix = prefix + "1.transformer_blocks." - transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys))) - if len(transformer_keys) > 0: - last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') - if context_dim is None: - context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] - use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 + out = calculate_transformer_depth(prefix, state_dict_keys, state_dict) + if out is not None: + transformer_depth.append(out[0]) + if context_dim is None: + context_dim = out[1] + use_linear_in_transformer = out[2] + else: + transformer_depth.append(0) + + res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output) + if res_block_prefix in block_keys_output: + out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict) + if out is not None: + transformer_depth_output.append(out[0]) + else: + transformer_depth_output.append(0) - count += 1 - if last_transformer_depth > 0: - attention_resolutions.append(current_res) - transformer_depth.append(last_transformer_depth) num_res_blocks.append(last_res_blocks) channel_mult.append(last_channel_mult) - transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') - - if len(set(num_res_blocks)) == 1: - num_res_blocks = num_res_blocks[0] - - if len(set(transformer_depth)) == 1: - transformer_depth = transformer_depth[0] + if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys: + transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') + else: + transformer_depth_middle = -1 unet_config["in_channels"] = in_channels unet_config["model_channels"] = model_channels unet_config["num_res_blocks"] = num_res_blocks - unet_config["attention_resolutions"] = attention_resolutions unet_config["transformer_depth"] = transformer_depth + unet_config["transformer_depth_output"] = transformer_depth_output unet_config["channel_mult"] = channel_mult unet_config["transformer_depth_middle"] = transformer_depth_middle unet_config['use_linear_in_transformer'] = use_linear_in_transformer @@ -124,6 +145,45 @@ def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_ma else: return model_config +def convert_config(unet_config): + new_config = unet_config.copy() + num_res_blocks = new_config.get("num_res_blocks", None) + channel_mult = new_config.get("channel_mult", None) + + if isinstance(num_res_blocks, int): + num_res_blocks = len(channel_mult) * [num_res_blocks] + + if "attention_resolutions" in new_config: + attention_resolutions = new_config.pop("attention_resolutions") + transformer_depth = new_config.get("transformer_depth", None) + transformer_depth_middle = new_config.get("transformer_depth_middle", None) + + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + if transformer_depth_middle is None: + transformer_depth_middle = transformer_depth[-1] + t_in = [] + t_out = [] + s = 1 + for i in range(len(num_res_blocks)): + res = num_res_blocks[i] + d = 0 + if s in attention_resolutions: + d = transformer_depth[i] + + t_in += [d] * res + t_out += [d] * (res + 1) + s *= 2 + transformer_depth = t_in + transformer_depth_output = t_out + new_config["transformer_depth"] = t_in + new_config["transformer_depth_output"] = t_out + new_config["transformer_depth_middle"] = transformer_depth_middle + + new_config["num_res_blocks"] = num_res_blocks + return new_config + + def unet_config_from_diffusers_unet(state_dict, dtype): match = {} attention_resolutions = [] @@ -200,7 +260,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype): matches = False break if matches: - return unet_config + return convert_config(unet_config) return None def model_config_from_diffusers_unet(state_dict, dtype): diff --git a/backend/headless/fcbh/sd.py b/backend/headless/fcbh/sd.py index 5f1f0c6b..0982446b 100644 --- a/backend/headless/fcbh/sd.py +++ b/backend/headless/fcbh/sd.py @@ -360,7 +360,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl from . import latent_formats model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) - model_config.unet_config = unet_config + model_config.unet_config = model_detection.convert_config(unet_config) if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type) @@ -388,11 +388,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"): clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model.clip_h elif clip_config["target"].endswith("FrozenCLIPEmbedder"): clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model.clip_l load_clip_weights(w, state_dict) return (fcbh.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) diff --git a/backend/headless/fcbh/sd1_clip.py b/backend/headless/fcbh/sd1_clip.py index 45382b00..56beb81c 100644 --- a/backend/headless/fcbh/sd1_clip.py +++ b/backend/headless/fcbh/sd1_clip.py @@ -35,7 +35,7 @@ class ClipTokenWeightEncoder: return z_empty.cpu(), first_pooled.cpu() return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() -class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): +class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = [ "last", @@ -278,7 +278,13 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No valid_file = None for embed_dir in embedding_directory: - embed_path = os.path.join(embed_dir, embedding_name) + embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name)) + embed_dir = os.path.abspath(embed_dir) + try: + if os.path.commonpath((embed_dir, embed_path)) != embed_dir: + continue + except: + continue if not os.path.isfile(embed_path): extensions = ['.safetensors', '.pt', '.bin'] for x in extensions: @@ -336,7 +342,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No embed_out = next(iter(values)) return embed_out -class SD1Tokenizer: +class SDTokenizer: def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") @@ -448,3 +454,40 @@ class SD1Tokenizer: def untokenize(self, token_weight_pair): return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) + + +class SD1Tokenizer: + def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer): + self.clip_name = clip_name + self.clip = "clip_{}".format(self.clip_name) + setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory)) + + def tokenize_with_weights(self, text:str, return_word_ids=False): + out = {} + out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids) + return out + + def untokenize(self, token_weight_pair): + return getattr(self, self.clip).untokenize(token_weight_pair) + + +class SD1ClipModel(torch.nn.Module): + def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs): + super().__init__() + self.clip_name = clip_name + self.clip = "clip_{}".format(self.clip_name) + setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) + + def clip_layer(self, layer_idx): + getattr(self, self.clip).clip_layer(layer_idx) + + def reset_clip_layer(self): + getattr(self, self.clip).reset_clip_layer() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs = token_weight_pairs[self.clip_name] + out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs) + return out, pooled + + def load_sd(self, sd): + return getattr(self, self.clip).load_sd(sd) diff --git a/backend/headless/fcbh/sd2_clip.py b/backend/headless/fcbh/sd2_clip.py index e5cac64b..052fe9ba 100644 --- a/backend/headless/fcbh/sd2_clip.py +++ b/backend/headless/fcbh/sd2_clip.py @@ -2,7 +2,7 @@ from fcbh import sd1_clip import torch import os -class SD2ClipModel(sd1_clip.SD1ClipModel): +class SD2ClipHModel(sd1_clip.SDClipModel): def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): if layer == "penultimate": layer="hidden" @@ -12,6 +12,14 @@ class SD2ClipModel(sd1_clip.SD1ClipModel): super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) self.empty_tokens = [[49406] + [49407] + [0] * 75] -class SD2Tokenizer(sd1_clip.SD1Tokenizer): +class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024) + +class SD2Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None): + super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer) + +class SD2ClipModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, **kwargs): + super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs) diff --git a/backend/headless/fcbh/sdxl_clip.py b/backend/headless/fcbh/sdxl_clip.py index 2064ba41..b05005c4 100644 --- a/backend/headless/fcbh/sdxl_clip.py +++ b/backend/headless/fcbh/sdxl_clip.py @@ -2,7 +2,7 @@ from fcbh import sd1_clip import torch import os -class SDXLClipG(sd1_clip.SD1ClipModel): +class SDXLClipG(sd1_clip.SDClipModel): def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): if layer == "penultimate": layer="hidden" @@ -16,14 +16,14 @@ class SDXLClipG(sd1_clip.SD1ClipModel): def load_sd(self, sd): return super().load_sd(sd) -class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer): +class SDXLClipGTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') -class SDXLTokenizer(sd1_clip.SD1Tokenizer): +class SDXLTokenizer: def __init__(self, embedding_directory=None): - self.clip_l = sd1_clip.SD1Tokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) def tokenize_with_weights(self, text:str, return_word_ids=False): @@ -38,7 +38,7 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer): class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None): super().__init__() - self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype) self.clip_l.layer_norm_hidden_state = False self.clip_g = SDXLClipG(device=device, dtype=dtype) @@ -63,21 +63,6 @@ class SDXLClipModel(torch.nn.Module): else: return self.clip_l.load_sd(sd) -class SDXLRefinerClipModel(torch.nn.Module): +class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None): - super().__init__() - self.clip_g = SDXLClipG(device=device, dtype=dtype) - - def clip_layer(self, layer_idx): - self.clip_g.clip_layer(layer_idx) - - def reset_clip_layer(self): - self.clip_g.reset_clip_layer() - - def encode_token_weights(self, token_weight_pairs): - token_weight_pairs_g = token_weight_pairs["g"] - g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) - return g_out, g_pooled - - def load_sd(self, sd): - return self.clip_g.load_sd(sd) + super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) diff --git a/backend/headless/fcbh/supported_models.py b/backend/headless/fcbh/supported_models.py index bb8ae214..fdd4ea4f 100644 --- a/backend/headless/fcbh/supported_models.py +++ b/backend/headless/fcbh/supported_models.py @@ -38,8 +38,15 @@ class SD15(supported_models_base.BASE): if ids.dtype == torch.float32: state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() + replace_prefix = {} + replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {"clip_l.": "cond_stage_model."} + return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def clip_target(self): return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel) @@ -62,12 +69,12 @@ class SD20(supported_models_base.BASE): return model_base.ModelType.EPS def process_clip_state_dict(self, state_dict): - state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) + state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24) return state_dict def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} - replace_prefix[""] = "cond_stage_model.model." + replace_prefix["clip_h"] = "cond_stage_model.model" state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) return state_dict @@ -104,7 +111,7 @@ class SDXLRefiner(supported_models_base.BASE): "use_linear_in_transformer": True, "context_dim": 1280, "adm_in_channels": 2560, - "transformer_depth": [0, 4, 4, 0], + "transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0], } latent_format = latent_formats.SDXL @@ -139,7 +146,7 @@ class SDXL(supported_models_base.BASE): unet_config = { "model_channels": 320, "use_linear_in_transformer": True, - "transformer_depth": [0, 2, 10], + "transformer_depth": [0, 0, 2, 2, 10, 10], "context_dim": 2048, "adm_in_channels": 2816 } @@ -165,6 +172,7 @@ class SDXL(supported_models_base.BASE): replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model" state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection" + keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection" keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) @@ -189,5 +197,14 @@ class SDXL(supported_models_base.BASE): def clip_target(self): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) +class SSD1B(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 0, 2, 2, 4, 4], + "context_dim": 2048, + "adm_in_channels": 2816 + } -models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL] + +models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B] diff --git a/backend/headless/fcbh/utils.py b/backend/headless/fcbh/utils.py index 2f50c821..5a694b14 100644 --- a/backend/headless/fcbh/utils.py +++ b/backend/headless/fcbh/utils.py @@ -170,25 +170,12 @@ UNET_MAP_BASIC = { def unet_to_diffusers(unet_config): num_res_blocks = unet_config["num_res_blocks"] - attention_resolutions = unet_config["attention_resolutions"] channel_mult = unet_config["channel_mult"] - transformer_depth = unet_config["transformer_depth"] + transformer_depth = unet_config["transformer_depth"][:] + transformer_depth_output = unet_config["transformer_depth_output"][:] num_blocks = len(channel_mult) - if isinstance(num_res_blocks, int): - num_res_blocks = [num_res_blocks] * num_blocks - if isinstance(transformer_depth, int): - transformer_depth = [transformer_depth] * num_blocks - transformers_per_layer = [] - res = 1 - for i in range(num_blocks): - transformers = 0 - if res in attention_resolutions: - transformers = transformer_depth[i] - transformers_per_layer.append(transformers) - res *= 2 - - transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1]) + transformers_mid = unet_config.get("transformer_depth_middle", None) diffusers_unet_map = {} for x in range(num_blocks): @@ -196,10 +183,11 @@ def unet_to_diffusers(unet_config): for i in range(num_res_blocks[x]): for b in UNET_MAP_RESNET: diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b) - if transformers_per_layer[x] > 0: + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: for b in UNET_MAP_ATTENTIONS: diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b) - for t in range(transformers_per_layer[x]): + for t in range(num_transformers): for b in TRANSFORMER_BLOCKS: diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) n += 1 @@ -218,7 +206,6 @@ def unet_to_diffusers(unet_config): diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b) num_res_blocks = list(reversed(num_res_blocks)) - transformers_per_layer = list(reversed(transformers_per_layer)) for x in range(num_blocks): n = (num_res_blocks[x] + 1) * x l = num_res_blocks[x] + 1 @@ -227,11 +214,12 @@ def unet_to_diffusers(unet_config): for b in UNET_MAP_RESNET: diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b) c += 1 - if transformers_per_layer[x] > 0: + num_transformers = transformer_depth_output.pop() + if num_transformers > 0: c += 1 for b in UNET_MAP_ATTENTIONS: diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b) - for t in range(transformers_per_layer[x]): + for t in range(num_transformers): for b in TRANSFORMER_BLOCKS: diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) if i == l - 1: diff --git a/fooocus_version.py b/fooocus_version.py index 68d233e4..d534bf7a 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.752' +version = '2.1.753'