diff --git a/fooocus_version.py b/fooocus_version.py index 709af32d..e1578ebb 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.848' +version = '2.1.849' diff --git a/modules/patch.py b/modules/patch.py index 6a7111a6..66b243cb 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -271,12 +271,11 @@ def sdxl_encode_adm_patched(self, **kwargs): height = float(height) * positive_adm_scale def embedder(number_list): - h = [self.embedder(torch.tensor([x], dtype=torch.float32)) for x in number_list] - h = torch.cat(h) + h = self.embedder(torch.tensor(number_list, dtype=torch.float32)) h = torch.flatten(h).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return h - width, height = round_to_64(width), round_to_64(height) + width, height = int(width), int(height) target_width, target_height = round_to_64(target_width), round_to_64(target_height) adm_emphasized = embedder([height, width, 0, 0, target_height, target_width]) diff --git a/modules/patch_clip.py b/modules/patch_clip.py index 4a1e0307..0ef22e8b 100644 --- a/modules/patch_clip.py +++ b/modules/patch_clip.py @@ -63,172 +63,94 @@ def encode_token_weights_fooocus(self, token_weight_pairs): return torch.cat(output, dim=-2).to(ldm_patched.modules.model_management.intermediate_device()), first_pooled -class SDClipModelFooocus(torch.nn.Module, ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder): - """Uses the CLIP transformer encoder for text (from huggingface)""" - LAYERS = [ - "last", - "pooled", - "hidden" - ] +def patched_SDClipModel__init__(self, max_length=77, freeze=True, layer="last", layer_idx=None, + textmodel_json_config=None, dtype=None, special_tokens=None, + layer_norm_hidden_state=True, **kwargs): + torch.nn.Module.__init__(self) + assert layer in self.LAYERS - def __init__(self, - max_length=77, - freeze=True, - layer="last", - layer_idx=None, - textmodel_json_config=None, - dtype=None, - special_tokens=None, - layer_norm_hidden_state=True, - **kwargs): - super().__init__() - assert layer in self.LAYERS + if special_tokens is None: + special_tokens = {"start": 49406, "end": 49407, "pad": 49407} - if special_tokens is None: - special_tokens = {"start": 49406, "end": 49407, "pad": 49407} + if textmodel_json_config is None: + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(ldm_patched.modules.sd1_clip.__file__)), + "sd1_clip_config.json") - if textmodel_json_config is None: - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(ldm_patched.modules.sd1_clip.__file__)), "sd1_clip_config.json") + config = CLIPTextConfig.from_json_file(textmodel_json_config) + self.num_layers = config.num_hidden_layers - config = CLIPTextConfig.from_json_file(textmodel_json_config) - self.num_layers = config.num_hidden_layers + with modeling_utils.no_init_weights(): + self.transformer = CLIPTextModel(config) - with modeling_utils.no_init_weights(): - self.transformer = CLIPTextModel(config) + if 'cuda' not in model_management.text_encoder_device().type: + dtype = torch.float32 - if 'cuda' not in model_management.text_encoder_device().type: - dtype = torch.float32 + if dtype is not None: + self.transformer.to(dtype) + self.transformer.text_model.embeddings.to(torch.float32) - if dtype is not None: - self.transformer.to(dtype) - self.transformer.text_model.embeddings.to(torch.float32) + if freeze: + self.freeze() - if freeze: - self.freeze() + self.max_length = max_length + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens + self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.enable_attention_masks = False - self.max_length = max_length - self.layer = layer - self.layer_idx = None - self.special_tokens = special_tokens - self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) - self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) - self.enable_attention_masks = False + self.layer_norm_hidden_state = layer_norm_hidden_state + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.clip_layer(layer_idx) + self.layer_default = (self.layer, self.layer_idx) - self.layer_norm_hidden_state = layer_norm_hidden_state - if layer == "hidden": - assert layer_idx is not None - self.clip_layer(layer_idx) - self.layer_default = (self.layer, self.layer_idx) - def freeze(self): - self.transformer = self.transformer.eval() - # self.train = disabled_train - for param in self.parameters(): - param.requires_grad = False +def patched_SDClipModel_forward(self, tokens): + backup_embeds = self.transformer.get_input_embeddings() + device = backup_embeds.weight.device + tokens = self.set_up_textual_embeddings(tokens, backup_embeds) + tokens = torch.LongTensor(tokens).to(device) - def clip_layer(self, layer_idx): - self.layer = "hidden" - self.layer_idx = layer_idx + if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32: + precision_scope = torch.autocast + else: + precision_scope = lambda a, dtype: contextlib.nullcontext(a) - def reset_clip_layer(self): - self.layer = self.layer_default[0] - self.layer_idx = self.layer_default[1] + with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32): + attention_mask = None + if self.enable_attention_masks: + attention_mask = torch.zeros_like(tokens) + max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 + for x in range(attention_mask.shape[0]): + for y in range(attention_mask.shape[1]): + attention_mask[x, y] = 1 + if tokens[x, y] == max_token: + break - def set_up_textual_embeddings(self, tokens, current_embeds): - out_tokens = [] - next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1 - embedding_weights = [] + outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, + output_hidden_states=self.layer == "hidden") + self.transformer.set_input_embeddings(backup_embeds) - for x in tokens: - tokens_temp = [] - for y in x: - if isinstance(y, int): - if y == token_dict_size: # EOS token - y = -1 - tokens_temp += [y] - else: - if y.shape[0] == current_embeds.weight.shape[1]: - embedding_weights += [y] - tokens_temp += [next_new_token] - next_new_token += 1 - else: - print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", - y.shape[0], current_embeds.weight.shape[1]) - while len(tokens_temp) < len(x): - tokens_temp += [self.special_tokens["pad"]] - out_tokens += [tokens_temp] - - n = token_dict_size - if len(embedding_weights) > 0: - new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], - device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) - new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1] - for x in embedding_weights: - new_embedding.weight[n] = x - n += 1 - new_embedding.weight[n] = current_embeds.weight[-1] # EOS embedding - self.transformer.set_input_embeddings(new_embedding) - - processed_tokens = [] - for x in out_tokens: - processed_tokens += [ - list(map(lambda a: n if a == -1 else a, x))] # The EOS token should always be the largest one - - return processed_tokens - - def forward(self, tokens): - backup_embeds = self.transformer.get_input_embeddings() - device = backup_embeds.weight.device - tokens = self.set_up_textual_embeddings(tokens, backup_embeds) - tokens = torch.LongTensor(tokens).to(device) - - if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32: - precision_scope = torch.autocast + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] else: - precision_scope = lambda a, dtype: contextlib.nullcontext(a) + z = outputs.hidden_states[self.layer_idx] + if self.layer_norm_hidden_state: + z = self.transformer.text_model.final_layer_norm(z) - with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32): - attention_mask = None - if self.enable_attention_masks: - attention_mask = torch.zeros_like(tokens) - max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 - for x in range(attention_mask.shape[0]): - for y in range(attention_mask.shape[1]): - attention_mask[x, y] = 1 - if tokens[x, y] == max_token: - break + if hasattr(outputs, "pooler_output"): + pooled_output = outputs.pooler_output.float() + else: + pooled_output = None - outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, - output_hidden_states=self.layer == "hidden") - self.transformer.set_input_embeddings(backup_embeds) - - if self.layer == "last": - z = outputs.last_hidden_state - elif self.layer == "pooled": - z = outputs.pooler_output[:, None, :] - else: - z = outputs.hidden_states[self.layer_idx] - if self.layer_norm_hidden_state: - z = self.transformer.text_model.final_layer_norm(z) - - if hasattr(outputs, "pooler_output"): - pooled_output = outputs.pooler_output.float() - else: - pooled_output = None - - if self.text_projection is not None and pooled_output is not None: - pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() - return z.float(), pooled_output - - def encode(self, tokens): - return self(tokens) - - def load_sd(self, sd): - if "text_projection" in sd: - self.text_projection[:] = sd.pop("text_projection") - if "text_projection.weight" in sd: - self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1) - return self.transformer.load_state_dict(sd, strict=False) + if self.text_projection is not None and pooled_output is not None: + pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() + return z.float(), pooled_output class ClipVisionModelFooocus: @@ -262,6 +184,7 @@ class ClipVisionModelFooocus: def patch_all_clip(): ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_fooocus - ldm_patched.modules.sd1_clip.SDClipModel = SDClipModelFooocus + ldm_patched.modules.sd1_clip.SDClipModel.__init__ = patched_SDClipModel__init__ + ldm_patched.modules.sd1_clip.SDClipModel.forward = patched_SDClipModel_forward ldm_patched.modules.clip_vision.ClipVisionModel = ClipVisionModelFooocus return