[BE] estimate count of token using rule of thumb: one token generally corresponds to ~4 characters of text.

This commit is contained in:
HappyZ 2023-05-29 23:33:21 -07:00
parent 9b7d2bd89d
commit 5a990fd550
5 changed files with 21 additions and 17 deletions

View File

@ -138,7 +138,13 @@ def backend(model, gfpgan_folderpath, is_debugging: bool):
) )
elif next_job[KEY_JOB_TYPE] == VALUE_JOB_RESTORATION: elif next_job[KEY_JOB_TYPE] == VALUE_JOB_RESTORATION:
ref_img_filepath = next_job[REFERENCE_IMG] ref_img_filepath = next_job[REFERENCE_IMG]
result_dict = gfpgan(gfpgan_folderpath, next_job[UUID], ref_img_filepath, config=config, logger=logger) result_dict = gfpgan(
gfpgan_folderpath,
next_job[UUID],
ref_img_filepath,
config=config,
logger=logger,
)
if not result_dict: if not result_dict:
raise ValueError("failed to run gfpgan") raise ValueError("failed to run gfpgan")
else: else:

View File

@ -281,7 +281,7 @@
<div class="form-row"> <div class="form-row">
<label for="inpaint-strike-size" class="form-label" data-en_XX="Strike Size" <label for="inpaint-strike-size" class="form-label" data-en_XX="Strike Size"
data-zh_CN="笔触大小">Strike Size</label> data-zh_CN="笔触大小">Strike Size</label>
<input type="range" class="form-range" min="1" max="30" <input type="range" class="form-range" min="1" max="50"
id="inpaint-strike-size"> id="inpaint-strike-size">
<output id='range-value'></output> <output id='range-value'></output>
</div> </div>
@ -689,9 +689,7 @@
}, },
error: function (xhr, status, error) { error: function (xhr, status, error) {
// Handle error response $('#joblist').html("try again later");
console.log(xhr.responseText);
$('#joblist').html("found nothing");
} }
}); });
}); });

View File

@ -44,17 +44,17 @@ class Img2Img:
self.__logger.info(f"model has max length of {self.__max_length}") self.__logger.info(f"model has max length of {self.__max_length}")
def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""): def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""):
count_prompt = len(re.split("[ ,]+", prompt)) token_est_count_prompt = len(prompt) / 4
count_negative_prompt = len(re.split("[ ,]+", negative_prompt)) token_est_count_neg_prompt = len(negative_prompt) / 4
if count_prompt < 77 and count_negative_prompt < 77: if token_est_count_prompt < 77 and token_est_count_neg_prompt < 77:
return prompt, None, negative_prompt, None return prompt, None, negative_prompt, None
self.__logger.info( self.__logger.info(
"using workaround to generate embeds instead of direct string" "using workaround to generate embeds instead of direct string"
) )
if count_prompt >= count_negative_prompt: if token_est_count_prompt >= token_est_count_neg_prompt:
input_ids = self.model.img2img_pipeline.tokenizer( input_ids = self.model.img2img_pipeline.tokenizer(
prompt, return_tensors="pt", truncation=False prompt, return_tensors="pt", truncation=False
).input_ids.to(self.__device) ).input_ids.to(self.__device)

View File

@ -43,17 +43,17 @@ class Inpainting:
self.__logger.info(f"model has max length of {self.__max_length}") self.__logger.info(f"model has max length of {self.__max_length}")
def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""): def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""):
count_prompt = len(re.split("[ ,]+", prompt)) token_est_count_prompt = len(prompt) / 4
count_negative_prompt = len(re.split("[ ,]+", negative_prompt)) token_est_count_neg_prompt = len(negative_prompt) / 4
if count_prompt < 77 and count_negative_prompt < 77: if token_est_count_prompt < 77 and token_est_count_neg_prompt < 77:
return prompt, None, negative_prompt, None return prompt, None, negative_prompt, None
self.__logger.info( self.__logger.info(
"using workaround to generate embeds instead of direct string" "using workaround to generate embeds instead of direct string"
) )
if count_prompt >= count_negative_prompt: if token_est_count_prompt >= token_est_count_neg_prompt:
input_ids = self.model.inpaint_pipeline.tokenizer( input_ids = self.model.inpaint_pipeline.tokenizer(
prompt, return_tensors="pt", truncation=False prompt, return_tensors="pt", truncation=False
).input_ids.to(self.__device) ).input_ids.to(self.__device)

View File

@ -41,17 +41,17 @@ class Text2Img:
self.__logger.info(f"model has max length of {self.__max_length}") self.__logger.info(f"model has max length of {self.__max_length}")
def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""): def __token_limit_workaround(self, prompt: str, negative_prompt: str = ""):
count_prompt = len(re.split("[ ,]+", prompt)) token_est_count_prompt = len(prompt) / 4
count_negative_prompt = len(re.split("[ ,]+", negative_prompt)) token_est_count_neg_prompt = len(negative_prompt) / 4
if count_prompt < 77 and count_negative_prompt < 77: if token_est_count_prompt < 77 and token_est_count_neg_prompt < 77:
return prompt, None, negative_prompt, None return prompt, None, negative_prompt, None
self.__logger.info( self.__logger.info(
"using workaround to generate embeds instead of direct string" "using workaround to generate embeds instead of direct string"
) )
if count_prompt >= count_negative_prompt: if token_est_count_prompt >= token_est_count_neg_prompt:
input_ids = self.model.txt2img_pipeline.tokenizer( input_ids = self.model.txt2img_pipeline.tokenizer(
prompt, return_tensors="pt", truncation=False prompt, return_tensors="pt", truncation=False
).input_ids.to(self.__device) ).input_ids.to(self.__device)