feat: add enums for Performance, Steps and StepsUOV

also move MetadataSchema enum to prevent circular dependency
This commit is contained in:
Manuel Schmid 2024-01-28 20:01:33 +01:00
parent 7ddd4e5209
commit cbc63ebba3
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
6 changed files with 93 additions and 52 deletions

View File

@ -44,7 +44,7 @@ def worker():
from modules.util import remove_empty_str, HWC3, resize_image, \
get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, calculate_sha256
from modules.upscaler import perform_upscale
from modules.metadata import MetadataScheme
from modules.flags import Performance, MetadataScheme
try:
async_gradio_app = shared.gradio_root
@ -125,7 +125,7 @@ def worker():
prompt = args.pop()
negative_prompt = args.pop()
style_selections = args.pop()
performance_selection = args.pop()
performance_selection = Performance(args.pop())
aspect_ratios_selection = args.pop()
image_number = args.pop()
image_seed = args.pop()
@ -144,8 +144,7 @@ def worker():
inpaint_additional_prompt = args.pop()
inpaint_mask_image_upload = args.pop()
save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False
metadata_scheme = args.pop() if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS.value
assert metadata_scheme in [item.value for item in MetadataScheme]
metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS
cn_tasks = {x: [] for x in flags.ip_list}
for _ in range(4):
@ -173,17 +172,9 @@ def worker():
print(f'Refiner disabled because base model and refiner are same.')
refiner_model_name = 'None'
assert performance_selection in ['Speed', 'Quality', 'Extreme Speed']
steps = performance_selection.steps()
steps = 30
if performance_selection == 'Speed':
steps = 30
if performance_selection == 'Quality':
steps = 60
if performance_selection == 'Extreme Speed':
if performance_selection == Performance.EXTREME_SPEED:
print('Enter LCM mode.')
progressbar(async_task, 1, 'Downloading LCM components ...')
loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)]
@ -201,7 +192,6 @@ def worker():
modules.patch.positive_adm_scale = advanced_parameters.adm_scaler_positive = 1.0
modules.patch.negative_adm_scale = advanced_parameters.adm_scaler_negative = 1.0
modules.patch.adm_scaler_end = advanced_parameters.adm_scaler_end = 0.0
steps = 8
base_model_path = os.path.join(modules.config.path_checkpoints, base_model_name)
base_model_hash = calculate_sha256(base_model_path)[0:10]
@ -274,16 +264,7 @@ def worker():
if 'fast' in uov_method:
skip_prompt_processing = True
else:
steps = 18
if performance_selection == 'Speed':
steps = 18
if performance_selection == 'Quality':
steps = 36
if performance_selection == 'Extreme Speed':
steps = 8
steps = performance_selection.steps_uov()
progressbar(async_task, 1, 'Downloading upscale models ...')
modules.config.downloading_upscale_model()
@ -802,11 +783,12 @@ def worker():
('Full Negative Prompt', 'full_negative_prompt', task['negative'], False, False),
('Fooocus V2 Expansion', 'prompt_expansion', task['expansion'], True, True),
('Styles', 'styles', str(raw_style_selections), True, True),
('Performance', 'performance', performance_selection, True, True),
('Performance', 'performance', performance_selection.value, True, True),
('Steps', 'steps', steps, False, False),
('Resolution', 'resolution', str((width, height)), True, True),
('Sharpness', 'sharpness', sharpness, True, True),
('Guidance Scale', 'guidance_scale', guidance_scale, True, True),
# ('Denoising Strength', 'denoising_strength', denoising_strength, False, False),
('ADM Guidance', 'adm_guidance', str((
modules.patch.positive_adm_scale,
modules.patch.negative_adm_scale,

View File

@ -6,9 +6,9 @@ import args_manager
import modules.flags
import modules.sdxl_styles
from modules.metadata import MetadataScheme
from modules.model_loader import load_file_from_url
from modules.util import get_files_from_folder
from modules.flags import Performance, MetadataScheme
config_path = os.path.abspath("./config.txt")
@ -236,8 +236,8 @@ default_prompt = get_config_item_or_set_default(
)
default_performance = get_config_item_or_set_default(
key='default_performance',
default_value='Speed',
validator=lambda x: x in modules.flags.performance_selections
default_value=Performance.SPEED.value,
validator=lambda x: x in Performance.list()
)
default_advanced_checkbox = get_config_item_or_set_default(
key='default_advanced_checkbox',

View File

@ -1,4 +1,4 @@
from modules.metadata import MetadataScheme
from enum import Enum
disabled = 'Disabled'
enabled = 'Enabled'
@ -34,6 +34,12 @@ default_parameters = {
cn_ip: (0.5, 0.6), cn_ip_face: (0.9, 0.75), cn_canny: (0.5, 1.0), cn_cpds: (0.5, 1.0)
} # stop, weight
class MetadataScheme(Enum):
FOOOCUS = 'fooocus'
A1111 = 'a1111'
# TODO use translation here
metadata_scheme = [
('Fooocus (json)', MetadataScheme.FOOOCUS.value),
@ -41,7 +47,37 @@ metadata_scheme = [
]
inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6']
performance_selections = ['Speed', 'Quality', 'Extreme Speed']
class Steps(Enum):
QUALITY = 60
SPEED = 30
EXTREME_SPEED = 8
class StepsUOV(Enum):
QUALITY = 36
SPEED = 18
EXTREME_SPEED = 8
class Performance(Enum):
QUALITY = 'Quality'
SPEED = 'Speed'
EXTREME_SPEED = 'Extreme Speed'
@classmethod
def list(cls) -> list:
return list(map(lambda c: c.value, cls))
def steps(self) -> int:
return Steps[self.name].value if Steps[self.name] else None
def steps_uov(self) -> int:
return StepsUOV[self.name].value if Steps[self.name] else None
performance_selections = Performance.list()
inpaint_option_default = 'Inpaint or Outpaint (default)'
inpaint_option_detail = 'Improve Detail (face, hand, eyes, etc.)'

View File

@ -8,17 +8,13 @@ import modules.config
import fooocus_version
# import advanced_parameters
from modules.util import quote, unquote, is_json
from modules.flags import MetadataScheme, Performance, Steps
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
class MetadataScheme(Enum):
FOOOCUS = 'fooocus'
A1111 = 'a1111'
class MetadataParser(ABC):
@abstractmethod
def parse_json(self, metadata: dict) -> dict:
@ -70,6 +66,14 @@ class A1111MetadataParser(MetadataParser):
else:
prompt += ('' if prompt == '' else "\n") + line
# if shared.opts.infotext_styles != "Ignore":
# found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt,
# negative_prompt)
#
# if shared.opts.infotext_styles == "Apply":
# res["Styles array"] = found_styles
# elif shared.opts.infotext_styles == "Apply if any" and found_styles:
# res["Styles array"] = found_styles
data = {
'prompt': prompt,
@ -87,11 +91,17 @@ class A1111MetadataParser(MetadataParser):
data[f"{k}-1"] = m.group(1)
data[f"{k}-2"] = m.group(2)
else:
key = list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)]
data[key] = v
data[list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)]] = v
except Exception:
print(f"Error parsing \"{k}: {v}\"")
# try to load performance based on steps
if 'steps' in data:
try:
data['performance'] = Performance[Steps(int(data['steps'])).name].value
except Exception:
pass
return data
def parse_string(self, metadata: dict) -> str:
@ -104,9 +114,10 @@ class A1111MetadataParser(MetadataParser):
lora_hashes = []
for index in range(5):
name = f'lora_name_{index + 1}'
if name in data:
# weight = f'lora_weight_{index}'
key = f'lora_name_{index + 1}'
if key in data:
name = data[f'lora_name_{index + 1}']
# weight = data[f'lora_weight_{index + 1}']
hash = data[f'lora_hash_{index + 1}']
lora_hashes.append(f'{name.split(".")[0]}: {hash}')
lora_hashes_string = ", ".join(lora_hashes)
@ -121,6 +132,7 @@ class A1111MetadataParser(MetadataParser):
self.fooocus_to_a1111['sampler']: data['sampler'],
self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'],
self.fooocus_to_a1111['seed']: data['seed'],
# TODO check resolution value, should be string
self.fooocus_to_a1111['resolution']: f'{width}x{heigth}',
self.fooocus_to_a1111['base_model']: data['base_model'].split('.')[0],
self.fooocus_to_a1111['base_model_hash']: data['base_model_hash']
@ -236,11 +248,11 @@ class FooocusMetadataParser(MetadataParser):
# return json.dumps(metadata, ensure_ascii=False)
def get_metadata_parser(metadata_scheme: str) -> MetadataParser:
def get_metadata_parser(metadata_scheme: MetadataScheme) -> MetadataParser:
match metadata_scheme:
case MetadataScheme.FOOOCUS.value:
case MetadataScheme.FOOOCUS:
return FooocusMetadataParser()
case MetadataScheme.A1111.value:
case MetadataScheme.A1111:
return A1111MetadataParser()
case _:
raise NotImplementedError
@ -252,7 +264,7 @@ def get_metadata_parser(metadata_scheme: str) -> MetadataParser:
# }
def read_info_from_image(filepath) -> tuple[str | None, dict, str | None]:
def read_info_from_image(filepath) -> tuple[str | None, dict, MetadataScheme | None]:
with Image.open(filepath) as image:
items = (image.info or {}).copy()
@ -260,8 +272,19 @@ def read_info_from_image(filepath) -> tuple[str | None, dict, str | None]:
if parameters is not None and is_json(parameters):
parameters = json.loads(parameters)
metadata_scheme = items.pop('fooocus_scheme', None)
try:
metadata_scheme = MetadataScheme(items.pop('fooocus_scheme', None))
except Exception:
metadata_scheme = None
# broad fallback
if metadata_scheme is None and isinstance(parameters, dict):
metadata_scheme = modules.metadata.MetadataScheme.FOOOCUS
if metadata_scheme is None and isinstance(parameters, str):
metadata_scheme = modules.metadata.MetadataScheme.A1111
# TODO code cleanup
# if "exif" in items:
# exif_data = items["exif"]
# try:

View File

@ -20,9 +20,7 @@ def get_current_html_path():
return html_name
def log(img, metadata, save_metadata_to_image=False, metadata_scheme: str = MetadataScheme.FOOOCUS.value):
assert metadata_scheme in [item.value for item in MetadataScheme]
def log(img, metadata, save_metadata_to_image=False, metadata_scheme: MetadataScheme = MetadataScheme.FOOOCUS):
if args_manager.args.disable_image_log:
return
@ -35,7 +33,7 @@ def log(img, metadata, save_metadata_to_image=False, metadata_scheme: str = Meta
pnginfo = PngInfo()
pnginfo.add_text('parameters', parsed_parameters)
pnginfo.add_text('fooocus_scheme', metadata_scheme)
pnginfo.add_text('fooocus_scheme', metadata_scheme.value)
else:
pnginfo = None
Image.fromarray(img).save(local_temp_filename, pnginfo=pnginfo)

View File

@ -222,10 +222,12 @@ with shared.gradio_root:
results = {}
if parameters is not None:
results['parameters'] = parameters
if items:
results['items'] = items
if metadata_scheme is not None:
results['metadata_scheme'] = metadata_scheme
if isinstance(metadata_scheme, flags.MetadataScheme):
results['metadata_scheme'] = metadata_scheme.value
return results