diff --git a/modules/async_worker.py b/modules/async_worker.py index ea4018ed..e6ec8c39 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -44,7 +44,7 @@ def worker(): from modules.util import remove_empty_str, HWC3, resize_image, \ get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, calculate_sha256 from modules.upscaler import perform_upscale - from modules.metadata import MetadataScheme + from modules.flags import Performance, MetadataScheme try: async_gradio_app = shared.gradio_root @@ -125,7 +125,7 @@ def worker(): prompt = args.pop() negative_prompt = args.pop() style_selections = args.pop() - performance_selection = args.pop() + performance_selection = Performance(args.pop()) aspect_ratios_selection = args.pop() image_number = args.pop() image_seed = args.pop() @@ -144,8 +144,7 @@ def worker(): inpaint_additional_prompt = args.pop() inpaint_mask_image_upload = args.pop() save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False - metadata_scheme = args.pop() if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS.value - assert metadata_scheme in [item.value for item in MetadataScheme] + metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS cn_tasks = {x: [] for x in flags.ip_list} for _ in range(4): @@ -173,17 +172,9 @@ def worker(): print(f'Refiner disabled because base model and refiner are same.') refiner_model_name = 'None' - assert performance_selection in ['Speed', 'Quality', 'Extreme Speed'] + steps = performance_selection.steps() - steps = 30 - - if performance_selection == 'Speed': - steps = 30 - - if performance_selection == 'Quality': - steps = 60 - - if performance_selection == 'Extreme Speed': + if performance_selection == Performance.EXTREME_SPEED: print('Enter LCM mode.') progressbar(async_task, 1, 'Downloading LCM components ...') loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)] @@ -201,7 +192,6 @@ def worker(): modules.patch.positive_adm_scale = advanced_parameters.adm_scaler_positive = 1.0 modules.patch.negative_adm_scale = advanced_parameters.adm_scaler_negative = 1.0 modules.patch.adm_scaler_end = advanced_parameters.adm_scaler_end = 0.0 - steps = 8 base_model_path = os.path.join(modules.config.path_checkpoints, base_model_name) base_model_hash = calculate_sha256(base_model_path)[0:10] @@ -274,16 +264,7 @@ def worker(): if 'fast' in uov_method: skip_prompt_processing = True else: - steps = 18 - - if performance_selection == 'Speed': - steps = 18 - - if performance_selection == 'Quality': - steps = 36 - - if performance_selection == 'Extreme Speed': - steps = 8 + steps = performance_selection.steps_uov() progressbar(async_task, 1, 'Downloading upscale models ...') modules.config.downloading_upscale_model() @@ -802,11 +783,12 @@ def worker(): ('Full Negative Prompt', 'full_negative_prompt', task['negative'], False, False), ('Fooocus V2 Expansion', 'prompt_expansion', task['expansion'], True, True), ('Styles', 'styles', str(raw_style_selections), True, True), - ('Performance', 'performance', performance_selection, True, True), + ('Performance', 'performance', performance_selection.value, True, True), ('Steps', 'steps', steps, False, False), ('Resolution', 'resolution', str((width, height)), True, True), ('Sharpness', 'sharpness', sharpness, True, True), ('Guidance Scale', 'guidance_scale', guidance_scale, True, True), + # ('Denoising Strength', 'denoising_strength', denoising_strength, False, False), ('ADM Guidance', 'adm_guidance', str(( modules.patch.positive_adm_scale, modules.patch.negative_adm_scale, diff --git a/modules/config.py b/modules/config.py index 4a9b6837..7b42ed62 100644 --- a/modules/config.py +++ b/modules/config.py @@ -6,9 +6,9 @@ import args_manager import modules.flags import modules.sdxl_styles -from modules.metadata import MetadataScheme from modules.model_loader import load_file_from_url from modules.util import get_files_from_folder +from modules.flags import Performance, MetadataScheme config_path = os.path.abspath("./config.txt") @@ -236,8 +236,8 @@ default_prompt = get_config_item_or_set_default( ) default_performance = get_config_item_or_set_default( key='default_performance', - default_value='Speed', - validator=lambda x: x in modules.flags.performance_selections + default_value=Performance.SPEED.value, + validator=lambda x: x in Performance.list() ) default_advanced_checkbox = get_config_item_or_set_default( key='default_advanced_checkbox', diff --git a/modules/flags.py b/modules/flags.py index abcd3f60..5b22c5ec 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -1,4 +1,4 @@ -from modules.metadata import MetadataScheme +from enum import Enum disabled = 'Disabled' enabled = 'Enabled' @@ -34,6 +34,12 @@ default_parameters = { cn_ip: (0.5, 0.6), cn_ip_face: (0.9, 0.75), cn_canny: (0.5, 1.0), cn_cpds: (0.5, 1.0) } # stop, weight + +class MetadataScheme(Enum): + FOOOCUS = 'fooocus' + A1111 = 'a1111' + + # TODO use translation here metadata_scheme = [ ('Fooocus (json)', MetadataScheme.FOOOCUS.value), @@ -41,7 +47,37 @@ metadata_scheme = [ ] inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6'] -performance_selections = ['Speed', 'Quality', 'Extreme Speed'] + + +class Steps(Enum): + QUALITY = 60 + SPEED = 30 + EXTREME_SPEED = 8 + + +class StepsUOV(Enum): + QUALITY = 36 + SPEED = 18 + EXTREME_SPEED = 8 + + +class Performance(Enum): + QUALITY = 'Quality' + SPEED = 'Speed' + EXTREME_SPEED = 'Extreme Speed' + + @classmethod + def list(cls) -> list: + return list(map(lambda c: c.value, cls)) + + def steps(self) -> int: + return Steps[self.name].value if Steps[self.name] else None + + def steps_uov(self) -> int: + return StepsUOV[self.name].value if Steps[self.name] else None + + +performance_selections = Performance.list() inpaint_option_default = 'Inpaint or Outpaint (default)' inpaint_option_detail = 'Improve Detail (face, hand, eyes, etc.)' diff --git a/modules/metadata.py b/modules/metadata.py index 4cdf534a..79a37719 100644 --- a/modules/metadata.py +++ b/modules/metadata.py @@ -8,17 +8,13 @@ import modules.config import fooocus_version # import advanced_parameters from modules.util import quote, unquote, is_json +from modules.flags import MetadataScheme, Performance, Steps re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") -class MetadataScheme(Enum): - FOOOCUS = 'fooocus' - A1111 = 'a1111' - - class MetadataParser(ABC): @abstractmethod def parse_json(self, metadata: dict) -> dict: @@ -70,6 +66,14 @@ class A1111MetadataParser(MetadataParser): else: prompt += ('' if prompt == '' else "\n") + line + # if shared.opts.infotext_styles != "Ignore": + # found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, + # negative_prompt) + # + # if shared.opts.infotext_styles == "Apply": + # res["Styles array"] = found_styles + # elif shared.opts.infotext_styles == "Apply if any" and found_styles: + # res["Styles array"] = found_styles data = { 'prompt': prompt, @@ -87,11 +91,17 @@ class A1111MetadataParser(MetadataParser): data[f"{k}-1"] = m.group(1) data[f"{k}-2"] = m.group(2) else: - key = list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)] - data[key] = v + data[list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)]] = v except Exception: print(f"Error parsing \"{k}: {v}\"") + # try to load performance based on steps + if 'steps' in data: + try: + data['performance'] = Performance[Steps(int(data['steps'])).name].value + except Exception: + pass + return data def parse_string(self, metadata: dict) -> str: @@ -104,9 +114,10 @@ class A1111MetadataParser(MetadataParser): lora_hashes = [] for index in range(5): - name = f'lora_name_{index + 1}' - if name in data: - # weight = f'lora_weight_{index}' + key = f'lora_name_{index + 1}' + if key in data: + name = data[f'lora_name_{index + 1}'] + # weight = data[f'lora_weight_{index + 1}'] hash = data[f'lora_hash_{index + 1}'] lora_hashes.append(f'{name.split(".")[0]}: {hash}') lora_hashes_string = ", ".join(lora_hashes) @@ -121,6 +132,7 @@ class A1111MetadataParser(MetadataParser): self.fooocus_to_a1111['sampler']: data['sampler'], self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'], self.fooocus_to_a1111['seed']: data['seed'], + # TODO check resolution value, should be string self.fooocus_to_a1111['resolution']: f'{width}x{heigth}', self.fooocus_to_a1111['base_model']: data['base_model'].split('.')[0], self.fooocus_to_a1111['base_model_hash']: data['base_model_hash'] @@ -236,11 +248,11 @@ class FooocusMetadataParser(MetadataParser): # return json.dumps(metadata, ensure_ascii=False) -def get_metadata_parser(metadata_scheme: str) -> MetadataParser: +def get_metadata_parser(metadata_scheme: MetadataScheme) -> MetadataParser: match metadata_scheme: - case MetadataScheme.FOOOCUS.value: + case MetadataScheme.FOOOCUS: return FooocusMetadataParser() - case MetadataScheme.A1111.value: + case MetadataScheme.A1111: return A1111MetadataParser() case _: raise NotImplementedError @@ -252,7 +264,7 @@ def get_metadata_parser(metadata_scheme: str) -> MetadataParser: # } -def read_info_from_image(filepath) -> tuple[str | None, dict, str | None]: +def read_info_from_image(filepath) -> tuple[str | None, dict, MetadataScheme | None]: with Image.open(filepath) as image: items = (image.info or {}).copy() @@ -260,8 +272,19 @@ def read_info_from_image(filepath) -> tuple[str | None, dict, str | None]: if parameters is not None and is_json(parameters): parameters = json.loads(parameters) - metadata_scheme = items.pop('fooocus_scheme', None) + try: + metadata_scheme = MetadataScheme(items.pop('fooocus_scheme', None)) + except Exception: + metadata_scheme = None + # broad fallback + if metadata_scheme is None and isinstance(parameters, dict): + metadata_scheme = modules.metadata.MetadataScheme.FOOOCUS + + if metadata_scheme is None and isinstance(parameters, str): + metadata_scheme = modules.metadata.MetadataScheme.A1111 + + # TODO code cleanup # if "exif" in items: # exif_data = items["exif"] # try: diff --git a/modules/private_logger.py b/modules/private_logger.py index 3e186e1a..1afcaa55 100644 --- a/modules/private_logger.py +++ b/modules/private_logger.py @@ -20,9 +20,7 @@ def get_current_html_path(): return html_name -def log(img, metadata, save_metadata_to_image=False, metadata_scheme: str = MetadataScheme.FOOOCUS.value): - assert metadata_scheme in [item.value for item in MetadataScheme] - +def log(img, metadata, save_metadata_to_image=False, metadata_scheme: MetadataScheme = MetadataScheme.FOOOCUS): if args_manager.args.disable_image_log: return @@ -35,7 +33,7 @@ def log(img, metadata, save_metadata_to_image=False, metadata_scheme: str = Meta pnginfo = PngInfo() pnginfo.add_text('parameters', parsed_parameters) - pnginfo.add_text('fooocus_scheme', metadata_scheme) + pnginfo.add_text('fooocus_scheme', metadata_scheme.value) else: pnginfo = None Image.fromarray(img).save(local_temp_filename, pnginfo=pnginfo) diff --git a/webui.py b/webui.py index 32fddb98..5c2d4a8c 100644 --- a/webui.py +++ b/webui.py @@ -222,10 +222,12 @@ with shared.gradio_root: results = {} if parameters is not None: results['parameters'] = parameters + if items: results['items'] = items - if metadata_scheme is not None: - results['metadata_scheme'] = metadata_scheme + + if isinstance(metadata_scheme, flags.MetadataScheme): + results['metadata_scheme'] = metadata_scheme.value return results