feat: optimize prompt translation

This commit is contained in:
Manuel Schmid 2024-06-16 21:40:07 +02:00
parent 9c93c18d0b
commit eeb1b79baa
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
1 changed files with 16 additions and 13 deletions

View File

@ -186,6 +186,7 @@ def worker():
from modules.upscaler import perform_upscale
from modules.flags import Performance
from modules.meta_parser import get_metadata_parser
from modules.translator import translate2en
pid = os.getpid()
print(f'Started worker with PID {pid}')
@ -757,11 +758,6 @@ def worker():
return final_scheduler_name
def translate_prompts(async_task):
from modules.translator import translate2en
async_task.prompt = translate2en(async_task.prompt, 'prompt')
async_task.negative_prompt = translate2en(async_task.negative_prompt, 'negative prompt')
def set_hyper_sd_defaults(async_task):
print('Enter Hyper-SD mode.')
progressbar(async_task, 1, 'Downloading Hyper-SD components ...')
@ -901,6 +897,15 @@ def worker():
'face')
return base_model_additional_loras, clip_vision_path, controlnet_canny_path, controlnet_cpds_path, inpaint_head_model_path, inpaint_image, inpaint_mask, ip_adapter_face_path, ip_adapter_path, ip_negative_path, skip_prompt_processing, use_synthetic_refiner
def prepare_enhance_prompt(prompt: str, fallback_prompt: str, translate: bool, type: str):
if len(remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='')) == 0:
prompt = fallback_prompt
else:
if translate:
prompt = translate2en(prompt, type)
prompt = prompt + '\n' + fallback_prompt
return prompt
@torch.no_grad()
@torch.inference_mode()
@ -932,7 +937,8 @@ def worker():
set_hyper_sd_defaults(async_task)
if async_task.translate_prompts:
translate_prompts(async_task)
async_task.prompt = translate2en(async_task.prompt, 'prompt')
async_task.negative_prompt = translate2en(async_task.negative_prompt, 'negative prompt')
print(f'[Parameters] Adaptive CFG = {async_task.adaptive_cfg}')
print(f'[Parameters] CLIP Skip = {async_task.clip_skip}')
@ -1103,9 +1109,7 @@ def worker():
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
# TODO also show do_not_show_finished_images=len(tasks) == 1
yield_result(async_task, mask, async_task.black_out_nsfw, False,
do_not_show_finished_images=len(
tasks) == 1 or async_task.disable_intermediate_results)
yield_result(async_task, mask, async_task.black_out_nsfw, False, do_not_show_finished_images=len(tasks) == 1 or async_task.disable_intermediate_results)
print(f'[Enhance] {dino_detection_count} boxes detected')
print(f'[Enhance] {sam_detection_count} segments detected in boxes')
@ -1126,10 +1130,9 @@ def worker():
if inpaint_patch_model_path_enhance not in base_model_additional_loras_enhance:
base_model_additional_loras_enhance += [(inpaint_patch_model_path_enhance, 1.0)]
if len(remove_empty_str([safe_str(p) for p in enhance_prompt.splitlines()], default='')) == 0:
enhance_prompt = async_task.prompt
if len(remove_empty_str([safe_str(p) for p in enhance_negative_prompt.splitlines()], default='')) == 0:
enhance_negative_prompt = async_task.negative_prompt
progressbar(async_task, current_progress, 'Preparing enhance prompts ...')
enhance_prompt = prepare_enhance_prompt(enhance_prompt, async_task.prompt, async_task.translate_prompts, 'prompt')
enhance_negative_prompt = prepare_enhance_prompt(enhance_negative_prompt, async_task.negative_prompt, async_task.translate_prompts, 'negative prompt')
tasks_enhance, use_expansion, loras = process_prompt(async_task, enhance_prompt,
enhance_negative_prompt,