wip: add upscale or variation to enhance

This commit is contained in:
Manuel Schmid 2024-06-19 23:53:15 +02:00
parent 87b3cec7d4
commit 4e575b9eb1
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
2 changed files with 165 additions and 98 deletions

View File

@ -113,6 +113,7 @@ class AsyncTask:
self.debugging_enhance_masks_checkbox = args.pop()
self.enhance_checkbox = args.pop()
self.enhance_uov_method = args.pop()
self.enhance_ctrls = []
for _ in range(modules.config.default_enhance_tabs):
enhance_enabled = args.pop()
@ -270,7 +271,7 @@ def worker():
return
def process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path, current_task_id,
denoising_strength, final_scheduler_name, goals, initial_latent, switch, positive_cond,
denoising_strength, final_scheduler_name, goals, initial_latent, steps, switch, positive_cond,
negative_cond, task, tasks, tiled, use_expansion, width, height, base_progress, preparation_steps,
total_count):
if async_task.last_stop is not False:
@ -287,7 +288,7 @@ def worker():
imgs = pipeline.process_diffusion(
positive_cond=positive_cond,
negative_cond=negative_cond,
steps=async_task.steps,
steps=steps,
switch=switch,
width=width,
height=height,
@ -305,7 +306,7 @@ def worker():
del positive_cond, negative_cond # Save memory
if inpaint_worker.current_task is not None:
imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
current_progress = int(base_progress + (100 - preparation_steps) * float((current_task_id + 1) * async_task.steps) / float(all_steps))
current_progress = int(base_progress + (100 - preparation_steps) * float((current_task_id + 1) * steps) / float(all_steps))
if modules.config.default_black_out_nsfw or async_task.black_out_nsfw:
progressbar(async_task, current_progress, 'Checking for NSFW content ...')
imgs = default_censor(imgs)
@ -440,22 +441,22 @@ def worker():
if len(all_ip_tasks) > 0:
pipeline.final_unet = ip_adapter.patch_model(pipeline.final_unet, all_ip_tasks)
def apply_vary(async_task, denoising_strength, switch, current_progress, advance_progress=False):
if 'subtle' in async_task.uov_method:
def apply_vary(async_task, uov_method, denoising_strength, uov_input_image, switch, current_progress, advance_progress=False):
if 'subtle' in uov_method:
denoising_strength = 0.5
if 'strong' in async_task.uov_method:
if 'strong' in uov_method:
denoising_strength = 0.85
if async_task.overwrite_vary_strength > 0:
denoising_strength = async_task.overwrite_vary_strength
shape_ceil = get_image_shape_ceil(async_task.uov_input_image)
shape_ceil = get_image_shape_ceil(uov_input_image)
if shape_ceil < 1024:
print(f'[Vary] Image is resized because it is too small.')
shape_ceil = 1024
elif shape_ceil > 2048:
print(f'[Vary] Image is resized because it is too big.')
shape_ceil = 2048
async_task.uov_input_image = set_image_shape_ceil(async_task.uov_input_image, shape_ceil)
initial_pixels = core.numpy_to_pytorch(async_task.uov_input_image)
uov_input_image = set_image_shape_ceil(uov_input_image, shape_ceil)
initial_pixels = core.numpy_to_pytorch(uov_input_image)
if advance_progress:
current_progress += 1
progressbar(async_task, current_progress, 'VAE encoding ...')
@ -470,7 +471,7 @@ def worker():
width = W * 8
height = H * 8
print(f'Final resolution is {str((width, height))}.')
return denoising_strength, initial_latent, width, height, current_progress
return uov_input_image, denoising_strength, initial_latent, width, height, current_progress
def apply_inpaint(async_task, initial_latent, inpaint_head_model_path, inpaint_image,
inpaint_mask, inpaint_parameterized, denoising_strength, inpaint_respective_field, switch,
@ -566,28 +567,28 @@ def worker():
async_task.inpaint_respective_field = 1.0
return inpaint_image, inpaint_mask
def apply_upscale(async_task, switch, current_progress, advance_progress=False):
H, W, C = async_task.uov_input_image.shape
def apply_upscale(async_task, uov_input_image, uov_method, switch, current_progress, advance_progress=False):
H, W, C = uov_input_image.shape
if advance_progress:
current_progress += 1
progressbar(async_task, current_progress, f'Upscaling image from {str((H, W))} ...')
async_task.uov_input_image = perform_upscale(async_task.uov_input_image)
uov_input_image = perform_upscale(uov_input_image)
print(f'Image upscaled.')
if '1.5x' in async_task.uov_method:
if '1.5x' in uov_method:
f = 1.5
elif '2x' in async_task.uov_method:
elif '2x' in uov_method:
f = 2.0
else:
f = 1.0
shape_ceil = get_shape_ceil(H * f, W * f)
if shape_ceil < 1024:
print(f'[Upscale] Image is resized because it is too small.')
async_task.uov_input_image = set_image_shape_ceil(async_task.uov_input_image, 1024)
uov_input_image = set_image_shape_ceil(uov_input_image, 1024)
shape_ceil = 1024
else:
async_task.uov_input_image = resample_image(async_task.uov_input_image, width=W * f, height=H * f)
uov_input_image = resample_image(uov_input_image, width=W * f, height=H * f)
image_is_super_large = shape_ceil > 2800
if 'fast' in async_task.uov_method:
if 'fast' in uov_method:
direct_return = True
elif image_is_super_large:
print('Image is too large. Directly returned the SR image. '
@ -597,21 +598,13 @@ def worker():
else:
direct_return = False
if direct_return:
d = [('Upscale (Fast)', 'upscale_fast', '2x')]
if modules.config.default_black_out_nsfw or async_task.black_out_nsfw:
progressbar(async_task, 100, 'Checking for NSFW content ...')
async_task.uov_input_image = default_censor(async_task.uov_input_image)
progressbar(async_task, 100, 'Saving image to system ...')
uov_input_image_path = log(async_task.uov_input_image, d, output_format=async_task.output_format)
yield_result(async_task, uov_input_image_path, 100, async_task.black_out_nsfw, False,
do_not_show_finished_images=True)
raise EarlyReturnException
return direct_return, uov_input_image, None, None, None, None, None, current_progress
tiled = True
denoising_strength = 0.382
if async_task.overwrite_upscale_strength > 0:
denoising_strength = async_task.overwrite_upscale_strength
initial_pixels = core.numpy_to_pytorch(async_task.uov_input_image)
initial_pixels = core.numpy_to_pytorch(uov_input_image)
if advance_progress:
current_progress += 1
progressbar(async_task, current_progress, 'VAE encoding ...')
@ -628,7 +621,7 @@ def worker():
width = W * 8
height = H * 8
print(f'Final resolution is {str((width, height))}.')
return denoising_strength, initial_latent, tiled, width, height, current_progress
return direct_return, uov_input_image, denoising_strength, initial_latent, tiled, width, height, current_progress
def apply_overrides(async_task, height, width):
if async_task.overwrite_step > 0:
@ -853,18 +846,9 @@ def worker():
if (async_task.current_tab == 'uov' or (
async_task.current_tab == 'ip' and async_task.mixing_image_prompt_and_vary_upscale)) \
and async_task.uov_method != flags.disabled and async_task.uov_input_image is not None:
async_task.uov_input_image = HWC3(async_task.uov_input_image)
if 'vary' in async_task.uov_method:
goals.append('vary')
elif 'upscale' in async_task.uov_method:
goals.append('upscale')
if 'fast' in async_task.uov_method:
skip_prompt_processing = True
else:
async_task.steps = async_task.performance_selection.steps_uov()
progressbar(async_task, 1, 'Downloading upscale models ...')
modules.config.downloading_upscale_model()
async_task.uov_input_image, skip_prompt_processing, async_task.steps = prepare_upscale(
async_task, goals, async_task.uov_input_image, async_task.uov_method, async_task.performance_selection,
async_task.steps, 1, skip_prompt_processing=skip_prompt_processing)
if (async_task.current_tab == 'inpaint' or (
async_task.current_tab == 'ip' and async_task.mixing_image_prompt_and_inpaint)) \
and isinstance(async_task.inpaint_input_image, dict):
@ -934,6 +918,24 @@ def worker():
'face')
return base_model_additional_loras, clip_vision_path, controlnet_canny_path, controlnet_cpds_path, inpaint_head_model_path, inpaint_image, inpaint_mask, ip_adapter_face_path, ip_adapter_path, ip_negative_path, skip_prompt_processing, use_synthetic_refiner
def prepare_upscale(async_task, goals, uov_input_image, uov_method, performance, steps, current_progress,
advance_progress=False, skip_prompt_processing=False):
uov_input_image = HWC3(uov_input_image)
if 'vary' in uov_method:
goals.append('vary')
elif 'upscale' in uov_method:
goals.append('upscale')
if 'fast' in uov_method:
skip_prompt_processing = True
else:
steps = performance.steps_uov()
if advance_progress:
current_progress += 1
progressbar(async_task, current_progress, 'Downloading upscale models ...')
modules.config.downloading_upscale_model()
return uov_input_image, skip_prompt_processing, steps
def prepare_enhance_prompt(prompt: str, fallback_prompt: str, translate: bool, prompt_type: str):
if safe_str(prompt) == '' or len(remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='')) == 0:
prompt = fallback_prompt
@ -948,6 +950,64 @@ def worker():
processing_time = time.perf_counter() - processing_start_time
print(f'Processing time (total): {processing_time:.2f} seconds')
def process_enhance(all_steps, async_task, base_progress, callback, controlnet_canny_path, controlnet_cpds_path,
current_progress, current_task_id, denoising_strength, inpaint_disable_initial_latent,
inpaint_engine, inpaint_respective_field, inpaint_strength,
negative_prompt, prompt, final_scheduler_name, goals, height, img, mask,
preparation_steps, steps, switch, tiled, total_count, use_expansion, use_style,
use_synthetic_refiner, width):
base_model_additional_loras = []
inpaint_head_model_path = None
inpaint_parameterized = inpaint_engine != 'None' # inpaint_engine = None, improve detail
initial_latent = None
if 'vary' in goals:
img, denoising_strength, initial_latent, width, height, current_progress = apply_vary(
async_task, async_task.enhance_uov_method, img, denoising_strength, switch, current_progress)
if 'upscale' in goals:
direct_return, img, denoising_strength, initial_latent, tiled, width, height, current_progress = apply_upscale(
async_task, img, async_task.enhance_uov_method, switch, current_progress,
advance_progress=True)
if direct_return:
return current_progress, img
if 'inpaint' in goals and inpaint_parameterized:
progressbar(async_task, current_progress, 'Downloading inpainter ...')
inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models(
inpaint_engine)
if inpaint_patch_model_path not in base_model_additional_loras:
base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
progressbar(async_task, current_progress, 'Preparing enhance prompts ...')
prompt = prepare_enhance_prompt(prompt, async_task.prompt, async_task.translate_prompts, 'prompt')
negative_prompt = prepare_enhance_prompt(negative_prompt, async_task.negative_prompt,
async_task.translate_prompts, 'negative prompt')
# positive and negative conditioning aren't available here anymore, process prompt again
tasks_enhance, use_expansion, loras, current_progress = process_prompt(
async_task, prompt, negative_prompt, base_model_additional_loras, 1, True,
use_expansion, use_style, use_synthetic_refiner, current_progress)
task_enhance = tasks_enhance[0]
# TODO could support vary, upscale and CN in the future
# if 'cn' in goals:
# apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width)
if async_task.freeu_enabled:
apply_freeu(async_task)
patch_samplers(async_task)
if 'inpaint' in goals:
denoising_strength, initial_latent, width, height, current_progress = apply_inpaint(
async_task, None, inpaint_head_model_path, img, mask,
inpaint_parameterized, inpaint_strength,
inpaint_respective_field, switch, inpaint_disable_initial_latent,
current_progress, True)
imgs, img_paths, current_progress = process_task(all_steps, async_task, callback, controlnet_canny_path,
controlnet_cpds_path, current_task_id, denoising_strength,
final_scheduler_name, goals, initial_latent, steps, switch,
task_enhance['c'], task_enhance['uc'], task_enhance,
tasks_enhance, tiled, use_expansion, width, height,
base_progress, preparation_steps, total_count)
del task_enhance['c'], task_enhance['uc'] # Save memory
return current_progress, imgs[0]
@torch.no_grad()
@torch.inference_mode()
def handler(async_task: AsyncTask):
@ -957,6 +1017,7 @@ def worker():
async_task.outpaint_selections = [o.lower() for o in async_task.outpaint_selections]
base_model_additional_loras = []
async_task.uov_method = async_task.uov_method.lower()
async_task.enhance_uov_method = async_task.enhance_uov_method.lower()
if fooocus_expansion in async_task.style_selections:
use_expansion = True
@ -1051,12 +1112,23 @@ def worker():
progressbar(async_task, current_progress, 'Image processing ...')
if 'vary' in goals:
denoising_strength, initial_latent, width, height, current_progress = apply_vary(async_task, denoising_strength, switch, current_progress)
async_task.uov_input_image, denoising_strength, initial_latent, width, height, current_progress = apply_vary(
async_task, async_task.uov_method, async_task.uov_input_image, denoising_strength, switch,
current_progress)
if 'upscale' in goals:
try:
denoising_strength, initial_latent, tiled, width, height, current_progress = apply_upscale(async_task, switch, current_progress, advance_progress=True)
except EarlyReturnException:
direct_return, async_task.uov_input_image, denoising_strength, initial_latent, tiled, width, height, current_progress = apply_upscale(
async_task, async_task.uov_input_image, async_task.uov_method, switch, current_progress,
advance_progress=True)
if direct_return:
d = [('Upscale (Fast)', 'upscale_fast', '2x')]
if modules.config.default_black_out_nsfw or async_task.black_out_nsfw:
progressbar(async_task, 100, 'Checking for NSFW content ...')
async_task.uov_input_image = default_censor(async_task.uov_input_image)
progressbar(async_task, 100, 'Saving image to system ...')
uov_input_image_path = log(async_task.uov_input_image, d, output_format=async_task.output_format)
yield_result(async_task, uov_input_image_path, 100, async_task.black_out_nsfw, False,
do_not_show_finished_images=True)
return
if 'inpaint' in goals:
@ -1131,7 +1203,7 @@ def worker():
controlnet_cpds_path,
current_task_id, denoising_strength,
final_scheduler_name, goals, initial_latent,
switch, task['c'], task['uc'], task,
async_task.steps, switch, task['c'], task['uc'], task,
tasks, tiled, use_expansion, width, height,
preparation_steps, preparation_steps,
async_task.image_number)
@ -1151,7 +1223,7 @@ def worker():
execution_time = time.perf_counter() - execution_start_time
print(f'Generating and saving time: {execution_time:.2f} seconds')
if not async_task.enhance_checkbox or len(async_task.enhance_ctrls) == 0:
if not async_task.enhance_checkbox or ('upscale' in goals and async_task.enhance_uov_method != flags.disabled) and len(async_task.enhance_ctrls) == 0:
print(f'[Enhance] Skipping, preconditions aren\'t met')
stop_processing(async_task, processing_start_time)
return
@ -1162,6 +1234,36 @@ def worker():
current_task_id = -1
for imgs in generated_imgs.values():
for img in imgs:
enhancement_image_start_time = time.perf_counter()
# upscale if not disabled or already in goals
if 'upscale' not in goals and async_task.enhance_uov_method != flags.disabled:
current_task_id += 1
goals_enhance = []
img, skip_prompt_processing, steps = prepare_upscale(async_task, goals_enhance, img,
async_task.enhance_uov_method,
async_task.performance_selection,
async_task.steps, current_progress)
if len(goals_enhance) > 0:
try:
current_progress, img = process_enhance(
all_steps, async_task, base_progress, callback, controlnet_canny_path,
controlnet_cpds_path, current_progress, current_task_id, denoising_strength, False,
'None', 0.0, 0.0, async_task.negative_prompt, async_task.prompt, final_scheduler_name,
goals_enhance, height, img, None, preparation_steps, steps, switch, tiled, total_count,
use_expansion, use_style, use_synthetic_refiner, width)
# TODO check steps in progress bar, 100% wrong
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
print('User skipped')
async_task.last_stop = False
continue
else:
print('User stopped')
break
# inpaint for all other tabs
for enhance_mask_dino_prompt_text, enhance_prompt, enhance_negative_prompt, enhance_mask_model, enhance_mask_sam_model, enhance_mask_text_threshold, enhance_mask_box_threshold, enhance_mask_sam_max_detections, enhance_inpaint_disable_initial_latent, enhance_inpaint_engine, enhance_inpaint_strength, enhance_inpaint_respective_field, enhance_inpaint_erode_or_dilate, enhance_mask_invert in async_task.enhance_ctrls:
current_task_id += 1
current_progress = int(base_progress + (100 - preparation_steps) * float(current_task_id * async_task.steps) / float(all_steps))
@ -1204,58 +1306,16 @@ def worker():
print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping')
continue
base_model_additional_loras_enhance = []
inpaint_head_model_path_enhance = None
inpaint_parameterized_enhance = enhance_inpaint_engine != 'None' # inpaint_engine = None, improve detail
if inpaint_parameterized_enhance:
progressbar(async_task, current_progress, 'Downloading inpainter ...')
inpaint_head_model_path_enhance, inpaint_patch_model_path_enhance = modules.config.downloading_inpaint_models(
enhance_inpaint_engine)
if inpaint_patch_model_path_enhance not in base_model_additional_loras_enhance:
base_model_additional_loras_enhance += [(inpaint_patch_model_path_enhance, 1.0)]
progressbar(async_task, current_progress, 'Preparing enhance prompts ...')
enhance_prompt = prepare_enhance_prompt(enhance_prompt, async_task.prompt, async_task.translate_prompts,
'prompt')
enhance_negative_prompt = prepare_enhance_prompt(enhance_negative_prompt, async_task.negative_prompt,
async_task.translate_prompts, 'negative prompt')
# positive and negative conditioning aren't available here anymore, process prompt again
tasks_enhance, use_expansion, loras, current_progress = process_prompt(async_task, enhance_prompt,
enhance_negative_prompt,
base_model_additional_loras_enhance,
1, True,
use_expansion, use_style,
use_synthetic_refiner,
current_progress)
task_enhance = tasks_enhance[0]
# TODO could support vary, upscale and CN in the future
# if 'cn' in goals:
# apply_control_nets(async_task, height, ip_adapter_face_path, ip_adapter_path, width)
if async_task.freeu_enabled:
apply_freeu(async_task)
patch_samplers(async_task)
goals_enhance = ['inpaint']
enhance_inpaint_strength, initial_latent_enhance, width_enhance, height_enhance, current_progress = apply_inpaint(
async_task, None, inpaint_head_model_path_enhance, img, mask,
inpaint_parameterized_enhance, enhance_inpaint_strength,
enhance_inpaint_respective_field, switch, enhance_inpaint_disable_initial_latent,
current_progress, True)
try:
imgs2, img_paths, current_progress = process_task(all_steps, async_task, callback,
controlnet_canny_path, controlnet_cpds_path,
current_task_id, enhance_inpaint_strength,
final_scheduler_name, goals_enhance,
initial_latent_enhance, switch,
task_enhance['c'], task_enhance['uc'],
task_enhance, tasks_enhance, tiled,
use_expansion, width_enhance, height_enhance,
base_progress, preparation_steps, total_count)
img = imgs2[0]
current_progress, img = process_enhance(
all_steps, async_task, base_progress, callback, controlnet_canny_path, controlnet_cpds_path,
current_progress, current_task_id, denoising_strength,
enhance_inpaint_disable_initial_latent, enhance_inpaint_engine,
enhance_inpaint_respective_field, enhance_inpaint_strength, enhance_negative_prompt,
enhance_prompt, final_scheduler_name, goals_enhance, height, img, mask, preparation_steps,
async_task.steps, switch, tiled, total_count, use_expansion, use_style, use_synthetic_refiner, width)
except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
@ -1266,10 +1326,12 @@ def worker():
print('User stopped')
break
del task_enhance['c'], task_enhance['uc'] # Save memory
enhancement_task_time = time.perf_counter() - enhancement_task_start_time
print(f'Enhancement time: {enhancement_task_time:.2f} seconds')
enhancement_image_time = time.perf_counter() - enhancement_image_start_time
print(f'Enhancement image time: {enhancement_image_time:.2f} seconds')
stop_processing(async_task, processing_start_time)
return

View File

@ -348,6 +348,11 @@ with shared.gradio_root:
with gr.Row(visible=False) as enhance_input_panel:
with gr.Tabs():
with gr.TabItem(label='Upscale or Variation'):
with gr.Row():
with gr.Column():
enhance_uov_method = gr.Radio(label='', show_label=False, choices=flags.uov_list, value=flags.disabled)
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/390" target="_blank">\U0001F4D4 Document</a>')
enhance_ctrls = []
for index in range(modules.config.default_enhance_tabs):
with gr.TabItem(label=f'#{index + 1}') as enhance_tab_item:
@ -925,7 +930,7 @@ with shared.gradio_root:
ctrls += [save_metadata_to_images, metadata_scheme]
ctrls += ip_ctrls
ctrls += [debugging_dino, dino_erode_or_dilate, debugging_enhance_masks_checkbox, enhance_checkbox]
ctrls += [debugging_dino, dino_erode_or_dilate, debugging_enhance_masks_checkbox, enhance_checkbox, enhance_uov_method]
ctrls += enhance_ctrls
def parse_meta(raw_prompt_txt, is_generating):