feat: add metadata handling for all non-img2img parameters

This commit is contained in:
Manuel Schmid 2024-01-31 01:18:09 +01:00
parent 7772eb7965
commit 9bdb65ec5d
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
4 changed files with 130 additions and 62 deletions

View File

@ -274,7 +274,7 @@ def worker():
and isinstance(inpaint_input_image, dict):
inpaint_image = inpaint_input_image['image']
inpaint_mask = inpaint_input_image['mask'][:, :, 0]
if advanced_parameters.inpaint_mask_upload_checkbox:
if isinstance(inpaint_mask_image_upload, np.ndarray):
if inpaint_mask_image_upload.ndim == 3:
@ -777,38 +777,57 @@ def worker():
imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
for x in imgs:
d = [
('Prompt', 'prompt', task['log_positive_prompt'], True, True),
('Full Positive Prompt', 'full_prompt', task['positive'], False, False),
('Negative Prompt', 'negative_prompt', task['log_negative_prompt'], True, True),
('Full Negative Prompt', 'full_negative_prompt', task['negative'], False, False),
('Fooocus V2 Expansion', 'prompt_expansion', task['expansion'], True, True),
('Styles', 'styles', str(raw_style_selections), True, True),
('Performance', 'performance', performance_selection.value, True, True),
('Steps', 'steps', steps, False, False),
('Resolution', 'resolution', str((width, height)), True, True),
('Sharpness', 'sharpness', sharpness, True, True),
('Guidance Scale', 'guidance_scale', guidance_scale, True, True),
('ADM Guidance', 'adm_guidance', str((
modules.patch.positive_adm_scale,
modules.patch.negative_adm_scale,
modules.patch.adm_scaler_end)), True, True),
('Base Model', 'base_model', base_model_name, True, True),
('Base Model Hash', 'base_model_hash', base_model_hash, False, False),
('Refiner Model', 'refiner_model', refiner_model_name, True, True),
('Refiner Model Hash', 'refiner_model_hash', refiner_model_hash, False, False),
('Refiner Switch', 'refiner_switch', refiner_switch, True, True),
('Sampler', 'sampler', sampler_name, True, True),
('Scheduler', 'scheduler', scheduler_name, True, True),
('Seed', 'seed', task['task_seed'], True, True)
]
d = [('Prompt', 'prompt', task['log_positive_prompt'], True, True),
('Full Positive Prompt', 'full_prompt', task['positive'], False, False),
('Negative Prompt', 'negative_prompt', task['log_negative_prompt'], True, True),
('Full Negative Prompt', 'full_negative_prompt', task['negative'], False, False),
('Fooocus V2 Expansion', 'prompt_expansion', task['expansion'], True, True),
('Styles', 'styles', str(raw_style_selections), True, True),
('Performance', 'performance', performance_selection.value, True, True),
('Steps', 'steps', steps, False, False),
('Resolution', 'resolution', str((width, height)), True, True),
('Guidance Scale', 'guidance_scale', guidance_scale, True, True),
('Sharpness', 'sharpness', sharpness, True, True),
('ADM Guidance', 'adm_guidance', str((
modules.patch.positive_adm_scale,
modules.patch.negative_adm_scale,
modules.patch.adm_scaler_end)), True, True),
('Base Model', 'base_model', base_model_name, True, True),
('Base Model Hash', 'base_model_hash', base_model_hash, False, False), # TODO move to metadata and use cache
('Refiner Model', 'refiner_model', refiner_model_name, True, True),
('Refiner Model Hash', 'refiner_model_hash', refiner_model_hash, False, False), # TODO move to metadata and use cache
('Refiner Switch', 'refiner_switch', refiner_switch, True, True)]
# TODO evaluate if this should always be added
if refiner_model_name != 'None':
if advanced_parameters.overwrite_switch > 0:
d.append(('Overwrite Switch', 'overwrite_switch', advanced_parameters.overwrite_switch, True, True))
if refiner_swap_method != flags.refiner_swap_method:
d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method, True, True))
if advanced_parameters.adaptive_cfg != modules.config.default_cfg_tsnr:
d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', advanced_parameters.adaptive_cfg, True, True))
d.append(('Sampler', 'sampler', sampler_name, True, True))
d.append(('Scheduler', 'scheduler', scheduler_name, True, True))
d.append(('Seed', 'seed', task['task_seed'], True, True))
if advanced_parameters.freeu_enabled:
d.append(('FreeU', 'freeu', str((
advanced_parameters.freeu_b1,
advanced_parameters.freeu_b2,
advanced_parameters.freeu_s1,
advanced_parameters.freeu_s2)), True, True))
for li, (n, w) in enumerate(loras):
if n != 'None':
d.append((f'LoRA {li + 1}', f'lora_combined_{li + 1}', f'{n} : {w}', True, True))
d.append((f'LoRA {li + 1} Name', f'lora_name_{li + 1}', n, False, False))
d.append((f'LoRA {li + 1} Weight', f'lora_weight_{li + 1}', w, False, False))
# TODO move hashes to metadata handling
d.append((f'LoRA {li + 1} Hash', f'lora_hash_{li + 1}', lora_hashes[li], False, False))
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version, True, True))
if modules.config.metadata_created_by != '':
d.append(('Created By', 'created_by', modules.config.metadata_created_by, False, False))

View File

@ -3,7 +3,7 @@ import json
import gradio as gr
import modules.config
from modules.flags import lora_count
from modules.flags import lora_count, Steps
def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
@ -18,10 +18,14 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
get_str('negative_prompt', 'Negative Prompt', loaded_parameter_dict, results)
get_list('styles', 'Styles', loaded_parameter_dict, results)
get_str('performance', 'Performance', loaded_parameter_dict, results)
get_steps('steps', 'Steps', loaded_parameter_dict, results)
get_float('overwrite_switch', 'Overwrite Switch', loaded_parameter_dict, results)
get_resolution('resolution', 'Resolution', loaded_parameter_dict, results)
get_float('sharpness', 'Sharpness', loaded_parameter_dict, results)
get_float('guidance_scale', 'Guidance Scale', loaded_parameter_dict, results)
get_float('sharpness', 'Sharpness', loaded_parameter_dict, results)
get_adm_guidance('adm_guidance', 'ADM Guidance', loaded_parameter_dict, results)
get_str('refiner_swap_method', 'Refiner Swap Method', loaded_parameter_dict, results)
get_str('adaptive_cfg', 'CFG Mimicking from TSNR', loaded_parameter_dict, results)
get_str('base_model', 'Base Model', loaded_parameter_dict, results)
get_str('refiner_model', 'Refiner Model', loaded_parameter_dict, results)
get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results)
@ -36,6 +40,8 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
results.append(gr.update(visible=False))
get_freeu('freeu', 'FreeU', loaded_parameter_dict, results)
for i in range(lora_count):
get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results)
@ -44,45 +50,51 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = source_dict.get(key, default)
h = source_dict.get(key, source_dict.get(fallback, default))
assert isinstance(h, str)
results.append(h)
except:
if fallback is not None:
get_str(fallback, None, source_dict, results, default)
return
results.append(gr.update())
def get_list(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = source_dict.get(key, default)
h = source_dict.get(key, source_dict.get(fallback, default))
h = eval(h)
assert isinstance(h, list)
results.append(h)
except:
if fallback is not None:
get_list(fallback, None, source_dict, results, default)
return
results.append(gr.update())
# TODO try get generic
def get_float(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = source_dict.get(key, default)
h = source_dict.get(key, source_dict.get(fallback, default))
assert h is not None
h = float(h)
results.append(h)
except:
if fallback is not None:
get_float(fallback, None, source_dict, results, default)
return
results.append(gr.update())
def get_steps(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = source_dict.get(key, source_dict.get(fallback, default))
assert h is not None
h = int(h)
if h not in set(item.value for item in Steps):
results.append(h)
return
results.append(-1)
except:
results.append(-1)
def get_resolution(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = source_dict.get(key, default)
h = source_dict.get(key, source_dict.get(fallback, default))
width, height = eval(h)
formatted = modules.config.add_ratio(f'{width}*{height}')
if formatted in modules.config.available_aspect_ratios:
@ -94,9 +106,6 @@ def get_resolution(key: str, fallback: str | None, source_dict: dict, results: l
results.append(width)
results.append(height)
except:
if fallback is not None:
get_resolution(fallback, None, source_dict, results, default)
return
results.append(gr.update())
results.append(gr.update())
results.append(gr.update())
@ -104,30 +113,41 @@ def get_resolution(key: str, fallback: str | None, source_dict: dict, results: l
def get_seed(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = source_dict.get(key, default)
h = source_dict.get(key, source_dict.get(fallback, default))
assert h is not None
h = int(h)
results.append(False)
results.append(h)
except:
if fallback is not None:
get_seed(fallback, None, source_dict, results, default)
return
results.append(gr.update())
results.append(gr.update())
def get_adm_guidance(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = source_dict.get(key, default)
h = source_dict.get(key, source_dict.get(fallback, default))
p, n, e = eval(h)
results.append(float(p))
results.append(float(n))
results.append(float(e))
except:
if fallback is not None:
get_adm_guidance(fallback, None, source_dict, results, default)
return
results.append(gr.update())
results.append(gr.update())
results.append(gr.update())
def get_freeu(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = source_dict.get(key, source_dict.get(fallback, default))
b1, b2, s1, s2 = eval(h)
results.append(True)
results.append(float(b1))
results.append(float(b2))
results.append(float(s1))
results.append(float(s2))
except:
results.append(False)
results.append(gr.update())
results.append(gr.update())
results.append(gr.update())
results.append(gr.update())
@ -135,13 +155,10 @@ def get_adm_guidance(key: str, fallback: str | None, source_dict: dict, results:
def get_lora(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
n, w = source_dict.get(key).split(' : ')
n, w = source_dict.get(key, source_dict.get(fallback)).split(' : ')
w = float(w)
results.append(n)
results.append(w)
except:
if fallback is not None:
get_lora(fallback, None, source_dict, results, default)
return
results.append('None')
results.append(1)

View File

@ -31,11 +31,18 @@ class A1111MetadataParser(MetadataParser):
fooocus_to_a1111 = {
'negative_prompt': 'Negative prompt',
'styles': 'Styles',
'performance': 'Performance',
'steps': 'Steps',
'sampler': 'Sampler',
'guidance_scale': 'CFG scale',
'seed': 'Seed',
'resolution': 'Size',
'sharpness': 'Sharpness',
'adm_guidance': 'ADM Guidance',
'refiner_swap_method': 'Refiner Swap Method',
'adaptive_cfg': 'Adaptive CFG',
'overwrite_switch': 'Overwrite Switch',
'freeu': 'FreeU',
'base_model': 'Model',
'base_model_hash': 'Model hash',
'refiner_model': 'Refiner',
@ -87,8 +94,8 @@ class A1111MetadataParser(MetadataParser):
except Exception:
print(f"Error parsing \"{k}: {v}\"")
# try to load performance based on steps
if 'steps' in data:
# try to load performance based on steps, fallback for direct A1111 imports
if 'steps' in data and 'performance' not in data:
try:
data['performance'] = Performance[Steps(int(data['steps'])).name].value
except Exception:
@ -132,10 +139,14 @@ class A1111MetadataParser(MetadataParser):
lora_hashes_string = ', '.join(lora_hashes)
generation_params = {
self.fooocus_to_a1111['performance']: data['performance'],
self.fooocus_to_a1111['steps']: data['steps'],
self.fooocus_to_a1111['sampler']: data['sampler'],
self.fooocus_to_a1111['seed']: data['seed'],
self.fooocus_to_a1111['resolution']: f'{width}x{heigth}',
self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'],
self.fooocus_to_a1111['sharpness']: data['sharpness'],
self.fooocus_to_a1111['adm_guidance']: data['adm_guidance'],
# TODO load model by name / hash
self.fooocus_to_a1111['base_model']: Path(data['base_model']).stem,
self.fooocus_to_a1111['base_model_hash']: data['base_model_hash']
@ -147,6 +158,26 @@ class A1111MetadataParser(MetadataParser):
self.fooocus_to_a1111['refiner_model_hash']: data['refiner_model_hash']
}
if 'refiner_swap_method' in data:
generation_params |= {
self.fooocus_to_a1111['refiner_swap_method']: data['refiner_swap_method'],
}
# TODO unify with for and call with key
if 'freeu' in data:
generation_params |= {
self.fooocus_to_a1111['freeu']: data['freeu'],
}
if 'adaptive_cfg' in data:
generation_params |= {
self.fooocus_to_a1111['adaptive_cfg']: data['adaptive_cfg'],
}
if 'overwrite_switch' in data:
generation_params |= {
self.fooocus_to_a1111['overwrite_switch']: data['overwrite_switch'],
}
generation_params |= {
self.fooocus_to_a1111['lora_hashes']: lora_hashes_string,
self.fooocus_to_a1111['version']: data['version']

View File

@ -588,10 +588,11 @@ with shared.gradio_root:
prompt.input(parse_meta, inputs=[prompt, state_is_generating], outputs=[prompt, generate_button, load_parameter_button], queue=False, show_progress=False)
load_data_outputs = [advanced_checkbox, image_number, prompt, negative_prompt, style_selections,
performance_selection, aspect_ratios_selection, overwrite_width, overwrite_height,
sharpness, guidance_scale, adm_scaler_positive, adm_scaler_negative, adm_scaler_end,
base_model, refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random,
image_seed, generate_button, load_parameter_button] + lora_ctrls
performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection,
overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive,
adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model,
refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed,
generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls
load_parameter_button.click(modules.meta_parser.load_parameter_button_click, inputs=[prompt, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=False)