feat: add enhance image input

use this so you don't have to modify an image before enhancement
This commit is contained in:
Manuel Schmid 2024-06-21 23:51:00 +02:00
parent c7a411a8c7
commit 40e1c82b74
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
2 changed files with 122 additions and 105 deletions

View File

@ -113,6 +113,7 @@ class AsyncTask:
self.dino_erode_or_dilate = args.pop()
self.debugging_enhance_masks_checkbox = args.pop()
self.enhance_input_image = args.pop()
self.enhance_checkbox = args.pop()
self.enhance_uov_method = args.pop()
self.enhance_ctrls = []
@ -569,7 +570,7 @@ def worker():
H, W, C = uov_input_image.shape
if advance_progress:
current_progress += 1
progressbar(async_task, current_progress, f'Upscaling image from {str((H, W))} ...')
progressbar(async_task, current_progress, f'Upscaling image from {str((W, H))} ...')
uov_input_image = perform_upscale(uov_input_image)
print(f'Image upscaled.')
if '1.5x' in uov_method:
@ -843,7 +844,7 @@ def worker():
skip_prompt_processing, use_synthetic_refiner):
if (async_task.current_tab == 'uov' or (
async_task.current_tab == 'ip' and async_task.mixing_image_prompt_and_vary_upscale)) \
and async_task.uov_method != flags.disabled and async_task.uov_input_image is not None:
and async_task.uov_method != flags.disabled.lower() and async_task.uov_input_image is not None:
async_task.uov_input_image, skip_prompt_processing, async_task.steps = prepare_upscale(
async_task, goals, async_task.uov_input_image, async_task.uov_method, async_task.performance_selection,
async_task.steps, 1, skip_prompt_processing=skip_prompt_processing)
@ -914,6 +915,10 @@ def worker():
if len(async_task.cn_tasks[flags.cn_ip_face]) > 0:
clip_vision_path, ip_negative_path, ip_adapter_face_path = modules.config.downloading_ip_adapters(
'face')
if async_task.current_tab == 'enhance' and async_task.enhance_input_image is not None:
goals.append('enhance')
skip_prompt_processing = True
async_task.enhance_input_image = HWC3(async_task.enhance_input_image)
return base_model_additional_loras, clip_vision_path, controlnet_canny_path, controlnet_cpds_path, inpaint_head_model_path, inpaint_image, inpaint_mask, ip_adapter_face_path, ip_adapter_path, ip_negative_path, skip_prompt_processing, use_synthetic_refiner
def prepare_upscale(async_task, goals, uov_input_image, uov_method, performance, steps, current_progress,
@ -1161,18 +1166,26 @@ def worker():
if async_task.freeu_enabled:
apply_freeu(async_task)
# async_task.steps can have value of uov steps here when upscale has been applied
steps, _, _, _ = apply_overrides(async_task, async_task.steps, height, width)
images_to_enhance = []
if 'enhance' in goals:
images_to_enhance += [async_task.enhance_input_image]
height, width, _ = async_task.enhance_input_image.shape
# input image already provided, processing is skipped
steps = 0
all_steps = steps * async_task.image_number
# enhance_upscale_steps = 0
# enhance_upscale_steps_total = 0
if async_task.enhance_checkbox and async_task.enhance_uov_method != flags.disabled:
if async_task.enhance_checkbox and async_task.enhance_uov_method != flags.disabled.lower():
enhance_upscale_steps, _, _, _ = apply_overrides(async_task, async_task.performance_selection.steps_uov(), height, width)
enhance_upscale_steps_total = async_task.image_number * enhance_upscale_steps
all_steps += enhance_upscale_steps_total
if async_task.enhance_checkbox and len(async_task.enhance_ctrls) != 0:
all_steps += async_task.image_number * len(async_task.enhance_ctrls) * steps
enhance_steps, _, _, _ = apply_overrides(async_task, async_task.original_steps, height, width)
all_steps += async_task.image_number * len(async_task.enhance_ctrls) * enhance_steps
print(f'[Parameters] Denoising Strength = {denoising_strength}')
@ -1205,8 +1218,6 @@ def worker():
int(current_progress + async_task.callback_steps),
f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{total_count} ...', y)])
generated_imgs = {}
for current_task_id, task in enumerate(tasks):
progressbar(async_task, current_progress,
f'Preparing task {current_task_id + 1}/{async_task.image_number} ...')
@ -1222,7 +1233,7 @@ def worker():
preparation_steps, preparation_steps,
async_task.image_number)
generated_imgs[current_task_id] = imgs
images_to_enhance += imgs
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
@ -1237,7 +1248,7 @@ def worker():
execution_time = time.perf_counter() - execution_start_time
print(f'Generating and saving time: {execution_time:.2f} seconds')
if not async_task.enhance_checkbox or (async_task.enhance_uov_method == flags.disabled and len(async_task.enhance_ctrls) == 0):
if not async_task.enhance_checkbox or (async_task.enhance_uov_method == flags.disabled.lower() and len(async_task.enhance_ctrls) == 0):
print(f'[Enhance] Skipping, preconditions aren\'t met')
stop_processing(async_task, processing_start_time)
return
@ -1245,124 +1256,123 @@ def worker():
progressbar(async_task, current_progress, 'Processing enhance ...')
active_enhance_tabs = len(async_task.enhance_ctrls)
should_process_uov = async_task.enhance_uov_method != flags.disabled
if should_process_uov:
should_process_enhance_uov = async_task.enhance_uov_method != flags.disabled.lower()
if should_process_enhance_uov:
active_enhance_tabs += 1
total_count = sum([len(imgs) for _, imgs in generated_imgs.items()]) * active_enhance_tabs
total_count = len(images_to_enhance) * active_enhance_tabs
base_progress = current_progress
current_task_id = -1
done_steps_upscaling = 0
done_steps_inpainting = 0
enhance_steps, _, _, _ = apply_overrides(async_task, async_task.original_steps, height, width)
for imgs in generated_imgs.values():
for img in imgs:
enhancement_image_start_time = time.perf_counter()
for img in images_to_enhance:
enhancement_image_start_time = time.perf_counter()
# upscale if not disabled or already in goals
if should_process_uov:
current_task_id += 1
current_progress = int(base_progress + (100 - preparation_steps) / float(all_steps) * (done_steps_upscaling + done_steps_inpainting))
goals_enhance = []
img, skip_prompt_processing, steps = prepare_upscale(async_task, goals_enhance, img,
async_task.enhance_uov_method,
async_task.performance_selection,
enhance_steps, current_progress)
# upscale if not disabled or already in goals
if should_process_enhance_uov:
current_task_id += 1
current_progress = int(base_progress + (100 - preparation_steps) / float(all_steps) * (done_steps_upscaling + done_steps_inpainting))
goals_enhance = []
img, skip_prompt_processing, steps = prepare_upscale(async_task, goals_enhance, img,
async_task.enhance_uov_method,
async_task.performance_selection,
enhance_steps, current_progress)
steps, _, _, _ = apply_overrides(async_task, async_task.original_steps, height, width)
if len(goals_enhance) > 0:
try:
current_progress, img = process_enhance(
all_steps, async_task, callback, controlnet_canny_path,
controlnet_cpds_path, current_progress, current_task_id, denoising_strength, False,
'None', 0.0, 0.0, async_task.negative_prompt, async_task.prompt, final_scheduler_name,
goals_enhance, height, img, None, preparation_steps, steps, switch, tiled, total_count,
use_expansion, use_style, use_synthetic_refiner, width)
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
print('User skipped')
async_task.last_stop = False
# also skip all enhance steps for this image, but add the steps to the progress bar
done_steps_inpainting += len(async_task.enhance_ctrls) * enhance_steps
continue
else:
print('User stopped')
break
finally:
done_steps_upscaling += steps
# inpaint for all other tabs
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_detections, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field, enhance_inpaint_erode_or_dilate, enhance_mask_invert in async_task.enhance_ctrls:
current_task_id += 1
current_progress = int(base_progress + (100 - preparation_steps) / float(all_steps) * (done_steps_upscaling + done_steps_inpainting))
progressbar(async_task, current_progress, f'Preparing enhancement {current_task_id + 1}/{total_count} ...')
enhancement_task_start_time = time.perf_counter()
if enhance_mask_model == 'sam':
print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"')
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(
img, mask_model=enhance_mask_model, sam_options=SAMOptions(
dino_prompt=enhance_mask_dino_prompt_text,
dino_box_threshold=enhance_mask_box_threshold,
dino_text_threshold=enhance_mask_text_threshold,
dino_erode_or_dilate=async_task.dino_erode_or_dilate,
dino_debug=async_task.debugging_dino,
max_detections=enhance_mask_sam_max_detections,
model_type=enhance_mask_sam_model,
))
if len(mask.shape) == 3:
mask = mask[:, :, 0]
if int(enhance_inpaint_erode_or_dilate) != 0:
mask = erode_or_dilate(mask, enhance_inpaint_erode_or_dilate)
if enhance_mask_invert:
mask = 255 - mask
if async_task.debugging_enhance_masks_checkbox:
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
yield_result(async_task, mask, current_progress, async_task.black_out_nsfw, False,
async_task.disable_intermediate_results)
print(f'[Enhance] {dino_detection_count} boxes detected')
print(f'[Enhance] {sam_detection_count} segments detected in boxes')
print(f'[Enhance] {sam_detection_on_mask_count} segments applied to mask')
if enhance_mask_model == 'sam' and (
dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0):
print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping')
continue
goals_enhance = ['inpaint']
steps, _, _, _ = apply_overrides(async_task, steps, height, width)
if len(goals_enhance) > 0:
try:
current_progress, img = process_enhance(
all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path,
current_progress, current_task_id, denoising_strength,
enhance_inpaint_disable_initial_latent, enhance_inpaint_engine,
enhance_inpaint_respective_field, enhance_inpaint_strength, enhance_negative_prompt,
enhance_prompt, final_scheduler_name, goals_enhance, height, img, mask, preparation_steps,
enhance_steps, switch, tiled, total_count, use_expansion, use_style, use_synthetic_refiner, width)
all_steps, async_task, callback, controlnet_canny_path,
controlnet_cpds_path, current_progress, current_task_id, denoising_strength, False,
'None', 0.0, 0.0, async_task.negative_prompt, async_task.prompt, final_scheduler_name,
goals_enhance, height, img, None, preparation_steps, steps, switch, tiled, total_count,
use_expansion, use_style, use_synthetic_refiner, width)
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
print('User skipped')
async_task.last_stop = False
# also skip all enhance steps for this image, but add the steps to the progress bar
done_steps_inpainting += len(async_task.enhance_ctrls) * enhance_steps
continue
else:
print('User stopped')
break
finally:
done_steps_inpainting += enhance_steps
done_steps_upscaling += steps
enhancement_task_time = time.perf_counter() - enhancement_task_start_time
print(f'Enhancement time: {enhancement_task_time:.2f} seconds')
# inpaint for all other tabs
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_detections, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field, enhance_inpaint_erode_or_dilate, enhance_mask_invert in async_task.enhance_ctrls:
current_task_id += 1
current_progress = int(base_progress + (100 - preparation_steps) / float(all_steps) * (done_steps_upscaling + done_steps_inpainting))
progressbar(async_task, current_progress, f'Preparing enhancement {current_task_id + 1}/{total_count} ...')
enhancement_task_start_time = time.perf_counter()
enhancement_image_time = time.perf_counter() - enhancement_image_start_time
print(f'Enhancement image time: {enhancement_image_time:.2f} seconds')
if enhance_mask_model == 'sam':
print(f'[Enhance] Searching for "{enhance_mask_dino_prompt_text}"')
mask, dino_detection_count, sam_detection_count, sam_detection_on_mask_count = generate_mask_from_image(
img, mask_model=enhance_mask_model, sam_options=SAMOptions(
dino_prompt=enhance_mask_dino_prompt_text,
dino_box_threshold=enhance_mask_box_threshold,
dino_text_threshold=enhance_mask_text_threshold,
dino_erode_or_dilate=async_task.dino_erode_or_dilate,
dino_debug=async_task.debugging_dino,
max_detections=enhance_mask_sam_max_detections,
model_type=enhance_mask_sam_model,
))
if len(mask.shape) == 3:
mask = mask[:, :, 0]
if int(enhance_inpaint_erode_or_dilate) != 0:
mask = erode_or_dilate(mask, enhance_inpaint_erode_or_dilate)
if enhance_mask_invert:
mask = 255 - mask
if async_task.debugging_enhance_masks_checkbox:
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
yield_result(async_task, mask, current_progress, async_task.black_out_nsfw, False,
async_task.disable_intermediate_results)
print(f'[Enhance] {dino_detection_count} boxes detected')
print(f'[Enhance] {sam_detection_count} segments detected in boxes')
print(f'[Enhance] {sam_detection_on_mask_count} segments applied to mask')
if enhance_mask_model == 'sam' and (
dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0):
print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping')
continue
goals_enhance = ['inpaint']
try:
current_progress, img = process_enhance(
all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path,
current_progress, current_task_id, denoising_strength,
enhance_inpaint_disable_initial_latent, enhance_inpaint_engine,
enhance_inpaint_respective_field, enhance_inpaint_strength, enhance_negative_prompt,
enhance_prompt, final_scheduler_name, goals_enhance, height, img, mask, preparation_steps,
enhance_steps, switch, tiled, total_count, use_expansion, use_style, use_synthetic_refiner, width)
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
print('User skipped')
async_task.last_stop = False
continue
else:
print('User stopped')
break
finally:
done_steps_inpainting += enhance_steps
enhancement_task_time = time.perf_counter() - enhancement_task_start_time
print(f'Enhancement time: {enhancement_task_time:.2f} seconds')
enhancement_image_time = time.perf_counter() - enhancement_image_start_time
print(f'Enhancement image time: {enhancement_image_time:.2f} seconds')
stop_processing(async_task, processing_start_time)
return

View File

@ -325,6 +325,11 @@ with shared.gradio_root:
desc_input_image.upload(trigger_show_image_properties, inputs=desc_input_image,
outputs=desc_image_size, show_progress=False, queue=False)
with gr.TabItem(label='Enhance') as enhance_tab:
with gr.Row():
with gr.Column():
enhance_input_image = grh.Image(label='Image to enhance', source='upload', type='numpy')
with gr.TabItem(label='Metadata') as metadata_tab:
with gr.Column():
metadata_input_image = grh.Image(label='For images created by Fooocus', source='upload', type='filepath')
@ -488,6 +493,7 @@ with shared.gradio_root:
inpaint_tab.select(lambda: 'inpaint', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
ip_tab.select(lambda: 'ip', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
desc_tab.select(lambda: 'desc', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
enhance_tab.select(lambda: 'enhance', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
metadata_tab.select(lambda: 'metadata', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
enhance_checkbox.change(lambda x: gr.update(visible=x), inputs=enhance_checkbox,
@ -930,7 +936,8 @@ with shared.gradio_root:
ctrls += [save_metadata_to_images, metadata_scheme]
ctrls += ip_ctrls
ctrls += [debugging_dino, dino_erode_or_dilate, debugging_enhance_masks_checkbox, enhance_checkbox, enhance_uov_method]
ctrls += [debugging_dino, dino_erode_or_dilate, debugging_enhance_masks_checkbox,
enhance_input_image, enhance_checkbox, enhance_uov_method]
ctrls += enhance_ctrls
def parse_meta(raw_prompt_txt, is_generating):