feat: map basic information for scheme A1111

This commit is contained in:
Manuel Schmid 2024-01-28 18:04:40 +01:00
parent ee21c2b6bc
commit e19596c2df
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
3 changed files with 128 additions and 27 deletions

View File

@ -203,16 +203,16 @@ def worker():
modules.patch.adm_scaler_end = advanced_parameters.adm_scaler_end = 0.0
steps = 8
if save_metadata_to_images:
base_model_path = os.path.join(modules.config.path_checkpoints, base_model_name)
base_model_hash = calculate_sha256(base_model_path)[0:10]
base_model_path = os.path.join(modules.config.path_checkpoints, base_model_name)
base_model_hash = calculate_sha256(base_model_path)[0:10]
lora_hashes = []
for (n, w) in loras:
if n != 'None':
lora_path = os.path.join(modules.config.path_loras, n)
lora_hashes.append(f'{n.split(".")[0]}: {calculate_sha256(lora_path)[0:10]}')
lora_hashes_string = ", ".join(lora_hashes)
refiner_model_path = os.path.join(modules.config.path_checkpoints, refiner_model_name)
refiner_model_hash = calculate_sha256(refiner_model_path)[0:10] if refiner_model_name != 'None' else ''
lora_hashes = []
for (n, w) in loras:
lora_path = os.path.join(modules.config.path_loras, n) if n != 'None' else ''
lora_hashes.append(calculate_sha256(lora_path)[0:10] if n != 'None' else '')
modules.patch.adaptive_cfg = advanced_parameters.adaptive_cfg
print(f'[Parameters] Adaptive CFG = {modules.patch.adaptive_cfg}')
@ -812,7 +812,9 @@ def worker():
modules.patch.negative_adm_scale,
modules.patch.adm_scaler_end)), True, True),
('Base Model', 'base_model', base_model_name, True, True),
('Base Model Hash', 'base_model_hash', base_model_hash, False, False),
('Refiner Model', 'refiner_model', refiner_model_name, True, True),
('Refiner Model Hash', 'refiner_model_hash', refiner_model_hash, False, False),
('Refiner Switch', 'refiner_switch', refiner_switch, True, True),
('Sampler', 'sampler', sampler_name, True, True),
('Scheduler', 'scheduler', scheduler_name, True, True),
@ -821,8 +823,9 @@ def worker():
for li, (n, w) in enumerate(loras):
if n != 'None':
d.append((f'LoRA {li + 1}', f'lora{li + 1}_combined', f'{n} : {w}', True, True))
# d.append((f'LoRA {li + 1} Name', f'lora{li + 1}_name', n, False, False))
# d.append((f'LoRA {li + 1} Weight', f'lora{li + 1}_weight', n, False, False))
d.append((f'LoRA {li + 1} Name', f'lora_name_{li + 1}', n, False, False))
d.append((f'LoRA {li + 1} Weight', f'lora_weight_{li + 1}', w, False, False))
d.append((f'LoRA {li + 1} Hash', f'lora_hash_{li + 1}', lora_hashes[li], False, False))
d.append(('Version', 'version', 'v' + fooocus_version.version, True, True))
log(x, d, save_metadata_to_images, metadata_scheme)

View File

@ -1,4 +1,5 @@
import json
import re
from abc import ABC, abstractmethod
from enum import Enum
from PIL import Image
@ -6,7 +7,11 @@ from PIL import Image
import modules.config
import fooocus_version
# import advanced_parameters
from modules.util import quote, is_json
from modules.util import quote, unquote, is_json
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
class MetadataScheme(Enum):
@ -16,7 +21,7 @@ class MetadataScheme(Enum):
class MetadataParser(ABC):
@abstractmethod
def parse_json(self, metadata: dict):
def parse_json(self, metadata: dict) -> dict:
raise NotImplementedError
# TODO add data to parse
@ -27,9 +32,67 @@ class MetadataParser(ABC):
class A1111MetadataParser(MetadataParser):
def parse_json(self, metadata: dict):
fooocus_to_a1111 = {
'negative_prompt': 'Negative prompt',
'steps': 'Steps',
'sampler': 'Sampler',
'guidance_scale': 'CFG scale',
'seed': 'Seed',
'resolution': 'Size',
'base_model': 'Model',
'base_model_hash': 'Model hash',
'refiner_model': 'Refiner',
'refiner_model_hash': 'Refiner hash',
'lora_hashes': 'Lora hashes',
'version': 'Version'
}
def parse_json(self, metadata: str) -> dict:
# TODO add correct mapping
pass
prompt = ''
negative_prompt = ''
done_with_prompt = False
*lines, lastline = metadata.strip().split("\n")
if len(re_param.findall(lastline)) < 3:
lines.append(lastline)
lastline = ''
for line in lines:
line = line.strip()
if line.startswith(f"{self.fooocus_to_a1111['negative_prompt']}:"):
done_with_prompt = True
line = line[len(f"{self.fooocus_to_a1111['negative_prompt']}:"):].strip()
if done_with_prompt:
negative_prompt += ('' if negative_prompt == '' else "\n") + line
else:
prompt += ('' if prompt == '' else "\n") + line
data = {
'prompt': prompt,
'negative_prompt': negative_prompt
}
for k, v in re_param.findall(lastline):
try:
if v[0] == '"' and v[-1] == '"':
v = unquote(v)
m = re_imagesize.match(v)
if m is not None:
# TODO check
data[f"{k}-1"] = m.group(1)
data[f"{k}-2"] = m.group(2)
else:
key = list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)]
data[key] = v
except Exception:
print(f"Error parsing \"{k}: {v}\"")
return data
def parse_string(self, metadata: dict) -> str:
# TODO add correct mapping
@ -39,30 +102,54 @@ class A1111MetadataParser(MetadataParser):
# TODO check if correct
width, heigth = data['resolution'].split(', ')
lora_hashes = []
for index in range(5):
name = f'lora_name_{index + 1}'
if name in data:
# weight = f'lora_weight_{index}'
hash = data[f'lora_hash_{index + 1}']
lora_hashes.append(f'{name.split(".")[0]}: {hash}')
lora_hashes_string = ", ".join(lora_hashes)
# set static defaults
generation_params = {
"Steps": data['steps'],
"Sampler": data['sampler'],
"CFG scale": data['guidance_scale'],
"Seed": data['seed'],
"Size": f"{width}x{heigth}",
# "Model hash": base_model_hash,
"Model": data['base_model'].split('.')[0],
# "Lora hashes": lora_hashes_string,
'styles': [],
}
generation_params |= {
self.fooocus_to_a1111['steps']: data['steps'],
self.fooocus_to_a1111['sampler']: data['sampler'],
self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'],
self.fooocus_to_a1111['seed']: data['seed'],
self.fooocus_to_a1111['resolution']: f'{width}x{heigth}',
self.fooocus_to_a1111['base_model']: data['base_model'].split('.')[0],
self.fooocus_to_a1111['base_model_hash']: data['base_model_hash']
}
if 'refiner_model' in data and data['refiner_model'] != 'None' and 'refiner_model_hash' in data:
generation_params |= {
self.fooocus_to_a1111['refiner_model']: data['refiner_model'].split('.')[0],
self.fooocus_to_a1111['refiner_model_hash']: data['refiner_model_hash'],
}
generation_params |= {
self.fooocus_to_a1111['lora_hashes']: lora_hashes_string,
# "Denoising strength": data['denoising_strength'],
"Version": f"Fooocus {data['version']}"
self.fooocus_to_a1111['version']: f"Fooocus {data['version']}"
}
generation_params_text = ", ".join(
[k if k == v else f'{k}: {quote(v)}' for k, v in generation_params.items() if v is not None])
positive_prompt_resolved = ', '.join(data['full_prompt'])
negative_prompt_resolved = ', '.join(data['full_negative_prompt'])
# TODO check if multiline positive prompt is correctly processed
positive_prompt_resolved = ', '.join(data['full_prompt']) #TODO add loras to positive prompt if even possible
negative_prompt_resolved = ', '.join(data['full_negative_prompt']) #TODO add loras to positive prompt if even possible
negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else ""
return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip()
class FooocusMetadataParser(MetadataParser):
def parse_json(self, metadata: dict):
def parse_json(self, metadata: dict) -> dict:
# TODO add mapping if necessary
return metadata
@ -140,6 +227,7 @@ class FooocusMetadataParser(MetadataParser):
# metadata |= {
# 'software': f'Fooocus v{fooocus_version.version}',
# }
# TODO add metadata_created_by
# if modules.config.metadata_created_by != '':
# metadata |= {
# 'created_by': modules.config.metadata_created_by

View File

@ -197,6 +197,16 @@ def quote(text):
return json.dumps(text, ensure_ascii=False)
def unquote(text):
if len(text) == 0 or text[0] != '"' or text[-1] != '"':
return text
try:
return json.loads(text)
except Exception:
return text
def is_json(data: str) -> bool:
try:
loaded_json = json.loads(data)