Merge branch 'feature/add-metadata-to-files'

# Conflicts:
#	language/en.json
#	modules/async_worker.py
#	modules/config.py
#	modules/flags.py
#	modules/meta_parser.py
#	modules/private_logger.py
#	modules/util.py
#	webui.py
This commit is contained in:
Manuel Schmid 2024-02-04 21:09:24 +01:00
commit ceefba9b69
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
8 changed files with 832 additions and 340 deletions

View File

@ -379,4 +379,11 @@
"Box Threshold": "Box Threshold",
"Text Threshold": "Text Threshold",
"Generate mask from image": "Generate mask from image"
"Drag any image generated by Fooocus here": "Drag any image generated by Fooocus here",
"Metadata": "Metadata",
"Apply Metadata": "Apply Metadata",
"Metadata Scheme": "Metadata Scheme",
"Image Prompt parameters are not included. Use a1111 for compatibility with Civitai.": "Image Prompt parameters are not included. Use a1111 for compatibility with Civitai.",
"fooocus (json)": "fooocus (json)",
"a1111 (plain text)": "a1111 (plain text)"
}

View File

@ -23,7 +23,6 @@ def worker():
import os
import traceback
import math
import json
import numpy as np
import torch
import time
@ -50,8 +49,10 @@ def worker():
from modules.private_logger import log
from extras.expansion import safe_str
from modules.util import remove_empty_str, HWC3, resize_image, \
get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, calculate_sha256, quote
get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate
from modules.upscaler import perform_upscale
from modules.flags import Performance, lora_count
from modules.meta_parser import get_metadata_parser, MetadataScheme
pid = os.getpid()
print(f'Started worker with PID {pid}')
@ -134,7 +135,7 @@ def worker():
negative_prompt = args.pop()
translate_prompts = args.pop()
style_selections = args.pop()
performance_selection = args.pop()
performance_selection = Performance(args.pop())
aspect_ratios_selection = args.pop()
image_number = args.pop()
output_format = args.pop()
@ -144,7 +145,7 @@ def worker():
base_model_name = args.pop()
refiner_model_name = args.pop()
refiner_switch = args.pop()
loras = [[str(args.pop()), float(args.pop())] for _ in range(5)]
loras = [[str(args.pop()), float(args.pop())] for _ in range(lora_count)]
input_image_checkbox = args.pop()
current_tab = args.pop()
uov_method = args.pop()
@ -192,10 +193,10 @@ def worker():
inpaint_erode_or_dilate = args.pop()
save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False
metadata_scheme = args.pop() if not args_manager.args.disable_metadata else 'fooocus'
metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS
cn_tasks = {x: [] for x in flags.ip_list}
for _ in range(4):
for _ in range(flags.controlnet_image_count):
cn_img = args.pop()
cn_stop = args.pop()
cn_weight = args.pop()
@ -220,17 +221,9 @@ def worker():
print(f'Refiner disabled because base model and refiner are same.')
refiner_model_name = 'None'
assert performance_selection in ['Speed', 'Quality', 'Extreme Speed']
steps = performance_selection.steps()
steps = 30
if performance_selection == 'Speed':
steps = 30
if performance_selection == 'Quality':
steps = 60
if performance_selection == 'Extreme Speed':
if performance_selection == Performance.EXTREME_SPEED:
print('Enter LCM mode.')
progressbar(async_task, 1, 'Downloading LCM components ...')
loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)]
@ -248,24 +241,12 @@ def worker():
adm_scaler_positive = 1.0
adm_scaler_negative = 1.0
adm_scaler_end = 0.0
steps = 8
if translate_prompts:
from modules.translator import translate2en
prompt = translate2en(prompt, 'prompt')
negative_prompt = translate2en(negative_prompt, 'negative prompt')
if not args_manager.args.disable_metadata:
base_model_path = os.path.join(modules.config.path_checkpoints, base_model_name)
base_model_hash = calculate_sha256(base_model_path)[0:10]
lora_hashes = []
for (n, w) in loras:
if n != 'None':
lora_path = os.path.join(modules.config.path_loras, n)
lora_hashes.append(f'{n.split(".")[0]}: {calculate_sha256(lora_path)[0:10]}')
lora_hashes_string = ", ".join(lora_hashes)
print(f'[Parameters] Adaptive CFG = {adaptive_cfg}')
print(f'[Parameters] Sharpness = {sharpness}')
print(f'[Parameters] ControlNet Softness = {controlnet_softness}')
@ -325,16 +306,7 @@ def worker():
if 'fast' in uov_method:
skip_prompt_processing = True
else:
steps = 18
if performance_selection == 'Speed':
steps = 18
if performance_selection == 'Quality':
steps = 36
if performance_selection == 'Extreme Speed':
steps = 8
steps = performance_selection.steps_uov()
progressbar(async_task, 1, 'Downloading upscale models ...')
modules.config.downloading_upscale_model()
@ -422,9 +394,6 @@ def worker():
progressbar(async_task, 1, 'Initializing ...')
raw_prompt = prompt
raw_negative_prompt = negative_prompt
if not skip_prompt_processing:
prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='')
@ -850,130 +819,52 @@ def worker():
imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
img_paths = []
metadata_string = ''
if save_metadata_to_images and metadata_scheme == 'fooocus':
metadata = {
'prompt': raw_prompt, 'negative_prompt': raw_negative_prompt, 'styles': str(raw_style_selections),
'real_prompt': task['log_positive_prompt'], 'real_negative_prompt': task['log_negative_prompt'],
'seed': task['task_seed'], 'width': width, 'height': height,
'sampler': sampler_name, 'scheduler': scheduler_name, 'performance': performance_selection,
'steps': steps, 'refiner_switch': refiner_switch, 'sharpness': sharpness, 'cfg': cfg_scale,
'base_model': base_model_name, 'refiner_model': refiner_model_name,
'denoising_strength': denoising_strength,
'freeu': freeu_enabled,
'img2img': input_image_checkbox,
'prompt_expansion': task['expansion']
}
if freeu_enabled:
metadata |= {
'freeu_b1': freeu_b1, 'freeu_b2': freeu_b2, 'freeu_s1': freeu_s1, 'freeu_s2': freeu_s2
}
if 'vary' in goals:
metadata |= {
'uov_method': uov_method
}
if 'upscale' in goals:
metadata |= {
'uov_method': uov_method, 'scale': f
}
if 'inpaint' in goals:
if len(outpaint_selections) > 0:
metadata |= {
'outpaint_selections': outpaint_selections
}
else:
metadata |= {
'inpaint_additional_prompt': inpaint_additional_prompt, 'inpaint_mask_upload': inpaint_mask_upload_checkbox, 'invert_mask': invert_mask_checkbox,
'inpaint_disable_initial_latent': inpaint_disable_initial_latent, 'inpaint_engine': inpaint_engine,
'inpaint_strength': inpaint_strength, 'inpaint_respective_field': inpaint_respective_field,
}
if 'cn' in goals:
metadata |= {
'canny_low_threshold': canny_low_threshold, 'canny_high_threshold': canny_high_threshold,
}
ip_list = {x: [] for x in flags.ip_list}
cn_task_index = 1
for cn_type in ip_list:
for cn_task in cn_tasks[cn_type]:
cn_img, cn_stop, cn_weight = cn_task
metadata |= {
f'image_prompt_{cn_task_index}': {
'cn_type': cn_type, 'cn_stop': cn_stop, 'cn_weight': cn_weight,
}
}
cn_task_index += 1
metadata |= {
'software': f'Fooocus v{fooocus_version.version}',
}
if modules.config.metadata_created_by != 'None':
metadata |= {
'created_by': modules.config.metadata_created_by
}
metadata_string = json.dumps(metadata, ensure_ascii=False)
elif save_metadata_to_images and metadata_scheme == 'a1111':
generation_params = {
"Steps": steps,
"Sampler": sampler_name,
"CFG scale": cfg_scale,
"Seed": task['task_seed'],
"Size": f"{width}x{height}",
"Model hash": base_model_hash,
"Model": base_model_name.split('.')[0],
"Lora hashes": lora_hashes_string,
"Denoising strength": denoising_strength,
"Version": f'Fooocus v{fooocus_version.version}'
}
if modules.config.metadata_created_by != 'None':
generation_params |= {
'Created By': f'{modules.config.metadata_created_by}'
}
generation_params_text = ", ".join([k if k == v else f'{k}: {quote(v)}' for k, v in generation_params.items() if v is not None])
positive_prompt_resolved = ', '.join(task['positive'])
negative_prompt_resolved = ', '.join(task['negative'])
negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else ""
metadata_string = f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip()
if modules.config.default_black_out_nsfw or black_out_nsfw:
progressbar_index = int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps))
progressbar(async_task, progressbar_index, 'Checking for NSFW content ...')
imgs = censor_batch(imgs)
for x in imgs:
d = [
('Prompt', task['log_positive_prompt']),
('Negative Prompt', task['log_negative_prompt']),
('Fooocus V2 Expansion', task['expansion']),
('Styles', str(raw_style_selections)),
('Performance', performance_selection),
('Resolution', str((width, height))),
('Sharpness', sharpness),
('Guidance Scale', guidance_scale),
('ADM Guidance', str((
modules.patch.patch_settings[pid].positive_adm_scale,
modules.patch.patch_settings[pid].negative_adm_scale,
modules.patch.patch_settings[pid].adm_scaler_end))),
('Base Model', base_model_name),
('Refiner Model', refiner_model_name),
('Refiner Switch', refiner_switch),
('Sampler', sampler_name),
('Scheduler', scheduler_name),
('Sampling Steps Override', overwrite_step),
('Seed', task['task_seed']),
]
d = [('Prompt', 'prompt', task['log_positive_prompt']),
('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),
('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']),
('Styles', 'styles', str(raw_style_selections)),
('Performance', 'performance', performance_selection.value),
('Resolution', 'resolution', str((width, height))),
('Guidance Scale', 'guidance_scale', guidance_scale),
('Sharpness', 'sharpness', modules.patch.patch_settings[pid].sharpness),
('ADM Guidance', 'adm_guidance', str((
modules.patch.patch_settings[pid].positive_adm_scale,
modules.patch.patch_settings[pid].negative_adm_scale,
modules.patch.patch_settings[pid].adm_scaler_end))),
('Base Model', 'base_model', base_model_name),
('Refiner Model', 'refiner_model', refiner_model_name),
('Refiner Switch', 'refiner_switch', refiner_switch)]
if refiner_model_name != 'None':
if overwrite_switch > 0:
d.append(('Overwrite Switch', 'overwrite_switch', overwrite_switch))
if refiner_swap_method != flags.refiner_swap_method:
d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method))
if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr:
d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg))
d.append(('Sampler', 'sampler', sampler_name))
d.append(('Scheduler', 'scheduler', scheduler_name))
d.append(('Seed', 'seed', task['task_seed']))
if freeu_enabled:
d.append(('FreeU', 'freeu', str((freeu_b1, freeu_b2, freeu_s1, freeu_s2))))
metadata_parser = None
if save_metadata_to_images:
metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme)
metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
task['log_negative_prompt'], task['negative'],
steps, base_model_name, refiner_model_name, loras)
for li, (n, w) in enumerate(loras):
if n != 'None':
d.append((f'LoRA {li + 1}', f'{n} : {w}'))
d.append(('Version', 'v' + fooocus_version.version))
img_paths.append(log(x, d, metadata_string, save_metadata_to_images, output_format))
d.append((f'LoRA {li + 1}', f'lora_combined_{li + 1}', f'{n} : {w}'))
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
img_paths.append(log(x, d, metadata_parser, output_format))
yield_result(async_task, img_paths, black_out_nsfw, do_not_show_finished_images=len(tasks) == 1
or disable_intermediate_results or sampler_name == 'lcm')

View File

@ -8,7 +8,7 @@ import modules.sdxl_styles
from modules.model_loader import load_file_from_url
from modules.util import get_files_from_folder
from modules.flags import Performance, MetadataScheme, lora_count
config_path = os.path.abspath("./config.txt")
config_example_path = os.path.abspath("config_modification_tutorial.txt")
@ -289,7 +289,7 @@ default_prompt = get_config_item_or_set_default(
)
default_performance = get_config_item_or_set_default(
key='default_performance',
default_value='Speed',
default_value=Performance.SPEED.value,
validator=lambda x: x in [y[1] for y in modules.flags.performance_selections if y[1] == x]
)
default_advanced_checkbox = get_config_item_or_set_default(
@ -377,7 +377,7 @@ default_save_metadata_to_images = get_config_item_or_set_default(
)
default_metadata_scheme = get_config_item_or_set_default(
key='default_metadata_scheme',
default_value='fooocus',
default_value=MetadataScheme.FOOOCUS.value,
validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x]
)
metadata_created_by = get_config_item_or_set_default(
@ -411,25 +411,25 @@ default_inpaint_mask_sam_model = get_config_item_or_set_default(
validator=lambda x: x in modules.flags.inpaint_mask_sam_model
)
config_dict["default_loras"] = default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))]
config_dict["default_loras"] = default_loras = default_loras[:lora_count] + [['None', 1.0] for _ in range(lora_count - len(default_loras))]
# mapping config to meta parameter
possible_preset_keys = {
"default_model": "Base Model",
"default_refiner": "Refiner Model",
"default_refiner_switch": "Refiner Switch",
"default_model": "base_model",
"default_refiner": "refiner_model",
"default_refiner_switch": "refiner_switch",
"previous_default_models": "previous_default_models",
"default_loras": "<processed>",
"default_cfg_scale": "Guidance Scale",
"default_sample_sharpness": "Sharpness",
"default_sampler": "Sampler",
"default_scheduler": "Scheduler",
"default_overwrite_step": "Sampling Steps Override",
"default_performance": "Performance",
"default_prompt": "Prompt",
"default_prompt_negative": "Negative Prompt",
"default_styles": "Styles",
"default_aspect_ratio": "Resolution",
"default_cfg_scale": "guidance_scale",
"default_sample_sharpness": "sharpness",
"default_sampler": "sampler",
"default_scheduler": "scheduler",
"default_overwrite_step": "steps",
"default_performance": "performance",
"default_prompt": "prompt",
"default_prompt_negative": "negative_prompt",
"default_styles": "styles",
"default_aspect_ratio": "resolution",
"checkpoint_downloads": "checkpoint_downloads",
"embeddings_downloads": "embeddings_downloads",
"lora_downloads": "lora_downloads"

View File

@ -1,3 +1,5 @@
from enum import Enum
disabled = 'Disabled'
enabled = 'Enabled'
subtle_variation = 'Vary (Subtle)'
@ -10,16 +12,49 @@ uov_list = [
disabled, subtle_variation, strong_variation, upscale_15, upscale_2, upscale_fast
]
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
CIVITAI_NO_KARRAS = ["euler", "euler_ancestral", "heun", "dpm_fast", "dpm_adaptive", "ddim", "uni_pc"]
# fooocus: a1111 (Civitai)
KSAMPLER = {
"euler": "Euler",
"euler_ancestral": "Euler a",
"heun": "Heun",
"heunpp2": "",
"dpm_2": "DPM2",
"dpm_2_ancestral": "DPM2 a",
"lms": "LMS",
"dpm_fast": "DPM fast",
"dpm_adaptive": "DPM adaptive",
"dpmpp_2s_ancestral": "DPM++ 2S a",
"dpmpp_sde": "DPM++ SDE",
"dpmpp_sde_gpu": "DPM++ SDE",
"dpmpp_2m": "DPM++ 2M",
"dpmpp_2m_sde": "DPM++ 2M SDE",
"dpmpp_2m_sde_gpu": "DPM++ 2M SDE",
"dpmpp_3m_sde": "",
"dpmpp_3m_sde_gpu": "",
"ddpm": "",
"lcm": "LCM"
}
SAMPLER_EXTRA = {
"ddim": "DDIM",
"uni_pc": "UniPC",
"uni_pc_bh2": ""
}
SAMPLERS = KSAMPLER | SAMPLER_EXTRA
KSAMPLER_NAMES = list(KSAMPLER.keys())
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys())
sampler_list = SAMPLER_NAMES
scheduler_list = SCHEDULER_NAMES
refiner_swap_method = 'joint'
cn_ip = "ImagePrompt"
cn_ip_face = "FaceSwap"
cn_canny = "PyraCanny"
@ -32,19 +67,8 @@ default_parameters = {
cn_ip: (0.5, 0.6), cn_ip_face: (0.9, 0.75), cn_canny: (0.5, 1.0), cn_cpds: (0.5, 1.0)
} # stop, weight
metadata_scheme =[
('Fooocus (json)', 'fooocus'),
('A1111 (plain text)', 'a1111'),
]
inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6']
performance_selections = [
('Quality <span style="color: grey;"> \U00002223 60 steps</span>', 'Quality'),
('Speed <span style="color: grey;"> \U00002223 30 steps</span>', 'Speed'),
('Extreme Speed (LCM) <span style="color: grey;"> \U00002223 8 steps, intermediate results disabled</span>', 'Extreme Speed')
]
output_formats = ['png', 'jpg', 'webp']
inpaint_mask_models = [
@ -62,3 +86,54 @@ inpaint_options = [inpaint_option_default, inpaint_option_detail, inpaint_option
desc_type_photo = 'Photograph'
desc_type_anime = 'Art/Anime'
class MetadataScheme(Enum):
FOOOCUS = 'fooocus'
A1111 = 'a1111'
metadata_scheme = [
(f'{MetadataScheme.FOOOCUS.value} (json)', MetadataScheme.FOOOCUS.value),
(f'{MetadataScheme.A1111.value} (plain text)', MetadataScheme.A1111.value),
]
lora_count = 5
lora_count_with_lcm = lora_count + 1
controlnet_image_count = 4
class Steps(Enum):
QUALITY = 60
SPEED = 30
EXTREME_SPEED = 8
class StepsUOV(Enum):
QUALITY = 36
SPEED = 18
EXTREME_SPEED = 8
class Performance(Enum):
QUALITY = 'Quality'
SPEED = 'Speed'
EXTREME_SPEED = 'Extreme Speed'
@classmethod
def list(cls) -> list:
return list(map(lambda c: c.value, cls))
def steps(self) -> int | None:
return Steps[self.name].value if Steps[self.name] else None
def steps_uov(self) -> int | None:
return StepsUOV[self.name].value if Steps[self.name] else None
performance_selections = [
('Quality <span style="color: grey;"> \U00002223 60 steps</span>', Performance.QUALITY.value),
('Speed <span style="color: grey;"> \U00002223 30 steps</span>', Performance.SPEED.value),
('Extreme Speed (LCM) <span style="color: grey;"> \U00002223 8 steps, intermediate results disabled</span>', Performance.EXTREME_SPEED.value)
]

View File

@ -1,45 +1,112 @@
import json
import os
import re
from abc import ABC, abstractmethod
from pathlib import Path
import gradio as gr
from PIL import Image
import modules.config
import modules.sdxl_styles
from modules.flags import MetadataScheme, Performance, Steps
from modules.flags import lora_count, SAMPLERS, CIVITAI_NO_KARRAS
from modules.util import quote, unquote, extract_styles_from_prompt, is_json, calculate_sha256
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
hash_cache = {}
def load_parameter_button_click(raw_prompt_txt, is_generating):
loaded_parameter_dict = json.loads(raw_prompt_txt)
def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
loaded_parameter_dict = raw_metadata
if isinstance(raw_metadata, str):
loaded_parameter_dict = json.loads(raw_metadata)
assert isinstance(loaded_parameter_dict, dict)
results = [True, 1]
results = [len(loaded_parameter_dict) > 0, 1]
get_str('prompt', 'Prompt', loaded_parameter_dict, results)
get_str('negative_prompt', 'Negative Prompt', loaded_parameter_dict, results)
get_list('styles', 'Styles', loaded_parameter_dict, results)
get_str('performance', 'Performance', loaded_parameter_dict, results)
get_steps('steps', 'Steps', loaded_parameter_dict, results)
get_float('overwrite_switch', 'Overwrite Switch', loaded_parameter_dict, results)
get_resolution('resolution', 'Resolution', loaded_parameter_dict, results)
get_float('guidance_scale', 'Guidance Scale', loaded_parameter_dict, results)
get_float('sharpness', 'Sharpness', loaded_parameter_dict, results)
get_adm_guidance('adm_guidance', 'ADM Guidance', loaded_parameter_dict, results)
get_str('refiner_swap_method', 'Refiner Swap Method', loaded_parameter_dict, results)
get_float('adaptive_cfg', 'CFG Mimicking from TSNR', loaded_parameter_dict, results)
get_str('base_model', 'Base Model', loaded_parameter_dict, results)
get_str('refiner_model', 'Refiner Model', loaded_parameter_dict, results)
get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results)
get_str('sampler', 'Sampler', loaded_parameter_dict, results)
get_str('scheduler', 'Scheduler', loaded_parameter_dict, results)
get_seed('seed', 'Seed', loaded_parameter_dict, results)
if is_generating:
results.append(gr.update())
else:
results.append(gr.update(visible=True))
results.append(gr.update(visible=False))
get_freeu('freeu', 'FreeU', loaded_parameter_dict, results)
for i in range(lora_count):
get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results)
return results
def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = loaded_parameter_dict.get('Prompt', None)
h = source_dict.get(key, source_dict.get(fallback, default))
assert isinstance(h, str)
results.append(h)
except:
results.append(gr.update())
try:
h = loaded_parameter_dict.get('Negative Prompt', None)
assert isinstance(h, str)
results.append(h)
except:
results.append(gr.update())
def get_list(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = loaded_parameter_dict.get('Styles', None)
h = source_dict.get(key, source_dict.get(fallback, default))
h = eval(h)
assert isinstance(h, list)
results.append(h)
except:
results.append(gr.update())
def get_float(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = loaded_parameter_dict.get('Performance', None)
assert isinstance(h, str)
h = source_dict.get(key, source_dict.get(fallback, default))
assert h is not None
h = float(h)
results.append(h)
except:
results.append(gr.update())
def get_steps(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = loaded_parameter_dict.get('Resolution', None)
h = source_dict.get(key, source_dict.get(fallback, default))
assert h is not None
h = int(h)
if h not in set(item.value for item in Steps):
results.append(h)
return
results.append(-1)
except:
results.append(-1)
def get_resolution(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = source_dict.get(key, source_dict.get(fallback, default))
width, height = eval(h)
formatted = modules.config.add_ratio(f'{width}*{height}')
if formatted in modules.config.available_aspect_ratios:
@ -55,24 +122,22 @@ def load_parameter_button_click(raw_prompt_txt, is_generating):
results.append(gr.update())
results.append(gr.update())
def get_seed(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = loaded_parameter_dict.get('Sharpness', None)
h = source_dict.get(key, source_dict.get(fallback, default))
assert h is not None
h = float(h)
h = int(h)
results.append(False)
results.append(h)
except:
results.append(gr.update())
try:
h = loaded_parameter_dict.get('Guidance Scale', None)
assert h is not None
h = float(h)
results.append(h)
except:
results.append(gr.update())
def get_adm_guidance(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = loaded_parameter_dict.get('ADM Guidance', None)
h = source_dict.get(key, source_dict.get(fallback, default))
p, n, e = eval(h)
results.append(float(p))
results.append(float(n))
@ -82,78 +147,41 @@ def load_parameter_button_click(raw_prompt_txt, is_generating):
results.append(gr.update())
results.append(gr.update())
try:
h = loaded_parameter_dict.get('Base Model', None)
assert isinstance(h, str)
results.append(h)
except:
results.append(gr.update())
def get_freeu(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = loaded_parameter_dict.get('Refiner Model', None)
assert isinstance(h, str)
results.append(h)
h = source_dict.get(key, source_dict.get(fallback, default))
b1, b2, s1, s2 = eval(h)
results.append(True)
results.append(float(b1))
results.append(float(b2))
results.append(float(s1))
results.append(float(s2))
except:
results.append(gr.update())
try:
h = loaded_parameter_dict.get('Refiner Switch', None)
assert h is not None
h = float(h)
results.append(h)
except:
results.append(gr.update())
try:
h = loaded_parameter_dict.get('Sampler', None)
assert isinstance(h, str)
results.append(h)
except:
results.append(gr.update())
try:
h = loaded_parameter_dict.get('Scheduler', None)
assert isinstance(h, str)
results.append(h)
except:
results.append(gr.update())
try:
h = loaded_parameter_dict.get('Sampling Steps Override', None)
assert h is not None
h = float(h)
results.append(h)
except:
results.append(gr.update())
try:
h = loaded_parameter_dict.get('Seed', None)
assert h is not None
h = int(h)
results.append(False)
results.append(h)
results.append(gr.update())
results.append(gr.update())
results.append(gr.update())
results.append(gr.update())
def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
try:
n, w = source_dict.get(key, source_dict.get(fallback)).split(' : ')
w = float(w)
results.append(n)
results.append(w)
except:
results.append(gr.update())
results.append(gr.update())
results.append('None')
results.append(1)
if is_generating:
results.append(gr.update())
else:
results.append(gr.update(visible=True))
results.append(gr.update(visible=False))
for i in range(1, 6):
try:
n, w = loaded_parameter_dict.get(f'LoRA {i}').split(' : ')
w = float(w)
results.append(n)
results.append(w)
except:
results.append(gr.update())
results.append(gr.update())
def get_sha256(filepath):
global hash_cache
if filepath not in hash_cache:
hash_cache[filepath] = calculate_sha256(filepath)
return results
return hash_cache[filepath]
def parse_meta_from_preset(preset_content):
@ -167,7 +195,7 @@ def parse_meta_from_preset(preset_content):
if settings_key in items:
loras = items[settings_key]
for index, lora in enumerate(loras[:5]):
preset_prepared[f'LoRA {index + 1}'] = ' : '.join(map(str, lora))
preset_prepared[f'lora_combined_{index + 1}'] = ' : '.join(map(str, lora))
elif settings_key == "default_aspect_ratio":
if settings_key in items and items[settings_key] is not None:
default_aspect_ratio = items[settings_key]
@ -184,3 +212,330 @@ def parse_meta_from_preset(preset_content):
preset_prepared[meta_key] = str(preset_prepared[meta_key])
return preset_prepared
class MetadataParser(ABC):
def __init__(self):
self.raw_prompt: str = ''
self.full_prompt: str = ''
self.raw_negative_prompt: str = ''
self.full_negative_prompt: str = ''
self.steps: int = 30
self.base_model_name: str = ''
self.base_model_hash: str = ''
self.refiner_model_name: str = ''
self.refiner_model_hash: str = ''
self.loras: list = []
@abstractmethod
def get_scheme(self) -> MetadataScheme:
raise NotImplementedError
@abstractmethod
def parse_json(self, metadata: dict | str) -> dict:
raise NotImplementedError
@abstractmethod
def parse_string(self, metadata: dict) -> str:
raise NotImplementedError
def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name, refiner_model_name, loras):
self.raw_prompt = raw_prompt
self.full_prompt = full_prompt
self.raw_negative_prompt = raw_negative_prompt
self.full_negative_prompt = full_negative_prompt
self.steps = steps
self.base_model_name = Path(base_model_name).stem
base_model_path = os.path.join(modules.config.path_checkpoints, base_model_name)
self.base_model_hash = get_sha256(base_model_path)
if refiner_model_name not in ['', 'None']:
self.refiner_model_name = Path(refiner_model_name).stem
refiner_model_path = os.path.join(modules.config.path_checkpoints, refiner_model_name)
self.refiner_model_hash = get_sha256(refiner_model_path)
self.loras = []
for (lora_name, lora_weight) in loras:
if lora_name != 'None':
lora_path = os.path.join(modules.config.path_loras, lora_name)
lora_hash = get_sha256(lora_path)
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
class A1111MetadataParser(MetadataParser):
def get_scheme(self) -> MetadataScheme:
return MetadataScheme.A1111
fooocus_to_a1111 = {
'raw_prompt': 'Raw prompt',
'raw_negative_prompt': 'Raw negative prompt',
'negative_prompt': 'Negative prompt',
'styles': 'Styles',
'performance': 'Performance',
'steps': 'Steps',
'sampler': 'Sampler',
'scheduler': 'Scheduler',
'guidance_scale': 'CFG scale',
'seed': 'Seed',
'resolution': 'Size',
'sharpness': 'Sharpness',
'adm_guidance': 'ADM Guidance',
'refiner_swap_method': 'Refiner Swap Method',
'adaptive_cfg': 'Adaptive CFG',
'overwrite_switch': 'Overwrite Switch',
'freeu': 'FreeU',
'base_model': 'Model',
'base_model_hash': 'Model hash',
'refiner_model': 'Refiner',
'refiner_model_hash': 'Refiner hash',
'lora_hashes': 'Lora hashes',
'lora_weights': 'Lora weights',
'created_by': 'User',
'version': 'Version'
}
def parse_json(self, metadata: str) -> dict:
metadata_prompt = ''
metadata_negative_prompt = ''
done_with_prompt = False
*lines, lastline = metadata.strip().split("\n")
if len(re_param.findall(lastline)) < 3:
lines.append(lastline)
lastline = ''
for line in lines:
line = line.strip()
if line.startswith(f"{self.fooocus_to_a1111['negative_prompt']}:"):
done_with_prompt = True
line = line[len(f"{self.fooocus_to_a1111['negative_prompt']}:"):].strip()
if done_with_prompt:
metadata_negative_prompt += ('' if metadata_negative_prompt == '' else "\n") + line
else:
metadata_prompt += ('' if metadata_prompt == '' else "\n") + line
found_styles, prompt, negative_prompt = extract_styles_from_prompt(metadata_prompt, metadata_negative_prompt)
data = {
'prompt': prompt,
'negative_prompt': negative_prompt
}
for k, v in re_param.findall(lastline):
try:
if v[0] == '"' and v[-1] == '"':
v = unquote(v)
m = re_imagesize.match(v)
if m is not None:
data['resolution'] = str((m.group(1), m.group(2)))
else:
data[list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)]] = v
except Exception:
print(f"Error parsing \"{k}: {v}\"")
# workaround for multiline prompts
if 'raw_prompt' in data:
data['prompt'] = data['raw_prompt']
raw_prompt = data['raw_prompt'].replace("\n", ', ')
if metadata_prompt != raw_prompt and modules.sdxl_styles.fooocus_expansion not in found_styles:
found_styles.append(modules.sdxl_styles.fooocus_expansion)
if 'raw_negative_prompt' in data:
data['negative_prompt'] = data['raw_negative_prompt']
data['styles'] = str(found_styles)
# try to load performance based on steps, fallback for direct A1111 imports
if 'steps' in data and 'performance' not in data:
try:
data['performance'] = Performance[Steps(int(data['steps'])).name].value
except ValueError | KeyError:
pass
if 'sampler' in data:
data['sampler'] = data['sampler'].replace(' Karras', '')
# get key
for k, v in SAMPLERS.items():
if v == data['sampler']:
data['sampler'] = k
break
for key in ['base_model', 'refiner_model']:
if key in data:
for filename in modules.config.model_filenames:
path = Path(filename)
if data[key] == path.stem:
data[key] = filename
break
if 'lora_hashes' in data:
lora_filenames = modules.config.lora_filenames.copy()
lora_filenames.remove(modules.config.downloading_sdxl_lcm_lora())
for li, lora in enumerate(data['lora_hashes'].split(', ')):
lora_name, lora_hash, lora_weight = lora.split(': ')
for filename in lora_filenames:
path = Path(filename)
if lora_name == path.stem:
data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}'
break
return data
def parse_string(self, metadata: dict) -> str:
data = {k: v for _, k, v in metadata}
width, height = eval(data['resolution'])
sampler = data['sampler']
scheduler = data['scheduler']
if sampler in SAMPLERS and SAMPLERS[sampler] != '':
sampler = SAMPLERS[sampler]
if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras':
sampler += f' Karras'
generation_params = {
self.fooocus_to_a1111['steps']: self.steps,
self.fooocus_to_a1111['sampler']: sampler,
self.fooocus_to_a1111['seed']: data['seed'],
self.fooocus_to_a1111['resolution']: f'{width}x{height}',
self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'],
self.fooocus_to_a1111['sharpness']: data['sharpness'],
self.fooocus_to_a1111['adm_guidance']: data['adm_guidance'],
self.fooocus_to_a1111['base_model']: Path(data['base_model']).stem,
self.fooocus_to_a1111['base_model_hash']: self.base_model_hash,
self.fooocus_to_a1111['performance']: data['performance'],
self.fooocus_to_a1111['scheduler']: scheduler,
# workaround for multiline prompts
self.fooocus_to_a1111['raw_prompt']: self.raw_prompt,
self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt,
}
if self.refiner_model_name not in ['', 'None']:
generation_params |= {
self.fooocus_to_a1111['refiner_model']: self.refiner_model_name,
self.fooocus_to_a1111['refiner_model_hash']: self.refiner_model_hash
}
for key in ['adaptive_cfg', 'overwrite_switch', 'refiner_swap_method', 'freeu']:
if key in data:
generation_params[self.fooocus_to_a1111[key]] = data[key]
lora_hashes = []
for index, (lora_name, lora_weight, lora_hash) in enumerate(self.loras):
# workaround for Fooocus not knowing LoRA name in LoRA metadata
lora_hashes.append(f'{lora_name}: {lora_hash}: {lora_weight}')
lora_hashes_string = ', '.join(lora_hashes)
generation_params |= {
self.fooocus_to_a1111['lora_hashes']: lora_hashes_string,
self.fooocus_to_a1111['version']: data['version']
}
if modules.config.metadata_created_by != '':
generation_params[self.fooocus_to_a1111['created_by']] = modules.config.metadata_created_by
generation_params_text = ", ".join(
[k if k == v else f'{k}: {quote(v)}' for k, v in generation_params.items() if
v is not None])
positive_prompt_resolved = ', '.join(self.full_prompt)
negative_prompt_resolved = ', '.join(self.full_negative_prompt)
negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else ""
return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip()
class FooocusMetadataParser(MetadataParser):
def get_scheme(self) -> MetadataScheme:
return MetadataScheme.FOOOCUS
def parse_json(self, metadata: dict) -> dict:
model_filenames = modules.config.model_filenames.copy()
lora_filenames = modules.config.lora_filenames.copy()
lora_filenames.remove(modules.config.downloading_sdxl_lcm_lora())
for key, value in metadata.items():
if value in ['', 'None']:
continue
if key in ['base_model', 'refiner_model']:
metadata[key] = self.replace_value_with_filename(key, value, model_filenames)
elif key.startswith('lora_combined_'):
metadata[key] = self.replace_value_with_filename(key, value, lora_filenames)
else:
continue
return metadata
def parse_string(self, metadata: list) -> str:
for li, (label, key, value) in enumerate(metadata):
# remove model folder paths from metadata
if key.startswith('lora_combined_'):
name, weight = value.split(' : ')
name = Path(name).stem
value = f'{name} : {weight}'
metadata[li] = (label, key, value)
res = {k: v for _, k, v in metadata}
res['full_prompt'] = self.full_prompt
res['full_negative_prompt'] = self.full_negative_prompt
res['steps'] = self.steps
res['base_model'] = self.base_model_name
res['base_model_hash'] = self.base_model_hash
if self.refiner_model_name not in ['', 'None']:
res['refiner_model'] = self.refiner_model_name
res['refiner_model_hash'] = self.refiner_model_hash
res['loras'] = self.loras
if modules.config.metadata_created_by != '':
res['created_by'] = modules.config.metadata_created_by
return json.dumps(dict(sorted(res.items())))
@staticmethod
def replace_value_with_filename(key, value, filenames):
for filename in filenames:
path = Path(filename)
if key.startswith('lora_combined_'):
name, weight = value.split(' : ')
if name == path.stem:
return f'{filename} : {weight}'
elif value == path.stem:
return filename
def get_metadata_parser(metadata_scheme: MetadataScheme) -> MetadataParser:
match metadata_scheme:
case MetadataScheme.FOOOCUS:
return FooocusMetadataParser()
case MetadataScheme.A1111:
return A1111MetadataParser()
case _:
raise NotImplementedError
def read_info_from_image(filepath) -> tuple[str | None, dict, MetadataScheme | None]:
with Image.open(filepath) as image:
items = (image.info or {}).copy()
parameters = items.pop('parameters', None)
if parameters is not None and is_json(parameters):
parameters = json.loads(parameters)
try:
metadata_scheme = MetadataScheme(items.pop('fooocus_scheme', None))
except ValueError:
metadata_scheme = None
# broad fallback
if isinstance(parameters, dict):
metadata_scheme = MetadataScheme.FOOOCUS
if isinstance(parameters, str):
metadata_scheme = MetadataScheme.A1111
return parameters, items, metadata_scheme

View File

@ -6,8 +6,9 @@ import urllib.parse
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from fooocus_version import version
from modules.util import generate_temp_filename
from tempfile import gettempdir
from modules.meta_parser import MetadataParser
log_cache = {}
@ -20,24 +21,26 @@ def get_current_html_path(output_format=None):
return html_name
def log(img, dic, metadata=None, save_metadata_to_image=False, output_format=None) -> str:
def log(img, metadata, metadata_parser: MetadataParser | None = None, output_format=None) -> str:
path_outputs = args_manager.args.temp_path if args_manager.args.disable_image_log else modules.config.path_outputs
output_format = output_format if output_format else modules.config.default_output_format
date_string, local_temp_filename, only_name = generate_temp_filename(folder=path_outputs, extension=output_format)
os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True)
parsed_parameters = metadata_parser.parse_string(metadata) if metadata_parser else None
if output_format == 'png':
if save_metadata_to_image:
if parsed_parameters != '':
pnginfo = PngInfo()
pnginfo.add_text("parameters", metadata)
pnginfo.add_text('parameters', parsed_parameters)
pnginfo.add_text('fooocus_scheme', metadata_parser.get_scheme().value)
else:
pnginfo = None
Image.fromarray(img).save(local_temp_filename, pnginfo=pnginfo)
elif output_format == 'jpg':
# TODO check if metadata works correctly here
Image.fromarray(img).save(local_temp_filename, quality=95, optimize=True, progressive=True, comment=metadata if save_metadata_to_image else None)
Image.fromarray(img).save(local_temp_filename, quality=95, optimize=True, progressive=True)
elif output_format == 'webp':
# TODO test exif handling
Image.fromarray(img).save(local_temp_filename, quality=95, lossless=False)
else:
Image.fromarray(img).save(local_temp_filename)
@ -52,7 +55,7 @@ def log(img, dic, metadata=None, save_metadata_to_image=False, output_format=Non
"body { background-color: #121212; color: #E0E0E0; } "
"a { color: #BB86FC; } "
".metadata { border-collapse: collapse; width: 100%; } "
".metadata .key { width: 15%; } "
".metadata .label { width: 15%; } "
".metadata .value { width: 85%; font-weight: bold; } "
".metadata th, .metadata td { border: 1px solid #4d4d4d; padding: 4px; } "
".image-container img { height: auto; max-width: 512px; display: block; padding-right:10px; } "
@ -105,12 +108,12 @@ def log(img, dic, metadata=None, save_metadata_to_image=False, output_format=Non
item = f"<div id=\"{div_name}\" class=\"image-container\"><hr><table><tr>\n"
item += f"<td><a href=\"{only_name}\" target=\"_blank\"><img src='{only_name}' onerror=\"this.closest('.image-container').style.display='none';\" loading='lazy'></img></a><div>{only_name}</div></td>"
item += "<td><table class='metadata'>"
for key, value in dic:
for label, key, value in metadata:
value_txt = str(value).replace('\n', ' </br> ')
item += f"<tr><td class='key'>{key}</td><td class='value'>{value_txt}</td></tr>\n"
item += f"<tr><td class='label'>{label}</td><td class='value'>{value_txt}</td></tr>\n"
item += "</table>"
js_txt = urllib.parse.quote(json.dumps({k: v for k, v in dic}, indent=0), safe='')
js_txt = urllib.parse.quote(json.dumps({k: v for _, k, v in metadata}, indent=0), safe='')
item += f"</br><button onclick=\"to_clipboard('{js_txt}')\">Copy to Clipboard</button>"
item += "</td>"

View File

@ -1,3 +1,5 @@
import typing
import numpy as np
import datetime
import random
@ -9,9 +11,10 @@ import json
from PIL import Image
from hashlib import sha256
import modules.sdxl_styles
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
HASH_SHA256_LENGTH = 10
def erode_or_dilate(x, k):
k = int(k)
@ -172,13 +175,14 @@ def get_files_from_folder(folder_path, exensions=None, name_filter=None):
relative_path = ""
for filename in sorted(files):
_, file_extension = os.path.splitext(filename)
if (exensions == None or file_extension.lower() in exensions) and (name_filter == None or name_filter in _):
if (exensions is None or file_extension.lower() in exensions) and (name_filter is None or name_filter in _):
path = os.path.join(relative_path, filename)
filenames.append(path)
return sorted(filenames, key=lambda x: -1 if os.sep in x else 1)
def calculate_sha256(filename):
def calculate_sha256(filename, length=HASH_SHA256_LENGTH) -> str:
hash_sha256 = sha256()
blksize = 1024 * 1024
@ -186,10 +190,153 @@ def calculate_sha256(filename):
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
res = hash_sha256.hexdigest()
return res[:length] if length else res
def quote(text):
if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
return text
return json.dumps(text, ensure_ascii=False)
def unquote(text):
if len(text) == 0 or text[0] != '"' or text[-1] != '"':
return text
try:
return json.loads(text)
except Exception:
return text
def unwrap_style_text_from_prompt(style_text, prompt):
"""
Checks the prompt to see if the style text is wrapped around it. If so,
returns True plus the prompt text without the style text. Otherwise, returns
False with the original prompt.
Note that the "cleaned" version of the style text is only used for matching
purposes here. It isn't returned; the original style text is not modified.
"""
stripped_prompt = prompt
stripped_style_text = style_text
if "{prompt}" in stripped_style_text:
# Work out whether the prompt is wrapped in the style text. If so, we
# return True and the "inner" prompt text that isn't part of the style.
try:
left, right = stripped_style_text.split("{prompt}", 2)
except ValueError as e:
# If the style text has multple "{prompt}"s, we can't split it into
# two parts. This is an error, but we can't do anything about it.
print(f"Unable to compare style text to prompt:\n{style_text}")
print(f"Error: {e}")
return False, prompt, ''
left_pos = stripped_prompt.find(left)
right_pos = stripped_prompt.find(right)
if 0 <= left_pos < right_pos:
real_prompt = stripped_prompt[left_pos + len(left):right_pos]
prompt = stripped_prompt.replace(left + real_prompt + right, '', 1)
if prompt.startswith(", "):
prompt = prompt[2:]
if prompt.endswith(", "):
prompt = prompt[:-2]
return True, prompt, real_prompt
else:
# Work out whether the given prompt starts with the style text. If so, we
# return True and the prompt text up to where the style text starts.
if stripped_prompt.endswith(stripped_style_text):
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
if prompt.endswith(", "):
prompt = prompt[:-2]
return True, prompt, prompt
return False, prompt, ''
def extract_original_prompts(style, prompt, negative_prompt):
"""
Takes a style and compares it to the prompt and negative prompt. If the style
matches, returns True plus the prompt and negative prompt with the style text
removed. Otherwise, returns False with the original prompt and negative prompt.
"""
if not style.prompt and not style.negative_prompt:
return False, prompt, negative_prompt
match_positive, extracted_positive, real_prompt = unwrap_style_text_from_prompt(
style.prompt, prompt
)
if not match_positive:
return False, prompt, negative_prompt, ''
match_negative, extracted_negative, _ = unwrap_style_text_from_prompt(
style.negative_prompt, negative_prompt
)
if not match_negative:
return False, prompt, negative_prompt, ''
return True, extracted_positive, extracted_negative, real_prompt
def extract_styles_from_prompt(prompt, negative_prompt):
extracted = []
applicable_styles = []
for style_name, (style_prompt, style_negative_prompt) in modules.sdxl_styles.styles.items():
applicable_styles.append(PromptStyle(name=style_name, prompt=style_prompt, negative_prompt=style_negative_prompt))
real_prompt = ''
while True:
found_style = None
for style in applicable_styles:
is_match, new_prompt, new_neg_prompt, new_real_prompt = extract_original_prompts(
style, prompt, negative_prompt
)
if is_match:
found_style = style
prompt = new_prompt
negative_prompt = new_neg_prompt
if real_prompt == '' and new_real_prompt != '' and new_real_prompt != prompt:
real_prompt = new_real_prompt
break
if not found_style:
break
applicable_styles.remove(found_style)
extracted.append(found_style.name)
# add prompt expansion if not all styles could be resolved
if prompt != '':
if real_prompt != '':
extracted.append(modules.sdxl_styles.fooocus_expansion)
else:
# find real_prompt when only prompt expansion is selected
first_word = prompt.split(', ')[0]
first_word_positions = [i for i in range(len(prompt)) if prompt.startswith(first_word, i)]
if len(first_word_positions) > 1:
real_prompt = prompt[:first_word_positions[-1]]
extracted.append(modules.sdxl_styles.fooocus_expansion)
if real_prompt.endswith(', '):
real_prompt = real_prompt[:-2]
return list(reversed(extracted)), real_prompt, negative_prompt
class PromptStyle(typing.NamedTuple):
name: str
prompt: str
negative_prompt: str
def is_json(data: str) -> bool:
try:
loaded_json = json.loads(data)
assert isinstance(loaded_json, dict)
except (ValueError, AssertionError):
return False
return True

100
webui.py
View File

@ -21,6 +21,7 @@ from modules.sdxl_styles import legal_style_names
from modules.private_logger import get_current_html_path
from modules.ui_gradio_extensions import reload_javascript
from modules.auth import auth_enabled, check_auth
from modules.util import is_json
def get_task(*args):
args = list(args)
@ -159,7 +160,7 @@ with shared.gradio_root:
ip_weights = []
ip_ctrls = []
ip_ad_cols = []
for _ in range(4):
for _ in range(flags.controlnet_image_count):
with gr.Column():
ip_image = grh.Image(label='Image', source='upload', type='numpy', show_label=False, height=300)
ip_images.append(ip_image)
@ -270,6 +271,30 @@ with shared.gradio_root:
value=flags.desc_type_photo)
desc_btn = gr.Button(value='Describe this Image into Prompt')
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/1363" target="_blank">\U0001F4D4 Document</a>')
with gr.TabItem(label='Metadata') as load_tab:
with gr.Column():
metadata_input_image = grh.Image(label='Drag any image generated by Fooocus here', source='upload', type='filepath')
metadata_json = gr.JSON(label='Metadata')
metadata_import_button = gr.Button(value='Apply Metadata')
def trigger_metadata_preview(filepath):
parameters, items, metadata_scheme = modules.meta_parser.read_info_from_image(filepath)
results = {}
if parameters is not None:
results['parameters'] = parameters
if items:
results['items'] = items
if isinstance(metadata_scheme, flags.MetadataScheme):
results['metadata_scheme'] = metadata_scheme.value
return results
metadata_input_image.upload(trigger_metadata_preview, inputs=metadata_input_image,
outputs=metadata_json, queue=False, show_progress=True)
switch_js = "(x) => {if(x){viewer_to_bottom(100);viewer_to_bottom(500);}else{viewer_to_top();} return x;}"
down_js = "() => {viewer_to_bottom();}"
@ -301,7 +326,8 @@ with shared.gradio_root:
output_format = gr.Radio(label='Output Format',
choices=modules.flags.output_formats,
value=modules.config.default_output_format)
value=modules.config.default_output_format,
info='Metadata support has only been implemented for png.')
negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.",
info='Describing what you do not want to see.', lines=2,
@ -423,7 +449,7 @@ with shared.gradio_root:
step=0.001, value=0.3,
info='When to end the guidance from positive/negative ADM. ')
refiner_swap_method = gr.Dropdown(label='Refiner swap method', value='joint',
refiner_swap_method = gr.Dropdown(label='Refiner swap method', value=flags.refiner_swap_method,
choices=['joint', 'separate', 'vae'])
adaptive_cfg = gr.Slider(label='CFG Mimicking from TSNR', minimum=1.0, maximum=30.0, step=0.01,
@ -481,7 +507,7 @@ with shared.gradio_root:
save_metadata_to_images = gr.Checkbox(label='Save Metadata to Images', value=modules.config.default_save_metadata_to_images,
info='Adds parameters to generated images allowing manual regeneration.')
metadata_scheme = gr.Radio(label='Metadata Scheme', choices=flags.metadata_scheme, value=modules.config.default_metadata_scheme,
info='Use A1111 for compatibility with Civitai.',
info='Image Prompt parameters are not included. Use a1111 for compatibility with Civitai.',
visible=modules.config.default_save_metadata_to_images)
save_metadata_to_images.change(lambda x: gr.update(visible=x), inputs=[save_metadata_to_images], outputs=[metadata_scheme],
@ -563,11 +589,11 @@ with shared.gradio_root:
modules.config.update_all_model_names()
modules.config.update_presets()
results = []
results += [gr.update(choices=modules.config.model_filenames),
results += [gr.update(choices=modules.config.model_filenames),
gr.update(choices=['None'] + modules.config.model_filenames)]
if not args_manager.args.disable_preset_selection:
results += [gr.update(choices=modules.config.available_presets)]
for i in range(5):
for i in range(flags.lora_count):
results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
return results
@ -597,39 +623,19 @@ with shared.gradio_root:
state_is_generating = gr.State(False)
load_parameter_outputs = [
advanced_checkbox,
image_number,
prompt,
negative_prompt,
style_selections,
performance_selection,
aspect_ratios_selection,
overwrite_width,
overwrite_height,
sharpness,
guidance_scale,
adm_scaler_positive,
adm_scaler_negative,
adm_scaler_end,
base_model,
refiner_model,
refiner_switch,
sampler_name,
scheduler_name,
overwrite_step,
seed_random,
image_seed,
generate_button,
load_parameter_button
] + lora_ctrls
load_data_outputs = [advanced_checkbox, image_number, prompt, negative_prompt, style_selections,
performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection,
overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive,
adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model,
refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed,
generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls
if not args_manager.args.disable_preset_selection:
def preset_selection_change(preset, is_generating):
preset_content = modules.config.try_get_preset_content(preset) if preset != 'initial' else {}
preset_prepared = modules.meta_parser.parse_meta_from_preset(preset_content)
default_model = preset_prepared['Base Model']
default_model = preset_prepared['base_model']
previous_default_models = preset_prepared['previous_default_models']
checkpoint_downloads = preset_prepared['checkpoint_downloads']
embeddings_downloads = preset_prepared['embeddings_downloads']
@ -640,7 +646,7 @@ with shared.gradio_root:
return modules.meta_parser.load_parameter_button_click(json.dumps(preset_prepared), is_generating)
preset_selection.change(preset_selection_change, inputs=[preset_selection, state_is_generating], outputs=load_parameter_outputs, queue=False, show_progress=True) \
preset_selection.change(preset_selection_change, inputs=[preset_selection, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=True) \
.then(fn=style_sorter.sort_styles, inputs=style_selections, outputs=style_selections, queue=False, show_progress=False) \
.then(lambda: None, _js='()=>{refresh_style_localization();}')
@ -721,14 +727,8 @@ with shared.gradio_root:
def parse_meta(raw_prompt_txt, is_generating):
loaded_json = None
try:
if '{' in raw_prompt_txt:
if '}' in raw_prompt_txt:
if ':' in raw_prompt_txt:
loaded_json = json.loads(raw_prompt_txt)
assert isinstance(loaded_json, dict)
except:
loaded_json = None
if is_json(raw_prompt_txt):
loaded_json = json.loads(raw_prompt_txt)
if loaded_json is None:
if is_generating:
@ -740,7 +740,21 @@ with shared.gradio_root:
prompt.input(parse_meta, inputs=[prompt, state_is_generating], outputs=[prompt, generate_button, load_parameter_button], queue=False, show_progress=False)
load_parameter_button.click(modules.meta_parser.load_parameter_button_click, inputs=[prompt, state_is_generating], outputs=load_parameter_outputs, queue=False, show_progress=False)
load_parameter_button.click(modules.meta_parser.load_parameter_button_click, inputs=[prompt, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=False)
def trigger_metadata_import(filepath, state_is_generating):
parameters, items, metadata_scheme = modules.meta_parser.read_info_from_image(filepath)
if parameters is None:
print('Could not find metadata in the image!')
parsed_parameters = {}
else:
metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme)
parsed_parameters = metadata_parser.parse_json(parameters)
return modules.meta_parser.load_parameter_button_click(parsed_parameters, state_is_generating)
metadata_import_button.click(trigger_metadata_import, inputs=[metadata_input_image, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=True) \
.then(style_sorter.sort_styles, inputs=style_selections, outputs=style_selections, queue=False, show_progress=False)
generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False, interactive=False), [], True),
outputs=[stop_button, skip_button, generate_button, gallery, state_is_generating]) \