[BE] supports more than 77 tokens
This commit is contained in:
parent
313da1b1e9
commit
f6b970692a
|
|
@ -89,10 +89,10 @@ def backend(model, is_debugging: bool):
|
||||||
negative_prompt = next_job[KEY_NEG_PROMPT]
|
negative_prompt = next_job[KEY_NEG_PROMPT]
|
||||||
|
|
||||||
if KEY_LANGUAGE in next_job:
|
if KEY_LANGUAGE in next_job:
|
||||||
logger.info(
|
|
||||||
f"found {next_job[KEY_LANGUAGE]}, translate prompt and negative prompt first"
|
|
||||||
)
|
|
||||||
if VALUE_LANGUAGE_EN != next_job[KEY_LANGUAGE]:
|
if VALUE_LANGUAGE_EN != next_job[KEY_LANGUAGE]:
|
||||||
|
logger.info(
|
||||||
|
f"found {next_job[KEY_LANGUAGE]}, translate prompt and negative prompt first"
|
||||||
|
)
|
||||||
prompt_en = translate_prompt(prompt, next_job[KEY_LANGUAGE])
|
prompt_en = translate_prompt(prompt, next_job[KEY_LANGUAGE])
|
||||||
logger.info(f"translated {prompt} to {prompt_en}")
|
logger.info(f"translated {prompt} to {prompt_en}")
|
||||||
prompt = prompt_en
|
prompt = prompt_en
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import re
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
@ -37,7 +38,61 @@ class Img2Img:
|
||||||
self.lunch(prompt, negative_prompt)
|
self.lunch(prompt, negative_prompt)
|
||||||
|
|
||||||
def breakfast(self):
|
def breakfast(self):
|
||||||
pass
|
self.__max_length = self.model.img2img_pipeline.tokenizer.model_max_length
|
||||||
|
self.__logger.info(f"model has max length of {self.__max_length}")
|
||||||
|
|
||||||
|
def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""):
|
||||||
|
count_prompt = len(re.split("[ ,]+", prompt))
|
||||||
|
count_negative_prompt = len(re.split("[ ,]+", negative_prompt))
|
||||||
|
|
||||||
|
if count_prompt < 77 and count_negative_prompt < 77:
|
||||||
|
return prompt, None, negative_prompt, None
|
||||||
|
|
||||||
|
self.__logger.info(
|
||||||
|
"using workaround to generate embeds instead of direct string"
|
||||||
|
)
|
||||||
|
|
||||||
|
if count_prompt >= count_negative_prompt:
|
||||||
|
input_ids = self.model.img2img_pipeline.tokenizer(
|
||||||
|
prompt, return_tensors="pt", truncation=False
|
||||||
|
).input_ids.to(self.__device)
|
||||||
|
shape_max_length = input_ids.shape[-1]
|
||||||
|
negative_ids = self.model.img2img_pipeline.tokenizer(
|
||||||
|
negative_prompt,
|
||||||
|
truncation=False,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=shape_max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
).input_ids.to(self.__device)
|
||||||
|
|
||||||
|
else:
|
||||||
|
negative_ids = self.model.img2img_pipeline.tokenizer(
|
||||||
|
negative_prompt, return_tensors="pt", truncation=False
|
||||||
|
).input_ids.to(self.__device)
|
||||||
|
shape_max_length = negative_ids.shape[-1]
|
||||||
|
input_ids = self.model.img2img_pipeline.tokenizer(
|
||||||
|
prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=False,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=shape_max_length,
|
||||||
|
).input_ids.to(self.__device)
|
||||||
|
|
||||||
|
concat_embeds = []
|
||||||
|
neg_embeds = []
|
||||||
|
for i in range(0, shape_max_length, self.__max_length):
|
||||||
|
concat_embeds.append(
|
||||||
|
self.model.img2img_pipeline.text_encoder(
|
||||||
|
input_ids[:, i : i + self.__max_length]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
neg_embeds.append(
|
||||||
|
self.model.img2img_pipeline.text_encoder(
|
||||||
|
negative_ids[:, i : i + self.__max_length]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
return None, torch.cat(concat_embeds, dim=1), None, torch.cat(neg_embeds, dim=1)
|
||||||
|
|
||||||
def lunch(
|
def lunch(
|
||||||
self,
|
self,
|
||||||
|
|
@ -63,10 +118,19 @@ class Img2Img:
|
||||||
reference_image = base64_to_image(reference_image).convert("RGB")
|
reference_image = base64_to_image(reference_image).convert("RGB")
|
||||||
reference_image.thumbnail((config.get_width(), config.get_height()))
|
reference_image.thumbnail((config.get_width(), config.get_height()))
|
||||||
|
|
||||||
|
(
|
||||||
|
prompt,
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
) = self.__token_limit_workaround(prompt, negative_prompt)
|
||||||
|
|
||||||
result = self.model.img2img_pipeline(
|
result = self.model.img2img_pipeline(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=reference_image,
|
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
image=reference_image,
|
||||||
guidance_scale=config.get_guidance_scale(),
|
guidance_scale=config.get_guidance_scale(),
|
||||||
strength=config.get_strength(),
|
strength=config.get_strength(),
|
||||||
num_inference_steps=config.get_steps(),
|
num_inference_steps=config.get_steps(),
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import re
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
@ -37,7 +38,61 @@ class Inpainting:
|
||||||
self.lunch(prompt, negative_prompt)
|
self.lunch(prompt, negative_prompt)
|
||||||
|
|
||||||
def breakfast(self):
|
def breakfast(self):
|
||||||
pass
|
self.__max_length = self.model.inpaint_pipeline.tokenizer.model_max_length
|
||||||
|
self.__logger.info(f"model has max length of {self.__max_length}")
|
||||||
|
|
||||||
|
def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""):
|
||||||
|
count_prompt = len(re.split("[ ,]+", prompt))
|
||||||
|
count_negative_prompt = len(re.split("[ ,]+", negative_prompt))
|
||||||
|
|
||||||
|
if count_prompt < 77 and count_negative_prompt < 77:
|
||||||
|
return prompt, None, negative_prompt, None
|
||||||
|
|
||||||
|
self.__logger.info(
|
||||||
|
"using workaround to generate embeds instead of direct string"
|
||||||
|
)
|
||||||
|
|
||||||
|
if count_prompt >= count_negative_prompt:
|
||||||
|
input_ids = self.model.inpaint_pipeline.tokenizer(
|
||||||
|
prompt, return_tensors="pt", truncation=False
|
||||||
|
).input_ids.to(self.__device)
|
||||||
|
shape_max_length = input_ids.shape[-1]
|
||||||
|
negative_ids = self.model.inpaint_pipeline.tokenizer(
|
||||||
|
negative_prompt,
|
||||||
|
truncation=False,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=shape_max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
).input_ids.to(self.__device)
|
||||||
|
|
||||||
|
else:
|
||||||
|
negative_ids = self.model.inpaint_pipeline.tokenizer(
|
||||||
|
negative_prompt, return_tensors="pt", truncation=False
|
||||||
|
).input_ids.to(self.__device)
|
||||||
|
shape_max_length = negative_ids.shape[-1]
|
||||||
|
input_ids = self.model.inpaint_pipeline.tokenizer(
|
||||||
|
prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=False,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=shape_max_length,
|
||||||
|
).input_ids.to(self.__device)
|
||||||
|
|
||||||
|
concat_embeds = []
|
||||||
|
neg_embeds = []
|
||||||
|
for i in range(0, shape_max_length, self.__max_length):
|
||||||
|
concat_embeds.append(
|
||||||
|
self.model.inpaint_pipeline.text_encoder(
|
||||||
|
input_ids[:, i : i + self.__max_length]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
neg_embeds.append(
|
||||||
|
self.model.inpaint_pipeline.text_encoder(
|
||||||
|
negative_ids[:, i : i + self.__max_length]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
return None, torch.cat(concat_embeds, dim=1), None, torch.cat(neg_embeds, dim=1)
|
||||||
|
|
||||||
def lunch(
|
def lunch(
|
||||||
self,
|
self,
|
||||||
|
|
@ -72,13 +127,28 @@ class Inpainting:
|
||||||
if mask_image.size[0] < reference_image.size[0]:
|
if mask_image.size[0] < reference_image.size[0]:
|
||||||
mask_image = mask_image.resize(reference_image.size)
|
mask_image = mask_image.resize(reference_image.size)
|
||||||
elif mask_image.size[0] > reference_image.size[0]:
|
elif mask_image.size[0] > reference_image.size[0]:
|
||||||
mask_image = mask_image.resize(reference_image.size, resample=Image.LANCZOS)
|
mask_image = mask_image.resize(
|
||||||
|
reference_image.size, resample=Image.LANCZOS
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
prompt,
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
) = self.__token_limit_workaround(prompt, negative_prompt)
|
||||||
|
|
||||||
result = self.model.inpaint_pipeline(
|
result = self.model.inpaint_pipeline(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=reference_image.resize((512, 512)), # must use size 512 for inpaint model
|
|
||||||
mask_image=mask_image.convert("L").resize((512, 512)), # must use size 512 for inpaint model
|
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
image=reference_image.resize(
|
||||||
|
(512, 512)
|
||||||
|
), # must use size 512 for inpaint model
|
||||||
|
mask_image=mask_image.convert("L").resize(
|
||||||
|
(512, 512)
|
||||||
|
), # must use size 512 for inpaint model
|
||||||
guidance_scale=config.get_guidance_scale(),
|
guidance_scale=config.get_guidance_scale(),
|
||||||
num_inference_steps=config.get_steps(),
|
num_inference_steps=config.get_steps(),
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
|
@ -87,7 +157,9 @@ class Inpainting:
|
||||||
)
|
)
|
||||||
|
|
||||||
# resize it back based on ratio (keep width 512)
|
# resize it back based on ratio (keep width 512)
|
||||||
result_img = result.images[0].resize((512, int(512 * reference_image.size[1] / reference_image.size[0])))
|
result_img = result.images[0].resize(
|
||||||
|
(512, int(512 * reference_image.size[1] / reference_image.size[0]))
|
||||||
|
)
|
||||||
|
|
||||||
if self.__output_folder:
|
if self.__output_folder:
|
||||||
out_filepath = "{}/{}.png".format(self.__output_folder, t)
|
out_filepath = "{}/{}.png".format(self.__output_folder, t)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import re
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from utilities.constants import BASE64IMAGE
|
from utilities.constants import BASE64IMAGE
|
||||||
|
|
@ -35,11 +36,69 @@ class Text2Img:
|
||||||
self.lunch(prompt, negative_prompt)
|
self.lunch(prompt, negative_prompt)
|
||||||
|
|
||||||
def breakfast(self):
|
def breakfast(self):
|
||||||
pass
|
self.__max_length = self.model.txt2img_pipeline.tokenizer.model_max_length
|
||||||
|
self.__logger.info(f"model has max length of {self.__max_length}")
|
||||||
|
|
||||||
|
def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""):
|
||||||
|
count_prompt = len(re.split("[ ,]+", prompt))
|
||||||
|
count_negative_prompt = len(re.split("[ ,]+", negative_prompt))
|
||||||
|
|
||||||
|
if count_prompt < 77 and count_negative_prompt < 77:
|
||||||
|
return prompt, None, negative_prompt, None
|
||||||
|
|
||||||
|
self.__logger.info(
|
||||||
|
"using workaround to generate embeds instead of direct string"
|
||||||
|
)
|
||||||
|
|
||||||
|
if count_prompt >= count_negative_prompt:
|
||||||
|
input_ids = self.model.txt2img_pipeline.tokenizer(
|
||||||
|
prompt, return_tensors="pt", truncation=False
|
||||||
|
).input_ids.to(self.__device)
|
||||||
|
shape_max_length = input_ids.shape[-1]
|
||||||
|
negative_ids = self.model.txt2img_pipeline.tokenizer(
|
||||||
|
negative_prompt,
|
||||||
|
truncation=False,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=shape_max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
).input_ids.to(self.__device)
|
||||||
|
|
||||||
|
else:
|
||||||
|
negative_ids = self.model.txt2img_pipeline.tokenizer(
|
||||||
|
negative_prompt, return_tensors="pt", truncation=False
|
||||||
|
).input_ids.to(self.__device)
|
||||||
|
shape_max_length = negative_ids.shape[-1]
|
||||||
|
input_ids = self.model.txt2img_pipeline.tokenizer(
|
||||||
|
prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=False,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=shape_max_length,
|
||||||
|
).input_ids.to(self.__device)
|
||||||
|
|
||||||
|
concat_embeds = []
|
||||||
|
neg_embeds = []
|
||||||
|
for i in range(0, shape_max_length, self.__max_length):
|
||||||
|
concat_embeds.append(
|
||||||
|
self.model.txt2img_pipeline.text_encoder(
|
||||||
|
input_ids[:, i : i + self.__max_length]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
neg_embeds.append(
|
||||||
|
self.model.txt2img_pipeline.text_encoder(
|
||||||
|
negative_ids[:, i : i + self.__max_length]
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
return None, torch.cat(concat_embeds, dim=1), None, torch.cat(neg_embeds, dim=1)
|
||||||
|
|
||||||
def lunch(
|
def lunch(
|
||||||
self, prompt: str, negative_prompt: str = "", config: Config = Config()
|
self, prompt: str, negative_prompt: str = "", config: Config = Config()
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
if not prompt:
|
||||||
|
self.__logger.error("no prompt provided, won't proceed")
|
||||||
|
return {}
|
||||||
|
|
||||||
self.model.set_txt2img_scheduler(config.get_scheduler())
|
self.model.set_txt2img_scheduler(config.get_scheduler())
|
||||||
|
|
||||||
t = get_epoch_now()
|
t = get_epoch_now()
|
||||||
|
|
@ -47,11 +106,20 @@ class Text2Img:
|
||||||
generator = torch.Generator(self.__device).manual_seed(seed)
|
generator = torch.Generator(self.__device).manual_seed(seed)
|
||||||
self.__logger.info("current seed: {}".format(seed))
|
self.__logger.info("current seed: {}".format(seed))
|
||||||
|
|
||||||
|
(
|
||||||
|
prompt,
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
) = self.__token_limit_workaround(prompt, negative_prompt)
|
||||||
|
|
||||||
result = self.model.txt2img_pipeline(
|
result = self.model.txt2img_pipeline(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
width=config.get_width(),
|
width=config.get_width(),
|
||||||
height=config.get_height(),
|
height=config.get_height(),
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
guidance_scale=config.get_guidance_scale(),
|
guidance_scale=config.get_guidance_scale(),
|
||||||
num_inference_steps=config.get_steps(),
|
num_inference_steps=config.get_steps(),
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue