[BE] supports more than 77 tokens
This commit is contained in:
parent
313da1b1e9
commit
f6b970692a
|
|
@ -89,10 +89,10 @@ def backend(model, is_debugging: bool):
|
|||
negative_prompt = next_job[KEY_NEG_PROMPT]
|
||||
|
||||
if KEY_LANGUAGE in next_job:
|
||||
if VALUE_LANGUAGE_EN != next_job[KEY_LANGUAGE]:
|
||||
logger.info(
|
||||
f"found {next_job[KEY_LANGUAGE]}, translate prompt and negative prompt first"
|
||||
)
|
||||
if VALUE_LANGUAGE_EN != next_job[KEY_LANGUAGE]:
|
||||
prompt_en = translate_prompt(prompt, next_job[KEY_LANGUAGE])
|
||||
logger.info(f"translated {prompt} to {prompt_en}")
|
||||
prompt = prompt_en
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import re
|
||||
from typing import Union
|
||||
from PIL import Image
|
||||
|
||||
|
|
@ -37,7 +38,61 @@ class Img2Img:
|
|||
self.lunch(prompt, negative_prompt)
|
||||
|
||||
def breakfast(self):
|
||||
pass
|
||||
self.__max_length = self.model.img2img_pipeline.tokenizer.model_max_length
|
||||
self.__logger.info(f"model has max length of {self.__max_length}")
|
||||
|
||||
def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""):
|
||||
count_prompt = len(re.split("[ ,]+", prompt))
|
||||
count_negative_prompt = len(re.split("[ ,]+", negative_prompt))
|
||||
|
||||
if count_prompt < 77 and count_negative_prompt < 77:
|
||||
return prompt, None, negative_prompt, None
|
||||
|
||||
self.__logger.info(
|
||||
"using workaround to generate embeds instead of direct string"
|
||||
)
|
||||
|
||||
if count_prompt >= count_negative_prompt:
|
||||
input_ids = self.model.img2img_pipeline.tokenizer(
|
||||
prompt, return_tensors="pt", truncation=False
|
||||
).input_ids.to(self.__device)
|
||||
shape_max_length = input_ids.shape[-1]
|
||||
negative_ids = self.model.img2img_pipeline.tokenizer(
|
||||
negative_prompt,
|
||||
truncation=False,
|
||||
padding="max_length",
|
||||
max_length=shape_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids.to(self.__device)
|
||||
|
||||
else:
|
||||
negative_ids = self.model.img2img_pipeline.tokenizer(
|
||||
negative_prompt, return_tensors="pt", truncation=False
|
||||
).input_ids.to(self.__device)
|
||||
shape_max_length = negative_ids.shape[-1]
|
||||
input_ids = self.model.img2img_pipeline.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
truncation=False,
|
||||
padding="max_length",
|
||||
max_length=shape_max_length,
|
||||
).input_ids.to(self.__device)
|
||||
|
||||
concat_embeds = []
|
||||
neg_embeds = []
|
||||
for i in range(0, shape_max_length, self.__max_length):
|
||||
concat_embeds.append(
|
||||
self.model.img2img_pipeline.text_encoder(
|
||||
input_ids[:, i : i + self.__max_length]
|
||||
)[0]
|
||||
)
|
||||
neg_embeds.append(
|
||||
self.model.img2img_pipeline.text_encoder(
|
||||
negative_ids[:, i : i + self.__max_length]
|
||||
)[0]
|
||||
)
|
||||
|
||||
return None, torch.cat(concat_embeds, dim=1), None, torch.cat(neg_embeds, dim=1)
|
||||
|
||||
def lunch(
|
||||
self,
|
||||
|
|
@ -63,10 +118,19 @@ class Img2Img:
|
|||
reference_image = base64_to_image(reference_image).convert("RGB")
|
||||
reference_image.thumbnail((config.get_width(), config.get_height()))
|
||||
|
||||
(
|
||||
prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt,
|
||||
negative_prompt_embeds,
|
||||
) = self.__token_limit_workaround(prompt, negative_prompt)
|
||||
|
||||
result = self.model.img2img_pipeline(
|
||||
prompt=prompt,
|
||||
image=reference_image,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
image=reference_image,
|
||||
guidance_scale=config.get_guidance_scale(),
|
||||
strength=config.get_strength(),
|
||||
num_inference_steps=config.get_steps(),
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import re
|
||||
from typing import Union
|
||||
from PIL import Image
|
||||
|
||||
|
|
@ -37,7 +38,61 @@ class Inpainting:
|
|||
self.lunch(prompt, negative_prompt)
|
||||
|
||||
def breakfast(self):
|
||||
pass
|
||||
self.__max_length = self.model.inpaint_pipeline.tokenizer.model_max_length
|
||||
self.__logger.info(f"model has max length of {self.__max_length}")
|
||||
|
||||
def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""):
|
||||
count_prompt = len(re.split("[ ,]+", prompt))
|
||||
count_negative_prompt = len(re.split("[ ,]+", negative_prompt))
|
||||
|
||||
if count_prompt < 77 and count_negative_prompt < 77:
|
||||
return prompt, None, negative_prompt, None
|
||||
|
||||
self.__logger.info(
|
||||
"using workaround to generate embeds instead of direct string"
|
||||
)
|
||||
|
||||
if count_prompt >= count_negative_prompt:
|
||||
input_ids = self.model.inpaint_pipeline.tokenizer(
|
||||
prompt, return_tensors="pt", truncation=False
|
||||
).input_ids.to(self.__device)
|
||||
shape_max_length = input_ids.shape[-1]
|
||||
negative_ids = self.model.inpaint_pipeline.tokenizer(
|
||||
negative_prompt,
|
||||
truncation=False,
|
||||
padding="max_length",
|
||||
max_length=shape_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids.to(self.__device)
|
||||
|
||||
else:
|
||||
negative_ids = self.model.inpaint_pipeline.tokenizer(
|
||||
negative_prompt, return_tensors="pt", truncation=False
|
||||
).input_ids.to(self.__device)
|
||||
shape_max_length = negative_ids.shape[-1]
|
||||
input_ids = self.model.inpaint_pipeline.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
truncation=False,
|
||||
padding="max_length",
|
||||
max_length=shape_max_length,
|
||||
).input_ids.to(self.__device)
|
||||
|
||||
concat_embeds = []
|
||||
neg_embeds = []
|
||||
for i in range(0, shape_max_length, self.__max_length):
|
||||
concat_embeds.append(
|
||||
self.model.inpaint_pipeline.text_encoder(
|
||||
input_ids[:, i : i + self.__max_length]
|
||||
)[0]
|
||||
)
|
||||
neg_embeds.append(
|
||||
self.model.inpaint_pipeline.text_encoder(
|
||||
negative_ids[:, i : i + self.__max_length]
|
||||
)[0]
|
||||
)
|
||||
|
||||
return None, torch.cat(concat_embeds, dim=1), None, torch.cat(neg_embeds, dim=1)
|
||||
|
||||
def lunch(
|
||||
self,
|
||||
|
|
@ -72,13 +127,28 @@ class Inpainting:
|
|||
if mask_image.size[0] < reference_image.size[0]:
|
||||
mask_image = mask_image.resize(reference_image.size)
|
||||
elif mask_image.size[0] > reference_image.size[0]:
|
||||
mask_image = mask_image.resize(reference_image.size, resample=Image.LANCZOS)
|
||||
mask_image = mask_image.resize(
|
||||
reference_image.size, resample=Image.LANCZOS
|
||||
)
|
||||
|
||||
(
|
||||
prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt,
|
||||
negative_prompt_embeds,
|
||||
) = self.__token_limit_workaround(prompt, negative_prompt)
|
||||
|
||||
result = self.model.inpaint_pipeline(
|
||||
prompt=prompt,
|
||||
image=reference_image.resize((512, 512)), # must use size 512 for inpaint model
|
||||
mask_image=mask_image.convert("L").resize((512, 512)), # must use size 512 for inpaint model
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
image=reference_image.resize(
|
||||
(512, 512)
|
||||
), # must use size 512 for inpaint model
|
||||
mask_image=mask_image.convert("L").resize(
|
||||
(512, 512)
|
||||
), # must use size 512 for inpaint model
|
||||
guidance_scale=config.get_guidance_scale(),
|
||||
num_inference_steps=config.get_steps(),
|
||||
generator=generator,
|
||||
|
|
@ -87,7 +157,9 @@ class Inpainting:
|
|||
)
|
||||
|
||||
# resize it back based on ratio (keep width 512)
|
||||
result_img = result.images[0].resize((512, int(512 * reference_image.size[1] / reference_image.size[0])))
|
||||
result_img = result.images[0].resize(
|
||||
(512, int(512 * reference_image.size[1] / reference_image.size[0]))
|
||||
)
|
||||
|
||||
if self.__output_folder:
|
||||
out_filepath = "{}/{}.png".format(self.__output_folder, t)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from utilities.constants import BASE64IMAGE
|
||||
|
|
@ -35,11 +36,69 @@ class Text2Img:
|
|||
self.lunch(prompt, negative_prompt)
|
||||
|
||||
def breakfast(self):
|
||||
pass
|
||||
self.__max_length = self.model.txt2img_pipeline.tokenizer.model_max_length
|
||||
self.__logger.info(f"model has max length of {self.__max_length}")
|
||||
|
||||
def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""):
|
||||
count_prompt = len(re.split("[ ,]+", prompt))
|
||||
count_negative_prompt = len(re.split("[ ,]+", negative_prompt))
|
||||
|
||||
if count_prompt < 77 and count_negative_prompt < 77:
|
||||
return prompt, None, negative_prompt, None
|
||||
|
||||
self.__logger.info(
|
||||
"using workaround to generate embeds instead of direct string"
|
||||
)
|
||||
|
||||
if count_prompt >= count_negative_prompt:
|
||||
input_ids = self.model.txt2img_pipeline.tokenizer(
|
||||
prompt, return_tensors="pt", truncation=False
|
||||
).input_ids.to(self.__device)
|
||||
shape_max_length = input_ids.shape[-1]
|
||||
negative_ids = self.model.txt2img_pipeline.tokenizer(
|
||||
negative_prompt,
|
||||
truncation=False,
|
||||
padding="max_length",
|
||||
max_length=shape_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids.to(self.__device)
|
||||
|
||||
else:
|
||||
negative_ids = self.model.txt2img_pipeline.tokenizer(
|
||||
negative_prompt, return_tensors="pt", truncation=False
|
||||
).input_ids.to(self.__device)
|
||||
shape_max_length = negative_ids.shape[-1]
|
||||
input_ids = self.model.txt2img_pipeline.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
truncation=False,
|
||||
padding="max_length",
|
||||
max_length=shape_max_length,
|
||||
).input_ids.to(self.__device)
|
||||
|
||||
concat_embeds = []
|
||||
neg_embeds = []
|
||||
for i in range(0, shape_max_length, self.__max_length):
|
||||
concat_embeds.append(
|
||||
self.model.txt2img_pipeline.text_encoder(
|
||||
input_ids[:, i : i + self.__max_length]
|
||||
)[0]
|
||||
)
|
||||
neg_embeds.append(
|
||||
self.model.txt2img_pipeline.text_encoder(
|
||||
negative_ids[:, i : i + self.__max_length]
|
||||
)[0]
|
||||
)
|
||||
|
||||
return None, torch.cat(concat_embeds, dim=1), None, torch.cat(neg_embeds, dim=1)
|
||||
|
||||
def lunch(
|
||||
self, prompt: str, negative_prompt: str = "", config: Config = Config()
|
||||
) -> dict:
|
||||
if not prompt:
|
||||
self.__logger.error("no prompt provided, won't proceed")
|
||||
return {}
|
||||
|
||||
self.model.set_txt2img_scheduler(config.get_scheduler())
|
||||
|
||||
t = get_epoch_now()
|
||||
|
|
@ -47,11 +106,20 @@ class Text2Img:
|
|||
generator = torch.Generator(self.__device).manual_seed(seed)
|
||||
self.__logger.info("current seed: {}".format(seed))
|
||||
|
||||
(
|
||||
prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt,
|
||||
negative_prompt_embeds,
|
||||
) = self.__token_limit_workaround(prompt, negative_prompt)
|
||||
|
||||
result = self.model.txt2img_pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
width=config.get_width(),
|
||||
height=config.get_height(),
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=config.get_guidance_scale(),
|
||||
num_inference_steps=config.get_steps(),
|
||||
generator=generator,
|
||||
|
|
|
|||
Loading…
Reference in New Issue