From c32bc5e199f7a0a45736f10c248cd1955433a609 Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Thu, 9 May 2024 18:59:35 +0200 Subject: [PATCH 1/2] feat: add optional model VAE select (#2867) * Revert "fix: use LF as line breaks for Docker entrypoint.sh (#2843)" (#2865) False alarm, worked as intended before. Sorry for the fuzz. This reverts commit d16a54edd69f82158ae7ffe5669618db33a01ac7. * feat: add VAE select * feat: use different default label, add translation * fix: do not reload model when VAE stays the same * refactor: code cleanup * feat: add metadata handling --- language/en.json | 2 ++ ldm_patched/modules/sd.py | 13 +++++++++---- modules/async_worker.py | 6 ++++-- modules/config.py | 14 +++++++++++++- modules/core.py | 10 ++++++---- modules/default_pipeline.py | 22 ++++++++++++++-------- modules/flags.py | 2 ++ modules/meta_parser.py | 31 ++++++++++++++++++++++++------- modules/util.py | 3 +++ webui.py | 11 +++++++---- 10 files changed, 84 insertions(+), 30 deletions(-) diff --git a/language/en.json b/language/en.json index d10c29dc..1fe78662 100644 --- a/language/en.json +++ b/language/en.json @@ -340,6 +340,8 @@ "sgm_uniform": "sgm_uniform", "simple": "simple", "ddim_uniform": "ddim_uniform", + "VAE": "VAE", + "Default (model)": "Default (model)", "Forced Overwrite of Sampling Step": "Forced Overwrite of Sampling Step", "Set as -1 to disable. For developer debugging.": "Set as -1 to disable. For developer debugging.", "Forced Overwrite of Refiner Switch Step": "Forced Overwrite of Refiner Switch Step", diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index e197c39c..282f2559 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -427,12 +427,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, vae_filename_param=None): sd = ldm_patched.modules.utils.load_torch_file(ckpt_path) sd_keys = sd.keys() clip = None clipvision = None vae = None + vae_filename = None model = None model_patcher = None clip_target = None @@ -462,8 +463,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model.load_model_weights(sd, "model.diffusion_model.") if output_vae: - vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) - vae_sd = model_config.process_vae_state_dict(vae_sd) + if vae_filename_param is None: + vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae_sd = model_config.process_vae_state_dict(vae_sd) + else: + vae_sd = ldm_patched.modules.utils.load_torch_file(vae_filename_param) + vae_filename = vae_filename_param vae = VAE(sd=vae_sd) if output_clip: @@ -485,7 +490,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o print("loaded straight to GPU") model_management.load_model_gpu(model_patcher) - return (model_patcher, clip, vae, clipvision) + return model_patcher, clip, vae, vae_filename, clipvision def load_unet_state_dict(sd): #load unet in diffusers format diff --git a/modules/async_worker.py b/modules/async_worker.py index d8a1e072..3576c4ec 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -166,6 +166,7 @@ def worker(): adaptive_cfg = args.pop() sampler_name = args.pop() scheduler_name = args.pop() + vae_name = args.pop() overwrite_step = args.pop() overwrite_switch = args.pop() overwrite_width = args.pop() @@ -428,7 +429,7 @@ def worker(): progressbar(async_task, 3, 'Loading models ...') pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras, base_model_additional_loras=base_model_additional_loras, - use_synthetic_refiner=use_synthetic_refiner) + use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name) progressbar(async_task, 3, 'Processing prompts ...') tasks = [] @@ -869,6 +870,7 @@ def worker(): d.append(('Sampler', 'sampler', sampler_name)) d.append(('Scheduler', 'scheduler', scheduler_name)) + d.append(('VAE', 'vae', vae_name)) d.append(('Seed', 'seed', str(task['task_seed']))) if freeu_enabled: @@ -883,7 +885,7 @@ def worker(): metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme) metadata_parser.set_data(task['log_positive_prompt'], task['positive'], task['log_negative_prompt'], task['negative'], - steps, base_model_name, refiner_model_name, loras) + steps, base_model_name, refiner_model_name, loras, vae_name) d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version)) img_paths.append(log(x, d, metadata_parser, output_format)) diff --git a/modules/config.py b/modules/config.py index b81e218a..f11460c8 100644 --- a/modules/config.py +++ b/modules/config.py @@ -189,6 +189,7 @@ paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/check paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True) path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/') path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/') +path_vae = get_dir_or_set_default('path_vae', '../models/vae/') path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/') path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/') path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/') @@ -346,6 +347,11 @@ default_scheduler = get_config_item_or_set_default( default_value='karras', validator=lambda x: x in modules.flags.scheduler_list ) +default_vae = get_config_item_or_set_default( + key='default_vae', + default_value=modules.flags.default_vae, + validator=lambda x: isinstance(x, str) +) default_styles = get_config_item_or_set_default( key='default_styles', default_value=[ @@ -535,6 +541,7 @@ with open(config_example_path, "w", encoding="utf-8") as json_file: model_filenames = [] lora_filenames = [] +vae_filenames = [] wildcard_filenames = [] sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors' @@ -546,15 +553,20 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None): if extensions is None: extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'] files = [] + + if not isinstance(folder_paths, list): + folder_paths = [folder_paths] for folder in folder_paths: files += get_files_from_folder(folder, extensions, name_filter) + return files def update_files(): - global model_filenames, lora_filenames, wildcard_filenames, available_presets + global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets model_filenames = get_model_filenames(paths_checkpoints) lora_filenames = get_model_filenames(paths_loras) + vae_filenames = get_model_filenames(path_vae) wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt']) available_presets = get_presets() return diff --git a/modules/core.py b/modules/core.py index 38ee8e8d..3ca4cc5b 100644 --- a/modules/core.py +++ b/modules/core.py @@ -35,12 +35,13 @@ opModelSamplingDiscrete = ModelSamplingDiscrete() class StableDiffusionModel: - def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None): + def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None, vae_filename=None): self.unet = unet self.vae = vae self.clip = clip self.clip_vision = clip_vision self.filename = filename + self.vae_filename = vae_filename self.unet_with_lora = unet self.clip_with_lora = clip self.visited_loras = '' @@ -142,9 +143,10 @@ def apply_controlnet(positive, negative, control_net, image, strength, start_per @torch.no_grad() @torch.inference_mode() -def load_model(ckpt_filename): - unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings) - return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename) +def load_model(ckpt_filename, vae_filename=None): + unet, clip, vae, vae_filename, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings, + vae_filename_param=vae_filename) + return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename, vae_filename=vae_filename) @torch.no_grad() diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 190601ec..38f914c5 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -3,6 +3,7 @@ import os import torch import modules.patch import modules.config +import modules.flags import ldm_patched.modules.model_management import ldm_patched.modules.latent_formats import modules.inpaint_worker @@ -58,17 +59,21 @@ def assert_model_integrity(): @torch.no_grad() @torch.inference_mode() -def refresh_base_model(name): +def refresh_base_model(name, vae_name=None): global model_base filename = get_file_from_folder_list(name, modules.config.paths_checkpoints) - if model_base.filename == filename: + vae_filename = None + if vae_name is not None and vae_name != modules.flags.default_vae: + vae_filename = get_file_from_folder_list(vae_name, modules.config.path_vae) + + if model_base.filename == filename and model_base.vae_filename == vae_filename: return - model_base = core.StableDiffusionModel() - model_base = core.load_model(filename) + model_base = core.load_model(filename, vae_filename) print(f'Base model loaded: {model_base.filename}') + print(f'VAE loaded: {model_base.vae_filename}') return @@ -216,7 +221,7 @@ def prepare_text_encoder(async_call=True): @torch.no_grad() @torch.inference_mode() def refresh_everything(refiner_model_name, base_model_name, loras, - base_model_additional_loras=None, use_synthetic_refiner=False): + base_model_additional_loras=None, use_synthetic_refiner=False, vae_name=None): global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion final_unet = None @@ -227,11 +232,11 @@ def refresh_everything(refiner_model_name, base_model_name, loras, if use_synthetic_refiner and refiner_model_name == 'None': print('Synthetic Refiner Activated') - refresh_base_model(base_model_name) + refresh_base_model(base_model_name, vae_name) synthesize_refiner_model() else: refresh_refiner_model(refiner_model_name) - refresh_base_model(base_model_name) + refresh_base_model(base_model_name, vae_name) refresh_loras(loras, base_model_additional_loras=base_model_additional_loras) assert_model_integrity() @@ -254,7 +259,8 @@ def refresh_everything(refiner_model_name, base_model_name, loras, refresh_everything( refiner_model_name=modules.config.default_refiner_model_name, base_model_name=modules.config.default_base_model_name, - loras=get_enabled_loras(modules.config.default_loras) + loras=get_enabled_loras(modules.config.default_loras), + vae_name=modules.config.default_vae, ) diff --git a/modules/flags.py b/modules/flags.py index c9d13fd8..9f2aefb3 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -53,6 +53,8 @@ SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys()) sampler_list = SAMPLER_NAMES scheduler_list = SCHEDULER_NAMES +default_vae = 'Default (model)' + refiner_swap_method = 'joint' cn_ip = "ImagePrompt" diff --git a/modules/meta_parser.py b/modules/meta_parser.py index 70ab8860..84032e82 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -46,6 +46,7 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool): get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results) get_str('sampler', 'Sampler', loaded_parameter_dict, results) get_str('scheduler', 'Scheduler', loaded_parameter_dict, results) + get_str('vae', 'VAE', loaded_parameter_dict, results) get_seed('seed', 'Seed', loaded_parameter_dict, results) if is_generating: @@ -253,6 +254,7 @@ class MetadataParser(ABC): self.refiner_model_name: str = '' self.refiner_model_hash: str = '' self.loras: list = [] + self.vae_name: str = '' @abstractmethod def get_scheme(self) -> MetadataScheme: @@ -267,7 +269,7 @@ class MetadataParser(ABC): raise NotImplementedError def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name, - refiner_model_name, loras): + refiner_model_name, loras, vae_name): self.raw_prompt = raw_prompt self.full_prompt = full_prompt self.raw_negative_prompt = raw_negative_prompt @@ -289,6 +291,7 @@ class MetadataParser(ABC): lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras) lora_hash = get_sha256(lora_path) self.loras.append((Path(lora_name).stem, lora_weight, lora_hash)) + self.vae_name = Path(vae_name).stem @staticmethod def remove_special_loras(lora_filenames): @@ -310,6 +313,7 @@ class A1111MetadataParser(MetadataParser): 'steps': 'Steps', 'sampler': 'Sampler', 'scheduler': 'Scheduler', + 'vae': 'VAE', 'guidance_scale': 'CFG scale', 'seed': 'Seed', 'resolution': 'Size', @@ -397,13 +401,12 @@ class A1111MetadataParser(MetadataParser): data['sampler'] = k break - for key in ['base_model', 'refiner_model']: + for key in ['base_model', 'refiner_model', 'vae']: if key in data: - for filename in modules.config.model_filenames: - path = Path(filename) - if data[key] == path.stem: - data[key] = filename - break + if key == 'vae': + self.add_extension_to_filename(data, modules.config.vae_filenames, 'vae') + else: + self.add_extension_to_filename(data, modules.config.model_filenames, key) lora_data = '' if 'lora_weights' in data and data['lora_weights'] != '': @@ -433,6 +436,7 @@ class A1111MetadataParser(MetadataParser): sampler = data['sampler'] scheduler = data['scheduler'] + if sampler in SAMPLERS and SAMPLERS[sampler] != '': sampler = SAMPLERS[sampler] if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras': @@ -451,6 +455,7 @@ class A1111MetadataParser(MetadataParser): self.fooocus_to_a1111['performance']: data['performance'], self.fooocus_to_a1111['scheduler']: scheduler, + self.fooocus_to_a1111['vae']: Path(data['vae']).stem, # workaround for multiline prompts self.fooocus_to_a1111['raw_prompt']: self.raw_prompt, self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt, @@ -491,6 +496,14 @@ class A1111MetadataParser(MetadataParser): negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else "" return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip() + @staticmethod + def add_extension_to_filename(data, filenames, key): + for filename in filenames: + path = Path(filename) + if data[key] == path.stem: + data[key] = filename + break + class FooocusMetadataParser(MetadataParser): def get_scheme(self) -> MetadataScheme: @@ -499,6 +512,7 @@ class FooocusMetadataParser(MetadataParser): def parse_json(self, metadata: dict) -> dict: model_filenames = modules.config.model_filenames.copy() lora_filenames = modules.config.lora_filenames.copy() + vae_filenames = modules.config.vae_filenames.copy() self.remove_special_loras(lora_filenames) for key, value in metadata.items(): if value in ['', 'None']: @@ -507,6 +521,8 @@ class FooocusMetadataParser(MetadataParser): metadata[key] = self.replace_value_with_filename(key, value, model_filenames) elif key.startswith('lora_combined_'): metadata[key] = self.replace_value_with_filename(key, value, lora_filenames) + elif key == 'vae': + metadata[key] = self.replace_value_with_filename(key, value, vae_filenames) else: continue @@ -533,6 +549,7 @@ class FooocusMetadataParser(MetadataParser): res['refiner_model'] = self.refiner_model_name res['refiner_model_hash'] = self.refiner_model_hash + res['vae'] = self.vae_name res['loras'] = self.loras if modules.config.metadata_created_by != '': diff --git a/modules/util.py b/modules/util.py index 9e0fb294..d2feecb6 100644 --- a/modules/util.py +++ b/modules/util.py @@ -371,6 +371,9 @@ def is_json(data: str) -> bool: def get_file_from_folder_list(name, folders): + if not isinstance(folders, list): + folders = [folders] + for folder in folders: filename = os.path.abspath(os.path.realpath(os.path.join(folder, name))) if os.path.isfile(filename): diff --git a/webui.py b/webui.py index ababb8b0..eec6054a 100644 --- a/webui.py +++ b/webui.py @@ -407,6 +407,8 @@ with shared.gradio_root: value=modules.config.default_sampler) scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list, value=modules.config.default_scheduler) + vae_name = gr.Dropdown(label='VAE', choices=[modules.flags.default_vae] + modules.config.vae_filenames, + value=modules.config.default_vae, show_label=True) generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch', info='(Experimental) This may cause performance problems on some computers and certain internet conditions.', @@ -529,6 +531,7 @@ with shared.gradio_root: modules.config.update_files() results = [gr.update(choices=modules.config.model_filenames)] results += [gr.update(choices=['None'] + modules.config.model_filenames)] + results += [gr.update(choices=['None'] + modules.config.vae_filenames)] if not args_manager.args.disable_preset_selection: results += [gr.update(choices=modules.config.available_presets)] for i in range(modules.config.default_max_lora_number): @@ -536,7 +539,7 @@ with shared.gradio_root: gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()] return results - refresh_files_output = [base_model, refiner_model] + refresh_files_output = [base_model, refiner_model, vae_name] if not args_manager.args.disable_preset_selection: refresh_files_output += [preset_selection] refresh_files.click(refresh_files_clicked, [], refresh_files_output + lora_ctrls, @@ -548,8 +551,8 @@ with shared.gradio_root: performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection, overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model, - refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed, - generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls + refiner_model, refiner_switch, sampler_name, scheduler_name, vae_name, seed_random, + image_seed, generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls if not args_manager.args.disable_preset_selection: def preset_selection_change(preset, is_generating): @@ -635,7 +638,7 @@ with shared.gradio_root: ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image] ctrls += [disable_preview, disable_intermediate_results, disable_seed_increment] ctrls += [adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg] - ctrls += [sampler_name, scheduler_name] + ctrls += [sampler_name, scheduler_name, vae_name] ctrls += [overwrite_step, overwrite_switch, overwrite_width, overwrite_height, overwrite_vary_strength] ctrls += [overwrite_upscale_strength, mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint] ctrls += [debugging_cn_preprocessor, skipping_cn_preprocessor, canny_low_threshold, canny_high_threshold] From f54364fe4ebd737349611c1d040703b0ac7ace68 Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Thu, 9 May 2024 19:02:04 +0200 Subject: [PATCH 2/2] feat: add random style checkbox to styles selection (#2855) * feat: add random style * feat: rename random to random style, add translation * feat: add preview image for random style --- language/en.json | 1 + modules/async_worker.py | 11 ++++++++--- modules/sdxl_styles.py | 10 ++++++++-- sdxl_styles/samples/random_style.jpg | Bin 0 -> 1454 bytes 4 files changed, 17 insertions(+), 5 deletions(-) create mode 100644 sdxl_styles/samples/random_style.jpg diff --git a/language/en.json b/language/en.json index 1fe78662..20189b28 100644 --- a/language/en.json +++ b/language/en.json @@ -58,6 +58,7 @@ "\ud83d\udcda History Log": "\uD83D\uDCDA History Log", "Image Style": "Image Style", "Fooocus V2": "Fooocus V2", + "Random Style": "Random Style", "Default (Slightly Cinematic)": "Default (Slightly Cinematic)", "Fooocus Masterpiece": "Fooocus Masterpiece", "Fooocus Photograph": "Fooocus Photograph", diff --git a/modules/async_worker.py b/modules/async_worker.py index 3576c4ec..432bfe9b 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -43,7 +43,7 @@ def worker(): import fooocus_version import args_manager - from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays + from modules.sdxl_styles import apply_style, get_random_style, apply_wildcards, fooocus_expansion, apply_arrays, random_style_name from modules.private_logger import log from extras.expansion import safe_str from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \ @@ -450,8 +450,12 @@ def worker(): positive_basic_workloads = [] negative_basic_workloads = [] + task_styles = style_selections.copy() if use_style: - for s in style_selections: + for i, s in enumerate(task_styles): + if s == random_style_name: + s = get_random_style(task_rng) + task_styles[i] = s p, n = apply_style(s, positive=task_prompt) positive_basic_workloads = positive_basic_workloads + p negative_basic_workloads = negative_basic_workloads + n @@ -479,6 +483,7 @@ def worker(): negative_top_k=len(negative_basic_workloads), log_positive_prompt='\n'.join([task_prompt] + task_extra_positive_prompts), log_negative_prompt='\n'.join([task_negative_prompt] + task_extra_negative_prompts), + styles=task_styles )) if use_expansion: @@ -843,7 +848,7 @@ def worker(): d = [('Prompt', 'prompt', task['log_positive_prompt']), ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']), ('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']), - ('Styles', 'styles', str(raw_style_selections)), + ('Styles', 'styles', str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])), ('Performance', 'performance', performance_selection.value)] if performance_selection.steps() != steps: diff --git a/modules/sdxl_styles.py b/modules/sdxl_styles.py index 77ad6b57..5b6afb59 100644 --- a/modules/sdxl_styles.py +++ b/modules/sdxl_styles.py @@ -5,6 +5,7 @@ import math import modules.config from modules.util import get_files_from_folder +from random import Random # cannot use modules.config - validators causing circular imports styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/')) @@ -50,8 +51,13 @@ for styles_file in styles_files: print(f'Failed to load style file {styles_file}') style_keys = list(styles.keys()) -fooocus_expansion = "Fooocus V2" -legal_style_names = [fooocus_expansion] + style_keys +fooocus_expansion = 'Fooocus V2' +random_style_name = 'Random Style' +legal_style_names = [fooocus_expansion, random_style_name] + style_keys + + +def get_random_style(rng: Random) -> str: + return rng.choice(list(styles.items()))[0] def apply_style(style, positive): diff --git a/sdxl_styles/samples/random_style.jpg b/sdxl_styles/samples/random_style.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9f685108fdcf78409e488d79cc2c245fec3ad06e GIT binary patch literal 1454 zcmex=ma3|jiIItm zOAI4iKMQ#V{6EAX$idLS(7?>7#K0uT$SlbC{|JLD$RAA1j3D1Y0V4+|6B|1#Gt2*5 z3>*;g^Do3AxHTA*=9V*ky#3T}oy+sk*U|b%zddvNE zjs=&Z%SwbAm^4hL-Cdy6?3sIuNrOSPVbg*AKu3xoxen?`u-ljfU~UG|c+@eYsWTH` zU}R=uW@TexVTYK&%*e#T%ErJh$RQx45GbsuWN1_<;^5fmRI=%^sH#D5V)7PfXfWC{ zJi9LSIQG`#nC#y2^(XXI?^Z;Avz7RIyRSOgMZUH3nbY2yHnV@~>-68Qc%R*KyfZ7) zap|iSSJV1uxP7e@zxd`%$m*@TPR3;IO)7b0SsH$7g0p4a`_TKo>#n^xnz>T6^Mc3K ztMX!wOizAq3>M?NQoeOZg3GQi>w~_!6zzBOk*?a%`TAR(&gnz<_Gv6wcf0tunx=T` z%O?(dZH~FsiK}b9pKm3r#&c_~!&@GY2E}vXzak8`#ytJ4lx+H0Z{^JEF0-~w-&1nH z^3bxwq3;;%3;*TE)vP$PYgxoy!BrRMIR*1Q>srbyzT?5)W0e~V{pRwmWD@n8S2jy# z1J46(TYX9CBiGe;=_|dzwdlF7xSq3n*Vn}C;+vBH8K!e-#_!fK`c;=2yJ78)lz`?| ze{t=l z^L=YTQN$6Mh5BKF!Xd}fpMLqXb$8nBVvc3(TyipVP2N5hzVyE2V&Ls<;k$ezPEQVe zvNtYE<@+)Dr)TD_wiK_snYC6fhEx2&>ah5q-*a*p*7O{n@l9!p(=tx+#azom|H|x6 zsZsxE&EM?b{Uh`5`ndBQHPYKm-dgh<`FCAb+rO57ul>^U``=1iIVS5>-B8}U=21#W zk)3Ci!jXBZ$1g_pebA?)`tFHA<+OT!! z`3ZKb7?n6@U*NcJ!mve9i6MdG=ACOixo%rmzUka+u;+JJTHvgTpV{tK z8+;Rk|$3d_P;TUH576 z^`5UqXIFY^hcCGLswaPQ^rE~Idp0gh{kqg|PDW4Z;ovSy;U#-*lAV@IG47Q*SNi<( z`#T-IE-t=;!Y+O(KUpB~{xXweGcK3SxGbe_56aGhh(ZaHO%ov5o)MS{L4_AC1;FAA Hq~In1cXT