feat: add metadata handling

This commit is contained in:
Manuel Schmid 2024-05-05 00:33:24 +02:00
parent af33e930d3
commit ab76a26806
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
2 changed files with 26 additions and 8 deletions

View File

@ -870,6 +870,7 @@ def worker():
d.append(('Sampler', 'sampler', sampler_name))
d.append(('Scheduler', 'scheduler', scheduler_name))
d.append(('VAE', 'vae', vae_name))
d.append(('Seed', 'seed', str(task['task_seed'])))
if freeu_enabled:
@ -884,7 +885,7 @@ def worker():
metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme)
metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
task['log_negative_prompt'], task['negative'],
steps, base_model_name, refiner_model_name, loras)
steps, base_model_name, refiner_model_name, loras, vae_name)
d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
img_paths.append(log(x, d, metadata_parser, output_format))

View File

@ -46,6 +46,7 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results)
get_str('sampler', 'Sampler', loaded_parameter_dict, results)
get_str('scheduler', 'Scheduler', loaded_parameter_dict, results)
get_str('vae', 'VAE', loaded_parameter_dict, results)
get_seed('seed', 'Seed', loaded_parameter_dict, results)
if is_generating:
@ -253,6 +254,7 @@ class MetadataParser(ABC):
self.refiner_model_name: str = ''
self.refiner_model_hash: str = ''
self.loras: list = []
self.vae_name: str = ''
@abstractmethod
def get_scheme(self) -> MetadataScheme:
@ -267,7 +269,7 @@ class MetadataParser(ABC):
raise NotImplementedError
def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name,
refiner_model_name, loras):
refiner_model_name, loras, vae_name):
self.raw_prompt = raw_prompt
self.full_prompt = full_prompt
self.raw_negative_prompt = raw_negative_prompt
@ -289,6 +291,7 @@ class MetadataParser(ABC):
lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras)
lora_hash = get_sha256(lora_path)
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
self.vae_name = Path(vae_name).stem
@staticmethod
def remove_special_loras(lora_filenames):
@ -310,6 +313,7 @@ class A1111MetadataParser(MetadataParser):
'steps': 'Steps',
'sampler': 'Sampler',
'scheduler': 'Scheduler',
'vae': 'VAE',
'guidance_scale': 'CFG scale',
'seed': 'Seed',
'resolution': 'Size',
@ -397,13 +401,12 @@ class A1111MetadataParser(MetadataParser):
data['sampler'] = k
break
for key in ['base_model', 'refiner_model']:
for key in ['base_model', 'refiner_model', 'vae']:
if key in data:
for filename in modules.config.model_filenames:
path = Path(filename)
if data[key] == path.stem:
data[key] = filename
break
if key == 'vae':
self.add_extension_to_filename(data, modules.config.vae_filenames, 'vae')
else:
self.add_extension_to_filename(data, modules.config.model_filenames, key)
lora_data = ''
if 'lora_weights' in data and data['lora_weights'] != '':
@ -433,6 +436,7 @@ class A1111MetadataParser(MetadataParser):
sampler = data['sampler']
scheduler = data['scheduler']
if sampler in SAMPLERS and SAMPLERS[sampler] != '':
sampler = SAMPLERS[sampler]
if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras':
@ -451,6 +455,7 @@ class A1111MetadataParser(MetadataParser):
self.fooocus_to_a1111['performance']: data['performance'],
self.fooocus_to_a1111['scheduler']: scheduler,
self.fooocus_to_a1111['vae']: Path(data['vae']).stem,
# workaround for multiline prompts
self.fooocus_to_a1111['raw_prompt']: self.raw_prompt,
self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt,
@ -491,6 +496,14 @@ class A1111MetadataParser(MetadataParser):
negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else ""
return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip()
@staticmethod
def add_extension_to_filename(data, filenames, key):
for filename in filenames:
path = Path(filename)
if data[key] == path.stem:
data[key] = filename
break
class FooocusMetadataParser(MetadataParser):
def get_scheme(self) -> MetadataScheme:
@ -499,6 +512,7 @@ class FooocusMetadataParser(MetadataParser):
def parse_json(self, metadata: dict) -> dict:
model_filenames = modules.config.model_filenames.copy()
lora_filenames = modules.config.lora_filenames.copy()
vae_filenames = modules.config.vae_filenames.copy()
self.remove_special_loras(lora_filenames)
for key, value in metadata.items():
if value in ['', 'None']:
@ -507,6 +521,8 @@ class FooocusMetadataParser(MetadataParser):
metadata[key] = self.replace_value_with_filename(key, value, model_filenames)
elif key.startswith('lora_combined_'):
metadata[key] = self.replace_value_with_filename(key, value, lora_filenames)
elif key == 'vae':
metadata[key] = self.replace_value_with_filename(key, value, vae_filenames)
else:
continue
@ -533,6 +549,7 @@ class FooocusMetadataParser(MetadataParser):
res['refiner_model'] = self.refiner_model_name
res['refiner_model_hash'] = self.refiner_model_hash
res['vae'] = self.vae_name
res['loras'] = self.loras
if modules.config.metadata_created_by != '':