feat: refresh the whole pipeline, allows usage of inpaint and enhancement prompts
This commit is contained in:
parent
e1be3fa37a
commit
ff3418876d
|
|
@ -622,10 +622,10 @@ def worker():
|
|||
height = async_task.overwrite_height
|
||||
return height, switch, width
|
||||
|
||||
def process_prompt(async_task, base_model_additional_loras, use_expansion, use_style,
|
||||
def process_prompt(async_task, prompt, negative_prompt, base_model_additional_loras, image_number, disable_seed_increment, use_expansion, use_style,
|
||||
use_synthetic_refiner):
|
||||
prompts = remove_empty_str([safe_str(p) for p in async_task.prompt.splitlines()], default='')
|
||||
negative_prompts = remove_empty_str([safe_str(p) for p in async_task.negative_prompt.splitlines()], default='')
|
||||
prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='')
|
||||
negative_prompts = remove_empty_str([safe_str(p) for p in negative_prompt.splitlines()], default='')
|
||||
prompt = prompts[0]
|
||||
negative_prompt = negative_prompts[0]
|
||||
if prompt == '':
|
||||
|
|
@ -647,8 +647,8 @@ def worker():
|
|||
pipeline.set_clip_skip(async_task.clip_skip)
|
||||
progressbar(async_task, 3, 'Processing prompts ...')
|
||||
tasks = []
|
||||
for i in range(async_task.image_number):
|
||||
if async_task.disable_seed_increment:
|
||||
for i in range(image_number):
|
||||
if disable_seed_increment:
|
||||
task_seed = async_task.seed % (constants.MAX_SEED + 1)
|
||||
else:
|
||||
task_seed = (async_task.seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not
|
||||
|
|
@ -811,6 +811,95 @@ def worker():
|
|||
async_task.adm_scaler_negative = 1.0
|
||||
async_task.adm_scaler_end = 0.0
|
||||
|
||||
def apply_image_input(async_task, base_model_additional_loras, clip_vision_path, controlnet_canny_path,
|
||||
controlnet_cpds_path, goals, inpaint_head_model_path, inpaint_mask, inpaint_parameterized,
|
||||
ip_adapter_face_path, ip_adapter_path, ip_negative_path, skip_prompt_processing,
|
||||
use_synthetic_refiner):
|
||||
if (async_task.current_tab == 'uov' or (
|
||||
async_task.current_tab == 'ip' and async_task.mixing_image_prompt_and_vary_upscale)) \
|
||||
and async_task.uov_method != flags.disabled and async_task.uov_input_image is not None:
|
||||
async_task.uov_input_image = HWC3(async_task.uov_input_image)
|
||||
if 'vary' in async_task.uov_method:
|
||||
goals.append('vary')
|
||||
elif 'upscale' in async_task.uov_method:
|
||||
goals.append('upscale')
|
||||
if 'fast' in async_task.uov_method:
|
||||
skip_prompt_processing = True
|
||||
else:
|
||||
async_task.steps = async_task.performance_selection.steps_uov()
|
||||
|
||||
progressbar(async_task, 1, 'Downloading upscale models ...')
|
||||
modules.config.downloading_upscale_model()
|
||||
if (async_task.current_tab == 'inpaint' or (
|
||||
async_task.current_tab == 'ip' and async_task.mixing_image_prompt_and_inpaint)) \
|
||||
and isinstance(async_task.inpaint_input_image, dict):
|
||||
inpaint_image = async_task.inpaint_input_image['image']
|
||||
inpaint_mask = async_task.inpaint_input_image['mask'][:, :, 0]
|
||||
|
||||
if async_task.inpaint_mask_upload_checkbox:
|
||||
if isinstance(async_task.inpaint_mask_image_upload, dict):
|
||||
if (isinstance(async_task.inpaint_mask_image_upload['image'], np.ndarray)
|
||||
and isinstance(async_task.inpaint_mask_image_upload['mask'], np.ndarray)
|
||||
and async_task.inpaint_mask_image_upload['image'].ndim == 3):
|
||||
async_task.inpaint_mask_image_upload = np.maximum(
|
||||
async_task.inpaint_mask_image_upload['image'],
|
||||
async_task.inpaint_mask_image_upload['mask'])
|
||||
if isinstance(async_task.inpaint_mask_image_upload,
|
||||
np.ndarray) and async_task.inpaint_mask_image_upload.ndim == 3:
|
||||
H, W, C = inpaint_image.shape
|
||||
async_task.inpaint_mask_image_upload = resample_image(async_task.inpaint_mask_image_upload,
|
||||
width=W, height=H)
|
||||
async_task.inpaint_mask_image_upload = np.mean(async_task.inpaint_mask_image_upload, axis=2)
|
||||
async_task.inpaint_mask_image_upload = (async_task.inpaint_mask_image_upload > 127).astype(
|
||||
np.uint8) * 255
|
||||
inpaint_mask = np.maximum(inpaint_mask, async_task.inpaint_mask_image_upload)
|
||||
|
||||
if int(async_task.inpaint_erode_or_dilate) != 0:
|
||||
inpaint_mask = erode_or_dilate(inpaint_mask, async_task.inpaint_erode_or_dilate)
|
||||
|
||||
if async_task.invert_mask_checkbox:
|
||||
inpaint_mask = 255 - inpaint_mask
|
||||
|
||||
inpaint_image = HWC3(inpaint_image)
|
||||
if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
|
||||
and (np.any(inpaint_mask > 127) or len(async_task.outpaint_selections) > 0):
|
||||
progressbar(async_task, 1, 'Downloading upscale models ...')
|
||||
modules.config.downloading_upscale_model()
|
||||
if inpaint_parameterized:
|
||||
progressbar(async_task, 1, 'Downloading inpainter ...')
|
||||
inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models(
|
||||
async_task.inpaint_engine)
|
||||
base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
|
||||
print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
|
||||
if async_task.refiner_model_name == 'None':
|
||||
use_synthetic_refiner = True
|
||||
async_task.refiner_switch = 0.8
|
||||
else:
|
||||
inpaint_head_model_path, inpaint_patch_model_path = None, None
|
||||
print(f'[Inpaint] Parameterized inpaint is disabled.')
|
||||
if async_task.inpaint_additional_prompt != '':
|
||||
if async_task.prompt == '':
|
||||
async_task.prompt = async_task.inpaint_additional_prompt
|
||||
else:
|
||||
async_task.prompt = async_task.inpaint_additional_prompt + '\n' + async_task.prompt
|
||||
goals.append('inpaint')
|
||||
if async_task.current_tab == 'ip' or \
|
||||
async_task.mixing_image_prompt_and_vary_upscale or \
|
||||
async_task.mixing_image_prompt_and_inpaint:
|
||||
goals.append('cn')
|
||||
progressbar(async_task, 1, 'Downloading control models ...')
|
||||
if len(async_task.cn_tasks[flags.cn_canny]) > 0:
|
||||
controlnet_canny_path = modules.config.downloading_controlnet_canny()
|
||||
if len(async_task.cn_tasks[flags.cn_cpds]) > 0:
|
||||
controlnet_cpds_path = modules.config.downloading_controlnet_cpds()
|
||||
if len(async_task.cn_tasks[flags.cn_ip]) > 0:
|
||||
clip_vision_path, ip_negative_path, ip_adapter_path = modules.config.downloading_ip_adapters('ip')
|
||||
if len(async_task.cn_tasks[flags.cn_ip_face]) > 0:
|
||||
clip_vision_path, ip_negative_path, ip_adapter_face_path = modules.config.downloading_ip_adapters(
|
||||
'face')
|
||||
return base_model_additional_loras, clip_vision_path, controlnet_canny_path, controlnet_cpds_path, inpaint_head_model_path, inpaint_image, inpaint_mask, ip_adapter_face_path, ip_adapter_path, ip_negative_path, skip_prompt_processing, use_synthetic_refiner
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def handler(async_task: AsyncTask):
|
||||
|
|
@ -882,83 +971,10 @@ def worker():
|
|||
tasks = []
|
||||
|
||||
if async_task.input_image_checkbox:
|
||||
if (async_task.current_tab == 'uov' or (
|
||||
async_task.current_tab == 'ip' and async_task.mixing_image_prompt_and_vary_upscale)) \
|
||||
and async_task.uov_method != flags.disabled and async_task.uov_input_image is not None:
|
||||
async_task.uov_input_image = HWC3(async_task.uov_input_image)
|
||||
if 'vary' in async_task.uov_method:
|
||||
goals.append('vary')
|
||||
elif 'upscale' in async_task.uov_method:
|
||||
goals.append('upscale')
|
||||
if 'fast' in async_task.uov_method:
|
||||
skip_prompt_processing = True
|
||||
else:
|
||||
async_task.steps = async_task.performance_selection.steps_uov()
|
||||
|
||||
progressbar(async_task, 1, 'Downloading upscale models ...')
|
||||
modules.config.downloading_upscale_model()
|
||||
if (async_task.current_tab == 'inpaint' or (
|
||||
async_task.current_tab == 'ip' and async_task.mixing_image_prompt_and_inpaint)) \
|
||||
and isinstance(async_task.inpaint_input_image, dict):
|
||||
inpaint_image = async_task.inpaint_input_image['image']
|
||||
inpaint_mask = async_task.inpaint_input_image['mask'][:, :, 0]
|
||||
|
||||
if async_task.inpaint_mask_upload_checkbox:
|
||||
if isinstance(async_task.inpaint_mask_image_upload, dict):
|
||||
if (isinstance(async_task.inpaint_mask_image_upload['image'], np.ndarray)
|
||||
and isinstance(async_task.inpaint_mask_image_upload['mask'], np.ndarray)
|
||||
and async_task.inpaint_mask_image_upload['image'].ndim == 3):
|
||||
async_task.inpaint_mask_image_upload = np.maximum(async_task.inpaint_mask_image_upload['image'], async_task.inpaint_mask_image_upload['mask'])
|
||||
if isinstance(async_task.inpaint_mask_image_upload, np.ndarray) and async_task.inpaint_mask_image_upload.ndim == 3:
|
||||
H, W, C = inpaint_image.shape
|
||||
async_task.inpaint_mask_image_upload = resample_image(async_task.inpaint_mask_image_upload, width=W, height=H)
|
||||
async_task.inpaint_mask_image_upload = np.mean(async_task.inpaint_mask_image_upload, axis=2)
|
||||
async_task.inpaint_mask_image_upload = (async_task.inpaint_mask_image_upload > 127).astype(np.uint8) * 255
|
||||
inpaint_mask = np.maximum(inpaint_mask, async_task.inpaint_mask_image_upload)
|
||||
|
||||
if int(async_task.inpaint_erode_or_dilate) != 0:
|
||||
inpaint_mask = erode_or_dilate(inpaint_mask, async_task.inpaint_erode_or_dilate)
|
||||
|
||||
if async_task.invert_mask_checkbox:
|
||||
inpaint_mask = 255 - inpaint_mask
|
||||
|
||||
inpaint_image = HWC3(inpaint_image)
|
||||
if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
|
||||
and (np.any(inpaint_mask > 127) or len(async_task.outpaint_selections) > 0):
|
||||
progressbar(async_task, 1, 'Downloading upscale models ...')
|
||||
modules.config.downloading_upscale_model()
|
||||
if inpaint_parameterized:
|
||||
progressbar(async_task, 1, 'Downloading inpainter ...')
|
||||
inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models(
|
||||
async_task.inpaint_engine)
|
||||
base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
|
||||
print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
|
||||
if async_task.refiner_model_name == 'None':
|
||||
use_synthetic_refiner = True
|
||||
async_task.refiner_switch = 0.8
|
||||
else:
|
||||
inpaint_head_model_path, inpaint_patch_model_path = None, None
|
||||
print(f'[Inpaint] Parameterized inpaint is disabled.')
|
||||
if async_task.inpaint_additional_prompt != '':
|
||||
if async_task.prompt == '':
|
||||
async_task.prompt = async_task.inpaint_additional_prompt
|
||||
else:
|
||||
async_task.prompt = async_task.inpaint_additional_prompt + '\n' + async_task.prompt
|
||||
goals.append('inpaint')
|
||||
if async_task.current_tab == 'ip' or \
|
||||
async_task.mixing_image_prompt_and_vary_upscale or \
|
||||
async_task.mixing_image_prompt_and_inpaint:
|
||||
goals.append('cn')
|
||||
progressbar(async_task, 1, 'Downloading control models ...')
|
||||
if len(async_task.cn_tasks[flags.cn_canny]) > 0:
|
||||
controlnet_canny_path = modules.config.downloading_controlnet_canny()
|
||||
if len(async_task.cn_tasks[flags.cn_cpds]) > 0:
|
||||
controlnet_cpds_path = modules.config.downloading_controlnet_cpds()
|
||||
if len(async_task.cn_tasks[flags.cn_ip]) > 0:
|
||||
clip_vision_path, ip_negative_path, ip_adapter_path = modules.config.downloading_ip_adapters('ip')
|
||||
if len(async_task.cn_tasks[flags.cn_ip_face]) > 0:
|
||||
clip_vision_path, ip_negative_path, ip_adapter_face_path = modules.config.downloading_ip_adapters(
|
||||
'face')
|
||||
base_model_additional_loras, clip_vision_path, controlnet_canny_path, controlnet_cpds_path, inpaint_head_model_path, inpaint_image, inpaint_mask, ip_adapter_face_path, ip_adapter_path, ip_negative_path, skip_prompt_processing, use_synthetic_refiner = apply_image_input(
|
||||
async_task, base_model_additional_loras, clip_vision_path, controlnet_canny_path, controlnet_cpds_path,
|
||||
goals, inpaint_head_model_path, inpaint_mask, inpaint_parameterized, ip_adapter_face_path, ip_adapter_path,
|
||||
ip_negative_path, skip_prompt_processing, use_synthetic_refiner)
|
||||
|
||||
|
||||
# Load or unload CNs
|
||||
|
|
@ -975,8 +991,10 @@ def worker():
|
|||
progressbar(async_task, 1, 'Initializing ...')
|
||||
|
||||
if not skip_prompt_processing:
|
||||
tasks, use_expansion, loras = process_prompt(async_task, base_model_additional_loras, use_expansion, use_style,
|
||||
use_synthetic_refiner)
|
||||
tasks, use_expansion, loras = process_prompt(async_task, async_task.prompt, async_task.negative_prompt,
|
||||
base_model_additional_loras, async_task.image_number,
|
||||
async_task.disable_seed_increment, use_expansion, use_style,
|
||||
use_synthetic_refiner)
|
||||
|
||||
if len(goals) > 0:
|
||||
progressbar(async_task, 7, 'Image processing ...')
|
||||
|
|
@ -1054,13 +1072,12 @@ def worker():
|
|||
switch, task['c'], task['uc'], task,
|
||||
tasks, tiled, use_expansion, width, height)
|
||||
|
||||
if not async_task.enhance_checkbox or len(async_task.enhance_ctrls) == 0 or 'inpaint' in goals:
|
||||
if not async_task.enhance_checkbox or len(async_task.enhance_ctrls) == 0:
|
||||
print(f'[Enhance] Skipping, preconditions aren\'t met')
|
||||
continue
|
||||
|
||||
# enhance
|
||||
progressbar(async_task, current_progress, 'Processing enhance ...')
|
||||
final_unet = pipeline.final_unet.clone()
|
||||
|
||||
for img in imgs:
|
||||
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_num_boxes, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field in async_task.enhance_ctrls:
|
||||
|
|
@ -1076,6 +1093,9 @@ def worker():
|
|||
))
|
||||
mask = mask[:, :, 0]
|
||||
|
||||
if int(async_task.inpaint_erode_or_dilate) != 0:
|
||||
mask = erode_or_dilate(mask, async_task.inpaint_erode_or_dilate)
|
||||
|
||||
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
|
||||
# TODO also show do_not_show_finished_images=len(tasks) == 1
|
||||
yield_result(async_task, mask, async_task.black_out_nsfw, False,
|
||||
|
|
@ -1090,38 +1110,7 @@ def worker():
|
|||
print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping')
|
||||
continue
|
||||
|
||||
# TODO make configurable
|
||||
|
||||
# # do not apply loras / controlnets / etc. twice (samplers are needed though)
|
||||
# pipeline.final_unet = pipeline.model_base.unet.clone()
|
||||
|
||||
# pipeline.refresh_everything(refiner_model_name=async_task.refiner_model_name,
|
||||
# base_model_name=async_task.base_model_name,
|
||||
# loras=[],
|
||||
# base_model_additional_loras=[],
|
||||
# use_synthetic_refiner=use_synthetic_refiner,
|
||||
# vae_name=async_task.vae_name)
|
||||
# pipeline.set_clip_skip(async_task.clip_skip)
|
||||
#
|
||||
# # patch everything again except original inpainting
|
||||
# if 'cn' in goals:
|
||||
# apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width)
|
||||
# if async_task.freeu_enabled:
|
||||
# apply_freeu(async_task)
|
||||
# patch_samplers(async_task)
|
||||
|
||||
positive_cond = task['c']
|
||||
if enhance_prompt is not '':
|
||||
progressbar(async_task, current_progress, f'Encoding positive ...')
|
||||
positive_cond = pipeline.clip_encode(texts=[enhance_prompt], pool_top_k=1)
|
||||
|
||||
negative_cond = task['uc']
|
||||
if abs(float(async_task.cfg_scale) - 1.0) < 1e-4:
|
||||
negative_cond = pipeline.clone_cond(positive_cond)
|
||||
elif enhance_negative_prompt != '':
|
||||
progressbar(async_task, current_progress, f'Encoding negative ...')
|
||||
negative_cond = pipeline.clip_encode(texts=[enhance_negative_prompt], pool_top_k=1)
|
||||
|
||||
base_model_additional_loras_enhance = []
|
||||
inpaint_head_model_path_enhance = None
|
||||
inpaint_parameterized_enhance = enhance_inpaint_engine != 'None' # inpaint_engine = None, improve detail
|
||||
|
||||
|
|
@ -1129,9 +1118,28 @@ def worker():
|
|||
progressbar(async_task, current_progress, 'Downloading inpainter ...')
|
||||
inpaint_head_model_path_enhance, inpaint_patch_model_path_enhance = modules.config.downloading_inpaint_models(
|
||||
async_task.inpaint_engine)
|
||||
if inpaint_patch_model_path_enhance not in base_model_additional_loras:
|
||||
base_model_additional_loras += [(inpaint_patch_model_path_enhance, 1.0)]
|
||||
pipeline.refresh_loras(loras, base_model_additional_loras=base_model_additional_loras)
|
||||
if inpaint_patch_model_path_enhance not in base_model_additional_loras_enhance:
|
||||
base_model_additional_loras_enhance += [(inpaint_patch_model_path_enhance, 1.0)]
|
||||
|
||||
if len(remove_empty_str([safe_str(p) for p in enhance_prompt.splitlines()], default='')) == 0:
|
||||
enhance_prompt = async_task.prompt
|
||||
if len(remove_empty_str([safe_str(p) for p in enhance_negative_prompt.splitlines()], default='')) == 0:
|
||||
enhance_negative_prompt = async_task.negative_prompt
|
||||
|
||||
tasks_enhance, use_expansion, loras = process_prompt(async_task, enhance_prompt,
|
||||
enhance_negative_prompt,
|
||||
base_model_additional_loras_enhance,
|
||||
1, True,
|
||||
use_expansion, use_style,
|
||||
use_synthetic_refiner)
|
||||
task_enhance = tasks_enhance[0]
|
||||
|
||||
# TODO could support vary, upscale and CN in the future
|
||||
# if 'cn' in goals:
|
||||
# apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width)
|
||||
if async_task.freeu_enabled:
|
||||
apply_freeu(async_task)
|
||||
patch_samplers(async_task)
|
||||
|
||||
goals_enhance = ['inpaint']
|
||||
enhance_inpaint_strength, initial_latent_enhance, width_enhance, height_enhance = apply_inpaint(
|
||||
|
|
@ -1144,14 +1152,13 @@ def worker():
|
|||
controlnet_canny_path, controlnet_cpds_path,
|
||||
current_task_id, enhance_inpaint_strength,
|
||||
final_scheduler_name, goals_enhance,
|
||||
initial_latent_enhance, switch, positive_cond,
|
||||
negative_cond, task, tasks,
|
||||
tiled, use_expansion, width_enhance,
|
||||
height_enhance)
|
||||
initial_latent_enhance, switch,
|
||||
task_enhance['c'], task_enhance['uc'],
|
||||
task_enhance, tasks_enhance, tiled,
|
||||
use_expansion, width_enhance, height_enhance)
|
||||
|
||||
# reset and prepare next iteration
|
||||
img = imgs2[0]
|
||||
pipeline.final_unet = final_unet
|
||||
|
||||
except ldm_patched.modules.model_management.InterruptProcessingException:
|
||||
if async_task.last_stop == 'skip':
|
||||
|
|
@ -1171,7 +1178,6 @@ def worker():
|
|||
processing_time = time.perf_counter() - processing_start_time
|
||||
print(f'Processing time (total): {processing_time:.2f} seconds')
|
||||
|
||||
|
||||
while True:
|
||||
time.sleep(0.01)
|
||||
if len(async_tasks) > 0:
|
||||
|
|
|
|||
2
webui.py
2
webui.py
|
|
@ -336,8 +336,6 @@ with shared.gradio_root:
|
|||
with gr.TabItem(label=f'#{index + 1}') as enhance_tab_item:
|
||||
enhance_enabled = gr.Checkbox(label='Enable', value=False, elem_classes='min_check',
|
||||
container=False)
|
||||
gr.HTML(
|
||||
'DISCLAIMER: The enhance feature does not work with Inpaint or Outpaint and will be skipped.')
|
||||
|
||||
enhance_mask_dino_prompt_text = gr.Textbox(label='Detection prompt',
|
||||
info='Use singular whenever possible',
|
||||
|
|
|
|||
Loading…
Reference in New Issue