From 80cad7787f28a2b6057ac132bf6eede96fb9518f Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 18 May 2024 17:18:23 +0200 Subject: [PATCH] refactor: code cleanup, move makedirs_with_log back to util --- modules/async_worker.py | 42 ++++++++++++++++++++++++++--------------- modules/config.py | 5 +++-- modules/extra_utils.py | 9 +-------- modules/util.py | 8 +++++++- 4 files changed, 38 insertions(+), 26 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index e8a068db..7f0a46e3 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -4,6 +4,7 @@ from modules.patch import PatchSettings, patch_settings, patch_all patch_all() + class AsyncTask: def __init__(self, args): self.args = args @@ -44,12 +45,12 @@ def worker(): import args_manager from extras.censor import censor_batch, censor_single - from modules.sdxl_styles import apply_style, get_random_style, apply_wildcards, fooocus_expansion, apply_arrays, random_style_name + from modules.sdxl_styles import apply_style, get_random_style, fooocus_expansion, apply_arrays, random_style_name from modules.private_logger import log from extras.expansion import safe_str from modules.util import (remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, - get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras, - parse_lora_references_from_prompt, apply_wildcards) + get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras, + parse_lora_references_from_prompt, apply_wildcards) from modules.upscaler import perform_upscale from modules.flags import Performance from modules.meta_parser import get_metadata_parser, MetadataScheme @@ -70,7 +71,8 @@ def worker(): print(f'[Fooocus] {text}') async_task.yields.append(['preview', (number, text, None)]) - def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, progressbar_index=13): + def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, + progressbar_index=13): if not isinstance(imgs, list): imgs = [imgs] @@ -153,7 +155,8 @@ def worker(): base_model_name = args.pop() refiner_model_name = args.pop() refiner_switch = args.pop() - loras = get_enabled_loras([(bool(args.pop()), str(args.pop()), float(args.pop())) for _ in range(modules.config.default_max_lora_number)]) + loras = get_enabled_loras([(bool(args.pop()), str(args.pop()), float(args.pop())) for _ in + range(modules.config.default_max_lora_number)]) input_image_checkbox = args.pop() current_tab = args.pop() uov_method = args.pop() @@ -203,7 +206,8 @@ def worker(): inpaint_erode_or_dilate = args.pop() save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False - metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS + metadata_scheme = MetadataScheme( + args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS cn_tasks = {x: [] for x in flags.ip_list} for _ in range(flags.controlnet_image_count): @@ -443,7 +447,7 @@ def worker(): progressbar(async_task, 3, 'Processing prompts ...') tasks = [] - + for i in range(image_number): if disable_seed_increment: task_seed = seed % (constants.MAX_SEED + 1) @@ -454,8 +458,10 @@ def worker(): task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order) task_prompt = apply_arrays(task_prompt, i) task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order) - task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_positive_prompts] - task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in extra_negative_prompts] + task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in + extra_positive_prompts] + task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in + extra_negative_prompts] positive_basic_workloads = [] negative_basic_workloads = [] @@ -656,7 +662,8 @@ def worker(): ) if debugging_inpaint_preprocessor: - yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw, do_not_show_finished_images=True) + yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw, + do_not_show_finished_images=True) return progressbar(async_task, 13, 'VAE Inpaint encoding ...') @@ -811,7 +818,8 @@ def worker(): done_steps = current_task_id * steps + step async_task.yields.append(['preview', ( int(15.0 + 85.0 * float(done_steps) / float(all_steps)), - f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling', y)]) + f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling', + y)]) for current_task_id, task in enumerate(tasks): execution_start_time = time.perf_counter() @@ -866,7 +874,8 @@ def worker(): d = [('Prompt', 'prompt', task['log_positive_prompt']), ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']), ('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']), - ('Styles', 'styles', str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])), + ('Styles', 'styles', + str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])), ('Performance', 'performance', performance_selection.value)] if performance_selection.steps() != steps: @@ -889,7 +898,8 @@ def worker(): if refiner_swap_method != flags.refiner_swap_method: d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method)) if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr: - d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg)) + d.append( + ('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg)) d.append(('Sampler', 'sampler', sampler_name)) d.append(('Scheduler', 'scheduler', scheduler_name)) @@ -909,11 +919,13 @@ def worker(): metadata_parser.set_data(task['log_positive_prompt'], task['positive'], task['log_negative_prompt'], task['negative'], steps, base_model_name, refiner_model_name, loras, vae_name) - d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) + d.append(('Metadata Scheme', 'metadata_scheme', + metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version)) img_paths.append(log(x, d, metadata_parser, output_format, task)) - yield_result(async_task, img_paths, black_out_nsfw, False, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results) + yield_result(async_task, img_paths, black_out_nsfw, False, + do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results) except ldm_patched.modules.model_management.InterruptProcessingException as e: if async_task.last_stop == 'skip': print('User skipped') diff --git a/modules/config.py b/modules/config.py index 701222c8..11fe3181 100644 --- a/modules/config.py +++ b/modules/config.py @@ -1,7 +1,6 @@ import os import json import math -import re import numbers import args_manager import tempfile @@ -9,9 +8,11 @@ import modules.flags import modules.sdxl_styles from modules.model_loader import load_file_from_url -from modules.extra_utils import get_files_from_folder, makedirs_with_log +from modules.util import makedirs_with_log +from modules.extra_utils import get_files_from_folder from modules.flags import OutputFormat, Performance, MetadataScheme + def get_config_path(key, default_value): env = os.getenv(key) if env is not None and isinstance(env, str): diff --git a/modules/extra_utils.py b/modules/extra_utils.py index 6bbf7557..3e95e8b5 100644 --- a/modules/extra_utils.py +++ b/modules/extra_utils.py @@ -1,7 +1,6 @@ - import os -#TODO: Use refactor to use glob instead + def get_files_from_folder(folder_path, extensions=None, name_filter=None): if not os.path.isdir(folder_path): raise ValueError("Folder path is not a valid directory.") @@ -19,9 +18,3 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None): filenames.append(path) return filenames - -def makedirs_with_log(path): - try: - os.makedirs(path, exist_ok=True) - except OSError as error: - print(f'Directory {path} could not be created, reason: {error}') diff --git a/modules/util.py b/modules/util.py index 7d695df7..73430230 100644 --- a/modules/util.py +++ b/modules/util.py @@ -172,7 +172,6 @@ def generate_temp_filename(folder='./outputs/', extension='png'): return date_string, os.path.abspath(result), filename - def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH): print(f"Calculating sha256 for {filename}: ", end='') if use_addnet_hash: @@ -377,6 +376,13 @@ def ordinal_suffix(number: int) -> str: return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th') +def makedirs_with_log(path): + try: + os.makedirs(path, exist_ok=True) + except OSError as error: + print(f'Directory {path} could not be created, reason: {error}') + + def get_enabled_loras(loras: list) -> list: return [(lora[1], lora[2]) for lora in loras if lora[0]]