feat: map basic information for scheme A1111
This commit is contained in:
parent
ee21c2b6bc
commit
e19596c2df
|
|
@ -203,16 +203,16 @@ def worker():
|
|||
modules.patch.adm_scaler_end = advanced_parameters.adm_scaler_end = 0.0
|
||||
steps = 8
|
||||
|
||||
if save_metadata_to_images:
|
||||
base_model_path = os.path.join(modules.config.path_checkpoints, base_model_name)
|
||||
base_model_hash = calculate_sha256(base_model_path)[0:10]
|
||||
base_model_path = os.path.join(modules.config.path_checkpoints, base_model_name)
|
||||
base_model_hash = calculate_sha256(base_model_path)[0:10]
|
||||
|
||||
lora_hashes = []
|
||||
for (n, w) in loras:
|
||||
if n != 'None':
|
||||
lora_path = os.path.join(modules.config.path_loras, n)
|
||||
lora_hashes.append(f'{n.split(".")[0]}: {calculate_sha256(lora_path)[0:10]}')
|
||||
lora_hashes_string = ", ".join(lora_hashes)
|
||||
refiner_model_path = os.path.join(modules.config.path_checkpoints, refiner_model_name)
|
||||
refiner_model_hash = calculate_sha256(refiner_model_path)[0:10] if refiner_model_name != 'None' else ''
|
||||
|
||||
lora_hashes = []
|
||||
for (n, w) in loras:
|
||||
lora_path = os.path.join(modules.config.path_loras, n) if n != 'None' else ''
|
||||
lora_hashes.append(calculate_sha256(lora_path)[0:10] if n != 'None' else '')
|
||||
|
||||
modules.patch.adaptive_cfg = advanced_parameters.adaptive_cfg
|
||||
print(f'[Parameters] Adaptive CFG = {modules.patch.adaptive_cfg}')
|
||||
|
|
@ -812,7 +812,9 @@ def worker():
|
|||
modules.patch.negative_adm_scale,
|
||||
modules.patch.adm_scaler_end)), True, True),
|
||||
('Base Model', 'base_model', base_model_name, True, True),
|
||||
('Base Model Hash', 'base_model_hash', base_model_hash, False, False),
|
||||
('Refiner Model', 'refiner_model', refiner_model_name, True, True),
|
||||
('Refiner Model Hash', 'refiner_model_hash', refiner_model_hash, False, False),
|
||||
('Refiner Switch', 'refiner_switch', refiner_switch, True, True),
|
||||
('Sampler', 'sampler', sampler_name, True, True),
|
||||
('Scheduler', 'scheduler', scheduler_name, True, True),
|
||||
|
|
@ -821,8 +823,9 @@ def worker():
|
|||
for li, (n, w) in enumerate(loras):
|
||||
if n != 'None':
|
||||
d.append((f'LoRA {li + 1}', f'lora{li + 1}_combined', f'{n} : {w}', True, True))
|
||||
# d.append((f'LoRA {li + 1} Name', f'lora{li + 1}_name', n, False, False))
|
||||
# d.append((f'LoRA {li + 1} Weight', f'lora{li + 1}_weight', n, False, False))
|
||||
d.append((f'LoRA {li + 1} Name', f'lora_name_{li + 1}', n, False, False))
|
||||
d.append((f'LoRA {li + 1} Weight', f'lora_weight_{li + 1}', w, False, False))
|
||||
d.append((f'LoRA {li + 1} Hash', f'lora_hash_{li + 1}', lora_hashes[li], False, False))
|
||||
d.append(('Version', 'version', 'v' + fooocus_version.version, True, True))
|
||||
log(x, d, save_metadata_to_images, metadata_scheme)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from PIL import Image
|
||||
|
|
@ -6,7 +7,11 @@ from PIL import Image
|
|||
import modules.config
|
||||
import fooocus_version
|
||||
# import advanced_parameters
|
||||
from modules.util import quote, is_json
|
||||
from modules.util import quote, unquote, is_json
|
||||
|
||||
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
||||
re_param = re.compile(re_param_code)
|
||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||
|
||||
|
||||
class MetadataScheme(Enum):
|
||||
|
|
@ -16,7 +21,7 @@ class MetadataScheme(Enum):
|
|||
|
||||
class MetadataParser(ABC):
|
||||
@abstractmethod
|
||||
def parse_json(self, metadata: dict):
|
||||
def parse_json(self, metadata: dict) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
# TODO add data to parse
|
||||
|
|
@ -27,9 +32,67 @@ class MetadataParser(ABC):
|
|||
|
||||
class A1111MetadataParser(MetadataParser):
|
||||
|
||||
def parse_json(self, metadata: dict):
|
||||
fooocus_to_a1111 = {
|
||||
'negative_prompt': 'Negative prompt',
|
||||
'steps': 'Steps',
|
||||
'sampler': 'Sampler',
|
||||
'guidance_scale': 'CFG scale',
|
||||
'seed': 'Seed',
|
||||
'resolution': 'Size',
|
||||
'base_model': 'Model',
|
||||
'base_model_hash': 'Model hash',
|
||||
'refiner_model': 'Refiner',
|
||||
'refiner_model_hash': 'Refiner hash',
|
||||
'lora_hashes': 'Lora hashes',
|
||||
'version': 'Version'
|
||||
}
|
||||
|
||||
def parse_json(self, metadata: str) -> dict:
|
||||
# TODO add correct mapping
|
||||
pass
|
||||
|
||||
prompt = ''
|
||||
negative_prompt = ''
|
||||
|
||||
done_with_prompt = False
|
||||
|
||||
*lines, lastline = metadata.strip().split("\n")
|
||||
if len(re_param.findall(lastline)) < 3:
|
||||
lines.append(lastline)
|
||||
lastline = ''
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith(f"{self.fooocus_to_a1111['negative_prompt']}:"):
|
||||
done_with_prompt = True
|
||||
line = line[len(f"{self.fooocus_to_a1111['negative_prompt']}:"):].strip()
|
||||
if done_with_prompt:
|
||||
negative_prompt += ('' if negative_prompt == '' else "\n") + line
|
||||
else:
|
||||
prompt += ('' if prompt == '' else "\n") + line
|
||||
|
||||
|
||||
data = {
|
||||
'prompt': prompt,
|
||||
'negative_prompt': negative_prompt
|
||||
}
|
||||
|
||||
for k, v in re_param.findall(lastline):
|
||||
try:
|
||||
if v[0] == '"' and v[-1] == '"':
|
||||
v = unquote(v)
|
||||
|
||||
m = re_imagesize.match(v)
|
||||
if m is not None:
|
||||
# TODO check
|
||||
data[f"{k}-1"] = m.group(1)
|
||||
data[f"{k}-2"] = m.group(2)
|
||||
else:
|
||||
key = list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)]
|
||||
data[key] = v
|
||||
except Exception:
|
||||
print(f"Error parsing \"{k}: {v}\"")
|
||||
|
||||
return data
|
||||
|
||||
def parse_string(self, metadata: dict) -> str:
|
||||
# TODO add correct mapping
|
||||
|
|
@ -39,30 +102,54 @@ class A1111MetadataParser(MetadataParser):
|
|||
# TODO check if correct
|
||||
width, heigth = data['resolution'].split(', ')
|
||||
|
||||
lora_hashes = []
|
||||
for index in range(5):
|
||||
name = f'lora_name_{index + 1}'
|
||||
if name in data:
|
||||
# weight = f'lora_weight_{index}'
|
||||
hash = data[f'lora_hash_{index + 1}']
|
||||
lora_hashes.append(f'{name.split(".")[0]}: {hash}')
|
||||
lora_hashes_string = ", ".join(lora_hashes)
|
||||
|
||||
# set static defaults
|
||||
generation_params = {
|
||||
"Steps": data['steps'],
|
||||
"Sampler": data['sampler'],
|
||||
"CFG scale": data['guidance_scale'],
|
||||
"Seed": data['seed'],
|
||||
"Size": f"{width}x{heigth}",
|
||||
# "Model hash": base_model_hash,
|
||||
"Model": data['base_model'].split('.')[0],
|
||||
# "Lora hashes": lora_hashes_string,
|
||||
'styles': [],
|
||||
}
|
||||
|
||||
generation_params |= {
|
||||
self.fooocus_to_a1111['steps']: data['steps'],
|
||||
self.fooocus_to_a1111['sampler']: data['sampler'],
|
||||
self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'],
|
||||
self.fooocus_to_a1111['seed']: data['seed'],
|
||||
self.fooocus_to_a1111['resolution']: f'{width}x{heigth}',
|
||||
self.fooocus_to_a1111['base_model']: data['base_model'].split('.')[0],
|
||||
self.fooocus_to_a1111['base_model_hash']: data['base_model_hash']
|
||||
}
|
||||
|
||||
if 'refiner_model' in data and data['refiner_model'] != 'None' and 'refiner_model_hash' in data:
|
||||
generation_params |= {
|
||||
self.fooocus_to_a1111['refiner_model']: data['refiner_model'].split('.')[0],
|
||||
self.fooocus_to_a1111['refiner_model_hash']: data['refiner_model_hash'],
|
||||
}
|
||||
|
||||
generation_params |= {
|
||||
self.fooocus_to_a1111['lora_hashes']: lora_hashes_string,
|
||||
# "Denoising strength": data['denoising_strength'],
|
||||
"Version": f"Fooocus {data['version']}"
|
||||
self.fooocus_to_a1111['version']: f"Fooocus {data['version']}"
|
||||
}
|
||||
|
||||
generation_params_text = ", ".join(
|
||||
[k if k == v else f'{k}: {quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||
positive_prompt_resolved = ', '.join(data['full_prompt'])
|
||||
negative_prompt_resolved = ', '.join(data['full_negative_prompt'])
|
||||
# TODO check if multiline positive prompt is correctly processed
|
||||
positive_prompt_resolved = ', '.join(data['full_prompt']) #TODO add loras to positive prompt if even possible
|
||||
negative_prompt_resolved = ', '.join(data['full_negative_prompt']) #TODO add loras to positive prompt if even possible
|
||||
negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else ""
|
||||
return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||
|
||||
|
||||
class FooocusMetadataParser(MetadataParser):
|
||||
|
||||
def parse_json(self, metadata: dict):
|
||||
def parse_json(self, metadata: dict) -> dict:
|
||||
# TODO add mapping if necessary
|
||||
return metadata
|
||||
|
||||
|
|
@ -140,6 +227,7 @@ class FooocusMetadataParser(MetadataParser):
|
|||
# metadata |= {
|
||||
# 'software': f'Fooocus v{fooocus_version.version}',
|
||||
# }
|
||||
# TODO add metadata_created_by
|
||||
# if modules.config.metadata_created_by != '':
|
||||
# metadata |= {
|
||||
# 'created_by': modules.config.metadata_created_by
|
||||
|
|
|
|||
|
|
@ -197,6 +197,16 @@ def quote(text):
|
|||
return json.dumps(text, ensure_ascii=False)
|
||||
|
||||
|
||||
def unquote(text):
|
||||
if len(text) == 0 or text[0] != '"' or text[-1] != '"':
|
||||
return text
|
||||
|
||||
try:
|
||||
return json.loads(text)
|
||||
except Exception:
|
||||
return text
|
||||
|
||||
|
||||
def is_json(data: str) -> bool:
|
||||
try:
|
||||
loaded_json = json.loads(data)
|
||||
|
|
|
|||
Loading…
Reference in New Issue