[BE] supports more than 77 tokens

This commit is contained in:
HappyZ 2023-05-24 14:14:18 -07:00
parent 313da1b1e9
commit f6b970692a
4 changed files with 216 additions and 12 deletions

View File

@ -89,10 +89,10 @@ def backend(model, is_debugging: bool):
negative_prompt = next_job[KEY_NEG_PROMPT]
if KEY_LANGUAGE in next_job:
logger.info(
f"found {next_job[KEY_LANGUAGE]}, translate prompt and negative prompt first"
)
if VALUE_LANGUAGE_EN != next_job[KEY_LANGUAGE]:
logger.info(
f"found {next_job[KEY_LANGUAGE]}, translate prompt and negative prompt first"
)
prompt_en = translate_prompt(prompt, next_job[KEY_LANGUAGE])
logger.info(f"translated {prompt} to {prompt_en}")
prompt = prompt_en

View File

@ -1,4 +1,5 @@
import torch
import re
from typing import Union
from PIL import Image
@ -37,7 +38,61 @@ class Img2Img:
self.lunch(prompt, negative_prompt)
def breakfast(self):
pass
self.__max_length = self.model.img2img_pipeline.tokenizer.model_max_length
self.__logger.info(f"model has max length of {self.__max_length}")
def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""):
count_prompt = len(re.split("[ ,]+", prompt))
count_negative_prompt = len(re.split("[ ,]+", negative_prompt))
if count_prompt < 77 and count_negative_prompt < 77:
return prompt, None, negative_prompt, None
self.__logger.info(
"using workaround to generate embeds instead of direct string"
)
if count_prompt >= count_negative_prompt:
input_ids = self.model.img2img_pipeline.tokenizer(
prompt, return_tensors="pt", truncation=False
).input_ids.to(self.__device)
shape_max_length = input_ids.shape[-1]
negative_ids = self.model.img2img_pipeline.tokenizer(
negative_prompt,
truncation=False,
padding="max_length",
max_length=shape_max_length,
return_tensors="pt",
).input_ids.to(self.__device)
else:
negative_ids = self.model.img2img_pipeline.tokenizer(
negative_prompt, return_tensors="pt", truncation=False
).input_ids.to(self.__device)
shape_max_length = negative_ids.shape[-1]
input_ids = self.model.img2img_pipeline.tokenizer(
prompt,
return_tensors="pt",
truncation=False,
padding="max_length",
max_length=shape_max_length,
).input_ids.to(self.__device)
concat_embeds = []
neg_embeds = []
for i in range(0, shape_max_length, self.__max_length):
concat_embeds.append(
self.model.img2img_pipeline.text_encoder(
input_ids[:, i : i + self.__max_length]
)[0]
)
neg_embeds.append(
self.model.img2img_pipeline.text_encoder(
negative_ids[:, i : i + self.__max_length]
)[0]
)
return None, torch.cat(concat_embeds, dim=1), None, torch.cat(neg_embeds, dim=1)
def lunch(
self,
@ -63,10 +118,19 @@ class Img2Img:
reference_image = base64_to_image(reference_image).convert("RGB")
reference_image.thumbnail((config.get_width(), config.get_height()))
(
prompt,
prompt_embeds,
negative_prompt,
negative_prompt_embeds,
) = self.__token_limit_workaround(prompt, negative_prompt)
result = self.model.img2img_pipeline(
prompt=prompt,
image=reference_image,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image=reference_image,
guidance_scale=config.get_guidance_scale(),
strength=config.get_strength(),
num_inference_steps=config.get_steps(),

View File

@ -1,4 +1,5 @@
import torch
import re
from typing import Union
from PIL import Image
@ -37,7 +38,61 @@ class Inpainting:
self.lunch(prompt, negative_prompt)
def breakfast(self):
pass
self.__max_length = self.model.inpaint_pipeline.tokenizer.model_max_length
self.__logger.info(f"model has max length of {self.__max_length}")
def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""):
count_prompt = len(re.split("[ ,]+", prompt))
count_negative_prompt = len(re.split("[ ,]+", negative_prompt))
if count_prompt < 77 and count_negative_prompt < 77:
return prompt, None, negative_prompt, None
self.__logger.info(
"using workaround to generate embeds instead of direct string"
)
if count_prompt >= count_negative_prompt:
input_ids = self.model.inpaint_pipeline.tokenizer(
prompt, return_tensors="pt", truncation=False
).input_ids.to(self.__device)
shape_max_length = input_ids.shape[-1]
negative_ids = self.model.inpaint_pipeline.tokenizer(
negative_prompt,
truncation=False,
padding="max_length",
max_length=shape_max_length,
return_tensors="pt",
).input_ids.to(self.__device)
else:
negative_ids = self.model.inpaint_pipeline.tokenizer(
negative_prompt, return_tensors="pt", truncation=False
).input_ids.to(self.__device)
shape_max_length = negative_ids.shape[-1]
input_ids = self.model.inpaint_pipeline.tokenizer(
prompt,
return_tensors="pt",
truncation=False,
padding="max_length",
max_length=shape_max_length,
).input_ids.to(self.__device)
concat_embeds = []
neg_embeds = []
for i in range(0, shape_max_length, self.__max_length):
concat_embeds.append(
self.model.inpaint_pipeline.text_encoder(
input_ids[:, i : i + self.__max_length]
)[0]
)
neg_embeds.append(
self.model.inpaint_pipeline.text_encoder(
negative_ids[:, i : i + self.__max_length]
)[0]
)
return None, torch.cat(concat_embeds, dim=1), None, torch.cat(neg_embeds, dim=1)
def lunch(
self,
@ -72,13 +127,28 @@ class Inpainting:
if mask_image.size[0] < reference_image.size[0]:
mask_image = mask_image.resize(reference_image.size)
elif mask_image.size[0] > reference_image.size[0]:
mask_image = mask_image.resize(reference_image.size, resample=Image.LANCZOS)
mask_image = mask_image.resize(
reference_image.size, resample=Image.LANCZOS
)
(
prompt,
prompt_embeds,
negative_prompt,
negative_prompt_embeds,
) = self.__token_limit_workaround(prompt, negative_prompt)
result = self.model.inpaint_pipeline(
prompt=prompt,
image=reference_image.resize((512, 512)), # must use size 512 for inpaint model
mask_image=mask_image.convert("L").resize((512, 512)), # must use size 512 for inpaint model
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image=reference_image.resize(
(512, 512)
), # must use size 512 for inpaint model
mask_image=mask_image.convert("L").resize(
(512, 512)
), # must use size 512 for inpaint model
guidance_scale=config.get_guidance_scale(),
num_inference_steps=config.get_steps(),
generator=generator,
@ -87,7 +157,9 @@ class Inpainting:
)
# resize it back based on ratio (keep width 512)
result_img = result.images[0].resize((512, int(512 * reference_image.size[1] / reference_image.size[0])))
result_img = result.images[0].resize(
(512, int(512 * reference_image.size[1] / reference_image.size[0]))
)
if self.__output_folder:
out_filepath = "{}/{}.png".format(self.__output_folder, t)

View File

@ -1,4 +1,5 @@
import torch
import re
from typing import Union
from utilities.constants import BASE64IMAGE
@ -35,11 +36,69 @@ class Text2Img:
self.lunch(prompt, negative_prompt)
def breakfast(self):
pass
self.__max_length = self.model.txt2img_pipeline.tokenizer.model_max_length
self.__logger.info(f"model has max length of {self.__max_length}")
def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""):
count_prompt = len(re.split("[ ,]+", prompt))
count_negative_prompt = len(re.split("[ ,]+", negative_prompt))
if count_prompt < 77 and count_negative_prompt < 77:
return prompt, None, negative_prompt, None
self.__logger.info(
"using workaround to generate embeds instead of direct string"
)
if count_prompt >= count_negative_prompt:
input_ids = self.model.txt2img_pipeline.tokenizer(
prompt, return_tensors="pt", truncation=False
).input_ids.to(self.__device)
shape_max_length = input_ids.shape[-1]
negative_ids = self.model.txt2img_pipeline.tokenizer(
negative_prompt,
truncation=False,
padding="max_length",
max_length=shape_max_length,
return_tensors="pt",
).input_ids.to(self.__device)
else:
negative_ids = self.model.txt2img_pipeline.tokenizer(
negative_prompt, return_tensors="pt", truncation=False
).input_ids.to(self.__device)
shape_max_length = negative_ids.shape[-1]
input_ids = self.model.txt2img_pipeline.tokenizer(
prompt,
return_tensors="pt",
truncation=False,
padding="max_length",
max_length=shape_max_length,
).input_ids.to(self.__device)
concat_embeds = []
neg_embeds = []
for i in range(0, shape_max_length, self.__max_length):
concat_embeds.append(
self.model.txt2img_pipeline.text_encoder(
input_ids[:, i : i + self.__max_length]
)[0]
)
neg_embeds.append(
self.model.txt2img_pipeline.text_encoder(
negative_ids[:, i : i + self.__max_length]
)[0]
)
return None, torch.cat(concat_embeds, dim=1), None, torch.cat(neg_embeds, dim=1)
def lunch(
self, prompt: str, negative_prompt: str = "", config: Config = Config()
) -> dict:
if not prompt:
self.__logger.error("no prompt provided, won't proceed")
return {}
self.model.set_txt2img_scheduler(config.get_scheduler())
t = get_epoch_now()
@ -47,11 +106,20 @@ class Text2Img:
generator = torch.Generator(self.__device).manual_seed(seed)
self.__logger.info("current seed: {}".format(seed))
(
prompt,
prompt_embeds,
negative_prompt,
negative_prompt_embeds,
) = self.__token_limit_workaround(prompt, negative_prompt)
result = self.model.txt2img_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
width=config.get_width(),
height=config.get_height(),
negative_prompt=negative_prompt,
guidance_scale=config.get_guidance_scale(),
num_inference_steps=config.get_steps(),
generator=generator,