feat: refresh the whole pipeline, allows usage of inpaint and enhancement prompts

This commit is contained in:
Manuel Schmid 2024-06-16 19:27:31 +02:00
parent e1be3fa37a
commit ff3418876d
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
2 changed files with 133 additions and 129 deletions

View File

@ -622,10 +622,10 @@ def worker():
height = async_task.overwrite_height
return height, switch, width
def process_prompt(async_task, base_model_additional_loras, use_expansion, use_style,
def process_prompt(async_task, prompt, negative_prompt, base_model_additional_loras, image_number, disable_seed_increment, use_expansion, use_style,
use_synthetic_refiner):
prompts = remove_empty_str([safe_str(p) for p in async_task.prompt.splitlines()], default='')
negative_prompts = remove_empty_str([safe_str(p) for p in async_task.negative_prompt.splitlines()], default='')
prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='')
negative_prompts = remove_empty_str([safe_str(p) for p in negative_prompt.splitlines()], default='')
prompt = prompts[0]
negative_prompt = negative_prompts[0]
if prompt == '':
@ -647,8 +647,8 @@ def worker():
pipeline.set_clip_skip(async_task.clip_skip)
progressbar(async_task, 3, 'Processing prompts ...')
tasks = []
for i in range(async_task.image_number):
if async_task.disable_seed_increment:
for i in range(image_number):
if disable_seed_increment:
task_seed = async_task.seed % (constants.MAX_SEED + 1)
else:
task_seed = (async_task.seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not
@ -811,6 +811,95 @@ def worker():
async_task.adm_scaler_negative = 1.0
async_task.adm_scaler_end = 0.0
def apply_image_input(async_task, base_model_additional_loras, clip_vision_path, controlnet_canny_path,
controlnet_cpds_path, goals, inpaint_head_model_path, inpaint_mask, inpaint_parameterized,
ip_adapter_face_path, ip_adapter_path, ip_negative_path, skip_prompt_processing,
use_synthetic_refiner):
if (async_task.current_tab == 'uov' or (
async_task.current_tab == 'ip' and async_task.mixing_image_prompt_and_vary_upscale)) \
and async_task.uov_method != flags.disabled and async_task.uov_input_image is not None:
async_task.uov_input_image = HWC3(async_task.uov_input_image)
if 'vary' in async_task.uov_method:
goals.append('vary')
elif 'upscale' in async_task.uov_method:
goals.append('upscale')
if 'fast' in async_task.uov_method:
skip_prompt_processing = True
else:
async_task.steps = async_task.performance_selection.steps_uov()
progressbar(async_task, 1, 'Downloading upscale models ...')
modules.config.downloading_upscale_model()
if (async_task.current_tab == 'inpaint' or (
async_task.current_tab == 'ip' and async_task.mixing_image_prompt_and_inpaint)) \
and isinstance(async_task.inpaint_input_image, dict):
inpaint_image = async_task.inpaint_input_image['image']
inpaint_mask = async_task.inpaint_input_image['mask'][:, :, 0]
if async_task.inpaint_mask_upload_checkbox:
if isinstance(async_task.inpaint_mask_image_upload, dict):
if (isinstance(async_task.inpaint_mask_image_upload['image'], np.ndarray)
and isinstance(async_task.inpaint_mask_image_upload['mask'], np.ndarray)
and async_task.inpaint_mask_image_upload['image'].ndim == 3):
async_task.inpaint_mask_image_upload = np.maximum(
async_task.inpaint_mask_image_upload['image'],
async_task.inpaint_mask_image_upload['mask'])
if isinstance(async_task.inpaint_mask_image_upload,
np.ndarray) and async_task.inpaint_mask_image_upload.ndim == 3:
H, W, C = inpaint_image.shape
async_task.inpaint_mask_image_upload = resample_image(async_task.inpaint_mask_image_upload,
width=W, height=H)
async_task.inpaint_mask_image_upload = np.mean(async_task.inpaint_mask_image_upload, axis=2)
async_task.inpaint_mask_image_upload = (async_task.inpaint_mask_image_upload > 127).astype(
np.uint8) * 255
inpaint_mask = np.maximum(inpaint_mask, async_task.inpaint_mask_image_upload)
if int(async_task.inpaint_erode_or_dilate) != 0:
inpaint_mask = erode_or_dilate(inpaint_mask, async_task.inpaint_erode_or_dilate)
if async_task.invert_mask_checkbox:
inpaint_mask = 255 - inpaint_mask
inpaint_image = HWC3(inpaint_image)
if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
and (np.any(inpaint_mask > 127) or len(async_task.outpaint_selections) > 0):
progressbar(async_task, 1, 'Downloading upscale models ...')
modules.config.downloading_upscale_model()
if inpaint_parameterized:
progressbar(async_task, 1, 'Downloading inpainter ...')
inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models(
async_task.inpaint_engine)
base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
if async_task.refiner_model_name == 'None':
use_synthetic_refiner = True
async_task.refiner_switch = 0.8
else:
inpaint_head_model_path, inpaint_patch_model_path = None, None
print(f'[Inpaint] Parameterized inpaint is disabled.')
if async_task.inpaint_additional_prompt != '':
if async_task.prompt == '':
async_task.prompt = async_task.inpaint_additional_prompt
else:
async_task.prompt = async_task.inpaint_additional_prompt + '\n' + async_task.prompt
goals.append('inpaint')
if async_task.current_tab == 'ip' or \
async_task.mixing_image_prompt_and_vary_upscale or \
async_task.mixing_image_prompt_and_inpaint:
goals.append('cn')
progressbar(async_task, 1, 'Downloading control models ...')
if len(async_task.cn_tasks[flags.cn_canny]) > 0:
controlnet_canny_path = modules.config.downloading_controlnet_canny()
if len(async_task.cn_tasks[flags.cn_cpds]) > 0:
controlnet_cpds_path = modules.config.downloading_controlnet_cpds()
if len(async_task.cn_tasks[flags.cn_ip]) > 0:
clip_vision_path, ip_negative_path, ip_adapter_path = modules.config.downloading_ip_adapters('ip')
if len(async_task.cn_tasks[flags.cn_ip_face]) > 0:
clip_vision_path, ip_negative_path, ip_adapter_face_path = modules.config.downloading_ip_adapters(
'face')
return base_model_additional_loras, clip_vision_path, controlnet_canny_path, controlnet_cpds_path, inpaint_head_model_path, inpaint_image, inpaint_mask, ip_adapter_face_path, ip_adapter_path, ip_negative_path, skip_prompt_processing, use_synthetic_refiner
@torch.no_grad()
@torch.inference_mode()
def handler(async_task: AsyncTask):
@ -882,83 +971,10 @@ def worker():
tasks = []
if async_task.input_image_checkbox:
if (async_task.current_tab == 'uov' or (
async_task.current_tab == 'ip' and async_task.mixing_image_prompt_and_vary_upscale)) \
and async_task.uov_method != flags.disabled and async_task.uov_input_image is not None:
async_task.uov_input_image = HWC3(async_task.uov_input_image)
if 'vary' in async_task.uov_method:
goals.append('vary')
elif 'upscale' in async_task.uov_method:
goals.append('upscale')
if 'fast' in async_task.uov_method:
skip_prompt_processing = True
else:
async_task.steps = async_task.performance_selection.steps_uov()
progressbar(async_task, 1, 'Downloading upscale models ...')
modules.config.downloading_upscale_model()
if (async_task.current_tab == 'inpaint' or (
async_task.current_tab == 'ip' and async_task.mixing_image_prompt_and_inpaint)) \
and isinstance(async_task.inpaint_input_image, dict):
inpaint_image = async_task.inpaint_input_image['image']
inpaint_mask = async_task.inpaint_input_image['mask'][:, :, 0]
if async_task.inpaint_mask_upload_checkbox:
if isinstance(async_task.inpaint_mask_image_upload, dict):
if (isinstance(async_task.inpaint_mask_image_upload['image'], np.ndarray)
and isinstance(async_task.inpaint_mask_image_upload['mask'], np.ndarray)
and async_task.inpaint_mask_image_upload['image'].ndim == 3):
async_task.inpaint_mask_image_upload = np.maximum(async_task.inpaint_mask_image_upload['image'], async_task.inpaint_mask_image_upload['mask'])
if isinstance(async_task.inpaint_mask_image_upload, np.ndarray) and async_task.inpaint_mask_image_upload.ndim == 3:
H, W, C = inpaint_image.shape
async_task.inpaint_mask_image_upload = resample_image(async_task.inpaint_mask_image_upload, width=W, height=H)
async_task.inpaint_mask_image_upload = np.mean(async_task.inpaint_mask_image_upload, axis=2)
async_task.inpaint_mask_image_upload = (async_task.inpaint_mask_image_upload > 127).astype(np.uint8) * 255
inpaint_mask = np.maximum(inpaint_mask, async_task.inpaint_mask_image_upload)
if int(async_task.inpaint_erode_or_dilate) != 0:
inpaint_mask = erode_or_dilate(inpaint_mask, async_task.inpaint_erode_or_dilate)
if async_task.invert_mask_checkbox:
inpaint_mask = 255 - inpaint_mask
inpaint_image = HWC3(inpaint_image)
if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
and (np.any(inpaint_mask > 127) or len(async_task.outpaint_selections) > 0):
progressbar(async_task, 1, 'Downloading upscale models ...')
modules.config.downloading_upscale_model()
if inpaint_parameterized:
progressbar(async_task, 1, 'Downloading inpainter ...')
inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models(
async_task.inpaint_engine)
base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
if async_task.refiner_model_name == 'None':
use_synthetic_refiner = True
async_task.refiner_switch = 0.8
else:
inpaint_head_model_path, inpaint_patch_model_path = None, None
print(f'[Inpaint] Parameterized inpaint is disabled.')
if async_task.inpaint_additional_prompt != '':
if async_task.prompt == '':
async_task.prompt = async_task.inpaint_additional_prompt
else:
async_task.prompt = async_task.inpaint_additional_prompt + '\n' + async_task.prompt
goals.append('inpaint')
if async_task.current_tab == 'ip' or \
async_task.mixing_image_prompt_and_vary_upscale or \
async_task.mixing_image_prompt_and_inpaint:
goals.append('cn')
progressbar(async_task, 1, 'Downloading control models ...')
if len(async_task.cn_tasks[flags.cn_canny]) > 0:
controlnet_canny_path = modules.config.downloading_controlnet_canny()
if len(async_task.cn_tasks[flags.cn_cpds]) > 0:
controlnet_cpds_path = modules.config.downloading_controlnet_cpds()
if len(async_task.cn_tasks[flags.cn_ip]) > 0:
clip_vision_path, ip_negative_path, ip_adapter_path = modules.config.downloading_ip_adapters('ip')
if len(async_task.cn_tasks[flags.cn_ip_face]) > 0:
clip_vision_path, ip_negative_path, ip_adapter_face_path = modules.config.downloading_ip_adapters(
'face')
base_model_additional_loras, clip_vision_path, controlnet_canny_path, controlnet_cpds_path, inpaint_head_model_path, inpaint_image, inpaint_mask, ip_adapter_face_path, ip_adapter_path, ip_negative_path, skip_prompt_processing, use_synthetic_refiner = apply_image_input(
async_task, base_model_additional_loras, clip_vision_path, controlnet_canny_path, controlnet_cpds_path,
goals, inpaint_head_model_path, inpaint_mask, inpaint_parameterized, ip_adapter_face_path, ip_adapter_path,
ip_negative_path, skip_prompt_processing, use_synthetic_refiner)
# Load or unload CNs
@ -975,8 +991,10 @@ def worker():
progressbar(async_task, 1, 'Initializing ...')
if not skip_prompt_processing:
tasks, use_expansion, loras = process_prompt(async_task, base_model_additional_loras, use_expansion, use_style,
use_synthetic_refiner)
tasks, use_expansion, loras = process_prompt(async_task, async_task.prompt, async_task.negative_prompt,
base_model_additional_loras, async_task.image_number,
async_task.disable_seed_increment, use_expansion, use_style,
use_synthetic_refiner)
if len(goals) > 0:
progressbar(async_task, 7, 'Image processing ...')
@ -1054,13 +1072,12 @@ def worker():
switch, task['c'], task['uc'], task,
tasks, tiled, use_expansion, width, height)
if not async_task.enhance_checkbox or len(async_task.enhance_ctrls) == 0 or 'inpaint' in goals:
if not async_task.enhance_checkbox or len(async_task.enhance_ctrls) == 0:
print(f'[Enhance] Skipping, preconditions aren\'t met')
continue
# enhance
progressbar(async_task, current_progress, 'Processing enhance ...')
final_unet = pipeline.final_unet.clone()
for img in imgs:
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls:
@ -1076,6 +1093,9 @@ def worker():
))
mask = mask[:, :, 0]
if int(async_task.inpaint_erode_or_dilate) != 0:
mask = erode_or_dilate(mask, async_task.inpaint_erode_or_dilate)
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
# TODO also show do_not_show_finished_images=len(tasks) == 1
yield_result(async_task, mask, async_task.black_out_nsfw, False,
@ -1090,38 +1110,7 @@ def worker():
print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping')
continue
# TODO make configurable
# # do not apply loras / controlnets / etc. twice (samplers are needed though)
# pipeline.final_unet = pipeline.model_base.unet.clone()
# pipeline.refresh_everything(refiner_model_name=async_task.refiner_model_name,
# base_model_name=async_task.base_model_name,
# loras=[],
# base_model_additional_loras=[],
# use_synthetic_refiner=use_synthetic_refiner,
# vae_name=async_task.vae_name)
# pipeline.set_clip_skip(async_task.clip_skip)
#
# # patch everything again except original inpainting
# if 'cn' in goals:
# apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width)
# if async_task.freeu_enabled:
# apply_freeu(async_task)
# patch_samplers(async_task)
positive_cond = task['c']
if enhance_prompt is not '':
progressbar(async_task, current_progress, f'Encoding positive ...')
positive_cond = pipeline.clip_encode(texts=[enhance_prompt], pool_top_k=1)
negative_cond = task['uc']
if abs(float(async_task.cfg_scale) - 1.0) < 1e-4:
negative_cond = pipeline.clone_cond(positive_cond)
elif enhance_negative_prompt != '':
progressbar(async_task, current_progress, f'Encoding negative ...')
negative_cond = pipeline.clip_encode(texts=[enhance_negative_prompt], pool_top_k=1)
base_model_additional_loras_enhance = []
inpaint_head_model_path_enhance = None
inpaint_parameterized_enhance = enhance_inpaint_engine != 'None' # inpaint_engine = None, improve detail
@ -1129,9 +1118,28 @@ def worker():
progressbar(async_task, current_progress, 'Downloading inpainter ...')
inpaint_head_model_path_enhance, inpaint_patch_model_path_enhance = modules.config.downloading_inpaint_models(
async_task.inpaint_engine)
if inpaint_patch_model_path_enhance not in base_model_additional_loras:
base_model_additional_loras += [(inpaint_patch_model_path_enhance, 1.0)]
pipeline.refresh_loras(loras, base_model_additional_loras=base_model_additional_loras)
if inpaint_patch_model_path_enhance not in base_model_additional_loras_enhance:
base_model_additional_loras_enhance += [(inpaint_patch_model_path_enhance, 1.0)]
if len(remove_empty_str([safe_str(p) for p in enhance_prompt.splitlines()], default='')) == 0:
enhance_prompt = async_task.prompt
if len(remove_empty_str([safe_str(p) for p in enhance_negative_prompt.splitlines()], default='')) == 0:
enhance_negative_prompt = async_task.negative_prompt
tasks_enhance, use_expansion, loras = process_prompt(async_task, enhance_prompt,
enhance_negative_prompt,
base_model_additional_loras_enhance,
1, True,
use_expansion, use_style,
use_synthetic_refiner)
task_enhance = tasks_enhance[0]
# TODO could support vary, upscale and CN in the future
# if 'cn' in goals:
# apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width)
if async_task.freeu_enabled:
apply_freeu(async_task)
patch_samplers(async_task)
goals_enhance = ['inpaint']
enhance_inpaint_strength, initial_latent_enhance, width_enhance, height_enhance = apply_inpaint(
@ -1144,14 +1152,13 @@ def worker():
controlnet_canny_path, controlnet_cpds_path,
current_task_id, enhance_inpaint_strength,
final_scheduler_name, goals_enhance,
initial_latent_enhance, switch, positive_cond,
negative_cond, task, tasks,
tiled, use_expansion, width_enhance,
height_enhance)
initial_latent_enhance, switch,
task_enhance['c'], task_enhance['uc'],
task_enhance, tasks_enhance, tiled,
use_expansion, width_enhance, height_enhance)
# reset and prepare next iteration
img = imgs2[0]
pipeline.final_unet = final_unet
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
@ -1171,7 +1178,6 @@ def worker():
processing_time = time.perf_counter() - processing_start_time
print(f'Processing time (total): {processing_time:.2f} seconds')
while True:
time.sleep(0.01)
if len(async_tasks) > 0:

View File

@ -336,8 +336,6 @@ with shared.gradio_root:
with gr.TabItem(label=f'#{index + 1}') as enhance_tab_item:
enhance_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check',
container=False)
gr.HTML(
'DISCLAIMER: The enhance feature does not work with Inpaint or Outpaint and will be skipped.')
enhance_mask_dino_prompt_text = gr.Textbox(label='Detection prompt',
info='Use singular whenever possible',