From 5ada070d88f0527568d2c8c3ac77d3e12d77997b Mon Sep 17 00:00:00 2001 From: delta_lt_0 Date: Sat, 6 Apr 2024 21:25:19 +0800 Subject: [PATCH 01/13] feat: support download of huggingface files from a mirror website (#2637) * fix: load image number from preset (#2611) * fix: add default_image_number to preset handling * fix: use minimum image number of preset and config to prevent UI overflow * fix: use correct base dimensions for outpaint mask padding (#2612) * fix: add Civitai compatibility for LoRAs in a1111 metadata scheme by switching schema (#2615) * feat: update sha256 generation functions https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/29be1da7cf2b5dccfc70fbdd33eb35c56a31ffb7/modules/hashes.py * feat: add compatibility for LoRAs in a1111 metadata scheme * feat: add backwards compatibility * refactor: extract remove_special_loras * fix: correctly apply LoRA weight for legacy schema * docs: bump version number to 2.3.1, add changelog (#2616) * feat:support download huggingface files from a mirror site --------- Co-authored-by: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> --- docker.md | 1 + fooocus_version.py | 2 +- launch.py | 4 ++ ldm_patched/modules/args_parser.py | 1 + modules/async_worker.py | 8 ++-- modules/config.py | 2 + modules/meta_parser.py | 61 ++++++++++++++++++++++-------- modules/model_loader.py | 2 + modules/util.py | 38 ++++++++++++++++--- readme.md | 1 + update_log.md | 7 ++++ 11 files changed, 101 insertions(+), 26 deletions(-) diff --git a/docker.md b/docker.md index 36cfa632..1939d6fc 100644 --- a/docker.md +++ b/docker.md @@ -54,6 +54,7 @@ Docker specified environments are there. They are used by 'entrypoint.sh' |CMDARGS|Arguments for [entry_with_update.py](entry_with_update.py) which is called by [entrypoint.sh](entrypoint.sh)| |config_path|'config.txt' location| |config_example_path|'config_modification_tutorial.txt' location| +|HF_MIRROR| huggingface mirror site domain| You can also use the same json key names and values explained in the 'config_modification_tutorial.txt' as the environments. See examples in the [docker-compose.yml](docker-compose.yml) diff --git a/fooocus_version.py b/fooocus_version.py index a4b8895b..b2050196 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.3.0' +version = '2.3.1' diff --git a/launch.py b/launch.py index afa66705..5c865e6d 100644 --- a/launch.py +++ b/launch.py @@ -80,6 +80,10 @@ if args.gpu_device_id is not None: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_device_id) print("Set device to:", args.gpu_device_id) +if args.hf_mirror is not None : + os.environ['HF_MIRROR'] = str(args.hf_mirror) + print("Set hf_mirror to:", args.hf_mirror) + from modules import config os.environ['GRADIO_TEMP_DIR'] = config.temp_path diff --git a/ldm_patched/modules/args_parser.py b/ldm_patched/modules/args_parser.py index 0c6165a7..bf873783 100644 --- a/ldm_patched/modules/args_parser.py +++ b/ldm_patched/modules/args_parser.py @@ -37,6 +37,7 @@ parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nar parser.add_argument("--port", type=int, default=8188) parser.add_argument("--disable-header-check", type=str, default=None, metavar="ORIGIN", nargs="?", const="*") parser.add_argument("--web-upload-size", type=float, default=100) +parser.add_argument("--hf-mirror", type=str, default=None) parser.add_argument("--external-working-path", type=str, default=None, metavar="PATH", nargs='+', action='append') parser.add_argument("--output-path", type=str, default=None) diff --git a/modules/async_worker.py b/modules/async_worker.py index fa959361..d8a1e072 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -614,12 +614,12 @@ def worker(): H, W, C = inpaint_image.shape if 'left' in outpaint_selections: - inpaint_image = np.pad(inpaint_image, [[0, 0], [int(H * 0.3), 0], [0, 0]], mode='edge') - inpaint_mask = np.pad(inpaint_mask, [[0, 0], [int(H * 0.3), 0]], mode='constant', + inpaint_image = np.pad(inpaint_image, [[0, 0], [int(W * 0.3), 0], [0, 0]], mode='edge') + inpaint_mask = np.pad(inpaint_mask, [[0, 0], [int(W * 0.3), 0]], mode='constant', constant_values=255) if 'right' in outpaint_selections: - inpaint_image = np.pad(inpaint_image, [[0, 0], [0, int(H * 0.3)], [0, 0]], mode='edge') - inpaint_mask = np.pad(inpaint_mask, [[0, 0], [0, int(H * 0.3)]], mode='constant', + inpaint_image = np.pad(inpaint_image, [[0, 0], [0, int(W * 0.3)], [0, 0]], mode='edge') + inpaint_mask = np.pad(inpaint_mask, [[0, 0], [0, int(W * 0.3)]], mode='constant', constant_values=255) inpaint_image = np.ascontiguousarray(inpaint_image.copy()) diff --git a/modules/config.py b/modules/config.py index 76ffd348..b81e218a 100644 --- a/modules/config.py +++ b/modules/config.py @@ -485,6 +485,7 @@ possible_preset_keys = { "default_scheduler": "scheduler", "default_overwrite_step": "steps", "default_performance": "performance", + "default_image_number": "image_number", "default_prompt": "prompt", "default_prompt_negative": "negative_prompt", "default_styles": "styles", @@ -538,6 +539,7 @@ wildcard_filenames = [] sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors' sdxl_lightning_lora = 'sdxl_lightning_4step_lora.safetensors' +loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora] def get_model_filenames(folder_paths, extensions=None, name_filter=None): diff --git a/modules/meta_parser.py b/modules/meta_parser.py index 10bc6896..70ab8860 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -1,5 +1,4 @@ import json -import os import re from abc import ABC, abstractmethod from pathlib import Path @@ -12,7 +11,7 @@ import modules.config import modules.sdxl_styles from modules.flags import MetadataScheme, Performance, Steps from modules.flags import SAMPLERS, CIVITAI_NO_KARRAS -from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list, calculate_sha256 +from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list, sha256 re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) @@ -27,8 +26,9 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool): loaded_parameter_dict = json.loads(raw_metadata) assert isinstance(loaded_parameter_dict, dict) - results = [len(loaded_parameter_dict) > 0, 1] + results = [len(loaded_parameter_dict) > 0] + get_image_number('image_number', 'Image Number', loaded_parameter_dict, results) get_str('prompt', 'Prompt', loaded_parameter_dict, results) get_str('negative_prompt', 'Negative Prompt', loaded_parameter_dict, results) get_list('styles', 'Styles', loaded_parameter_dict, results) @@ -92,13 +92,25 @@ def get_float(key: str, fallback: str | None, source_dict: dict, results: list, results.append(gr.update()) +def get_image_number(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, source_dict.get(fallback, default)) + assert h is not None + h = int(h) + h = min(h, modules.config.default_max_image_number) + results.append(h) + except: + results.append(1) + + def get_steps(key: str, fallback: str | None, source_dict: dict, results: list, default=None): try: h = source_dict.get(key, source_dict.get(fallback, default)) assert h is not None h = int(h) # if not in steps or in steps and performance is not the same - if h not in iter(Steps) or Steps(h).name.casefold() != source_dict.get('performance', '').replace(' ', '_').casefold(): + if h not in iter(Steps) or Steps(h).name.casefold() != source_dict.get('performance', '').replace(' ', + '_').casefold(): results.append(h) return results.append(-1) @@ -192,7 +204,8 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list): def get_sha256(filepath): global hash_cache if filepath not in hash_cache: - hash_cache[filepath] = calculate_sha256(filepath) + # is_safetensors = os.path.splitext(filepath)[1].lower() == '.safetensors' + hash_cache[filepath] = sha256(filepath) return hash_cache[filepath] @@ -219,8 +232,9 @@ def parse_meta_from_preset(preset_content): height = height[:height.index(" ")] preset_prepared[meta_key] = (width, height) else: - preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[settings_key] is not None else getattr(modules.config, settings_key) - + preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[ + settings_key] is not None else getattr(modules.config, settings_key) + if settings_key == "default_styles" or settings_key == "default_aspect_ratio": preset_prepared[meta_key] = str(preset_prepared[meta_key]) @@ -276,6 +290,12 @@ class MetadataParser(ABC): lora_hash = get_sha256(lora_path) self.loras.append((Path(lora_name).stem, lora_weight, lora_hash)) + @staticmethod + def remove_special_loras(lora_filenames): + for lora_to_remove in modules.config.loras_metadata_remove: + if lora_to_remove in lora_filenames: + lora_filenames.remove(lora_to_remove) + class A1111MetadataParser(MetadataParser): def get_scheme(self) -> MetadataScheme: @@ -385,12 +405,19 @@ class A1111MetadataParser(MetadataParser): data[key] = filename break - if 'lora_hashes' in data and data['lora_hashes'] != '': + lora_data = '' + if 'lora_weights' in data and data['lora_weights'] != '': + lora_data = data['lora_weights'] + elif 'lora_hashes' in data and data['lora_hashes'] != '' and data['lora_hashes'].split(', ')[0].count(':') == 2: + lora_data = data['lora_hashes'] + + if lora_data != '': lora_filenames = modules.config.lora_filenames.copy() - if modules.config.sdxl_lcm_lora in lora_filenames: - lora_filenames.remove(modules.config.sdxl_lcm_lora) - for li, lora in enumerate(data['lora_hashes'].split(', ')): - lora_name, lora_hash, lora_weight = lora.split(': ') + self.remove_special_loras(lora_filenames) + for li, lora in enumerate(lora_data.split(', ')): + lora_split = lora.split(': ') + lora_name = lora_split[0] + lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1] for filename in lora_filenames: path = Path(filename) if lora_name == path.stem: @@ -441,11 +468,15 @@ class A1111MetadataParser(MetadataParser): if len(self.loras) > 0: lora_hashes = [] + lora_weights = [] for index, (lora_name, lora_weight, lora_hash) in enumerate(self.loras): # workaround for Fooocus not knowing LoRA name in LoRA metadata - lora_hashes.append(f'{lora_name}: {lora_hash}: {lora_weight}') + lora_hashes.append(f'{lora_name}: {lora_hash}') + lora_weights.append(f'{lora_name}: {lora_weight}') lora_hashes_string = ', '.join(lora_hashes) + lora_weights_string = ', '.join(lora_weights) generation_params[self.fooocus_to_a1111['lora_hashes']] = lora_hashes_string + generation_params[self.fooocus_to_a1111['lora_weights']] = lora_weights_string generation_params[self.fooocus_to_a1111['version']] = data['version'] @@ -468,9 +499,7 @@ class FooocusMetadataParser(MetadataParser): def parse_json(self, metadata: dict) -> dict: model_filenames = modules.config.model_filenames.copy() lora_filenames = modules.config.lora_filenames.copy() - if modules.config.sdxl_lcm_lora in lora_filenames: - lora_filenames.remove(modules.config.sdxl_lcm_lora) - + self.remove_special_loras(lora_filenames) for key, value in metadata.items(): if value in ['', 'None']: continue diff --git a/modules/model_loader.py b/modules/model_loader.py index 8ba336a9..1143f75e 100644 --- a/modules/model_loader.py +++ b/modules/model_loader.py @@ -14,6 +14,8 @@ def load_file_from_url( Returns the path to the downloaded file. """ + domain = os.environ.get("HF_MIRROR", "https://huggingface.co").rstrip('/') + url = str.replace(url, "https://huggingface.co", domain, 1) os.makedirs(model_dir, exist_ok=True) if not file_name: parts = urlparse(url) diff --git a/modules/util.py b/modules/util.py index 7c46d946..9e0fb294 100644 --- a/modules/util.py +++ b/modules/util.py @@ -7,9 +7,9 @@ import math import os import cv2 import json +import hashlib from PIL import Image -from hashlib import sha256 import modules.sdxl_styles @@ -182,16 +182,44 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None): return filenames -def calculate_sha256(filename, length=HASH_SHA256_LENGTH) -> str: - hash_sha256 = sha256() +def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH): + print(f"Calculating sha256 for {filename}: ", end='') + if use_addnet_hash: + with open(filename, "rb") as file: + sha256_value = addnet_hash_safetensors(file) + else: + sha256_value = calculate_sha256(filename) + print(f"{sha256_value}") + + return sha256_value[:length] if length is not None else sha256_value + + +def addnet_hash_safetensors(b): + """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def calculate_sha256(filename) -> str: + hash_sha256 = hashlib.sha256() blksize = 1024 * 1024 with open(filename, "rb") as f: for chunk in iter(lambda: f.read(blksize), b""): hash_sha256.update(chunk) - res = hash_sha256.hexdigest() - return res[:length] if length else res + return hash_sha256.hexdigest() def quote(text): diff --git a/readme.md b/readme.md index 5f66e02a..0ec06f19 100644 --- a/readme.md +++ b/readme.md @@ -368,6 +368,7 @@ A safer way is just to try "run_anime.bat" or "run_realistic.bat" - they should entry_with_update.py [-h] [--listen [IP]] [--port PORT] [--disable-header-check [ORIGIN]] [--web-upload-size WEB_UPLOAD_SIZE] + [--hf-mirror HF_MIRROR] [--external-working-path PATH [PATH ...]] [--output-path OUTPUT_PATH] [--temp-path TEMP_PATH] [--cache-path CACHE_PATH] [--in-browser] diff --git a/update_log.md b/update_log.md index 4e22db0a..62c4882b 100644 --- a/update_log.md +++ b/update_log.md @@ -1,3 +1,10 @@ +# [2.3.1](https://github.com/lllyasviel/Fooocus/releases/tag/2.3.1) + +* Remove positive prompt from anime prefix to not reset prompt after switching presets +* Fix image number being reset to 1 when switching preset, now doesn't reset anymore +* Fix outpainting dimension calculation when extending left/right +* Fix LoRA compatibility for LoRAs in a1111 metadata scheme + # [2.3.0](https://github.com/lllyasviel/Fooocus/releases/tag/2.3.0) * Add performance "lightning" (based on [SDXL-Lightning 4 step LoRA](https://huggingface.co/ByteDance/SDXL-Lightning/blob/main/sdxl_lightning_4step_lora.safetensors)) From 1dff430d4c089fb3bee6287f9371d0926352fb54 Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Sat, 6 Apr 2024 15:27:35 +0200 Subject: [PATCH 02/13] feat: update interposer from v3.1 to v4.0 (#2717) * fix: load image number from preset (#2611) * fix: add default_image_number to preset handling * fix: use minimum image number of preset and config to prevent UI overflow * fix: use correct base dimensions for outpaint mask padding (#2612) * fix: add Civitai compatibility for LoRAs in a1111 metadata scheme by switching schema (#2615) * feat: update sha256 generation functions https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/29be1da7cf2b5dccfc70fbdd33eb35c56a31ffb7/modules/hashes.py * feat: add compatibility for LoRAs in a1111 metadata scheme * feat: add backwards compatibility * refactor: extract remove_special_loras * fix: correctly apply LoRA weight for legacy schema * docs: bump version number to 2.3.1, add changelog (#2616) * feat: update interposer vrom v3.1 to v4.0 --- extras/vae_interpose.py | 98 ++++++++++++++++++++++++----------------- launch.py | 4 +- 2 files changed, 59 insertions(+), 43 deletions(-) diff --git a/extras/vae_interpose.py b/extras/vae_interpose.py index 72fb09a4..d407ca83 100644 --- a/extras/vae_interpose.py +++ b/extras/vae_interpose.py @@ -1,69 +1,85 @@ # https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py import os -import torch -import safetensors.torch as sf -import torch.nn as nn -import ldm_patched.modules.model_management +import safetensors.torch as sf +import torch +import torch.nn as nn + +import ldm_patched.modules.model_management from ldm_patched.modules.model_patcher import ModelPatcher from modules.config import path_vae_approx -class Block(nn.Module): - def __init__(self, size): +class ResBlock(nn.Module): + """Block with residuals""" + + def __init__(self, ch): super().__init__() self.join = nn.ReLU() + self.norm = nn.BatchNorm2d(ch) self.long = nn.Sequential( - nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(0.1), - nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(0.1), - nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1), + nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1), + nn.SiLU(), + nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1), + nn.SiLU(), + nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1), + nn.Dropout(0.1) ) def forward(self, x): - y = self.long(x) - z = self.join(y + x) - return z + x = self.norm(x) + return self.join(self.long(x) + x) -class Interposer(nn.Module): - def __init__(self): +class ExtractBlock(nn.Module): + """Increase no. of channels by [out/in]""" + + def __init__(self, ch_in, ch_out): super().__init__() - self.chan = 4 - self.hid = 128 - - self.head_join = nn.ReLU() - self.head_short = nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1) - self.head_long = nn.Sequential( - nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(0.1), - nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(0.1), - nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1), - ) - self.core = nn.Sequential( - Block(self.hid), - Block(self.hid), - Block(self.hid), - ) - self.tail = nn.Sequential( - nn.ReLU(), - nn.Conv2d(self.hid, self.chan, kernel_size=3, stride=1, padding=1) + self.join = nn.ReLU() + self.short = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1) + self.long = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1), + nn.SiLU(), + nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1), + nn.SiLU(), + nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1), + nn.Dropout(0.1) ) def forward(self, x): - y = self.head_join( - self.head_long(x) + - self.head_short(x) + return self.join(self.long(x) + self.short(x)) + + +class InterposerModel(nn.Module): + """Main neural network""" + + def __init__(self, ch_in=4, ch_out=4, ch_mid=64, scale=1.0, blocks=12): + super().__init__() + self.ch_in = ch_in + self.ch_out = ch_out + self.ch_mid = ch_mid + self.blocks = blocks + self.scale = scale + + self.head = ExtractBlock(self.ch_in, self.ch_mid) + self.core = nn.Sequential( + nn.Upsample(scale_factor=self.scale, mode="nearest"), + *[ResBlock(self.ch_mid) for _ in range(blocks)], + nn.BatchNorm2d(self.ch_mid), + nn.SiLU(), ) + self.tail = nn.Conv2d(self.ch_mid, self.ch_out, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + y = self.head(x) z = self.core(y) return self.tail(z) vae_approx_model = None -vae_approx_filename = os.path.join(path_vae_approx, 'xl-to-v1_interposer-v3.1.safetensors') +vae_approx_filename = os.path.join(path_vae_approx, 'xl-to-v1_interposer-v4.0.safetensors') def parse(x): @@ -72,7 +88,7 @@ def parse(x): x_origin = x.clone() if vae_approx_model is None: - model = Interposer() + model = InterposerModel() model.eval() sd = sf.load_file(vae_approx_filename) model.load_state_dict(sd) diff --git a/launch.py b/launch.py index 5c865e6d..5d40cc5b 100644 --- a/launch.py +++ b/launch.py @@ -62,8 +62,8 @@ def prepare_environment(): vae_approx_filenames = [ ('xlvaeapp.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth'), ('vaeapp_sd15.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/vaeapp_sd15.pt'), - ('xl-to-v1_interposer-v3.1.safetensors', - 'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors') + ('xl-to-v1_interposer-v4.0.safetensors', + 'https://huggingface.co/mashb1t/misc/resolve/main/xl-to-v1_interposer-v4.0.safetensors') ] From dbf49d323eca159499f23b2c055244144ca8fade Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Wed, 17 Apr 2024 22:23:18 +0200 Subject: [PATCH 03/13] feat: add button to reconnect UI without having to reload the page (#2727) * feat: add button to reconnect UI without having to reload the page * qa: add missing semicolon --- javascript/script.js | 37 +++++++++++++++++++++++++++++++++++++ language/en.json | 1 + webui.py | 11 ++++++++++- 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/javascript/script.js b/javascript/script.js index 9aa0b5c1..d379a783 100644 --- a/javascript/script.js +++ b/javascript/script.js @@ -122,6 +122,43 @@ document.addEventListener("DOMContentLoaded", function() { initStylePreviewOverlay(); }); +var onAppend = function(elem, f) { + var observer = new MutationObserver(function(mutations) { + mutations.forEach(function(m) { + if (m.addedNodes.length) { + f(m.addedNodes); + } + }); + }); + observer.observe(elem, {childList: true}); +} + +function addObserverIfDesiredNodeAvailable(querySelector, callback) { + var elem = document.querySelector(querySelector); + if (!elem) { + window.setTimeout(() => addObserverIfDesiredNodeAvailable(querySelector, callback), 1000); + return; + } + + onAppend(elem, callback); +} + +/** + * Show reset button on toast "Connection errored out." + */ +addObserverIfDesiredNodeAvailable(".toast-wrap", function(added) { + added.forEach(function(element) { + if (element.innerText.includes("Connection errored out.")) { + window.setTimeout(function() { + document.getElementById("reset_button").classList.remove("hidden"); + document.getElementById("generate_button").classList.add("hidden"); + document.getElementById("skip_button").classList.add("hidden"); + document.getElementById("stop_button").classList.add("hidden"); + }); + } + }); +}); + /** * Add a ctrl+enter as a shortcut to start a generation */ diff --git a/language/en.json b/language/en.json index fefc79c4..d10c29dc 100644 --- a/language/en.json +++ b/language/en.json @@ -4,6 +4,7 @@ "Generate": "Generate", "Skip": "Skip", "Stop": "Stop", + "Reconnect and Reset UI": "Reconnect and Reset UI", "Input Image": "Input Image", "Advanced": "Advanced", "Upscale or Variation": "Upscale or Variation", diff --git a/webui.py b/webui.py index 98780bff..ababb8b0 100644 --- a/webui.py +++ b/webui.py @@ -123,8 +123,9 @@ with shared.gradio_root: with gr.Column(scale=3, min_width=0): generate_button = gr.Button(label="Generate", value="Generate", elem_classes='type_row', elem_id='generate_button', visible=True) + reset_button = gr.Button(label="Reconnect and Reset UI", value="Reconnect and Reset UI", elem_classes='type_row', elem_id='reset_button', visible=False) load_parameter_button = gr.Button(label="Load Parameters", value="Load Parameters", elem_classes='type_row', elem_id='load_parameter_button', visible=False) - skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', visible=False) + skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', elem_id='skip_button', visible=False) stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False) def stop_clicked(currentTask): @@ -688,6 +689,14 @@ with shared.gradio_root: .then(fn=update_history_link, outputs=history_link) \ .then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed') + reset_button.click(lambda: [worker.AsyncTask(args=[]), False, gr.update(visible=True, interactive=True)] + + [gr.update(visible=False)] * 6 + + [gr.update(visible=True, value=[])], + outputs=[currentTask, state_is_generating, generate_button, + reset_button, stop_button, skip_button, + progress_html, progress_window, progress_gallery, gallery], + queue=False) + for notification_file in ['notification.ogg', 'notification.mp3']: if os.path.exists(notification_file): gr.Audio(interactive=False, value=notification_file, elem_id='audio_notification', visible=False) From c32bc5e199f7a0a45736f10c248cd1955433a609 Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Thu, 9 May 2024 18:59:35 +0200 Subject: [PATCH 04/13] feat: add optional model VAE select (#2867) * Revert "fix: use LF as line breaks for Docker entrypoint.sh (#2843)" (#2865) False alarm, worked as intended before. Sorry for the fuzz. This reverts commit d16a54edd69f82158ae7ffe5669618db33a01ac7. * feat: add VAE select * feat: use different default label, add translation * fix: do not reload model when VAE stays the same * refactor: code cleanup * feat: add metadata handling --- language/en.json | 2 ++ ldm_patched/modules/sd.py | 13 +++++++++---- modules/async_worker.py | 6 ++++-- modules/config.py | 14 +++++++++++++- modules/core.py | 10 ++++++---- modules/default_pipeline.py | 22 ++++++++++++++-------- modules/flags.py | 2 ++ modules/meta_parser.py | 31 ++++++++++++++++++++++++------- modules/util.py | 3 +++ webui.py | 11 +++++++---- 10 files changed, 84 insertions(+), 30 deletions(-) diff --git a/language/en.json b/language/en.json index d10c29dc..1fe78662 100644 --- a/language/en.json +++ b/language/en.json @@ -340,6 +340,8 @@ "sgm_uniform": "sgm_uniform", "simple": "simple", "ddim_uniform": "ddim_uniform", + "VAE": "VAE", + "Default (model)": "Default (model)", "Forced Overwrite of Sampling Step": "Forced Overwrite of Sampling Step", "Set as -1 to disable. For developer debugging.": "Set as -1 to disable. For developer debugging.", "Forced Overwrite of Refiner Switch Step": "Forced Overwrite of Refiner Switch Step", diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index e197c39c..282f2559 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -427,12 +427,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, vae_filename_param=None): sd = ldm_patched.modules.utils.load_torch_file(ckpt_path) sd_keys = sd.keys() clip = None clipvision = None vae = None + vae_filename = None model = None model_patcher = None clip_target = None @@ -462,8 +463,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model.load_model_weights(sd, "model.diffusion_model.") if output_vae: - vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) - vae_sd = model_config.process_vae_state_dict(vae_sd) + if vae_filename_param is None: + vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae_sd = model_config.process_vae_state_dict(vae_sd) + else: + vae_sd = ldm_patched.modules.utils.load_torch_file(vae_filename_param) + vae_filename = vae_filename_param vae = VAE(sd=vae_sd) if output_clip: @@ -485,7 +490,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o print("loaded straight to GPU") model_management.load_model_gpu(model_patcher) - return (model_patcher, clip, vae, clipvision) + return model_patcher, clip, vae, vae_filename, clipvision def load_unet_state_dict(sd): #load unet in diffusers format diff --git a/modules/async_worker.py b/modules/async_worker.py index d8a1e072..3576c4ec 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -166,6 +166,7 @@ def worker(): adaptive_cfg = args.pop() sampler_name = args.pop() scheduler_name = args.pop() + vae_name = args.pop() overwrite_step = args.pop() overwrite_switch = args.pop() overwrite_width = args.pop() @@ -428,7 +429,7 @@ def worker(): progressbar(async_task, 3, 'Loading models ...') pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras, base_model_additional_loras=base_model_additional_loras, - use_synthetic_refiner=use_synthetic_refiner) + use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name) progressbar(async_task, 3, 'Processing prompts ...') tasks = [] @@ -869,6 +870,7 @@ def worker(): d.append(('Sampler', 'sampler', sampler_name)) d.append(('Scheduler', 'scheduler', scheduler_name)) + d.append(('VAE', 'vae', vae_name)) d.append(('Seed', 'seed', str(task['task_seed']))) if freeu_enabled: @@ -883,7 +885,7 @@ def worker(): metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme) metadata_parser.set_data(task['log_positive_prompt'], task['positive'], task['log_negative_prompt'], task['negative'], - steps, base_model_name, refiner_model_name, loras) + steps, base_model_name, refiner_model_name, loras, vae_name) d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version)) img_paths.append(log(x, d, metadata_parser, output_format)) diff --git a/modules/config.py b/modules/config.py index b81e218a..f11460c8 100644 --- a/modules/config.py +++ b/modules/config.py @@ -189,6 +189,7 @@ paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/check paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True) path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/') path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/') +path_vae = get_dir_or_set_default('path_vae', '../models/vae/') path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/') path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/') path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/') @@ -346,6 +347,11 @@ default_scheduler = get_config_item_or_set_default( default_value='karras', validator=lambda x: x in modules.flags.scheduler_list ) +default_vae = get_config_item_or_set_default( + key='default_vae', + default_value=modules.flags.default_vae, + validator=lambda x: isinstance(x, str) +) default_styles = get_config_item_or_set_default( key='default_styles', default_value=[ @@ -535,6 +541,7 @@ with open(config_example_path, "w", encoding="utf-8") as json_file: model_filenames = [] lora_filenames = [] +vae_filenames = [] wildcard_filenames = [] sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors' @@ -546,15 +553,20 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None): if extensions is None: extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'] files = [] + + if not isinstance(folder_paths, list): + folder_paths = [folder_paths] for folder in folder_paths: files += get_files_from_folder(folder, extensions, name_filter) + return files def update_files(): - global model_filenames, lora_filenames, wildcard_filenames, available_presets + global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets model_filenames = get_model_filenames(paths_checkpoints) lora_filenames = get_model_filenames(paths_loras) + vae_filenames = get_model_filenames(path_vae) wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt']) available_presets = get_presets() return diff --git a/modules/core.py b/modules/core.py index 38ee8e8d..3ca4cc5b 100644 --- a/modules/core.py +++ b/modules/core.py @@ -35,12 +35,13 @@ opModelSamplingDiscrete = ModelSamplingDiscrete() class StableDiffusionModel: - def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None): + def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None, vae_filename=None): self.unet = unet self.vae = vae self.clip = clip self.clip_vision = clip_vision self.filename = filename + self.vae_filename = vae_filename self.unet_with_lora = unet self.clip_with_lora = clip self.visited_loras = '' @@ -142,9 +143,10 @@ def apply_controlnet(positive, negative, control_net, image, strength, start_per @torch.no_grad() @torch.inference_mode() -def load_model(ckpt_filename): - unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings) - return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename) +def load_model(ckpt_filename, vae_filename=None): + unet, clip, vae, vae_filename, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings, + vae_filename_param=vae_filename) + return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename, vae_filename=vae_filename) @torch.no_grad() diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 190601ec..38f914c5 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -3,6 +3,7 @@ import os import torch import modules.patch import modules.config +import modules.flags import ldm_patched.modules.model_management import ldm_patched.modules.latent_formats import modules.inpaint_worker @@ -58,17 +59,21 @@ def assert_model_integrity(): @torch.no_grad() @torch.inference_mode() -def refresh_base_model(name): +def refresh_base_model(name, vae_name=None): global model_base filename = get_file_from_folder_list(name, modules.config.paths_checkpoints) - if model_base.filename == filename: + vae_filename = None + if vae_name is not None and vae_name != modules.flags.default_vae: + vae_filename = get_file_from_folder_list(vae_name, modules.config.path_vae) + + if model_base.filename == filename and model_base.vae_filename == vae_filename: return - model_base = core.StableDiffusionModel() - model_base = core.load_model(filename) + model_base = core.load_model(filename, vae_filename) print(f'Base model loaded: {model_base.filename}') + print(f'VAE loaded: {model_base.vae_filename}') return @@ -216,7 +221,7 @@ def prepare_text_encoder(async_call=True): @torch.no_grad() @torch.inference_mode() def refresh_everything(refiner_model_name, base_model_name, loras, - base_model_additional_loras=None, use_synthetic_refiner=False): + base_model_additional_loras=None, use_synthetic_refiner=False, vae_name=None): global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion final_unet = None @@ -227,11 +232,11 @@ def refresh_everything(refiner_model_name, base_model_name, loras, if use_synthetic_refiner and refiner_model_name == 'None': print('Synthetic Refiner Activated') - refresh_base_model(base_model_name) + refresh_base_model(base_model_name, vae_name) synthesize_refiner_model() else: refresh_refiner_model(refiner_model_name) - refresh_base_model(base_model_name) + refresh_base_model(base_model_name, vae_name) refresh_loras(loras, base_model_additional_loras=base_model_additional_loras) assert_model_integrity() @@ -254,7 +259,8 @@ def refresh_everything(refiner_model_name, base_model_name, loras, refresh_everything( refiner_model_name=modules.config.default_refiner_model_name, base_model_name=modules.config.default_base_model_name, - loras=get_enabled_loras(modules.config.default_loras) + loras=get_enabled_loras(modules.config.default_loras), + vae_name=modules.config.default_vae, ) diff --git a/modules/flags.py b/modules/flags.py index c9d13fd8..9f2aefb3 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -53,6 +53,8 @@ SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys()) sampler_list = SAMPLER_NAMES scheduler_list = SCHEDULER_NAMES +default_vae = 'Default (model)' + refiner_swap_method = 'joint' cn_ip = "ImagePrompt" diff --git a/modules/meta_parser.py b/modules/meta_parser.py index 70ab8860..84032e82 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -46,6 +46,7 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool): get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results) get_str('sampler', 'Sampler', loaded_parameter_dict, results) get_str('scheduler', 'Scheduler', loaded_parameter_dict, results) + get_str('vae', 'VAE', loaded_parameter_dict, results) get_seed('seed', 'Seed', loaded_parameter_dict, results) if is_generating: @@ -253,6 +254,7 @@ class MetadataParser(ABC): self.refiner_model_name: str = '' self.refiner_model_hash: str = '' self.loras: list = [] + self.vae_name: str = '' @abstractmethod def get_scheme(self) -> MetadataScheme: @@ -267,7 +269,7 @@ class MetadataParser(ABC): raise NotImplementedError def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name, - refiner_model_name, loras): + refiner_model_name, loras, vae_name): self.raw_prompt = raw_prompt self.full_prompt = full_prompt self.raw_negative_prompt = raw_negative_prompt @@ -289,6 +291,7 @@ class MetadataParser(ABC): lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras) lora_hash = get_sha256(lora_path) self.loras.append((Path(lora_name).stem, lora_weight, lora_hash)) + self.vae_name = Path(vae_name).stem @staticmethod def remove_special_loras(lora_filenames): @@ -310,6 +313,7 @@ class A1111MetadataParser(MetadataParser): 'steps': 'Steps', 'sampler': 'Sampler', 'scheduler': 'Scheduler', + 'vae': 'VAE', 'guidance_scale': 'CFG scale', 'seed': 'Seed', 'resolution': 'Size', @@ -397,13 +401,12 @@ class A1111MetadataParser(MetadataParser): data['sampler'] = k break - for key in ['base_model', 'refiner_model']: + for key in ['base_model', 'refiner_model', 'vae']: if key in data: - for filename in modules.config.model_filenames: - path = Path(filename) - if data[key] == path.stem: - data[key] = filename - break + if key == 'vae': + self.add_extension_to_filename(data, modules.config.vae_filenames, 'vae') + else: + self.add_extension_to_filename(data, modules.config.model_filenames, key) lora_data = '' if 'lora_weights' in data and data['lora_weights'] != '': @@ -433,6 +436,7 @@ class A1111MetadataParser(MetadataParser): sampler = data['sampler'] scheduler = data['scheduler'] + if sampler in SAMPLERS and SAMPLERS[sampler] != '': sampler = SAMPLERS[sampler] if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras': @@ -451,6 +455,7 @@ class A1111MetadataParser(MetadataParser): self.fooocus_to_a1111['performance']: data['performance'], self.fooocus_to_a1111['scheduler']: scheduler, + self.fooocus_to_a1111['vae']: Path(data['vae']).stem, # workaround for multiline prompts self.fooocus_to_a1111['raw_prompt']: self.raw_prompt, self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt, @@ -491,6 +496,14 @@ class A1111MetadataParser(MetadataParser): negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else "" return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip() + @staticmethod + def add_extension_to_filename(data, filenames, key): + for filename in filenames: + path = Path(filename) + if data[key] == path.stem: + data[key] = filename + break + class FooocusMetadataParser(MetadataParser): def get_scheme(self) -> MetadataScheme: @@ -499,6 +512,7 @@ class FooocusMetadataParser(MetadataParser): def parse_json(self, metadata: dict) -> dict: model_filenames = modules.config.model_filenames.copy() lora_filenames = modules.config.lora_filenames.copy() + vae_filenames = modules.config.vae_filenames.copy() self.remove_special_loras(lora_filenames) for key, value in metadata.items(): if value in ['', 'None']: @@ -507,6 +521,8 @@ class FooocusMetadataParser(MetadataParser): metadata[key] = self.replace_value_with_filename(key, value, model_filenames) elif key.startswith('lora_combined_'): metadata[key] = self.replace_value_with_filename(key, value, lora_filenames) + elif key == 'vae': + metadata[key] = self.replace_value_with_filename(key, value, vae_filenames) else: continue @@ -533,6 +549,7 @@ class FooocusMetadataParser(MetadataParser): res['refiner_model'] = self.refiner_model_name res['refiner_model_hash'] = self.refiner_model_hash + res['vae'] = self.vae_name res['loras'] = self.loras if modules.config.metadata_created_by != '': diff --git a/modules/util.py b/modules/util.py index 9e0fb294..d2feecb6 100644 --- a/modules/util.py +++ b/modules/util.py @@ -371,6 +371,9 @@ def is_json(data: str) -> bool: def get_file_from_folder_list(name, folders): + if not isinstance(folders, list): + folders = [folders] + for folder in folders: filename = os.path.abspath(os.path.realpath(os.path.join(folder, name))) if os.path.isfile(filename): diff --git a/webui.py b/webui.py index ababb8b0..eec6054a 100644 --- a/webui.py +++ b/webui.py @@ -407,6 +407,8 @@ with shared.gradio_root: value=modules.config.default_sampler) scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list, value=modules.config.default_scheduler) + vae_name = gr.Dropdown(label='VAE', choices=[modules.flags.default_vae] + modules.config.vae_filenames, + value=modules.config.default_vae, show_label=True) generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch', info='(Experimental) This may cause performance problems on some computers and certain internet conditions.', @@ -529,6 +531,7 @@ with shared.gradio_root: modules.config.update_files() results = [gr.update(choices=modules.config.model_filenames)] results += [gr.update(choices=['None'] + modules.config.model_filenames)] + results += [gr.update(choices=['None'] + modules.config.vae_filenames)] if not args_manager.args.disable_preset_selection: results += [gr.update(choices=modules.config.available_presets)] for i in range(modules.config.default_max_lora_number): @@ -536,7 +539,7 @@ with shared.gradio_root: gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()] return results - refresh_files_output = [base_model, refiner_model] + refresh_files_output = [base_model, refiner_model, vae_name] if not args_manager.args.disable_preset_selection: refresh_files_output += [preset_selection] refresh_files.click(refresh_files_clicked, [], refresh_files_output + lora_ctrls, @@ -548,8 +551,8 @@ with shared.gradio_root: performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection, overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model, - refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed, - generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls + refiner_model, refiner_switch, sampler_name, scheduler_name, vae_name, seed_random, + image_seed, generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls if not args_manager.args.disable_preset_selection: def preset_selection_change(preset, is_generating): @@ -635,7 +638,7 @@ with shared.gradio_root: ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image] ctrls += [disable_preview, disable_intermediate_results, disable_seed_increment] ctrls += [adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg] - ctrls += [sampler_name, scheduler_name] + ctrls += [sampler_name, scheduler_name, vae_name] ctrls += [overwrite_step, overwrite_switch, overwrite_width, overwrite_height, overwrite_vary_strength] ctrls += [overwrite_upscale_strength, mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint] ctrls += [debugging_cn_preprocessor, skipping_cn_preprocessor, canny_low_threshold, canny_high_threshold] From f54364fe4ebd737349611c1d040703b0ac7ace68 Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Thu, 9 May 2024 19:02:04 +0200 Subject: [PATCH 05/13] feat: add random style checkbox to styles selection (#2855) * feat: add random style * feat: rename random to random style, add translation * feat: add preview image for random style --- language/en.json | 1 + modules/async_worker.py | 11 ++++++++--- modules/sdxl_styles.py | 10 ++++++++-- sdxl_styles/samples/random_style.jpg | Bin 0 -> 1454 bytes 4 files changed, 17 insertions(+), 5 deletions(-) create mode 100644 sdxl_styles/samples/random_style.jpg diff --git a/language/en.json b/language/en.json index 1fe78662..20189b28 100644 --- a/language/en.json +++ b/language/en.json @@ -58,6 +58,7 @@ "\ud83d\udcda History Log": "\uD83D\uDCDA History Log", "Image Style": "Image Style", "Fooocus V2": "Fooocus V2", + "Random Style": "Random Style", "Default (Slightly Cinematic)": "Default (Slightly Cinematic)", "Fooocus Masterpiece": "Fooocus Masterpiece", "Fooocus Photograph": "Fooocus Photograph", diff --git a/modules/async_worker.py b/modules/async_worker.py index 3576c4ec..432bfe9b 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -43,7 +43,7 @@ def worker(): import fooocus_version import args_manager - from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays + from modules.sdxl_styles import apply_style, get_random_style, apply_wildcards, fooocus_expansion, apply_arrays, random_style_name from modules.private_logger import log from extras.expansion import safe_str from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \ @@ -450,8 +450,12 @@ def worker(): positive_basic_workloads = [] negative_basic_workloads = [] + task_styles = style_selections.copy() if use_style: - for s in style_selections: + for i, s in enumerate(task_styles): + if s == random_style_name: + s = get_random_style(task_rng) + task_styles[i] = s p, n = apply_style(s, positive=task_prompt) positive_basic_workloads = positive_basic_workloads + p negative_basic_workloads = negative_basic_workloads + n @@ -479,6 +483,7 @@ def worker(): negative_top_k=len(negative_basic_workloads), log_positive_prompt='\n'.join([task_prompt] + task_extra_positive_prompts), log_negative_prompt='\n'.join([task_negative_prompt] + task_extra_negative_prompts), + styles=task_styles )) if use_expansion: @@ -843,7 +848,7 @@ def worker(): d = [('Prompt', 'prompt', task['log_positive_prompt']), ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']), ('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']), - ('Styles', 'styles', str(raw_style_selections)), + ('Styles', 'styles', str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])), ('Performance', 'performance', performance_selection.value)] if performance_selection.steps() != steps: diff --git a/modules/sdxl_styles.py b/modules/sdxl_styles.py index 77ad6b57..5b6afb59 100644 --- a/modules/sdxl_styles.py +++ b/modules/sdxl_styles.py @@ -5,6 +5,7 @@ import math import modules.config from modules.util import get_files_from_folder +from random import Random # cannot use modules.config - validators causing circular imports styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/')) @@ -50,8 +51,13 @@ for styles_file in styles_files: print(f'Failed to load style file {styles_file}') style_keys = list(styles.keys()) -fooocus_expansion = "Fooocus V2" -legal_style_names = [fooocus_expansion] + style_keys +fooocus_expansion = 'Fooocus V2' +random_style_name = 'Random Style' +legal_style_names = [fooocus_expansion, random_style_name] + style_keys + + +def get_random_style(rng: Random) -> str: + return rng.choice(list(styles.items()))[0] def apply_style(style, positive): diff --git a/sdxl_styles/samples/random_style.jpg b/sdxl_styles/samples/random_style.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9f685108fdcf78409e488d79cc2c245fec3ad06e GIT binary patch literal 1454 zcmex=ma3|jiIItm zOAI4iKMQ#V{6EAX$idLS(7?>7#K0uT$SlbC{|JLD$RAA1j3D1Y0V4+|6B|1#Gt2*5 z3>*;g^Do3AxHTA*=9V*ky#3T}oy+sk*U|b%zddvNE zjs=&Z%SwbAm^4hL-Cdy6?3sIuNrOSPVbg*AKu3xoxen?`u-ljfU~UG|c+@eYsWTH` zU}R=uW@TexVTYK&%*e#T%ErJh$RQx45GbsuWN1_<;^5fmRI=%^sH#D5V)7PfXfWC{ zJi9LSIQG`#nC#y2^(XXI?^Z;Avz7RIyRSOgMZUH3nbY2yHnV@~>-68Qc%R*KyfZ7) zap|iSSJV1uxP7e@zxd`%$m*@TPR3;IO)7b0SsH$7g0p4a`_TKo>#n^xnz>T6^Mc3K ztMX!wOizAq3>M?NQoeOZg3GQi>w~_!6zzBOk*?a%`TAR(&gnz<_Gv6wcf0tunx=T` z%O?(dZH~FsiK}b9pKm3r#&c_~!&@GY2E}vXzak8`#ytJ4lx+H0Z{^JEF0-~w-&1nH z^3bxwq3;;%3;*TE)vP$PYgxoy!BrRMIR*1Q>srbyzT?5)W0e~V{pRwmWD@n8S2jy# z1J46(TYX9CBiGe;=_|dzwdlF7xSq3n*Vn}C;+vBH8K!e-#_!fK`c;=2yJ78)lz`?| ze{t=l z^L=YTQN$6Mh5BKF!Xd}fpMLqXb$8nBVvc3(TyipVP2N5hzVyE2V&Ls<;k$ezPEQVe zvNtYE<@+)Dr)TD_wiK_snYC6fhEx2&>ah5q-*a*p*7O{n@l9!p(=tx+#azom|H|x6 zsZsxE&EM?b{Uh`5`ndBQHPYKm-dgh<`FCAb+rO57ul>^U``=1iIVS5>-B8}U=21#W zk)3Ci!jXBZ$1g_pebA?)`tFHA<+OT!! z`3ZKb7?n6@U*NcJ!mve9i6MdG=ACOixo%rmzUka+u;+JJTHvgTpV{tK z8+;Rk|$3d_P;TUH576 z^`5UqXIFY^hcCGLswaPQ^rE~Idp0gh{kqg|PDW4Z;ovSy;U#-*lAV@IG47Q*SNi<( z`#T-IE-t=;!Y+O(KUpB~{xXweGcK3SxGbe_56aGhh(ZaHO%ov5o)MS{L4_AC1;FAA Hq~In1cXT Date: Thu, 9 May 2024 19:03:30 +0200 Subject: [PATCH 06/13] feat: update anime from animaPencilXL_v100 to animaPencilXL_v310 (#2454) * feat: update anime from animaPencilXL_v100 to animaPencilXL_v200 * feat: update animaPencilXL from 2.0.0 to 2.6.0 * feat: update animaPencilXL from 2.6.0 to 3.1.0 * feat: reduce cfg as suggested by vendor from 3.0.0 https://civitai.com/models/261336?modelVersionId=435001 "recommend to decrease CFG scale." + all examples are in CFG 6 --- presets/anime.json | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/presets/anime.json b/presets/anime.json index 2610677c..78607edb 100644 --- a/presets/anime.json +++ b/presets/anime.json @@ -1,5 +1,5 @@ { - "default_model": "animaPencilXL_v100.safetensors", + "default_model": "animaPencilXL_v310.safetensors", "default_refiner": "None", "default_refiner_switch": 0.5, "default_loras": [ @@ -29,7 +29,7 @@ 1.0 ] ], - "default_cfg_scale": 7.0, + "default_cfg_scale": 6.0, "default_sample_sharpness": 2.0, "default_sampler": "dpmpp_2m_sde_gpu", "default_scheduler": "karras", @@ -43,9 +43,15 @@ ], "default_aspect_ratio": "896*1152", "checkpoint_downloads": { - "animaPencilXL_v100.safetensors": "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/animaPencilXL_v100.safetensors" + "animaPencilXL_v310.safetensors": "https://huggingface.co/mashb1t/fav_models/resolve/main/fav/animaPencilXL_v310.safetensors" }, "embeddings_downloads": {}, "lora_downloads": {}, - "previous_default_models": [] + "previous_default_models": [ + "animaPencilXL_v300.safetensors", + "animaPencilXL_v260.safetensors", + "animaPencilXL_v210.safetensors", + "animaPencilXL_v200.safetensors", + "animaPencilXL_v100.safetensors" + ] } \ No newline at end of file From 052393bb9bfa6fe66d1f8d3fdf8da38605998eff Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Thu, 9 May 2024 19:13:59 +0200 Subject: [PATCH 07/13] refactor: rename label for reconnect button (#2893) * feat: add button to reconnect UI without having to reload the page * qa: add missing semicolon * refactor: rename button label to "Reconnect" --- language/en.json | 2 +- webui.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/language/en.json b/language/en.json index 20189b28..e9cd6b73 100644 --- a/language/en.json +++ b/language/en.json @@ -4,7 +4,7 @@ "Generate": "Generate", "Skip": "Skip", "Stop": "Stop", - "Reconnect and Reset UI": "Reconnect and Reset UI", + "Reconnect": "Reconnect", "Input Image": "Input Image", "Advanced": "Advanced", "Upscale or Variation": "Upscale or Variation", diff --git a/webui.py b/webui.py index eec6054a..85b2c0df 100644 --- a/webui.py +++ b/webui.py @@ -123,7 +123,7 @@ with shared.gradio_root: with gr.Column(scale=3, min_width=0): generate_button = gr.Button(label="Generate", value="Generate", elem_classes='type_row', elem_id='generate_button', visible=True) - reset_button = gr.Button(label="Reconnect and Reset UI", value="Reconnect and Reset UI", elem_classes='type_row', elem_id='reset_button', visible=False) + reset_button = gr.Button(label="Reconnect", value="Reconnect", elem_classes='type_row', elem_id='reset_button', visible=False) load_parameter_button = gr.Button(label="Load Parameters", value="Load Parameters", elem_classes='type_row', elem_id='load_parameter_button', visible=False) skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', elem_id='skip_button', visible=False) stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False) From bdd6b1a9b0b182ce62c20642e4c6bd8acec0e4c3 Mon Sep 17 00:00:00 2001 From: docppp <29142757+docppp@users.noreply.github.com> Date: Thu, 9 May 2024 20:25:43 +0200 Subject: [PATCH 08/13] feat: add full raw prompt to history log (#1920) * Update async_worker.py * Update private_logger.py * refactor: only show full prompt details in logs, exclude from image metadata --------- Co-authored-by: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Co-authored-by: Manuel Schmid --- modules/async_worker.py | 2 +- modules/private_logger.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index 432bfe9b..cde99bdc 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -893,7 +893,7 @@ def worker(): steps, base_model_name, refiner_model_name, loras, vae_name) d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version)) - img_paths.append(log(x, d, metadata_parser, output_format)) + img_paths.append(log(x, d, metadata_parser, output_format, task)) yield_result(async_task, img_paths, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results) except ldm_patched.modules.model_management.InterruptProcessingException as e: diff --git a/modules/private_logger.py b/modules/private_logger.py index edd9457d..eb8f0cc5 100644 --- a/modules/private_logger.py +++ b/modules/private_logger.py @@ -21,7 +21,7 @@ def get_current_html_path(output_format=None): return html_name -def log(img, metadata, metadata_parser: MetadataParser | None = None, output_format=None) -> str: +def log(img, metadata, metadata_parser: MetadataParser | None = None, output_format=None, task=None) -> str: path_outputs = modules.config.temp_path if args_manager.args.disable_image_log else modules.config.path_outputs output_format = output_format if output_format else modules.config.default_output_format date_string, local_temp_filename, only_name = generate_temp_filename(folder=path_outputs, extension=output_format) @@ -111,9 +111,15 @@ def log(img, metadata, metadata_parser: MetadataParser | None = None, output_for for label, key, value in metadata: value_txt = str(value).replace('\n', '
') item += f"{label}{value_txt}\n" + + if task is not None and 'positive' in task and 'negative' in task: + full_prompt_details = f"""
Positive{', '.join(task['positive'])}
+
Negative{', '.join(task['negative'])}
""" + item += f"Full raw prompt{full_prompt_details}\n" + item += "" - js_txt = urllib.parse.quote(json.dumps({k: v for _, k, v in metadata}, indent=0), safe='') + js_txt = urllib.parse.quote(json.dumps({k: v for _, k, v, in metadata}, indent=0), safe='') item += f"
" item += "" From 96bf89f782376544f4f7f20492c5ae0d6a82001f Mon Sep 17 00:00:00 2001 From: Vishvesh Khanvilkar <158825962+khanvilkarvishvesh@users.noreply.github.com> Date: Fri, 17 May 2024 20:48:45 +0530 Subject: [PATCH 09/13] fix: use correct border radius css property (#2845) --- css/style.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/css/style.css b/css/style.css index c702a725..b9e6e2ce 100644 --- a/css/style.css +++ b/css/style.css @@ -391,6 +391,6 @@ progress::after { background-color: #fff8; font-family: monospace; text-align: center; - border-radius-top: 5px; + border-radius: 5px 5px 0px 0px; display: none; /* remove this to enable tooltip in preview image */ } \ No newline at end of file From 5e594685e1f86ffaf4b10d6ca7f11742daca4a84 Mon Sep 17 00:00:00 2001 From: e52fa787 <31095594+e52fa787@users.noreply.github.com> Date: Fri, 17 May 2024 23:25:56 +0800 Subject: [PATCH 10/13] fix: do not close meta tag in HTML header (#2740) * fixed typo in HTML (extra tag) * refactor: remove closing slash for meta tag as of specification in https://html.com/tags/meta/, meta tagas are null elements: This element must not contain any content, and does not need a closing tag. --------- Co-authored-by: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> --- modules/ui_gradio_extensions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py index bebf9f8c..409c7e33 100644 --- a/modules/ui_gradio_extensions.py +++ b/modules/ui_gradio_extensions.py @@ -39,7 +39,7 @@ def javascript_html(): head += f'\n' head += f'\n' head += f'\n' - head += f'\n' + head += f'\n' if args_manager.args.theme: head += f'\n' From 33fa175bd438041fe4ae715adc9a06d025a940b3 Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Fri, 17 May 2024 18:25:08 +0200 Subject: [PATCH 11/13] feat: automatically describe image on uov image upload (#1938) * feat: automatically describe image on uov image upload if prompt is empty * feat: add argument to disable automatic uov image description * feat: rename argument, disable by default this prevents computers with low hardware specifications from being unnecessary blocked --- args_manager.py | 3 +++ webui.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/args_manager.py b/args_manager.py index 6a3ae9dc..e023da27 100644 --- a/args_manager.py +++ b/args_manager.py @@ -31,6 +31,9 @@ args_parser.parser.add_argument("--disable-metadata", action='store_true', args_parser.parser.add_argument("--disable-preset-download", action='store_true', help="Disables downloading models for presets", default=False) +args_parser.parser.add_argument("--enable-describe-uov-image", action='store_true', + help="Disables automatic description of uov images when prompt is empty", default=False) + args_parser.parser.add_argument("--always-download-new-model", action='store_true', help="Always download newer models ", default=False) diff --git a/webui.py b/webui.py index 85b2c0df..f99ab159 100644 --- a/webui.py +++ b/webui.py @@ -717,6 +717,15 @@ with shared.gradio_root: desc_btn.click(trigger_describe, inputs=[desc_method, desc_input_image], outputs=[prompt, style_selections], show_progress=True, queue=True) + if args_manager.args.enable_describe_uov_image: + def trigger_uov_describe(mode, img, prompt): + # keep prompt if not empty + if prompt == '': + return trigger_describe(mode, img) + return gr.update(), gr.update() + + uov_input_image.upload(trigger_uov_describe, inputs=[desc_method, uov_input_image, prompt], + outputs=[prompt, style_selections], show_progress=True, queue=True) def dump_default_english_config(): from modules.localization import dump_english_config From 00d3d1b4b31b2effa32f6eb96f8e5caf6368f8e3 Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Sat, 18 May 2024 15:50:28 +0200 Subject: [PATCH 12/13] feat: add nsfw image censoring via config and checkbox (#958) * add nsfw image censoring activatable via config, uses CompVis/stable-diffusion-safety-checker * fix progressbar call for nsfw output * use config to set cache dir for safety checker * add checkbox black_out_nsfw makes both enabling via config and checkbox possible, where config overrides the checkbox value * fix: add missing diffusers package * feat: extract safety checker, remove dependency to diffusers * feat: make code compatible again after merge with main * feat: move censor to extras, optimize safety checker file handling * refactor: rename folder safety_checker_models to safety_checker --- extras/censor.py | 56 ++++++ extras/safety_checker/configs/config.json | 171 ++++++++++++++++++ .../configs/preprocessor_config.json | 20 ++ .../safety_checker/models/safety_checker.py | 126 +++++++++++++ language/en.json | 2 + .../put_safety_checker_models_here | 0 modules/async_worker.py | 32 +++- modules/config.py | 14 ++ webui.py | 14 +- 9 files changed, 424 insertions(+), 11 deletions(-) create mode 100644 extras/censor.py create mode 100644 extras/safety_checker/configs/config.json create mode 100644 extras/safety_checker/configs/preprocessor_config.json create mode 100644 extras/safety_checker/models/safety_checker.py create mode 100644 models/safety_checker/put_safety_checker_models_here diff --git a/extras/censor.py b/extras/censor.py new file mode 100644 index 00000000..2047db24 --- /dev/null +++ b/extras/censor.py @@ -0,0 +1,56 @@ +# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py +import numpy as np +import os + +from extras.safety_checker.models.safety_checker import StableDiffusionSafetyChecker +from transformers import CLIPFeatureExtractor, CLIPConfig +from PIL import Image +import modules.config + +safety_checker_repo_root = os.path.join(os.path.dirname(__file__), 'safety_checker') +config_path = os.path.join(safety_checker_repo_root, "configs", "config.json") +preprocessor_config_path = os.path.join(safety_checker_repo_root, "configs", "preprocessor_config.json") + +safety_feature_extractor = None +safety_checker = None + + +def numpy_to_pil(image): + image = (image * 255).round().astype("uint8") + pil_image = Image.fromarray(image) + + return pil_image + + +# check and replace nsfw content +def check_safety(x_image): + global safety_feature_extractor, safety_checker + + if safety_feature_extractor is None or safety_checker is None: + safety_checker_model = modules.config.downloading_safety_checker_model() + safety_feature_extractor = CLIPFeatureExtractor.from_json_file(preprocessor_config_path) + clip_config = CLIPConfig.from_json_file(config_path) + safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_checker_model, config=clip_config) + + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + + return x_checked_image, has_nsfw_concept + + +def censor_single(x): + x_checked_image, has_nsfw_concept = check_safety(x) + + # replace image with black pixels, keep dimensions + # workaround due to different numpy / pytorch image matrix format + if has_nsfw_concept[0]: + imageshape = x_checked_image.shape + x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8) + + return x_checked_image + + +def censor_batch(images): + images = [censor_single(image) for image in images] + + return images \ No newline at end of file diff --git a/extras/safety_checker/configs/config.json b/extras/safety_checker/configs/config.json new file mode 100644 index 00000000..aa454d22 --- /dev/null +++ b/extras/safety_checker/configs/config.json @@ -0,0 +1,171 @@ +{ + "_name_or_path": "clip-vit-large-patch14/", + "architectures": [ + "SafetyChecker" + ], + "initializer_factor": 1.0, + "logit_scale_init_value": 2.6592, + "model_type": "clip", + "projection_dim": 768, + "text_config": { + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "attention_dropout": 0.0, + "bad_words_ids": null, + "bos_token_id": 0, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.0, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": 2, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 77, + "min_length": 0, + "model_type": "clip_text_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 12, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 12, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": 1, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "task_specific_params": null, + "temperature": 1.0, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "transformers_version": "4.21.0.dev0", + "typical_p": 1.0, + "use_bfloat16": false, + "vocab_size": 49408 + }, + "text_config_dict": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12 + }, + "torch_dtype": "float32", + "transformers_version": null, + "vision_config": { + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "attention_dropout": 0.0, + "bad_words_ids": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.0, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "hidden_act": "quick_gelu", + "hidden_size": 1024, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "image_size": 224, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 4096, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "min_length": 0, + "model_type": "clip_vision_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 16, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 24, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "patch_size": 14, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "task_specific_params": null, + "temperature": 1.0, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "transformers_version": "4.21.0.dev0", + "typical_p": 1.0, + "use_bfloat16": false + }, + "vision_config_dict": { + "hidden_size": 1024, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "patch_size": 14 + } +} diff --git a/extras/safety_checker/configs/preprocessor_config.json b/extras/safety_checker/configs/preprocessor_config.json new file mode 100644 index 00000000..5294955f --- /dev/null +++ b/extras/safety_checker/configs/preprocessor_config.json @@ -0,0 +1,20 @@ +{ + "crop_size": 224, + "do_center_crop": true, + "do_convert_rgb": true, + "do_normalize": true, + "do_resize": true, + "feature_extractor_type": "CLIPFeatureExtractor", + "image_mean": [ + 0.48145466, + 0.4578275, + 0.40821073 + ], + "image_std": [ + 0.26862954, + 0.26130258, + 0.27577711 + ], + "resample": 3, + "size": 224 +} diff --git a/extras/safety_checker/models/safety_checker.py b/extras/safety_checker/models/safety_checker.py new file mode 100644 index 00000000..ea38bf03 --- /dev/null +++ b/extras/safety_checker/models/safety_checker.py @@ -0,0 +1,126 @@ +# from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py + +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class StableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + main_input_name = "clip_input" + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if torch.is_tensor(images) or torch.is_tensor(images[0]): + images[idx] = torch.zeros_like(images[idx]) # black image + else: + images[idx] = np.zeros(images[idx].shape) # black image + + if any(has_nsfw_concepts): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + images[has_nsfw_concepts] = 0.0 # black image + + return images, has_nsfw_concepts diff --git a/language/en.json b/language/en.json index e9cd6b73..3eb5d5e2 100644 --- a/language/en.json +++ b/language/en.json @@ -55,6 +55,8 @@ "Disable seed increment": "Disable seed increment", "Disable automatic seed increment when image number is > 1.": "Disable automatic seed increment when image number is > 1.", "Read wildcards in order": "Read wildcards in order", + "Black Out NSFW": "Black Out NSFW", + "Use black image if NSFW is detected.": "Use black image if NSFW is detected.", "\ud83d\udcda History Log": "\uD83D\uDCDA History Log", "Image Style": "Image Style", "Fooocus V2": "Fooocus V2", diff --git a/models/safety_checker/put_safety_checker_models_here b/models/safety_checker/put_safety_checker_models_here new file mode 100644 index 00000000..e69de29b diff --git a/modules/async_worker.py b/modules/async_worker.py index cde99bdc..6f0b30a9 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -43,6 +43,7 @@ def worker(): import fooocus_version import args_manager + from extras.censor import censor_batch, censor_single from modules.sdxl_styles import apply_style, get_random_style, apply_wildcards, fooocus_expansion, apply_arrays, random_style_name from modules.private_logger import log from extras.expansion import safe_str @@ -68,10 +69,14 @@ def worker(): print(f'[Fooocus] {text}') async_task.yields.append(['preview', (number, text, None)]) - def yield_result(async_task, imgs, do_not_show_finished_images=False): + def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, progressbar_index=13): if not isinstance(imgs, list): imgs = [imgs] + if censor and (modules.config.default_black_out_nsfw or black_out_nsfw): + progressbar(async_task, progressbar_index, 'Checking for NSFW content ...') + imgs = censor_batch(imgs) + async_task.results = async_task.results + imgs if do_not_show_finished_images: @@ -160,6 +165,7 @@ def worker(): disable_preview = args.pop() disable_intermediate_results = args.pop() disable_seed_increment = args.pop() + black_out_nsfw = args.pop() adm_scaler_positive = args.pop() adm_scaler_negative = args.pop() adm_scaler_end = args.pop() @@ -578,8 +584,11 @@ def worker(): if direct_return: d = [('Upscale (Fast)', 'upscale_fast', '2x')] + if modules.config.default_black_out_nsfw or black_out_nsfw: + progressbar(async_task, 100, 'Checking for NSFW content ...') + uov_input_image = censor_single(uov_input_image) uov_input_image_path = log(uov_input_image, d, output_format=output_format) - yield_result(async_task, uov_input_image_path, do_not_show_finished_images=True) + yield_result(async_task, uov_input_image_path, black_out_nsfw, False, do_not_show_finished_images=True) return tiled = True @@ -643,8 +652,7 @@ def worker(): ) if debugging_inpaint_preprocessor: - yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), - do_not_show_finished_images=True) + yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw, do_not_show_finished_images=True) return progressbar(async_task, 13, 'VAE Inpaint encoding ...') @@ -707,7 +715,7 @@ def worker(): cn_img = HWC3(cn_img) task[0] = core.numpy_to_pytorch(cn_img) if debugging_cn_preprocessor: - yield_result(async_task, cn_img, do_not_show_finished_images=True) + yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True) return for task in cn_tasks[flags.cn_cpds]: cn_img, cn_stop, cn_weight = task @@ -719,7 +727,7 @@ def worker(): cn_img = HWC3(cn_img) task[0] = core.numpy_to_pytorch(cn_img) if debugging_cn_preprocessor: - yield_result(async_task, cn_img, do_not_show_finished_images=True) + yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True) return for task in cn_tasks[flags.cn_ip]: cn_img, cn_stop, cn_weight = task @@ -730,7 +738,7 @@ def worker(): task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path) if debugging_cn_preprocessor: - yield_result(async_task, cn_img, do_not_show_finished_images=True) + yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True) return for task in cn_tasks[flags.cn_ip_face]: cn_img, cn_stop, cn_weight = task @@ -744,7 +752,7 @@ def worker(): task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path) if debugging_cn_preprocessor: - yield_result(async_task, cn_img, do_not_show_finished_images=True) + yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True) return all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face] @@ -844,6 +852,12 @@ def worker(): imgs = [inpaint_worker.current_task.post_process(x) for x in imgs] img_paths = [] + + if modules.config.default_black_out_nsfw or black_out_nsfw: + progressbar(async_task, int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)), + 'Checking for NSFW content ...') + imgs = censor_batch(imgs) + for x in imgs: d = [('Prompt', 'prompt', task['log_positive_prompt']), ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']), @@ -895,7 +909,7 @@ def worker(): d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version)) img_paths.append(log(x, d, metadata_parser, output_format, task)) - yield_result(async_task, img_paths, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results) + yield_result(async_task, img_paths, black_out_nsfw, False, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results) except ldm_patched.modules.model_management.InterruptProcessingException as e: if async_task.last_stop == 'skip': print('User skipped') diff --git a/modules/config.py b/modules/config.py index f11460c8..ffb74a23 100644 --- a/modules/config.py +++ b/modules/config.py @@ -196,6 +196,7 @@ path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlne path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/') path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion') path_wildcards = get_dir_or_set_default('path_wildcards', '../wildcards/') +path_safety_checker = get_dir_or_set_default('path_safety_checker', '../models/safety_checker/') path_outputs = get_path_output() @@ -456,6 +457,11 @@ example_inpaint_prompts = get_config_item_or_set_default( ], validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x) ) +default_black_out_nsfw = get_config_item_or_set_default( + key='default_black_out_nsfw', + default_value=False, + validator=lambda x: isinstance(x, bool) +) default_save_metadata_to_images = get_config_item_or_set_default( key='default_save_metadata_to_images', default_value=False, @@ -691,5 +697,13 @@ def downloading_upscale_model(): ) return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin') +def downloading_safety_checker_model(): + load_file_from_url( + url='https://huggingface.co/mashb1t/misc/resolve/main/stable-diffusion-safety-checker.bin', + model_dir=path_safety_checker, + file_name='stable-diffusion-safety-checker.bin' + ) + return os.path.join(path_safety_checker, 'stable-diffusion-safety-checker.bin') + update_files() diff --git a/webui.py b/webui.py index f99ab159..55f3102c 100644 --- a/webui.py +++ b/webui.py @@ -436,7 +436,8 @@ with shared.gradio_root: overwrite_upscale_strength = gr.Slider(label='Forced Overwrite of Denoising Strength of "Upscale"', minimum=-1, maximum=1.0, step=0.001, value=-1, info='Set as negative number to disable. For developer debugging.') - disable_preview = gr.Checkbox(label='Disable Preview', value=False, + disable_preview = gr.Checkbox(label='Disable Preview', value=modules.config.default_black_out_nsfw, + interactive=not modules.config.default_black_out_nsfw, info='Disable preview during generation.') disable_intermediate_results = gr.Checkbox(label='Disable Intermediate Results', value=modules.config.default_performance == flags.Performance.EXTREME_SPEED.value, @@ -447,6 +448,15 @@ with shared.gradio_root: value=False) read_wildcards_in_order = gr.Checkbox(label="Read wildcards in order", value=False) + black_out_nsfw = gr.Checkbox(label='Black Out NSFW', + value=modules.config.default_black_out_nsfw, + interactive=not modules.config.default_black_out_nsfw, + info='Use black image if NSFW is detected.') + + black_out_nsfw.change(lambda x: gr.update(value=x, interactive=not x), + inputs=black_out_nsfw, outputs=disable_preview, queue=False, + show_progress=False) + if not args_manager.args.disable_metadata: save_metadata_to_images = gr.Checkbox(label='Save Metadata to Images', value=modules.config.default_save_metadata_to_images, info='Adds parameters to generated images allowing manual regeneration.') @@ -636,7 +646,7 @@ with shared.gradio_root: ctrls += [input_image_checkbox, current_tab] ctrls += [uov_method, uov_input_image] ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image] - ctrls += [disable_preview, disable_intermediate_results, disable_seed_increment] + ctrls += [disable_preview, disable_intermediate_results, disable_seed_increment, black_out_nsfw] ctrls += [adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg] ctrls += [sampler_name, scheduler_name, vae_name] ctrls += [overwrite_step, overwrite_switch, overwrite_width, overwrite_height, overwrite_vary_strength] From 3a55e7e3910b8ae58f82a5a0e4c11d7d4fa3143f Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Sat, 18 May 2024 15:53:34 +0200 Subject: [PATCH 13/13] feat: add AlignYourStepsScheduler (#2905) --- .../contrib/external_align_your_steps.py | 55 +++++++++++++++++++ modules/flags.py | 2 +- modules/sample_hijack.py | 4 ++ 3 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 ldm_patched/contrib/external_align_your_steps.py diff --git a/ldm_patched/contrib/external_align_your_steps.py b/ldm_patched/contrib/external_align_your_steps.py new file mode 100644 index 00000000..624bbce2 --- /dev/null +++ b/ldm_patched/contrib/external_align_your_steps.py @@ -0,0 +1,55 @@ +# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py + +#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html +import numpy as np +import torch + +def loglinear_interp(t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = np.linspace(0, 1, len(t_steps)) + ys = np.log(t_steps[::-1]) + + new_xs = np.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = np.exp(new_ys)[::-1].copy() + return interped_ys + +NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582], + "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582], + "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]} + +class AlignYourStepsScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model_type": (["SD1", "SDXL", "SVD"], ), + "steps": ("INT", {"default": 10, "min": 10, "max": 10000}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model_type, steps, denoise): + total_steps = steps + if denoise < 1.0: + if denoise <= 0.0: + return (torch.FloatTensor([]),) + total_steps = round(steps * denoise) + + sigmas = NOISE_LEVELS[model_type][:] + if (steps + 1) != len(sigmas): + sigmas = loglinear_interp(sigmas, steps + 1) + + sigmas = sigmas[-(total_steps + 1):] + sigmas[-1] = 0 + return (torch.FloatTensor(sigmas), ) + +NODE_CLASS_MAPPINGS = { + "AlignYourStepsScheduler": AlignYourStepsScheduler, +} \ No newline at end of file diff --git a/modules/flags.py b/modules/flags.py index 9f2aefb3..0c605439 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -47,7 +47,7 @@ SAMPLERS = KSAMPLER | SAMPLER_EXTRA KSAMPLER_NAMES = list(KSAMPLER.keys()) -SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo"] +SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", "align_your_steps"] SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys()) sampler_list = SAMPLER_NAMES diff --git a/modules/sample_hijack.py b/modules/sample_hijack.py index 5936a096..4ab3cbbd 100644 --- a/modules/sample_hijack.py +++ b/modules/sample_hijack.py @@ -3,6 +3,7 @@ import ldm_patched.modules.samplers import ldm_patched.modules.model_management from collections import namedtuple +from ldm_patched.contrib.external_align_your_steps import AlignYourStepsScheduler from ldm_patched.contrib.external_custom_sampler import SDTurboScheduler from ldm_patched.k_diffusion import sampling as k_diffusion_sampling from ldm_patched.modules.samplers import normal_scheduler, simple_scheduler, ddim_scheduler @@ -175,6 +176,9 @@ def calculate_sigmas_scheduler_hacked(model, scheduler_name, steps): sigmas = normal_scheduler(model, steps, sgm=True) elif scheduler_name == "turbo": sigmas = SDTurboScheduler().get_sigmas(namedtuple('Patcher', ['model'])(model=model), steps=steps, denoise=1.0)[0] + elif scheduler_name == "align_your_steps": + model_type = 'SDXL' if isinstance(model.latent_format, ldm_patched.modules.latent_formats.SDXL) else 'SD1' + sigmas = AlignYourStepsScheduler().get_sigmas(model_type=model_type, steps=steps, denoise=1.0)[0] else: raise TypeError("error invalid scheduler") return sigmas