feat: add support for enhance prompts

This commit is contained in:
Manuel Schmid 2024-06-16 15:58:27 +02:00
parent b585d9dfa7
commit 541fb2d445
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
1 changed files with 69 additions and 32 deletions

View File

@ -114,20 +114,30 @@ class AsyncTask:
self.enhance_ctrls = []
for _ in range(modules.config.default_enhance_tabs):
enhance_enabled = args.pop()
# enhance_mode = args.pop()
enhance_mask_dino_prompt_text = args.pop()
enhance_mask_box_threshold = args.pop()
enhance_mask_text_threshold = args.pop()
enhance_mask_sam_max_num_boxes = args.pop()
enhance_prompt = args.pop()
enhance_negative_prompt = args.pop()
enhance_mask_sam_model = args.pop()
enhance_mask_text_threshold = args.pop()
enhance_mask_box_threshold = args.pop()
enhance_mask_sam_max_num_boxes = args.pop()
enhance_inpaint_disable_initial_latent = args.pop()
enhance_inpaint_engine = args.pop()
enhance_inpaint_strength = args.pop()
enhance_inpaint_respective_field = args.pop()
if enhance_enabled:
self.enhance_ctrls.append([
# enhance_mode,
enhance_mask_dino_prompt_text,
enhance_mask_box_threshold,
enhance_mask_text_threshold,
enhance_mask_sam_max_num_boxes,
enhance_prompt,
enhance_negative_prompt,
enhance_mask_sam_model,
enhance_mask_text_threshold,
enhance_mask_box_threshold,
enhance_mask_sam_max_num_boxes,
enhance_inpaint_disable_initial_latent,
enhance_inpaint_engine,
enhance_inpaint_strength,
enhance_inpaint_respective_field
])
@ -252,11 +262,10 @@ def worker():
return
def process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path, current_task_id,
denoising_strength, final_scheduler_name, goals, initial_latent, switch, task, tasks,
tiled, use_expansion, width, height):
denoising_strength, final_scheduler_name, goals, initial_latent, switch, positive_cond,
negative_cond, task, tasks, tiled, use_expansion, width, height):
if async_task.last_stop is not False:
ldm_patched.modules.model_management.interrupt_current_processing()
positive_cond, negative_cond = task['c'], task['uc']
if 'cn' in goals:
for cn_flag, cn_path in [
(flags.cn_canny, controlnet_canny_path),
@ -455,7 +464,7 @@ def worker():
def apply_inpaint(async_task, initial_latent, inpaint_head_model_path, inpaint_image,
inpaint_mask, inpaint_parameterized, denoising_strength, inpaint_respective_field, switch,
current_progress, skip_apply_outpaint=False):
inpaint_disable_initial_latent, current_progress, skip_apply_outpaint=False):
if not skip_apply_outpaint:
inpaint_image, inpaint_mask = apply_outpaint(async_task, inpaint_image, inpaint_mask)
@ -503,7 +512,7 @@ def worker():
inpaint_latent_mask=latent_mask,
model=pipeline.final_unet
)
if not async_task.inpaint_disable_initial_latent:
if not inpaint_disable_initial_latent:
initial_latent = {'samples': latent_fill}
B, C, H, W = latent_fill.shape
height, width = H * 8, W * 8
@ -708,7 +717,7 @@ def worker():
else:
progressbar(async_task, 6, f'Encoding negative #{i + 1} ...')
t['uc'] = pipeline.clip_encode(texts=t['negative'], pool_top_k=t['negative_top_k'])
return tasks, use_expansion
return tasks, use_expansion, loras
def apply_freeu(async_task):
print(f'FreeU is enabled!')
@ -965,7 +974,7 @@ def worker():
progressbar(async_task, 1, 'Initializing ...')
if not skip_prompt_processing:
tasks, use_expansion = process_prompt(async_task, base_model_additional_loras, use_expansion, use_style,
tasks, use_expansion, loras = process_prompt(async_task, base_model_additional_loras, use_expansion, use_style,
use_synthetic_refiner)
if len(goals) > 0:
@ -983,11 +992,14 @@ def worker():
if 'inpaint' in goals:
try:
denoising_strength, initial_latent, width, height = apply_inpaint(async_task, initial_latent,
inpaint_head_model_path, inpaint_image,
inpaint_head_model_path,
inpaint_image,
inpaint_mask, inpaint_parameterized,
async_task.inpaint_strength,
async_task.inpaint_respective_field,
switch, 11)
switch,
async_task.inpaint_disable_initial_latent,
11)
except EarlyReturnException:
return
@ -1034,9 +1046,12 @@ def worker():
execution_start_time = time.perf_counter()
try:
imgs, img_paths, current_progress = process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path,
current_task_id, denoising_strength, final_scheduler_name, goals, initial_latent,
switch, task, tasks, tiled, use_expansion, width, height)
imgs, img_paths, current_progress = process_task(all_steps, async_task, callback, controlnet_canny_path,
controlnet_cpds_path,
current_task_id, denoising_strength,
final_scheduler_name, goals, initial_latent,
switch, task['c'], task['uc'], task,
tasks, tiled, use_expansion, width, height)
# enhance
progressbar(async_task, current_progress, 'Processing enhance ...')
@ -1046,7 +1061,7 @@ def worker():
continue
for img in imgs:
for enhance_mask_dino_prompt_text, enhance_mask_box_threshold, enhance_mask_text_threshold, enhance_mask_sam_max_num_boxes, enhance_mask_sam_model in async_task.enhance_ctrls:
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls:
print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"')
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(img, sam_options=SAMOptions(
dino_prompt=enhance_mask_dino_prompt_text,
@ -1093,22 +1108,44 @@ def worker():
# apply_freeu(async_task)
# patch_samplers(async_task)
# defaults from inpaint mode improve details
denoising_strength_enhance = 0.5
inpaint_respective_field_enhance = 0.0
positive_cond = task['c']
if enhance_prompt is not '':
progressbar(async_task, current_progress, f'Encoding positive ...')
positive_cond = pipeline.clip_encode(texts=[enhance_prompt], pool_top_k=1)
negative_cond = task['uc']
if abs(float(async_task.cfg_scale) - 1.0) < 1e-4:
negative_cond = pipeline.clone_cond(positive_cond)
elif enhance_negative_prompt is not '':
progressbar(async_task, current_progress, f'Encoding negative ...')
negative_cond = pipeline.clip_encode(texts=[enhance_negative_prompt], pool_top_k=1)
inpaint_head_model_path_enhance = None
inpaint_parameterized_enhance = False # inpaint_engine = None, improve detail
inpaint_parameterized_enhance = enhance_inpaint_engine != 'None' # inpaint_engine = None, improve detail
if inpaint_parameterized_enhance:
progressbar(async_task, current_progress, 'Downloading inpainter ...')
inpaint_head_model_path_enhance, inpaint_patch_model_path_enhance = modules.config.downloading_inpaint_models(
async_task.inpaint_engine)
if inpaint_patch_model_path_enhance not in base_model_additional_loras:
base_model_additional_loras += [(inpaint_patch_model_path_enhance, 1.0)]
pipeline.refresh_loras(loras, base_model_additional_loras=base_model_additional_loras)
goals_enhance = ['inpaint']
denoising_strength_enhance, initial_latent_enhance, width_enhance, height_enhance = apply_inpaint(
enhance_inpaint_strength, initial_latent_enhance, width_enhance, height_enhance = apply_inpaint(
async_task, None, inpaint_head_model_path_enhance, img, mask,
inpaint_parameterized_enhance, denoising_strength_enhance,
inpaint_respective_field_enhance, switch, current_progress, True)
inpaint_parameterized_enhance, enhance_inpaint_strength,
enhance_inpaint_respective_field, switch, enhance_inpaint_disable_initial_latent,
current_progress, True)
imgs2, img_paths, current_progress = process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path,
current_task_id, denoising_strength_enhance, final_scheduler_name, goals_enhance,
initial_latent_enhance, switch, task, tasks, tiled, use_expansion, width_enhance,
height_enhance)
imgs2, img_paths, current_progress = process_task(all_steps, async_task, callback,
controlnet_canny_path, controlnet_cpds_path,
current_task_id, enhance_inpaint_strength,
final_scheduler_name, goals_enhance,
initial_latent_enhance, switch, positive_cond,
negative_cond, task, tasks,
tiled, use_expansion, width_enhance,
height_enhance)
# reset and prepare next iteration
img = imgs2[0]