diff --git a/backend.py b/backend.py index 6fc2dd5..e7b7cc3 100644 --- a/backend.py +++ b/backend.py @@ -89,10 +89,10 @@ def backend(model, is_debugging: bool): negative_prompt = next_job[KEY_NEG_PROMPT] if KEY_LANGUAGE in next_job: - logger.info( - f"found {next_job[KEY_LANGUAGE]}, translate prompt and negative prompt first" - ) if VALUE_LANGUAGE_EN != next_job[KEY_LANGUAGE]: + logger.info( + f"found {next_job[KEY_LANGUAGE]}, translate prompt and negative prompt first" + ) prompt_en = translate_prompt(prompt, next_job[KEY_LANGUAGE]) logger.info(f"translated {prompt} to {prompt_en}") prompt = prompt_en diff --git a/utilities/img2img.py b/utilities/img2img.py index aa93873..19ee0e2 100644 --- a/utilities/img2img.py +++ b/utilities/img2img.py @@ -1,4 +1,5 @@ import torch +import re from typing import Union from PIL import Image @@ -37,7 +38,61 @@ class Img2Img: self.lunch(prompt, negative_prompt) def breakfast(self): - pass + self.__max_length = self.model.img2img_pipeline.tokenizer.model_max_length + self.__logger.info(f"model has max length of {self.__max_length}") + + def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""): + count_prompt = len(re.split("[ ,]+", prompt)) + count_negative_prompt = len(re.split("[ ,]+", negative_prompt)) + + if count_prompt < 77 and count_negative_prompt < 77: + return prompt, None, negative_prompt, None + + self.__logger.info( + "using workaround to generate embeds instead of direct string" + ) + + if count_prompt >= count_negative_prompt: + input_ids = self.model.img2img_pipeline.tokenizer( + prompt, return_tensors="pt", truncation=False + ).input_ids.to(self.__device) + shape_max_length = input_ids.shape[-1] + negative_ids = self.model.img2img_pipeline.tokenizer( + negative_prompt, + truncation=False, + padding="max_length", + max_length=shape_max_length, + return_tensors="pt", + ).input_ids.to(self.__device) + + else: + negative_ids = self.model.img2img_pipeline.tokenizer( + negative_prompt, return_tensors="pt", truncation=False + ).input_ids.to(self.__device) + shape_max_length = negative_ids.shape[-1] + input_ids = self.model.img2img_pipeline.tokenizer( + prompt, + return_tensors="pt", + truncation=False, + padding="max_length", + max_length=shape_max_length, + ).input_ids.to(self.__device) + + concat_embeds = [] + neg_embeds = [] + for i in range(0, shape_max_length, self.__max_length): + concat_embeds.append( + self.model.img2img_pipeline.text_encoder( + input_ids[:, i : i + self.__max_length] + )[0] + ) + neg_embeds.append( + self.model.img2img_pipeline.text_encoder( + negative_ids[:, i : i + self.__max_length] + )[0] + ) + + return None, torch.cat(concat_embeds, dim=1), None, torch.cat(neg_embeds, dim=1) def lunch( self, @@ -63,10 +118,19 @@ class Img2Img: reference_image = base64_to_image(reference_image).convert("RGB") reference_image.thumbnail((config.get_width(), config.get_height())) + ( + prompt, + prompt_embeds, + negative_prompt, + negative_prompt_embeds, + ) = self.__token_limit_workaround(prompt, negative_prompt) + result = self.model.img2img_pipeline( prompt=prompt, - image=reference_image, negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image=reference_image, guidance_scale=config.get_guidance_scale(), strength=config.get_strength(), num_inference_steps=config.get_steps(), diff --git a/utilities/inpainting.py b/utilities/inpainting.py index aa08ec5..632c87a 100644 --- a/utilities/inpainting.py +++ b/utilities/inpainting.py @@ -1,4 +1,5 @@ import torch +import re from typing import Union from PIL import Image @@ -37,7 +38,61 @@ class Inpainting: self.lunch(prompt, negative_prompt) def breakfast(self): - pass + self.__max_length = self.model.inpaint_pipeline.tokenizer.model_max_length + self.__logger.info(f"model has max length of {self.__max_length}") + + def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""): + count_prompt = len(re.split("[ ,]+", prompt)) + count_negative_prompt = len(re.split("[ ,]+", negative_prompt)) + + if count_prompt < 77 and count_negative_prompt < 77: + return prompt, None, negative_prompt, None + + self.__logger.info( + "using workaround to generate embeds instead of direct string" + ) + + if count_prompt >= count_negative_prompt: + input_ids = self.model.inpaint_pipeline.tokenizer( + prompt, return_tensors="pt", truncation=False + ).input_ids.to(self.__device) + shape_max_length = input_ids.shape[-1] + negative_ids = self.model.inpaint_pipeline.tokenizer( + negative_prompt, + truncation=False, + padding="max_length", + max_length=shape_max_length, + return_tensors="pt", + ).input_ids.to(self.__device) + + else: + negative_ids = self.model.inpaint_pipeline.tokenizer( + negative_prompt, return_tensors="pt", truncation=False + ).input_ids.to(self.__device) + shape_max_length = negative_ids.shape[-1] + input_ids = self.model.inpaint_pipeline.tokenizer( + prompt, + return_tensors="pt", + truncation=False, + padding="max_length", + max_length=shape_max_length, + ).input_ids.to(self.__device) + + concat_embeds = [] + neg_embeds = [] + for i in range(0, shape_max_length, self.__max_length): + concat_embeds.append( + self.model.inpaint_pipeline.text_encoder( + input_ids[:, i : i + self.__max_length] + )[0] + ) + neg_embeds.append( + self.model.inpaint_pipeline.text_encoder( + negative_ids[:, i : i + self.__max_length] + )[0] + ) + + return None, torch.cat(concat_embeds, dim=1), None, torch.cat(neg_embeds, dim=1) def lunch( self, @@ -72,13 +127,28 @@ class Inpainting: if mask_image.size[0] < reference_image.size[0]: mask_image = mask_image.resize(reference_image.size) elif mask_image.size[0] > reference_image.size[0]: - mask_image = mask_image.resize(reference_image.size, resample=Image.LANCZOS) + mask_image = mask_image.resize( + reference_image.size, resample=Image.LANCZOS + ) + + ( + prompt, + prompt_embeds, + negative_prompt, + negative_prompt_embeds, + ) = self.__token_limit_workaround(prompt, negative_prompt) result = self.model.inpaint_pipeline( prompt=prompt, - image=reference_image.resize((512, 512)), # must use size 512 for inpaint model - mask_image=mask_image.convert("L").resize((512, 512)), # must use size 512 for inpaint model negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image=reference_image.resize( + (512, 512) + ), # must use size 512 for inpaint model + mask_image=mask_image.convert("L").resize( + (512, 512) + ), # must use size 512 for inpaint model guidance_scale=config.get_guidance_scale(), num_inference_steps=config.get_steps(), generator=generator, @@ -87,7 +157,9 @@ class Inpainting: ) # resize it back based on ratio (keep width 512) - result_img = result.images[0].resize((512, int(512 * reference_image.size[1] / reference_image.size[0]))) + result_img = result.images[0].resize( + (512, int(512 * reference_image.size[1] / reference_image.size[0])) + ) if self.__output_folder: out_filepath = "{}/{}.png".format(self.__output_folder, t) diff --git a/utilities/text2img.py b/utilities/text2img.py index 5f23c76..d058756 100644 --- a/utilities/text2img.py +++ b/utilities/text2img.py @@ -1,4 +1,5 @@ import torch +import re from typing import Union from utilities.constants import BASE64IMAGE @@ -35,11 +36,69 @@ class Text2Img: self.lunch(prompt, negative_prompt) def breakfast(self): - pass + self.__max_length = self.model.txt2img_pipeline.tokenizer.model_max_length + self.__logger.info(f"model has max length of {self.__max_length}") + + def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""): + count_prompt = len(re.split("[ ,]+", prompt)) + count_negative_prompt = len(re.split("[ ,]+", negative_prompt)) + + if count_prompt < 77 and count_negative_prompt < 77: + return prompt, None, negative_prompt, None + + self.__logger.info( + "using workaround to generate embeds instead of direct string" + ) + + if count_prompt >= count_negative_prompt: + input_ids = self.model.txt2img_pipeline.tokenizer( + prompt, return_tensors="pt", truncation=False + ).input_ids.to(self.__device) + shape_max_length = input_ids.shape[-1] + negative_ids = self.model.txt2img_pipeline.tokenizer( + negative_prompt, + truncation=False, + padding="max_length", + max_length=shape_max_length, + return_tensors="pt", + ).input_ids.to(self.__device) + + else: + negative_ids = self.model.txt2img_pipeline.tokenizer( + negative_prompt, return_tensors="pt", truncation=False + ).input_ids.to(self.__device) + shape_max_length = negative_ids.shape[-1] + input_ids = self.model.txt2img_pipeline.tokenizer( + prompt, + return_tensors="pt", + truncation=False, + padding="max_length", + max_length=shape_max_length, + ).input_ids.to(self.__device) + + concat_embeds = [] + neg_embeds = [] + for i in range(0, shape_max_length, self.__max_length): + concat_embeds.append( + self.model.txt2img_pipeline.text_encoder( + input_ids[:, i : i + self.__max_length] + )[0] + ) + neg_embeds.append( + self.model.txt2img_pipeline.text_encoder( + negative_ids[:, i : i + self.__max_length] + )[0] + ) + + return None, torch.cat(concat_embeds, dim=1), None, torch.cat(neg_embeds, dim=1) def lunch( self, prompt: str, negative_prompt: str = "", config: Config = Config() ) -> dict: + if not prompt: + self.__logger.error("no prompt provided, won't proceed") + return {} + self.model.set_txt2img_scheduler(config.get_scheduler()) t = get_epoch_now() @@ -47,11 +106,20 @@ class Text2Img: generator = torch.Generator(self.__device).manual_seed(seed) self.__logger.info("current seed: {}".format(seed)) + ( + prompt, + prompt_embeds, + negative_prompt, + negative_prompt_embeds, + ) = self.__token_limit_workaround(prompt, negative_prompt) + result = self.model.txt2img_pipeline( prompt=prompt, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, width=config.get_width(), height=config.get_height(), - negative_prompt=negative_prompt, guidance_scale=config.get_guidance_scale(), num_inference_steps=config.get_steps(), generator=generator,