diff --git a/docker.md b/docker.md index 36cfa632..1939d6fc 100644 --- a/docker.md +++ b/docker.md @@ -54,6 +54,7 @@ Docker specified environments are there. They are used by 'entrypoint.sh' |CMDARGS|Arguments for [entry_with_update.py](entry_with_update.py) which is called by [entrypoint.sh](entrypoint.sh)| |config_path|'config.txt' location| |config_example_path|'config_modification_tutorial.txt' location| +|HF_MIRROR| huggingface mirror site domain| You can also use the same json key names and values explained in the 'config_modification_tutorial.txt' as the environments. See examples in the [docker-compose.yml](docker-compose.yml) diff --git a/extras/vae_interpose.py b/extras/vae_interpose.py index 72fb09a4..d407ca83 100644 --- a/extras/vae_interpose.py +++ b/extras/vae_interpose.py @@ -1,69 +1,85 @@ # https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py import os -import torch -import safetensors.torch as sf -import torch.nn as nn -import ldm_patched.modules.model_management +import safetensors.torch as sf +import torch +import torch.nn as nn + +import ldm_patched.modules.model_management from ldm_patched.modules.model_patcher import ModelPatcher from modules.config import path_vae_approx -class Block(nn.Module): - def __init__(self, size): +class ResBlock(nn.Module): + """Block with residuals""" + + def __init__(self, ch): super().__init__() self.join = nn.ReLU() + self.norm = nn.BatchNorm2d(ch) self.long = nn.Sequential( - nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(0.1), - nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(0.1), - nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1), + nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1), + nn.SiLU(), + nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1), + nn.SiLU(), + nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1), + nn.Dropout(0.1) ) def forward(self, x): - y = self.long(x) - z = self.join(y + x) - return z + x = self.norm(x) + return self.join(self.long(x) + x) -class Interposer(nn.Module): - def __init__(self): +class ExtractBlock(nn.Module): + """Increase no. of channels by [out/in]""" + + def __init__(self, ch_in, ch_out): super().__init__() - self.chan = 4 - self.hid = 128 - - self.head_join = nn.ReLU() - self.head_short = nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1) - self.head_long = nn.Sequential( - nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(0.1), - nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(0.1), - nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1), - ) - self.core = nn.Sequential( - Block(self.hid), - Block(self.hid), - Block(self.hid), - ) - self.tail = nn.Sequential( - nn.ReLU(), - nn.Conv2d(self.hid, self.chan, kernel_size=3, stride=1, padding=1) + self.join = nn.ReLU() + self.short = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1) + self.long = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1), + nn.SiLU(), + nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1), + nn.SiLU(), + nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1), + nn.Dropout(0.1) ) def forward(self, x): - y = self.head_join( - self.head_long(x) + - self.head_short(x) + return self.join(self.long(x) + self.short(x)) + + +class InterposerModel(nn.Module): + """Main neural network""" + + def __init__(self, ch_in=4, ch_out=4, ch_mid=64, scale=1.0, blocks=12): + super().__init__() + self.ch_in = ch_in + self.ch_out = ch_out + self.ch_mid = ch_mid + self.blocks = blocks + self.scale = scale + + self.head = ExtractBlock(self.ch_in, self.ch_mid) + self.core = nn.Sequential( + nn.Upsample(scale_factor=self.scale, mode="nearest"), + *[ResBlock(self.ch_mid) for _ in range(blocks)], + nn.BatchNorm2d(self.ch_mid), + nn.SiLU(), ) + self.tail = nn.Conv2d(self.ch_mid, self.ch_out, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + y = self.head(x) z = self.core(y) return self.tail(z) vae_approx_model = None -vae_approx_filename = os.path.join(path_vae_approx, 'xl-to-v1_interposer-v3.1.safetensors') +vae_approx_filename = os.path.join(path_vae_approx, 'xl-to-v1_interposer-v4.0.safetensors') def parse(x): @@ -72,7 +88,7 @@ def parse(x): x_origin = x.clone() if vae_approx_model is None: - model = Interposer() + model = InterposerModel() model.eval() sd = sf.load_file(vae_approx_filename) model.load_state_dict(sd) diff --git a/javascript/script.js b/javascript/script.js index 9aa0b5c1..d379a783 100644 --- a/javascript/script.js +++ b/javascript/script.js @@ -122,6 +122,43 @@ document.addEventListener("DOMContentLoaded", function() { initStylePreviewOverlay(); }); +var onAppend = function(elem, f) { + var observer = new MutationObserver(function(mutations) { + mutations.forEach(function(m) { + if (m.addedNodes.length) { + f(m.addedNodes); + } + }); + }); + observer.observe(elem, {childList: true}); +} + +function addObserverIfDesiredNodeAvailable(querySelector, callback) { + var elem = document.querySelector(querySelector); + if (!elem) { + window.setTimeout(() => addObserverIfDesiredNodeAvailable(querySelector, callback), 1000); + return; + } + + onAppend(elem, callback); +} + +/** + * Show reset button on toast "Connection errored out." + */ +addObserverIfDesiredNodeAvailable(".toast-wrap", function(added) { + added.forEach(function(element) { + if (element.innerText.includes("Connection errored out.")) { + window.setTimeout(function() { + document.getElementById("reset_button").classList.remove("hidden"); + document.getElementById("generate_button").classList.add("hidden"); + document.getElementById("skip_button").classList.add("hidden"); + document.getElementById("stop_button").classList.add("hidden"); + }); + } + }); +}); + /** * Add a ctrl+enter as a shortcut to start a generation */ diff --git a/language/en.json b/language/en.json index 11816e78..eb2b6218 100644 --- a/language/en.json +++ b/language/en.json @@ -4,6 +4,7 @@ "Generate": "Generate", "Skip": "Skip", "Stop": "Stop", + "Reconnect and Reset UI": "Reconnect and Reset UI", "Input Image": "Input Image", "Advanced": "Advanced", "Upscale or Variation": "Upscale or Variation", diff --git a/launch.py b/launch.py index 79416761..e333e287 100644 --- a/launch.py +++ b/launch.py @@ -63,8 +63,8 @@ def prepare_environment(): vae_approx_filenames = [ ('xlvaeapp.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth'), ('vaeapp_sd15.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/vaeapp_sd15.pt'), - ('xl-to-v1_interposer-v3.1.safetensors', - 'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors') + ('xl-to-v1_interposer-v4.0.safetensors', + 'https://huggingface.co/mashb1t/misc/resolve/main/xl-to-v1_interposer-v4.0.safetensors') ] @@ -81,6 +81,10 @@ if args.gpu_device_id is not None: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_device_id) print("Set device to:", args.gpu_device_id) +if args.hf_mirror is not None : + os.environ['HF_MIRROR'] = str(args.hf_mirror) + print("Set hf_mirror to:", args.hf_mirror) + from modules import config os.environ["U2NET_HOME"] = config.path_inpaint diff --git a/ldm_patched/modules/args_parser.py b/ldm_patched/modules/args_parser.py index 0c6165a7..bf873783 100644 --- a/ldm_patched/modules/args_parser.py +++ b/ldm_patched/modules/args_parser.py @@ -37,6 +37,7 @@ parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nar parser.add_argument("--port", type=int, default=8188) parser.add_argument("--disable-header-check", type=str, default=None, metavar="ORIGIN", nargs="?", const="*") parser.add_argument("--web-upload-size", type=float, default=100) +parser.add_argument("--hf-mirror", type=str, default=None) parser.add_argument("--external-working-path", type=str, default=None, metavar="PATH", nargs='+', action='append') parser.add_argument("--output-path", type=str, default=None) diff --git a/modules/model_loader.py b/modules/model_loader.py index 8ba336a9..1143f75e 100644 --- a/modules/model_loader.py +++ b/modules/model_loader.py @@ -14,6 +14,8 @@ def load_file_from_url( Returns the path to the downloaded file. """ + domain = os.environ.get("HF_MIRROR", "https://huggingface.co").rstrip('/') + url = str.replace(url, "https://huggingface.co", domain, 1) os.makedirs(model_dir, exist_ok=True) if not file_name: parts = urlparse(url) diff --git a/readme.md b/readme.md index bde77fbd..22bef857 100644 --- a/readme.md +++ b/readme.md @@ -565,6 +565,7 @@ A safer way is just to try "run_anime.bat" or "run_realistic.bat" - they should entry_with_update.py [-h] [--listen [IP]] [--port PORT] [--disable-header-check [ORIGIN]] [--web-upload-size WEB_UPLOAD_SIZE] + [--hf-mirror HF_MIRROR] [--external-working-path PATH [PATH ...]] [--output-path OUTPUT_PATH] [--temp-path TEMP_PATH] [--cache-path CACHE_PATH] [--in-browser] diff --git a/webui.py b/webui.py index 7132608f..67738595 100644 --- a/webui.py +++ b/webui.py @@ -123,8 +123,9 @@ with shared.gradio_root: with gr.Column(scale=3, min_width=0): generate_button = gr.Button(label="Generate", value="Generate", elem_classes='type_row', elem_id='generate_button', visible=True) + reset_button = gr.Button(label="Reconnect and Reset UI", value="Reconnect and Reset UI", elem_classes='type_row', elem_id='reset_button', visible=False) load_parameter_button = gr.Button(label="Load Parameters", value="Load Parameters", elem_classes='type_row', elem_id='load_parameter_button', visible=False) - skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', visible=False) + skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', elem_id='skip_button', visible=False) stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False) def stop_clicked(currentTask): @@ -775,6 +776,14 @@ with shared.gradio_root: .then(fn=update_history_link, outputs=history_link) \ .then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed') + reset_button.click(lambda: [worker.AsyncTask(args=[]), False, gr.update(visible=True, interactive=True)] + + [gr.update(visible=False)] * 6 + + [gr.update(visible=True, value=[])], + outputs=[currentTask, state_is_generating, generate_button, + reset_button, stop_button, skip_button, + progress_html, progress_window, progress_gallery, gallery], + queue=False) + def trigger_describe(mode, img): if mode == flags.desc_type_photo: from extras.interrogate import default_interrogator as default_interrogator_photo