feat: add optional model VAE select (#2867)

* Revert "fix: use LF as line breaks for Docker entrypoint.sh (#2843)" (#2865)

False alarm, worked as intended before. Sorry for the fuzz.
This reverts commit d16a54edd6.

* feat: add VAE select

* feat: use different default label, add translation

* fix: do not reload model when VAE stays the same

* refactor: code cleanup

* feat: add metadata handling
This commit is contained in:
Manuel Schmid 2024-05-09 18:59:35 +02:00 committed by GitHub
parent 121f1e0a15
commit c32bc5e199
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 84 additions and 30 deletions

View File

@ -340,6 +340,8 @@
"sgm_uniform": "sgm_uniform", "sgm_uniform": "sgm_uniform",
"simple": "simple", "simple": "simple",
"ddim_uniform": "ddim_uniform", "ddim_uniform": "ddim_uniform",
"VAE": "VAE",
"Default (model)": "Default (model)",
"Forced Overwrite of Sampling Step": "Forced Overwrite of Sampling Step", "Forced Overwrite of Sampling Step": "Forced Overwrite of Sampling Step",
"Set as -1 to disable. For developer debugging.": "Set as -1 to disable. For developer debugging.", "Set as -1 to disable. For developer debugging.": "Set as -1 to disable. For developer debugging.",
"Forced Overwrite of Refiner Switch Step": "Forced Overwrite of Refiner Switch Step", "Forced Overwrite of Refiner Switch Step": "Forced Overwrite of Refiner Switch Step",

View File

@ -427,12 +427,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) return (ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, vae_filename_param=None):
sd = ldm_patched.modules.utils.load_torch_file(ckpt_path) sd = ldm_patched.modules.utils.load_torch_file(ckpt_path)
sd_keys = sd.keys() sd_keys = sd.keys()
clip = None clip = None
clipvision = None clipvision = None
vae = None vae = None
vae_filename = None
model = None model = None
model_patcher = None model_patcher = None
clip_target = None clip_target = None
@ -462,8 +463,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
if output_vae: if output_vae:
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) if vae_filename_param is None:
vae_sd = model_config.process_vae_state_dict(vae_sd) vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
else:
vae_sd = ldm_patched.modules.utils.load_torch_file(vae_filename_param)
vae_filename = vae_filename_param
vae = VAE(sd=vae_sd) vae = VAE(sd=vae_sd)
if output_clip: if output_clip:
@ -485,7 +490,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
print("loaded straight to GPU") print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher) model_management.load_model_gpu(model_patcher)
return (model_patcher, clip, vae, clipvision) return model_patcher, clip, vae, vae_filename, clipvision
def load_unet_state_dict(sd): #load unet in diffusers format def load_unet_state_dict(sd): #load unet in diffusers format

View File

@ -166,6 +166,7 @@ def worker():
adaptive_cfg = args.pop() adaptive_cfg = args.pop()
sampler_name = args.pop() sampler_name = args.pop()
scheduler_name = args.pop() scheduler_name = args.pop()
vae_name = args.pop()
overwrite_step = args.pop() overwrite_step = args.pop()
overwrite_switch = args.pop() overwrite_switch = args.pop()
overwrite_width = args.pop() overwrite_width = args.pop()
@ -428,7 +429,7 @@ def worker():
progressbar(async_task, 3, 'Loading models ...') progressbar(async_task, 3, 'Loading models ...')
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras, loras=loras, base_model_additional_loras=base_model_additional_loras,
use_synthetic_refiner=use_synthetic_refiner) use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)
progressbar(async_task, 3, 'Processing prompts ...') progressbar(async_task, 3, 'Processing prompts ...')
tasks = [] tasks = []
@ -869,6 +870,7 @@ def worker():
d.append(('Sampler', 'sampler', sampler_name)) d.append(('Sampler', 'sampler', sampler_name))
d.append(('Scheduler', 'scheduler', scheduler_name)) d.append(('Scheduler', 'scheduler', scheduler_name))
d.append(('VAE', 'vae', vae_name))
d.append(('Seed', 'seed', str(task['task_seed']))) d.append(('Seed', 'seed', str(task['task_seed'])))
if freeu_enabled: if freeu_enabled:
@ -883,7 +885,7 @@ def worker():
metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme) metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme)
metadata_parser.set_data(task['log_positive_prompt'], task['positive'], metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
task['log_negative_prompt'], task['negative'], task['log_negative_prompt'], task['negative'],
steps, base_model_name, refiner_model_name, loras) steps, base_model_name, refiner_model_name, loras, vae_name)
d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version)) d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
img_paths.append(log(x, d, metadata_parser, output_format)) img_paths.append(log(x, d, metadata_parser, output_format))

View File

@ -189,6 +189,7 @@ paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/check
paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True) paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True)
path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/') path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/')
path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/') path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/')
path_vae = get_dir_or_set_default('path_vae', '../models/vae/')
path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/') path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/')
path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/') path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/')
path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/') path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/')
@ -346,6 +347,11 @@ default_scheduler = get_config_item_or_set_default(
default_value='karras', default_value='karras',
validator=lambda x: x in modules.flags.scheduler_list validator=lambda x: x in modules.flags.scheduler_list
) )
default_vae = get_config_item_or_set_default(
key='default_vae',
default_value=modules.flags.default_vae,
validator=lambda x: isinstance(x, str)
)
default_styles = get_config_item_or_set_default( default_styles = get_config_item_or_set_default(
key='default_styles', key='default_styles',
default_value=[ default_value=[
@ -535,6 +541,7 @@ with open(config_example_path, "w", encoding="utf-8") as json_file:
model_filenames = [] model_filenames = []
lora_filenames = [] lora_filenames = []
vae_filenames = []
wildcard_filenames = [] wildcard_filenames = []
sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors' sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
@ -546,15 +553,20 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None):
if extensions is None: if extensions is None:
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'] extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
files = [] files = []
if not isinstance(folder_paths, list):
folder_paths = [folder_paths]
for folder in folder_paths: for folder in folder_paths:
files += get_files_from_folder(folder, extensions, name_filter) files += get_files_from_folder(folder, extensions, name_filter)
return files return files
def update_files(): def update_files():
global model_filenames, lora_filenames, wildcard_filenames, available_presets global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets
model_filenames = get_model_filenames(paths_checkpoints) model_filenames = get_model_filenames(paths_checkpoints)
lora_filenames = get_model_filenames(paths_loras) lora_filenames = get_model_filenames(paths_loras)
vae_filenames = get_model_filenames(path_vae)
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt']) wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
available_presets = get_presets() available_presets = get_presets()
return return

View File

@ -35,12 +35,13 @@ opModelSamplingDiscrete = ModelSamplingDiscrete()
class StableDiffusionModel: class StableDiffusionModel:
def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None): def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None, vae_filename=None):
self.unet = unet self.unet = unet
self.vae = vae self.vae = vae
self.clip = clip self.clip = clip
self.clip_vision = clip_vision self.clip_vision = clip_vision
self.filename = filename self.filename = filename
self.vae_filename = vae_filename
self.unet_with_lora = unet self.unet_with_lora = unet
self.clip_with_lora = clip self.clip_with_lora = clip
self.visited_loras = '' self.visited_loras = ''
@ -142,9 +143,10 @@ def apply_controlnet(positive, negative, control_net, image, strength, start_per
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def load_model(ckpt_filename): def load_model(ckpt_filename, vae_filename=None):
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings) unet, clip, vae, vae_filename, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings,
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename) vae_filename_param=vae_filename)
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename, vae_filename=vae_filename)
@torch.no_grad() @torch.no_grad()

View File

@ -3,6 +3,7 @@ import os
import torch import torch
import modules.patch import modules.patch
import modules.config import modules.config
import modules.flags
import ldm_patched.modules.model_management import ldm_patched.modules.model_management
import ldm_patched.modules.latent_formats import ldm_patched.modules.latent_formats
import modules.inpaint_worker import modules.inpaint_worker
@ -58,17 +59,21 @@ def assert_model_integrity():
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def refresh_base_model(name): def refresh_base_model(name, vae_name=None):
global model_base global model_base
filename = get_file_from_folder_list(name, modules.config.paths_checkpoints) filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
if model_base.filename == filename: vae_filename = None
if vae_name is not None and vae_name != modules.flags.default_vae:
vae_filename = get_file_from_folder_list(vae_name, modules.config.path_vae)
if model_base.filename == filename and model_base.vae_filename == vae_filename:
return return
model_base = core.StableDiffusionModel() model_base = core.load_model(filename, vae_filename)
model_base = core.load_model(filename)
print(f'Base model loaded: {model_base.filename}') print(f'Base model loaded: {model_base.filename}')
print(f'VAE loaded: {model_base.vae_filename}')
return return
@ -216,7 +221,7 @@ def prepare_text_encoder(async_call=True):
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def refresh_everything(refiner_model_name, base_model_name, loras, def refresh_everything(refiner_model_name, base_model_name, loras,
base_model_additional_loras=None, use_synthetic_refiner=False): base_model_additional_loras=None, use_synthetic_refiner=False, vae_name=None):
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion
final_unet = None final_unet = None
@ -227,11 +232,11 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
if use_synthetic_refiner and refiner_model_name == 'None': if use_synthetic_refiner and refiner_model_name == 'None':
print('Synthetic Refiner Activated') print('Synthetic Refiner Activated')
refresh_base_model(base_model_name) refresh_base_model(base_model_name, vae_name)
synthesize_refiner_model() synthesize_refiner_model()
else: else:
refresh_refiner_model(refiner_model_name) refresh_refiner_model(refiner_model_name)
refresh_base_model(base_model_name) refresh_base_model(base_model_name, vae_name)
refresh_loras(loras, base_model_additional_loras=base_model_additional_loras) refresh_loras(loras, base_model_additional_loras=base_model_additional_loras)
assert_model_integrity() assert_model_integrity()
@ -254,7 +259,8 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
refresh_everything( refresh_everything(
refiner_model_name=modules.config.default_refiner_model_name, refiner_model_name=modules.config.default_refiner_model_name,
base_model_name=modules.config.default_base_model_name, base_model_name=modules.config.default_base_model_name,
loras=get_enabled_loras(modules.config.default_loras) loras=get_enabled_loras(modules.config.default_loras),
vae_name=modules.config.default_vae,
) )

View File

@ -53,6 +53,8 @@ SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys())
sampler_list = SAMPLER_NAMES sampler_list = SAMPLER_NAMES
scheduler_list = SCHEDULER_NAMES scheduler_list = SCHEDULER_NAMES
default_vae = 'Default (model)'
refiner_swap_method = 'joint' refiner_swap_method = 'joint'
cn_ip = "ImagePrompt" cn_ip = "ImagePrompt"

View File

@ -46,6 +46,7 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results) get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results)
get_str('sampler', 'Sampler', loaded_parameter_dict, results) get_str('sampler', 'Sampler', loaded_parameter_dict, results)
get_str('scheduler', 'Scheduler', loaded_parameter_dict, results) get_str('scheduler', 'Scheduler', loaded_parameter_dict, results)
get_str('vae', 'VAE', loaded_parameter_dict, results)
get_seed('seed', 'Seed', loaded_parameter_dict, results) get_seed('seed', 'Seed', loaded_parameter_dict, results)
if is_generating: if is_generating:
@ -253,6 +254,7 @@ class MetadataParser(ABC):
self.refiner_model_name: str = '' self.refiner_model_name: str = ''
self.refiner_model_hash: str = '' self.refiner_model_hash: str = ''
self.loras: list = [] self.loras: list = []
self.vae_name: str = ''
@abstractmethod @abstractmethod
def get_scheme(self) -> MetadataScheme: def get_scheme(self) -> MetadataScheme:
@ -267,7 +269,7 @@ class MetadataParser(ABC):
raise NotImplementedError raise NotImplementedError
def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name, def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name,
refiner_model_name, loras): refiner_model_name, loras, vae_name):
self.raw_prompt = raw_prompt self.raw_prompt = raw_prompt
self.full_prompt = full_prompt self.full_prompt = full_prompt
self.raw_negative_prompt = raw_negative_prompt self.raw_negative_prompt = raw_negative_prompt
@ -289,6 +291,7 @@ class MetadataParser(ABC):
lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras) lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras)
lora_hash = get_sha256(lora_path) lora_hash = get_sha256(lora_path)
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash)) self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
self.vae_name = Path(vae_name).stem
@staticmethod @staticmethod
def remove_special_loras(lora_filenames): def remove_special_loras(lora_filenames):
@ -310,6 +313,7 @@ class A1111MetadataParser(MetadataParser):
'steps': 'Steps', 'steps': 'Steps',
'sampler': 'Sampler', 'sampler': 'Sampler',
'scheduler': 'Scheduler', 'scheduler': 'Scheduler',
'vae': 'VAE',
'guidance_scale': 'CFG scale', 'guidance_scale': 'CFG scale',
'seed': 'Seed', 'seed': 'Seed',
'resolution': 'Size', 'resolution': 'Size',
@ -397,13 +401,12 @@ class A1111MetadataParser(MetadataParser):
data['sampler'] = k data['sampler'] = k
break break
for key in ['base_model', 'refiner_model']: for key in ['base_model', 'refiner_model', 'vae']:
if key in data: if key in data:
for filename in modules.config.model_filenames: if key == 'vae':
path = Path(filename) self.add_extension_to_filename(data, modules.config.vae_filenames, 'vae')
if data[key] == path.stem: else:
data[key] = filename self.add_extension_to_filename(data, modules.config.model_filenames, key)
break
lora_data = '' lora_data = ''
if 'lora_weights' in data and data['lora_weights'] != '': if 'lora_weights' in data and data['lora_weights'] != '':
@ -433,6 +436,7 @@ class A1111MetadataParser(MetadataParser):
sampler = data['sampler'] sampler = data['sampler']
scheduler = data['scheduler'] scheduler = data['scheduler']
if sampler in SAMPLERS and SAMPLERS[sampler] != '': if sampler in SAMPLERS and SAMPLERS[sampler] != '':
sampler = SAMPLERS[sampler] sampler = SAMPLERS[sampler]
if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras': if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras':
@ -451,6 +455,7 @@ class A1111MetadataParser(MetadataParser):
self.fooocus_to_a1111['performance']: data['performance'], self.fooocus_to_a1111['performance']: data['performance'],
self.fooocus_to_a1111['scheduler']: scheduler, self.fooocus_to_a1111['scheduler']: scheduler,
self.fooocus_to_a1111['vae']: Path(data['vae']).stem,
# workaround for multiline prompts # workaround for multiline prompts
self.fooocus_to_a1111['raw_prompt']: self.raw_prompt, self.fooocus_to_a1111['raw_prompt']: self.raw_prompt,
self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt, self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt,
@ -491,6 +496,14 @@ class A1111MetadataParser(MetadataParser):
negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else "" negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else ""
return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip() return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip()
@staticmethod
def add_extension_to_filename(data, filenames, key):
for filename in filenames:
path = Path(filename)
if data[key] == path.stem:
data[key] = filename
break
class FooocusMetadataParser(MetadataParser): class FooocusMetadataParser(MetadataParser):
def get_scheme(self) -> MetadataScheme: def get_scheme(self) -> MetadataScheme:
@ -499,6 +512,7 @@ class FooocusMetadataParser(MetadataParser):
def parse_json(self, metadata: dict) -> dict: def parse_json(self, metadata: dict) -> dict:
model_filenames = modules.config.model_filenames.copy() model_filenames = modules.config.model_filenames.copy()
lora_filenames = modules.config.lora_filenames.copy() lora_filenames = modules.config.lora_filenames.copy()
vae_filenames = modules.config.vae_filenames.copy()
self.remove_special_loras(lora_filenames) self.remove_special_loras(lora_filenames)
for key, value in metadata.items(): for key, value in metadata.items():
if value in ['', 'None']: if value in ['', 'None']:
@ -507,6 +521,8 @@ class FooocusMetadataParser(MetadataParser):
metadata[key] = self.replace_value_with_filename(key, value, model_filenames) metadata[key] = self.replace_value_with_filename(key, value, model_filenames)
elif key.startswith('lora_combined_'): elif key.startswith('lora_combined_'):
metadata[key] = self.replace_value_with_filename(key, value, lora_filenames) metadata[key] = self.replace_value_with_filename(key, value, lora_filenames)
elif key == 'vae':
metadata[key] = self.replace_value_with_filename(key, value, vae_filenames)
else: else:
continue continue
@ -533,6 +549,7 @@ class FooocusMetadataParser(MetadataParser):
res['refiner_model'] = self.refiner_model_name res['refiner_model'] = self.refiner_model_name
res['refiner_model_hash'] = self.refiner_model_hash res['refiner_model_hash'] = self.refiner_model_hash
res['vae'] = self.vae_name
res['loras'] = self.loras res['loras'] = self.loras
if modules.config.metadata_created_by != '': if modules.config.metadata_created_by != '':

View File

@ -371,6 +371,9 @@ def is_json(data: str) -> bool:
def get_file_from_folder_list(name, folders): def get_file_from_folder_list(name, folders):
if not isinstance(folders, list):
folders = [folders]
for folder in folders: for folder in folders:
filename = os.path.abspath(os.path.realpath(os.path.join(folder, name))) filename = os.path.abspath(os.path.realpath(os.path.join(folder, name)))
if os.path.isfile(filename): if os.path.isfile(filename):

View File

@ -407,6 +407,8 @@ with shared.gradio_root:
value=modules.config.default_sampler) value=modules.config.default_sampler)
scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list, scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list,
value=modules.config.default_scheduler) value=modules.config.default_scheduler)
vae_name = gr.Dropdown(label='VAE', choices=[modules.flags.default_vae] + modules.config.vae_filenames,
value=modules.config.default_vae, show_label=True)
generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch', generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch',
info='(Experimental) This may cause performance problems on some computers and certain internet conditions.', info='(Experimental) This may cause performance problems on some computers and certain internet conditions.',
@ -529,6 +531,7 @@ with shared.gradio_root:
modules.config.update_files() modules.config.update_files()
results = [gr.update(choices=modules.config.model_filenames)] results = [gr.update(choices=modules.config.model_filenames)]
results += [gr.update(choices=['None'] + modules.config.model_filenames)] results += [gr.update(choices=['None'] + modules.config.model_filenames)]
results += [gr.update(choices=['None'] + modules.config.vae_filenames)]
if not args_manager.args.disable_preset_selection: if not args_manager.args.disable_preset_selection:
results += [gr.update(choices=modules.config.available_presets)] results += [gr.update(choices=modules.config.available_presets)]
for i in range(modules.config.default_max_lora_number): for i in range(modules.config.default_max_lora_number):
@ -536,7 +539,7 @@ with shared.gradio_root:
gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()] gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
return results return results
refresh_files_output = [base_model, refiner_model] refresh_files_output = [base_model, refiner_model, vae_name]
if not args_manager.args.disable_preset_selection: if not args_manager.args.disable_preset_selection:
refresh_files_output += [preset_selection] refresh_files_output += [preset_selection]
refresh_files.click(refresh_files_clicked, [], refresh_files_output + lora_ctrls, refresh_files.click(refresh_files_clicked, [], refresh_files_output + lora_ctrls,
@ -548,8 +551,8 @@ with shared.gradio_root:
performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection, performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection,
overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive, overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive,
adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model, adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model,
refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed, refiner_model, refiner_switch, sampler_name, scheduler_name, vae_name, seed_random,
generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls image_seed, generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls
if not args_manager.args.disable_preset_selection: if not args_manager.args.disable_preset_selection:
def preset_selection_change(preset, is_generating): def preset_selection_change(preset, is_generating):
@ -635,7 +638,7 @@ with shared.gradio_root:
ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image] ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image]
ctrls += [disable_preview, disable_intermediate_results, disable_seed_increment] ctrls += [disable_preview, disable_intermediate_results, disable_seed_increment]
ctrls += [adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg] ctrls += [adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg]
ctrls += [sampler_name, scheduler_name] ctrls += [sampler_name, scheduler_name, vae_name]
ctrls += [overwrite_step, overwrite_switch, overwrite_width, overwrite_height, overwrite_vary_strength] ctrls += [overwrite_step, overwrite_switch, overwrite_width, overwrite_height, overwrite_vary_strength]
ctrls += [overwrite_upscale_strength, mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint] ctrls += [overwrite_upscale_strength, mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint]
ctrls += [debugging_cn_preprocessor, skipping_cn_preprocessor, canny_low_threshold, canny_high_threshold] ctrls += [debugging_cn_preprocessor, skipping_cn_preprocessor, canny_low_threshold, canny_high_threshold]