From e19596c2df7201c94cf81039570bb397e56fe59a Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sun, 28 Jan 2024 18:04:40 +0100 Subject: [PATCH] feat: map basic information for scheme A1111 --- modules/async_worker.py | 25 +++++---- modules/metadata.py | 120 ++++++++++++++++++++++++++++++++++------ modules/util.py | 10 ++++ 3 files changed, 128 insertions(+), 27 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index ddef06c1..ea4018ed 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -203,16 +203,16 @@ def worker(): modules.patch.adm_scaler_end = advanced_parameters.adm_scaler_end = 0.0 steps = 8 - if save_metadata_to_images: - base_model_path = os.path.join(modules.config.path_checkpoints, base_model_name) - base_model_hash = calculate_sha256(base_model_path)[0:10] + base_model_path = os.path.join(modules.config.path_checkpoints, base_model_name) + base_model_hash = calculate_sha256(base_model_path)[0:10] - lora_hashes = [] - for (n, w) in loras: - if n != 'None': - lora_path = os.path.join(modules.config.path_loras, n) - lora_hashes.append(f'{n.split(".")[0]}: {calculate_sha256(lora_path)[0:10]}') - lora_hashes_string = ", ".join(lora_hashes) + refiner_model_path = os.path.join(modules.config.path_checkpoints, refiner_model_name) + refiner_model_hash = calculate_sha256(refiner_model_path)[0:10] if refiner_model_name != 'None' else '' + + lora_hashes = [] + for (n, w) in loras: + lora_path = os.path.join(modules.config.path_loras, n) if n != 'None' else '' + lora_hashes.append(calculate_sha256(lora_path)[0:10] if n != 'None' else '') modules.patch.adaptive_cfg = advanced_parameters.adaptive_cfg print(f'[Parameters] Adaptive CFG = {modules.patch.adaptive_cfg}') @@ -812,7 +812,9 @@ def worker(): modules.patch.negative_adm_scale, modules.patch.adm_scaler_end)), True, True), ('Base Model', 'base_model', base_model_name, True, True), + ('Base Model Hash', 'base_model_hash', base_model_hash, False, False), ('Refiner Model', 'refiner_model', refiner_model_name, True, True), + ('Refiner Model Hash', 'refiner_model_hash', refiner_model_hash, False, False), ('Refiner Switch', 'refiner_switch', refiner_switch, True, True), ('Sampler', 'sampler', sampler_name, True, True), ('Scheduler', 'scheduler', scheduler_name, True, True), @@ -821,8 +823,9 @@ def worker(): for li, (n, w) in enumerate(loras): if n != 'None': d.append((f'LoRA {li + 1}', f'lora{li + 1}_combined', f'{n} : {w}', True, True)) - # d.append((f'LoRA {li + 1} Name', f'lora{li + 1}_name', n, False, False)) - # d.append((f'LoRA {li + 1} Weight', f'lora{li + 1}_weight', n, False, False)) + d.append((f'LoRA {li + 1} Name', f'lora_name_{li + 1}', n, False, False)) + d.append((f'LoRA {li + 1} Weight', f'lora_weight_{li + 1}', w, False, False)) + d.append((f'LoRA {li + 1} Hash', f'lora_hash_{li + 1}', lora_hashes[li], False, False)) d.append(('Version', 'version', 'v' + fooocus_version.version, True, True)) log(x, d, save_metadata_to_images, metadata_scheme) diff --git a/modules/metadata.py b/modules/metadata.py index 397f7333..4cdf534a 100644 --- a/modules/metadata.py +++ b/modules/metadata.py @@ -1,4 +1,5 @@ import json +import re from abc import ABC, abstractmethod from enum import Enum from PIL import Image @@ -6,7 +7,11 @@ from PIL import Image import modules.config import fooocus_version # import advanced_parameters -from modules.util import quote, is_json +from modules.util import quote, unquote, is_json + +re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' +re_param = re.compile(re_param_code) +re_imagesize = re.compile(r"^(\d+)x(\d+)$") class MetadataScheme(Enum): @@ -16,7 +21,7 @@ class MetadataScheme(Enum): class MetadataParser(ABC): @abstractmethod - def parse_json(self, metadata: dict): + def parse_json(self, metadata: dict) -> dict: raise NotImplementedError # TODO add data to parse @@ -27,9 +32,67 @@ class MetadataParser(ABC): class A1111MetadataParser(MetadataParser): - def parse_json(self, metadata: dict): + fooocus_to_a1111 = { + 'negative_prompt': 'Negative prompt', + 'steps': 'Steps', + 'sampler': 'Sampler', + 'guidance_scale': 'CFG scale', + 'seed': 'Seed', + 'resolution': 'Size', + 'base_model': 'Model', + 'base_model_hash': 'Model hash', + 'refiner_model': 'Refiner', + 'refiner_model_hash': 'Refiner hash', + 'lora_hashes': 'Lora hashes', + 'version': 'Version' + } + + def parse_json(self, metadata: str) -> dict: # TODO add correct mapping - pass + + prompt = '' + negative_prompt = '' + + done_with_prompt = False + + *lines, lastline = metadata.strip().split("\n") + if len(re_param.findall(lastline)) < 3: + lines.append(lastline) + lastline = '' + + for line in lines: + line = line.strip() + if line.startswith(f"{self.fooocus_to_a1111['negative_prompt']}:"): + done_with_prompt = True + line = line[len(f"{self.fooocus_to_a1111['negative_prompt']}:"):].strip() + if done_with_prompt: + negative_prompt += ('' if negative_prompt == '' else "\n") + line + else: + prompt += ('' if prompt == '' else "\n") + line + + + data = { + 'prompt': prompt, + 'negative_prompt': negative_prompt + } + + for k, v in re_param.findall(lastline): + try: + if v[0] == '"' and v[-1] == '"': + v = unquote(v) + + m = re_imagesize.match(v) + if m is not None: + # TODO check + data[f"{k}-1"] = m.group(1) + data[f"{k}-2"] = m.group(2) + else: + key = list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)] + data[key] = v + except Exception: + print(f"Error parsing \"{k}: {v}\"") + + return data def parse_string(self, metadata: dict) -> str: # TODO add correct mapping @@ -39,30 +102,54 @@ class A1111MetadataParser(MetadataParser): # TODO check if correct width, heigth = data['resolution'].split(', ') + lora_hashes = [] + for index in range(5): + name = f'lora_name_{index + 1}' + if name in data: + # weight = f'lora_weight_{index}' + hash = data[f'lora_hash_{index + 1}'] + lora_hashes.append(f'{name.split(".")[0]}: {hash}') + lora_hashes_string = ", ".join(lora_hashes) + + # set static defaults generation_params = { - "Steps": data['steps'], - "Sampler": data['sampler'], - "CFG scale": data['guidance_scale'], - "Seed": data['seed'], - "Size": f"{width}x{heigth}", - # "Model hash": base_model_hash, - "Model": data['base_model'].split('.')[0], - # "Lora hashes": lora_hashes_string, + 'styles': [], + } + + generation_params |= { + self.fooocus_to_a1111['steps']: data['steps'], + self.fooocus_to_a1111['sampler']: data['sampler'], + self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'], + self.fooocus_to_a1111['seed']: data['seed'], + self.fooocus_to_a1111['resolution']: f'{width}x{heigth}', + self.fooocus_to_a1111['base_model']: data['base_model'].split('.')[0], + self.fooocus_to_a1111['base_model_hash']: data['base_model_hash'] + } + + if 'refiner_model' in data and data['refiner_model'] != 'None' and 'refiner_model_hash' in data: + generation_params |= { + self.fooocus_to_a1111['refiner_model']: data['refiner_model'].split('.')[0], + self.fooocus_to_a1111['refiner_model_hash']: data['refiner_model_hash'], + } + + generation_params |= { + self.fooocus_to_a1111['lora_hashes']: lora_hashes_string, # "Denoising strength": data['denoising_strength'], - "Version": f"Fooocus {data['version']}" + self.fooocus_to_a1111['version']: f"Fooocus {data['version']}" } generation_params_text = ", ".join( [k if k == v else f'{k}: {quote(v)}' for k, v in generation_params.items() if v is not None]) - positive_prompt_resolved = ', '.join(data['full_prompt']) - negative_prompt_resolved = ', '.join(data['full_negative_prompt']) + # TODO check if multiline positive prompt is correctly processed + positive_prompt_resolved = ', '.join(data['full_prompt']) #TODO add loras to positive prompt if even possible + negative_prompt_resolved = ', '.join(data['full_negative_prompt']) #TODO add loras to positive prompt if even possible negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else "" return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip() class FooocusMetadataParser(MetadataParser): - def parse_json(self, metadata: dict): + def parse_json(self, metadata: dict) -> dict: # TODO add mapping if necessary return metadata @@ -140,6 +227,7 @@ class FooocusMetadataParser(MetadataParser): # metadata |= { # 'software': f'Fooocus v{fooocus_version.version}', # } + # TODO add metadata_created_by # if modules.config.metadata_created_by != '': # metadata |= { # 'created_by': modules.config.metadata_created_by diff --git a/modules/util.py b/modules/util.py index 8e304d6c..f7fcc4e7 100644 --- a/modules/util.py +++ b/modules/util.py @@ -197,6 +197,16 @@ def quote(text): return json.dumps(text, ensure_ascii=False) +def unquote(text): + if len(text) == 0 or text[0] != '"' or text[-1] != '"': + return text + + try: + return json.loads(text) + except Exception: + return text + + def is_json(data: str) -> bool: try: loaded_json = json.loads(data)