Merge branch 'develop' into feature/add-nsfw-filter
This commit is contained in:
commit
2d327bbd28
|
|
@ -31,6 +31,9 @@ args_parser.parser.add_argument("--disable-metadata", action='store_true',
|
|||
args_parser.parser.add_argument("--disable-preset-download", action='store_true',
|
||||
help="Disables downloading models for presets", default=False)
|
||||
|
||||
args_parser.parser.add_argument("--enable-describe-uov-image", action='store_true',
|
||||
help="Disables automatic description of uov images when prompt is empty", default=False)
|
||||
|
||||
args_parser.parser.add_argument("--always-download-new-model", action='store_true',
|
||||
help="Always download newer models ", default=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -391,6 +391,6 @@ progress::after {
|
|||
background-color: #fff8;
|
||||
font-family: monospace;
|
||||
text-align: center;
|
||||
border-radius-top: 5px;
|
||||
border-radius: 5px 5px 0px 0px;
|
||||
display: none; /* remove this to enable tooltip in preview image */
|
||||
}
|
||||
|
|
@ -54,6 +54,7 @@ Docker specified environments are there. They are used by 'entrypoint.sh'
|
|||
|CMDARGS|Arguments for [entry_with_update.py](entry_with_update.py) which is called by [entrypoint.sh](entrypoint.sh)|
|
||||
|config_path|'config.txt' location|
|
||||
|config_example_path|'config_modification_tutorial.txt' location|
|
||||
|HF_MIRROR| huggingface mirror site domain|
|
||||
|
||||
You can also use the same json key names and values explained in the 'config_modification_tutorial.txt' as the environments.
|
||||
See examples in the [docker-compose.yml](docker-compose.yml)
|
||||
|
|
|
|||
|
|
@ -1,69 +1,85 @@
|
|||
# https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py
|
||||
|
||||
import os
|
||||
import torch
|
||||
import safetensors.torch as sf
|
||||
import torch.nn as nn
|
||||
import ldm_patched.modules.model_management
|
||||
|
||||
import safetensors.torch as sf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import ldm_patched.modules.model_management
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
from modules.config import path_vae_approx
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, size):
|
||||
class ResBlock(nn.Module):
|
||||
"""Block with residuals"""
|
||||
|
||||
def __init__(self, ch):
|
||||
super().__init__()
|
||||
self.join = nn.ReLU()
|
||||
self.norm = nn.BatchNorm2d(ch)
|
||||
self.long = nn.Sequential(
|
||||
nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
|
||||
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
|
||||
nn.Dropout(0.1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.long(x)
|
||||
z = self.join(y + x)
|
||||
return z
|
||||
x = self.norm(x)
|
||||
return self.join(self.long(x) + x)
|
||||
|
||||
|
||||
class Interposer(nn.Module):
|
||||
def __init__(self):
|
||||
class ExtractBlock(nn.Module):
|
||||
"""Increase no. of channels by [out/in]"""
|
||||
|
||||
def __init__(self, ch_in, ch_out):
|
||||
super().__init__()
|
||||
self.chan = 4
|
||||
self.hid = 128
|
||||
|
||||
self.head_join = nn.ReLU()
|
||||
self.head_short = nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1)
|
||||
self.head_long = nn.Sequential(
|
||||
nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1),
|
||||
)
|
||||
self.core = nn.Sequential(
|
||||
Block(self.hid),
|
||||
Block(self.hid),
|
||||
Block(self.hid),
|
||||
)
|
||||
self.tail = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(self.hid, self.chan, kernel_size=3, stride=1, padding=1)
|
||||
self.join = nn.ReLU()
|
||||
self.short = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
|
||||
self.long = nn.Sequential(
|
||||
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
|
||||
nn.Dropout(0.1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.head_join(
|
||||
self.head_long(x) +
|
||||
self.head_short(x)
|
||||
return self.join(self.long(x) + self.short(x))
|
||||
|
||||
|
||||
class InterposerModel(nn.Module):
|
||||
"""Main neural network"""
|
||||
|
||||
def __init__(self, ch_in=4, ch_out=4, ch_mid=64, scale=1.0, blocks=12):
|
||||
super().__init__()
|
||||
self.ch_in = ch_in
|
||||
self.ch_out = ch_out
|
||||
self.ch_mid = ch_mid
|
||||
self.blocks = blocks
|
||||
self.scale = scale
|
||||
|
||||
self.head = ExtractBlock(self.ch_in, self.ch_mid)
|
||||
self.core = nn.Sequential(
|
||||
nn.Upsample(scale_factor=self.scale, mode="nearest"),
|
||||
*[ResBlock(self.ch_mid) for _ in range(blocks)],
|
||||
nn.BatchNorm2d(self.ch_mid),
|
||||
nn.SiLU(),
|
||||
)
|
||||
self.tail = nn.Conv2d(self.ch_mid, self.ch_out, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.head(x)
|
||||
z = self.core(y)
|
||||
return self.tail(z)
|
||||
|
||||
|
||||
vae_approx_model = None
|
||||
vae_approx_filename = os.path.join(path_vae_approx, 'xl-to-v1_interposer-v3.1.safetensors')
|
||||
vae_approx_filename = os.path.join(path_vae_approx, 'xl-to-v1_interposer-v4.0.safetensors')
|
||||
|
||||
|
||||
def parse(x):
|
||||
|
|
@ -72,7 +88,7 @@ def parse(x):
|
|||
x_origin = x.clone()
|
||||
|
||||
if vae_approx_model is None:
|
||||
model = Interposer()
|
||||
model = InterposerModel()
|
||||
model.eval()
|
||||
sd = sf.load_file(vae_approx_filename)
|
||||
model.load_state_dict(sd)
|
||||
|
|
|
|||
|
|
@ -122,6 +122,43 @@ document.addEventListener("DOMContentLoaded", function() {
|
|||
initStylePreviewOverlay();
|
||||
});
|
||||
|
||||
var onAppend = function(elem, f) {
|
||||
var observer = new MutationObserver(function(mutations) {
|
||||
mutations.forEach(function(m) {
|
||||
if (m.addedNodes.length) {
|
||||
f(m.addedNodes);
|
||||
}
|
||||
});
|
||||
});
|
||||
observer.observe(elem, {childList: true});
|
||||
}
|
||||
|
||||
function addObserverIfDesiredNodeAvailable(querySelector, callback) {
|
||||
var elem = document.querySelector(querySelector);
|
||||
if (!elem) {
|
||||
window.setTimeout(() => addObserverIfDesiredNodeAvailable(querySelector, callback), 1000);
|
||||
return;
|
||||
}
|
||||
|
||||
onAppend(elem, callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* Show reset button on toast "Connection errored out."
|
||||
*/
|
||||
addObserverIfDesiredNodeAvailable(".toast-wrap", function(added) {
|
||||
added.forEach(function(element) {
|
||||
if (element.innerText.includes("Connection errored out.")) {
|
||||
window.setTimeout(function() {
|
||||
document.getElementById("reset_button").classList.remove("hidden");
|
||||
document.getElementById("generate_button").classList.add("hidden");
|
||||
document.getElementById("skip_button").classList.add("hidden");
|
||||
document.getElementById("stop_button").classList.add("hidden");
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* Add a ctrl+enter as a shortcut to start a generation
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
"Generate": "Generate",
|
||||
"Skip": "Skip",
|
||||
"Stop": "Stop",
|
||||
"Reconnect": "Reconnect",
|
||||
"Input Image": "Input Image",
|
||||
"Advanced": "Advanced",
|
||||
"Upscale or Variation": "Upscale or Variation",
|
||||
|
|
@ -59,6 +60,7 @@
|
|||
"\ud83d\udcda History Log": "\uD83D\uDCDA History Log",
|
||||
"Image Style": "Image Style",
|
||||
"Fooocus V2": "Fooocus V2",
|
||||
"Random Style": "Random Style",
|
||||
"Default (Slightly Cinematic)": "Default (Slightly Cinematic)",
|
||||
"Fooocus Masterpiece": "Fooocus Masterpiece",
|
||||
"Fooocus Photograph": "Fooocus Photograph",
|
||||
|
|
@ -341,6 +343,8 @@
|
|||
"sgm_uniform": "sgm_uniform",
|
||||
"simple": "simple",
|
||||
"ddim_uniform": "ddim_uniform",
|
||||
"VAE": "VAE",
|
||||
"Default (model)": "Default (model)",
|
||||
"Forced Overwrite of Sampling Step": "Forced Overwrite of Sampling Step",
|
||||
"Set as -1 to disable. For developer debugging.": "Set as -1 to disable. For developer debugging.",
|
||||
"Forced Overwrite of Refiner Switch Step": "Forced Overwrite of Refiner Switch Step",
|
||||
|
|
|
|||
|
|
@ -62,8 +62,8 @@ def prepare_environment():
|
|||
vae_approx_filenames = [
|
||||
('xlvaeapp.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth'),
|
||||
('vaeapp_sd15.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/vaeapp_sd15.pt'),
|
||||
('xl-to-v1_interposer-v3.1.safetensors',
|
||||
'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors')
|
||||
('xl-to-v1_interposer-v4.0.safetensors',
|
||||
'https://huggingface.co/mashb1t/misc/resolve/main/xl-to-v1_interposer-v4.0.safetensors')
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -80,6 +80,10 @@ if args.gpu_device_id is not None:
|
|||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_device_id)
|
||||
print("Set device to:", args.gpu_device_id)
|
||||
|
||||
if args.hf_mirror is not None :
|
||||
os.environ['HF_MIRROR'] = str(args.hf_mirror)
|
||||
print("Set hf_mirror to:", args.hf_mirror)
|
||||
|
||||
from modules import config
|
||||
|
||||
os.environ['GRADIO_TEMP_DIR'] = config.temp_path
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nar
|
|||
parser.add_argument("--port", type=int, default=8188)
|
||||
parser.add_argument("--disable-header-check", type=str, default=None, metavar="ORIGIN", nargs="?", const="*")
|
||||
parser.add_argument("--web-upload-size", type=float, default=100)
|
||||
parser.add_argument("--hf-mirror", type=str, default=None)
|
||||
|
||||
parser.add_argument("--external-working-path", type=str, default=None, metavar="PATH", nargs='+', action='append')
|
||||
parser.add_argument("--output-path", type=str, default=None)
|
||||
|
|
|
|||
|
|
@ -427,12 +427,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||
|
||||
return (ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, vae_filename_param=None):
|
||||
sd = ldm_patched.modules.utils.load_torch_file(ckpt_path)
|
||||
sd_keys = sd.keys()
|
||||
clip = None
|
||||
clipvision = None
|
||||
vae = None
|
||||
vae_filename = None
|
||||
model = None
|
||||
model_patcher = None
|
||||
clip_target = None
|
||||
|
|
@ -462,8 +463,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
|
||||
if output_vae:
|
||||
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
|
||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||
if vae_filename_param is None:
|
||||
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
|
||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||
else:
|
||||
vae_sd = ldm_patched.modules.utils.load_torch_file(vae_filename_param)
|
||||
vae_filename = vae_filename_param
|
||||
vae = VAE(sd=vae_sd)
|
||||
|
||||
if output_clip:
|
||||
|
|
@ -485,7 +490,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||
print("loaded straight to GPU")
|
||||
model_management.load_model_gpu(model_patcher)
|
||||
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
return model_patcher, clip, vae, vae_filename, clipvision
|
||||
|
||||
|
||||
def load_unet_state_dict(sd): #load unet in diffusers format
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ def worker():
|
|||
import args_manager
|
||||
|
||||
from extras.censor import censor_batch, censor_single
|
||||
from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays
|
||||
from modules.sdxl_styles import apply_style, get_random_style, apply_wildcards, fooocus_expansion, apply_arrays, random_style_name
|
||||
from modules.private_logger import log
|
||||
from extras.expansion import safe_str
|
||||
from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \
|
||||
|
|
@ -172,6 +172,7 @@ def worker():
|
|||
adaptive_cfg = args.pop()
|
||||
sampler_name = args.pop()
|
||||
scheduler_name = args.pop()
|
||||
vae_name = args.pop()
|
||||
overwrite_step = args.pop()
|
||||
overwrite_switch = args.pop()
|
||||
overwrite_width = args.pop()
|
||||
|
|
@ -434,7 +435,7 @@ def worker():
|
|||
progressbar(async_task, 3, 'Loading models ...')
|
||||
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
|
||||
loras=loras, base_model_additional_loras=base_model_additional_loras,
|
||||
use_synthetic_refiner=use_synthetic_refiner)
|
||||
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)
|
||||
|
||||
progressbar(async_task, 3, 'Processing prompts ...')
|
||||
tasks = []
|
||||
|
|
@ -455,8 +456,12 @@ def worker():
|
|||
positive_basic_workloads = []
|
||||
negative_basic_workloads = []
|
||||
|
||||
task_styles = style_selections.copy()
|
||||
if use_style:
|
||||
for s in style_selections:
|
||||
for i, s in enumerate(task_styles):
|
||||
if s == random_style_name:
|
||||
s = get_random_style(task_rng)
|
||||
task_styles[i] = s
|
||||
p, n = apply_style(s, positive=task_prompt)
|
||||
positive_basic_workloads = positive_basic_workloads + p
|
||||
negative_basic_workloads = negative_basic_workloads + n
|
||||
|
|
@ -484,6 +489,7 @@ def worker():
|
|||
negative_top_k=len(negative_basic_workloads),
|
||||
log_positive_prompt='\n'.join([task_prompt] + task_extra_positive_prompts),
|
||||
log_negative_prompt='\n'.join([task_negative_prompt] + task_extra_negative_prompts),
|
||||
styles=task_styles
|
||||
))
|
||||
|
||||
if use_expansion:
|
||||
|
|
@ -856,7 +862,7 @@ def worker():
|
|||
d = [('Prompt', 'prompt', task['log_positive_prompt']),
|
||||
('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),
|
||||
('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']),
|
||||
('Styles', 'styles', str(raw_style_selections)),
|
||||
('Styles', 'styles', str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])),
|
||||
('Performance', 'performance', performance_selection.value)]
|
||||
|
||||
if performance_selection.steps() != steps:
|
||||
|
|
@ -883,6 +889,7 @@ def worker():
|
|||
|
||||
d.append(('Sampler', 'sampler', sampler_name))
|
||||
d.append(('Scheduler', 'scheduler', scheduler_name))
|
||||
d.append(('VAE', 'vae', vae_name))
|
||||
d.append(('Seed', 'seed', str(task['task_seed'])))
|
||||
|
||||
if freeu_enabled:
|
||||
|
|
@ -897,10 +904,10 @@ def worker():
|
|||
metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme)
|
||||
metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
|
||||
task['log_negative_prompt'], task['negative'],
|
||||
steps, base_model_name, refiner_model_name, loras)
|
||||
steps, base_model_name, refiner_model_name, loras, vae_name)
|
||||
d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
|
||||
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
|
||||
img_paths.append(log(x, d, metadata_parser, output_format))
|
||||
img_paths.append(log(x, d, metadata_parser, output_format, task))
|
||||
|
||||
yield_result(async_task, img_paths, black_out_nsfw, False, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results)
|
||||
except ldm_patched.modules.model_management.InterruptProcessingException as e:
|
||||
|
|
|
|||
|
|
@ -189,6 +189,7 @@ paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/check
|
|||
paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True)
|
||||
path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/')
|
||||
path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/')
|
||||
path_vae = get_dir_or_set_default('path_vae', '../models/vae/')
|
||||
path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/')
|
||||
path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/')
|
||||
path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/')
|
||||
|
|
@ -347,6 +348,11 @@ default_scheduler = get_config_item_or_set_default(
|
|||
default_value='karras',
|
||||
validator=lambda x: x in modules.flags.scheduler_list
|
||||
)
|
||||
default_vae = get_config_item_or_set_default(
|
||||
key='default_vae',
|
||||
default_value=modules.flags.default_vae,
|
||||
validator=lambda x: isinstance(x, str)
|
||||
)
|
||||
default_styles = get_config_item_or_set_default(
|
||||
key='default_styles',
|
||||
default_value=[
|
||||
|
|
@ -541,6 +547,7 @@ with open(config_example_path, "w", encoding="utf-8") as json_file:
|
|||
|
||||
model_filenames = []
|
||||
lora_filenames = []
|
||||
vae_filenames = []
|
||||
wildcard_filenames = []
|
||||
|
||||
sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
|
||||
|
|
@ -552,15 +559,20 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None):
|
|||
if extensions is None:
|
||||
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
|
||||
files = []
|
||||
|
||||
if not isinstance(folder_paths, list):
|
||||
folder_paths = [folder_paths]
|
||||
for folder in folder_paths:
|
||||
files += get_files_from_folder(folder, extensions, name_filter)
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def update_files():
|
||||
global model_filenames, lora_filenames, wildcard_filenames, available_presets
|
||||
global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets
|
||||
model_filenames = get_model_filenames(paths_checkpoints)
|
||||
lora_filenames = get_model_filenames(paths_loras)
|
||||
vae_filenames = get_model_filenames(path_vae)
|
||||
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
|
||||
available_presets = get_presets()
|
||||
return
|
||||
|
|
|
|||
|
|
@ -35,12 +35,13 @@ opModelSamplingDiscrete = ModelSamplingDiscrete()
|
|||
|
||||
|
||||
class StableDiffusionModel:
|
||||
def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None):
|
||||
def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None, vae_filename=None):
|
||||
self.unet = unet
|
||||
self.vae = vae
|
||||
self.clip = clip
|
||||
self.clip_vision = clip_vision
|
||||
self.filename = filename
|
||||
self.vae_filename = vae_filename
|
||||
self.unet_with_lora = unet
|
||||
self.clip_with_lora = clip
|
||||
self.visited_loras = ''
|
||||
|
|
@ -142,9 +143,10 @@ def apply_controlnet(positive, negative, control_net, image, strength, start_per
|
|||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def load_model(ckpt_filename):
|
||||
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings)
|
||||
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename)
|
||||
def load_model(ckpt_filename, vae_filename=None):
|
||||
unet, clip, vae, vae_filename, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings,
|
||||
vae_filename_param=vae_filename)
|
||||
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename, vae_filename=vae_filename)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import os
|
|||
import torch
|
||||
import modules.patch
|
||||
import modules.config
|
||||
import modules.flags
|
||||
import ldm_patched.modules.model_management
|
||||
import ldm_patched.modules.latent_formats
|
||||
import modules.inpaint_worker
|
||||
|
|
@ -58,17 +59,21 @@ def assert_model_integrity():
|
|||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def refresh_base_model(name):
|
||||
def refresh_base_model(name, vae_name=None):
|
||||
global model_base
|
||||
|
||||
filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
|
||||
|
||||
if model_base.filename == filename:
|
||||
vae_filename = None
|
||||
if vae_name is not None and vae_name != modules.flags.default_vae:
|
||||
vae_filename = get_file_from_folder_list(vae_name, modules.config.path_vae)
|
||||
|
||||
if model_base.filename == filename and model_base.vae_filename == vae_filename:
|
||||
return
|
||||
|
||||
model_base = core.StableDiffusionModel()
|
||||
model_base = core.load_model(filename)
|
||||
model_base = core.load_model(filename, vae_filename)
|
||||
print(f'Base model loaded: {model_base.filename}')
|
||||
print(f'VAE loaded: {model_base.vae_filename}')
|
||||
return
|
||||
|
||||
|
||||
|
|
@ -216,7 +221,7 @@ def prepare_text_encoder(async_call=True):
|
|||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def refresh_everything(refiner_model_name, base_model_name, loras,
|
||||
base_model_additional_loras=None, use_synthetic_refiner=False):
|
||||
base_model_additional_loras=None, use_synthetic_refiner=False, vae_name=None):
|
||||
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion
|
||||
|
||||
final_unet = None
|
||||
|
|
@ -227,11 +232,11 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
|
|||
|
||||
if use_synthetic_refiner and refiner_model_name == 'None':
|
||||
print('Synthetic Refiner Activated')
|
||||
refresh_base_model(base_model_name)
|
||||
refresh_base_model(base_model_name, vae_name)
|
||||
synthesize_refiner_model()
|
||||
else:
|
||||
refresh_refiner_model(refiner_model_name)
|
||||
refresh_base_model(base_model_name)
|
||||
refresh_base_model(base_model_name, vae_name)
|
||||
|
||||
refresh_loras(loras, base_model_additional_loras=base_model_additional_loras)
|
||||
assert_model_integrity()
|
||||
|
|
@ -254,7 +259,8 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
|
|||
refresh_everything(
|
||||
refiner_model_name=modules.config.default_refiner_model_name,
|
||||
base_model_name=modules.config.default_base_model_name,
|
||||
loras=get_enabled_loras(modules.config.default_loras)
|
||||
loras=get_enabled_loras(modules.config.default_loras),
|
||||
vae_name=modules.config.default_vae,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -53,6 +53,8 @@ SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys())
|
|||
sampler_list = SAMPLER_NAMES
|
||||
scheduler_list = SCHEDULER_NAMES
|
||||
|
||||
default_vae = 'Default (model)'
|
||||
|
||||
refiner_swap_method = 'joint'
|
||||
|
||||
cn_ip = "ImagePrompt"
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
|
|||
get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results)
|
||||
get_str('sampler', 'Sampler', loaded_parameter_dict, results)
|
||||
get_str('scheduler', 'Scheduler', loaded_parameter_dict, results)
|
||||
get_str('vae', 'VAE', loaded_parameter_dict, results)
|
||||
get_seed('seed', 'Seed', loaded_parameter_dict, results)
|
||||
|
||||
if is_generating:
|
||||
|
|
@ -253,6 +254,7 @@ class MetadataParser(ABC):
|
|||
self.refiner_model_name: str = ''
|
||||
self.refiner_model_hash: str = ''
|
||||
self.loras: list = []
|
||||
self.vae_name: str = ''
|
||||
|
||||
@abstractmethod
|
||||
def get_scheme(self) -> MetadataScheme:
|
||||
|
|
@ -267,7 +269,7 @@ class MetadataParser(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name,
|
||||
refiner_model_name, loras):
|
||||
refiner_model_name, loras, vae_name):
|
||||
self.raw_prompt = raw_prompt
|
||||
self.full_prompt = full_prompt
|
||||
self.raw_negative_prompt = raw_negative_prompt
|
||||
|
|
@ -289,6 +291,7 @@ class MetadataParser(ABC):
|
|||
lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras)
|
||||
lora_hash = get_sha256(lora_path)
|
||||
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
|
||||
self.vae_name = Path(vae_name).stem
|
||||
|
||||
@staticmethod
|
||||
def remove_special_loras(lora_filenames):
|
||||
|
|
@ -310,6 +313,7 @@ class A1111MetadataParser(MetadataParser):
|
|||
'steps': 'Steps',
|
||||
'sampler': 'Sampler',
|
||||
'scheduler': 'Scheduler',
|
||||
'vae': 'VAE',
|
||||
'guidance_scale': 'CFG scale',
|
||||
'seed': 'Seed',
|
||||
'resolution': 'Size',
|
||||
|
|
@ -397,13 +401,12 @@ class A1111MetadataParser(MetadataParser):
|
|||
data['sampler'] = k
|
||||
break
|
||||
|
||||
for key in ['base_model', 'refiner_model']:
|
||||
for key in ['base_model', 'refiner_model', 'vae']:
|
||||
if key in data:
|
||||
for filename in modules.config.model_filenames:
|
||||
path = Path(filename)
|
||||
if data[key] == path.stem:
|
||||
data[key] = filename
|
||||
break
|
||||
if key == 'vae':
|
||||
self.add_extension_to_filename(data, modules.config.vae_filenames, 'vae')
|
||||
else:
|
||||
self.add_extension_to_filename(data, modules.config.model_filenames, key)
|
||||
|
||||
lora_data = ''
|
||||
if 'lora_weights' in data and data['lora_weights'] != '':
|
||||
|
|
@ -433,6 +436,7 @@ class A1111MetadataParser(MetadataParser):
|
|||
|
||||
sampler = data['sampler']
|
||||
scheduler = data['scheduler']
|
||||
|
||||
if sampler in SAMPLERS and SAMPLERS[sampler] != '':
|
||||
sampler = SAMPLERS[sampler]
|
||||
if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras':
|
||||
|
|
@ -451,6 +455,7 @@ class A1111MetadataParser(MetadataParser):
|
|||
|
||||
self.fooocus_to_a1111['performance']: data['performance'],
|
||||
self.fooocus_to_a1111['scheduler']: scheduler,
|
||||
self.fooocus_to_a1111['vae']: Path(data['vae']).stem,
|
||||
# workaround for multiline prompts
|
||||
self.fooocus_to_a1111['raw_prompt']: self.raw_prompt,
|
||||
self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt,
|
||||
|
|
@ -491,6 +496,14 @@ class A1111MetadataParser(MetadataParser):
|
|||
negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else ""
|
||||
return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||
|
||||
@staticmethod
|
||||
def add_extension_to_filename(data, filenames, key):
|
||||
for filename in filenames:
|
||||
path = Path(filename)
|
||||
if data[key] == path.stem:
|
||||
data[key] = filename
|
||||
break
|
||||
|
||||
|
||||
class FooocusMetadataParser(MetadataParser):
|
||||
def get_scheme(self) -> MetadataScheme:
|
||||
|
|
@ -499,6 +512,7 @@ class FooocusMetadataParser(MetadataParser):
|
|||
def parse_json(self, metadata: dict) -> dict:
|
||||
model_filenames = modules.config.model_filenames.copy()
|
||||
lora_filenames = modules.config.lora_filenames.copy()
|
||||
vae_filenames = modules.config.vae_filenames.copy()
|
||||
self.remove_special_loras(lora_filenames)
|
||||
for key, value in metadata.items():
|
||||
if value in ['', 'None']:
|
||||
|
|
@ -507,6 +521,8 @@ class FooocusMetadataParser(MetadataParser):
|
|||
metadata[key] = self.replace_value_with_filename(key, value, model_filenames)
|
||||
elif key.startswith('lora_combined_'):
|
||||
metadata[key] = self.replace_value_with_filename(key, value, lora_filenames)
|
||||
elif key == 'vae':
|
||||
metadata[key] = self.replace_value_with_filename(key, value, vae_filenames)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
|
@ -533,6 +549,7 @@ class FooocusMetadataParser(MetadataParser):
|
|||
res['refiner_model'] = self.refiner_model_name
|
||||
res['refiner_model_hash'] = self.refiner_model_hash
|
||||
|
||||
res['vae'] = self.vae_name
|
||||
res['loras'] = self.loras
|
||||
|
||||
if modules.config.metadata_created_by != '':
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ def load_file_from_url(
|
|||
|
||||
Returns the path to the downloaded file.
|
||||
"""
|
||||
domain = os.environ.get("HF_MIRROR", "https://huggingface.co").rstrip('/')
|
||||
url = str.replace(url, "https://huggingface.co", domain, 1)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
if not file_name:
|
||||
parts = urlparse(url)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def get_current_html_path(output_format=None):
|
|||
return html_name
|
||||
|
||||
|
||||
def log(img, metadata, metadata_parser: MetadataParser | None = None, output_format=None) -> str:
|
||||
def log(img, metadata, metadata_parser: MetadataParser | None = None, output_format=None, task=None) -> str:
|
||||
path_outputs = modules.config.temp_path if args_manager.args.disable_image_log else modules.config.path_outputs
|
||||
output_format = output_format if output_format else modules.config.default_output_format
|
||||
date_string, local_temp_filename, only_name = generate_temp_filename(folder=path_outputs, extension=output_format)
|
||||
|
|
@ -111,9 +111,15 @@ def log(img, metadata, metadata_parser: MetadataParser | None = None, output_for
|
|||
for label, key, value in metadata:
|
||||
value_txt = str(value).replace('\n', ' </br> ')
|
||||
item += f"<tr><td class='label'>{label}</td><td class='value'>{value_txt}</td></tr>\n"
|
||||
|
||||
if task is not None and 'positive' in task and 'negative' in task:
|
||||
full_prompt_details = f"""<details><summary>Positive</summary>{', '.join(task['positive'])}</details>
|
||||
<details><summary>Negative</summary>{', '.join(task['negative'])}</details>"""
|
||||
item += f"<tr><td class='label'>Full raw prompt</td><td class='value'>{full_prompt_details}</td></tr>\n"
|
||||
|
||||
item += "</table>"
|
||||
|
||||
js_txt = urllib.parse.quote(json.dumps({k: v for _, k, v in metadata}, indent=0), safe='')
|
||||
js_txt = urllib.parse.quote(json.dumps({k: v for _, k, v, in metadata}, indent=0), safe='')
|
||||
item += f"</br><button onclick=\"to_clipboard('{js_txt}')\">Copy to Clipboard</button>"
|
||||
|
||||
item += "</td>"
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import math
|
|||
import modules.config
|
||||
|
||||
from modules.util import get_files_from_folder
|
||||
from random import Random
|
||||
|
||||
# cannot use modules.config - validators causing circular imports
|
||||
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
|
||||
|
|
@ -50,8 +51,13 @@ for styles_file in styles_files:
|
|||
print(f'Failed to load style file {styles_file}')
|
||||
|
||||
style_keys = list(styles.keys())
|
||||
fooocus_expansion = "Fooocus V2"
|
||||
legal_style_names = [fooocus_expansion] + style_keys
|
||||
fooocus_expansion = 'Fooocus V2'
|
||||
random_style_name = 'Random Style'
|
||||
legal_style_names = [fooocus_expansion, random_style_name] + style_keys
|
||||
|
||||
|
||||
def get_random_style(rng: Random) -> str:
|
||||
return rng.choice(list(styles.items()))[0]
|
||||
|
||||
|
||||
def apply_style(style, positive):
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def javascript_html():
|
|||
head += f'<script type="text/javascript" src="{edit_attention_js_path}"></script>\n'
|
||||
head += f'<script type="text/javascript" src="{viewer_js_path}"></script>\n'
|
||||
head += f'<script type="text/javascript" src="{image_viewer_js_path}"></script>\n'
|
||||
head += f'<meta name="samples-path" content="{samples_path}"></meta>\n'
|
||||
head += f'<meta name="samples-path" content="{samples_path}">\n'
|
||||
|
||||
if args_manager.args.theme:
|
||||
head += f'<script type="text/javascript">set_theme(\"{args_manager.args.theme}\");</script>\n'
|
||||
|
|
|
|||
|
|
@ -371,6 +371,9 @@ def is_json(data: str) -> bool:
|
|||
|
||||
|
||||
def get_file_from_folder_list(name, folders):
|
||||
if not isinstance(folders, list):
|
||||
folders = [folders]
|
||||
|
||||
for folder in folders:
|
||||
filename = os.path.abspath(os.path.realpath(os.path.join(folder, name)))
|
||||
if os.path.isfile(filename):
|
||||
|
|
|
|||
|
|
@ -368,6 +368,7 @@ A safer way is just to try "run_anime.bat" or "run_realistic.bat" - they should
|
|||
entry_with_update.py [-h] [--listen [IP]] [--port PORT]
|
||||
[--disable-header-check [ORIGIN]]
|
||||
[--web-upload-size WEB_UPLOAD_SIZE]
|
||||
[--hf-mirror HF_MIRROR]
|
||||
[--external-working-path PATH [PATH ...]]
|
||||
[--output-path OUTPUT_PATH] [--temp-path TEMP_PATH]
|
||||
[--cache-path CACHE_PATH] [--in-browser]
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 1.4 KiB |
31
webui.py
31
webui.py
|
|
@ -123,8 +123,9 @@ with shared.gradio_root:
|
|||
|
||||
with gr.Column(scale=3, min_width=0):
|
||||
generate_button = gr.Button(label="Generate", value="Generate", elem_classes='type_row', elem_id='generate_button', visible=True)
|
||||
reset_button = gr.Button(label="Reconnect", value="Reconnect", elem_classes='type_row', elem_id='reset_button', visible=False)
|
||||
load_parameter_button = gr.Button(label="Load Parameters", value="Load Parameters", elem_classes='type_row', elem_id='load_parameter_button', visible=False)
|
||||
skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', visible=False)
|
||||
skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', elem_id='skip_button', visible=False)
|
||||
stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False)
|
||||
|
||||
def stop_clicked(currentTask):
|
||||
|
|
@ -406,6 +407,8 @@ with shared.gradio_root:
|
|||
value=modules.config.default_sampler)
|
||||
scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list,
|
||||
value=modules.config.default_scheduler)
|
||||
vae_name = gr.Dropdown(label='VAE', choices=[modules.flags.default_vae] + modules.config.vae_filenames,
|
||||
value=modules.config.default_vae, show_label=True)
|
||||
|
||||
generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch',
|
||||
info='(Experimental) This may cause performance problems on some computers and certain internet conditions.',
|
||||
|
|
@ -538,6 +541,7 @@ with shared.gradio_root:
|
|||
modules.config.update_files()
|
||||
results = [gr.update(choices=modules.config.model_filenames)]
|
||||
results += [gr.update(choices=['None'] + modules.config.model_filenames)]
|
||||
results += [gr.update(choices=['None'] + modules.config.vae_filenames)]
|
||||
if not args_manager.args.disable_preset_selection:
|
||||
results += [gr.update(choices=modules.config.available_presets)]
|
||||
for i in range(modules.config.default_max_lora_number):
|
||||
|
|
@ -545,7 +549,7 @@ with shared.gradio_root:
|
|||
gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
|
||||
return results
|
||||
|
||||
refresh_files_output = [base_model, refiner_model]
|
||||
refresh_files_output = [base_model, refiner_model, vae_name]
|
||||
if not args_manager.args.disable_preset_selection:
|
||||
refresh_files_output += [preset_selection]
|
||||
refresh_files.click(refresh_files_clicked, [], refresh_files_output + lora_ctrls,
|
||||
|
|
@ -557,8 +561,8 @@ with shared.gradio_root:
|
|||
performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection,
|
||||
overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive,
|
||||
adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model,
|
||||
refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed,
|
||||
generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls
|
||||
refiner_model, refiner_switch, sampler_name, scheduler_name, vae_name, seed_random,
|
||||
image_seed, generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls
|
||||
|
||||
if not args_manager.args.disable_preset_selection:
|
||||
def preset_selection_change(preset, is_generating):
|
||||
|
|
@ -644,7 +648,7 @@ with shared.gradio_root:
|
|||
ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image]
|
||||
ctrls += [disable_preview, disable_intermediate_results, disable_seed_increment, black_out_nsfw]
|
||||
ctrls += [adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg]
|
||||
ctrls += [sampler_name, scheduler_name]
|
||||
ctrls += [sampler_name, scheduler_name, vae_name]
|
||||
ctrls += [overwrite_step, overwrite_switch, overwrite_width, overwrite_height, overwrite_vary_strength]
|
||||
ctrls += [overwrite_upscale_strength, mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint]
|
||||
ctrls += [debugging_cn_preprocessor, skipping_cn_preprocessor, canny_low_threshold, canny_high_threshold]
|
||||
|
|
@ -698,6 +702,14 @@ with shared.gradio_root:
|
|||
.then(fn=update_history_link, outputs=history_link) \
|
||||
.then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed')
|
||||
|
||||
reset_button.click(lambda: [worker.AsyncTask(args=[]), False, gr.update(visible=True, interactive=True)] +
|
||||
[gr.update(visible=False)] * 6 +
|
||||
[gr.update(visible=True, value=[])],
|
||||
outputs=[currentTask, state_is_generating, generate_button,
|
||||
reset_button, stop_button, skip_button,
|
||||
progress_html, progress_window, progress_gallery, gallery],
|
||||
queue=False)
|
||||
|
||||
for notification_file in ['notification.ogg', 'notification.mp3']:
|
||||
if os.path.exists(notification_file):
|
||||
gr.Audio(interactive=False, value=notification_file, elem_id='audio_notification', visible=False)
|
||||
|
|
@ -715,6 +727,15 @@ with shared.gradio_root:
|
|||
desc_btn.click(trigger_describe, inputs=[desc_method, desc_input_image],
|
||||
outputs=[prompt, style_selections], show_progress=True, queue=True)
|
||||
|
||||
if args_manager.args.enable_describe_uov_image:
|
||||
def trigger_uov_describe(mode, img, prompt):
|
||||
# keep prompt if not empty
|
||||
if prompt == '':
|
||||
return trigger_describe(mode, img)
|
||||
return gr.update(), gr.update()
|
||||
|
||||
uov_input_image.upload(trigger_uov_describe, inputs=[desc_method, uov_input_image, prompt],
|
||||
outputs=[prompt, style_selections], show_progress=True, queue=True)
|
||||
|
||||
def dump_default_english_config():
|
||||
from modules.localization import dump_english_config
|
||||
|
|
|
|||
Loading…
Reference in New Issue